sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py CHANGED
@@ -1,28 +1,63 @@
1
1
  """SQL statement handling with centralized parameter management."""
2
2
 
3
- from dataclasses import dataclass, field, replace
4
- from typing import Any, Optional, Union
3
+ import operator
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Any, Optional, Union
5
6
 
6
7
  import sqlglot
7
8
  import sqlglot.expressions as exp
8
- from sqlglot.dialects.dialect import DialectType
9
9
  from sqlglot.errors import ParseError
10
+ from typing_extensions import TypeAlias
10
11
 
11
- from sqlspec.exceptions import RiskLevel, SQLValidationError
12
+ from sqlspec.exceptions import RiskLevel, SQLParsingError, SQLValidationError
12
13
  from sqlspec.statement.filters import StatementFilter
13
- from sqlspec.statement.parameters import ParameterConverter, ParameterStyle, ParameterValidator
14
- from sqlspec.statement.pipelines.base import StatementPipeline
15
- from sqlspec.statement.pipelines.context import SQLProcessingContext
16
- from sqlspec.statement.pipelines.transformers import CommentRemover, ParameterizeLiterals
14
+ from sqlspec.statement.parameters import (
15
+ SQLGLOT_INCOMPATIBLE_STYLES,
16
+ ParameterConverter,
17
+ ParameterStyle,
18
+ ParameterValidator,
19
+ )
20
+ from sqlspec.statement.pipelines import SQLProcessingContext, StatementPipeline
21
+ from sqlspec.statement.pipelines.transformers import CommentAndHintRemover, ParameterizeLiterals
17
22
  from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator
18
- from sqlspec.typing import is_dict
19
23
  from sqlspec.utils.logging import get_logger
24
+ from sqlspec.utils.type_guards import (
25
+ can_append_to_statement,
26
+ can_extract_parameters,
27
+ has_parameter_value,
28
+ has_risk_level,
29
+ is_dict,
30
+ is_expression,
31
+ is_statement_filter,
32
+ supports_limit,
33
+ supports_offset,
34
+ supports_order_by,
35
+ supports_where,
36
+ )
37
+
38
+ if TYPE_CHECKING:
39
+ from sqlglot.dialects.dialect import DialectType
40
+
41
+ from sqlspec.statement.parameters import ParameterNormalizationState
20
42
 
21
43
  __all__ = ("SQL", "SQLConfig", "Statement")
22
44
 
23
45
  logger = get_logger("sqlspec.statement")
24
46
 
25
- Statement = Union[str, exp.Expression, "SQL"]
47
+ Statement: TypeAlias = Union[str, exp.Expression, "SQL"]
48
+
49
+ # Parameter naming constants
50
+ PARAM_PREFIX = "param_"
51
+ POS_PARAM_PREFIX = "pos_param_"
52
+ KW_POS_PARAM_PREFIX = "kw_pos_param_"
53
+ ARG_PREFIX = "arg_"
54
+
55
+ # Cache and limit constants
56
+ DEFAULT_CACHE_SIZE = 1000
57
+
58
+ # Oracle/Colon style parameter constants
59
+ COLON_PARAM_ONE = "1"
60
+ COLON_PARAM_MIN_INDEX = 1
26
61
 
27
62
 
28
63
  @dataclass
@@ -39,9 +74,30 @@ class _ProcessedState:
39
74
 
40
75
  @dataclass
41
76
  class SQLConfig:
42
- """Configuration for SQL statement behavior."""
77
+ """Configuration for SQL statement behavior.
78
+
79
+ Uses conservative defaults that prioritize compatibility and robustness
80
+ over strict enforcement, making it easier to work with diverse SQL dialects
81
+ and complex queries.
82
+
83
+ Component Lists:
84
+ transformers: Optional list of SQL transformers for explicit staging
85
+ validators: Optional list of SQL validators for explicit staging
86
+ analyzers: Optional list of SQL analyzers for explicit staging
87
+
88
+ Configuration Options:
89
+ parameter_converter: Handles parameter style conversions
90
+ parameter_validator: Validates parameter usage and styles
91
+ analysis_cache_size: Cache size for analysis results
92
+ input_sql_had_placeholders: Populated by SQL.__init__ to track original SQL state
93
+ dialect: SQL dialect to use for parsing and generation
94
+
95
+ Parameter Style Configuration:
96
+ allowed_parameter_styles: Allowed parameter styles (e.g., ('qmark', 'named_colon'))
97
+ target_parameter_style: Target parameter style for SQL generation
98
+ allow_mixed_parameter_styles: Whether to allow mixing parameter styles in same query
99
+ """
43
100
 
44
- # Behavior flags
45
101
  enable_parsing: bool = True
46
102
  enable_validation: bool = True
47
103
  enable_transformations: bool = True
@@ -49,29 +105,23 @@ class SQLConfig:
49
105
  enable_normalization: bool = True
50
106
  strict_mode: bool = False
51
107
  cache_parsed_expression: bool = True
108
+ parse_errors_as_warnings: bool = True
52
109
 
53
- # Component lists for explicit staging
54
- transformers: Optional[list[Any]] = None
55
- validators: Optional[list[Any]] = None
56
- analyzers: Optional[list[Any]] = None
110
+ transformers: "Optional[list[Any]]" = None
111
+ validators: "Optional[list[Any]]" = None
112
+ analyzers: "Optional[list[Any]]" = None
57
113
 
58
- # Other configs
59
114
  parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
60
115
  parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
61
116
  analysis_cache_size: int = 1000
62
- input_sql_had_placeholders: bool = False # Populated by SQL.__init__
63
-
64
- # Parameter style configuration
65
- allowed_parameter_styles: Optional[tuple[str, ...]] = None
66
- """Allowed parameter styles for this SQL configuration (e.g., ('qmark', 'named_colon'))."""
67
-
68
- target_parameter_style: Optional[str] = None
69
- """Target parameter style for SQL generation."""
117
+ input_sql_had_placeholders: bool = False
118
+ dialect: "Optional[DialectType]" = None
70
119
 
120
+ allowed_parameter_styles: "Optional[tuple[str, ...]]" = None
121
+ target_parameter_style: "Optional[str]" = None
71
122
  allow_mixed_parameter_styles: bool = False
72
- """Whether to allow mixing named and positional parameters in same query."""
73
123
 
74
- def validate_parameter_style(self, style: Union[ParameterStyle, str]) -> bool:
124
+ def validate_parameter_style(self, style: "Union[ParameterStyle, str]") -> bool:
75
125
  """Check if a parameter style is allowed.
76
126
 
77
127
  Args:
@@ -81,7 +131,7 @@ class SQLConfig:
81
131
  True if the style is allowed, False otherwise
82
132
  """
83
133
  if self.allowed_parameter_styles is None:
84
- return True # No restrictions
134
+ return True
85
135
  style_str = str(style)
86
136
  return style_str in self.allowed_parameter_styles
87
137
 
@@ -91,36 +141,23 @@ class SQLConfig:
91
141
  Returns:
92
142
  StatementPipeline configured with transformers, validators, and analyzers
93
143
  """
94
- # Import here to avoid circular dependencies
95
-
96
- # Create transformers based on config
97
144
  transformers = []
98
145
  if self.transformers is not None:
99
- # Use explicit transformers if provided
100
146
  transformers = list(self.transformers)
101
- # Use default transformers
102
147
  elif self.enable_transformations:
103
- # Use target_parameter_style if available, otherwise default to "?"
104
148
  placeholder_style = self.target_parameter_style or "?"
105
- transformers = [CommentRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
149
+ transformers = [CommentAndHintRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
106
150
 
107
- # Create validators based on config
108
151
  validators = []
109
152
  if self.validators is not None:
110
- # Use explicit validators if provided
111
153
  validators = list(self.validators)
112
- # Use default validators
113
154
  elif self.enable_validation:
114
155
  validators = [ParameterStyleValidator(fail_on_violation=self.strict_mode), DMLSafetyValidator()]
115
156
 
116
- # Create analyzers based on config
117
157
  analyzers = []
118
158
  if self.analyzers is not None:
119
- # Use explicit analyzers if provided
120
159
  analyzers = list(self.analyzers)
121
- # Use default analyzers
122
160
  elif self.enable_analysis:
123
- # Currently no default analyzers
124
161
  analyzers = []
125
162
 
126
163
  return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers)
@@ -139,36 +176,39 @@ class SQL:
139
176
  """
140
177
 
141
178
  __slots__ = (
142
- "_builder_result_type", # Optional[type] - for query builders
143
- "_config", # SQLConfig - configuration
144
- "_dialect", # DialectType - SQL dialect
145
- "_filters", # list[StatementFilter] - filters to apply
146
- "_is_many", # bool - for executemany operations
147
- "_is_script", # bool - for script execution
148
- "_named_params", # dict[str, Any] - named parameters
149
- "_original_parameters", # Any - original parameters as passed in
150
- "_original_sql", # str - original SQL before normalization
151
- "_placeholder_mapping", # dict[str, Union[str, int]] - placeholder normalization mapping
152
- "_positional_params", # list[Any] - positional parameters
153
- "_processed_state", # Cached processed state
154
- "_processing_context", # SQLProcessingContext - context from pipeline processing
155
- "_raw_sql", # str - original SQL string for compatibility
156
- "_statement", # exp.Expression - the SQL expression
179
+ "_builder_result_type",
180
+ "_config",
181
+ "_dialect",
182
+ "_filters",
183
+ "_is_many",
184
+ "_is_script",
185
+ "_named_params",
186
+ "_original_parameters",
187
+ "_original_sql",
188
+ "_parameter_normalization_state",
189
+ "_placeholder_mapping",
190
+ "_positional_params",
191
+ "_processed_state",
192
+ "_processing_context",
193
+ "_raw_sql",
194
+ "_statement",
157
195
  )
158
196
 
159
197
  def __init__(
160
198
  self,
161
- statement: Union[str, exp.Expression, "SQL"],
162
- *parameters: Union[Any, StatementFilter, list[Union[Any, StatementFilter]]],
163
- _dialect: DialectType = None,
164
- _config: Optional[SQLConfig] = None,
165
- _builder_result_type: Optional[type] = None,
166
- _existing_state: Optional[dict[str, Any]] = None,
199
+ statement: "Union[str, exp.Expression, 'SQL']",
200
+ *parameters: "Union[Any, StatementFilter, list[Union[Any, StatementFilter]]]",
201
+ _dialect: "DialectType" = None,
202
+ _config: "Optional[SQLConfig]" = None,
203
+ _builder_result_type: "Optional[type]" = None,
204
+ _existing_state: "Optional[dict[str, Any]]" = None,
167
205
  **kwargs: Any,
168
206
  ) -> None:
