sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,81 @@
1
+ """Removes SQL hints from expressions."""
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from sqlglot import exp
6
+
7
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
8
+
9
+ if TYPE_CHECKING:
10
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
11
+
12
+ __all__ = ("HintRemover",)
13
+
14
+
15
+ class HintRemover(ProcessorProtocol):
16
+ """Removes SQL hints from expressions using SQLGlot's AST traversal.
17
+
18
+ This transformer removes SQL hints while preserving standard comments:
19
+ - Removes Oracle-style hints (/*+ hint */)
20
+ - Removes MySQL version comments (/*!50000 */)
21
+ - Removes formal hint expressions (exp.Hint nodes)
22
+ - Preserves standard comments (-- comment, /* comment */)
23
+ - Uses SQLGlot's AST for reliable, context-aware hint detection
24
+
25
+ Args:
26
+ enabled: Whether hint removal is enabled.
27
+ remove_oracle_hints: Whether to remove Oracle-style hints (/*+ hint */).
28
+ remove_mysql_version_comments: Whether to remove MySQL /*!50000 */ style comments.
29
+ """
30
+
31
+ def __init__(
32
+ self, enabled: bool = True, remove_oracle_hints: bool = True, remove_mysql_version_comments: bool = True
33
+ ) -> None:
34
+ self.enabled = enabled
35
+ self.remove_oracle_hints = remove_oracle_hints
36
+ self.remove_mysql_version_comments = remove_mysql_version_comments
37
+
38
+ def process(
39
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
40
+ ) -> "Optional[exp.Expression]":
41
+ """Removes SQL hints from the expression using SQLGlot AST traversal."""
42
+ if not self.enabled or expression is None or context.current_expression is None:
43
+ return expression
44
+
45
+ hints_removed_count = 0
46
+
47
+ def _remove_hint_node(node: exp.Expression) -> "Optional[exp.Expression]":
48
+ nonlocal hints_removed_count
49
+ if isinstance(node, exp.Hint):
50
+ hints_removed_count += 1
51
+ return None
52
+
53
+ if hasattr(node, "comments") and node.comments:
54
+ original_comment_count = len(node.comments)
55
+ comments_to_keep = []
56
+ for comment in node.comments:
57
+ comment_text = str(comment).strip()
58
+ hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
59
+ is_oracle_hint = any(keyword in comment_text.upper() for keyword in hint_keywords)
60
+
61
+ if is_oracle_hint:
62
+ if self.remove_oracle_hints:
63
+ continue
64
+ elif comment_text.startswith("!") and self.remove_mysql_version_comments:
65
+ continue
66
+
67
+ comments_to_keep.append(comment)
68
+
69
+ if len(comments_to_keep) < original_comment_count:
70
+ hints_removed_count += original_comment_count - len(comments_to_keep)
71
+ node.pop_comments()
72
+ if comments_to_keep:
73
+ node.add_comments(comments_to_keep)
74
+ return node
75
+
76
+ transformed_expression = context.current_expression.transform(_remove_hint_node, copy=True)
77
+ context.current_expression = transformed_expression or exp.Anonymous(this="")
78
+
79
+ context.metadata["hints_removed"] = hints_removed_count
80
+
81
+ return context.current_expression
@@ -0,0 +1,23 @@
1
+ """SQL Validation Pipeline Components."""
2
+
3
+ from sqlspec.statement.pipelines.validators._dml_safety import DMLSafetyConfig, DMLSafetyValidator
4
+ from sqlspec.statement.pipelines.validators._parameter_style import ParameterStyleValidator
5
+ from sqlspec.statement.pipelines.validators._performance import PerformanceConfig, PerformanceValidator
6
+ from sqlspec.statement.pipelines.validators._security import (
7
+ SecurityIssue,
8
+ SecurityIssueType,
9
+ SecurityValidator,
10
+ SecurityValidatorConfig,
11
+ )
12
+
13
+ __all__ = (
14
+ "DMLSafetyConfig",
15
+ "DMLSafetyValidator",
16
+ "ParameterStyleValidator",
17
+ "PerformanceConfig",
18
+ "PerformanceValidator",
19
+ "SecurityIssue",
20
+ "SecurityIssueType",
21
+ "SecurityValidator",
22
+ "SecurityValidatorConfig",
23
+ )
@@ -0,0 +1,275 @@
1
+ # DML Safety Validator - Consolidates risky DML operations and DDL prevention
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ from sqlglot import expressions as exp
7
+
8
+ from sqlspec.exceptions import RiskLevel
9
+ from sqlspec.statement.pipelines.validators.base import BaseValidator
10
+
11
+ if TYPE_CHECKING:
12
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
13
+
14
+ __all__ = ("DMLSafetyConfig", "DMLSafetyValidator", "StatementCategory")
15
+
16
+
17
+ class StatementCategory(Enum):
18
+ """Categories for SQL statement types."""
19
+
20
+ DDL = "ddl" # CREATE, ALTER, DROP, TRUNCATE
21
+ DML = "dml" # INSERT, UPDATE, DELETE, MERGE
22
+ DQL = "dql" # SELECT
23
+ DCL = "dcl" # GRANT, REVOKE
24
+ TCL = "tcl" # COMMIT, ROLLBACK, SAVEPOINT
25
+
26
+
27
+ @dataclass
28
+ class DMLSafetyConfig:
29
+ """Configuration for DML safety validation."""
30
+
31
+ prevent_ddl: bool = True
32
+ prevent_dcl: bool = True
33
+ require_where_clause: "set[str]" = field(default_factory=lambda: {"DELETE", "UPDATE"})
34
+ allowed_ddl_operations: "set[str]" = field(default_factory=set)
35
+ migration_mode: bool = False # Allow DDL in migration contexts
36
+ max_affected_rows: "Optional[int]" = None # Limit for DML operations
37
+
38
+
39
+ class DMLSafetyValidator(BaseValidator):
40
+ """Unified validator for DML/DDL safety checks.
41
+
42
+ This validator consolidates:
43
+ - DDL prevention (CREATE, ALTER, DROP, etc.)
44
+ - Risky DML detection (DELETE/UPDATE without WHERE)
45
+ - DCL restrictions (GRANT, REVOKE)
46
+ - Row limit enforcement
47
+ """
48
+
49
+ def __init__(self, config: "Optional[DMLSafetyConfig]" = None) -> None:
50
+ """Initialize the DML safety validator.
51
+
52
+ Args:
53
+ config: Configuration for safety validation
54
+ """
55
+ super().__init__()
56
+ self.config = config or DMLSafetyConfig()
57
+
58
+ def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
59
+ """Validate SQL statement for safety issues.
60
+
61
+ Args:
62
+ expression: The SQL expression to validate
63
+ context: The SQL processing context
64
+ """
65
+ # Categorize statement
66
+ category = self._categorize_statement(expression)
67
+ operation = self._get_operation_type(expression)
68
+
69
+ # Check DDL restrictions
70
+ if category == StatementCategory.DDL and self.config.prevent_ddl:
71
+ if operation not in self.config.allowed_ddl_operations:
72
+ self.add_error(
73
+ context,
74
+ message=f"DDL operation '{operation}' is not allowed",
75
+ code="ddl-not-allowed",
76
+ risk_level=RiskLevel.CRITICAL,
77
+ expression=expression,
78
+ )
79
+
80
+ # Check DML safety
81
+ elif category == StatementCategory.DML:
82
+ if operation in self.config.require_where_clause and not self._has_where_clause(expression):
83
+ self.add_error(
84
+ context,
85
+ message=f"{operation} without WHERE clause affects all rows",
86
+ code=f"{operation.lower()}-without-where",
87
+ risk_level=RiskLevel.HIGH,
88
+ expression=expression,
89
+ )
90
+
91
+ # Check affected row limits
92
+ if self.config.max_affected_rows:
93
+ estimated_rows = self._estimate_affected_rows(expression)
94
+ if estimated_rows > self.config.max_affected_rows:
95
+ self.add_error(
96
+ context,
97
+ message=f"Operation may affect {estimated_rows:,} rows (limit: {self.config.max_affected_rows:,})",
98
+ code="excessive-rows-affected",
99
+ risk_level=RiskLevel.MEDIUM,
100
+ expression=expression,
101
+ )
102
+
103
+ # Check DCL restrictions
104
+ elif category == StatementCategory.DCL and self.config.prevent_dcl:
105
+ self.add_error(
106
+ context,
107
+ message=f"DCL operation '{operation}' is not allowed",
108
+ code="dcl-not-allowed",
109
+ risk_level=RiskLevel.HIGH,
110
+ expression=expression,
111
+ )
112
+
113
+ # Store metadata in context
114
+ context.metadata[self.__class__.__name__] = {
115
+ "statement_category": category.value,
116
+ "operation": operation,
117
+ "has_where_clause": self._has_where_clause(expression) if category == StatementCategory.DML else None,
118
+ "affected_tables": self._extract_affected_tables(expression),
119
+ "migration_mode": self.config.migration_mode,
120
+ }
121
+
122
+ @staticmethod
123
+ def _categorize_statement(expression: "exp.Expression") -> StatementCategory:
124
+ """Categorize SQL statement type.
125
+
126
+ Args:
127
+ expression: The SQL expression to categorize
128
+
129
+ Returns:
130
+ The statement category
131
+ """
132
+ if isinstance(expression, (exp.Create, exp.Alter, exp.Drop, exp.TruncateTable, exp.Comment)):
133
+ return StatementCategory.DDL
134
+
135
+ if isinstance(expression, (exp.Select, exp.Union, exp.Intersect, exp.Except)):
136
+ return StatementCategory.DQL
137
+
138
+ if isinstance(expression, (exp.Insert, exp.Update, exp.Delete, exp.Merge)):
139
+ return StatementCategory.DML
140
+
141
+ if isinstance(expression, (exp.Grant,)):
142
+ return StatementCategory.DCL
143
+
144
+ if isinstance(expression, (exp.Commit, exp.Rollback)):
145
+ return StatementCategory.TCL
146
+
147
+ return StatementCategory.DQL # Default to query
148
+
149
+ @staticmethod
150
+ def _get_operation_type(expression: "exp.Expression") -> str:
151
+ """Get specific operation name.
152
+
153
+ Args:
154
+ expression: The SQL expression
155
+
156
+ Returns:
157
+ The operation type as string
158
+ """
159
+ return expression.__class__.__name__.upper()
160
+
161
+ @staticmethod
162
+ def _has_where_clause(expression: "exp.Expression") -> bool:
163
+ """Check if DML statement has WHERE clause.
164
+
165
+ Args:
166
+ expression: The SQL expression to check
167
+
168
+ Returns:
169
+ True if WHERE clause exists, False otherwise
170
+ """
171
+ if isinstance(expression, (exp.Delete, exp.Update)):
172
+ return expression.args.get("where") is not None
173
+ return True # Other statements don't require WHERE
174
+
175
+ def _estimate_affected_rows(self, expression: "exp.Expression") -> int:
176
+ """Estimate number of rows affected by DML operation.
177
+
178
+ Args:
179
+ expression: The SQL expression
180
+
181
+ Returns:
182
+ Estimated number of affected rows
183
+ """
184
+ # Simple heuristic - can be enhanced with table statistics
185
+ if not self._has_where_clause(expression):
186
+ return 999999999 # Large number to indicate all rows
187
+
188
+ where = expression.args.get("where")
189
+ if where:
190
+ # Check for primary key or unique conditions
191
+ if self._has_unique_condition(where):
192
+ return 1
193
+ # Check for indexed conditions
194
+ if self._has_indexed_condition(where):
195
+ return 100 # Rough estimate
196
+
197
+ return 10000 # Conservative estimate
198
+
199
+ @staticmethod
200
+ def _has_unique_condition(where: "Optional[exp.Expression]") -> bool:
201
+ """Check if WHERE clause uses unique columns.
202
+
203
+ Args:
204
+ where: The WHERE expression
205
+
206
+ Returns:
207
+ True if unique condition found
208
+ """
209
+ if where is None:
210
+ return False
211
+ # Look for id = value patterns
212
+ for condition in where.find_all(exp.EQ):
213
+ if isinstance(condition.left, exp.Column):
214
+ col_name = condition.left.name.lower()
215
+ if col_name in {"id", "uuid", "guid", "pk", "primary_key"}:
216
+ return True
217
+ return False
218
+
219
+ @staticmethod
220
+ def _has_indexed_condition(where: "Optional[exp.Expression]") -> bool:
221
+ """Check if WHERE clause uses indexed columns.
222
+
223
+ Args:
224
+ where: The WHERE expression
225
+
226
+ Returns:
227
+ True if indexed condition found
228
+ """
229
+ if where is None:
230
+ return False
231
+ # Look for common indexed column patterns
232
+ for condition in where.find_all(exp.Predicate):
233
+ if hasattr(condition, "left") and isinstance(condition.left, exp.Column): # pyright: ignore
234
+ col_name = condition.left.name.lower() # pyright: ignore
235
+ # Common indexed columns
236
+ if col_name in {"created_at", "updated_at", "email", "username", "status", "type"}:
237
+ return True
238
+ return False
239
+
240
+ @staticmethod
241
+ def _extract_affected_tables(expression: "exp.Expression") -> "list[str]":
242
+ """Extract table names affected by the statement.
243
+
244
+ Args:
245
+ expression: The SQL expression
246
+
247
+ Returns:
248
+ List of affected table names
249
+ """
250
+ tables = []
251
+
252
+ # For DML statements
253
+ if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
254
+ if hasattr(expression, "this") and expression.this:
255
+ table_expr = expression.this
256
+ if isinstance(table_expr, exp.Table):
257
+ tables.append(table_expr.name)
258
+
259
+ # For DDL statements
260
+ elif (
261
+ isinstance(expression, (exp.Create, exp.Drop, exp.Alter))
262
+ and hasattr(expression, "this")
263
+ and expression.this
264
+ ):
265
+ # For CREATE TABLE, the table is in expression.this.this
266
+ if isinstance(expression, exp.Create) and isinstance(expression.this, exp.Schema):
267
+ if hasattr(expression.this, "this") and expression.this.this:
268
+ table_expr = expression.this.this
269
+ if isinstance(table_expr, exp.Table):
270
+ tables.append(table_expr.name)
271
+ # For DROP/ALTER, table is directly in expression.this
272
+ elif isinstance(expression.this, (exp.Table, exp.Identifier)):
273
+ tables.append(expression.this.name)
274
+
275
+ return tables
@@ -0,0 +1,297 @@
1
+ """Parameter style validation for SQL statements."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, Optional, Union
5
+
6
+ from sqlglot import exp
7
+
8
+ from sqlspec.exceptions import MissingParameterError, RiskLevel, SQLValidationError
9
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
10
+ from sqlspec.statement.pipelines.result_types import ValidationError
11
+ from sqlspec.typing import is_dict
12
+
13
+ if TYPE_CHECKING:
14
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
15
+
16
+ logger = logging.getLogger("sqlspec.validators.parameter_style")
17
+
18
+ __all__ = ("ParameterStyleValidator",)
19
+
20
+
21
+ class UnsupportedParameterStyleError(SQLValidationError):
22
+ """Raised when a parameter style is not supported by the current database."""
23
+
24
+
25
+ class MixedParameterStyleError(SQLValidationError):
26
+ """Raised when mixed parameter styles are detected but not allowed."""
27
+
28
+
29
+ class ParameterStyleValidator(ProcessorProtocol):
30
+ """Validates that parameter styles are supported by the database configuration.
31
+
32
+ This validator checks:
33
+ 1. Whether detected parameter styles are in the allowed list
34
+ 2. Whether mixed parameter styles are used when not allowed
35
+ 3. Provides helpful error messages about supported styles
36
+ """
37
+
38
+ def __init__(self, risk_level: "RiskLevel" = RiskLevel.HIGH, fail_on_violation: bool = True) -> None:
39
+ """Initialize the parameter style validator.
40
+
41
+ Args:
42
+ risk_level: Risk level for unsupported parameter styles
43
+ fail_on_violation: Whether to raise exception on violation
44
+ """
45
+ self.risk_level = risk_level
46
+ self.fail_on_violation = fail_on_violation
47
+
48
+ def process(self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext") -> None:
49
+ """Validate parameter styles in SQL.
50
+
51
+ Args:
52
+ expression: The SQL expression being validated
53
+ context: SQL processing context with config
54
+
55
+ Returns:
56
+ A ProcessorResult with the outcome of the validation.
57
+ """
58
+ if expression is None:
59
+ return
60
+
61
+ if context.current_expression is None:
62
+ error = ValidationError(
63
+ message="ParameterStyleValidator received no expression.",
64
+ code="no-expression",
65
+ risk_level=RiskLevel.CRITICAL,
66
+ processor="ParameterStyleValidator",
67
+ expression=None,
68
+ )
69
+ context.validation_errors.append(error)
70
+ return
71
+
72
+ try:
73
+ config = context.config
74
+ param_info = context.parameter_info
75
+
76
+ # First check parameter styles if configured
77
+ has_style_errors = False
78
+ if config.allowed_parameter_styles is not None and param_info:
79
+ unique_styles = {p.style for p in param_info}
80
+
81
+ # Check for mixed styles first (before checking individual styles)
82
+ if len(unique_styles) > 1 and not config.allow_mixed_parameter_styles:
83
+ detected_style_strs = [str(s) for s in unique_styles]
84
+ detected_styles = ", ".join(sorted(detected_style_strs))
85
+ msg = f"Mixed parameter styles detected ({detected_styles}) but not allowed."
86
+ if self.fail_on_violation:
87
+ self._raise_mixed_style_error(msg)
88
+ error = ValidationError(
89
+ message=msg,
90
+ code="mixed-parameter-styles",
91
+ risk_level=self.risk_level,
92
+ processor="ParameterStyleValidator",
93
+ expression=expression,
94
+ )
95
+ context.validation_errors.append(error)
96
+ has_style_errors = True
97
+
98
+ # Check for disallowed styles
99
+ disallowed_styles = {str(s) for s in unique_styles if not config.validate_parameter_style(s)}
100
+ if disallowed_styles:
101
+ disallowed_str = ", ".join(sorted(disallowed_styles))
102
+ # Defensive handling to avoid "expected str instance, NoneType found"
103
+ if config.allowed_parameter_styles:
104
+ allowed_styles_strs = [str(s) for s in config.allowed_parameter_styles]
105
+ allowed_str = ", ".join(allowed_styles_strs)
106
+ msg = f"Parameter style(s) {disallowed_str} not supported. Allowed: {allowed_str}"
107
+ else:
108
+ msg = f"Parameter style(s) {disallowed_str} not supported."
109
+
110
+ if self.fail_on_violation:
111
+ self._raise_unsupported_style_error(msg)
112
+ error = ValidationError(
113
+ message=msg,
114
+ code="unsupported-parameter-style",
115
+ risk_level=self.risk_level,
116
+ processor="ParameterStyleValidator",
117
+ expression=expression,
118
+ )
119
+ context.validation_errors.append(error)
120
+ has_style_errors = True
121
+
122
+ # Check for missing parameters if:
123
+ # 1. We have parameter info
124
+ # 2. Style validation is enabled (allowed_parameter_styles is not None)
125
+ # 3. No style errors were found
126
+ # 4. We have merged parameters OR the original SQL had placeholders
127
+ logger.debug(
128
+ "Checking missing parameters: param_info=%s, extracted=%s, had_placeholders=%s, merged=%s",
129
+ len(param_info) if param_info else 0,
130
+ len(context.extracted_parameters_from_pipeline) if context.extracted_parameters_from_pipeline else 0,
131
+ context.input_sql_had_placeholders,
132
+ context.merged_parameters is not None,
133
+ )
134
+ # Skip validation if we have no merged parameters and the SQL didn't originally have placeholders
135
+ # This handles the case where literals were parameterized by transformers
136
+ if (
137
+ param_info
138
+ and config.allowed_parameter_styles is not None
139
+ and not has_style_errors
140
+ and (context.merged_parameters is not None or context.input_sql_had_placeholders)
141
+ ):
142
+ self._validate_missing_parameters(context, expression)
143
+
144
+ except (UnsupportedParameterStyleError, MixedParameterStyleError, MissingParameterError):
145
+ raise
146
+ except Exception as e:
147
+ logger.warning("Parameter style validation failed: %s", e)
148
+ error = ValidationError(
149
+ message=f"Parameter style validation failed: {e}",
150
+ code="validation-error",
151
+ risk_level=RiskLevel.LOW,
152
+ processor="ParameterStyleValidator",
153
+ expression=expression,
154
+ )
155
+ context.validation_errors.append(error)
156
+
157
+ @staticmethod
158
+ def _raise_mixed_style_error(msg: "str") -> "None":
159
+ """Raise MixedParameterStyleError with the given message."""
160
+ raise MixedParameterStyleError(msg)
161
+
162
+ @staticmethod
163
+ def _raise_unsupported_style_error(msg: "str") -> "None":
164
+ """Raise UnsupportedParameterStyleError with the given message."""
165
+ raise UnsupportedParameterStyleError(msg)
166
+
167
+ def _validate_missing_parameters(self, context: "SQLProcessingContext", expression: exp.Expression) -> None:
168
+ """Validate that all required parameters have values provided."""
169
+ param_info = context.parameter_info
170
+ if not param_info:
171
+ return
172
+
173
+ merged_params = self._prepare_merged_parameters(context, param_info)
174
+
175
+ if merged_params is None:
176
+ self._handle_no_parameters(context, expression, param_info)
177
+ elif isinstance(merged_params, (list, tuple)):
178
+ self._handle_positional_parameters(context, expression, param_info, merged_params)
179
+ elif is_dict(merged_params):
180
+ self._handle_named_parameters(context, expression, param_info, merged_params)
181
+ elif len(param_info) > 1:
182
+ self._handle_single_value_multiple_params(context, expression, param_info)
183
+
184
+ @staticmethod
185
+ def _prepare_merged_parameters(context: "SQLProcessingContext", param_info: list[Any]) -> Any:
186
+ """Prepare merged parameters for validation."""
187
+ merged_params = context.merged_parameters
188
+
189
+ # If we have extracted parameters from transformers (like ParameterizeLiterals),
190
+ # use those for validation instead of the original merged_parameters
191
+ if context.extracted_parameters_from_pipeline and not context.input_sql_had_placeholders:
192
+ # Use extracted parameters as they represent the actual values to be used
193
+ merged_params = context.extracted_parameters_from_pipeline
194
+ has_positional_colon = any(p.style.value == "positional_colon" for p in param_info)
195
+ if has_positional_colon and not isinstance(merged_params, (list, tuple, dict)) and merged_params is not None:
196
+ return [merged_params]
197
+ return merged_params
198
+
199
+ def _report_error(self, context: "SQLProcessingContext", expression: exp.Expression, message: str) -> None:
200
+ """Report a missing parameter error."""
201
+ if self.fail_on_violation:
202
+ raise MissingParameterError(message)
203
+ error = ValidationError(
204
+ message=message,
205
+ code="missing-parameters",
206
+ risk_level=self.risk_level,
207
+ processor="ParameterStyleValidator",
208
+ expression=expression,
209
+ )
210
+ context.validation_errors.append(error)
211
+
212
+ def _handle_no_parameters(
213
+ self, context: "SQLProcessingContext", expression: exp.Expression, param_info: list[Any]
214
+ ) -> None:
215
+ """Handle validation when no parameters are provided."""
216
+ if context.extracted_parameters_from_pipeline:
217
+ return
218
+ missing = [p.name or p.placeholder_text or f"param_{p.ordinal}" for p in param_info]
219
+ msg = f"Missing required parameters: {', '.join(str(m) for m in missing)}"
220
+ self._report_error(context, expression, msg)
221
+
222
+ def _handle_positional_parameters(
223
+ self,
224
+ context: "SQLProcessingContext",
225
+ expression: exp.Expression,
226
+ param_info: list[Any],
227
+ merged_params: "Union[list[Any], tuple[Any, ...]]",
228
+ ) -> None:
229
+ """Handle validation for positional parameters."""
230
+ has_named = any(p.style.value in {"named_colon", "named_at"} for p in param_info)
231
+ if has_named:
232
+ missing_named = [
233
+ p.name or p.placeholder_text for p in param_info if p.style.value in {"named_colon", "named_at"}
234
+ ]
235
+ if missing_named:
236
+ msg = f"Missing required parameters: {', '.join(str(m) for m in missing_named if m)}"
237
+ self._report_error(context, expression, msg)
238
+ return
239
+
240
+ has_positional_colon = any(p.style.value == "positional_colon" for p in param_info)
241
+ if has_positional_colon:
242
+ self._validate_oracle_numeric_params(context, expression, param_info, merged_params)
243
+ elif len(merged_params) < len(param_info):
244
+ msg = f"Expected {len(param_info)} parameters but got {len(merged_params)}"
245
+ self._report_error(context, expression, msg)
246
+
247
+ def _validate_oracle_numeric_params(
248
+ self,
249
+ context: "SQLProcessingContext",
250
+ expression: exp.Expression,
251
+ param_info: list[Any],
252
+ merged_params: "Union[list[Any], tuple[Any, ...]]",
253
+ ) -> None:
254
+ """Validate Oracle-style numeric parameters."""
255
+ missing_indices: list[str] = []
256
+ provided_count = len(merged_params)
257
+ for p in param_info:
258
+ if p.style.value != "positional_colon" or not p.name:
259
+ continue
260
+ try:
261
+ idx = int(p.name)
262
+ if not (idx < provided_count or (idx > 0 and (idx - 1) < provided_count)):
263
+ missing_indices.append(p.name)
264
+ except (ValueError, TypeError):
265
+ pass
266
+ if missing_indices:
267
+ msg = f"Missing required parameters: :{', :'.join(missing_indices)}"
268
+ self._report_error(context, expression, msg)
269
+
270
+ def _handle_named_parameters(
271
+ self,
272
+ context: "SQLProcessingContext",
273
+ expression: exp.Expression,
274
+ param_info: list[Any],
275
+ merged_params: dict[str, Any],
276
+ ) -> None:
277
+ """Handle validation for named parameters."""
278
+ missing: list[str] = []
279
+ for p in param_info:
280
+ param_name = p.name
281
+ if param_name not in merged_params:
282
+ is_synthetic = any(key.startswith(("_arg_", "param_")) for key in merged_params)
283
+ is_named_style = p.style.value not in {"qmark", "numeric"}
284
+ if (not is_synthetic or is_named_style) and param_name:
285
+ missing.append(param_name)
286
+
287
+ if missing:
288
+ msg = f"Missing required parameters: {', '.join(missing)}"
289
+ self._report_error(context, expression, msg)
290
+
291
+ def _handle_single_value_multiple_params(
292
+ self, context: "SQLProcessingContext", expression: exp.Expression, param_info: list[Any]
293
+ ) -> None:
294
+ """Handle validation for a single value provided for multiple parameters."""
295
+ missing = [p.name or p.placeholder_text or f"param_{p.ordinal}" for p in param_info[1:]]
296
+ msg = f"Missing required parameters: {', '.join(str(m) for m in missing)}"
297
+ self._report_error(context, expression, msg)