sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,76 @@
1
+ """Removes SQL comments and hints from expressions."""
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from sqlglot import exp
6
+
7
+ from sqlspec.protocols import ProcessorProtocol
8
+
9
+ if TYPE_CHECKING:
10
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
11
+
12
+ __all__ = ("CommentAndHintRemover",)
13
+
14
+
15
+ class CommentAndHintRemover(ProcessorProtocol):
16
+ """Removes SQL comments and hints from expressions using SQLGlot's AST traversal."""
17
+
18
+ def __init__(self, enabled: bool = True, remove_comments: bool = True, remove_hints: bool = False) -> None:
19
+ self.enabled = enabled
20
+ self.remove_comments = remove_comments
21
+ self.remove_hints = remove_hints
22
+
23
+ def process(
24
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
25
+ ) -> "Optional[exp.Expression]":
26
+ if not self.enabled or expression is None:
27
+ return expression
28
+
29
+ comments_removed_count = 0
30
+ hints_removed_count = 0
31
+
32
+ def _remove_comments_and_hints(node: exp.Expression) -> "Optional[exp.Expression]":
33
+ nonlocal comments_removed_count, hints_removed_count
34
+
35
+ if self.remove_hints and isinstance(node, exp.Hint):
36
+ hints_removed_count += 1
37
+ return None
38
+
39
+ if hasattr(node, "comments") and node.comments:
40
+ original_comment_count = len(node.comments)
41
+ comments_to_keep = []
42
+ for comment in node.comments:
43
+ comment_text = str(comment).strip()
44
+ is_hint = self._is_hint(comment_text)
45
+
46
+ if is_hint:
47
+ if not self.remove_hints:
48
+ comments_to_keep.append(comment)
49
+ elif not self.remove_comments:
50
+ comments_to_keep.append(comment)
51
+
52
+ removed_count = original_comment_count - len(comments_to_keep)
53
+ if removed_count > 0:
54
+ if self.remove_hints:
55
+ hints_removed_count += sum(1 for c in node.comments if self._is_hint(str(c).strip()))
56
+ if self.remove_comments:
57
+ comments_removed_count += sum(1 for c in node.comments if not self._is_hint(str(c).strip()))
58
+
59
+ node.pop_comments()
60
+ if comments_to_keep:
61
+ node.add_comments(comments_to_keep)
62
+
63
+ return node
64
+
65
+ cleaned_expression = expression.transform(_remove_comments_and_hints, copy=True)
66
+
67
+ context.metadata["comments_removed"] = comments_removed_count
68
+ context.metadata["hints_removed"] = hints_removed_count
69
+
70
+ return cleaned_expression
71
+
72
+ def _is_hint(self, comment_text: str) -> bool:
73
+ hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
74
+ return any(keyword in comment_text.upper() for keyword in hint_keywords) or (
75
+ comment_text.startswith("!") and comment_text.endswith("")
76
+ )
@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Optional
6
6
  from sqlglot import expressions as exp
7
7
 
8
8
  from sqlspec.exceptions import RiskLevel
9
- from sqlspec.statement.pipelines.validators.base import BaseValidator
9
+ from sqlspec.protocols import ProcessorProtocol
10
+ from sqlspec.statement.pipelines.context import ValidationError
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sqlspec.statement.pipelines.context import SQLProcessingContext
@@ -36,7 +37,7 @@ class DMLSafetyConfig:
36
37
  max_affected_rows: "Optional[int]" = None # Limit for DML operations
37
38
 
38
39
 
39
- class DMLSafetyValidator(BaseValidator):
40
+ class DMLSafetyValidator(ProcessorProtocol):
40
41
  """Unified validator for DML/DDL safety checks.
41
42
 
42
43
  This validator consolidates:
@@ -52,9 +53,31 @@ class DMLSafetyValidator(BaseValidator):
52
53
  Args:
53
54
  config: Configuration for safety validation
54
55
  """
55
- super().__init__()
56
56
  self.config = config or DMLSafetyConfig()
57
57
 
58
+ def process(
59
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
60
+ ) -> "Optional[exp.Expression]":
61
+ """Process the expression for validation (implements ProcessorProtocol)."""
62
+ if expression is None:
63
+ return None
64
+ self.validate(expression, context)
65
+ return expression
66
+
67
+ def add_error(
68
+ self,
69
+ context: "SQLProcessingContext",
70
+ message: str,
71
+ code: str,
72
+ risk_level: RiskLevel,
73
+ expression: "Optional[exp.Expression]" = None,
74
+ ) -> None:
75
+ """Add a validation error to the context."""
76
+ error = ValidationError(
77
+ message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
78
+ )
79
+ context.validation_errors.append(error)
80
+
58
81
  def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