169
207
  """Initialize SQL with centralized parameter management."""
208
+ if "config" in kwargs and _config is None:
209
+ _config = kwargs.pop("config")
170
210
  self._config = _config or SQLConfig()
171
- self._dialect = _dialect
211
+ self._dialect = _dialect or (self._config.dialect if self._config else None)
172
212
  self._builder_result_type = _builder_result_type
173
213
  self._processed_state: Optional[_ProcessedState] = None
174
214
  self._processing_context: Optional[SQLProcessingContext] = None
@@ -180,6 +220,7 @@ class SQL:
180
220
  self._original_parameters: Any = None
181
221
  self._original_sql: str = ""
182
222
  self._placeholder_mapping: dict[str, Union[str, int]] = {}
223
+ self._parameter_normalization_state: Optional[ParameterNormalizationState] = None
183
224
  self._is_many: bool = False
184
225
  self._is_script: bool = False
185
226
 
@@ -191,13 +232,17 @@ class SQL:
191
232
  if _existing_state:
192
233
  self._load_from_existing_state(_existing_state)
193
234
 
194
- if not isinstance(statement, SQL):
235
+ if not isinstance(statement, SQL) and not _existing_state:
195
236
  self._set_original_parameters(*parameters)
196
237
 
197
238
  self._process_parameters(*parameters, **kwargs)
198
239
 
199
240
  def _init_from_sql_object(
200
- self, statement: "SQL", dialect: DialectType, config: Optional[SQLConfig], builder_result_type: Optional[type]
241
+ self,
242
+ statement: "SQL",
243
+ dialect: "DialectType",
244
+ config: "Optional[SQLConfig]",
245
+ builder_result_type: "Optional[type]",
201
246
  ) -> None:
202
247
  """Initialize attributes from an existing SQL object."""
203
248
  self._statement = statement._statement
@@ -210,24 +255,21 @@ class SQL:
210
255
  self._original_parameters = statement._original_parameters
211
256
  self._original_sql = statement._original_sql
212
257
  self._placeholder_mapping = statement._placeholder_mapping.copy()
258
+ self._parameter_normalization_state = statement._parameter_normalization_state
213
259
  self._positional_params.extend(statement._positional_params)
214
260
  self._named_params.update(statement._named_params)
215
261
  self._filters.extend(statement._filters)
216
262
 
217
- def _init_from_str_or_expression(self, statement: Union[str, exp.Expression]) -> None:
263
+ def _init_from_str_or_expression(self, statement: "Union[str, exp.Expression]") -> None:
218
264
  """Initialize attributes from a SQL string or expression."""
219
265
  if isinstance(statement, str):
220
266
  self._raw_sql = statement
221
- if self._raw_sql and not self._config.input_sql_had_placeholders:
222
- param_info = self._config.parameter_validator.extract_parameters(self._raw_sql)
223
- if param_info:
224
- self._config = replace(self._config, input_sql_had_placeholders=True)
225
267
  self._statement = self._to_expression(statement)
226
268
  else:
227
269
  self._raw_sql = statement.sql(dialect=self._dialect) # pyright: ignore
228
270
  self._statement = statement
229
271
 
230
- def _load_from_existing_state(self, existing_state: dict[str, Any]) -> None:
272
+ def _load_from_existing_state(self, existing_state: "dict[str, Any]") -> None:
231
273
  """Load state from a dictionary (used by copy)."""
232
274
  self._positional_params = list(existing_state.get("positional_params", self._positional_params))
233
275
  self._named_params = dict(existing_state.get("named_params", self._named_params))
@@ -235,15 +277,16 @@ class SQL:
235
277
  self._is_many = existing_state.get("is_many", self._is_many)
236
278
  self._is_script = existing_state.get("is_script", self._is_script)
237
279
  self._raw_sql = existing_state.get("raw_sql", self._raw_sql)
280
+ self._original_parameters = existing_state.get("original_parameters", self._original_parameters)
238
281
 
239
282
  def _set_original_parameters(self, *parameters: Any) -> None:
240
283
  """Store the original parameters for compatibility."""
241
- if len(parameters) == 1 and not isinstance(parameters[0], StatementFilter):
284
+ if len(parameters) == 0 or (len(parameters) == 1 and is_statement_filter(parameters[0])):
285
+ self._original_parameters = None
286
+ elif len(parameters) == 1 and isinstance(parameters[0], (list, tuple)):
242
287
  self._original_parameters = parameters[0]
243
- elif len(parameters) > 1:
244
- self._original_parameters = parameters
245
288
  else:
246
- self._original_parameters = None
289
+ self._original_parameters = parameters
247
290
 
248
291
  def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
249
292
  """Process positional and keyword arguments for parameters and filters."""
@@ -254,7 +297,7 @@ class SQL:
254
297
  param_value = kwargs.pop("parameters")
255
298
  if isinstance(param_value, (list, tuple)):
256
299
  self._positional_params.extend(param_value)
257
- elif isinstance(param_value, dict):
300
+ elif is_dict(param_value):
258
301
  self._named_params.update(param_value)
259
302
  else:
260
303
  self._positional_params.append(param_value)
@@ -265,7 +308,7 @@ class SQL:
265
308
 
266
309
  def _process_parameter_item(self, item: Any) -> None:
267
310
  """Process a single item from the parameters list."""
268
- if isinstance(item, StatementFilter):
311
+ if is_statement_filter(item):
269
312
  self._filters.append(item)
270
313
  pos_params, named_params = self._extract_filter_parameters(item)
271
314
  self._positional_params.extend(pos_params)
@@ -273,7 +316,7 @@ class SQL:
273
316
  elif isinstance(item, list):
274
317
  for sub_item in item:
275
318
  self._process_parameter_item(sub_item)
276
- elif isinstance(item, dict):
319
+ elif is_dict(item):
277
320
  self._named_params.update(item)
278
321
  elif isinstance(item, tuple):
279
322
  self._positional_params.extend(item)
@@ -289,120 +332,255 @@ class SQL:
289
332
  if self._processed_state is not None:
290
333
  return
291
334
 
292
- # Get the final expression and parameters after filters
293
335
  final_expr, final_params = self._build_final_state()
336
+ has_placeholders = self._detect_placeholders()
337
+ initial_sql_for_context, final_params = self._prepare_context_sql(final_expr, final_params)
338
+
339
+ context = self._create_processing_context(initial_sql_for_context, final_expr, final_params, has_placeholders)
340
+ result = self._run_pipeline(context)
341
+
342
+ processed_sql, merged_params = self._process_pipeline_result(result, final_params, context)
343
+
344
+ self._finalize_processed_state(result, processed_sql, merged_params)
294
345
 
295
- # Check if the raw SQL has placeholders
346
+ def _detect_placeholders(self) -> bool:
347
+ """Detect if the raw SQL has placeholders."""
296
348
  if self._raw_sql:
297
349
  validator = self._config.parameter_validator
298
350
  raw_param_info = validator.extract_parameters(self._raw_sql)
299
351
  has_placeholders = bool(raw_param_info)
300
- else:
301
- has_placeholders = self._config.input_sql_had_placeholders
352
+ if has_placeholders:
353
+ self._config.input_sql_had_placeholders = True
354
+ return has_placeholders
355
+ return self._config.input_sql_had_placeholders
356
+
357
+ def _prepare_context_sql(self, final_expr: exp.Expression, final_params: Any) -> tuple[str, Any]:
358
+ """Prepare SQL string and parameters for context."""
359
+ initial_sql_for_context = self._raw_sql or final_expr.sql(dialect=self._dialect or self._config.dialect)
360
+
361
+ if is_expression(final_expr) and self._placeholder_mapping:
362
+ initial_sql_for_context = final_expr.sql(dialect=self._dialect or self._config.dialect)
363
+ if self._placeholder_mapping:
364
+ final_params = self._normalize_parameters(final_params)
365
+
366
+ return initial_sql_for_context, final_params
367
+
368
+ def _normalize_parameters(self, final_params: Any) -> Any:
369
+ """Normalize parameters based on placeholder mapping."""
370
+ if is_dict(final_params):
371
+ normalized_params = {}
372
+ for placeholder_key, original_name in self._placeholder_mapping.items():
373
+ if str(original_name) in final_params:
374
+ normalized_params[placeholder_key] = final_params[str(original_name)]
375
+ non_oracle_params = {
376
+ key: value
377
+ for key, value in final_params.items()
378
+ if key not in {str(name) for name in self._placeholder_mapping.values()}
379
+ }
380
+ normalized_params.update(non_oracle_params)
381
+ return normalized_params
382
+ if isinstance(final_params, (list, tuple)):
383
+ validator = self._config.parameter_validator
384
+ param_info = validator.extract_parameters(self._raw_sql)
385
+
386
+ all_numeric = all(p.name and p.name.isdigit() for p in param_info)
387
+
388
+ if all_numeric:
389
+ normalized_params = {}
302
390
 
303
- # Update config if we detected placeholders
304
- if has_placeholders and not self._config.input_sql_had_placeholders:
305
- self._config = replace(self._config, input_sql_had_placeholders=True)
391
+ min_param_num = min(int(p.name) for p in param_info if p.name)
306
392
 
307
- # Create processing context
393
+ for i, param in enumerate(final_params):
394
+ param_num = str(i + min_param_num)
395
+ normalized_params[param_num] = param
396
+
397
+ return normalized_params
398
+ normalized_params = {}
399
+ for i, param in enumerate(final_params):
400
+ if i < len(param_info):
401
+ placeholder_key = f"{PARAM_PREFIX}{param_info[i].ordinal}"
402
+ normalized_params[placeholder_key] = param
403
+ return normalized_params
404
+ return final_params
405
+
406
+ def _create_processing_context(
407
+ self, initial_sql_for_context: str, final_expr: exp.Expression, final_params: Any, has_placeholders: bool
408
+ ) -> SQLProcessingContext:
409
+ """Create SQL processing context."""
308
410
  context = SQLProcessingContext(
309
- initial_sql_string=self._raw_sql or final_expr.sql(dialect=self._dialect),
310
- dialect=self._dialect,
411
+ initial_sql_string=initial_sql_for_context,
412
+ dialect=self._dialect or self._config.dialect,
311
413
  config=self._config,
312
- current_expression=final_expr,
313
414
  initial_expression=final_expr,
415
+ current_expression=final_expr,
314
416
  merged_parameters=final_params,
315
- input_sql_had_placeholders=has_placeholders,
417
+ input_sql_had_placeholders=has_placeholders or self._config.input_sql_had_placeholders,
316
418
  )
317
419
 
