sqlspec 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/METADATA +97 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""SQL Processing Pipeline Base.
|
|
2
|
+
|
|
3
|
+
This module defines the core framework for constructing and executing a series of
|
|
4
|
+
SQL processing steps, such as transformations and validations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
|
10
|
+
|
|
11
|
+
import sqlglot # Added
|
|
12
|
+
from sqlglot import exp
|
|
13
|
+
from sqlglot.errors import ParseError as SQLGlotParseError # Added
|
|
14
|
+
from typing_extensions import TypeVar
|
|
15
|
+
|
|
16
|
+
from sqlspec.exceptions import RiskLevel, SQLValidationError
|
|
17
|
+
from sqlspec.statement.pipelines.context import PipelineResult
|
|
18
|
+
from sqlspec.statement.pipelines.result_types import ValidationError
|
|
19
|
+
from sqlspec.utils.correlation import CorrelationContext
|
|
20
|
+
from sqlspec.utils.logging import get_logger
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import Sequence
|
|
24
|
+
|
|
25
|
+
from sqlglot.dialects.dialect import DialectType
|
|
26
|
+
|
|
27
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
28
|
+
from sqlspec.statement.sql import SQLConfig, Statement
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
__all__ = ("ProcessorProtocol", "SQLValidator", "StatementPipeline", "UsesExpression")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = get_logger("pipelines")
|
|
35
|
+
|
|
36
|
+
ExpressionT = TypeVar("ExpressionT", bound="exp.Expression")
|
|
37
|
+
ResultT = TypeVar("ResultT")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Copied UsesExpression class here
|
|
41
|
+
class UsesExpression:
|
|
42
|
+
"""Utility mixin class to get a sqlglot expression from various inputs."""
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_expression(statement: "Statement", dialect: "DialectType" = None) -> "exp.Expression":
|
|
46
|
+
"""Convert SQL input to expression.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
statement: The SQL statement to convert to an expression.
|
|
50
|
+
dialect: The SQL dialect.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
SQLValidationError: If the SQL parsing fails.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
An exp.Expression.
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(statement, exp.Expression):
|
|
59
|
+
return statement
|
|
60
|
+
|
|
61
|
+
# Local import to avoid circular dependency at module level
|
|
62
|
+
from sqlspec.statement.sql import SQL
|
|
63
|
+
|
|
64
|
+
if isinstance(statement, SQL):
|
|
65
|
+
expr = statement.expression
|
|
66
|
+
if expr is not None:
|
|
67
|
+
return expr
|
|
68
|
+
return sqlglot.parse_one(statement.sql, read=dialect)
|
|
69
|
+
|
|
70
|
+
# Assuming statement is str hereafter
|
|
71
|
+
sql_str = str(statement)
|
|
72
|
+
if not sql_str or not sql_str.strip():
|
|
73
|
+
return exp.Select()
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
return sqlglot.parse_one(sql_str, read=dialect)
|
|
77
|
+
except SQLGlotParseError as e:
|
|
78
|
+
msg = f"SQL parsing failed: {e}"
|
|
79
|
+
raise SQLValidationError(msg, sql_str, RiskLevel.HIGH) from e
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ProcessorProtocol(ABC):
|
|
83
|
+
"""Defines the interface for a single processing step in the SQL pipeline."""
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def process(
|
|
87
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
88
|
+
) -> "Optional[exp.Expression]":
|
|
89
|
+
"""Processes an SQL expression.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
expression: The SQL expression to process.
|
|
93
|
+
context: The SQLProcessingContext holding the current state and config.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The (possibly modified) SQL expression for transformers, or None for validators/analyzers.
|
|
97
|
+
"""
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class StatementPipeline:
|
|
102
|
+
"""Orchestrates the processing of an SQL expression through transformers, validators, and analyzers."""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
transformers: Optional[list[ProcessorProtocol]] = None,
|
|
107
|
+
validators: Optional[list[ProcessorProtocol]] = None,
|
|
108
|
+
analyzers: Optional[list[ProcessorProtocol]] = None,
|
|
109
|
+
) -> None:
|
|
110
|
+
self.transformers = transformers or []
|
|
111
|
+
self.validators = validators or []
|
|
112
|
+
self.analyzers = analyzers or []
|
|
113
|
+
|
|
114
|
+
def execute_pipeline(self, context: "SQLProcessingContext") -> "PipelineResult":
|
|
115
|
+
"""Executes the full pipeline (transform, validate, analyze) using the SQLProcessingContext."""
|
|
116
|
+
CorrelationContext.get()
|
|
117
|
+
if context.current_expression is None:
|
|
118
|
+
if context.config.enable_parsing:
|
|
119
|
+
try:
|
|
120
|
+
context.current_expression = sqlglot.parse_one(context.initial_sql_string, dialect=context.dialect)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
error = ValidationError(
|
|
123
|
+
message=f"SQL Parsing Error: {e}",
|
|
124
|
+
code="parsing-error",
|
|
125
|
+
risk_level=RiskLevel.CRITICAL,
|
|
126
|
+
processor="StatementPipeline",
|
|
127
|
+
expression=None,
|
|
128
|
+
)
|
|
129
|
+
context.validation_errors.append(error)
|
|
130
|
+
|
|
131
|
+
return PipelineResult(expression=exp.Select(), context=context)
|
|
132
|
+
else:
|
|
133
|
+
# If parsing is disabled and no expression given, it's a config error for the pipeline.
|
|
134
|
+
# However, SQL._initialize_statement should have handled this by not calling the pipeline
|
|
135
|
+
# or by ensuring current_expression is set if enable_parsing is false.
|
|
136
|
+
# For safety, we can raise or create an error result.
|
|
137
|
+
|
|
138
|
+
error = ValidationError(
|
|
139
|
+
message="Pipeline executed without an initial expression and parsing disabled.",
|
|
140
|
+
code="no-expression",
|
|
141
|
+
risk_level=RiskLevel.CRITICAL,
|
|
142
|
+
processor="StatementPipeline",
|
|
143
|
+
expression=None,
|
|
144
|
+
)
|
|
145
|
+
context.validation_errors.append(error)
|
|
146
|
+
|
|
147
|
+
return PipelineResult(
|
|
148
|
+
expression=exp.Select(), # Default empty expression
|
|
149
|
+
context=context,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# 1. Transformation Stage
|
|
153
|
+
if context.config.enable_transformations:
|
|
154
|
+
for transformer in self.transformers:
|
|
155
|
+
transformer_name = transformer.__class__.__name__
|
|
156
|
+
try:
|
|
157
|
+
if context.current_expression is not None:
|
|
158
|
+
context.current_expression = transformer.process(context.current_expression, context)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
# Log transformation failure as a validation error
|
|
161
|
+
|
|
162
|
+
error = ValidationError(
|
|
163
|
+
message=f"Transformer {transformer_name} failed: {e}",
|
|
164
|
+
code="transformer-failure",
|
|
165
|
+
risk_level=RiskLevel.CRITICAL,
|
|
166
|
+
processor=transformer_name,
|
|
167
|
+
expression=context.current_expression,
|
|
168
|
+
)
|
|
169
|
+
context.validation_errors.append(error)
|
|
170
|
+
logger.exception("Transformer %s failed", transformer_name)
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
# 2. Validation Stage
|
|
174
|
+
if context.config.enable_validation:
|
|
175
|
+
for validator_component in self.validators:
|
|
176
|
+
validator_name = validator_component.__class__.__name__
|
|
177
|
+
try:
|
|
178
|
+
# Validators process and add errors to context
|
|
179
|
+
if context.current_expression is not None:
|
|
180
|
+
validator_component.process(context.current_expression, context)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
# Log validator failure
|
|
183
|
+
|
|
184
|
+
error = ValidationError(
|
|
185
|
+
message=f"Validator {validator_name} failed: {e}",
|
|
186
|
+
code="validator-failure",
|
|
187
|
+
risk_level=RiskLevel.CRITICAL,
|
|
188
|
+
processor=validator_name,
|
|
189
|
+
expression=context.current_expression,
|
|
190
|
+
)
|
|
191
|
+
context.validation_errors.append(error)
|
|
192
|
+
logger.exception("Validator %s failed", validator_name)
|
|
193
|
+
|
|
194
|
+
# 3. Analysis Stage
|
|
195
|
+
if context.config.enable_analysis and context.current_expression is not None:
|
|
196
|
+
for analyzer_component in self.analyzers:
|
|
197
|
+
analyzer_name = analyzer_component.__class__.__name__
|
|
198
|
+
try:
|
|
199
|
+
analyzer_component.process(context.current_expression, context)
|
|
200
|
+
except Exception as e:
|
|
201
|
+
error = ValidationError(
|
|
202
|
+
message=f"Analyzer {analyzer_name} failed: {e}",
|
|
203
|
+
code="analyzer-failure",
|
|
204
|
+
risk_level=RiskLevel.MEDIUM,
|
|
205
|
+
processor=analyzer_name,
|
|
206
|
+
expression=context.current_expression,
|
|
207
|
+
)
|
|
208
|
+
context.validation_errors.append(error)
|
|
209
|
+
logger.exception("Analyzer %s failed", analyzer_name)
|
|
210
|
+
|
|
211
|
+
return PipelineResult(expression=context.current_expression or exp.Select(), context=context)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class SQLValidator(ProcessorProtocol, UsesExpression):
|
|
215
|
+
"""Main SQL validator that orchestrates multiple validation checks.
|
|
216
|
+
This class functions as a validation pipeline runner.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
validators: "Optional[Sequence[ProcessorProtocol]]" = None,
|
|
222
|
+
min_risk_to_raise: "Optional[RiskLevel]" = RiskLevel.HIGH,
|
|
223
|
+
) -> None:
|
|
224
|
+
self.validators: list[ProcessorProtocol] = list(validators) if validators is not None else []
|
|
225
|
+
self.min_risk_to_raise = min_risk_to_raise
|
|
226
|
+
|
|
227
|
+
def add_validator(self, validator: "ProcessorProtocol") -> None:
|
|
228
|
+
"""Add a validator to the pipeline."""
|
|
229
|
+
self.validators.append(validator)
|
|
230
|
+
|
|
231
|
+
def process(
|
|
232
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
233
|
+
) -> "Optional[exp.Expression]":
|
|
234
|
+
"""Process the expression through all configured validators.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
expression: The SQL expression to validate.
|
|
238
|
+
context: The SQLProcessingContext holding the current state and config.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The expression unchanged (validators don't transform).
|
|
242
|
+
"""
|
|
243
|
+
if expression is None:
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
if not context.config.enable_validation:
|
|
247
|
+
# Skip validation - add a skip marker to context
|
|
248
|
+
return expression
|
|
249
|
+
|
|
250
|
+
self._run_validators(expression, context)
|
|
251
|
+
return expression
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def _validate_safely(
|
|
255
|
+
validator_instance: "ProcessorProtocol", expression: "exp.Expression", context: "SQLProcessingContext"
|
|
256
|
+
) -> None:
|
|
257
|
+
try:
|
|
258
|
+
validator_instance.process(expression, context)
|
|
259
|
+
except Exception as e:
|
|
260
|
+
# Add error to context
|
|
261
|
+
|
|
262
|
+
error = ValidationError(
|
|
263
|
+
message=f"Validator {validator_instance.__class__.__name__} error: {e}",
|
|
264
|
+
code="validator-error",
|
|
265
|
+
risk_level=RiskLevel.CRITICAL,
|
|
266
|
+
processor=validator_instance.__class__.__name__,
|
|
267
|
+
expression=expression,
|
|
268
|
+
)
|
|
269
|
+
context.validation_errors.append(error)
|
|
270
|
+
logger.warning("Individual validator %s failed: %s", validator_instance.__class__.__name__, e)
|
|
271
|
+
|
|
272
|
+
def _run_validators(self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext") -> None:
|
|
273
|
+
"""Run all validators and handle exceptions."""
|
|
274
|
+
if not expression:
|
|
275
|
+
# If no expression, nothing to validate
|
|
276
|
+
return
|
|
277
|
+
for validator_instance in self.validators:
|
|
278
|
+
self._validate_safely(validator_instance, expression, context)
|
|
279
|
+
|
|
280
|
+
def validate(
|
|
281
|
+
self, sql: "Statement", dialect: "DialectType", config: "Optional[SQLConfig]" = None
|
|
282
|
+
) -> "list[ValidationError]":
|
|
283
|
+
"""Convenience method to validate a raw SQL string or expression.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
List of ValidationError objects found during validation.
|
|
287
|
+
"""
|
|
288
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext # Local import for context
|
|
289
|
+
from sqlspec.statement.sql import SQLConfig # Local import for SQL.to_expression
|
|
290
|
+
|
|
291
|
+
current_config = config or SQLConfig()
|
|
292
|
+
expression_to_validate = self.get_expression(sql, dialect=dialect)
|
|
293
|
+
|
|
294
|
+
# Create a context for this validation run
|
|
295
|
+
validation_context = SQLProcessingContext(
|
|
296
|
+
initial_sql_string=str(sql),
|
|
297
|
+
dialect=dialect,
|
|
298
|
+
config=current_config,
|
|
299
|
+
current_expression=expression_to_validate,
|
|
300
|
+
initial_expression=expression_to_validate,
|
|
301
|
+
# Other context fields like parameters might not be strictly necessary for all validators
|
|
302
|
+
# but good to pass if available or if validators might need them.
|
|
303
|
+
# For a standalone validate() call, parameter context might be minimal.
|
|
304
|
+
input_sql_had_placeholders=False, # Assume false for raw validation, or detect
|
|
305
|
+
)
|
|
306
|
+
if isinstance(sql, str):
|
|
307
|
+
with contextlib.suppress(Exception):
|
|
308
|
+
param_val = current_config.parameter_validator
|
|
309
|
+
if param_val.extract_parameters(sql):
|
|
310
|
+
validation_context.input_sql_had_placeholders = True
|
|
311
|
+
|
|
312
|
+
self.process(expression_to_validate, validation_context)
|
|
313
|
+
|
|
314
|
+
# Return the list of validation errors
|
|
315
|
+
return list(validation_context.validation_errors)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
|
|
4
|
+
from sqlglot import exp
|
|
5
|
+
|
|
6
|
+
from sqlspec.exceptions import RiskLevel
|
|
7
|
+
from sqlspec.statement.pipelines.result_types import AnalysisFinding, TransformationLog, ValidationError
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from sqlglot.dialects.dialect import DialectType
|
|
11
|
+
|
|
12
|
+
from sqlspec.statement.parameters import ParameterInfo
|
|
13
|
+
from sqlspec.statement.sql import SQLConfig
|
|
14
|
+
from sqlspec.typing import SQLParameterType
|
|
15
|
+
|
|
16
|
+
__all__ = ("PipelineResult", "SQLProcessingContext")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SQLProcessingContext:
|
|
21
|
+
"""Carries expression through pipeline and collects all results."""
|
|
22
|
+
|
|
23
|
+
# Input
|
|
24
|
+
initial_sql_string: str
|
|
25
|
+
"""The original SQL string input by the user."""
|
|
26
|
+
|
|
27
|
+
dialect: "DialectType"
|
|
28
|
+
"""The SQL dialect to be used for parsing and generation."""
|
|
29
|
+
|
|
30
|
+
config: "SQLConfig"
|
|
31
|
+
"""The configuration for SQL processing for this statement."""
|
|
32
|
+
|
|
33
|
+
# Initial state
|
|
34
|
+
initial_expression: Optional[exp.Expression] = None
|
|
35
|
+
"""The initial parsed expression (for diffing/auditing)."""
|
|
36
|
+
|
|
37
|
+
# Current state
|
|
38
|
+
current_expression: Optional[exp.Expression] = None
|
|
39
|
+
"""The SQL expression, potentially modified by transformers."""
|
|
40
|
+
|
|
41
|
+
# Parameters
|
|
42
|
+
initial_parameters: "Optional[SQLParameterType]" = None
|
|
43
|
+
"""The initial parameters as provided to the SQL object (before merging with kwargs)."""
|
|
44
|
+
initial_kwargs: "Optional[dict[str, Any]]" = None
|
|
45
|
+
"""The initial keyword arguments as provided to the SQL object."""
|
|
46
|
+
merged_parameters: "SQLParameterType" = field(default_factory=list)
|
|
47
|
+
"""Parameters after merging initial_parameters and initial_kwargs."""
|
|
48
|
+
parameter_info: "list[ParameterInfo]" = field(default_factory=list)
|
|
49
|
+
"""Information about identified parameters in the initial_sql_string."""
|
|
50
|
+
extracted_parameters_from_pipeline: list[Any] = field(default_factory=list)
|
|
51
|
+
"""List of parameters extracted by transformers (e.g., ParameterizeLiterals)."""
|
|
52
|
+
|
|
53
|
+
# Collected results (processors append to these)
|
|
54
|
+
validation_errors: list[ValidationError] = field(default_factory=list)
|
|
55
|
+
"""Validation errors found during processing."""
|
|
56
|
+
analysis_findings: list[AnalysisFinding] = field(default_factory=list)
|
|
57
|
+
"""Analysis findings discovered during processing."""
|
|
58
|
+
transformations: list[TransformationLog] = field(default_factory=list)
|
|
59
|
+
"""Transformations applied during processing."""
|
|
60
|
+
|
|
61
|
+
# General metadata
|
|
62
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
63
|
+
"""General-purpose metadata store."""
|
|
64
|
+
|
|
65
|
+
# Flags
|
|
66
|
+
input_sql_had_placeholders: bool = False
|
|
67
|
+
"""Flag indicating if the initial_sql_string already contained placeholders."""
|
|
68
|
+
statement_type: Optional[str] = None
|
|
69
|
+
"""The detected type of the SQL statement (e.g., SELECT, INSERT, DDL)."""
|
|
70
|
+
extra_info: dict[str, Any] = field(default_factory=dict)
|
|
71
|
+
"""Extra information from parameter processing, including normalization state."""
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def has_errors(self) -> bool:
|
|
75
|
+
"""Check if any validation errors exist."""
|
|
76
|
+
return bool(self.validation_errors)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def risk_level(self) -> RiskLevel:
|
|
80
|
+
"""Calculate overall risk from validation errors."""
|
|
81
|
+
if not self.validation_errors:
|
|
82
|
+
return RiskLevel.SAFE
|
|
83
|
+
return max(error.risk_level for error in self.validation_errors)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class PipelineResult:
|
|
88
|
+
"""Final result of pipeline execution."""
|
|
89
|
+
|
|
90
|
+
expression: exp.Expression
|
|
91
|
+
"""The SQL expression after all transformations."""
|
|
92
|
+
|
|
93
|
+
context: SQLProcessingContext
|
|
94
|
+
"""Contains all collected results."""
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def validation_errors(self) -> list[ValidationError]:
|
|
98
|
+
"""Get validation errors from context."""
|
|
99
|
+
return self.context.validation_errors
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def has_errors(self) -> bool:
|
|
103
|
+
"""Check if any validation errors exist."""
|
|
104
|
+
return self.context.has_errors
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def risk_level(self) -> RiskLevel:
|
|
108
|
+
"""Get overall risk level."""
|
|
109
|
+
return self.context.risk_level
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def merged_parameters(self) -> "SQLParameterType":
|
|
113
|
+
"""Get merged parameters from context."""
|
|
114
|
+
return self.context.merged_parameters
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def parameter_info(self) -> "list[ParameterInfo]":
|
|
118
|
+
"""Get parameter info from context."""
|
|
119
|
+
return self.context.parameter_info
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Specific result types for the SQL processing pipeline."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from sqlglot import exp
|
|
8
|
+
|
|
9
|
+
from sqlspec.exceptions import RiskLevel
|
|
10
|
+
|
|
11
|
+
__all__ = ("AnalysisFinding", "TransformationLog", "ValidationError")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ValidationError:
|
|
16
|
+
"""A specific validation issue found during processing."""
|
|
17
|
+
|
|
18
|
+
message: str
|
|
19
|
+
code: str # e.g., "risky-delete", "missing-where"
|
|
20
|
+
risk_level: "RiskLevel"
|
|
21
|
+
processor: str # Which processor found it
|
|
22
|
+
expression: "Optional[exp.Expression]" = None # Problematic sub-expression
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class AnalysisFinding:
|
|
27
|
+
"""Metadata discovered during analysis."""
|
|
28
|
+
|
|
29
|
+
key: str # e.g., "complexity_score", "table_count"
|
|
30
|
+
value: Any
|
|
31
|
+
processor: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class TransformationLog:
|
|
36
|
+
"""Record of a transformation applied."""
|
|
37
|
+
|
|
38
|
+
description: str
|
|
39
|
+
processor: str
|
|
40
|
+
before: Optional[str] = None # SQL before transform
|
|
41
|
+
after: Optional[str] = None # SQL after transform
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""SQL Transformers for the processing pipeline."""
|
|
2
|
+
|
|
3
|
+
from sqlspec.statement.pipelines.transformers._expression_simplifier import ExpressionSimplifier, SimplificationConfig
|
|
4
|
+
from sqlspec.statement.pipelines.transformers._literal_parameterizer import ParameterizeLiterals
|
|
5
|
+
from sqlspec.statement.pipelines.transformers._remove_comments import CommentRemover
|
|
6
|
+
from sqlspec.statement.pipelines.transformers._remove_hints import HintRemover
|
|
7
|
+
|
|
8
|
+
__all__ = ("CommentRemover", "ExpressionSimplifier", "HintRemover", "ParameterizeLiterals", "SimplificationConfig")
|