59
82
  """Validate SQL statement for safety issues.
60
83
 
@@ -66,7 +89,6 @@ class DMLSafetyValidator(BaseValidator):
66
89
  category = self._categorize_statement(expression)
67
90
  operation = self._get_operation_type(expression)
68
91
 
69
- # Check DDL restrictions
70
92
  if category == StatementCategory.DDL and self.config.prevent_ddl:
71
93
  if operation not in self.config.allowed_ddl_operations:
72
94
  self.add_error(
@@ -77,7 +99,6 @@ class DMLSafetyValidator(BaseValidator):
77
99
  expression=expression,
78
100
  )
79
101
 
80
- # Check DML safety
81
102
  elif category == StatementCategory.DML:
82
103
  if operation in self.config.require_where_clause and not self._has_where_clause(expression):
83
104
  self.add_error(
@@ -88,7 +109,6 @@ class DMLSafetyValidator(BaseValidator):
88
109
  expression=expression,
89
110
  )
90
111
 
91
- # Check affected row limits
92
112
  if self.config.max_affected_rows:
93
113
  estimated_rows = self._estimate_affected_rows(expression)
94
114
  if estimated_rows > self.config.max_affected_rows:
@@ -100,7 +120,6 @@ class DMLSafetyValidator(BaseValidator):
100
120
  expression=expression,
101
121
  )
102
122
 
103
- # Check DCL restrictions
104
123
  elif category == StatementCategory.DCL and self.config.prevent_dcl:
105
124
  self.add_error(
106
125
  context,
@@ -187,10 +206,8 @@ class DMLSafetyValidator(BaseValidator):
187
206
 
188
207
  where = expression.args.get("where")
189
208
  if where:
190
- # Check for primary key or unique conditions
191
209
  if self._has_unique_condition(where):
192
210
  return 1
193
- # Check for indexed conditions
194
211
  if self._has_indexed_condition(where):
195
212
  return 100 # Rough estimate
196
213
 
@@ -230,8 +247,10 @@ class DMLSafetyValidator(BaseValidator):
230
247
  return False
231
248
  # Look for common indexed column patterns
232
249
  for condition in where.find_all(exp.Predicate):
233
- if hasattr(condition, "left") and isinstance(condition.left, exp.Column): # pyright: ignore
234
- col_name = condition.left.name.lower() # pyright: ignore
250
+ if isinstance(condition, (exp.EQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.NEQ)) and isinstance(
251
+ condition.left, exp.Column
252
+ ):
253
+ col_name = condition.left.name.lower()
235
254
  # Common indexed columns
236
255
  if col_name in {"created_at", "updated_at", "email", "username", "status", "type"}:
237
256
  return True
@@ -251,20 +270,16 @@ class DMLSafetyValidator(BaseValidator):
251
270
 
252
271
  # For DML statements
253
272
  if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
254
- if hasattr(expression, "this") and expression.this:
273
+ if expression.this:
255
274
  table_expr = expression.this
256
275
  if isinstance(table_expr, exp.Table):
257
276
  tables.append(table_expr.name)
258
277
 
259
278
  # For DDL statements
260
- elif (
261
- isinstance(expression, (exp.Create, exp.Drop, exp.Alter))
262
- and hasattr(expression, "this")
263
- and expression.this
264
- ):
279
+ elif isinstance(expression, (exp.Create, exp.Drop, exp.Alter)) and expression.this:
265
280
  # For CREATE TABLE, the table is in expression.this.this
266
281
  if isinstance(expression, exp.Create) and isinstance(expression.this, exp.Schema):
267
- if hasattr(expression.this, "this") and expression.this.this:
282
+ if expression.this.this:
268
283
  table_expr = expression.this.this
269
284
  if isinstance(table_expr, exp.Table):
270
285
  tables.append(table_expr.name)
@@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Any, Optional, Union
6
6
  from sqlglot import exp
7
7
 
8
8
  from sqlspec.exceptions import MissingParameterError, RiskLevel, SQLValidationError
9
- from sqlspec.statement.pipelines.base import ProcessorProtocol
10
- from sqlspec.statement.pipelines.result_types import ValidationError
11
- from sqlspec.typing import is_dict
9
+ from sqlspec.protocols import ProcessorProtocol
10
+ from sqlspec.statement.pipelines.context import ValidationError
11
+ from sqlspec.utils.type_guards import is_dict
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from sqlspec.statement.pipelines.context import SQLProcessingContext
@@ -73,12 +73,15 @@ class ParameterStyleValidator(ProcessorProtocol):
73
73
  config = context.config
74
74
  param_info = context.parameter_info
75
75
 
76
- # First check parameter styles if configured
76
+ # Check if parameters were normalized by looking for param_ placeholders
77
+ # This happens when Oracle numeric parameters (:1, :2) are normalized
78
+ is_normalized = param_info and any(p.name and p.name.startswith("param_") for p in param_info)
79
+
80
+ # First check parameter styles if configured (skip if normalized)
77
81
  has_style_errors = False
78
- if config.allowed_parameter_styles is not None and param_info:
82
+ if not is_normalized and config.allowed_parameter_styles is not None and param_info:
79
83
  unique_styles = {p.style for p in param_info}
80
84
 
81
- # Check for mixed styles first (before checking individual styles)
82
85
  if len(unique_styles) > 1 and not config.allow_mixed_parameter_styles:
83
86
  detected_style_strs = [str(s) for s in unique_styles]
84
87
  detected_styles = ", ".join(sorted(detected_style_strs))
@@ -95,7 +98,6 @@ class ParameterStyleValidator(ProcessorProtocol):
95
98
  context.validation_errors.append(error)
96
99
  has_style_errors = True
97
100
 
98
- # Check for disallowed styles
99
101
  disallowed_styles = {str(s) for s in unique_styles if not config.validate_parameter_style(s)}
100
102
  if disallowed_styles:
101
103
  disallowed_str = ", ".join(sorted(disallowed_styles))
@@ -276,13 +278,84 @@ class ParameterStyleValidator(ProcessorProtocol):
276
278
  ) -> None:
277
279
  """Handle validation for named parameters."""
278
280
  missing: list[str] = []
279
- for p in param_info:
280
- param_name = p.name
281
- if param_name not in merged_params:
282
- is_synthetic = any(key.startswith(("_arg_", "param_")) for key in merged_params)
283
- is_named_style = p.style.value not in {"qmark", "numeric"}
284
- if (not is_synthetic or is_named_style) and param_name:
285
- missing.append(param_name)
281
+
282
+ # Check if we have normalized parameters (e.g., param_0)
283
+ is_normalized = any(p.name and p.name.startswith("param_") for p in param_info)
284
+
285
+ if is_normalized and hasattr(context, "extra_info"):
286
+ # For normalized parameters, we need to check against the original placeholder mapping
287
+ placeholder_map = context.extra_info.get("placeholder_map", {})
288
+
289
+ # Check if we have Oracle numeric keys in merged_params
290
+ all_numeric_keys = all(key.isdigit() for key in merged_params)
291
+
292
+ if all_numeric_keys:
293
+ # Parameters were provided as list and converted to Oracle numeric dict {"1": val1, "2": val2}
294
+ for i, _p in enumerate(param_info):
295
+ normalized_name = f"param_{i}"
296
+ original_key = placeholder_map.get(normalized_name)
297
+
298
+ if original_key is not None:
299
+ # Check using the original key (e.g., "1", "2" for Oracle)
300
+ original_key_str = str(original_key)
301
+ if original_key_str not in merged_params or merged_params[original_key_str] is None:
302
+ if original_key_str.isdigit():
303
+ missing.append(f":{original_key}")
304
+ else:
305
+ missing.append(f":{original_key}")
306
+ else:
307
+ # Check if all params follow param_N pattern
308
+ all_param_keys = all(key.startswith("param_") and key[6:].isdigit() for key in merged_params)
309
+
310
+ if all_param_keys:
311
+ # This was originally a list converted to dict with param_N keys
312
+ for i, _p in enumerate(param_info):
313
+ normalized_name = f"param_{i}"
314
+ if normalized_name not in merged_params or merged_params[normalized_name] is None:
315
+ # Get original parameter style from placeholder map
316
+ original_key = placeholder_map.get(normalized_name)
317
+ if original_key is not None:
318
+ original_key_str = str(original_key)
319
+ if original_key_str.isdigit():
320
+ missing.append(f":{original_key}")
321
+ else:
322
+ missing.append(f":{original_key}")
323
+ else:
324
+ # Mixed parameter names, check using placeholder map
325
+ for i, _p in enumerate(param_info):
326
+ normalized_name = f"param_{i}"
327
+ original_key = placeholder_map.get(normalized_name)
328
+
329
+ if original_key is not None:
330
+ # For mixed params, check both normalized and original keys
331
+ original_key_str = str(original_key)
332
+
333
+ # First check with normalized name
334
+ found = normalized_name in merged_params and merged_params[normalized_name] is not None
335
+
336
+ # If not found, check with original key
337
+ if not found:
338
+ found = (
339
+ original_key_str in merged_params and merged_params[original_key_str] is not None
340
+ )
341
+
342
+ if not found:
343
+ # Format the missing parameter based on original style
344
+ if original_key_str.isdigit():
345
+ # It was an Oracle numeric parameter (e.g., :1)
346
+ missing.append(f":{original_key}")
347
+ else:
348
+ # It was a named parameter (e.g., :status)
349
+ missing.append(f":{original_key}")
350
+ else:
351
+ # Regular parameter validation
352
+ for p in param_info:
353
+ param_name = p.name
354
+ if param_name not in merged_params or merged_params.get(param_name) is None:
355
+ is_synthetic = any(key.startswith(("arg_", "param_")) for key in merged_params)
356
+ is_named_style = p.style.value not in {"qmark", "numeric"}
357
+ if (not is_synthetic or is_named_style) and param_name:
358
+ missing.append(param_name)
286
359
 
287
360
  if missing:
288
361
  msg = f"Missing required parameters: {', '.join(missing)}"
@@ -18,7 +18,9 @@ from sqlglot.optimizer import (
18
18
  )
19
19
 
20
20
  from sqlspec.exceptions import RiskLevel
21
- from sqlspec.statement.pipelines.validators.base import BaseValidator
21
+ from sqlspec.protocols import ProcessorProtocol
22
+ from sqlspec.statement.pipelines.context import ValidationError
23
+ from sqlspec.utils.type_guards import has_expressions
22
24
 
23
25
  if TYPE_CHECKING:
24
26
  from sqlspec.statement.pipelines.context import SQLProcessingContext
@@ -126,7 +128,7 @@ class PerformanceAnalysis:
126
128
  potential_improvement: float = 0.0
127
129
 
128
130
 
129
- class PerformanceValidator(BaseValidator):
131
+ class PerformanceValidator(ProcessorProtocol):
130
132
  """Comprehensive query performance validator.