318
- # Extract parameter info from the SQL
420
+ if self._placeholder_mapping:
421
+ context.extra_info["placeholder_map"] = self._placeholder_mapping
422
+
423
+ # Set normalization state if available
424
+ if self._parameter_normalization_state:
425
+ context.parameter_normalization = self._parameter_normalization_state
426
+
319
427
  validator = self._config.parameter_validator
320
428
  context.parameter_info = validator.extract_parameters(context.initial_sql_string)
321
429
 
322
- # Run the pipeline
430
+ return context
431
+
432
+ def _run_pipeline(self, context: SQLProcessingContext) -> Any:
433
+ """Run the SQL processing pipeline."""
323
434
  pipeline = self._config.get_statement_pipeline()
324
435
  result = pipeline.execute_pipeline(context)
325
-
326
- # Store the processing context for later use
327
436
  self._processing_context = result.context
437
+ return result
328
438
 
329
- # Extract processed state
439
+ def _process_pipeline_result(
440
+ self, result: Any, final_params: Any, context: SQLProcessingContext
441
+ ) -> tuple[str, Any]:
442
+ """Process the result from the pipeline."""
330
443
  processed_expr = result.expression
444
+
331
445
  if isinstance(processed_expr, exp.Anonymous):
332
446
  processed_sql = self._raw_sql or context.initial_sql_string
333
447
  else:
334
- processed_sql = processed_expr.sql(dialect=self._dialect, comments=False)
448
+ processed_sql = processed_expr.sql(dialect=self._dialect or self._config.dialect, comments=False)
335
449
  logger.debug("Processed expression SQL: '%s'", processed_sql)
336
450
 
337
- # Check if we need to denormalize pyformat placeholders
338
451
  if self._placeholder_mapping and self._original_sql:
339
- # We normalized pyformat placeholders before parsing, need to denormalize
340
- original_sql = self._original_sql
341
- # Extract parameter info from the original SQL to get the original styles
342
- param_info = self._config.parameter_validator.extract_parameters(original_sql)
343
-
344
- # Find the target style (should be pyformat)
345
- from sqlspec.statement.parameters import ParameterStyle
346
-
347
- target_styles = {p.style for p in param_info}
348
- logger.debug(
349
- "Denormalizing SQL: before='%s', original='%s', styles=%s",
350
- processed_sql,
351
- original_sql,
352
- target_styles,
353
- )
354
- if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
355
- # Denormalize back to %s
356
- processed_sql = self._config.parameter_converter._denormalize_sql(
357
- processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
358
- )
359
- logger.debug("Denormalized SQL to: '%s'", processed_sql)
360
- elif ParameterStyle.NAMED_PYFORMAT in target_styles:
361
- # Denormalize back to %(name)s
362
- processed_sql = self._config.parameter_converter._denormalize_sql(
363
- processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
364
- )
365
- logger.debug("Denormalized SQL to: '%s'", processed_sql)
452
+ processed_sql, result = self._denormalize_sql(processed_sql, result)
453
+
454
+ merged_params = self._merge_pipeline_parameters(result, final_params)
455
+
456
+ return processed_sql, merged_params
457
+
458
+ def _denormalize_sql(self, processed_sql: str, result: Any) -> tuple[str, Any]:
459
+ """Denormalize SQL back to original parameter style."""
460
+
461
+ original_sql = self._original_sql
462
+ param_info = self._config.parameter_validator.extract_parameters(original_sql)
463
+ target_styles = {p.style for p in param_info}
464
+
465
+ logger.debug(
466
+ "Denormalizing SQL: before='%s', original='%s', styles=%s", processed_sql, original_sql, target_styles
467
+ )
468
+
469
+ if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
470
+ processed_sql = self._config.parameter_converter._convert_sql_placeholders(
471
+ processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
472
+ )
473
+ logger.debug("Denormalized SQL to: '%s'", processed_sql)
474
+ elif ParameterStyle.NAMED_PYFORMAT in target_styles:
475
+ processed_sql = self._config.parameter_converter._convert_sql_placeholders(
476
+ processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
477
+ )
478
+ logger.debug("Denormalized SQL to: '%s'", processed_sql)
479
+ # Also denormalize the parameters back to their original names
480
+ if (
481
+ self._placeholder_mapping
482
+ and result.context.merged_parameters
483
+ and is_dict(result.context.merged_parameters)
484
+ ):
485
+ result.context.merged_parameters = self._denormalize_pyformat_params(result.context.merged_parameters)
486
+ elif ParameterStyle.POSITIONAL_COLON in target_styles:
487
+ processed_param_info = self._config.parameter_validator.extract_parameters(processed_sql)
488
+ has_param_placeholders = any(p.name and p.name.startswith(PARAM_PREFIX) for p in processed_param_info)
489
+
490
+ if has_param_placeholders:
491
+ logger.debug("Skipping denormalization for param_N placeholders")
366
492
  else:
367
- logger.debug(
368
- "No denormalization needed: mapping=%s, original=%s",
369
- bool(self._placeholder_mapping),
370
- bool(self._original_sql),
493
+ processed_sql = self._config.parameter_converter._convert_sql_placeholders(
494
+ processed_sql, param_info, ParameterStyle.POSITIONAL_COLON
371
495
  )
496
+ logger.debug("Denormalized SQL to: '%s'", processed_sql)
497
+ if (
498
+ self._placeholder_mapping
499
+ and result.context.merged_parameters
500
+ and is_dict(result.context.merged_parameters)
501
+ ):
502
+ result.context.merged_parameters = self._denormalize_colon_params(result.context.merged_parameters)
503
+ else:
504
+ logger.debug(
505
+ "No denormalization needed: mapping=%s, original=%s",
506
+ bool(self._placeholder_mapping),
507
+ bool(self._original_sql),
508
+ )
509
+
510
+ return processed_sql, result
372
511
 
373
- # Merge parameters from pipeline
374
- merged_params = final_params
375
- # Only merge extracted parameters if the original SQL didn't have placeholders
376
- # If it already had placeholders, the parameters should already be provided
377
- if result.context.extracted_parameters_from_pipeline and not context.input_sql_had_placeholders:
378
- if isinstance(merged_params, dict):
379
- for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
380
- param_name = f"param_{i}"
381
- merged_params[param_name] = param
382
- elif isinstance(merged_params, list):
383
- merged_params.extend(result.context.extracted_parameters_from_pipeline)
384
- elif merged_params is None:
512
+ def _denormalize_colon_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
513
+ """Denormalize colon-style parameters back to numeric format."""
514
+ # For positional colon style, all params should have numeric keys
515
+ # Just return the params as-is if they already have the right format
516
+ if all(key.isdigit() for key in params):
517
+ return params
518
+
519
+ # For positional colon, we need ALL parameters in the final result
520
+ # This includes both user parameters and extracted literals
521
+ # We should NOT filter out extracted parameters (param_0, param_1, etc)
522
+ # because they need to be included in the final parameter conversion
523
+ return params
524
+
525
+ def _denormalize_pyformat_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
526
+ """Denormalize pyformat parameters back to their original names."""
527
+ denormalized_params = {}
528
+ for placeholder_key, original_name in self._placeholder_mapping.items():
529
+ if placeholder_key in params:
530
+ # For pyformat, the original_name is the actual parameter name (e.g., 'max_value')
531
+ denormalized_params[str(original_name)] = params[placeholder_key]
532
+ # Include any parameters that weren't normalized
533
+ non_normalized_params = {key: value for key, value in params.items() if not key.startswith(PARAM_PREFIX)}
534
+ denormalized_params.update(non_normalized_params)
535
+ return denormalized_params
536
+
537
+ def _merge_pipeline_parameters(self, result: Any, final_params: Any) -> Any:
538
+ """Merge parameters from the pipeline processing."""
539
+ merged_params = result.context.merged_parameters
540
+
541
+ # If we have extracted parameters from the pipeline, only merge them if:
542
+ # 1. We don't already have parameters in merged_params, OR
543
+ # 2. The original params were None and we need to use the extracted ones
544
+ if result.context.extracted_parameters_from_pipeline:
545
+ if merged_params is None:
546
+ # No existing parameters - use the extracted ones
385
547
  merged_params = result.context.extracted_parameters_from_pipeline
386
- else:
387
- # Single value, convert to list
388
- merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
548
+ elif merged_params == final_params and final_params is None:
549
+ # Both are None, use extracted parameters
550
+ merged_params = result.context.extracted_parameters_from_pipeline
551
+ elif merged_params != result.context.extracted_parameters_from_pipeline:
552
+ # Only merge if the extracted parameters are different from what we already have
553
+ # This prevents the duplication issue where the same parameters get added twice
554
+ if is_dict(merged_params):
555
+ for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
556
+ param_name = f"{PARAM_PREFIX}{i}"
557
+ merged_params[param_name] = param
558
+ elif isinstance(merged_params, (list, tuple)):
559
+ # Only extend if we don't already have these parameters
560
+ # Convert to list and extend with extracted parameters
561
+ if isinstance(merged_params, tuple):
562
+ merged_params = list(merged_params)
563
+ merged_params.extend(result.context.extracted_parameters_from_pipeline)
564
+ else:
565
+ # Single parameter case - convert to list with original + extracted
566
+ merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
389
567
 
390
- # Cache the processed state
568
+ return merged_params
569
+
570
+ def _finalize_processed_state(self, result: Any, processed_sql: str, merged_params: Any) -> None:
571
+ """Finalize the processed state."""
391
572
  self._processed_state = _ProcessedState(
392
- processed_expression=processed_expr,
573
+ processed_expression=result.expression,
393
574
  processed_sql=processed_sql,
394
575
  merged_parameters=merged_params,
395
576
  validation_errors=list(result.context.validation_errors),
396
- analysis_results={}, # Can be populated from analysis_findings if needed
397
- transformation_results={}, # Can be populated from transformations if needed
577
+ analysis_results={},
578
+ transformation_results={},
398
579
  )
399
580
 
400
- # Check strict mode
401
581
  if self._config.strict_mode and self._processed_state.validation_errors:
402
- # Find the highest risk error
403
582
  highest_risk_error = max(
404
- self._processed_state.validation_errors,
405
- key=lambda e: e.risk_level.value if hasattr(e, "risk_level") else 0,
583
+ self._processed_state.validation_errors, key=lambda e: e.risk_level.value if has_risk_level(e) else 0
406
584
  )
407
585
  raise SQLValidationError(
408
586
  message=highest_risk_error.message,
@@ -410,81 +588,85 @@ class SQL:
410
588
  risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
411
589
  )
412
590
 
413
- def _to_expression(self, statement: Union[str, exp.Expression]) -> exp.Expression:
591
+ def _to_expression(self, statement: "Union[str, exp.Expression]") -> exp.Expression:
414
592
  """Convert string to sqlglot expression."""
415
- if isinstance(statement, exp.Expression):
593
+ if is_expression(statement):
416
594
  return statement
417
595
 
418
- # Handle empty string
419
- if not statement or not statement.strip():
420
- # Return an empty select instead of Anonymous for empty strings
596
+ if not statement or (isinstance(statement, str) and not statement.strip()):
421
597
  return exp.Select()
422
598
 
423
- # Check if parsing is disabled
424
599
  if not self._config.enable_parsing:
425
- # Return an anonymous expression that preserves the raw SQL
426
600
  return exp.Anonymous(this=statement)
427
601
 
428
- # Check if SQL contains pyformat placeholders that need normalization
429
- from sqlspec.statement.parameters import ParameterStyle
430
-
602
+ if not isinstance(statement, str):
603
+ return exp.Anonymous(this="")
431
604
  validator = self._config.parameter_validator
432
605
  param_info = validator.extract_parameters(statement)
433
606
 
434
- # Check if we have pyformat placeholders
435
- has_pyformat = any(
436
- p.style in {ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT} for p in param_info
437
- )
607
+ # Check if normalization is needed
608
+ needs_normalization = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in param_info)
438
609
 
