sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (159) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +480 -121
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +18 -65
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +830 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +666 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +164 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/cli.py +1 -1
  90. sqlspec/extensions/litestar/config.py +0 -1
  91. sqlspec/extensions/litestar/handlers.py +15 -26
  92. sqlspec/extensions/litestar/plugin.py +18 -16
  93. sqlspec/extensions/litestar/providers.py +17 -52
  94. sqlspec/loader.py +424 -105
  95. sqlspec/migrations/__init__.py +12 -0
  96. sqlspec/migrations/base.py +92 -68
  97. sqlspec/migrations/commands.py +24 -106
  98. sqlspec/migrations/loaders.py +402 -0
  99. sqlspec/migrations/runner.py +49 -51
  100. sqlspec/migrations/tracker.py +31 -44
  101. sqlspec/migrations/utils.py +64 -24
  102. sqlspec/protocols.py +7 -183
  103. sqlspec/storage/__init__.py +1 -1
  104. sqlspec/storage/backends/base.py +37 -40
  105. sqlspec/storage/backends/fsspec.py +136 -112
  106. sqlspec/storage/backends/obstore.py +138 -160
  107. sqlspec/storage/capabilities.py +5 -4
  108. sqlspec/storage/registry.py +57 -106
  109. sqlspec/typing.py +136 -115
  110. sqlspec/utils/__init__.py +2 -3
  111. sqlspec/utils/correlation.py +0 -3
  112. sqlspec/utils/deprecation.py +6 -6
  113. sqlspec/utils/fixtures.py +6 -6
  114. sqlspec/utils/logging.py +0 -2
  115. sqlspec/utils/module_loader.py +7 -12
  116. sqlspec/utils/singleton.py +0 -1
  117. sqlspec/utils/sync_tools.py +17 -38
  118. sqlspec/utils/text.py +12 -51
  119. sqlspec/utils/type_guards.py +443 -232
  120. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
  121. sqlspec-0.16.0.dist-info/RECORD +134 -0
  122. sqlspec/adapters/adbc/transformers.py +0 -108
  123. sqlspec/driver/connection.py +0 -207
  124. sqlspec/driver/mixins/_cache.py +0 -114
  125. sqlspec/driver/mixins/_csv_writer.py +0 -91
  126. sqlspec/driver/mixins/_pipeline.py +0 -508
  127. sqlspec/driver/mixins/_query_tools.py +0 -796
  128. sqlspec/driver/mixins/_result_utils.py +0 -138
  129. sqlspec/driver/mixins/_storage.py +0 -912
  130. sqlspec/driver/mixins/_type_coercion.py +0 -128
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/statement/__init__.py +0 -21
  133. sqlspec/statement/builder/_merge.py +0 -95
  134. sqlspec/statement/cache.py +0 -50
  135. sqlspec/statement/filters.py +0 -625
  136. sqlspec/statement/parameters.py +0 -956
  137. sqlspec/statement/pipelines/__init__.py +0 -210
  138. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  139. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  140. sqlspec/statement/pipelines/context.py +0 -109
  141. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  142. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  143. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  144. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  145. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  146. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  147. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  148. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  149. sqlspec/statement/pipelines/validators/_security.py +0 -967
  150. sqlspec/statement/result.py +0 -435
  151. sqlspec/statement/sql.py +0 -1774
  152. sqlspec/utils/cached_property.py +0 -25
  153. sqlspec/utils/statement_hashing.py +0 -203
  154. sqlspec-0.14.1.dist-info/RECORD +0 -145
  155. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
  159. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py DELETED