131
133
 
132
134
  Validates query performance by detecting:
@@ -143,9 +145,31 @@ class PerformanceValidator(BaseValidator):
143
145
  Args:
144
146
  config: Configuration for performance validation
145
147
  """
146
- super().__init__()
147
148
  self.config = config or PerformanceConfig()
148
149
 
150
+ def process(
151
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
152
+ ) -> "Optional[exp.Expression]":
153
+ """Process the expression for validation (implements ProcessorProtocol)."""
154
+ if expression is None:
155
+ return None
156
+ self.validate(expression, context)
157
+ return expression
158
+
159
+ def add_error(
160
+ self,
161
+ context: "SQLProcessingContext",
162
+ message: str,
163
+ code: str,
164
+ risk_level: RiskLevel,
165
+ expression: "Optional[exp.Expression]" = None,
166
+ ) -> None:
167
+ """Add a validation error to the context."""
168
+ error = ValidationError(
169
+ message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
170
+ )
171
+ context.validation_errors.append(error)
172
+
149
173
  def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
150
174
  """Validate SQL statement for performance issues.
151
175
 
@@ -167,7 +191,6 @@ class PerformanceValidator(BaseValidator):
167
191
  if self.config.enable_optimization_analysis:
168
192
  self._analyze_optimization_opportunities(expression, analysis, context)