439
610
  normalized_sql = statement
440
611
  placeholder_mapping: dict[str, Any] = {}
441
612
 
442
- if has_pyformat:
443
- # Normalize pyformat placeholders to named placeholders for SQLGlot
613
+ if needs_normalization:
444
614
  converter = self._config.parameter_converter
445
615
  normalized_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
446
- # Store the original SQL before normalization
447
616
  self._original_sql = statement
448
617
  self._placeholder_mapping = placeholder_mapping
449
618
 
619
+ # Create normalization state
620
+ from sqlspec.statement.parameters import ParameterNormalizationState
621
+
622
+ self._parameter_normalization_state = ParameterNormalizationState(
623
+ was_normalized=True,
624
+ original_styles=list({p.style for p in param_info}),
625
+ normalized_style=ParameterStyle.NAMED_COLON,
626
+ placeholder_map=placeholder_mapping,
627
+ original_param_info=param_info,
628
+ )
629
+ else:
630
+ self._parameter_normalization_state = None
631
+
450
632
  try:
451
- # Parse with sqlglot
452
633
  expressions = sqlglot.parse(normalized_sql, dialect=self._dialect) # pyright: ignore
453
634
  if not expressions:
454
- # Empty statement
455
635
  return exp.Anonymous(this=statement)
456
636
  first_expr = expressions[0]
457
637
  if first_expr is None:
458
- # Could not parse
459
638
  return exp.Anonymous(this=statement)
460
639
 
461
640
  except ParseError as e:
462
- # If parsing fails, wrap in a RawString expression
463
- logger.debug("Failed to parse SQL: %s", e)
464
- return exp.Anonymous(this=statement)
641
+ if getattr(self._config, "parse_errors_as_warnings", False):
642
+ logger.warning(
643
+ "Failed to parse SQL, returning Anonymous expression.", extra={"sql": statement, "error": str(e)}
644
+ )
645
+ return exp.Anonymous(this=statement)
646
+
647
+ msg = f"Failed to parse SQL: {statement}"
648
+ raise SQLParsingError(msg) from e
465
649
  return first_expr
466
650
 
467
651
  @staticmethod
468
652
  def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
469
653
  """Extract parameters from a filter object."""
470
- if hasattr(filter_obj, "extract_parameters"):
654
+ if can_extract_parameters(filter_obj):
471
655
  return filter_obj.extract_parameters()
472
- # Fallback for filters that don't implement the new method yet
473
656
  return [], {}
474
657
 
475
658
  def copy(
476
659
  self,
477
- statement: Optional[Union[str, exp.Expression]] = None,
478
- parameters: Optional[Any] = None,
479
- dialect: DialectType = None,
480
- config: Optional[SQLConfig] = None,
660
+ statement: "Optional[Union[str, exp.Expression]]" = None,
661
+ parameters: "Optional[Any]" = None,
662
+ dialect: "DialectType" = None,
663
+ config: "Optional[SQLConfig]" = None,
481
664
  **kwargs: Any,
482
665
  ) -> "SQL":
483
666
  """Create a copy with optional modifications.
484
667
 
485
668
  This is the primary method for creating modified SQL objects.
486
669
  """
487
- # Prepare existing state
488
670
  existing_state = {
489
671
  "positional_params": list(self._positional_params),
490
672
  "named_params": dict(self._named_params),
@@ -493,25 +675,22 @@ class SQL:
493
675
  "is_script": self._is_script,
494
676
  "raw_sql": self._raw_sql,
495
677
  }
678
+ existing_state["original_parameters"] = self._original_parameters
496
679
 
497
- # Create new instance
498
680
  new_statement = statement if statement is not None else self._statement
499
681
  new_dialect = dialect if dialect is not None else self._dialect
500
682
  new_config = config if config is not None else self._config
501
683
 
502
- # If parameters are explicitly provided, they replace existing ones
503
684
  if parameters is not None:
504
- # Clear existing state so only new parameters are used
505
685
  existing_state["positional_params"] = []
506
686
  existing_state["named_params"] = {}
507
- # Pass parameters through normal processing
508
687
  return SQL(
509
688
  new_statement,
510
689
  parameters,
511
690
  _dialect=new_dialect,
512
691
  _config=new_config,
513
692
  _builder_result_type=self._builder_result_type,
514
- _existing_state=None, # Don't use existing state
693
+ _existing_state=None,
515
694
  **kwargs,
516
695
  )
517
696
 
@@ -524,14 +703,14 @@ class SQL:
524
703
  **kwargs,
525
704
  )
526
705
 
527
- def add_named_parameter(self, name: str, value: Any) -> "SQL":
706
+ def add_named_parameter(self, name: "str", value: Any) -> "SQL":
528
707
  """Add a named parameter and return a new SQL instance."""
529
708
  new_obj = self.copy()
530
709
  new_obj._named_params[name] = value
531
710
  return new_obj
532
711
 
533
712
  def get_unique_parameter_name(
534
- self, base_name: str, namespace: Optional[str] = None, preserve_original: bool = False
713
+ self, base_name: "str", namespace: "Optional[str]" = None, preserve_original: bool = False
535
714
  ) -> str:
536
715
  """Generate a unique parameter name.
537
716
 
@@ -543,21 +722,16 @@ class SQL:
543
722
  Returns:
544
723
  A unique parameter name
545
724
  """
546
- # Check both positional and named params
547
725
  all_param_names = set(self._named_params.keys())
548
726
 
549
- # Build the candidate name
550
727
  candidate = f"{namespace}_{base_name}" if namespace else base_name
551
728
 
552
- # If preserve_original and the name is unique, use it
553
729
  if preserve_original and candidate not in all_param_names:
554
730
  return candidate
555
731
 
556
- # If not preserving or name exists, generate unique name
557
732
  if candidate not in all_param_names:
558
733
  return candidate
559
734
 
560
- # Generate unique name with counter
561
735
  counter = 1
562
736
  while True:
563
737
  new_candidate = f"{candidate}_{counter}"
@@ -567,24 +741,19 @@ class SQL:
567
741
 
568
742
  def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
569
743
  """Apply WHERE clause and return new SQL instance."""
570
- # Convert condition to expression
571
744
  condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
572
745
 
573
- # Apply WHERE to statement
574
- if hasattr(self._statement, "where"):
746
+ if supports_where(self._statement):
575
747
  new_statement = self._statement.where(condition_expr) # pyright: ignore
576
748
  else:
577
- # Wrap in SELECT if needed
578
749
  new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
579
750
 
580
751
  return self.copy(statement=new_statement)
581
752
 
582
753
  def filter(self, filter_obj: StatementFilter) -> "SQL":
583
754
  """Apply a filter and return a new SQL instance."""
584
- # Create a new SQL object with the filter added
585
755
  new_obj = self.copy()
586
756
  new_obj._filters.append(filter_obj)
587
- # Extract filter parameters
588
757
  pos_params, named_params = self._extract_filter_parameters(filter_obj)
589
758
  new_obj._positional_params.extend(pos_params)
590
759
  new_obj._named_params.update(named_params)
@@ -595,10 +764,9 @@ class SQL:
595
764
  new_obj = self.copy()
596
765
  new_obj._is_many = True
597
766
  if parameters is not None:
598
- # Replace parameters for executemany
599
767
  new_obj._positional_params = []
600
768
  new_obj._named_params = {}
601
- new_obj._positional_params = parameters
769
+ new_obj._original_parameters = parameters
602
770
  return new_obj
603
771
 
604
772
  def as_script(self) -> "SQL":
@@ -609,77 +777,82 @@ class SQL:
609
777
 
610
778
  def _build_final_state(self) -> tuple[exp.Expression, Any]:
611
779
  """Build final expression and parameters after applying filters."""
612
- # Start with current statement
613
780
  final_expr = self._statement
614
781
 
615
- # Apply all filters to the expression
616
782
  for filter_obj in self._filters:
617
- if hasattr(filter_obj, "append_to_statement"):
783
+ if can_append_to_statement(filter_obj):
618
784
  temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
619
785
  temp_sql._positional_params = list(self._positional_params)
620
786
  temp_sql._named_params = dict(self._named_params)
621
787
  result = filter_obj.append_to_statement(temp_sql)
622
788
  final_expr = result._statement if isinstance(result, SQL) else result
623
789
 
624
- # Determine final parameters format
625
790
  final_params: Any
626
791
  if self._named_params and not self._positional_params:
627
- # Only named params
628
792
  final_params = dict(self._named_params)
629
793
  elif self._positional_params and not self._named_params:
630
- # Always return a list for positional params to maintain sequence type
631
794
  final_params = list(self._positional_params)
632
795
  elif self._positional_params and self._named_params:
633
- # Mixed - merge into dict
634
796
  final_params = dict(self._named_params)
635
- # Add positional params with generated names
636
797
  for i, param in enumerate(self._positional_params):
637
798
  param_name = f"arg_{i}"
638
799
  while param_name in final_params:
639
800
  param_name = f"arg_{i}_{id(param)}"
640
801
  final_params[param_name] = param
641
802
  else:
642
- # No parameters
643
803
  final_params = None
644
804
 
645
805
  return final_expr, final_params
646
806
 
647
- # Properties for compatibility
648
807
  @property
649
808
  def sql(self) -> str:
650
809
  """Get SQL string."""
651
- # Handle empty string case
652
810
  if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
653
811
  return ""
654
812
 
