sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -12,56 +12,199 @@ Key Components:
|
|
|
12
12
|
- `ValidationError`: Represents a single issue found during validation.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
from
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import TYPE_CHECKING, Optional
|
|
17
|
+
|
|
18
|
+
import sqlglot
|
|
19
|
+
from sqlglot import exp
|
|
20
|
+
from typing_extensions import TypeVar
|
|
21
|
+
|
|
22
|
+
from sqlspec.exceptions import RiskLevel
|
|
23
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
24
|
+
from sqlspec.statement.pipelines.context import (
|
|
25
|
+
AnalysisFinding,
|
|
26
|
+
SQLProcessingContext,
|
|
27
|
+
TransformationLog,
|
|
28
|
+
ValidationError,
|
|
29
|
+
)
|
|
20
30
|
from sqlspec.statement.pipelines.transformers import (
|
|
21
|
-
|
|
31
|
+
CommentAndHintRemover,
|
|
22
32
|
ExpressionSimplifier,
|
|
23
|
-
HintRemover,
|
|
24
33
|
ParameterizeLiterals,
|
|
25
34
|
SimplificationConfig,
|
|
26
35
|
)
|
|
27
36
|
from sqlspec.statement.pipelines.validators import (
|
|
28
37
|
DMLSafetyConfig,
|
|
29
38
|
DMLSafetyValidator,
|
|
39
|
+
ParameterStyleValidator,
|
|
30
40
|
PerformanceConfig,
|
|
31
41
|
PerformanceValidator,
|
|
42
|
+
SecurityValidator,
|
|
32
43
|
SecurityValidatorConfig,
|
|
33
44
|
)
|
|
45
|
+
from sqlspec.utils.correlation import CorrelationContext
|
|
46
|
+
from sqlspec.utils.logging import get_logger
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from sqlspec.statement.parameters import ParameterInfo
|
|
50
|
+
from sqlspec.typing import SQLParameterType
|
|
51
|
+
|
|
34
52
|
|
|
35
53
|
__all__ = (
|
|
36
|
-
# New Result Types
|
|
37
54
|
"AnalysisFinding",
|
|
38
|
-
|
|
39
|
-
"CommentRemover",
|
|
40
|
-
# Concrete Validators
|
|
55
|
+
"CommentAndHintRemover",
|
|
41
56
|
"DMLSafetyConfig",
|
|
42
57
|
"DMLSafetyValidator",
|
|
43
58
|
"ExpressionSimplifier",
|
|
44
|
-
"
|
|
59
|
+
"ParameterStyleValidator",
|
|
45
60
|
"ParameterizeLiterals",
|
|
46
61
|
"PerformanceConfig",
|
|
47
62
|
"PerformanceValidator",
|
|
48
|
-
# Core Pipeline Components
|
|
49
63
|
"PipelineResult",
|
|
50
64
|
"ProcessorProtocol",
|
|
51
65
|
"SQLProcessingContext",
|
|
52
|
-
|
|
53
|
-
"SQLValidator",
|
|
66
|
+
"SecurityValidator",
|
|
54
67
|
"SecurityValidatorConfig",
|
|
55
68
|
"SimplificationConfig",
|
|
56
|
-
# Concrete Analyzers
|
|
57
|
-
"StatementAnalysis",
|
|
58
|
-
"StatementAnalyzer",
|
|
59
|
-
# Core Pipeline & Context
|
|
60
69
|
"StatementPipeline",
|
|
61
70
|
"TransformationLog",
|
|
62
71
|
"ValidationError",
|
|
63
|
-
# Module exports
|
|
64
|
-
"analyzers",
|
|
65
|
-
"transformers",
|
|
66
|
-
"validators",
|
|
67
72
|
)
|
|
73
|
+
|
|
74
|
+
logger = get_logger("pipelines")
|
|
75
|
+
|
|
76
|
+
ExpressionT = TypeVar("ExpressionT", bound="exp.Expression")
|
|
77
|
+
ResultT = TypeVar("ResultT")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Import from context module to avoid duplication
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class PipelineResult:
|
|
85
|
+
"""Final result of pipeline execution."""
|
|
86
|
+
|
|
87
|
+
expression: exp.Expression
|
|
88
|
+
"""The SQL expression after all transformations."""
|
|
89
|
+
|
|
90
|
+
context: SQLProcessingContext
|
|
91
|
+
"""Contains all collected results."""
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def validation_errors(self) -> list[ValidationError]:
|
|
95
|
+
"""Get validation errors from context."""
|
|
96
|
+
return self.context.validation_errors
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def has_errors(self) -> bool:
|
|
100
|
+
"""Check if any validation errors exist."""
|
|
101
|
+
return self.context.has_errors
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def risk_level(self) -> RiskLevel:
|
|
105
|
+
"""Get overall risk level."""
|
|
106
|
+
return self.context.risk_level
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def merged_parameters(self) -> "SQLParameterType":
|
|
110
|
+
"""Get merged parameters from context."""
|
|
111
|
+
return self.context.merged_parameters
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def parameter_info(self) -> "list[ParameterInfo]":
|
|
115
|
+
"""Get parameter info from context."""
|
|
116
|
+
return self.context.parameter_info
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class StatementPipeline:
|
|
120
|
+
"""Orchestrates the processing of an SQL expression through transformers, validators, and analyzers."""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
transformers: Optional[list[ProcessorProtocol]] = None,
|
|
125
|
+
validators: Optional[list[ProcessorProtocol]] = None,
|
|
126
|
+
analyzers: Optional[list[ProcessorProtocol]] = None,
|
|
127
|
+
) -> None:
|
|
128
|
+
self.transformers = transformers or []
|
|
129
|
+
self.validators = validators or []
|
|
130
|
+
self.analyzers = analyzers or []
|
|
131
|
+
|
|
132
|
+
def _run_processors(
|
|
133
|
+
self,
|
|
134
|
+
processors: list[ProcessorProtocol],
|
|
135
|
+
context: SQLProcessingContext,
|
|
136
|
+
processor_type: str,
|
|
137
|
+
enable_flag: bool,
|
|
138
|
+
error_risk_level: RiskLevel,
|
|
139
|
+
) -> None:
|
|
140
|
+
if not enable_flag or context.current_expression is None:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
for processor in processors:
|
|
144
|
+
processor_name = processor.__class__.__name__
|
|
145
|
+
try:
|
|
146
|
+
if processor_type == "transformer":
|
|
147
|
+
context.current_expression = processor.process(context.current_expression, context)
|
|
148
|
+
else:
|
|
149
|
+
processor.process(context.current_expression, context)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
# In strict mode, re-raise validation exceptions
|
|
152
|
+
from sqlspec.exceptions import MissingParameterError
|
|
153
|
+
from sqlspec.statement.pipelines.validators._parameter_style import (
|
|
154
|
+
MixedParameterStyleError,
|
|
155
|
+
UnsupportedParameterStyleError,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if context.config.strict_mode and isinstance(
|
|
159
|
+
e, (MissingParameterError, MixedParameterStyleError, UnsupportedParameterStyleError)
|
|
160
|
+
):
|
|
161
|
+
raise
|
|
162
|
+
|
|
163
|
+
error = ValidationError(
|
|
164
|
+
message=f"{processor_type.capitalize()} {processor_name} failed: {e}",
|
|
165
|
+
code=f"{processor_type}-failure",
|
|
166
|
+
risk_level=error_risk_level,
|
|
167
|
+
processor=processor_name,
|
|
168
|
+
expression=context.current_expression,
|
|
169
|
+
)
|
|
170
|
+
context.validation_errors.append(error)
|
|
171
|
+
logger.exception("%s %s failed", processor_type.capitalize(), processor_name)
|
|
172
|
+
if processor_type == "transformer":
|
|
173
|
+
break # Stop further transformations if one fails
|
|
174
|
+
|
|
175
|
+
def execute_pipeline(self, context: "SQLProcessingContext") -> "PipelineResult":
|
|
176
|
+
"""Executes the full pipeline (transform, validate, analyze) using the SQLProcessingContext."""
|
|
177
|
+
CorrelationContext.get()
|
|
178
|
+
if context.current_expression is None:
|
|
179
|
+
try:
|
|
180
|
+
context.current_expression = sqlglot.parse_one(context.initial_sql_string, dialect=context.dialect)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
error = ValidationError(
|
|
183
|
+
message=f"SQL Parsing Error: {e}",
|
|
184
|
+
code="parsing-error",
|
|
185
|
+
risk_level=RiskLevel.CRITICAL,
|
|
186
|
+
processor="StatementPipeline",
|
|
187
|
+
expression=None,
|
|
188
|
+
)
|
|
189
|
+
context.validation_errors.append(error)
|
|
190
|
+
return PipelineResult(expression=exp.Select(), context=context)
|
|
191
|
+
|
|
192
|
+
# Run transformers
|
|
193
|
+
if self.transformers:
|
|
194
|
+
self._run_processors(
|
|
195
|
+
self.transformers, context, "transformer", enable_flag=True, error_risk_level=RiskLevel.CRITICAL
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Run validators
|
|
199
|
+
if self.validators:
|
|
200
|
+
self._run_processors(
|
|
201
|
+
self.validators, context, "validator", enable_flag=True, error_risk_level=RiskLevel.CRITICAL
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Run analyzers
|
|
205
|
+
if self.analyzers:
|
|
206
|
+
self._run_processors(
|
|
207
|
+
self.analyzers, context, "analyzer", enable_flag=True, error_risk_level=RiskLevel.MEDIUM
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return PipelineResult(expression=context.current_expression or exp.Select(), context=context)
|
|
@@ -7,10 +7,11 @@ from typing import TYPE_CHECKING, Any, Optional
|
|
|
7
7
|
from sqlglot import exp, parse_one
|
|
8
8
|
from sqlglot.errors import ParseError as SQLGlotParseError
|
|
9
9
|
|
|
10
|
-
from sqlspec.
|
|
11
|
-
from sqlspec.statement.pipelines.
|
|
10
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
11
|
+
from sqlspec.statement.pipelines.context import AnalysisFinding
|
|
12
12
|
from sqlspec.utils.correlation import CorrelationContext
|
|
13
13
|
from sqlspec.utils.logging import get_logger
|
|
14
|
+
from sqlspec.utils.type_guards import has_expressions
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from sqlglot.dialects.dialect import DialectType
|
|
@@ -146,7 +147,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
146
147
|
|
|
147
148
|
duration = time.perf_counter() - start_time
|
|
148
149
|
|
|
149
|
-
# Add analysis findings to context
|
|
150
150
|
if analysis_result_obj.complexity_warnings:
|
|
151
151
|
for warning in analysis_result_obj.complexity_warnings:
|
|
152
152
|
finding = AnalysisFinding(key="complexity_warning", value=warning, processor=self.__class__.__name__)
|
|
@@ -194,7 +194,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
194
194
|
if expr is None:
|
|
195
195
|
expr = parse_one(sql_string, dialect=dialect)
|
|
196
196
|
|
|
197
|
-
# Check if the parsed expression is a valid SQL statement type
|
|
198
197
|
# Simple expressions like Alias or Identifier are not valid SQL statements
|
|
199
198
|
valid_statement_types = (
|
|
200
199
|
exp.Select,
|
|
@@ -230,7 +229,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
230
229
|
self, expression: exp.Expression, dialect: "DialectType" = None, config: "Optional[SQLConfig]" = None
|
|
231
230
|
) -> StatementAnalysis:
|
|
232
231
|
"""Analyze a SQLGlot expression directly, potentially using validation results for context."""
|
|
233
|
-
# Check cache first (using expression.sql() as key)
|
|
234
232
|
# This caching needs to be context-aware if analysis depends on prior steps (e.g. validation_result)
|
|
235
233
|
# For simplicity, let's assume for now direct expression analysis is cacheable if validation_result is not used deeply.
|
|
236
234
|
cache_key = expression.sql() # Simplified cache key
|
|
@@ -291,7 +289,7 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
291
289
|
|
|
292
290
|
for select in expression.find_all(exp.Select):
|
|
293
291
|
from_clause = select.args.get("from")
|
|
294
|
-
if from_clause and
|
|
292
|
+
if from_clause and has_expressions(from_clause) and len(from_clause.expressions) > 1:
|
|
295
293
|
# This logic checks for multiple tables in FROM without explicit JOINs
|
|
296
294
|
# It's a simplified check for potential cartesian products
|
|
297
295
|
cartesian_products += 1
|
|
@@ -324,11 +322,7 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
324
322
|
def _analyze_subqueries(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
|
|
325
323
|
"""Analyze subquery complexity and nesting depth."""
|
|
326
324
|
subqueries: list[exp.Expression] = list(expression.find_all(exp.Subquery))
|
|
327
|
-
subqueries
|
|
328
|
-
query
|
|
329
|
-
for in_clause in expression.find_all(exp.In)
|
|
330
|
-
if (query := in_clause.args.get("query")) and isinstance(query, exp.Select)
|
|
331
|
-
)
|
|
325
|
+
# Workaround for EXISTS clauses: sqlglot doesn't wrap EXISTS subqueries in Subquery nodes
|
|
332
326
|
subqueries.extend(
|
|
333
327
|
[
|
|
334
328
|
exists_clause.this
|
|
@@ -346,7 +340,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
346
340
|
"""Calculate the maximum depth of nested SELECT statements."""
|
|
347
341
|
max_depth = 0
|
|
348
342
|
|
|
349
|
-
# Find all SELECT statements
|
|
350
343
|
select_statements = list(expr.find_all(exp.Select))
|
|
351
344
|
|
|
352
345
|
for select in select_statements:
|
|
@@ -354,7 +347,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
354
347
|
depth = 0
|
|
355
348
|
current = select.parent
|
|
356
349
|
while current:
|
|
357
|
-
# Check if parent is a SELECT or if it's inside a SELECT via Subquery/IN/EXISTS
|
|
358
350
|
if isinstance(current, exp.Select):
|
|
359
351
|
depth += 1
|
|
360
352
|
elif isinstance(current, (exp.Subquery, exp.In, exp.Exists)):
|
|
@@ -481,18 +473,21 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
481
473
|
def _extract_primary_table_name(expr: exp.Expression) -> "Optional[str]":
|
|
482
474
|
"""Extract the primary table name from an expression."""
|
|
483
475
|
if isinstance(expr, exp.Insert):
|
|
484
|
-
if expr.this
|
|
485
|
-
# Handle schema.table cases
|
|
476
|
+
if expr.this:
|
|
486
477
|
table = expr.this
|
|
487
478
|
if isinstance(table, exp.Table):
|
|
488
479
|
return table.name
|
|
489
|
-
if
|
|
480
|
+
if isinstance(table, (exp.Identifier, exp.Var)):
|
|
490
481
|
return str(table.name)
|
|
491
482
|
elif isinstance(expr, (exp.Update, exp.Delete)):
|
|
492
483
|
if expr.this:
|
|
493
|
-
|
|
484
|
+
if isinstance(expr.this, (exp.Table, exp.Identifier, exp.Var)):
|
|
485
|
+
return str(expr.this.name)
|
|
486
|
+
return str(expr.this)
|
|
494
487
|
elif isinstance(expr, exp.Select) and (from_clause := expr.find(exp.From)) and from_clause.this:
|
|
495
|
-
|
|
488
|
+
if isinstance(from_clause.this, (exp.Table, exp.Identifier, exp.Var)):
|
|
489
|
+
return str(from_clause.this.name)
|
|
490
|
+
return str(from_clause.this)
|
|
496
491
|
return None
|
|
497
492
|
|
|
498
493
|
@staticmethod
|
|
@@ -500,16 +495,19 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
500
495
|
"""Extract column names from an expression."""
|
|
501
496
|
columns: list[str] = []
|
|
502
497
|
if isinstance(expr, exp.Insert):
|
|
503
|
-
if expr.this and
|
|
504
|
-
columns.extend(
|
|
498
|
+
if expr.this and has_expressions(expr.this):
|
|
499
|
+
columns.extend(
|
|
500
|
+
str(col_expr.name)
|
|
501
|
+
for col_expr in expr.this.expressions
|
|
502
|
+
if isinstance(col_expr, (exp.Column, exp.Identifier, exp.Var))
|
|
503
|
+
)
|
|
505
504
|
elif isinstance(expr, exp.Select):
|
|
506
|
-
# Extract selected columns
|
|
507
505
|
for projection in expr.expressions:
|
|
508
506
|
if isinstance(projection, exp.Column):
|
|
509
507
|
columns.append(str(projection.name))
|
|
510
|
-
elif
|
|
508
|
+
elif isinstance(projection, exp.Alias) and projection.alias:
|
|
511
509
|
columns.append(str(projection.alias))
|
|
512
|
-
elif
|
|
510
|
+
elif isinstance(projection, (exp.Identifier, exp.Var)):
|
|
513
511
|
columns.append(str(projection.name))
|
|
514
512
|
|
|
515
513
|
return columns
|
|
@@ -519,7 +517,7 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
519
517
|
"""Extract all table names referenced in the expression."""
|
|
520
518
|
tables: list[str] = []
|
|
521
519
|
for table in expr.find_all(exp.Table):
|
|
522
|
-
if
|
|
520
|
+
if isinstance(table, exp.Table):
|
|
523
521
|
table_name = str(table.name)
|
|
524
522
|
if table_name not in tables:
|
|
525
523
|
tables.append(table_name)
|
|
@@ -567,7 +565,6 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
567
565
|
# but exclude those within CTEs
|
|
568
566
|
select_statements = []
|
|
569
567
|
for select in expr.find_all(exp.Select):
|
|
570
|
-
# Check if this SELECT is inside a CTE
|
|
571
568
|
parent = select.parent
|
|
572
569
|
is_in_cte = False
|
|
573
570
|
while parent:
|
|
@@ -4,16 +4,45 @@ from typing import TYPE_CHECKING, Any, Optional
|
|
|
4
4
|
from sqlglot import exp
|
|
5
5
|
|
|
6
6
|
from sqlspec.exceptions import RiskLevel
|
|
7
|
-
from sqlspec.statement.pipelines.result_types import AnalysisFinding, TransformationLog, ValidationError
|
|
8
7
|
|
|
9
8
|
if TYPE_CHECKING:
|
|
10
9
|
from sqlglot.dialects.dialect import DialectType
|
|
11
10
|
|
|
12
|
-
from sqlspec.statement.parameters import ParameterInfo
|
|
11
|
+
from sqlspec.statement.parameters import ParameterInfo, ParameterNormalizationState
|
|
13
12
|
from sqlspec.statement.sql import SQLConfig
|
|
14
13
|
from sqlspec.typing import SQLParameterType
|
|
15
14
|
|
|
16
|
-
__all__ = ("
|
|
15
|
+
__all__ = ("AnalysisFinding", "SQLProcessingContext", "TransformationLog", "ValidationError")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ValidationError:
|
|
20
|
+
"""A specific validation issue found during processing."""
|
|
21
|
+
|
|
22
|
+
message: str
|
|
23
|
+
code: str # e.g., "risky-delete", "missing-where"
|
|
24
|
+
risk_level: "RiskLevel"
|
|
25
|
+
processor: str # Which processor found it
|
|
26
|
+
expression: "Optional[exp.Expression]" = None # Problematic sub-expression
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class TransformationLog:
|
|
31
|
+
"""Record of a transformation applied."""
|
|
32
|
+
|
|
33
|
+
description: str
|
|
34
|
+
processor: str
|
|
35
|
+
before: Optional[str] = None # SQL before transform
|
|
36
|
+
after: Optional[str] = None # SQL after transform
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class AnalysisFinding:
|
|
41
|
+
"""Metadata discovered during analysis."""
|
|
42
|
+
|
|
43
|
+
key: str # e.g., "complexity_score", "table_count"
|
|
44
|
+
value: Any
|
|
45
|
+
processor: str
|
|
17
46
|
|
|
18
47
|
|
|
19
48
|
@dataclass
|
|
@@ -70,6 +99,9 @@ class SQLProcessingContext:
|
|
|
70
99
|
extra_info: dict[str, Any] = field(default_factory=dict)
|
|
71
100
|
"""Extra information from parameter processing, including normalization state."""
|
|
72
101
|
|
|
102
|
+
parameter_normalization: "Optional[ParameterNormalizationState]" = None
|
|
103
|
+
"""Single source of truth for parameter normalization tracking."""
|
|
104
|
+
|
|
73
105
|
@property
|
|
74
106
|
def has_errors(self) -> bool:
|
|
75
107
|
"""Check if any validation errors exist."""
|
|
@@ -81,39 +113,3 @@ class SQLProcessingContext:
|
|
|
81
113
|
if not self.validation_errors:
|
|
82
114
|
return RiskLevel.SAFE
|
|
83
115
|
return max(error.risk_level for error in self.validation_errors)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
@dataclass
|
|
87
|
-
class PipelineResult:
|
|
88
|
-
"""Final result of pipeline execution."""
|
|
89
|
-
|
|
90
|
-
expression: exp.Expression
|
|
91
|
-
"""The SQL expression after all transformations."""
|
|
92
|
-
|
|
93
|
-
context: SQLProcessingContext
|
|
94
|
-
"""Contains all collected results."""
|
|
95
|
-
|
|
96
|
-
@property
|
|
97
|
-
def validation_errors(self) -> list[ValidationError]:
|
|
98
|
-
"""Get validation errors from context."""
|
|
99
|
-
return self.context.validation_errors
|
|
100
|
-
|
|
101
|
-
@property
|
|
102
|
-
def has_errors(self) -> bool:
|
|
103
|
-
"""Check if any validation errors exist."""
|
|
104
|
-
return self.context.has_errors
|
|
105
|
-
|
|
106
|
-
@property
|
|
107
|
-
def risk_level(self) -> RiskLevel:
|
|
108
|
-
"""Get overall risk level."""
|
|
109
|
-
return self.context.risk_level
|
|
110
|
-
|
|
111
|
-
@property
|
|
112
|
-
def merged_parameters(self) -> "SQLParameterType":
|
|
113
|
-
"""Get merged parameters from context."""
|
|
114
|
-
return self.context.merged_parameters
|
|
115
|
-
|
|
116
|
-
@property
|
|
117
|
-
def parameter_info(self) -> "list[ParameterInfo]":
|
|
118
|
-
"""Get parameter info from context."""
|
|
119
|
-
return self.context.parameter_info
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from sqlspec.statement.pipelines.transformers._expression_simplifier import ExpressionSimplifier, SimplificationConfig
|
|
4
4
|
from sqlspec.statement.pipelines.transformers._literal_parameterizer import ParameterizeLiterals
|
|
5
|
-
from sqlspec.statement.pipelines.transformers.
|
|
6
|
-
from sqlspec.statement.pipelines.transformers._remove_hints import HintRemover
|
|
5
|
+
from sqlspec.statement.pipelines.transformers._remove_comments_and_hints import CommentAndHintRemover
|
|
7
6
|
|
|
8
|
-
__all__ = ("
|
|
7
|
+
__all__ = ("CommentAndHintRemover", "ExpressionSimplifier", "ParameterizeLiterals", "SimplificationConfig")
|