sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Removes SQL hints from expressions."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
from sqlspec.statement.pipelines.base import ProcessorProtocol
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
11
|
+
|
|
12
|
+
__all__ = ("HintRemover",)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HintRemover(ProcessorProtocol):
|
|
16
|
+
"""Removes SQL hints from expressions using SQLGlot's AST traversal.
|
|
17
|
+
|
|
18
|
+
This transformer removes SQL hints while preserving standard comments:
|
|
19
|
+
- Removes Oracle-style hints (/*+ hint */)
|
|
20
|
+
- Removes MySQL version comments (/*!50000 */)
|
|
21
|
+
- Removes formal hint expressions (exp.Hint nodes)
|
|
22
|
+
- Preserves standard comments (-- comment, /* comment */)
|
|
23
|
+
- Uses SQLGlot's AST for reliable, context-aware hint detection
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
enabled: Whether hint removal is enabled.
|
|
27
|
+
remove_oracle_hints: Whether to remove Oracle-style hints (/*+ hint */).
|
|
28
|
+
remove_mysql_version_comments: Whether to remove MySQL /*!50000 */ style comments.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self, enabled: bool = True, remove_oracle_hints: bool = True, remove_mysql_version_comments: bool = True
|
|
33
|
+
) -> None:
|
|
34
|
+
self.enabled = enabled
|
|
35
|
+
self.remove_oracle_hints = remove_oracle_hints
|
|
36
|
+
self.remove_mysql_version_comments = remove_mysql_version_comments
|
|
37
|
+
|
|
38
|
+
def process(
|
|
39
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
40
|
+
) -> "Optional[exp.Expression]":
|
|
41
|
+
"""Removes SQL hints from the expression using SQLGlot AST traversal."""
|
|
42
|
+
if not self.enabled or expression is None or context.current_expression is None:
|
|
43
|
+
return expression
|
|
44
|
+
|
|
45
|
+
hints_removed_count = 0
|
|
46
|
+
|
|
47
|
+
def _remove_hint_node(node: exp.Expression) -> "Optional[exp.Expression]":
|
|
48
|
+
nonlocal hints_removed_count
|
|
49
|
+
if isinstance(node, exp.Hint):
|
|
50
|
+
hints_removed_count += 1
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
if hasattr(node, "comments") and node.comments:
|
|
54
|
+
original_comment_count = len(node.comments)
|
|
55
|
+
comments_to_keep = []
|
|
56
|
+
for comment in node.comments:
|
|
57
|
+
comment_text = str(comment).strip()
|
|
58
|
+
hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
|
|
59
|
+
is_oracle_hint = any(keyword in comment_text.upper() for keyword in hint_keywords)
|
|
60
|
+
|
|
61
|
+
if is_oracle_hint:
|
|
62
|
+
if self.remove_oracle_hints:
|
|
63
|
+
continue
|
|
64
|
+
elif comment_text.startswith("!") and self.remove_mysql_version_comments:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
comments_to_keep.append(comment)
|
|
68
|
+
|
|
69
|
+
if len(comments_to_keep) < original_comment_count:
|
|
70
|
+
hints_removed_count += original_comment_count - len(comments_to_keep)
|
|
71
|
+
node.pop_comments()
|
|
72
|
+
if comments_to_keep:
|
|
73
|
+
node.add_comments(comments_to_keep)
|
|
74
|
+
return node
|
|
75
|
+
|
|
76
|
+
transformed_expression = context.current_expression.transform(_remove_hint_node, copy=True)
|
|
77
|
+
context.current_expression = transformed_expression or exp.Anonymous(this="")
|
|
78
|
+
|
|
79
|
+
context.metadata["hints_removed"] = hints_removed_count
|
|
80
|
+
|
|
81
|
+
return context.current_expression
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""SQL Validation Pipeline Components."""
|
|
2
|
+
|
|
3
|
+
from sqlspec.statement.pipelines.validators._dml_safety import DMLSafetyConfig, DMLSafetyValidator
|
|
4
|
+
from sqlspec.statement.pipelines.validators._parameter_style import ParameterStyleValidator
|
|
5
|
+
from sqlspec.statement.pipelines.validators._performance import PerformanceConfig, PerformanceValidator
|
|
6
|
+
from sqlspec.statement.pipelines.validators._security import (
|
|
7
|
+
SecurityIssue,
|
|
8
|
+
SecurityIssueType,
|
|
9
|
+
SecurityValidator,
|
|
10
|
+
SecurityValidatorConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = (
|
|
14
|
+
"DMLSafetyConfig",
|
|
15
|
+
"DMLSafetyValidator",
|
|
16
|
+
"ParameterStyleValidator",
|
|
17
|
+
"PerformanceConfig",
|
|
18
|
+
"PerformanceValidator",
|
|
19
|
+
"SecurityIssue",
|
|
20
|
+
"SecurityIssueType",
|
|
21
|
+
"SecurityValidator",
|
|
22
|
+
"SecurityValidatorConfig",
|
|
23
|
+
)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# DML Safety Validator - Consolidates risky DML operations and DDL prevention
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
|
|
6
|
+
from sqlglot import expressions as exp
|
|
7
|
+
|
|
8
|
+
from sqlspec.exceptions import RiskLevel
|
|
9
|
+
from sqlspec.statement.pipelines.validators.base import BaseValidator
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
13
|
+
|
|
14
|
+
__all__ = ("DMLSafetyConfig", "DMLSafetyValidator", "StatementCategory")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class StatementCategory(Enum):
|
|
18
|
+
"""Categories for SQL statement types."""
|
|
19
|
+
|
|
20
|
+
DDL = "ddl" # CREATE, ALTER, DROP, TRUNCATE
|
|
21
|
+
DML = "dml" # INSERT, UPDATE, DELETE, MERGE
|
|
22
|
+
DQL = "dql" # SELECT
|
|
23
|
+
DCL = "dcl" # GRANT, REVOKE
|
|
24
|
+
TCL = "tcl" # COMMIT, ROLLBACK, SAVEPOINT
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DMLSafetyConfig:
|
|
29
|
+
"""Configuration for DML safety validation."""
|
|
30
|
+
|
|
31
|
+
prevent_ddl: bool = True
|
|
32
|
+
prevent_dcl: bool = True
|
|
33
|
+
require_where_clause: "set[str]" = field(default_factory=lambda: {"DELETE", "UPDATE"})
|
|
34
|
+
allowed_ddl_operations: "set[str]" = field(default_factory=set)
|
|
35
|
+
migration_mode: bool = False # Allow DDL in migration contexts
|
|
36
|
+
max_affected_rows: "Optional[int]" = None # Limit for DML operations
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DMLSafetyValidator(BaseValidator):
|
|
40
|
+
"""Unified validator for DML/DDL safety checks.
|
|
41
|
+
|
|
42
|
+
This validator consolidates:
|
|
43
|
+
- DDL prevention (CREATE, ALTER, DROP, etc.)
|
|
44
|
+
- Risky DML detection (DELETE/UPDATE without WHERE)
|
|
45
|
+
- DCL restrictions (GRANT, REVOKE)
|
|
46
|
+
- Row limit enforcement
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config: "Optional[DMLSafetyConfig]" = None) -> None:
|
|
50
|
+
"""Initialize the DML safety validator.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: Configuration for safety validation
|
|
54
|
+
"""
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.config = config or DMLSafetyConfig()
|
|
57
|
+
|
|
58
|
+
def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
|
|
59
|
+
"""Validate SQL statement for safety issues.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
expression: The SQL expression to validate
|
|
63
|
+
context: The SQL processing context
|
|
64
|
+
"""
|
|
65
|
+
# Categorize statement
|
|
66
|
+
category = self._categorize_statement(expression)
|
|
67
|
+
operation = self._get_operation_type(expression)
|
|
68
|
+
|
|
69
|
+
# Check DDL restrictions
|
|
70
|
+
if category == StatementCategory.DDL and self.config.prevent_ddl:
|
|
71
|
+
if operation not in self.config.allowed_ddl_operations:
|
|
72
|
+
self.add_error(
|
|
73
|
+
context,
|
|
74
|
+
message=f"DDL operation '{operation}' is not allowed",
|
|
75
|
+
code="ddl-not-allowed",
|
|
76
|
+
risk_level=RiskLevel.CRITICAL,
|
|
77
|
+
expression=expression,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Check DML safety
|
|
81
|
+
elif category == StatementCategory.DML:
|
|
82
|
+
if operation in self.config.require_where_clause and not self._has_where_clause(expression):
|
|
83
|
+
self.add_error(
|
|
84
|
+
context,
|
|
85
|
+
message=f"{operation} without WHERE clause affects all rows",
|
|
86
|
+
code=f"{operation.lower()}-without-where",
|
|
87
|
+
risk_level=RiskLevel.HIGH,
|
|
88
|
+
expression=expression,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Check affected row limits
|
|
92
|
+
if self.config.max_affected_rows:
|
|
93
|
+
estimated_rows = self._estimate_affected_rows(expression)
|
|
94
|
+
if estimated_rows > self.config.max_affected_rows:
|
|
95
|
+
self.add_error(
|
|
96
|
+
context,
|
|
97
|
+
message=f"Operation may affect {estimated_rows:,} rows (limit: {self.config.max_affected_rows:,})",
|
|
98
|
+
code="excessive-rows-affected",
|
|
99
|
+
risk_level=RiskLevel.MEDIUM,
|
|
100
|
+
expression=expression,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Check DCL restrictions
|
|
104
|
+
elif category == StatementCategory.DCL and self.config.prevent_dcl:
|
|
105
|
+
self.add_error(
|
|
106
|
+
context,
|
|
107
|
+
message=f"DCL operation '{operation}' is not allowed",
|
|
108
|
+
code="dcl-not-allowed",
|
|
109
|
+
risk_level=RiskLevel.HIGH,
|
|
110
|
+
expression=expression,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Store metadata in context
|
|
114
|
+
context.metadata[self.__class__.__name__] = {
|
|
115
|
+
"statement_category": category.value,
|
|
116
|
+
"operation": operation,
|
|
117
|
+
"has_where_clause": self._has_where_clause(expression) if category == StatementCategory.DML else None,
|
|
118
|
+
"affected_tables": self._extract_affected_tables(expression),
|
|
119
|
+
"migration_mode": self.config.migration_mode,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _categorize_statement(expression: "exp.Expression") -> StatementCategory:
|
|
124
|
+
"""Categorize SQL statement type.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
expression: The SQL expression to categorize
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The statement category
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(expression, (exp.Create, exp.Alter, exp.Drop, exp.TruncateTable, exp.Comment)):
|
|
133
|
+
return StatementCategory.DDL
|
|
134
|
+
|
|
135
|
+
if isinstance(expression, (exp.Select, exp.Union, exp.Intersect, exp.Except)):
|
|
136
|
+
return StatementCategory.DQL
|
|
137
|
+
|
|
138
|
+
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete, exp.Merge)):
|
|
139
|
+
return StatementCategory.DML
|
|
140
|
+
|
|
141
|
+
if isinstance(expression, (exp.Grant,)):
|
|
142
|
+
return StatementCategory.DCL
|
|
143
|
+
|
|
144
|
+
if isinstance(expression, (exp.Commit, exp.Rollback)):
|
|
145
|
+
return StatementCategory.TCL
|
|
146
|
+
|
|
147
|
+
return StatementCategory.DQL # Default to query
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def _get_operation_type(expression: "exp.Expression") -> str:
|
|
151
|
+
"""Get specific operation name.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
expression: The SQL expression
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
The operation type as string
|
|
158
|
+
"""
|
|
159
|
+
return expression.__class__.__name__.upper()
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def _has_where_clause(expression: "exp.Expression") -> bool:
|
|
163
|
+
"""Check if DML statement has WHERE clause.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
expression: The SQL expression to check
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
True if WHERE clause exists, False otherwise
|
|
170
|
+
"""
|
|
171
|
+
if isinstance(expression, (exp.Delete, exp.Update)):
|
|
172
|
+
return expression.args.get("where") is not None
|
|
173
|
+
return True # Other statements don't require WHERE
|
|
174
|
+
|
|
175
|
+
def _estimate_affected_rows(self, expression: "exp.Expression") -> int:
|
|
176
|
+
"""Estimate number of rows affected by DML operation.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
expression: The SQL expression
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Estimated number of affected rows
|
|
183
|
+
"""
|
|
184
|
+
# Simple heuristic - can be enhanced with table statistics
|
|
185
|
+
if not self._has_where_clause(expression):
|
|
186
|
+
return 999999999 # Large number to indicate all rows
|
|
187
|
+
|
|
188
|
+
where = expression.args.get("where")
|
|
189
|
+
if where:
|
|
190
|
+
# Check for primary key or unique conditions
|
|
191
|
+
if self._has_unique_condition(where):
|
|
192
|
+
return 1
|
|
193
|
+
# Check for indexed conditions
|
|
194
|
+
if self._has_indexed_condition(where):
|
|
195
|
+
return 100 # Rough estimate
|
|
196
|
+
|
|
197
|
+
return 10000 # Conservative estimate
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def _has_unique_condition(where: "Optional[exp.Expression]") -> bool:
|
|
201
|
+
"""Check if WHERE clause uses unique columns.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
where: The WHERE expression
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
True if unique condition found
|
|
208
|
+
"""
|
|
209
|
+
if where is None:
|
|
210
|
+
return False
|
|
211
|
+
# Look for id = value patterns
|
|
212
|
+
for condition in where.find_all(exp.EQ):
|
|
213
|
+
if isinstance(condition.left, exp.Column):
|
|
214
|
+
col_name = condition.left.name.lower()
|
|
215
|
+
if col_name in {"id", "uuid", "guid", "pk", "primary_key"}:
|
|
216
|
+
return True
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def _has_indexed_condition(where: "Optional[exp.Expression]") -> bool:
|
|
221
|
+
"""Check if WHERE clause uses indexed columns.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
where: The WHERE expression
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
True if indexed condition found
|
|
228
|
+
"""
|
|
229
|
+
if where is None:
|
|
230
|
+
return False
|
|
231
|
+
# Look for common indexed column patterns
|
|
232
|
+
for condition in where.find_all(exp.Predicate):
|
|
233
|
+
if hasattr(condition, "left") and isinstance(condition.left, exp.Column): # pyright: ignore
|
|
234
|
+
col_name = condition.left.name.lower() # pyright: ignore
|
|
235
|
+
# Common indexed columns
|
|
236
|
+
if col_name in {"created_at", "updated_at", "email", "username", "status", "type"}:
|
|
237
|
+
return True
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
def _extract_affected_tables(expression: "exp.Expression") -> "list[str]":
|
|
242
|
+
"""Extract table names affected by the statement.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
expression: The SQL expression
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
List of affected table names
|
|
249
|
+
"""
|
|
250
|
+
tables = []
|
|
251
|
+
|
|
252
|
+
# For DML statements
|
|
253
|
+
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
|
|
254
|
+
if hasattr(expression, "this") and expression.this:
|
|
255
|
+
table_expr = expression.this
|
|
256
|
+
if isinstance(table_expr, exp.Table):
|
|
257
|
+
tables.append(table_expr.name)
|
|
258
|
+
|
|
259
|
+
# For DDL statements
|
|
260
|
+
elif (
|
|
261
|
+
isinstance(expression, (exp.Create, exp.Drop, exp.Alter))
|
|
262
|
+
and hasattr(expression, "this")
|
|
263
|
+
and expression.this
|
|
264
|
+
):
|
|
265
|
+
# For CREATE TABLE, the table is in expression.this.this
|
|
266
|
+
if isinstance(expression, exp.Create) and isinstance(expression.this, exp.Schema):
|
|
267
|
+
if hasattr(expression.this, "this") and expression.this.this:
|
|
268
|
+
table_expr = expression.this.this
|
|
269
|
+
if isinstance(table_expr, exp.Table):
|
|
270
|
+
tables.append(table_expr.name)
|
|
271
|
+
# For DROP/ALTER, table is directly in expression.this
|
|
272
|
+
elif isinstance(expression.this, (exp.Table, exp.Identifier)):
|
|
273
|
+
tables.append(expression.this.name)
|
|
274
|
+
|
|
275
|
+
return tables
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Parameter style validation for SQL statements."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
|
+
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
|
|
8
|
+
from sqlspec.exceptions import MissingParameterError, RiskLevel, SQLValidationError
|
|
9
|
+
from sqlspec.statement.pipelines.base import ProcessorProtocol
|
|
10
|
+
from sqlspec.statement.pipelines.result_types import ValidationError
|
|
11
|
+
from sqlspec.typing import is_dict
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("sqlspec.validators.parameter_style")
|
|
17
|
+
|
|
18
|
+
__all__ = ("ParameterStyleValidator",)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UnsupportedParameterStyleError(SQLValidationError):
|
|
22
|
+
"""Raised when a parameter style is not supported by the current database."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MixedParameterStyleError(SQLValidationError):
|
|
26
|
+
"""Raised when mixed parameter styles are detected but not allowed."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ParameterStyleValidator(ProcessorProtocol):
|
|
30
|
+
"""Validates that parameter styles are supported by the database configuration.
|
|
31
|
+
|
|
32
|
+
This validator checks:
|
|
33
|
+
1. Whether detected parameter styles are in the allowed list
|
|
34
|
+
2. Whether mixed parameter styles are used when not allowed
|
|
35
|
+
3. Provides helpful error messages about supported styles
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, risk_level: "RiskLevel" = RiskLevel.HIGH, fail_on_violation: bool = True) -> None:
|
|
39
|
+
"""Initialize the parameter style validator.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
risk_level: Risk level for unsupported parameter styles
|
|
43
|
+
fail_on_violation: Whether to raise exception on violation
|
|
44
|
+
"""
|
|
45
|
+
self.risk_level = risk_level
|
|
46
|
+
self.fail_on_violation = fail_on_violation
|
|
47
|
+
|
|
48
|
+
def process(self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext") -> None:
|
|
49
|
+
"""Validate parameter styles in SQL.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
expression: The SQL expression being validated
|
|
53
|
+
context: SQL processing context with config
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A ProcessorResult with the outcome of the validation.
|
|
57
|
+
"""
|
|
58
|
+
if expression is None:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if context.current_expression is None:
|
|
62
|
+
error = ValidationError(
|
|
63
|
+
message="ParameterStyleValidator received no expression.",
|
|
64
|
+
code="no-expression",
|
|
65
|
+
risk_level=RiskLevel.CRITICAL,
|
|
66
|
+
processor="ParameterStyleValidator",
|
|
67
|
+
expression=None,
|
|
68
|
+
)
|
|
69
|
+
context.validation_errors.append(error)
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
config = context.config
|
|
74
|
+
param_info = context.parameter_info
|
|
75
|
+
|
|
76
|
+
# First check parameter styles if configured
|
|
77
|
+
has_style_errors = False
|
|
78
|
+
if config.allowed_parameter_styles is not None and param_info:
|
|
79
|
+
unique_styles = {p.style for p in param_info}
|
|
80
|
+
|
|
81
|
+
# Check for mixed styles first (before checking individual styles)
|
|
82
|
+
if len(unique_styles) > 1 and not config.allow_mixed_parameter_styles:
|
|
83
|
+
detected_style_strs = [str(s) for s in unique_styles]
|
|
84
|
+
detected_styles = ", ".join(sorted(detected_style_strs))
|
|
85
|
+
msg = f"Mixed parameter styles detected ({detected_styles}) but not allowed."
|
|
86
|
+
if self.fail_on_violation:
|
|
87
|
+
self._raise_mixed_style_error(msg)
|
|
88
|
+
error = ValidationError(
|
|
89
|
+
message=msg,
|
|
90
|
+
code="mixed-parameter-styles",
|
|
91
|
+
risk_level=self.risk_level,
|
|
92
|
+
processor="ParameterStyleValidator",
|
|
93
|
+
expression=expression,
|
|
94
|
+
)
|
|
95
|
+
context.validation_errors.append(error)
|
|
96
|
+
has_style_errors = True
|
|
97
|
+
|
|
98
|
+
# Check for disallowed styles
|
|
99
|
+
disallowed_styles = {str(s) for s in unique_styles if not config.validate_parameter_style(s)}
|
|
100
|
+
if disallowed_styles:
|
|
101
|
+
disallowed_str = ", ".join(sorted(disallowed_styles))
|
|
102
|
+
# Defensive handling to avoid "expected str instance, NoneType found"
|
|
103
|
+
if config.allowed_parameter_styles:
|
|
104
|
+
allowed_styles_strs = [str(s) for s in config.allowed_parameter_styles]
|
|
105
|
+
allowed_str = ", ".join(allowed_styles_strs)
|
|
106
|
+
msg = f"Parameter style(s) {disallowed_str} not supported. Allowed: {allowed_str}"
|
|
107
|
+
else:
|
|
108
|
+
msg = f"Parameter style(s) {disallowed_str} not supported."
|
|
109
|
+
|
|
110
|
+
if self.fail_on_violation:
|
|
111
|
+
self._raise_unsupported_style_error(msg)
|
|
112
|
+
error = ValidationError(
|
|
113
|
+
message=msg,
|
|
114
|
+
code="unsupported-parameter-style",
|
|
115
|
+
risk_level=self.risk_level,
|
|
116
|
+
processor="ParameterStyleValidator",
|
|
117
|
+
expression=expression,
|
|
118
|
+
)
|
|
119
|
+
context.validation_errors.append(error)
|
|
120
|
+
has_style_errors = True
|
|
121
|
+
|
|
122
|
+
# Check for missing parameters if:
|
|
123
|
+
# 1. We have parameter info
|
|
124
|
+
# 2. Style validation is enabled (allowed_parameter_styles is not None)
|
|
125
|
+
# 3. No style errors were found
|
|
126
|
+
# 4. We have merged parameters OR the original SQL had placeholders
|
|
127
|
+
logger.debug(
|
|
128
|
+
"Checking missing parameters: param_info=%s, extracted=%s, had_placeholders=%s, merged=%s",
|
|
129
|
+
len(param_info) if param_info else 0,
|
|
130
|
+
len(context.extracted_parameters_from_pipeline) if context.extracted_parameters_from_pipeline else 0,
|
|
131
|
+
context.input_sql_had_placeholders,
|
|
132
|
+
context.merged_parameters is not None,
|
|
133
|
+
)
|
|
134
|
+
# Skip validation if we have no merged parameters and the SQL didn't originally have placeholders
|
|
135
|
+
# This handles the case where literals were parameterized by transformers
|
|
136
|
+
if (
|
|
137
|
+
param_info
|
|
138
|
+
and config.allowed_parameter_styles is not None
|
|
139
|
+
and not has_style_errors
|
|
140
|
+
and (context.merged_parameters is not None or context.input_sql_had_placeholders)
|
|
141
|
+
):
|
|
142
|
+
self._validate_missing_parameters(context, expression)
|
|
143
|
+
|
|
144
|
+
except (UnsupportedParameterStyleError, MixedParameterStyleError, MissingParameterError):
|
|
145
|
+
raise
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.warning("Parameter style validation failed: %s", e)
|
|
148
|
+
error = ValidationError(
|
|
149
|
+
message=f"Parameter style validation failed: {e}",
|
|
150
|
+
code="validation-error",
|
|
151
|
+
risk_level=RiskLevel.LOW,
|
|
152
|
+
processor="ParameterStyleValidator",
|
|
153
|
+
expression=expression,
|
|
154
|
+
)
|
|
155
|
+
context.validation_errors.append(error)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def _raise_mixed_style_error(msg: "str") -> "None":
|
|
159
|
+
"""Raise MixedParameterStyleError with the given message."""
|
|
160
|
+
raise MixedParameterStyleError(msg)
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _raise_unsupported_style_error(msg: "str") -> "None":
|
|
164
|
+
"""Raise UnsupportedParameterStyleError with the given message."""
|
|
165
|
+
raise UnsupportedParameterStyleError(msg)
|
|
166
|
+
|
|
167
|
+
def _validate_missing_parameters(self, context: "SQLProcessingContext", expression: exp.Expression) -> None:
|
|
168
|
+
"""Validate that all required parameters have values provided."""
|
|
169
|
+
param_info = context.parameter_info
|
|
170
|
+
if not param_info:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
merged_params = self._prepare_merged_parameters(context, param_info)
|
|
174
|
+
|
|
175
|
+
if merged_params is None:
|
|
176
|
+
self._handle_no_parameters(context, expression, param_info)
|
|
177
|
+
elif isinstance(merged_params, (list, tuple)):
|
|
178
|
+
self._handle_positional_parameters(context, expression, param_info, merged_params)
|
|
179
|
+
elif is_dict(merged_params):
|
|
180
|
+
self._handle_named_parameters(context, expression, param_info, merged_params)
|
|
181
|
+
elif len(param_info) > 1:
|
|
182
|
+
self._handle_single_value_multiple_params(context, expression, param_info)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _prepare_merged_parameters(context: "SQLProcessingContext", param_info: list[Any]) -> Any:
|
|
186
|
+
"""Prepare merged parameters for validation."""
|
|
187
|
+
merged_params = context.merged_parameters
|
|
188
|
+
|
|
189
|
+
# If we have extracted parameters from transformers (like ParameterizeLiterals),
|
|
190
|
+
# use those for validation instead of the original merged_parameters
|
|
191
|
+
if context.extracted_parameters_from_pipeline and not context.input_sql_had_placeholders:
|
|
192
|
+
# Use extracted parameters as they represent the actual values to be used
|
|
193
|
+
merged_params = context.extracted_parameters_from_pipeline
|
|
194
|
+
has_positional_colon = any(p.style.value == "positional_colon" for p in param_info)
|
|
195
|
+
if has_positional_colon and not isinstance(merged_params, (list, tuple, dict)) and merged_params is not None:
|
|
196
|
+
return [merged_params]
|
|
197
|
+
return merged_params
|
|
198
|
+
|
|
199
|
+
def _report_error(self, context: "SQLProcessingContext", expression: exp.Expression, message: str) -> None:
|
|
200
|
+
"""Report a missing parameter error."""
|
|
201
|
+
if self.fail_on_violation:
|
|
202
|
+
raise MissingParameterError(message)
|
|
203
|
+
error = ValidationError(
|
|
204
|
+
message=message,
|
|
205
|
+
code="missing-parameters",
|
|
206
|
+
risk_level=self.risk_level,
|
|
207
|
+
processor="ParameterStyleValidator",
|
|
208
|
+
expression=expression,
|
|
209
|
+
)
|
|
210
|
+
context.validation_errors.append(error)
|
|
211
|
+
|
|
212
|
+
def _handle_no_parameters(
|
|
213
|
+
self, context: "SQLProcessingContext", expression: exp.Expression, param_info: list[Any]
|
|
214
|
+
) -> None:
|
|
215
|
+
"""Handle validation when no parameters are provided."""
|
|
216
|
+
if context.extracted_parameters_from_pipeline:
|
|
217
|
+
return
|
|
218
|
+
missing = [p.name or p.placeholder_text or f"param_{p.ordinal}" for p in param_info]
|
|
219
|
+
msg = f"Missing required parameters: {', '.join(str(m) for m in missing)}"
|
|
220
|
+
self._report_error(context, expression, msg)
|
|
221
|
+
|
|
222
|
+
def _handle_positional_parameters(
|
|
223
|
+
self,
|
|
224
|
+
context: "SQLProcessingContext",
|
|
225
|
+
expression: exp.Expression,
|
|
226
|
+
param_info: list[Any],
|
|
227
|
+
merged_params: "Union[list[Any], tuple[Any, ...]]",
|
|
228
|
+
) -> None:
|
|
229
|
+
"""Handle validation for positional parameters."""
|
|
230
|
+
has_named = any(p.style.value in {"named_colon", "named_at"} for p in param_info)
|
|
231
|
+
if has_named:
|
|
232
|
+
missing_named = [
|
|
233
|
+
p.name or p.placeholder_text for p in param_info if p.style.value in {"named_colon", "named_at"}
|
|
234
|
+
]
|
|
235
|
+
if missing_named:
|
|
236
|
+
msg = f"Missing required parameters: {', '.join(str(m) for m in missing_named if m)}"
|
|
237
|
+
self._report_error(context, expression, msg)
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
has_positional_colon = any(p.style.value == "positional_colon" for p in param_info)
|
|
241
|
+
if has_positional_colon:
|
|
242
|
+
self._validate_oracle_numeric_params(context, expression, param_info, merged_params)
|
|
243
|
+
elif len(merged_params) < len(param_info):
|
|
244
|
+
msg = f"Expected {len(param_info)} parameters but got {len(merged_params)}"
|
|
245
|
+
self._report_error(context, expression, msg)
|
|
246
|
+
|
|
247
|
+
def _validate_oracle_numeric_params(
|
|
248
|
+
self,
|
|
249
|
+
context: "SQLProcessingContext",
|
|
250
|
+
expression: exp.Expression,
|
|
251
|
+
param_info: list[Any],
|
|
252
|
+
merged_params: "Union[list[Any], tuple[Any, ...]]",
|
|
253
|
+
) -> None:
|
|
254
|
+
"""Validate Oracle-style numeric parameters."""
|
|
255
|
+
missing_indices: list[str] = []
|
|
256
|
+
provided_count = len(merged_params)
|
|
257
|
+
for p in param_info:
|
|
258
|
+
if p.style.value != "positional_colon" or not p.name:
|
|
259
|
+
continue
|
|
260
|
+
try:
|
|
261
|
+
idx = int(p.name)
|
|
262
|
+
if not (idx < provided_count or (idx > 0 and (idx - 1) < provided_count)):
|
|
263
|
+
missing_indices.append(p.name)
|
|
264
|
+
except (ValueError, TypeError):
|
|
265
|
+
pass
|
|
266
|
+
if missing_indices:
|
|
267
|
+
msg = f"Missing required parameters: :{', :'.join(missing_indices)}"
|
|
268
|
+
self._report_error(context, expression, msg)
|
|
269
|
+
|
|
270
|
+
def _handle_named_parameters(
|
|
271
|
+
self,
|
|
272
|
+
context: "SQLProcessingContext",
|
|
273
|
+
expression: exp.Expression,
|
|
274
|
+
param_info: list[Any],
|
|
275
|
+
merged_params: dict[str, Any],
|
|
276
|
+
) -> None:
|
|
277
|
+
"""Handle validation for named parameters."""
|
|
278
|
+
missing: list[str] = []
|
|
279
|
+
for p in param_info:
|
|
280
|
+
param_name = p.name
|
|
281
|
+
if param_name not in merged_params:
|
|
282
|
+
is_synthetic = any(key.startswith(("_arg_", "param_")) for key in merged_params)
|
|
283
|
+
is_named_style = p.style.value not in {"qmark", "numeric"}
|
|
284
|
+
if (not is_synthetic or is_named_style) and param_name:
|
|
285
|
+
missing.append(param_name)
|
|
286
|
+
|
|
287
|
+
if missing:
|
|
288
|
+
msg = f"Missing required parameters: {', '.join(missing)}"
|
|
289
|
+
self._report_error(context, expression, msg)
|
|
290
|
+
|
|
291
|
+
def _handle_single_value_multiple_params(
|
|
292
|
+
self, context: "SQLProcessingContext", expression: exp.Expression, param_info: list[Any]
|
|
293
|
+
) -> None:
|
|
294
|
+
"""Handle validation for a single value provided for multiple parameters."""
|
|
295
|
+
missing = [p.name or p.placeholder_text or f"param_{p.ordinal}" for p in param_info[1:]]
|
|
296
|
+
msg = f"Missing required parameters: {', '.join(str(m) for m in missing)}"
|
|
297
|
+
self._report_error(context, expression, msg)
|