655
- # For scripts, always return the raw SQL to preserve multi-statement scripts
656
813
  if self._is_script and self._raw_sql:
657
814
  return self._raw_sql
658
- # If parsing is disabled, return the raw SQL
659
815
  if not self._config.enable_parsing and self._raw_sql:
660
816
  return self._raw_sql
661
817
 
662
- # Ensure processed
663
818
  self._ensure_processed()
664
- assert self._processed_state is not None
819
+ if self._processed_state is None:
820
+ msg = "Failed to process SQL statement"
821
+ raise RuntimeError(msg)
665
822
  return self._processed_state.processed_sql
666
823
 
667
824
  @property
668
- def expression(self) -> Optional[exp.Expression]:
825
+ def expression(self) -> "Optional[exp.Expression]":
669
826
  """Get the final expression."""
670
- # Return None if parsing is disabled
671
827
  if not self._config.enable_parsing:
672
828
  return None
673
829
  self._ensure_processed()
674
- assert self._processed_state is not None
830
+ if self._processed_state is None:
831
+ msg = "Failed to process SQL statement"
832
+ raise RuntimeError(msg)
675
833
  return self._processed_state.processed_expression
676
834
 
677
835
  @property
678
836
  def parameters(self) -> Any:
679
837
  """Get merged parameters."""
838
+ if self._is_many and self._original_parameters is not None:
839
+ return self._original_parameters
840
+
841
+ if (
842
+ self._original_parameters is not None
843
+ and isinstance(self._original_parameters, tuple)
844
+ and not self._named_params
845
+ ):
846
+ return self._original_parameters
847
+
680
848
  self._ensure_processed()
681
- assert self._processed_state is not None
682
- return self._processed_state.merged_parameters
849
+ if self._processed_state is None:
850
+ msg = "Failed to process SQL statement"
851
+ raise RuntimeError(msg)
852
+ params = self._processed_state.merged_parameters
853
+ if params is None:
854
+ return {}
855
+ return params
683
856
 
684
857
  @property
685
858
  def is_many(self) -> bool:
@@ -691,56 +864,173 @@ class SQL:
691
864
  """Check if this is a script."""
692
865
  return self._is_script
693
866
 
694
- def to_sql(self, placeholder_style: Optional[str] = None) -> str:
867
+ @property
868
+ def dialect(self) -> "Optional[DialectType]":
869
+ """Get the SQL dialect."""
870
+ return self._dialect
871
+
872
+ def to_sql(self, placeholder_style: "Optional[str]" = None) -> "str":
695
873
  """Convert to SQL string with given placeholder style."""
696
874
  if self._is_script:
697
875
  return self.sql
698
876
  sql, _ = self.compile(placeholder_style=placeholder_style)
699
877
  return sql
700
878
 
701
- def get_parameters(self, style: Optional[str] = None) -> Any:
879
+ def get_parameters(self, style: "Optional[str]" = None) -> Any:
702
880
  """Get parameters in the requested style."""
703
- # Get compiled parameters with style
704
881
  _, params = self.compile(placeholder_style=style)
705
882
  return params
706
883
 
707
- def compile(self, placeholder_style: Optional[str] = None) -> tuple[str, Any]:
884
+ def _compile_execute_many(self, placeholder_style: "Optional[str]") -> "tuple[str, Any]":
885
+ """Handle compilation for execute_many operations."""
886
+ sql = self.sql
887
+
888
+ self._ensure_processed()
889
+
890
+ params = self._original_parameters
891
+
892
+ extracted_params = self._get_extracted_parameters()
893
+
894
+ if extracted_params:
895
+ params = self._merge_extracted_params_with_sets(params, extracted_params)
896
+
897
+ if placeholder_style:
898
+ sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
899
+
900
+ return sql, params
901
+
902
+ def _get_extracted_parameters(self) -> "list[Any]":
903
+ """Get extracted parameters from pipeline processing."""
904
+ extracted_params = []
905
+ if self._processed_state and self._processed_state.merged_parameters:
906
+ merged = self._processed_state.merged_parameters
907
+ if isinstance(merged, list):
908
+ if merged and not isinstance(merged[0], (tuple, list)):
909
+ extracted_params = merged
910
+ elif self._processing_context and self._processing_context.extracted_parameters_from_pipeline:
911
+ extracted_params = self._processing_context.extracted_parameters_from_pipeline
912
+ return extracted_params
913
+
914
+ def _merge_extracted_params_with_sets(self, params: Any, extracted_params: "list[Any]") -> "list[tuple[Any, ...]]":
915
+ """Merge extracted parameters with each parameter set."""
916
+ enhanced_params = []
917
+ for param_set in params:
918
+ if isinstance(param_set, (list, tuple)):
919
+ extracted_values = []
920
+ for extracted in extracted_params:
921
+ if has_parameter_value(extracted):
922
+ extracted_values.append(extracted.value)
923
+ else:
924
+ extracted_values.append(extracted)
925
+ enhanced_set = list(param_set) + extracted_values
926
+ enhanced_params.append(tuple(enhanced_set))
927
+ else:
928
+ extracted_values = []
929
+ for extracted in extracted_params:
930
+ if has_parameter_value(extracted):
931
+ extracted_values.append(extracted.value)
932
+ else:
933
+ extracted_values.append(extracted)
934
+ enhanced_params.append((param_set, *extracted_values))
935
+ return enhanced_params
936
+
937
+ def compile(self, placeholder_style: "Optional[str]" = None) -> "tuple[str, Any]":
708
938
  """Compile to SQL and parameters."""
709
- # For scripts, return raw SQL directly without processing
710
939
  if self._is_script:
711
940
  return self.sql, None
712
941
 
713
- # If parsing is disabled, return raw SQL without transformation
942
+ if self._is_many and self._original_parameters is not None:
943
+ return self._compile_execute_many(placeholder_style)
944
+
714
945
  if not self._config.enable_parsing and self._raw_sql:
715
946
  return self._raw_sql, self._raw_parameters
716
947
 
717
- # Ensure processed
718
948
  self._ensure_processed()
719
949
 
720
- # Get processed SQL and parameters
721
- assert self._processed_state is not None
950
+ if self._processed_state is None:
951
+ msg = "Failed to process SQL statement"
952
+ raise RuntimeError(msg)
722
953
  sql = self._processed_state.processed_sql
723
954
  params = self._processed_state.merged_parameters
724
955
 
725
- # Check if parameters were reordered during processing
726
- if params is not None and hasattr(self, "_processing_context") and self._processing_context:
956
+ if params is not None and self._processing_context:
727
957
  parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
728
958
  if parameter_mapping:
729
- # Apply parameter reordering based on the mapping
730
959
  params = self._reorder_parameters(params, parameter_mapping)
731
960
 
732
- # If no placeholder style requested, return as-is
961
+ # Handle denormalization if needed
962
+ if self._processing_context and self._processing_context.parameter_normalization:
963
+ norm_state = self._processing_context.parameter_normalization
964
+
965
+ # If original SQL had incompatible styles, denormalize back to the original style
966
+ # when no specific style requested OR when the requested style matches the original
967
+ if norm_state.was_normalized and norm_state.original_styles:
968
+ original_style = norm_state.original_styles[0]
969
+ should_denormalize = placeholder_style is None or (
970
+ placeholder_style and ParameterStyle(placeholder_style) == original_style
971
+ )
972
+
973
+ if should_denormalize and original_style in SQLGLOT_INCOMPATIBLE_STYLES:
974
+ # Denormalize SQL back to original style
975
+ sql = self._config.parameter_converter._convert_sql_placeholders(
976
+ sql, norm_state.original_param_info, original_style
977
+ )
978
+ # Also denormalize parameters if needed
979
+ if original_style == ParameterStyle.POSITIONAL_COLON and is_dict(params):
980
+ params = self._denormalize_colon_params(params)
981
+
982
+ params = self._unwrap_typed_parameters(params)
983
+
733
984
  if placeholder_style is None:
734
985
  return sql, params
735
986
 
736
- # Convert to requested placeholder style
737
987
  if placeholder_style:
738
- sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
988
+ sql, params = self._apply_placeholder_style(sql, params, placeholder_style)
739
989
 
740
- # Debug log the final SQL
741
- logger.debug("Final compiled SQL: '%s'", sql)
742
990
  return sql, params
743
991
 
992
+ def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
993
+ """Apply placeholder style conversion to SQL and parameters."""
994
+ # Just use the params passed in - they've already been processed
995
+ sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
996
+ return sql, params
997
+
998
+ @staticmethod
999
+ def _unwrap_typed_parameters(params: Any) -> Any:
1000
+ """Unwrap TypedParameter objects to their actual values.
1001
+
1002
+ Args:
1003
+ params: Parameters that may contain TypedParameter objects
1004
+
1005
+ Returns:
1006
+ Parameters with TypedParameter objects unwrapped to their values
1007
+ """
1008
+ if params is None:
1009
+ return None
1010
+
1011
+ if is_dict(params):
1012
+ unwrapped_dict = {}
1013
+ for key, value in params.items():
1014
+ if has_parameter_value(value):
1015
+ unwrapped_dict[key] = value.value
1016
+ else:
1017
+ unwrapped_dict[key] = value
1018
+ return unwrapped_dict
1019
+
1020
+ if isinstance(params, (list, tuple)):
1021
+ unwrapped_list = []
1022
+ for value in params:
1023
+ if has_parameter_value(value):
1024
+ unwrapped_list.append(value.value)
1025
+ else:
1026
+ unwrapped_list.append(value)
1027
+ return type(params)(unwrapped_list)
1028
+
1029
+ if has_parameter_value(params):
1030
+ return params.value
1031
+
1032
+ return params
1033
+
744
1034
  @staticmethod
745
1035
  def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
746
1036
  """Reorder parameters based on the position mapping.
@@ -753,43 +1043,34 @@ class SQL:
753
1043
  Reordered parameters in the same format as input
754
1044
  """
755
1045
  if isinstance(params, (list, tuple)):
756
- # Create a new list with reordered parameters
757
1046
  reordered_list = [None] * len(params) # pyright: ignore
758
1047
  for new_pos, old_pos in mapping.items():
759
1048
  if old_pos < len(params):
760
1049
  reordered_list[new_pos] = params[old_pos] # pyright: ignore
761
1050
 
762
- # Handle any unmapped positions
763
1051
  for i, val in enumerate(reordered_list):
764
1052
  if val is None and i < len(params) and i not in mapping:
765
- # If position wasn't mapped, try to use original
766
1053
  reordered_list[i] = params[i] # pyright: ignore
767
1054
 
768
- # Return in same format as input
769
1055
  return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
770
1056
 