169
193
 
170
- # Check for cartesian products
171
194
  if self.config.warn_on_cartesian:
172
195
  cartesian_issues = self._check_cartesian_products(analysis)
173
196
  for issue in cartesian_issues:
@@ -179,7 +202,6 @@ class PerformanceValidator(BaseValidator):
179
202
  expression=expression,
180
203
  )
181
204
 
182
- # Check join complexity
183
205
  if analysis.join_count > self.config.max_joins:
184
206
  self.add_error(
185
207
  context,
@@ -189,7 +211,6 @@ class PerformanceValidator(BaseValidator):
189
211
  expression=expression,
190
212
  )
191
213
 
192
- # Check subquery depth
193
214
  if analysis.max_subquery_depth > self.config.max_subqueries:
194
215
  self.add_error(
195
216
  context,
@@ -213,7 +234,6 @@ class PerformanceValidator(BaseValidator):
213
234
  # Calculate overall complexity score
214
235
  complexity_score = self._calculate_complexity(analysis)
215
236
 
216
- # Build metadata
217
237
  context.metadata[self.__class__.__name__] = {
218
238
  "complexity_score": complexity_score,
219
239
  "join_analysis": {
@@ -260,7 +280,6 @@ class PerformanceValidator(BaseValidator):
260
280
  analysis.current_subquery_depth = max(analysis.current_subquery_depth, depth + 1)
261
281
  analysis.max_subquery_depth = max(analysis.max_subquery_depth, analysis.current_subquery_depth)
262
282
 
263
- # Check if correlated
264
283
  if self._is_correlated_subquery(expr):
265
284
  analysis.correlated_subqueries += 1
266
285
 
@@ -270,7 +289,6 @@ class PerformanceValidator(BaseValidator):
270
289
  join_type = expr.args.get("kind", "INNER").upper()
271
290
  analysis.join_types[join_type] = analysis.join_types.get(join_type, 0) + 1
272
291
 
273
- # Extract join condition
274
292
  condition = expr.args.get("on")
275
293
  left_table = self._get_table_name(expr.parent) if expr.parent else "unknown"
276
294
  right_table = self._get_table_name(expr.this)
@@ -287,10 +305,10 @@ class PerformanceValidator(BaseValidator):
287
305
  analysis.where_conditions += len(list(expr.find_all(exp.Predicate)))
288
306
 
289
307
  elif isinstance(expr, exp.Group):
290
- analysis.group_by_columns += len(expr.expressions) if hasattr(expr, "expressions") else 0
308
+ analysis.group_by_columns += len(expr.expressions) if has_expressions(expr) else 0
291
309
 
292
310
  elif isinstance(expr, exp.Order):
293
- analysis.order_by_columns += len(expr.expressions) if hasattr(expr, "expressions") else 0
311
+ analysis.order_by_columns += len(expr.expressions) if has_expressions(expr) else 0
294
312
 
295
313
  elif isinstance(expr, exp.Distinct):
296
314
  analysis.distinct_operations += 1
@@ -302,13 +320,15 @@ class PerformanceValidator(BaseValidator):
302
320
  analysis.select_star_count += 1
303
321
 
304
322
  # Recursive traversal
305
- for child in expr.args.values():
306
- if isinstance(child, exp.Expression):
307
- self._analyze_expression(child, analysis, depth)
308
- elif isinstance(child, list):
309
- for item in child:
310
- if isinstance(item, exp.Expression):
311
- self._analyze_expression(item, analysis, depth)
323
+ expr_args = getattr(expr, "args", None)
324
+ if expr_args is not None and isinstance(expr_args, dict):
325
+ for child in expr_args.values():
326
+ if isinstance(child, exp.Expression):
327
+ self._analyze_expression(child, analysis, depth)
328
+ elif isinstance(child, list):
329
+ for item in child:
330
+ if isinstance(item, exp.Expression):
331
+ self._analyze_expression(item, analysis, depth)
312
332
 
313
333
  def _check_cartesian_products(self, analysis: PerformanceAnalysis) -> "list[PerformanceIssue]":
314
334
  """Detect potential cartesian products from join analysis.
@@ -335,11 +355,9 @@ class PerformanceValidator(BaseValidator):
335
355
  )
336
356
  )
337
357
  else:
338
- # Build join graph
339
358
  join_graph[condition.left_table].add(condition.right_table)
340
359
  join_graph[condition.right_table].add(condition.left_table)
341
360
 
342
- # Check for disconnected tables (implicit cartesian)
343
361
  if len(analysis.tables) > 1:
344
362
  connected = self._find_connected_components(join_graph, analysis.tables)
345
363
  if len(connected) > 1:
@@ -595,7 +613,6 @@ class PerformanceValidator(BaseValidator):
595
613
 
596
614
  for opt_type, optimizer, description in optimizations:
597
615
  try:
598
- # Apply the optimization
599
616
  optimized = optimizer(expression.copy(), dialect=context.dialect) # type: ignore[operator]
600
617
 
601
618
  if optimized is None:
@@ -623,7 +640,6 @@ class PerformanceValidator(BaseValidator):
623
640
  else:
624
641
  improvement = 0.0
625
642
 
626
- # Only add if improvement meets threshold
627
643
  if improvement >= self.config.optimization_threshold:
628
644
  opportunities.append(
629
645
  OptimizationOpportunity(
@@ -636,7 +652,6 @@ class PerformanceValidator(BaseValidator):
636
652
  )
637
653
  )
638
654
 
639
- # Update the best optimization if this is better
640
655
  if improvement > cumulative_improvement:
641
656
  best_optimized = optimized
642
657
  cumulative_improvement = improvement