sqlspec 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +256 -120
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -248
- sqlspec/adapters/adbc/driver.py +462 -353
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +6 -64
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +828 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +651 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +168 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +16 -14
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +16 -37
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
- sqlspec-0.15.0.dist-info/RECORD +134 -0
- sqlspec-0.15.0.dist-info/entry_points.txt +2 -0
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -996
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -115
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.0.dist-info/RECORD +0 -143
- sqlspec-0.14.0.dist-info/entry_points.txt +0 -2
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
"""Removes SQL comments and hints from expressions."""
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
-
|
|
5
|
-
from sqlglot import exp
|
|
6
|
-
|
|
7
|
-
from sqlspec.protocols import ProcessorProtocol
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
11
|
-
|
|
12
|
-
__all__ = ("CommentAndHintRemover",)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CommentAndHintRemover(ProcessorProtocol):
|
|
16
|
-
"""Removes SQL comments and hints from expressions using SQLGlot's AST traversal."""
|
|
17
|
-
|
|
18
|
-
def __init__(self, enabled: bool = True, remove_comments: bool = True, remove_hints: bool = False) -> None:
|
|
19
|
-
self.enabled = enabled
|
|
20
|
-
self.remove_comments = remove_comments
|
|
21
|
-
self.remove_hints = remove_hints
|
|
22
|
-
|
|
23
|
-
def process(
|
|
24
|
-
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
25
|
-
) -> "Optional[exp.Expression]":
|
|
26
|
-
if not self.enabled or expression is None:
|
|
27
|
-
return expression
|
|
28
|
-
|
|
29
|
-
comments_removed_count = 0
|
|
30
|
-
hints_removed_count = 0
|
|
31
|
-
|
|
32
|
-
def _remove_comments_and_hints(node: exp.Expression) -> "Optional[exp.Expression]":
|
|
33
|
-
nonlocal comments_removed_count, hints_removed_count
|
|
34
|
-
|
|
35
|
-
if self.remove_hints and isinstance(node, exp.Hint):
|
|
36
|
-
hints_removed_count += 1
|
|
37
|
-
return None
|
|
38
|
-
|
|
39
|
-
if hasattr(node, "comments") and node.comments:
|
|
40
|
-
original_comment_count = len(node.comments)
|
|
41
|
-
comments_to_keep = []
|
|
42
|
-
for comment in node.comments:
|
|
43
|
-
comment_text = str(comment).strip()
|
|
44
|
-
is_hint = self._is_hint(comment_text)
|
|
45
|
-
|
|
46
|
-
if is_hint:
|
|
47
|
-
if not self.remove_hints:
|
|
48
|
-
comments_to_keep.append(comment)
|
|
49
|
-
elif not self.remove_comments:
|
|
50
|
-
comments_to_keep.append(comment)
|
|
51
|
-
|
|
52
|
-
removed_count = original_comment_count - len(comments_to_keep)
|
|
53
|
-
if removed_count > 0:
|
|
54
|
-
if self.remove_hints:
|
|
55
|
-
hints_removed_count += sum(1 for c in node.comments if self._is_hint(str(c).strip()))
|
|
56
|
-
if self.remove_comments:
|
|
57
|
-
comments_removed_count += sum(1 for c in node.comments if not self._is_hint(str(c).strip()))
|
|
58
|
-
|
|
59
|
-
node.pop_comments()
|
|
60
|
-
if comments_to_keep:
|
|
61
|
-
node.add_comments(comments_to_keep)
|
|
62
|
-
|
|
63
|
-
return node
|
|
64
|
-
|
|
65
|
-
cleaned_expression = expression.transform(_remove_comments_and_hints, copy=True)
|
|
66
|
-
|
|
67
|
-
context.metadata["comments_removed"] = comments_removed_count
|
|
68
|
-
context.metadata["hints_removed"] = hints_removed_count
|
|
69
|
-
|
|
70
|
-
return cleaned_expression
|
|
71
|
-
|
|
72
|
-
def _is_hint(self, comment_text: str) -> bool:
|
|
73
|
-
hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
|
|
74
|
-
return any(keyword in comment_text.upper() for keyword in hint_keywords) or (
|
|
75
|
-
comment_text.startswith("!") and comment_text.endswith("")
|
|
76
|
-
)
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
"""SQL Validation Pipeline Components."""
|
|
2
|
-
|
|
3
|
-
from sqlspec.statement.pipelines.validators._dml_safety import DMLSafetyConfig, DMLSafetyValidator
|
|
4
|
-
from sqlspec.statement.pipelines.validators._parameter_style import ParameterStyleValidator
|
|
5
|
-
from sqlspec.statement.pipelines.validators._performance import PerformanceConfig, PerformanceValidator
|
|
6
|
-
from sqlspec.statement.pipelines.validators._security import (
|
|
7
|
-
SecurityIssue,
|
|
8
|
-
SecurityIssueType,
|
|
9
|
-
SecurityValidator,
|
|
10
|
-
SecurityValidatorConfig,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
__all__ = (
|
|
14
|
-
"DMLSafetyConfig",
|
|
15
|
-
"DMLSafetyValidator",
|
|
16
|
-
"ParameterStyleValidator",
|
|
17
|
-
"PerformanceConfig",
|
|
18
|
-
"PerformanceValidator",
|
|
19
|
-
"SecurityIssue",
|
|
20
|
-
"SecurityIssueType",
|
|
21
|
-
"SecurityValidator",
|
|
22
|
-
"SecurityValidatorConfig",
|
|
23
|
-
)
|
|
@@ -1,290 +0,0 @@
|
|
|
1
|
-
# DML Safety Validator - Consolidates risky DML operations and DDL prevention
|
|
2
|
-
from dataclasses import dataclass, field
|
|
3
|
-
from enum import Enum
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
5
|
-
|
|
6
|
-
from sqlglot import expressions as exp
|
|
7
|
-
|
|
8
|
-
from sqlspec.exceptions import RiskLevel
|
|
9
|
-
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
-
from sqlspec.statement.pipelines.context import ValidationError
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
14
|
-
|
|
15
|
-
__all__ = ("DMLSafetyConfig", "DMLSafetyValidator", "StatementCategory")
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class StatementCategory(Enum):
|
|
19
|
-
"""Categories for SQL statement types."""
|
|
20
|
-
|
|
21
|
-
DDL = "ddl" # CREATE, ALTER, DROP, TRUNCATE
|
|
22
|
-
DML = "dml" # INSERT, UPDATE, DELETE, MERGE
|
|
23
|
-
DQL = "dql" # SELECT
|
|
24
|
-
DCL = "dcl" # GRANT, REVOKE
|
|
25
|
-
TCL = "tcl" # COMMIT, ROLLBACK, SAVEPOINT
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@dataclass
|
|
29
|
-
class DMLSafetyConfig:
|
|
30
|
-
"""Configuration for DML safety validation."""
|
|
31
|
-
|
|
32
|
-
prevent_ddl: bool = True
|
|
33
|
-
prevent_dcl: bool = True
|
|
34
|
-
require_where_clause: "set[str]" = field(default_factory=lambda: {"DELETE", "UPDATE"})
|
|
35
|
-
allowed_ddl_operations: "set[str]" = field(default_factory=set)
|
|
36
|
-
migration_mode: bool = False # Allow DDL in migration contexts
|
|
37
|
-
max_affected_rows: "Optional[int]" = None # Limit for DML operations
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class DMLSafetyValidator(ProcessorProtocol):
|
|
41
|
-
"""Unified validator for DML/DDL safety checks.
|
|
42
|
-
|
|
43
|
-
This validator consolidates:
|
|
44
|
-
- DDL prevention (CREATE, ALTER, DROP, etc.)
|
|
45
|
-
- Risky DML detection (DELETE/UPDATE without WHERE)
|
|
46
|
-
- DCL restrictions (GRANT, REVOKE)
|
|
47
|
-
- Row limit enforcement
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(self, config: "Optional[DMLSafetyConfig]" = None) -> None:
|
|
51
|
-
"""Initialize the DML safety validator.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
config: Configuration for safety validation
|
|
55
|
-
"""
|
|
56
|
-
self.config = config or DMLSafetyConfig()
|
|
57
|
-
|
|
58
|
-
def process(
|
|
59
|
-
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
60
|
-
) -> "Optional[exp.Expression]":
|
|
61
|
-
"""Process the expression for validation (implements ProcessorProtocol)."""
|
|
62
|
-
if expression is None:
|
|
63
|
-
return None
|
|
64
|
-
self.validate(expression, context)
|
|
65
|
-
return expression
|
|
66
|
-
|
|
67
|
-
def add_error(
|
|
68
|
-
self,
|
|
69
|
-
context: "SQLProcessingContext",
|
|
70
|
-
message: str,
|
|
71
|
-
code: str,
|
|
72
|
-
risk_level: RiskLevel,
|
|
73
|
-
expression: "Optional[exp.Expression]" = None,
|
|
74
|
-
) -> None:
|
|
75
|
-
"""Add a validation error to the context."""
|
|
76
|
-
error = ValidationError(
|
|
77
|
-
message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
|
|
78
|
-
)
|
|
79
|
-
context.validation_errors.append(error)
|
|
80
|
-
|
|
81
|
-
def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
|
|
82
|
-
"""Validate SQL statement for safety issues.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
expression: The SQL expression to validate
|
|
86
|
-
context: The SQL processing context
|
|
87
|
-
"""
|
|
88
|
-
# Categorize statement
|
|
89
|
-
category = self._categorize_statement(expression)
|
|
90
|
-
operation = self._get_operation_type(expression)
|
|
91
|
-
|
|
92
|
-
if category == StatementCategory.DDL and self.config.prevent_ddl:
|
|
93
|
-
if operation not in self.config.allowed_ddl_operations:
|
|
94
|
-
self.add_error(
|
|
95
|
-
context,
|
|
96
|
-
message=f"DDL operation '{operation}' is not allowed",
|
|
97
|
-
code="ddl-not-allowed",
|
|
98
|
-
risk_level=RiskLevel.CRITICAL,
|
|
99
|
-
expression=expression,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
elif category == StatementCategory.DML:
|
|
103
|
-
if operation in self.config.require_where_clause and not self._has_where_clause(expression):
|
|
104
|
-
self.add_error(
|
|
105
|
-
context,
|
|
106
|
-
message=f"{operation} without WHERE clause affects all rows",
|
|
107
|
-
code=f"{operation.lower()}-without-where",
|
|
108
|
-
risk_level=RiskLevel.HIGH,
|
|
109
|
-
expression=expression,
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
if self.config.max_affected_rows:
|
|
113
|
-
estimated_rows = self._estimate_affected_rows(expression)
|
|
114
|
-
if estimated_rows > self.config.max_affected_rows:
|
|
115
|
-
self.add_error(
|
|
116
|
-
context,
|
|
117
|
-
message=f"Operation may affect {estimated_rows:,} rows (limit: {self.config.max_affected_rows:,})",
|
|
118
|
-
code="excessive-rows-affected",
|
|
119
|
-
risk_level=RiskLevel.MEDIUM,
|
|
120
|
-
expression=expression,
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
elif category == StatementCategory.DCL and self.config.prevent_dcl:
|
|
124
|
-
self.add_error(
|
|
125
|
-
context,
|
|
126
|
-
message=f"DCL operation '{operation}' is not allowed",
|
|
127
|
-
code="dcl-not-allowed",
|
|
128
|
-
risk_level=RiskLevel.HIGH,
|
|
129
|
-
expression=expression,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
# Store metadata in context
|
|
133
|
-
context.metadata[self.__class__.__name__] = {
|
|
134
|
-
"statement_category": category.value,
|
|
135
|
-
"operation": operation,
|
|
136
|
-
"has_where_clause": self._has_where_clause(expression) if category == StatementCategory.DML else None,
|
|
137
|
-
"affected_tables": self._extract_affected_tables(expression),
|
|
138
|
-
"migration_mode": self.config.migration_mode,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
@staticmethod
|
|
142
|
-
def _categorize_statement(expression: "exp.Expression") -> StatementCategory:
|
|
143
|
-
"""Categorize SQL statement type.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
expression: The SQL expression to categorize
|
|
147
|
-
|
|
148
|
-
Returns:
|
|
149
|
-
The statement category
|
|
150
|
-
"""
|
|
151
|
-
if isinstance(expression, (exp.Create, exp.Alter, exp.Drop, exp.TruncateTable, exp.Comment)):
|
|
152
|
-
return StatementCategory.DDL
|
|
153
|
-
|
|
154
|
-
if isinstance(expression, (exp.Select, exp.Union, exp.Intersect, exp.Except)):
|
|
155
|
-
return StatementCategory.DQL
|
|
156
|
-
|
|
157
|
-
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete, exp.Merge)):
|
|
158
|
-
return StatementCategory.DML
|
|
159
|
-
|
|
160
|
-
if isinstance(expression, (exp.Grant,)):
|
|
161
|
-
return StatementCategory.DCL
|
|
162
|
-
|
|
163
|
-
if isinstance(expression, (exp.Commit, exp.Rollback)):
|
|
164
|
-
return StatementCategory.TCL
|
|
165
|
-
|
|
166
|
-
return StatementCategory.DQL # Default to query
|
|
167
|
-
|
|
168
|
-
@staticmethod
|
|
169
|
-
def _get_operation_type(expression: "exp.Expression") -> str:
|
|
170
|
-
"""Get specific operation name.
|
|
171
|
-
|
|
172
|
-
Args:
|
|
173
|
-
expression: The SQL expression
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
The operation type as string
|
|
177
|
-
"""
|
|
178
|
-
return expression.__class__.__name__.upper()
|
|
179
|
-
|
|
180
|
-
@staticmethod
|
|
181
|
-
def _has_where_clause(expression: "exp.Expression") -> bool:
|
|
182
|
-
"""Check if DML statement has WHERE clause.
|
|
183
|
-
|
|
184
|
-
Args:
|
|
185
|
-
expression: The SQL expression to check
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
True if WHERE clause exists, False otherwise
|
|
189
|
-
"""
|
|
190
|
-
if isinstance(expression, (exp.Delete, exp.Update)):
|
|
191
|
-
return expression.args.get("where") is not None
|
|
192
|
-
return True # Other statements don't require WHERE
|
|
193
|
-
|
|
194
|
-
def _estimate_affected_rows(self, expression: "exp.Expression") -> int:
|
|
195
|
-
"""Estimate number of rows affected by DML operation.
|
|
196
|
-
|
|
197
|
-
Args:
|
|
198
|
-
expression: The SQL expression
|
|
199
|
-
|
|
200
|
-
Returns:
|
|
201
|
-
Estimated number of affected rows
|
|
202
|
-
"""
|
|
203
|
-
# Simple heuristic - can be enhanced with table statistics
|
|
204
|
-
if not self._has_where_clause(expression):
|
|
205
|
-
return 999999999 # Large number to indicate all rows
|
|
206
|
-
|
|
207
|
-
where = expression.args.get("where")
|
|
208
|
-
if where:
|
|
209
|
-
if self._has_unique_condition(where):
|
|
210
|
-
return 1
|
|
211
|
-
if self._has_indexed_condition(where):
|
|
212
|
-
return 100 # Rough estimate
|
|
213
|
-
|
|
214
|
-
return 10000 # Conservative estimate
|
|
215
|
-
|
|
216
|
-
@staticmethod
|
|
217
|
-
def _has_unique_condition(where: "Optional[exp.Expression]") -> bool:
|
|
218
|
-
"""Check if WHERE clause uses unique columns.
|
|
219
|
-
|
|
220
|
-
Args:
|
|
221
|
-
where: The WHERE expression
|
|
222
|
-
|
|
223
|
-
Returns:
|
|
224
|
-
True if unique condition found
|
|
225
|
-
"""
|
|
226
|
-
if where is None:
|
|
227
|
-
return False
|
|
228
|
-
# Look for id = value patterns
|
|
229
|
-
for condition in where.find_all(exp.EQ):
|
|
230
|
-
if isinstance(condition.left, exp.Column):
|
|
231
|
-
col_name = condition.left.name.lower()
|
|
232
|
-
if col_name in {"id", "uuid", "guid", "pk", "primary_key"}:
|
|
233
|
-
return True
|
|
234
|
-
return False
|
|
235
|
-
|
|
236
|
-
@staticmethod
|
|
237
|
-
def _has_indexed_condition(where: "Optional[exp.Expression]") -> bool:
|
|
238
|
-
"""Check if WHERE clause uses indexed columns.
|
|
239
|
-
|
|
240
|
-
Args:
|
|
241
|
-
where: The WHERE expression
|
|
242
|
-
|
|
243
|
-
Returns:
|
|
244
|
-
True if indexed condition found
|
|
245
|
-
"""
|
|
246
|
-
if where is None:
|
|
247
|
-
return False
|
|
248
|
-
# Look for common indexed column patterns
|
|
249
|
-
for condition in where.find_all(exp.Predicate):
|
|
250
|
-
if isinstance(condition, (exp.EQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.NEQ)) and isinstance(
|
|
251
|
-
condition.left, exp.Column
|
|
252
|
-
):
|
|
253
|
-
col_name = condition.left.name.lower()
|
|
254
|
-
# Common indexed columns
|
|
255
|
-
if col_name in {"created_at", "updated_at", "email", "username", "status", "type"}:
|
|
256
|
-
return True
|
|
257
|
-
return False
|
|
258
|
-
|
|
259
|
-
@staticmethod
|
|
260
|
-
def _extract_affected_tables(expression: "exp.Expression") -> "list[str]":
|
|
261
|
-
"""Extract table names affected by the statement.
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
expression: The SQL expression
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
List of affected table names
|
|
268
|
-
"""
|
|
269
|
-
tables = []
|
|
270
|
-
|
|
271
|
-
# For DML statements
|
|
272
|
-
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
|
|
273
|
-
if expression.this:
|
|
274
|
-
table_expr = expression.this
|
|
275
|
-
if isinstance(table_expr, exp.Table):
|
|
276
|
-
tables.append(table_expr.name)
|
|
277
|
-
|
|
278
|
-
# For DDL statements
|
|
279
|
-
elif isinstance(expression, (exp.Create, exp.Drop, exp.Alter)) and expression.this:
|
|
280
|
-
# For CREATE TABLE, the table is in expression.this.this
|
|
281
|
-
if isinstance(expression, exp.Create) and isinstance(expression.this, exp.Schema):
|
|
282
|
-
if expression.this.this:
|
|
283
|
-
table_expr = expression.this.this
|
|
284
|
-
if isinstance(table_expr, exp.Table):
|
|
285
|
-
tables.append(table_expr.name)
|
|
286
|
-
# For DROP/ALTER, table is directly in expression.this
|
|
287
|
-
elif isinstance(expression.this, (exp.Table, exp.Identifier)):
|
|
288
|
-
tables.append(expression.this.name)
|
|
289
|
-
|
|
290
|
-
return tables
|