771
- if isinstance(params, dict):
772
- # For dict parameters, we need to handle differently
773
- # If keys are like param_0, param_1, we can reorder them
774
- if all(key.startswith("param_") and key[6:].isdigit() for key in params):
1057
+ if is_dict(params):
1058
+ if all(key.startswith(PARAM_PREFIX) and key[len(PARAM_PREFIX) :].isdigit() for key in params):
775
1059
  reordered_dict: dict[str, Any] = {}
776
1060
  for new_pos, old_pos in mapping.items():
777
- old_key = f"param_{old_pos}"
778
- new_key = f"param_{new_pos}"
1061
+ old_key = f"{PARAM_PREFIX}{old_pos}"
1062
+ new_key = f"{PARAM_PREFIX}{new_pos}"
779
1063
  if old_key in params:
780
1064
  reordered_dict[new_key] = params[old_key]
781
1065
 
782
- # Add any unmapped parameters
783
1066
  for key, value in params.items():
784
- if key not in reordered_dict and key.startswith("param_"):
1067
+ if key not in reordered_dict and key.startswith(PARAM_PREFIX):
785
1068
  idx = int(key[6:])
786
1069
  if idx not in mapping:
787
1070
  reordered_dict[key] = value
788
1071
 
789
1072
  return reordered_dict
790
- # Can't reorder named parameters, return as-is
791
1073
  return params
792
- # Single value or unknown format, return as-is
793
1074
  return params
794
1075
 
795
1076
  def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
@@ -803,27 +1084,119 @@ class SQL:
803
1084
  Returns:
804
1085
  Tuple of (converted_sql, converted_params)
805
1086
  """
806
- # Extract parameter info from current SQL
1087
+ if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
1088
+ converter = self._config.parameter_converter
1089
+ param_info = converter.validator.extract_parameters(sql)
1090
+
1091
+ if param_info:
1092
+ target_style = (
1093
+ ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
1094
+ )
1095
+ sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
1096
+
1097
+ return sql, params
1098
+
807
1099
  converter = self._config.parameter_converter
808
- param_info = converter.validator.extract_parameters(sql)
1100
+
1101
+ # For POSITIONAL_COLON style, use original parameter info if available to preserve numeric identifiers
1102
+ target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
1103
+ if (
1104
+ target_style == ParameterStyle.POSITIONAL_COLON
1105
+ and self._processing_context
1106
+ and self._processing_context.parameter_normalization
1107
+ and self._processing_context.parameter_normalization.original_param_info
1108
+ ):
1109
+ param_info = self._processing_context.parameter_normalization.original_param_info
1110
+ else:
1111
+ param_info = converter.validator.extract_parameters(sql)
1112
+
1113
+ # CRITICAL FIX: For POSITIONAL_COLON, we need to ensure param_info reflects
1114
+ # all placeholders in the current SQL, not just the original ones.
1115
+ # This handles cases where transformers (like ParameterizeLiterals) add new placeholders.
1116
+ if target_style == ParameterStyle.POSITIONAL_COLON and param_info:
1117
+ # Re-extract from current SQL to get all placeholders
1118
+ current_param_info = converter.validator.extract_parameters(sql)
1119
+ if len(current_param_info) > len(param_info):
1120
+ # More placeholders in current SQL means transformers added some
1121
+ # Use the current info to ensure all placeholders get parameters
1122
+ param_info = current_param_info
809
1123
 
810
1124
  if not param_info:
811
1125
  return sql, params
812
1126
 
813
- # Use the internal denormalize method to convert to target style
814
- from sqlspec.statement.parameters import ParameterStyle
1127
+ if target_style == ParameterStyle.STATIC:
1128
+ return self._embed_static_parameters(sql, params, param_info)
815
1129
 
816
- target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
1130
+ if param_info and all(p.style == target_style for p in param_info):
1131
+ converted_params = self._convert_parameters_format(params, param_info, target_style)
1132
+ return sql, converted_params
817
1133
 
818
- # Replace placeholders in SQL
819
1134
  sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
820
1135
 
821
- # Convert parameters to appropriate format
822
1136
  params = self._convert_parameters_format(params, param_info, target_style)
823
1137
 
824
1138
  return sql, params
825
1139
 
826
- def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: "ParameterStyle") -> str:
1140
+ def _embed_static_parameters(self, sql: str, params: Any, param_info: list[Any]) -> tuple[str, Any]:
1141
+ """Embed parameter values directly into SQL for STATIC style.
1142
+
1143
+ This is used for scripts and other cases where parameters need to be
1144
+ embedded directly in the SQL string rather than passed separately.
1145
+
1146
+ Args:
1147
+ sql: The SQL string with placeholders
1148
+ params: The parameter values
1149
+ param_info: List of parameter information from extraction
1150
+
1151
+ Returns:
1152
+ Tuple of (sql_with_embedded_values, None)
1153
+ """
1154
+ param_list: list[Any] = []
1155
+ if is_dict(params):
1156
+ for p in param_info:
1157
+ if p.name and p.name in params:
1158
+ param_list.append(params[p.name])
1159
+ elif f"{PARAM_PREFIX}{p.ordinal}" in params:
1160
+ param_list.append(params[f"{PARAM_PREFIX}{p.ordinal}"])
1161
+ elif f"arg_{p.ordinal}" in params:
1162
+ param_list.append(params[f"arg_{p.ordinal}"])
1163
+ else:
1164
+ param_list.append(params.get(str(p.ordinal), None))
1165
+ elif isinstance(params, (list, tuple)):
1166
+ param_list = list(params)
1167
+ elif params is not None:
1168
+ param_list = [params]
1169
+
1170
+ sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
1171
+
1172
+ for p in sorted_params:
1173
+ if p.ordinal < len(param_list):
1174
+ value = param_list[p.ordinal]
1175
+
1176
+ if has_parameter_value(value):
1177
+ value = value.value
1178
+
1179
+ if value is None:
1180
+ literal_str = "NULL"
1181
+ elif isinstance(value, bool):
1182
+ literal_str = "TRUE" if value else "FALSE"
1183
+ elif isinstance(value, str):
1184
+ literal_expr = sqlglot.exp.Literal.string(value)
1185
+ literal_str = literal_expr.sql(dialect=self._dialect)
1186
+ elif isinstance(value, (int, float)):
1187
+ literal_expr = sqlglot.exp.Literal.number(value)
1188
+ literal_str = literal_expr.sql(dialect=self._dialect)
1189
+ else:
1190
+ literal_expr = sqlglot.exp.Literal.string(str(value))
1191
+ literal_str = literal_expr.sql(dialect=self._dialect)
1192
+
1193
+ start = p.position
1194
+ end = start + len(p.placeholder_text)
1195
+ sql = sql[:start] + literal_str + sql[end:]
1196
+
1197
+ return sql, None
1198
+
1199
+ def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: ParameterStyle) -> str:
827
1200
  """Replace placeholders in SQL string with target style placeholders.
828
1201
 
829
1202
  Args:
@@ -834,12 +1207,10 @@ class SQL:
834
1207
  Returns:
835
1208
  SQL string with replaced placeholders
836
1209
  """
837
- # Sort by position in reverse to avoid position shifts
838
1210
  sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
839
1211
 
840
1212
  for p in sorted_params:
841
1213
  new_placeholder = self._generate_placeholder(p, target_style)
842
- # Replace the placeholder in SQL
843
1214
  start = p.position
844
1215
  end = start + len(p.placeholder_text)
845
1216
  sql = sql[:start] + new_placeholder + sql[end:]
@@ -847,7 +1218,7 @@ class SQL:
847
1218
  return sql
848
1219
 
849
1220
  @staticmethod
850
- def _generate_placeholder(param: Any, target_style: "ParameterStyle") -> str:
1221
+ def _generate_placeholder(param: Any, target_style: ParameterStyle) -> str:
851
1222
  """Generate a placeholder string for the given parameter style.
852
1223
 
853
1224
  Args:
@@ -857,36 +1228,34 @@ class SQL:
857
1228
  Returns:
858
1229
  Placeholder string
859
1230
  """
860
- if target_style == ParameterStyle.QMARK:
1231
+ if target_style in {ParameterStyle.STATIC, ParameterStyle.QMARK}:
861
1232
  return "?"
862
1233
  if target_style == ParameterStyle.NUMERIC:
863
- # Use 1-based numbering for numeric style
864
1234
  return f"${param.ordinal + 1}"
865
1235
  if target_style == ParameterStyle.NAMED_COLON:
866
- # Use original name if available, otherwise generate one
867
- # Oracle doesn't like underscores at the start of parameter names
868
1236
  if param.name and not param.name.isdigit():
869
- # Use the name if it's not just a number
870
1237
  return f":{param.name}"
871
- # Generate a new name for numeric placeholders or missing names
872
1238
  return f":arg_{param.ordinal}"
873
1239
  if target_style == ParameterStyle.NAMED_AT:
874
- # Use @ prefix for BigQuery style
875
- # BigQuery requires parameter names to start with a letter, not underscore
876
1240
  return f"@{param.name or f'param_{param.ordinal}'}"
877
1241
  if target_style == ParameterStyle.POSITIONAL_COLON:
878
- # Use :1, :2, etc. for Oracle positional style
1242
+ # For Oracle positional colon, preserve the original numeric identifier if it was already :N style
1243
+ if (
1244
+ hasattr(param, "style")
1245
+ and param.style == ParameterStyle.POSITIONAL_COLON
1246
+ and hasattr(param, "name")
1247
+ and param.name
1248
+ and param.name.isdigit()
1249
+ ):
1250
+ return f":{param.name}"
879
1251
  return f":{param.ordinal + 1}"
880
1252
  if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
881
- # Use %s for positional pyformat
882
1253
  return "%s"
883
1254
  if target_style == ParameterStyle.NAMED_PYFORMAT:
884
- # Use %(name)s for named pyformat
885
- return f"%({param.name or f'_arg_{param.ordinal}'})s"
886
- # Keep original for unknown styles
1255
+ return f"%({param.name or f'arg_{param.ordinal}'})s"
887
1256
  return str(param.placeholder_text)
888
1257
 
889
- def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: "ParameterStyle") -> Any:
1258
+ def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: ParameterStyle) -> Any:
890
1259
  """Convert parameters to the appropriate format for the target style.
891
1260
 
892
1261
  Args:
@@ -907,10 +1276,96 @@ class SQL:
907
1276
  return self._convert_to_named_pyformat_format(params, param_info)
908
1277
  return params
909
1278
 
