sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,649 @@
1
+ """SQL statement analyzer for extracting metadata and complexity metrics."""
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Any, Optional
6
+
7
+ from sqlglot import exp, parse_one
8
+ from sqlglot.errors import ParseError as SQLGlotParseError
9
+
10
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
11
+ from sqlspec.statement.pipelines.result_types import AnalysisFinding
12
+ from sqlspec.utils.correlation import CorrelationContext
13
+ from sqlspec.utils.logging import get_logger
14
+
15
+ if TYPE_CHECKING:
16
+ from sqlglot.dialects.dialect import DialectType
17
+
18
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
19
+ from sqlspec.statement.sql import SQLConfig
20
+
21
+ __all__ = ("StatementAnalysis", "StatementAnalyzer")
22
+
23
+ # Constants for statement analysis
24
+ HIGH_SUBQUERY_COUNT_THRESHOLD = 10
25
+ """Threshold for flagging high number of subqueries."""
26
+
27
+ HIGH_CORRELATED_SUBQUERY_THRESHOLD = 3
28
+ """Threshold for flagging multiple correlated subqueries."""
29
+
30
+ EXPENSIVE_FUNCTION_THRESHOLD = 5
31
+ """Threshold for flagging multiple expensive functions."""
32
+
33
+ NESTED_FUNCTION_THRESHOLD = 3
34
+ """Threshold for flagging multiple nested function calls."""
35
+
36
+ logger = get_logger("pipelines.analyzers")
37
+
38
+
39
+ @dataclass
40
+ class StatementAnalysis:
41
+ """Analysis result for parsed SQL statements."""
42
+
43
+ statement_type: str
44
+ """Type of SQL statement (Insert, Select, Update, Delete, etc.)"""
45
+ expression: exp.Expression
46
+ """Parsed SQLGlot expression"""
47
+ table_name: "Optional[str]" = None
48
+ """Primary table name if detected"""
49
+ columns: "list[str]" = field(default_factory=list)
50
+ """Column names if detected"""
51
+ has_returning: bool = False
52
+ """Whether statement has RETURNING clause"""
53
+ is_from_select: bool = False
54
+ """Whether this is an INSERT FROM SELECT pattern"""
55
+ parameters: "dict[str, Any]" = field(default_factory=dict)
56
+ """Extracted parameters from the SQL"""
57
+ tables: "list[str]" = field(default_factory=list)
58
+ """All table names referenced in the query"""
59
+ complexity_score: int = 0
60
+ """Complexity score based on query structure"""
61
+ uses_subqueries: bool = False
62
+ """Whether the query uses subqueries"""
63
+ join_count: int = 0
64
+ """Number of joins in the query"""
65
+ aggregate_functions: "list[str]" = field(default_factory=list)
66
+ """List of aggregate functions used"""
67
+
68
+ # Enhanced complexity metrics
69
+ join_types: "dict[str, int]" = field(default_factory=dict)
70
+ """Types and counts of joins"""
71
+ max_subquery_depth: int = 0
72
+ """Maximum subquery nesting depth"""
73
+ correlated_subquery_count: int = 0
74
+ """Number of correlated subqueries"""
75
+ function_count: int = 0
76
+ """Total number of function calls"""
77
+ where_condition_count: int = 0
78
+ """Number of WHERE conditions"""
79
+ potential_cartesian_products: int = 0
80
+ """Number of potential Cartesian products detected"""
81
+ complexity_warnings: "list[str]" = field(default_factory=list)
82
+ """Warnings about query complexity"""
83
+ complexity_issues: "list[str]" = field(default_factory=list)
84
+ """Issues with query complexity"""
85
+
86
+ # Additional attributes for aggregator compatibility
87
+ subquery_count: int = 0
88
+ """Total number of subqueries"""
89
+ operations: "list[str]" = field(default_factory=list)
90
+ """SQL operations performed (SELECT, JOIN, etc.)"""
91
+ has_aggregation: bool = False
92
+ """Whether query uses aggregation functions"""
93
+ has_window_functions: bool = False
94
+ """Whether query uses window functions"""
95
+ cte_count: int = 0
96
+ """Number of CTEs (Common Table Expressions)"""
97
+
98
+
99
+ class StatementAnalyzer(ProcessorProtocol):
100
+ """SQL statement analyzer that extracts metadata and insights from SQL statements.
101
+
102
+ This processor analyzes SQL expressions to extract useful metadata without
103
+ modifying the SQL itself. It can be used in pipelines to gather insights
104
+ about query complexity, table usage, etc.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ cache_size: int = 1000,
110
+ max_join_count: int = 10,
111
+ max_subquery_depth: int = 3,
112
+ max_function_calls: int = 20,
113
+ max_where_conditions: int = 15,
114
+ ) -> None:
115
+ """Initialize the analyzer.
116
+
117
+ Args:
118
+ cache_size: Maximum number of parsed expressions to cache.
119
+ max_join_count: Maximum allowed joins before flagging.
120
+ max_subquery_depth: Maximum allowed subquery nesting depth.
121
+ max_function_calls: Maximum allowed function calls.
122
+ max_where_conditions: Maximum allowed WHERE conditions.
123
+ """
124
+ self.cache_size = cache_size
125
+ self.max_join_count = max_join_count
126
+ self.max_subquery_depth = max_subquery_depth
127
+ self.max_function_calls = max_function_calls
128
+ self.max_where_conditions = max_where_conditions
129
+ self._parse_cache: dict[tuple[str, Optional[str]], exp.Expression] = {}
130
+ self._analysis_cache: dict[str, StatementAnalysis] = {}
131
+
132
+ def process(
133
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
134
+ ) -> "Optional[exp.Expression]":
135
+ """Process the SQL expression to extract analysis metadata and store it in the context."""
136
+ if expression is None:
137
+ return None
138
+
139
+ CorrelationContext.get()
140
+ start_time = time.perf_counter()
141
+
142
+ if not context.config.enable_analysis:
143
+ return expression
144
+
145
+ analysis_result_obj = self.analyze_expression(expression, context.dialect, context.config)
146
+
147
+ duration = time.perf_counter() - start_time
148
+
149
+ # Add analysis findings to context
150
+ if analysis_result_obj.complexity_warnings:
151
+ for warning in analysis_result_obj.complexity_warnings:
152
+ finding = AnalysisFinding(key="complexity_warning", value=warning, processor=self.__class__.__name__)
153
+ context.analysis_findings.append(finding)
154
+
155
+ if analysis_result_obj.complexity_issues:
156
+ for issue in analysis_result_obj.complexity_issues:
157
+ finding = AnalysisFinding(key="complexity_issue", value=issue, processor=self.__class__.__name__)
158
+ context.analysis_findings.append(finding)
159
+
160
+ # Store metadata in context
161
+ context.metadata[self.__class__.__name__] = {
162
+ "duration_ms": duration * 1000,
163
+ "statement_type": analysis_result_obj.statement_type,
164
+ "table_count": len(analysis_result_obj.tables),
165
+ "has_subqueries": analysis_result_obj.uses_subqueries,
166
+ "join_count": analysis_result_obj.join_count,
167
+ "complexity_score": analysis_result_obj.complexity_score,
168
+ }
169
+ return expression
170
+
171
+ def analyze_statement(self, sql_string: str, dialect: "DialectType" = None) -> StatementAnalysis:
172
+ """Analyze SQL string and extract components efficiently.
173
+
174
+ Args:
175
+ sql_string: The SQL string to analyze
176
+ dialect: SQL dialect for parsing
177
+
178
+ Returns:
179
+ StatementAnalysis with extracted components
180
+ """
181
+ # Check cache first
182
+ cache_key = sql_string.strip()
183
+ if cache_key in self._analysis_cache:
184
+ return self._analysis_cache[cache_key]
185
+
186
+ # Use cache key for expression parsing performance
187
+ parse_cache_key = (sql_string.strip(), str(dialect) if dialect else None)
188
+
189
+ if parse_cache_key in self._parse_cache:
190
+ expr = self._parse_cache[parse_cache_key]
191
+ else:
192
+ try:
193
+ expr = exp.maybe_parse(sql_string, dialect=dialect)
194
+ if expr is None:
195
+ expr = parse_one(sql_string, dialect=dialect)
196
+
197
+ # Check if the parsed expression is a valid SQL statement type
198
+ # Simple expressions like Alias or Identifier are not valid SQL statements
199
+ valid_statement_types = (
200
+ exp.Select,
201
+ exp.Insert,
202
+ exp.Update,
203
+ exp.Delete,
204
+ exp.Create,
205
+ exp.Drop,
206
+ exp.Alter,
207
+ exp.Merge,
208
+ exp.Command,
209
+ exp.Set,
210
+ exp.Show,
211
+ exp.Describe,
212
+ exp.Use,
213
+ exp.Union,
214
+ exp.Intersect,
215
+ exp.Except,
216
+ )
217
+ if not isinstance(expr, valid_statement_types):
218
+ logger.warning("Parsed expression is not a valid SQL statement: %s", type(expr).__name__)
219
+ return StatementAnalysis(statement_type="Unknown", expression=exp.Anonymous(this="UNKNOWN"))
220
+
221
+ if len(self._parse_cache) < self.cache_size:
222
+ self._parse_cache[parse_cache_key] = expr
223
+ except (SQLGlotParseError, Exception) as e:
224
+ logger.warning("Failed to parse SQL statement: %s", e)
225
+ return StatementAnalysis(statement_type="Unknown", expression=exp.Anonymous(this="UNKNOWN"))
226
+
227
+ return self.analyze_expression(expr)
228
+
229
+ def analyze_expression(
230
+ self, expression: exp.Expression, dialect: "DialectType" = None, config: "Optional[SQLConfig]" = None
231
+ ) -> StatementAnalysis:
232
+ """Analyze a SQLGlot expression directly, potentially using validation results for context."""
233
+ # Check cache first (using expression.sql() as key)
234
+ # This caching needs to be context-aware if analysis depends on prior steps (e.g. validation_result)
235
+ # For simplicity, let's assume for now direct expression analysis is cacheable if validation_result is not used deeply.
236
+ cache_key = expression.sql() # Simplified cache key
237
+ if cache_key in self._analysis_cache:
238
+ return self._analysis_cache[cache_key]
239
+
240
+ analysis = StatementAnalysis(
241
+ statement_type=type(expression).__name__,
242
+ expression=expression,
243
+ table_name=self._extract_primary_table_name(expression),
244
+ columns=self._extract_columns(expression),
245
+ has_returning=bool(expression.find(exp.Returning)),
246
+ is_from_select=self._is_insert_from_select(expression),
247
+ parameters=self._extract_parameters(expression),
248
+ tables=self._extract_all_tables(expression),
249
+ uses_subqueries=self._has_subqueries(expression),
250
+ join_count=self._count_joins(expression),
251
+ aggregate_functions=self._extract_aggregate_functions(expression),
252
+ )
253
+ # Calculate subquery_count and cte_count before complexity analysis
254
+ analysis.subquery_count = len(list(expression.find_all(exp.Subquery)))
255
+ # Also need to account for IN/EXISTS subqueries that aren't wrapped in Subquery nodes
256
+ for in_clause in expression.find_all(exp.In):
257
+ if in_clause.args.get("query") and isinstance(in_clause.args.get("query"), exp.Select):
258
+ analysis.subquery_count += 1
259
+ for exists_clause in expression.find_all(exp.Exists):
260
+ if exists_clause.this and isinstance(exists_clause.this, exp.Select):
261
+ analysis.subquery_count += 1
262
+
263
+ # Calculate CTE count before complexity score
264
+ analysis.cte_count = len(list(expression.find_all(exp.CTE)))
265
+
266
+ self._analyze_complexity(expression, analysis)
267
+ analysis.complexity_score = self._calculate_comprehensive_complexity_score(analysis)
268
+ analysis.operations = self._extract_operations(expression)
269
+ analysis.has_aggregation = len(analysis.aggregate_functions) > 0
270
+ analysis.has_window_functions = self._has_window_functions(expression)
271
+
272
+ if len(self._analysis_cache) < self.cache_size:
273
+ self._analysis_cache[cache_key] = analysis
274
+ return analysis
275
+
276
+ def _analyze_complexity(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
277
+ """Perform comprehensive complexity analysis."""
278
+ self._analyze_joins(expression, analysis)
279
+ self._analyze_subqueries(expression, analysis)
280
+ self._analyze_where_clauses(expression, analysis)
281
+ self._analyze_functions(expression, analysis)
282
+
283
+ def _analyze_joins(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
284
+ """Analyze JOIN operations for potential issues."""
285
+ join_nodes = list(expression.find_all(exp.Join))
286
+ analysis.join_count = len(join_nodes)
287
+
288
+ warnings = []
289
+ issues = []
290
+ cartesian_products = 0
291
+
292
+ for select in expression.find_all(exp.Select):
293
+ from_clause = select.args.get("from")
294
+ if from_clause and hasattr(from_clause, "expressions") and len(from_clause.expressions) > 1:
295
+ # This logic checks for multiple tables in FROM without explicit JOINs
296
+ # It's a simplified check for potential cartesian products
297
+ cartesian_products += 1
298
+
299
+ if cartesian_products > 0:
300
+ issues.append(
301
+ f"Potential Cartesian product detected ({cartesian_products} instances from multiple FROM tables without JOIN)"
302
+ )
303
+
304
+ for join_node in join_nodes:
305
+ join_type = join_node.kind.upper() if join_node.kind else "INNER"
306
+ analysis.join_types[join_type] = analysis.join_types.get(join_type, 0) + 1
307
+
308
+ if join_type == "CROSS":
309
+ issues.append("Explicit CROSS JOIN found, potential Cartesian product.")
310
+ cartesian_products += 1
311
+ elif not join_node.args.get("on") and not join_node.args.get("using") and join_type != "NATURAL":
312
+ issues.append(f"JOIN ({join_node.sql()}) without ON/USING clause, potential Cartesian product.")
313
+ cartesian_products += 1
314
+
315
+ if analysis.join_count > self.max_join_count:
316
+ issues.append(f"Excessive number of joins ({analysis.join_count}), may cause performance issues")
317
+ elif analysis.join_count > self.max_join_count // 2:
318
+ warnings.append(f"High number of joins ({analysis.join_count}), monitor performance")
319
+
320
+ analysis.potential_cartesian_products = cartesian_products
321
+ analysis.complexity_warnings.extend(warnings)
322
+ analysis.complexity_issues.extend(issues)
323
+
324
+ def _analyze_subqueries(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
325
+ """Analyze subquery complexity and nesting depth."""
326
+ subqueries: list[exp.Expression] = list(expression.find_all(exp.Subquery))
327
+ subqueries.extend(
328
+ query
329
+ for in_clause in expression.find_all(exp.In)
330
+ if (query := in_clause.args.get("query")) and isinstance(query, exp.Select)
331
+ )
332
+ subqueries.extend(
333
+ [
334
+ exists_clause.this
335
+ for exists_clause in expression.find_all(exp.Exists)
336
+ if exists_clause.this and isinstance(exists_clause.this, exp.Select)
337
+ ]
338
+ )
339
+
340
+ analysis.subquery_count = len(subqueries)
341
+ max_depth = 0
342
+ correlated_count = 0
343
+
344
+ # Calculate maximum nesting depth - simpler approach
345
+ def calculate_depth(expr: exp.Expression) -> int:
346
+ """Calculate the maximum depth of nested SELECT statements."""
347
+ max_depth = 0
348
+
349
+ # Find all SELECT statements
350
+ select_statements = list(expr.find_all(exp.Select))
351
+
352
+ for select in select_statements:
353
+ # Count how many parent SELECTs this one has
354
+ depth = 0
355
+ current = select.parent
356
+ while current:
357
+ # Check if parent is a SELECT or if it's inside a SELECT via Subquery/IN/EXISTS
358
+ if isinstance(current, exp.Select):
359
+ depth += 1
360
+ elif isinstance(current, (exp.Subquery, exp.In, exp.Exists)):
361
+ # These nodes can contain SELECTs, check their parent
362
+ parent = current.parent
363
+ while parent and not isinstance(parent, exp.Select):
364
+ parent = parent.parent
365
+ if parent:
366
+ current = parent
367
+ continue
368
+ current = current.parent if current else None
369
+
370
+ max_depth = max(max_depth, depth)
371
+
372
+ return max_depth
373
+
374
+ max_depth = calculate_depth(expression)
375
+ outer_tables = {tbl.alias or tbl.name for tbl in expression.find_all(exp.Table)}
376
+ for subquery in subqueries:
377
+ for col in subquery.find_all(exp.Column):
378
+ if col.table and col.table in outer_tables:
379
+ correlated_count += 1
380
+ break
381
+
382
+ warnings = []
383
+ issues = []
384
+
385
+ if max_depth > self.max_subquery_depth:
386
+ issues.append(f"Excessive subquery nesting depth ({max_depth})")
387
+ elif max_depth > self.max_subquery_depth // 2:
388
+ warnings.append(f"High subquery nesting depth ({max_depth})")
389
+
390
+ if analysis.subquery_count > HIGH_SUBQUERY_COUNT_THRESHOLD:
391
+ warnings.append(f"High number of subqueries ({analysis.subquery_count})")
392
+
393
+ if correlated_count > HIGH_CORRELATED_SUBQUERY_THRESHOLD:
394
+ warnings.append(f"Multiple correlated subqueries detected ({correlated_count})")
395
+
396
+ analysis.max_subquery_depth = max_depth
397
+ analysis.correlated_subquery_count = correlated_count
398
+ analysis.complexity_warnings.extend(warnings)
399
+ analysis.complexity_issues.extend(issues)
400
+
401
+ def _analyze_where_clauses(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
402
+ """Analyze WHERE clause complexity."""
403
+ where_clauses = list(expression.find_all(exp.Where))
404
+ total_conditions = 0
405
+
406
+ for where_clause in where_clauses:
407
+ total_conditions += len(list(where_clause.find_all(exp.And)))
408
+ total_conditions += len(list(where_clause.find_all(exp.Or)))
409
+
410
+ warnings = []
411
+ issues = []
412
+
413
+ if total_conditions > self.max_where_conditions:
414
+ issues.append(f"Excessive WHERE conditions ({total_conditions})")
415
+ elif total_conditions > self.max_where_conditions // 2:
416
+ warnings.append(f"Complex WHERE clause ({total_conditions} conditions)")
417
+
418
+ analysis.where_condition_count = total_conditions
419
+ analysis.complexity_warnings.extend(warnings)
420
+ analysis.complexity_issues.extend(issues)
421
+
422
+ def _analyze_functions(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
423
+ """Analyze function usage and complexity."""
424
+ function_types: dict[str, int] = {}
425
+ nested_functions = 0
426
+ function_count = 0
427
+ for func in expression.find_all(exp.Func):
428
+ func_name = func.name.lower() if func.name else "unknown"
429
+ function_types[func_name] = function_types.get(func_name, 0) + 1
430
+ if any(isinstance(arg, exp.Func) for arg in func.args.values()):
431
+ nested_functions += 1
432
+ function_count += 1
433
+
434
+ expensive_functions = {"regexp", "regex", "like", "concat_ws", "group_concat"}
435
+ expensive_count = sum(function_types.get(func, 0) for func in expensive_functions)
436
+
437
+ warnings = []
438
+ issues = []
439
+
440
+ if function_count > self.max_function_calls:
441
+ issues.append(f"Excessive function calls ({function_count})")
442
+ elif function_count > self.max_function_calls // 2:
443
+ warnings.append(f"High number of function calls ({function_count})")
444
+
445
+ if expensive_count > EXPENSIVE_FUNCTION_THRESHOLD:
446
+ warnings.append(f"Multiple expensive functions used ({expensive_count})")
447
+
448
+ if nested_functions > NESTED_FUNCTION_THRESHOLD:
449
+ warnings.append(f"Multiple nested function calls ({nested_functions})")
450
+
451
+ analysis.function_count = function_count
452
+ analysis.complexity_warnings.extend(warnings)
453
+ analysis.complexity_issues.extend(issues)
454
+
455
+ @staticmethod
456
+ def _calculate_comprehensive_complexity_score(analysis: StatementAnalysis) -> int:
457
+ """Calculate an overall complexity score based on various metrics."""
458
+ score = 0
459
+
460
+ # Join complexity
461
+ score += analysis.join_count * 3
462
+ score += analysis.potential_cartesian_products * 20
463
+
464
+ # Subquery complexity
465
+ score += analysis.subquery_count * 5 # Use actual subquery count
466
+ score += analysis.max_subquery_depth * 10
467
+ score += analysis.correlated_subquery_count * 8
468
+
469
+ # CTE complexity (CTEs are complex, especially recursive ones)
470
+ score += analysis.cte_count * 7
471
+
472
+ # WHERE clause complexity
473
+ score += analysis.where_condition_count * 2
474
+
475
+ # Function complexity
476
+ score += analysis.function_count * 1
477
+
478
+ return score
479
+
480
+ @staticmethod
481
+ def _extract_primary_table_name(expr: exp.Expression) -> "Optional[str]":
482
+ """Extract the primary table name from an expression."""
483
+ if isinstance(expr, exp.Insert):
484
+ if expr.this and hasattr(expr.this, "this"):
485
+ # Handle schema.table cases
486
+ table = expr.this
487
+ if isinstance(table, exp.Table):
488
+ return table.name
489
+ if hasattr(table, "name"):
490
+ return str(table.name)
491
+ elif isinstance(expr, (exp.Update, exp.Delete)):
492
+ if expr.this:
493
+ return str(expr.this.name) if hasattr(expr.this, "name") else str(expr.this)
494
+ elif isinstance(expr, exp.Select) and (from_clause := expr.find(exp.From)) and from_clause.this:
495
+ return str(from_clause.this.name) if hasattr(from_clause.this, "name") else str(from_clause.this)
496
+ return None
497
+
498
+ @staticmethod
499
+ def _extract_columns(expr: exp.Expression) -> "list[str]":
500
+ """Extract column names from an expression."""
501
+ columns: list[str] = []
502
+ if isinstance(expr, exp.Insert):
503
+ if expr.this and hasattr(expr.this, "expressions"):
504
+ columns.extend(str(col_expr.name) for col_expr in expr.this.expressions if hasattr(col_expr, "name"))
505
+ elif isinstance(expr, exp.Select):
506
+ # Extract selected columns
507
+ for projection in expr.expressions:
508
+ if isinstance(projection, exp.Column):
509
+ columns.append(str(projection.name))
510
+ elif hasattr(projection, "alias") and projection.alias:
511
+ columns.append(str(projection.alias))
512
+ elif hasattr(projection, "name"):
513
+ columns.append(str(projection.name))
514
+
515
+ return columns
516
+
517
+ @staticmethod
518
+ def _extract_all_tables(expr: exp.Expression) -> "list[str]":
519
+ """Extract all table names referenced in the expression."""
520
+ tables: list[str] = []
521
+ for table in expr.find_all(exp.Table):
522
+ if hasattr(table, "name"):
523
+ table_name = str(table.name)
524
+ if table_name not in tables:
525
+ tables.append(table_name)
526
+ return tables
527
+
528
+ @staticmethod
529
+ def _is_insert_from_select(expr: exp.Expression) -> bool:
530
+ """Check if this is an INSERT FROM SELECT pattern."""
531
+ if not isinstance(expr, exp.Insert):
532
+ return False
533
+ return bool(expr.expression and isinstance(expr.expression, exp.Select))
534
+
535
+ @staticmethod
536
+ def _extract_parameters(_expr: exp.Expression) -> "dict[str, Any]":
537
+ """Extract parameters from the expression."""
538
+ # This could be enhanced to extract actual parameter placeholders
539
+ # For now, _expr is unused but will be used in future enhancements
540
+ _ = _expr
541
+ return {}
542
+
543
+ @staticmethod
544
+ def _has_subqueries(expr: exp.Expression) -> bool:
545
+ """Check if the expression contains subqueries.
546
+
547
+ Note: Due to sqlglot parser inconsistency, subqueries in IN clauses
548
+ are not wrapped in Subquery nodes, so we need additional detection.
549
+ CTEs are not considered subqueries.
550
+ """
551
+ # Standard subquery detection
552
+ if expr.find(exp.Subquery):
553
+ return True
554
+
555
+ # sqlglot compatibility: IN clauses with SELECT need explicit handling
556
+ for in_clause in expr.find_all(exp.In):
557
+ query_node = in_clause.args.get("query")
558
+ if query_node and isinstance(query_node, exp.Select):
559
+ return True
560
+
561
+ # sqlglot compatibility: EXISTS clauses with SELECT need explicit handling
562
+ for exists_clause in expr.find_all(exp.Exists):
563
+ if exists_clause.this and isinstance(exists_clause.this, exp.Select):
564
+ return True
565
+
566
+ # Check for multiple SELECT statements (indicates subqueries)
567
+ # but exclude those within CTEs
568
+ select_statements = []
569
+ for select in expr.find_all(exp.Select):
570
+ # Check if this SELECT is inside a CTE
571
+ parent = select.parent
572
+ is_in_cte = False
573
+ while parent:
574
+ if isinstance(parent, exp.CTE):
575
+ is_in_cte = True
576
+ break
577
+ parent = parent.parent
578
+ if not is_in_cte:
579
+ select_statements.append(select)
580
+
581
+ return len(select_statements) > 1
582
+
583
+ @staticmethod
584
+ def _count_joins(expr: exp.Expression) -> int:
585
+ """Count the number of joins in the expression."""
586
+ return len(list(expr.find_all(exp.Join)))
587
+
588
+ @staticmethod
589
+ def _extract_aggregate_functions(expr: exp.Expression) -> "list[str]":
590
+ """Extract aggregate function names from the expression."""
591
+ aggregates: list[str] = []
592
+
593
+ # Common aggregate function types in SQLGlot (using only those that exist)
594
+ aggregate_types = [exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max]
595
+
596
+ for agg_type in aggregate_types:
597
+ if expr.find(agg_type): # Check if this aggregate type exists in the expression
598
+ func_name = agg_type.__name__.lower()
599
+ if func_name not in aggregates:
600
+ aggregates.append(func_name)
601
+
602
+ return aggregates
603
+
604
+ def clear_cache(self) -> None:
605
+ """Clear both parse and analysis caches."""
606
+ self._parse_cache.clear()
607
+ self._analysis_cache.clear()
608
+
609
+ @staticmethod
610
+ def _extract_operations(expr: exp.Expression) -> "list[str]":
611
+ """Extract SQL operations performed."""
612
+ operations = []
613
+
614
+ # Main operation
615
+ if isinstance(expr, exp.Select):
616
+ operations.append("SELECT")
617
+ elif isinstance(expr, exp.Insert):
618
+ operations.append("INSERT")
619
+ elif isinstance(expr, exp.Update):
620
+ operations.append("UPDATE")
621
+ elif isinstance(expr, exp.Delete):
622
+ operations.append("DELETE")
623
+ elif isinstance(expr, exp.Create):
624
+ operations.append("CREATE")
625
+ elif isinstance(expr, exp.Drop):
626
+ operations.append("DROP")
627
+ elif isinstance(expr, exp.Alter):
628
+ operations.append("ALTER")
629
+ if expr.find(exp.Join):
630
+ operations.append("JOIN")
631
+ if expr.find(exp.Group):
632
+ operations.append("GROUP BY")
633
+ if expr.find(exp.Order):
634
+ operations.append("ORDER BY")
635
+ if expr.find(exp.Having):
636
+ operations.append("HAVING")
637
+ if expr.find(exp.Union):
638
+ operations.append("UNION")
639
+ if expr.find(exp.Intersect):
640
+ operations.append("INTERSECT")
641
+ if expr.find(exp.Except):
642
+ operations.append("EXCEPT")
643
+
644
+ return operations
645
+
646
+ @staticmethod
647
+ def _has_window_functions(expr: exp.Expression) -> bool:
648
+ """Check if expression uses window functions."""
649
+ return bool(expr.find(exp.Window))