@@ -1,1774 +0,0 @@
1
- """SQL statement handling with centralized parameter management."""
2
-
3
- import operator
4
- from dataclasses import dataclass, field
5
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6
-
7
- import sqlglot
8
- import sqlglot.expressions as exp
9
- from sqlglot.errors import ParseError
10
- from typing_extensions import TypeAlias
11
-
12
- from sqlspec.exceptions import RiskLevel, SQLParsingError, SQLValidationError
13
- from sqlspec.statement.cache import sql_cache
14
- from sqlspec.statement.filters import StatementFilter
15
- from sqlspec.statement.parameters import (
16
- SQLGLOT_INCOMPATIBLE_STYLES,
17
- ParameterConverter,
18
- ParameterStyle,
19
- ParameterValidator,
20
- )
21
- from sqlspec.statement.pipelines import SQLProcessingContext, StatementPipeline
22
- from sqlspec.statement.pipelines.transformers import CommentAndHintRemover, ParameterizeLiterals
23
- from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator, SecurityValidator
24
- from sqlspec.utils.logging import get_logger
25
- from sqlspec.utils.statement_hashing import hash_sql_statement
26
- from sqlspec.utils.type_guards import (
27
- can_append_to_statement,
28
- can_extract_parameters,
29
- has_parameter_value,
30
- has_risk_level,
31
- is_dict,
32
- is_expression,
33
- is_statement_filter,
34
- supports_where,
35
- )
36
-
37
- if TYPE_CHECKING:
38
- from sqlglot.dialects.dialect import DialectType
39
-
40
- from sqlspec.statement.parameters import ParameterStyleConversionState
41
-
42
- __all__ = ("SQL", "SQLConfig", "Statement")
43
-
44
- logger = get_logger("sqlspec.statement")
45
-
46
- Statement: TypeAlias = Union[str, exp.Expression, "SQL"]
47
-
48
- # Parameter naming constants
49
- PARAM_PREFIX = "param_"
50
- POS_PARAM_PREFIX = "pos_param_"
51
- KW_POS_PARAM_PREFIX = "kw_pos_param_"
52
- ARG_PREFIX = "arg_"
53
-
54
- # Cache and limit constants
55
- DEFAULT_CACHE_SIZE = 1000
56
-
57
- # Oracle/Colon style parameter constants
58
- COLON_PARAM_ONE = "1"
59
- COLON_PARAM_MIN_INDEX = 1
60
-
61
-
62
- @dataclass
63
- class _ProcessedState:
64
- """Cached state from pipeline processing."""
65
-
66
- processed_expression: exp.Expression
67
- processed_sql: str
68
- merged_parameters: Any
69
- validation_errors: list[Any] = field(default_factory=list)
70
- analysis_results: dict[str, Any] = field(default_factory=dict)
71
- transformation_results: dict[str, Any] = field(default_factory=dict)
72
-
73
-
74
- @dataclass
75
- class SQLConfig:
76
- """Configuration for SQL statement behavior.
77
-
78
- Uses conservative defaults that prioritize compatibility and robustness,
79
- making it easier to work with diverse SQL dialects and complex queries.
80
-
81
- Pipeline Configuration:
82
- enable_parsing: Parse SQL strings using sqlglot (default: True)
83
- enable_validation: Run SQL validators to check for safety issues (default: True)
84
- enable_transformations: Apply SQL transformers like literal parameterization (default: True)
85
- enable_analysis: Run SQL analyzers for metadata extraction (default: False)
86
- enable_expression_simplification: Apply expression simplification transformer (default: False)
87
- enable_parameter_type_wrapping: Wrap parameters with type information (default: True)
88
- parse_errors_as_warnings: Treat parse errors as warnings instead of failures (default: True)
89
- enable_caching: Cache processed SQL statements (default: True)
90
-
91
- Component Lists (Advanced):
92
- transformers: Optional list of SQL transformers for explicit staging
93
- validators: Optional list of SQL validators for explicit staging
94
- analyzers: Optional list of SQL analyzers for explicit staging
95
-
96
- Internal Configuration:
97
- parameter_converter: Handles parameter style conversions
98
- parameter_validator: Validates parameter usage and styles
99
- input_sql_had_placeholders: Populated by SQL.__init__ to track original SQL state
100
- dialect: SQL dialect to use for parsing and generation
101
-
102
- Parameter Style Configuration:
103
- allowed_parameter_styles: Allowed parameter styles (e.g., ('qmark', 'named_colon'))
104
- default_parameter_style: Target parameter style for SQL generation
105
- allow_mixed_parameter_styles: Whether to allow mixing parameter styles in same query
106
- """
107
-
108
- enable_parsing: bool = True
109
- enable_validation: bool = True
110
- enable_transformations: bool = True
111
- enable_analysis: bool = False
112
- enable_expression_simplification: bool = False
113
- enable_parameter_type_wrapping: bool = True
114
- parse_errors_as_warnings: bool = True
115
- enable_caching: bool = True
116
-
117
- transformers: "Optional[list[Any]]" = None
118
- validators: "Optional[list[Any]]" = None
119
- analyzers: "Optional[list[Any]]" = None
120
-
121
- parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
122
- parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
123
- input_sql_had_placeholders: bool = False
124
- dialect: "Optional[DialectType]" = None
125
-
126
- allowed_parameter_styles: "Optional[tuple[str, ...]]" = None
127
- default_parameter_style: "Optional[str]" = None
128
- allow_mixed_parameter_styles: bool = False
129
- analyzer_output_handler: "Optional[Callable[[Any], None]]" = None
130
-
131
- def validate_parameter_style(self, style: "Union[ParameterStyle, str]") -> bool:
132
- """Check if a parameter style is allowed.
133
-
134
- Args:
135
- style: Parameter style to validate (can be ParameterStyle enum or string)
136
-
137
- Returns:
138
- True if the style is allowed, False otherwise
139
- """
140
- if self.allowed_parameter_styles is None:
141
- return True
142
- style_str = str(style)
143
- return style_str in self.allowed_parameter_styles
144
-
145
- def get_statement_pipeline(self) -> StatementPipeline:
146
- """Get the configured statement pipeline.
147
-
148
- Returns:
149
- StatementPipeline configured with transformers, validators, and analyzers
150
- """
151
- transformers = []
152
- if self.transformers is not None:
153
- transformers = list(self.transformers)
154
- elif self.enable_transformations:
155
- placeholder_style = self.default_parameter_style or "?"
156
- transformers = [CommentAndHintRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
157
- if self.enable_expression_simplification:
158
- from sqlspec.statement.pipelines.transformers import ExpressionSimplifier
159
-
160
- transformers.append(ExpressionSimplifier())
161
-
162
- validators = []
163
- if self.validators is not None:
164
- validators = list(self.validators)
165
- elif self.enable_validation:
166
- validators = [
167
- ParameterStyleValidator(fail_on_violation=not self.parse_errors_as_warnings),
168
- DMLSafetyValidator(),
169
- SecurityValidator(),
170
- ]
171
-
172
- analyzers = []
173
- if self.analyzers is not None:
174
- analyzers = list(self.analyzers)
175
- elif self.enable_analysis:
176
- from sqlspec.statement.pipelines.analyzers import StatementAnalyzer
177
-
178
- analyzers = [StatementAnalyzer()]
179
-
180
- return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers) # pyright: ignore
181
-
182
-
183
- def default_analysis_handler(analysis: Any) -> None:
184
- """Default handler that logs analysis to debug."""
185
- logger.debug("SQL Analysis: %s", analysis)
186
-
187
-
188
- class SQL:
189
- """Immutable SQL statement with centralized parameter management.
190
-
191
- The SQL class is the single source of truth for:
192
- - SQL expression/statement
193
- - Positional parameters
194
- - Named parameters
195
- - Applied filters
196
-
197
- All methods that modify state return new SQL instances.
198
- """
199
-
200
- __slots__ = (
201
- "_builder_result_type",
202
- "_config",
203
- "_dialect",
204
- "_filters",
205
- "_is_many",
206
- "_is_script",
207
- "_named_params",
208
- "_original_parameters",
209
- "_original_sql",
210
- "_parameter_conversion_state",
211
- "_placeholder_mapping",
212
- "_positional_params",
213
- "_processed_state",
214
- "_processing_context",
215
- "_raw_sql",
216
- "_statement",
217
- )
218
-
219
- def __init__(
220
- self,
221
- statement: "Union[str, exp.Expression, 'SQL']",
222
- *parameters: "Union[Any, StatementFilter, list[Union[Any, StatementFilter]]]",
223
- _dialect: "DialectType" = None,
224
- _config: "Optional[SQLConfig]" = None,
225
- _builder_result_type: "Optional[type]" = None,
226
- _existing_state: "Optional[dict[str, Any]]" = None,
227
- **kwargs: Any,
228
- ) -> None:
229
- """Initialize SQL with centralized parameter management."""
230
- if "config" in kwargs and _config is None:
231
- _config = kwargs.pop("config")
232
- self._config = _config or SQLConfig()
233
- self._dialect = _dialect or self._config.dialect
234
- self._builder_result_type = _builder_result_type
235
- self._processed_state: Optional[_ProcessedState] = None
236
- self._processing_context: Optional[SQLProcessingContext] = None
237
- self._positional_params: list[Any] = []
238
- self._named_params: dict[str, Any] = {}
239
- self._filters: list[StatementFilter] = []
240
- self._statement: exp.Expression
241
- self._raw_sql: str = ""
242
- self._original_parameters: Any = None
243
- self._original_sql: str = ""
244
- self._placeholder_mapping: dict[str, Union[str, int]] = {}
245
- self._parameter_conversion_state: Optional[ParameterStyleConversionState] = None
246
- self._is_many: bool = False
247
- self._is_script: bool = False
248
-
249
- if isinstance(statement, SQL):
250
- self._init_from_sql_object(statement, _dialect, _config or SQLConfig(), _builder_result_type)
251
- else:
252
- self._init_from_str_or_expression(statement)
253
-
254
- if _existing_state:
255
- self._load_from_existing_state(_existing_state)
256
-
257
- if not isinstance(statement, SQL) and not _existing_state:
258
- self._set_original_parameters(*parameters)
259
-
260
- self._process_parameters(*parameters, **kwargs)
261
-
262
- def _init_from_sql_object(
263
- self, statement: "SQL", dialect: "DialectType", config: "SQLConfig", builder_result_type: "Optional[type]"
264
- ) -> None:
265
- """Initialize from an existing SQL object."""
266
- self._statement = statement._statement
267
- self._dialect = dialect or statement._dialect
268
- self._config = config or statement._config
269
- self._builder_result_type = builder_result_type or statement._builder_result_type
270
- self._is_many = statement._is_many
271
- self._is_script = statement._is_script
272
- self._raw_sql = statement._raw_sql
273
- self._original_parameters = statement._original_parameters
274
- self._original_sql = statement._original_sql
275
- self._placeholder_mapping = statement._placeholder_mapping.copy()
276
- self._parameter_conversion_state = statement._parameter_conversion_state
277
- self._positional_params.extend(statement._positional_params)
278
- self._named_params.update(statement._named_params)
279
- self._filters.extend(statement._filters)
280
-
281
- def _init_from_str_or_expression(self, statement: "Union[str, exp.Expression]") -> None:
282
- """Initialize from a string or expression."""
283
- if isinstance(statement, str):
284
- self._raw_sql = statement
285
- self._statement = self._to_expression(statement)
286
- else:
287
- self._raw_sql = statement.sql(dialect=self._dialect)
288
- self._statement = statement
289
-
290
- def _load_from_existing_state(self, existing_state: "dict[str, Any]") -> None:
291
- """Load state from a dictionary (used by copy)."""
292
- self._positional_params = list(existing_state.get("positional_params", self._positional_params))
293
- self._named_params = dict(existing_state.get("named_params", self._named_params))
294
- self._filters = list(existing_state.get("filters", self._filters))
295
- self._is_many = existing_state.get("is_many", self._is_many)
296
- self._is_script = existing_state.get("is_script", self._is_script)
297
- self._raw_sql = existing_state.get("raw_sql", self._raw_sql)
298
- self._original_parameters = existing_state.get("original_parameters", self._original_parameters)
299
-
300
- def _set_original_parameters(self, *parameters: Any) -> None:
301
- """Set the original parameters."""
302
- if not parameters or (len(parameters) == 1 and is_statement_filter(parameters[0])):
303
- self._original_parameters = None
304
- elif len(parameters) == 1 and isinstance(parameters[0], (list, tuple)):
305
- self._original_parameters = parameters[0]
306
- else:
307
- self._original_parameters = parameters
308
-
309
- def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
310
- """Process and categorize parameters."""
311
- for param in parameters:
312
- self._process_parameter_item(param)
313
-
314
- if "parameters" in kwargs:
315
- param_value = kwargs.pop("parameters")
316
- if isinstance(param_value, (list, tuple)):
317
- self._positional_params.extend(param_value)
318
- elif is_dict(param_value):
319
- self._named_params.update(param_value)
320
- else:
321
- self._positional_params.append(param_value)
322
-
323
- self._named_params.update({k: v for k, v in kwargs.items() if not k.startswith("_")})
324
-
325
- def _cache_key(self) -> str:
326
- """Generate a cache key for the current SQL state."""
327
- return hash_sql_statement(self)
328
-
329
- def _process_parameter_item(self, item: Any) -> None:
330
- """Process a single item from the parameters list."""
331
- if is_statement_filter(item):
332
- self._filters.append(item)
333
- pos_params, named_params = self._extract_filter_parameters(item)
334
- self._positional_params.extend(pos_params)
335
- self._named_params.update(named_params)
336
- elif isinstance(item, list):
337
- for sub_item in item:
338
- self._process_parameter_item(sub_item)
339
- elif is_dict(item):
340
- self._named_params.update(item)
341
- elif isinstance(item, tuple):
342
- self._positional_params.extend(item)
343
- else:
344
- self._positional_params.append(item)
345
-
346
- def _ensure_processed(self) -> None:
347
- """Ensure the SQL has been processed through the pipeline (lazy initialization).
348
-
349
- This method implements the facade pattern with lazy processing.
350
- It's called by public methods that need processed state.
351
- """
352
- if self._processed_state is not None:
353
- return
354
-
355
- # Check cache first if caching is enabled
356
- cache_key = None
357
- if self._config.enable_caching:
358
- cache_key = self._cache_key()
359
- cached_state = sql_cache.get(cache_key)
360
-
361
- if cached_state is not None:
362
- self._processed_state = cached_state
363
- return
364
-
365
- final_expr, final_params = self._build_final_state()
366
- has_placeholders = self._detect_placeholders()
367
- initial_sql_for_context, final_params = self._prepare_context_sql(final_expr, final_params)
368
-
369
- context = self._create_processing_context(initial_sql_for_context, final_expr, final_params, has_placeholders)
370
- result = self._run_pipeline(context)
371
-
372
- processed_sql, merged_params = self._process_pipeline_result(result, final_params, context)
373
-
374
- self._finalize_processed_state(result, processed_sql, merged_params)
375
-
376
- # Store in cache if caching is enabled
377
- if self._config.enable_caching and cache_key is not None and self._processed_state is not None:
378
- sql_cache.set(cache_key, self._processed_state)
379
-
380
- def _detect_placeholders(self) -> bool:
381
- """Detect if the raw SQL has placeholders."""
382
- if self._raw_sql:
383
- validator = self._config.parameter_validator
384
- raw_param_info = validator.extract_parameters(self._raw_sql)
385
- has_placeholders = bool(raw_param_info)
386
- if has_placeholders:
387
- self._config.input_sql_had_placeholders = True
388
- return has_placeholders
389
- return self._config.input_sql_had_placeholders
390
-
391
- def _prepare_context_sql(self, final_expr: exp.Expression, final_params: Any) -> tuple[str, Any]:
392
- """Prepare SQL string and parameters for context."""
393
- initial_sql_for_context = self._raw_sql or final_expr.sql(dialect=self._dialect or self._config.dialect)
394
-
395
- if is_expression(final_expr) and self._placeholder_mapping:
396
- initial_sql_for_context = final_expr.sql(dialect=self._dialect or self._config.dialect)
397
- if self._placeholder_mapping:
398
- final_params = self._convert_parameters(final_params)
399
-
400
- return initial_sql_for_context, final_params
401
-
402
- def _convert_parameters(self, final_params: Any) -> Any:
403
- """Convert parameters based on placeholder mapping."""
404
- if is_dict(final_params):
405
- converted_params = {}
406
- for placeholder_key, original_name in self._placeholder_mapping.items():
407
- if str(original_name) in final_params:
408
- converted_params[placeholder_key] = final_params[str(original_name)]
409
- non_oracle_params = {
410
- key: value
411
- for key, value in final_params.items()
412
- if key not in {str(name) for name in self._placeholder_mapping.values()}
413
- }
414
- converted_params.update(non_oracle_params)
415
- return converted_params
416
- if isinstance(final_params, (list, tuple)):
417
- validator = self._config.parameter_validator
418
- param_info = validator.extract_parameters(self._raw_sql)
419
-
420
- all_numeric = all(p.name and p.name.isdigit() for p in param_info)
421
-
422
- if all_numeric:
423
- converted_params = {}
424
-
425
- min_param_num = min(int(p.name) for p in param_info if p.name)
426
-
427
- for i, param in enumerate(final_params):
428
- param_num = str(i + min_param_num)
429
- converted_params[param_num] = param
430
-
431
- return converted_params
432
- converted_params = {}
433
- for i, param in enumerate(final_params):
434
- if i < len(param_info):
435
- placeholder_key = f"{PARAM_PREFIX}{param_info[i].ordinal}"
436
- converted_params[placeholder_key] = param
437
- return converted_params
438
- return final_params
439
-
440
- def _create_processing_context(
441
- self, initial_sql_for_context: str, final_expr: exp.Expression, final_params: Any, has_placeholders: bool
442
- ) -> SQLProcessingContext:
443
- """Create SQL processing context."""
444
- context = SQLProcessingContext(
445
- initial_sql_string=initial_sql_for_context,
446
- dialect=self._dialect or self._config.dialect,
447
- config=self._config,
448
- initial_expression=final_expr,
449
- current_expression=final_expr,
450
- merged_parameters=final_params,
451
- input_sql_had_placeholders=has_placeholders or self._config.input_sql_had_placeholders,
452
- )
453
-
454
- if self._placeholder_mapping:
455
- context.extra_info["placeholder_map"] = self._placeholder_mapping
456
-
457
- # Set conversion state if available
458
- if self._parameter_conversion_state:
459
- context.parameter_conversion = self._parameter_conversion_state
460
-
461
- validator = self._config.parameter_validator
462
- context.parameter_info = validator.extract_parameters(context.initial_sql_string)
463
-
464
- return context
465
-
466
- def _run_pipeline(self, context: SQLProcessingContext) -> Any:
467
- """Run the SQL processing pipeline."""
468
- pipeline = self._config.get_statement_pipeline()
469
- result = pipeline.execute_pipeline(context)
470
- self._processing_context = result.context
471
- return result
472
-
473
- def _process_pipeline_result(
474
- self, result: Any, final_params: Any, context: SQLProcessingContext
475
- ) -> tuple[str, Any]:
476
- """Process the result from the pipeline."""
477
- processed_expr = result.expression
478
-
479
- if isinstance(processed_expr, exp.Anonymous):
480
- processed_sql = self._raw_sql or context.initial_sql_string
481
- else:
482
- # Use the initial expression that includes filters, not the processed one
483
- # The processed expression may have lost LIMIT/OFFSET during pipeline processing
484
- if hasattr(context, "initial_expression") and context.initial_expression != processed_expr:
485
- # Check if LIMIT/OFFSET was stripped during processing
486
- has_limit_in_initial = (
487
- context.initial_expression is not None
488
- and hasattr(context.initial_expression, "args")
489
- and "limit" in context.initial_expression.args
490
- )
491
- has_limit_in_processed = hasattr(processed_expr, "args") and "limit" in processed_expr.args
492
-
493
- if has_limit_in_initial and not has_limit_in_processed:
494
- # Restore LIMIT/OFFSET from initial expression
495
- processed_expr = context.initial_expression
496
-
497
- processed_sql = (
498
- processed_expr.sql(dialect=self._dialect or self._config.dialect, comments=False)
499
- if processed_expr
500
- else ""
501
- )
502
- if self._placeholder_mapping and self._original_sql:
503
- processed_sql, result = self._denormalize_sql(processed_sql, result)
504
-
505
- merged_params = self._merge_pipeline_parameters(result, final_params)
506
-
507
- return processed_sql, merged_params
508
-
509
- def _denormalize_sql(self, processed_sql: str, result: Any) -> tuple[str, Any]:
510
- """Denormalize SQL back to original parameter style."""
511
-
512
- original_sql = self._original_sql
513
- param_info = self._config.parameter_validator.extract_parameters(original_sql)
514
- target_styles = {p.style for p in param_info}
515
- if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
516
- processed_sql = self._config.parameter_converter._convert_sql_placeholders(
517
- processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
518
- )
519
- elif ParameterStyle.NAMED_PYFORMAT in target_styles:
520
- processed_sql = self._config.parameter_converter._convert_sql_placeholders(
521
- processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
522
- )
523
- if (
524
- self._placeholder_mapping
525
- and result.context.merged_parameters
526
- and is_dict(result.context.merged_parameters)
527
- ):
528
- result.context.merged_parameters = self._denormalize_pyformat_params(result.context.merged_parameters)
529
- elif ParameterStyle.POSITIONAL_COLON in target_styles:
530
- processed_param_info = self._config.parameter_validator.extract_parameters(processed_sql)
531
- has_param_placeholders = any(p.name and p.name.startswith(PARAM_PREFIX) for p in processed_param_info)
532
-
533
- if not has_param_placeholders:
534
- processed_sql = self._config.parameter_converter._convert_sql_placeholders(
535
- processed_sql, param_info, ParameterStyle.POSITIONAL_COLON
536
- )
537
- if (
538
- self._placeholder_mapping
539
- and result.context.merged_parameters
540
- and is_dict(result.context.merged_parameters)
541
- ):
542
- result.context.merged_parameters = self._denormalize_colon_params(result.context.merged_parameters)
543
-
544
- return processed_sql, result
545
-
546
- def _denormalize_colon_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
547
- """Denormalize colon-style parameters back to numeric format."""
548
- # For positional colon style, all params should have numeric keys
549
- # Just return the params as-is if they already have the right format
550
- if all(key.isdigit() for key in params):
551
- return params
552
-
553
- # For positional colon, we need ALL parameters in the final result
554
- # This includes both user parameters and extracted literals
555
- # We should NOT filter out extracted parameters (param_0, param_1, etc)
556
- # because they need to be included in the final parameter conversion
557
- return params
558
-
559
- def _denormalize_pyformat_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
560
- """Denormalize pyformat parameters back to their original names."""
561
- deconverted_params = {}
562
- for placeholder_key, original_name in self._placeholder_mapping.items():
563
- if placeholder_key in params:
564
- # For pyformat, the original_name is the actual parameter name (e.g., 'max_value')
565
- deconverted_params[str(original_name)] = params[placeholder_key]
566
- # Include any parameters that weren't converted
567
- non_converted_params = {key: value for key, value in params.items() if not key.startswith(PARAM_PREFIX)}
568
- deconverted_params.update(non_converted_params)
569
- return deconverted_params
570
-
571
- def _merge_pipeline_parameters(self, result: Any, final_params: Any) -> Any:
572
- """Merge parameters from the pipeline processing."""
573
- merged_params = result.context.merged_parameters
574
-
575
- # If we have extracted parameters from the pipeline, only merge them if:
576
- # 1. We don't already have parameters in merged_params, OR
577
- # 2. The original params were None and we need to use the extracted ones
578
- if result.context.extracted_parameters_from_pipeline:
579
- if merged_params is None:
580
- # No existing parameters - use the extracted ones
581
- merged_params = result.context.extracted_parameters_from_pipeline
582
- elif merged_params == final_params and final_params is None:
583
- # Both are None, use extracted parameters
584
- merged_params = result.context.extracted_parameters_from_pipeline
585
- elif merged_params != result.context.extracted_parameters_from_pipeline:
586
- # Only merge if the extracted parameters are different from what we already have
587
- # This prevents the duplication issue where the same parameters get added twice
588
- if is_dict(merged_params):
589
- for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
590
- param_name = f"{PARAM_PREFIX}{i}"
591
- merged_params[param_name] = param
592
- elif isinstance(merged_params, (list, tuple)):
593
- # Only extend if we don't already have these parameters
594
- # Convert to list and extend with extracted parameters
595
- if isinstance(merged_params, tuple):
596
- merged_params = list(merged_params)
597
- merged_params.extend(result.context.extracted_parameters_from_pipeline)
598
- else:
599
- # Single parameter case - convert to list with original + extracted
600
- merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
601
-
602
- return merged_params
603
-
604
- def _finalize_processed_state(self, result: Any, processed_sql: str, merged_params: Any) -> None:
605
- """Finalize the processed state."""
606
- # Wrap parameters with type information if enabled
607
- if self._config.enable_parameter_type_wrapping and merged_params is not None:
608
- # Get parameter info from the processed SQL
609
- validator = self._config.parameter_validator
610
- param_info = validator.extract_parameters(processed_sql)
611
-
612
- # Wrap parameters with type information
613
- converter = self._config.parameter_converter
614
- merged_params = converter.wrap_parameters_with_types(merged_params, param_info)
615
-
616
- # Extract analyzer results from context metadata
617
- analysis_results = (
618
- {key: value for key, value in result.context.metadata.items() if key.endswith("Analyzer")}
619
- if result.context.metadata
620
- else {}
621
- )
622
-
623
- # If analyzer output handler is configured, call it with the analysis
624
- if self._config.analyzer_output_handler and analysis_results:
625
- # Create a structured analysis object from the metadata
626
-
627
- # Extract the main analyzer results
628
- analyzer_metadata = analysis_results.get("StatementAnalyzer", {})
629
- if analyzer_metadata:
630
- # Create a simplified analysis object for the handler
631
- analysis = {
632
- "statement_type": analyzer_metadata.get("statement_type"),
633
- "complexity_score": analyzer_metadata.get("complexity_score"),
634
- "table_count": analyzer_metadata.get("table_count"),
635
- "has_subqueries": analyzer_metadata.get("has_subqueries"),
636
- "join_count": analyzer_metadata.get("join_count"),
637
- "duration_ms": analyzer_metadata.get("duration_ms"),
638
- }
639
- self._config.analyzer_output_handler(analysis)
640
-
641
- self._processed_state = _ProcessedState(
642
- processed_expression=result.expression,
643
- processed_sql=processed_sql,
644
- merged_parameters=merged_params,
645
- validation_errors=list(result.context.validation_errors),
646
- analysis_results=analysis_results,
647
- transformation_results={},
648
- )
649
-
650
- if not self._config.parse_errors_as_warnings and self._processed_state.validation_errors:
651
- highest_risk_error = max(
652
- self._processed_state.validation_errors, key=lambda e: e.risk_level.value if has_risk_level(e) else 0
653
- )
654
- raise SQLValidationError(
655
- message=highest_risk_error.message,
656
- sql=self._raw_sql or processed_sql,
657
- risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
658
- )
659
-
660
- def _to_expression(self, statement: "Union[str, exp.Expression]") -> exp.Expression:
661
- """Convert string to sqlglot expression."""
662
- if is_expression(statement):
663
- return statement
664
-
665
- if not statement or (isinstance(statement, str) and not statement.strip()):
666
- return exp.Select()
667
-
668
- if not self._config.enable_parsing:
669
- return exp.Anonymous(this=statement)
670
-
671
- if not isinstance(statement, str):
672
- return exp.Anonymous(this="")
673
- validator = self._config.parameter_validator
674
- param_info = validator.extract_parameters(statement)
675
-
676
- # Check if conversion is needed
677
- needs_conversion = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in param_info)
678
-
679
- converted_sql = statement
680
- placeholder_mapping: dict[str, Any] = {}
681
-
682
- if needs_conversion:
683
- converter = self._config.parameter_converter
684
- converted_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
685
- self._original_sql = statement
686
- self._placeholder_mapping = placeholder_mapping
687
-
688
- # Create conversion state
689
- from sqlspec.statement.parameters import ParameterStyleConversionState
690
-
691
- self._parameter_conversion_state = ParameterStyleConversionState(
692
- was_transformed=True,
693
- original_styles=list({p.style for p in param_info}),
694
- transformation_style=ParameterStyle.NAMED_COLON,
695
- placeholder_map=placeholder_mapping,
696
- original_param_info=param_info,
697
- )
698
- else:
699
- self._parameter_conversion_state = None
700
-
701
- try:
702
- expressions = sqlglot.parse(converted_sql, dialect=self._dialect) # pyright: ignore
703
- if not expressions:
704
- return exp.Anonymous(this=statement)
705
- first_expr = expressions[0]
706
- if first_expr is None:
707
- return exp.Anonymous(this=statement)
708
-
709
- except ParseError as e:
710
- if getattr(self._config, "parse_errors_as_warnings", False):
711
- logger.warning(
712
- "Failed to parse SQL, returning Anonymous expression.", extra={"sql": statement, "error": str(e)}
713
- )
714
- return exp.Anonymous(this=statement)
715
-
716
- msg = f"Failed to parse SQL: {statement}"
717
- raise SQLParsingError(msg) from e
718
- return first_expr
719
-
720
- @staticmethod
721
- def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
722
- """Extract parameters from a filter object."""
723
- if can_extract_parameters(filter_obj):
724
- return filter_obj.extract_parameters()
725
- return [], {}
726
-
727
- def copy(
728
- self,
729
- statement: "Optional[Union[str, exp.Expression]]" = None,
730
- parameters: "Optional[Any]" = None,
731
- dialect: "DialectType" = None,
732
- config: "Optional[SQLConfig]" = None,
733
- **kwargs: Any,
734
- ) -> "SQL":
735
- """Create a copy with optional modifications.
736
-
737
- This is the primary method for creating modified SQL objects.
738
- """
739
- existing_state = {
740
- "positional_params": list(self._positional_params),
741
- "named_params": dict(self._named_params),
742
- "filters": list(self._filters),
743
- "is_many": self._is_many,
744
- "is_script": self._is_script,
745
- "raw_sql": self._raw_sql,
746
- }
747
- existing_state["original_parameters"] = self._original_parameters
748
-
749
- new_statement = statement if statement is not None else self._statement
750
- new_dialect = dialect if dialect is not None else self._dialect
751
- new_config = config if config is not None else self._config
752
-
753
- if parameters is not None:
754
- existing_state["positional_params"] = []
755
- existing_state["named_params"] = {}
756
- return SQL(
757
- new_statement,
758
- parameters,
759
- _dialect=new_dialect,
760
- _config=new_config,
761
- _builder_result_type=self._builder_result_type,
762
- _existing_state=None,
763
- **kwargs,
764
- )
765
-
766
- return SQL(
767
- new_statement,
768
- _dialect=new_dialect,
769
- _config=new_config,
770
- _builder_result_type=self._builder_result_type,
771
- _existing_state=existing_state,
772
- **kwargs,
773
- )
774
-
775
- def add_named_parameter(self, name: "str", value: Any) -> "SQL":
776
- """Add a named parameter and return a new SQL instance."""
777
- new_obj = self.copy()
778
- new_obj._named_params[name] = value
779
- return new_obj
780
-
781
- def get_unique_parameter_name(
782
- self, base_name: "str", namespace: "Optional[str]" = None, preserve_original: bool = False
783
- ) -> str:
784
- """Generate a unique parameter name.
785
-
786
- Args:
787
- base_name: The base parameter name
788
- namespace: Optional namespace prefix (e.g., 'cte', 'subquery')
789
- preserve_original: If True, try to preserve the original name
790
-
791
- Returns:
792
- A unique parameter name
793
- """
794
- all_param_names = set(self._named_params.keys())
795
-
796
- candidate = f"{namespace}_{base_name}" if namespace else base_name
797
-
798
- if preserve_original and candidate not in all_param_names:
799
- return candidate
800
-
801
- if candidate not in all_param_names:
802
- return candidate
803
-
804
- counter = 1
805
- while True:
806
- new_candidate = f"{candidate}_{counter}"
807
- if new_candidate not in all_param_names:
808
- return new_candidate
809
- counter += 1
810
-
811
- def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
812
- """Apply WHERE clause and return new SQL instance."""
813
- condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
814
-
815
- if supports_where(self._statement):
816
- new_statement = self._statement.where(condition_expr) # pyright: ignore
817
- else:
818
- new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
819
-
820
- return self.copy(statement=new_statement)
821
-
822
- def filter(self, filter_obj: StatementFilter) -> "SQL":
823
- """Apply a filter and return a new SQL instance."""
824
- new_obj = self.copy()
825
- new_obj._filters.append(filter_obj)
826
- pos_params, named_params = self._extract_filter_parameters(filter_obj)
827
- new_obj._positional_params.extend(pos_params)
828
- new_obj._named_params.update(named_params)
829
- return new_obj
830
-
831
- def as_many(self, parameters: "Optional[list[Any]]" = None) -> "SQL":
832
- """Mark for executemany with optional parameters."""
833
- new_obj = self.copy()
834
- new_obj._is_many = True
835
- if parameters is not None:
836
- new_obj._positional_params = []
837
- new_obj._named_params = {}
838
- new_obj._original_parameters = parameters
839
- return new_obj
840
-
841
- def as_script(self) -> "SQL":
842
- """Mark as script for execution."""
843
- new_obj = self.copy()
844
- new_obj._is_script = True
845
- return new_obj
846
-
847
- def _build_final_state(self) -> tuple[exp.Expression, Any]:
848
- """Build final expression and parameters after applying filters."""
849
- final_expr = self._statement
850
-
851
- # Accumulate parameters from both the original SQL and filters
852
- accumulated_positional = list(self._positional_params)
853
- accumulated_named = dict(self._named_params)
854
-
855
- for filter_obj in self._filters:
856
- if can_append_to_statement(filter_obj):
857
- temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
858
- temp_sql._positional_params = list(accumulated_positional)
859
- temp_sql._named_params = dict(accumulated_named)
860
- result = filter_obj.append_to_statement(temp_sql)
861
-
862
- if isinstance(result, SQL):
863
- # Extract the modified expression
864
- final_expr = result._statement
865
- # Also preserve any parameters added by the filter
866
- accumulated_positional = list(result._positional_params)
867
- accumulated_named = dict(result._named_params)
868
- else:
869
- final_expr = result
870
-
871
- final_params: Any
872
- if accumulated_named and not accumulated_positional:
873
- final_params = dict(accumulated_named)
874
- elif accumulated_positional and not accumulated_named:
875
- final_params = list(accumulated_positional)
876
- elif accumulated_positional and accumulated_named:
877
- final_params = dict(accumulated_named)
878
- for i, param in enumerate(accumulated_positional):
879
- param_name = f"arg_{i}"
880
- while param_name in final_params:
881
- param_name = f"arg_{i}_{id(param)}"
882
- final_params[param_name] = param
883
- else:
884
- final_params = None
885
-
886
- return final_expr, final_params
887
-
888
- @property
889
- def sql(self) -> str:
890
- """Get SQL string."""
891
- if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
892
- return ""
893
-
894
- if self._is_script and self._raw_sql:
895
- return self._raw_sql
896
- if not self._config.enable_parsing and self._raw_sql:
897
- return self._raw_sql
898
-
899
- self._ensure_processed()
900
- if self._processed_state is None:
901
- msg = "Failed to process SQL statement"
902
- raise RuntimeError(msg)
903
- return self._processed_state.processed_sql
904
-
905
- @property
906
- def config(self) -> "SQLConfig":
907
- """Get the SQL configuration."""
908
- return self._config
909
-
910
- @property
911
- def expression(self) -> "Optional[exp.Expression]":
912
- """Get the final expression."""
913
- if not self._config.enable_parsing:
914
- return None
915
- self._ensure_processed()
916
- if self._processed_state is None:
917
- msg = "Failed to process SQL statement"
918
- raise RuntimeError(msg)
919
- return self._processed_state.processed_expression
920
-
921
- @property
922
- def parameters(self) -> Any:
923
- """Get merged parameters."""
924
- if self._is_many and self._original_parameters is not None:
925
- return self._original_parameters
926
-
927
- if (
928
- self._original_parameters is not None
929
- and isinstance(self._original_parameters, tuple)
930
- and not self._named_params
931
- ):
932
- return self._original_parameters
933
-
934
- self._ensure_processed()
935
- if self._processed_state is None:
936
- msg = "Failed to process SQL statement"
937
- raise RuntimeError(msg)
938
- params = self._processed_state.merged_parameters
939
- if params is None:
940
- return {}
941
- return params
942
-
943
- @property
944
- def is_many(self) -> bool:
945
- """Check if this is for executemany."""
946
- return self._is_many
947
-
948
- @property
949
- def is_script(self) -> bool:
950
- """Check if this is a script."""
951
- return self._is_script
952
-
953
- @property
954
- def dialect(self) -> "Optional[DialectType]":
955
- """Get the SQL dialect."""
956
- return self._dialect
957
-
958
- def to_sql(self, placeholder_style: "Optional[str]" = None) -> "str":
959
- """Convert to SQL string with given placeholder style."""
960
- if self._is_script:
961
- return self.sql
962
- sql, _ = self.compile(placeholder_style=placeholder_style)
963
- return sql
964
-
965
- def get_parameters(self, style: "Optional[str]" = None) -> Any:
966
- """Get parameters in the requested style."""
967
- _, params = self.compile(placeholder_style=style)
968
- return params
969
-
970
- def _compile_execute_many(self, placeholder_style: "Optional[str]") -> "tuple[str, Any]":
971
- """Compile for execute_many operations.
972
-
973
- The pipeline processed the first parameter set to extract literals.
974
- Now we need to apply those extracted literals to all parameter sets.
975
- """
976
- sql = self.sql
977
- self._ensure_processed()
978
-
979
- # Get the original parameter sets
980
- param_sets = self._original_parameters or []
981
-
982
- # Get any literals extracted during pipeline processing
983
- if self._processed_state and self._processing_context:
984
- extracted_literals = self._processing_context.extracted_parameters_from_pipeline
985
-
986
- if extracted_literals:
987
- # Apply extracted literals to each parameter set
988
- enhanced_params: list[Any] = []
989
- for param_set in param_sets:
990
- if isinstance(param_set, (list, tuple)):
991
- # Add extracted literals to the parameter tuple
992
- enhanced_set = list(param_set) + [
993
- p.value if hasattr(p, "value") else p for p in extracted_literals
994
- ]
995
- enhanced_params.append(tuple(enhanced_set))
996
- elif isinstance(param_set, dict):
997
- # For dict params, add extracted literals with generated names
998
- enhanced_dict = dict(param_set)
999
- for i, literal in enumerate(extracted_literals):
1000
- param_name = f"_literal_{i}"
1001
- enhanced_dict[param_name] = literal.value if hasattr(literal, "value") else literal
1002
- enhanced_params.append(enhanced_dict)
1003
- else:
1004
- # Single parameter - convert to tuple with literals
1005
- literals = [p.value if hasattr(p, "value") else p for p in extracted_literals]
1006
- enhanced_params.append((param_set, *literals))
1007
- param_sets = enhanced_params
1008
-
1009
- if placeholder_style:
1010
- sql, param_sets = self._convert_placeholder_style(sql, param_sets, placeholder_style)
1011
-
1012
- return sql, param_sets
1013
-
1014
- def _get_extracted_parameters(self) -> "list[Any]":
1015
- """Get extracted parameters from pipeline processing."""
1016
- extracted_params = []
1017
- if self._processed_state and self._processed_state.merged_parameters:
1018
- merged = self._processed_state.merged_parameters
1019
- if isinstance(merged, list):
1020
- if merged and not isinstance(merged[0], (tuple, list)):
1021
- extracted_params = merged
1022
- elif self._processing_context and self._processing_context.extracted_parameters_from_pipeline:
1023
- extracted_params = self._processing_context.extracted_parameters_from_pipeline
1024
- return extracted_params
1025
-
1026
- def _merge_extracted_params_with_sets(self, params: Any, extracted_params: "list[Any]") -> "list[tuple[Any, ...]]":
1027
- """Merge extracted parameters with each parameter set."""
1028
- enhanced_params = []
1029
- for param_set in params:
1030
- if isinstance(param_set, (list, tuple)):
1031
- extracted_values = []
1032
- for extracted in extracted_params:
1033
- if has_parameter_value(extracted):
1034
- extracted_values.append(extracted.value)
1035
- else:
1036
- extracted_values.append(extracted)
1037
- enhanced_set = list(param_set) + extracted_values
1038
- enhanced_params.append(tuple(enhanced_set))
1039
- else:
1040
- extracted_values = []
1041
- for extracted in extracted_params:
1042
- if has_parameter_value(extracted):
1043
- extracted_values.append(extracted.value)
1044
- else:
1045
- extracted_values.append(extracted)
1046
- enhanced_params.append((param_set, *extracted_values))
1047
- return enhanced_params
1048
-
1049
- def compile(self, placeholder_style: "Optional[str]" = None) -> "tuple[str, Any]":
1050
- """Compile to SQL and parameters."""
1051
- if self._is_script:
1052
- return self.sql, None
1053
-
1054
- if self._is_many and self._original_parameters is not None:
1055
- return self._compile_execute_many(placeholder_style)
1056
-
1057
- if not self._config.enable_parsing and self._raw_sql:
1058
- return self._raw_sql, self._raw_parameters
1059
-
1060
- self._ensure_processed()
1061
-
1062
- if self._processed_state is None:
1063
- msg = "Failed to process SQL statement"
1064
- raise RuntimeError(msg)
1065
- sql = self._processed_state.processed_sql
1066
- params = self._processed_state.merged_parameters
1067
-
1068
- if params is not None and self._processing_context:
1069
- parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
1070
- if parameter_mapping:
1071
- params = self._reorder_parameters(params, parameter_mapping)
1072
-
1073
- # Handle deconversion if needed
1074
- if self._processing_context and self._processing_context.parameter_conversion:
1075
- norm_state = self._processing_context.parameter_conversion
1076
-
1077
- # If original SQL had incompatible styles, denormalize back to the original style
1078
- # when no specific style requested OR when the requested style matches the original
1079
- if norm_state.was_transformed and norm_state.original_styles:
1080
- original_style = norm_state.original_styles[0]
1081
- should_denormalize = placeholder_style is None or (
1082
- placeholder_style and ParameterStyle(placeholder_style) == original_style
1083
- )
1084
-
1085
- if should_denormalize and original_style in SQLGLOT_INCOMPATIBLE_STYLES:
1086
- # Denormalize SQL back to original style
1087
- sql = self._config.parameter_converter._convert_sql_placeholders(
1088
- sql, norm_state.original_param_info, original_style
1089
- )
1090
- # Also deConvert parameters if needed
1091
- if original_style == ParameterStyle.POSITIONAL_COLON and is_dict(params):
1092
- params = self._denormalize_colon_params(params)
1093
-
1094
- params = self._unwrap_typed_parameters(params)
1095
-
1096
- if placeholder_style is None:
1097
- return sql, params
1098
-
1099
- if placeholder_style:
1100
- sql, params = self._apply_placeholder_style(sql, params, placeholder_style)
1101
-
1102
- return sql, params
1103
-
1104
- def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
1105
- """Apply placeholder style conversion to SQL and parameters."""
1106
- # Just use the params passed in - they've already been processed
1107
- sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
1108
- return sql, params
1109
-
1110
- @staticmethod
1111
- def _unwrap_typed_parameters(params: Any) -> Any:
1112
- """Unwrap TypedParameter objects to their actual values.
1113
-
1114
- Args:
1115
- params: Parameters that may contain TypedParameter objects
1116
-
1117
- Returns:
1118
- Parameters with TypedParameter objects unwrapped to their values
1119
- """
1120
- if params is None:
1121
- return None
1122
-
1123
- if is_dict(params):
1124
- unwrapped_dict = {}
1125
- for key, value in params.items():
1126
- if has_parameter_value(value):
1127
- unwrapped_dict[key] = value.value
1128
- else:
1129
- unwrapped_dict[key] = value
1130
- return unwrapped_dict
1131
-
1132
- if isinstance(params, (list, tuple)):
1133
- unwrapped_list = []
1134
- for value in params:
1135
- if has_parameter_value(value):
1136
- unwrapped_list.append(value.value)
1137
- else:
1138
- unwrapped_list.append(value)
1139
- return type(params)(unwrapped_list)
1140
-
1141
- if has_parameter_value(params):
1142
- return params.value
1143
-
1144
- return params
1145
-
1146
- @staticmethod
1147
- def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
1148
- """Reorder parameters based on the position mapping.
1149
-
1150
- Args:
1151
- params: Original parameters (list, tuple, or dict)
1152
- mapping: Dict mapping new positions to original positions
1153
-
1154
- Returns:
1155
- Reordered parameters in the same format as input
1156
- """
1157
- if isinstance(params, (list, tuple)):
1158
- reordered_list = [None] * len(params) # pyright: ignore
1159
- for new_pos, old_pos in mapping.items():
1160
- if old_pos < len(params):
1161
- reordered_list[new_pos] = params[old_pos] # pyright: ignore
1162
-
1163
- for i, val in enumerate(reordered_list):
1164
- if val is None and i < len(params) and i not in mapping:
1165
- reordered_list[i] = params[i] # pyright: ignore
1166
-
1167
- return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
1168
-
1169
- if is_dict(params):
1170
- if all(key.startswith(PARAM_PREFIX) and key[len(PARAM_PREFIX) :].isdigit() for key in params):
1171
- reordered_dict: dict[str, Any] = {}
1172
- for new_pos, old_pos in mapping.items():
1173
- old_key = f"{PARAM_PREFIX}{old_pos}"
1174
- new_key = f"{PARAM_PREFIX}{new_pos}"
1175
- if old_key in params:
1176
- reordered_dict[new_key] = params[old_key]
1177
-
1178
- for key, value in params.items():
1179
- if key not in reordered_dict and key.startswith(PARAM_PREFIX):
1180
- idx = int(key[6:])
1181
- if idx not in mapping:
1182
- reordered_dict[key] = value
1183
-
1184
- return reordered_dict
1185
- return params
1186
- return params
1187
-
1188
- def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
1189
- """Convert SQL and parameters to the requested placeholder style.
1190
-
1191
- Args:
1192
- sql: The SQL string to convert
1193
- params: The parameters to convert
1194
- placeholder_style: Target placeholder style
1195
-
1196
- Returns:
1197
- Tuple of (converted_sql, converted_params)
1198
- """
1199
- if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
1200
- converter = self._config.parameter_converter
1201
- param_info = converter.validator.extract_parameters(sql)
1202
-
1203
- if param_info:
1204
- target_style = (
1205
- ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
1206
- )
1207
- sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
1208
-
1209
- return sql, params
1210
-
1211
- converter = self._config.parameter_converter
1212
-
1213
- # For POSITIONAL_COLON style, use original parameter info if available to preserve numeric identifiers
1214
- target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
1215
- if (
1216
- target_style == ParameterStyle.POSITIONAL_COLON
1217
- and self._processing_context
1218
- and self._processing_context.parameter_conversion
1219
- and self._processing_context.parameter_conversion.original_param_info
1220
- ):
1221
- param_info = self._processing_context.parameter_conversion.original_param_info
1222
- else:
1223
- param_info = converter.validator.extract_parameters(sql)
1224
-
1225
- # CRITICAL FIX: For POSITIONAL_COLON, we need to ensure param_info reflects
1226
- # all placeholders in the current SQL, not just the original ones.
1227
- # This handles cases where transformers (like ParameterizeLiterals) add new placeholders.
1228
- if target_style == ParameterStyle.POSITIONAL_COLON and param_info:
1229
- # Re-extract from current SQL to get all placeholders
1230
- current_param_info = converter.validator.extract_parameters(sql)
1231
- if len(current_param_info) > len(param_info):
1232
- # More placeholders in current SQL means transformers added some
1233
- # Use the current info to ensure all placeholders get parameters
1234
- param_info = current_param_info
1235
-
1236
- if not param_info:
1237
- return sql, params
1238
-
1239
- if target_style == ParameterStyle.STATIC:
1240
- return self._embed_static_parameters(sql, params, param_info)
1241
-
1242
- if param_info and all(p.style == target_style for p in param_info):
1243
- converted_params = self._convert_parameters_format(params, param_info, target_style)
1244
- return sql, converted_params
1245
-
1246
- sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
1247
-
1248
- params = self._convert_parameters_format(params, param_info, target_style)
1249
-
1250
- return sql, params
1251
-
1252
- def _embed_static_parameters(self, sql: str, params: Any, param_info: list[Any]) -> tuple[str, Any]:
1253
- """Embed parameter values directly into SQL for STATIC style.
1254
-
1255
- This is used for scripts and other cases where parameters need to be
1256
- embedded directly in the SQL string rather than passed separately.
1257
-
1258
- Args:
1259
- sql: The SQL string with placeholders
1260
- params: The parameter values
1261
- param_info: List of parameter information from extraction
1262
-
1263
- Returns:
1264
- Tuple of (sql_with_embedded_values, None)
1265
- """
1266
- param_list: list[Any] = []
1267
- if is_dict(params):
1268
- for p in param_info:
1269
- if p.name and p.name in params:
1270
- param_list.append(params[p.name])
1271
- elif f"{PARAM_PREFIX}{p.ordinal}" in params:
1272
- param_list.append(params[f"{PARAM_PREFIX}{p.ordinal}"])
1273
- elif f"arg_{p.ordinal}" in params:
1274
- param_list.append(params[f"arg_{p.ordinal}"])
1275
- else:
1276
- param_list.append(params.get(str(p.ordinal), None))
1277
- elif isinstance(params, (list, tuple)):
1278
- param_list = list(params)
1279
- elif params is not None:
1280
- param_list = [params]
1281
-
1282
- sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
1283
-
1284
- for p in sorted_params:
1285
- if p.ordinal < len(param_list):
1286
- value = param_list[p.ordinal]
1287
-
1288
- if has_parameter_value(value):
1289
- value = value.value
1290
-
1291
- if value is None:
1292
- literal_str = "NULL"
1293
- elif isinstance(value, bool):
1294
- literal_str = "TRUE" if value else "FALSE"
1295
- elif isinstance(value, str):
1296
- literal_expr = sqlglot.exp.Literal.string(value)
1297
- literal_str = literal_expr.sql(dialect=self._dialect)
1298
- elif isinstance(value, (int, float)):
1299
- literal_expr = sqlglot.exp.Literal.number(value)
1300
- literal_str = literal_expr.sql(dialect=self._dialect)
1301
- else:
1302
- literal_expr = sqlglot.exp.Literal.string(str(value))
1303
- literal_str = literal_expr.sql(dialect=self._dialect)
1304
-
1305
- start = p.position
1306
- end = start + len(p.placeholder_text)
1307
- sql = sql[:start] + literal_str + sql[end:]
1308
-
1309
- return sql, None
1310
-
1311
- def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: ParameterStyle) -> str:
1312
- """Replace placeholders in SQL string with target style placeholders.
1313
-
1314
- Args:
1315
- sql: The SQL string
1316
- param_info: List of parameter information
1317
- target_style: Target parameter style
1318
-
1319
- Returns:
1320
- SQL string with replaced placeholders
1321
- """
1322
- sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
1323
-
1324
- for p in sorted_params:
1325
- new_placeholder = self._generate_placeholder(p, target_style)
1326
- start = p.position
1327
- end = start + len(p.placeholder_text)
1328
- sql = sql[:start] + new_placeholder + sql[end:]
1329
-
1330
- return sql
1331
-
1332
- @staticmethod
1333
- def _generate_placeholder(param: Any, target_style: ParameterStyle) -> str:
1334
- """Generate a placeholder string for the given parameter style.
1335
-
1336
- Args:
1337
- param: Parameter information object
1338
- target_style: Target parameter style
1339
-
1340
- Returns:
1341
- Placeholder string
1342
- """
1343
- if target_style in {ParameterStyle.STATIC, ParameterStyle.QMARK}:
1344
- return "?"
1345
- if target_style == ParameterStyle.NUMERIC:
1346
- return f"${param.ordinal + 1}"
1347
- if target_style == ParameterStyle.NAMED_COLON:
1348
- if param.name and not param.name.isdigit():
1349
- return f":{param.name}"
1350
- return f":arg_{param.ordinal}"
1351
- if target_style == ParameterStyle.NAMED_AT:
1352
- return f"@{param.name or f'param_{param.ordinal}'}"
1353
- if target_style == ParameterStyle.POSITIONAL_COLON:
1354
- # For Oracle positional colon, preserve the original numeric identifier if it was already :N style
1355
- if (
1356
- hasattr(param, "style")
1357
- and param.style == ParameterStyle.POSITIONAL_COLON
1358
- and hasattr(param, "name")
1359
- and param.name
1360
- and param.name.isdigit()
1361
- ):
1362
- return f":{param.name}"
1363
- return f":{param.ordinal + 1}"
1364
- if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
1365
- return "%s"
1366
- if target_style == ParameterStyle.NAMED_PYFORMAT:
1367
- return f"%({param.name or f'arg_{param.ordinal}'})s"
1368
- return str(param.placeholder_text)
1369
-
1370
- def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: ParameterStyle) -> Any:
1371
- """Convert parameters to the appropriate format for the target style.
1372
-
1373
- Args:
1374
- params: Original parameters
1375
- param_info: List of parameter information
1376
- target_style: Target parameter style
1377
-
1378
- Returns:
1379
- Converted parameters
1380
- """
1381
- if target_style == ParameterStyle.POSITIONAL_COLON:
1382
- return self._convert_to_positional_colon_format(params, param_info)
1383
- if target_style in {ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT}:
1384
- return self._convert_to_positional_format(params, param_info)
1385
- if target_style == ParameterStyle.NAMED_COLON:
1386
- return self._convert_to_named_colon_format(params, param_info)
1387
- if target_style == ParameterStyle.NAMED_PYFORMAT:
1388
- return self._convert_to_named_pyformat_format(params, param_info)
1389
- return params
1390
-
1391
- def _convert_list_to_colon_dict(
1392
- self, params: "Union[list[Any], tuple[Any, ...]]", param_info: "list[Any]"
1393
- ) -> "dict[str, Any]":
1394
- """Convert list/tuple parameters to colon-style dict format."""
1395
- result_dict: dict[str, Any] = {}
1396
-
1397
- if param_info:
1398
- all_numeric = all(p.name and p.name.isdigit() for p in param_info)
1399
- if all_numeric:
1400
- for i, value in enumerate(params):
1401
- result_dict[str(i + 1)] = value
1402
- else:
1403
- for i, value in enumerate(params):
1404
- if i < len(param_info):
1405
- param_name = param_info[i].name or str(i + 1)
1406
- result_dict[param_name] = value
1407
- else:
1408
- result_dict[str(i + 1)] = value
1409
- else:
1410
- for i, value in enumerate(params):
1411
- result_dict[str(i + 1)] = value
1412
-
1413
- return result_dict
1414
-
1415
- def _convert_single_value_to_colon_dict(self, params: Any, param_info: "list[Any]") -> "dict[str, Any]":
1416
- """Convert single value parameter to colon-style dict format."""
1417
- result_dict: dict[str, Any] = {}
1418
- if param_info and param_info[0].name and param_info[0].name.isdigit():
1419
- result_dict[param_info[0].name] = params
1420
- else:
1421
- result_dict["1"] = params
1422
- return result_dict
1423
-
1424
- def _process_mixed_colon_params(self, params: "dict[str, Any]", param_info: "list[Any]") -> "dict[str, Any]":
1425
- """Process mixed colon-style numeric and converted parameters."""
1426
- result_dict: dict[str, Any] = {}
1427
-
1428
- # When we have mixed parameters (extracted literals + user oracle params),
1429
- # we need to be careful about the ordering. The extracted literals should
1430
- # fill positions based on where they appear in the SQL, not based on
1431
- # matching parameter names.
1432
-
1433
- # Separate extracted parameters and user oracle parameters
1434
- extracted_params = []
1435
- user_oracle_params = {}
1436
- extracted_keys_sorted = []
1437
-
1438
- for key, value in params.items():
1439
- if has_parameter_value(value):
1440
- extracted_params.append((key, value))
1441
- elif key.isdigit():
1442
- user_oracle_params[key] = value
1443
- elif key.startswith("param_") and key[6:].isdigit():
1444
- param_idx = int(key[6:])
1445
- oracle_key = str(param_idx + 1)
1446
- if oracle_key not in user_oracle_params:
1447
- extracted_keys_sorted.append((param_idx, key, value))
1448
- else:
1449
- extracted_params.append((key, value))
1450
-
1451
- extracted_keys_sorted.sort(key=operator.itemgetter(0))
1452
- for _, key, value in extracted_keys_sorted:
1453
- extracted_params.append((key, value))
1454
-
1455
- # Build lists of parameter values in order
1456
- extracted_values = []
1457
- for _, value in extracted_params:
1458
- if has_parameter_value(value):
1459
- extracted_values.append(value.value)
1460
- else:
1461
- extracted_values.append(value)
1462
-
1463
- user_values = [user_oracle_params[key] for key in sorted(user_oracle_params.keys(), key=int)]
1464
-
1465
- # Now assign parameters based on position
1466
- # Extracted parameters go first (they were literals in original positions)
1467
- # User parameters follow
1468
- all_values = extracted_values + user_values
1469
-
1470
- for i, p in enumerate(sorted(param_info, key=lambda x: x.ordinal)):
1471
- oracle_key = str(p.ordinal + 1)
1472
- if i < len(all_values):
1473
- result_dict[oracle_key] = all_values[i]
1474
-
1475
- return result_dict
1476
-
1477
- def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
1478
- """Convert to dict format for positional colon style.
1479
-
1480
- Positional colon style uses :1, :2, etc. placeholders and expects
1481
- parameters as a dict with string keys "1", "2", etc.
1482
-
1483
- For execute_many operations, returns a list of parameter sets.
1484
-
1485
- Args:
1486
- params: Original parameters
1487
- param_info: List of parameter information
1488
-
1489
- Returns:
1490
- Dict of parameters with string keys "1", "2", etc., or list for execute_many
1491
- """
1492
- if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
1493
- return params
1494
-
1495
- if isinstance(params, (list, tuple)):
1496
- return self._convert_list_to_colon_dict(params, param_info)
1497
-
1498
- if not is_dict(params) and param_info:
1499
- return self._convert_single_value_to_colon_dict(params, param_info)
1500
-
1501
- if is_dict(params):
1502
- if all(key.isdigit() for key in params):
1503
- return params
1504
-
1505
- if all(key.startswith("param_") for key in params):
1506
- param_result_dict: dict[str, Any] = {}
1507
- for p in sorted(param_info, key=lambda x: x.ordinal):
1508
- # Use the parameter's ordinal to find the converted key
1509
- converted_key = f"param_{p.ordinal}"
1510
- if converted_key in params:
1511
- if p.name and p.name.isdigit():
1512
- # For Oracle numeric parameters, preserve the original number
1513
- param_result_dict[p.name] = params[converted_key]
1514
- else:
1515
- # For other cases, use sequential numbering
1516
- param_result_dict[str(p.ordinal + 1)] = params[converted_key]
1517
- return param_result_dict
1518
-
1519
- has_oracle_numeric = any(key.isdigit() for key in params)
1520
- has_param_converted = any(key.startswith("param_") for key in params)
1521
- has_typed_params = any(has_parameter_value(v) for v in params.values())
1522
-
1523
- if (has_oracle_numeric and has_param_converted) or has_typed_params:
1524
- return self._process_mixed_colon_params(params, param_info)
1525
-
1526
- result_dict: dict[str, Any] = {}
1527
-
1528
- if param_info:
1529
- # Process all parameters in order of their ordinals
1530
- for p in sorted(param_info, key=lambda x: x.ordinal):
1531
- oracle_key = str(p.ordinal + 1)
1532
- value = None
1533
-
1534
- # Try different ways to find the parameter value
1535
- if p.name and (
1536
- p.name in params
1537
- or (p.name.isdigit() and p.name in params)
1538
- or (p.name.startswith("param_") and p.name in params)
1539
- ):
1540
- value = params[p.name]
1541
-
1542
- # If not found by name, try by ordinal-based keys
1543
- if value is None:
1544
- # Try param_N format (common for pipeline parameters)
1545
- param_key = f"param_{p.ordinal}"
1546
- if param_key in params:
1547
- value = params[param_key]
1548
- # Try arg_N format
1549
- elif f"arg_{p.ordinal}" in params:
1550
- value = params[f"arg_{p.ordinal}"]
1551
- # For positional colon, also check if there's a numeric key
1552
- # that matches the ordinal position
1553
- elif str(p.ordinal + 1) in params:
1554
- value = params[str(p.ordinal + 1)]
1555
-
1556
- # Unwrap TypedParameter if needed
1557
- if value is not None:
1558
- if has_parameter_value(value):
1559
- value = value.value
1560
- result_dict[oracle_key] = value
1561
-
1562
- return result_dict
1563
-
1564
- return params
1565
-
1566
- @staticmethod
1567
- def _convert_to_positional_format(params: Any, param_info: list[Any]) -> Any:
1568
- """Convert to list format for positional parameter styles.
1569
-
1570
- Args:
1571
- params: Original parameters
1572
- param_info: List of parameter information
1573
-
1574
- Returns:
1575
- List of parameters
1576
- """
1577
- result_list: list[Any] = []
1578
- if is_dict(params):
1579
- param_values_by_ordinal: dict[int, Any] = {}
1580
-
1581
- for p in param_info:
1582
- if p.name and p.name in params:
1583
- param_values_by_ordinal[p.ordinal] = params[p.name]
1584
-
1585
- for p in param_info:
1586
- if p.name is None and p.ordinal not in param_values_by_ordinal:
1587
- arg_key = f"arg_{p.ordinal}"
1588
- param_key = f"param_{p.ordinal}"
1589
- if arg_key in params:
1590
- param_values_by_ordinal[p.ordinal] = params[arg_key]
1591
- elif param_key in params:
1592
- param_values_by_ordinal[p.ordinal] = params[param_key]
1593
-
1594
- remaining_params = {
1595
- k: v
1596
- for k, v in params.items()
1597
- if k not in {p.name for p in param_info if p.name} and not k.startswith(("arg_", "param_"))
1598
- }
1599
-
1600
- unmatched_ordinals = [p.ordinal for p in param_info if p.ordinal not in param_values_by_ordinal]
1601
-
1602
- for ordinal, (_, value) in zip(unmatched_ordinals, remaining_params.items()):
1603
- param_values_by_ordinal[ordinal] = value
1604
-
1605
- for p in param_info:
1606
- val = param_values_by_ordinal.get(p.ordinal)
1607
- if val is not None:
1608
- if has_parameter_value(val):
1609
- result_list.append(val.value)
1610
- else:
1611
- result_list.append(val)
1612
- else:
1613
- result_list.append(None)
1614
-
1615
- return result_list
1616
- if isinstance(params, (list, tuple)):
1617
- # Special case: if params is empty, preserve it (don't create None values)
1618
- # This is important for execute_many with empty parameter lists
1619
- if not params:
1620
- return params
1621
-
1622
- # Handle mixed parameter styles correctly
1623
- # For mixed styles, assign parameters in order of appearance, not by numeric reference
1624
- if param_info and any(p.style == ParameterStyle.NUMERIC for p in param_info):
1625
- # Create mapping from ordinal to parameter value
1626
- param_mapping: dict[int, Any] = {}
1627
-
1628
- # Sort parameter info by position to get order of appearance
1629
- sorted_params = sorted(param_info, key=lambda p: p.position)
1630
-
1631
- # Assign parameters sequentially in order of appearance
1632
- for i, param_info_item in enumerate(sorted_params):
1633
- if i < len(params):
1634
- param_mapping[param_info_item.ordinal] = params[i]
1635
-
1636
- # Build result list ordered by original ordinal values
1637
- for i in range(len(param_info)):
1638
- val = param_mapping.get(i)
1639
- if val is not None:
1640
- if has_parameter_value(val):
1641
- result_list.append(val.value)
1642
- else:
1643
- result_list.append(val)
1644
- else:
1645
- result_list.append(None)
1646
-
1647
- return result_list
1648
-
1649
- # Standard conversion for non-mixed styles
1650
- for param in params:
1651
- if has_parameter_value(param):
1652
- result_list.append(param.value)
1653
- else:
1654
- result_list.append(param)
1655
- return result_list
1656
- return params
1657
-
1658
- @staticmethod
1659
- def _convert_to_named_colon_format(params: Any, param_info: list[Any]) -> Any:
1660
- """Convert to dict format for named colon style.
1661
-
1662
- Args:
1663
- params: Original parameters
1664
- param_info: List of parameter information
1665
-
1666
- Returns:
1667
- Dict of parameters with generated names
1668
- """
1669
- result_dict: dict[str, Any] = {}
1670
- if is_dict(params):
1671
- if all(p.name in params for p in param_info if p.name):
1672
- return params
1673
- for p in param_info:
1674
- if p.name and p.name in params:
1675
- result_dict[p.name] = params[p.name]
1676
- elif f"param_{p.ordinal}" in params:
1677
- result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
1678
- return result_dict
1679
- if isinstance(params, (list, tuple)):
1680
- for i, value in enumerate(params):
1681
- if has_parameter_value(value):
1682
- value = value.value
1683
-
1684
- if i < len(param_info):
1685
- p = param_info[i]
1686
- param_name = p.name or f"arg_{i}"
1687
- result_dict[param_name] = value
1688
- else:
1689
- param_name = f"arg_{i}"
1690
- result_dict[param_name] = value
1691
- return result_dict
1692
- return params
1693
-
1694
- @staticmethod
1695
- def _convert_to_named_pyformat_format(params: Any, param_info: list[Any]) -> Any:
1696
- """Convert to dict format for named pyformat style.
1697
-
1698
- Args:
1699
- params: Original parameters
1700
- param_info: List of parameter information
1701
-
1702
- Returns:
1703
- Dict of parameters with names
1704
- """
1705
- if isinstance(params, (list, tuple)):
1706
- result_dict: dict[str, Any] = {}
1707
- for i, p in enumerate(param_info):
1708
- if i < len(params):
1709
- param_name = p.name or f"param_{i}"
1710
- result_dict[param_name] = params[i]
1711
- return result_dict
1712
- return params
1713
-
1714
- @property
1715
- def validation_errors(self) -> list[Any]:
1716
- """Get validation errors."""
1717
- if not self._config.enable_validation:
1718
- return []
1719
- self._ensure_processed()
1720
- if not self._processed_state:
1721
- msg = "Failed to process SQL statement"
1722
- raise RuntimeError(msg)
1723
- return self._processed_state.validation_errors
1724
-
1725
- @property
1726
- def has_errors(self) -> bool:
1727
- """Check if there are validation errors."""
1728
- return bool(self.validation_errors)
1729
-
1730
- @property
1731
- def is_safe(self) -> bool:
1732
- """Check if statement is safe."""
1733
- return not self.has_errors
1734
-
1735
- def validate(self) -> list[Any]:
1736
- """Validate the SQL statement and return validation errors."""
1737
- return self.validation_errors
1738
-
1739
- @property
1740
- def parameter_info(self) -> list[Any]:
1741
- """Get parameter information from the SQL statement.
1742
-
1743
- Returns the original parameter info before any conversion.
1744
- """
1745
- validator = self._config.parameter_validator
1746
- if self._raw_sql:
1747
- return validator.extract_parameters(self._raw_sql)
1748
-
1749
- self._ensure_processed()
1750
-
1751
- if self._processing_context:
1752
- return self._processing_context.parameter_info
1753
-
1754
- return []
1755
-
1756
- @property
1757
- def _raw_parameters(self) -> Any:
1758
- """Get raw parameters for compatibility."""
1759
- return self._original_parameters
1760
-
1761
- @property
1762
- def _sql(self) -> str:
1763
- """Get SQL string for compatibility."""
1764
- return self.sql
1765
-
1766
- @property
1767
- def _expression(self) -> "Optional[exp.Expression]":
1768
- """Get expression for compatibility."""
1769
- return self.expression
1770
-
1771
- @property
1772
- def statement(self) -> exp.Expression:
1773
- """Get statement for compatibility."""
1774
- return self._statement