1279
+ def _convert_list_to_colon_dict(
1280
+ self, params: "Union[list[Any], tuple[Any, ...]]", param_info: "list[Any]"
1281
+ ) -> "dict[str, Any]":
1282
+ """Convert list/tuple parameters to colon-style dict format."""
1283
+ result_dict: dict[str, Any] = {}
1284
+
1285
+ if param_info:
1286
+ all_numeric = all(p.name and p.name.isdigit() for p in param_info)
1287
+ if all_numeric:
1288
+ for i, value in enumerate(params):
1289
+ result_dict[str(i + 1)] = value
1290
+ else:
1291
+ for i, value in enumerate(params):
1292
+ if i < len(param_info):
1293
+ param_name = param_info[i].name or str(i + 1)
1294
+ result_dict[param_name] = value
1295
+ else:
1296
+ result_dict[str(i + 1)] = value
1297
+ else:
1298
+ for i, value in enumerate(params):
1299
+ result_dict[str(i + 1)] = value
1300
+
1301
+ return result_dict
1302
+
1303
+ def _convert_single_value_to_colon_dict(self, params: Any, param_info: "list[Any]") -> "dict[str, Any]":
1304
+ """Convert single value parameter to colon-style dict format."""
1305
+ result_dict: dict[str, Any] = {}
1306
+ if param_info and param_info[0].name and param_info[0].name.isdigit():
1307
+ result_dict[param_info[0].name] = params
1308
+ else:
1309
+ result_dict["1"] = params
1310
+ return result_dict
1311
+
1312
+ def _process_mixed_colon_params(self, params: "dict[str, Any]", param_info: "list[Any]") -> "dict[str, Any]":
1313
+ """Process mixed colon-style numeric and normalized parameters."""
1314
+ result_dict: dict[str, Any] = {}
1315
+
1316
+ # When we have mixed parameters (extracted literals + user oracle params),
1317
+ # we need to be careful about the ordering. The extracted literals should
1318
+ # fill positions based on where they appear in the SQL, not based on
1319
+ # matching parameter names.
1320
+
1321
+ # Separate extracted parameters and user oracle parameters
1322
+ extracted_params = []
1323
+ user_oracle_params = {}
1324
+ extracted_keys_sorted = []
1325
+
1326
+ for key, value in params.items():
1327
+ if has_parameter_value(value):
1328
+ extracted_params.append((key, value))
1329
+ elif key.isdigit():
1330
+ user_oracle_params[key] = value
1331
+ elif key.startswith("param_") and key[6:].isdigit():
1332
+ param_idx = int(key[6:])
1333
+ oracle_key = str(param_idx + 1)
1334
+ if oracle_key not in user_oracle_params:
1335
+ extracted_keys_sorted.append((param_idx, key, value))
1336
+ else:
1337
+ extracted_params.append((key, value))
1338
+
1339
+ extracted_keys_sorted.sort(key=operator.itemgetter(0))
1340
+ for _, key, value in extracted_keys_sorted:
1341
+ extracted_params.append((key, value))
1342
+
1343
+ # Build lists of parameter values in order
1344
+ extracted_values = []
1345
+ for _, value in extracted_params:
1346
+ if has_parameter_value(value):
1347
+ extracted_values.append(value.value)
1348
+ else:
1349
+ extracted_values.append(value)
1350
+
1351
+ user_values = [user_oracle_params[key] for key in sorted(user_oracle_params.keys(), key=int)]
1352
+
1353
+ # Now assign parameters based on position
1354
+ # Extracted parameters go first (they were literals in original positions)
1355
+ # User parameters follow
1356
+ all_values = extracted_values + user_values
1357
+
1358
+ for i, p in enumerate(sorted(param_info, key=lambda x: x.ordinal)):
1359
+ oracle_key = str(p.ordinal + 1)
1360
+ if i < len(all_values):
1361
+ result_dict[oracle_key] = all_values[i]
1362
+
1363
+ return result_dict
1364
+
910
1365
  def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
911
- """Convert to dict format for Oracle positional colon style.
1366
+ """Convert to dict format for positional colon style.
912
1367
 
913
- Oracle's positional colon style uses :1, :2, etc. placeholders and expects
1368
+ Positional colon style uses :1, :2, etc. placeholders and expects
914
1369
  parameters as a dict with string keys "1", "2", etc.
915
1370
 
916
1371
  For execute_many operations, returns a list of parameter sets.
@@ -922,68 +1377,76 @@ class SQL:
922
1377
  Returns:
923
1378
  Dict of parameters with string keys "1", "2", etc., or list for execute_many
924
1379
  """
925
- # Special handling for execute_many
926
1380
  if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
927
- # This is execute_many - keep as list but process each item
928
1381
  return params
929
1382
 
930
- result_dict: dict[str, Any] = {}
931
-
932
1383
  if isinstance(params, (list, tuple)):
933
- # Convert list/tuple to dict with string keys based on param_info
934
- if param_info:
935
- # Check if all param names are numeric (positional colon style)
936
- all_numeric = all(p.name and p.name.isdigit() for p in param_info)
937
- if all_numeric:
938
- # Sort param_info by numeric name to match list order
939
- sorted_params = sorted(param_info, key=lambda p: int(p.name))
940
- for i, value in enumerate(params):
941
- if i < len(sorted_params):
942
- # Map based on numeric order, not SQL appearance order
943
- param_name = sorted_params[i].name
944
- result_dict[param_name] = value
945
- else:
946
- # Extra parameters
947
- result_dict[str(i + 1)] = value
948
- else:
949
- # Non-numeric names, map by ordinal
950
- for i, value in enumerate(params):
951
- if i < len(param_info):
952
- param_name = param_info[i].name or str(i + 1)
953
- result_dict[param_name] = value
954
- else:
955
- result_dict[str(i + 1)] = value
956
- else:
957
- # No param_info, default to 1-based indexing
958
- for i, value in enumerate(params):
959
- result_dict[str(i + 1)] = value
960
- return result_dict
1384
+ return self._convert_list_to_colon_dict(params, param_info)
961
1385
 
962
1386
  if not is_dict(params) and param_info:
963
- # Single value parameter
964
- if param_info and param_info[0].name and param_info[0].name.isdigit():
965
- # Use the actual parameter name from SQL (e.g., "0")
966
- result_dict[param_info[0].name] = params
967
- else:
968
- # Default to "1"
969
- result_dict["1"] = params
970
- return result_dict
1387
+ return self._convert_single_value_to_colon_dict(params, param_info)
971
1388
 
972
- if isinstance(params, dict):
973
- # Check if already in correct format (keys are "1", "2", etc.)
1389
+ if is_dict(params):
974
1390
  if all(key.isdigit() for key in params):
975
1391
  return params
976
1392
 
977
- # Convert from other dict formats
978
- for p in sorted(param_info, key=lambda x: x.ordinal):
979
- # Oracle uses 1-based indexing
980
- oracle_key = str(p.ordinal + 1)
981
- if p.name and p.name in params:
982
- result_dict[oracle_key] = params[p.name]
983
- elif f"arg_{p.ordinal}" in params:
984
- result_dict[oracle_key] = params[f"arg_{p.ordinal}"]
985
- elif f"param_{p.ordinal}" in params:
986
- result_dict[oracle_key] = params[f"param_{p.ordinal}"]
1393
+ if all(key.startswith("param_") for key in params):
1394
+ param_result_dict: dict[str, Any] = {}
1395
+ for p in sorted(param_info, key=lambda x: x.ordinal):
1396
+ # Use the parameter's ordinal to find the normalized key
1397
+ normalized_key = f"param_{p.ordinal}"
1398
+ if normalized_key in params:
1399
+ if p.name and p.name.isdigit():
1400
+ # For Oracle numeric parameters, preserve the original number
1401
+ param_result_dict[p.name] = params[normalized_key]
1402
+ else:
1403
+ # For other cases, use sequential numbering
1404
+ param_result_dict[str(p.ordinal + 1)] = params[normalized_key]
1405
+ return param_result_dict
1406
+
1407
+ has_oracle_numeric = any(key.isdigit() for key in params)
1408
+ has_param_normalized = any(key.startswith("param_") for key in params)
1409
+ has_typed_params = any(has_parameter_value(v) for v in params.values())
1410
+
1411
+ if (has_oracle_numeric and has_param_normalized) or has_typed_params:
1412
+ return self._process_mixed_colon_params(params, param_info)
1413
+
1414
+ result_dict: dict[str, Any] = {}
1415
+
1416
+ if param_info:
1417
+ # Process all parameters in order of their ordinals
1418
+ for p in sorted(param_info, key=lambda x: x.ordinal):
1419
+ oracle_key = str(p.ordinal + 1)
1420
+ value = None
1421
+
1422
+ # Try different ways to find the parameter value
1423
+ if p.name and (
1424
+ p.name in params
1425
+ or (p.name.isdigit() and p.name in params)
1426
+ or (p.name.startswith("param_") and p.name in params)
1427
+ ):
1428
+ value = params[p.name]
1429
+
1430
+ # If not found by name, try by ordinal-based keys
1431
+ if value is None:
1432
+ # Try param_N format (common for pipeline parameters)
1433
+ param_key = f"param_{p.ordinal}"
1434
+ if param_key in params:
1435
+ value = params[param_key]
1436
+ # Try arg_N format
1437
+ elif f"arg_{p.ordinal}" in params:
1438
+ value = params[f"arg_{p.ordinal}"]
1439
+ # For positional colon, also check if there's a numeric key
1440
+ # that matches the ordinal position
1441
+ elif str(p.ordinal + 1) in params:
1442
+ value = params[str(p.ordinal + 1)]
1443
+
1444
+ # Unwrap TypedParameter if needed
1445
+ if value is not None:
1446
+ if has_parameter_value(value):
1447
+ value = value.value
1448
+ result_dict[oracle_key] = value
1449
+
987
1450
  return result_dict
988
1451
 
989
1452
  return params
@@ -1001,33 +1464,79 @@ class SQL:
1001
1464
  """
1002
1465
  result_list: list[Any] = []
1003
1466
  if is_dict(params):
1467
+ param_values_by_ordinal: dict[int, Any] = {}
1468
+
1004
1469
  for p in param_info:
1005
1470
  if p.name and p.name in params:
1006
- # Named parameter - get from dict and extract value from TypedParameter if needed
1007
- val = params[p.name]
1008
- if hasattr(val, "value"):
1471
+ param_values_by_ordinal[p.ordinal] = params[p.name]
1472
+
1473
+ for p in param_info:
1474
+ if p.name is None and p.ordinal not in param_values_by_ordinal:
1475
+ arg_key = f"arg_{p.ordinal}"
1476
+ param_key = f"param_{p.ordinal}"
1477
+ if arg_key in params:
1478
+ param_values_by_ordinal[p.ordinal] = params[arg_key]
1479
+ elif param_key in params:
1480
+ param_values_by_ordinal[p.ordinal] = params[param_key]
1481
+
1482
+ remaining_params = {
1483
+ k: v
1484
+ for k, v in params.items()
1485
+ if k not in {p.name for p in param_info if p.name} and not k.startswith(("arg_", "param_"))
1486
+ }
1487
+
1488
+ unmatched_ordinals = [p.ordinal for p in param_info if p.ordinal not in param_values_by_ordinal]
1489
+
1490
+ for ordinal, (_, value) in zip(unmatched_ordinals, remaining_params.items()):
1491
+ param_values_by_ordinal[ordinal] = value
1492
+
1493
+ for p in param_info:
1494
+ val = param_values_by_ordinal.get(p.ordinal)
1495
+ if val is not None:
1496
+ if has_parameter_value(val):
1009
1497
  result_list.append(val.value)
1010
1498
  else:
1011
1499
  result_list.append(val)
1012
- elif p.name is None:
1013
- # Unnamed parameter (qmark style) - look for arg_N
1014
- arg_key = f"arg_{p.ordinal}"
1015
- if arg_key in params:
1016
- # Extract value from TypedParameter if needed
1017
- val = params[arg_key]
1018
- if hasattr(val, "value"):
1500
+ else:
1501
+ result_list.append(None)
1502
+
1503
+ return result_list
1504
+ if isinstance(params, (list, tuple)):
1505
+ # Special case: if params is empty, preserve it (don't create None values)
1506
+ # This is important for execute_many with empty parameter lists
1507
+ if not params:
1508
+ return params
1509
+
1510
+ # Handle mixed parameter styles correctly
1511
+ # For mixed styles, assign parameters in order of appearance, not by numeric reference
1512
+ if param_info and any(p.style == ParameterStyle.NUMERIC for p in param_info):
1513
+ # Create mapping from ordinal to parameter value
1514
+ param_mapping: dict[int, Any] = {}
1515
+
1516
+ # Sort parameter info by position to get order of appearance
1517
+ sorted_params = sorted(param_info, key=lambda p: p.position)
1518
+
1519
+ # Assign parameters sequentially in order of appearance
1520
+ for i, param_info_item in enumerate(sorted_params):
1521
+ if i < len(params):
1522
+ param_mapping[param_info_item.ordinal] = params[i]
1523
+
1524
+ # Build result list ordered by original ordinal values
1525
+ for i in range(len(param_info)):
1526
+ val = param_mapping.get(i)
1527
+ if val is not None:
1528
+ if has_parameter_value(val):
1019
1529
  result_list.append(val.value)
1020
1530
  else:
1021
1531
  result_list.append(val)
1022
1532
  else:
1023
1533
  result_list.append(None)
1024
- else:
1025
- # Named parameter not in dict
1026
- result_list.append(None)
1027
- return result_list
1028
- if isinstance(params, (list, tuple)):
1534
+
1535
+ return result_list
1536
+
1537
+ # Standard conversion for non-mixed styles
1029
1538
  for param in params:
1030
- if hasattr(param, "value"):
1539
+ if has_parameter_value(param):
1031
1540
  result_list.append(param.value)
1032
1541
  else:
1033
1542
  result_list.append(param)
@@ -1047,28 +1556,26 @@ class SQL:
1047
1556
  """
1048
1557
  result_dict: dict[str, Any] = {}
1049
1558
  if is_dict(params):
1050
- # For dict params with matching parameter names, return as-is
1051
- # Otherwise, remap to match the expected names
1052
1559
  if all(p.name in params for p in param_info if p.name):
1053
1560
  return params
1054
1561
  for p in param_info:
1055
1562
  if p.name and p.name in params:
1056
1563
  result_dict[p.name] = params[p.name]
1057
1564
  elif f"param_{p.ordinal}" in params:
1058
- # Handle param_N style names
1059
- # Oracle doesn't like underscores at the start of parameter names
1060
1565
  result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
1061
1566
  return result_dict
1062
1567
  if isinstance(params, (list, tuple)):
1063
- # Convert list/tuple to dict with parameter names from param_info
1064
-
1065
1568
  for i, value in enumerate(params):
1569
+ if has_parameter_value(value):
1570
+ value = value.value
1571
+
1066
1572
  if i < len(param_info):
1067
1573
  p = param_info[i]
1068
- # Use the actual parameter name if available
1069
- # Oracle doesn't like underscores at the start of parameter names
1070
1574
  param_name = p.name or f"arg_{i}"
1071
1575
  result_dict[param_name] = value
1576
+ else:
1577
+ param_name = f"arg_{i}"
1578
+ result_dict[param_name] = value
1072
1579
  return result_dict
1073
1580
  return params
1074
1581
 
@@ -1084,7 +1591,6 @@ class SQL:
1084
1591
  Dict of parameters with names
1085
1592
  """
1086
1593
  if isinstance(params, (list, tuple)):
1087
- # Convert list to dict with generated names
1088
1594
  result_dict: dict[str, Any] = {}
1089
1595
  for i, p in enumerate(param_info):
1090
1596
  if i < len(params):
@@ -1093,14 +1599,15 @@ class SQL:
1093
1599
  return result_dict
1094
1600
  return params
1095
1601
 
1096
- # Validation properties for compatibility
1097
1602
  @property
1098
1603
  def validation_errors(self) -> list[Any]:
1099
1604
  """Get validation errors."""
1100
1605
  if not self._config.enable_validation:
1101
1606
  return []
1102
1607
  self._ensure_processed()
1103
- assert self._processed_state
1608
+ if not self._processed_state:
1609
+ msg = "Failed to process SQL statement"
1610
+ raise RuntimeError(msg)
1104
1611
  return self._processed_state.validation_errors
1105
1612
 
1106
1613
  @property
@@ -1113,25 +1620,30 @@ class SQL:
1113
1620
  """Check if statement is safe."""
1114
1621
  return not self.has_errors
1115
1622
 
1116
- # Additional compatibility methods
1117
1623
  def validate(self) -> list[Any]:
1118
1624
  """Validate the SQL statement and return validation errors."""
1119
1625
  return self.validation_errors
1120
1626
 
1121
1627
  @property
1122
1628
  def parameter_info(self) -> list[Any]:
1123
- """Get parameter information from the SQL statement."""
1629
+ """Get parameter information from the SQL statement.
1630
+
1631
+ Returns the original parameter info before any normalization.
1632
+ """
1124
1633
  validator = self._config.parameter_validator
1125
- if self._config.enable_parsing and self._processed_state:
1126
- sql_for_validation = self.expression.sql(dialect=self._dialect) if self.expression else self.sql # pyright: ignore
1127
- else:
1128
- sql_for_validation = self.sql
1129
- return validator.extract_parameters(sql_for_validation)
1634
+ if self._raw_sql:
1635
+ return validator.extract_parameters(self._raw_sql)
1636
+
1637
+ self._ensure_processed()
1638
+
1639
+ if self._processing_context:
1640
+ return self._processing_context.parameter_info
1641
+
1642
+ return []
1130
1643
 
1131
1644
  @property
1132
1645
  def _raw_parameters(self) -> Any:
1133
1646
  """Get raw parameters for compatibility."""
1134
- # Return the original parameters as passed in
1135
1647
  return self._original_parameters
1136
1648
 
1137
1649
  @property
@@ -1140,7 +1652,7 @@ class SQL:
1140
1652
  return self.sql
1141
1653
 
1142
1654
  @property
1143
- def _expression(self) -> Optional[exp.Expression]:
1655
+ def _expression(self) -> "Optional[exp.Expression]":
1144
1656
  """Get expression for compatibility."""
1145
1657
  return self.expression
1146
1658
 
@@ -1152,18 +1664,15 @@ class SQL:
1152
1664
  def limit(self, count: int, use_parameter: bool = False) -> "SQL":
1153
1665
  """Add LIMIT clause."""
1154
1666
  if use_parameter:
1155
- # Create a unique parameter name
1156
1667
  param_name = self.get_unique_parameter_name("limit")
1157
- # Add parameter to the SQL object
1158
1668
  result = self
1159
1669
  result = result.add_named_parameter(param_name, count)
1160
- # Use placeholder in the expression
1161
- if hasattr(result._statement, "limit"):
1670
+ if supports_limit(result._statement):
1162
1671
  new_statement = result._statement.limit(exp.Placeholder(this=param_name)) # pyright: ignore
1163
1672
  else:
1164
1673
  new_statement = exp.Select().from_(result._statement).limit(exp.Placeholder(this=param_name)) # pyright: ignore
1165
1674
  return result.copy(statement=new_statement)
1166
- if hasattr(self._statement, "limit"):
1675
+ if supports_limit(self._statement):
1167
1676
  new_statement = self._statement.limit(count) # pyright: ignore
1168
1677
  else:
1169
1678
  new_statement = exp.Select().from_(self._statement).limit(count) # pyright: ignore
@@ -1172,18 +1681,15 @@ class SQL:
1172
1681
  def offset(self, count: int, use_parameter: bool = False) -> "SQL":
1173
1682
  """Add OFFSET clause."""
1174
1683
  if use_parameter:
1175
- # Create a unique parameter name
1176
1684
  param_name = self.get_unique_parameter_name("offset")
1177
- # Add parameter to the SQL object
1178
1685
  result = self
1179
1686
  result = result.add_named_parameter(param_name, count)
1180
- # Use placeholder in the expression
1181
- if hasattr(result._statement, "offset"):
1687
+ if supports_offset(result._statement):
1182
1688
  new_statement = result._statement.offset(exp.Placeholder(this=param_name)) # pyright: ignore
1183
1689
  else:
1184
1690
  new_statement = exp.Select().from_(result._statement).offset(exp.Placeholder(this=param_name)) # pyright: ignore
1185
1691
  return result.copy(statement=new_statement)
1186
- if hasattr(self._statement, "offset"):
1692
+ if supports_offset(self._statement):
1187
1693
  new_statement = self._statement.offset(count) # pyright: ignore
1188
1694
  else:
1189
1695
  new_statement = exp.Select().from_(self._statement).offset(count) # pyright: ignore
@@ -1191,7 +1697,7 @@ class SQL:
1191
1697
 
1192
1698
  def order_by(self, expression: exp.Expression) -> "SQL":
1193
1699
  """Add ORDER BY clause."""
1194
- if hasattr(self._statement, "order_by"):
1700
+ if supports_order_by(self._statement):
1195
1701
  new_statement = self._statement.order_by(expression) # pyright: ignore
1196
1702
  else:
1197
1703
  new_statement = exp.Select().from_(self._statement).order_by(expression) # pyright: ignore