kontra 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kontra/__init__.py +1871 -0
- kontra/api/__init__.py +22 -0
- kontra/api/compare.py +340 -0
- kontra/api/decorators.py +153 -0
- kontra/api/results.py +2121 -0
- kontra/api/rules.py +681 -0
- kontra/cli/__init__.py +0 -0
- kontra/cli/commands/__init__.py +1 -0
- kontra/cli/commands/config.py +153 -0
- kontra/cli/commands/diff.py +450 -0
- kontra/cli/commands/history.py +196 -0
- kontra/cli/commands/profile.py +289 -0
- kontra/cli/commands/validate.py +468 -0
- kontra/cli/constants.py +6 -0
- kontra/cli/main.py +48 -0
- kontra/cli/renderers.py +304 -0
- kontra/cli/utils.py +28 -0
- kontra/config/__init__.py +34 -0
- kontra/config/loader.py +127 -0
- kontra/config/models.py +49 -0
- kontra/config/settings.py +797 -0
- kontra/connectors/__init__.py +0 -0
- kontra/connectors/db_utils.py +251 -0
- kontra/connectors/detection.py +323 -0
- kontra/connectors/handle.py +368 -0
- kontra/connectors/postgres.py +127 -0
- kontra/connectors/sqlserver.py +226 -0
- kontra/engine/__init__.py +0 -0
- kontra/engine/backends/duckdb_session.py +227 -0
- kontra/engine/backends/duckdb_utils.py +18 -0
- kontra/engine/backends/polars_backend.py +47 -0
- kontra/engine/engine.py +1205 -0
- kontra/engine/executors/__init__.py +15 -0
- kontra/engine/executors/base.py +50 -0
- kontra/engine/executors/database_base.py +528 -0
- kontra/engine/executors/duckdb_sql.py +607 -0
- kontra/engine/executors/postgres_sql.py +162 -0
- kontra/engine/executors/registry.py +69 -0
- kontra/engine/executors/sqlserver_sql.py +163 -0
- kontra/engine/materializers/__init__.py +14 -0
- kontra/engine/materializers/base.py +42 -0
- kontra/engine/materializers/duckdb.py +110 -0
- kontra/engine/materializers/factory.py +22 -0
- kontra/engine/materializers/polars_connector.py +131 -0
- kontra/engine/materializers/postgres.py +157 -0
- kontra/engine/materializers/registry.py +138 -0
- kontra/engine/materializers/sqlserver.py +160 -0
- kontra/engine/result.py +15 -0
- kontra/engine/sql_utils.py +611 -0
- kontra/engine/sql_validator.py +609 -0
- kontra/engine/stats.py +194 -0
- kontra/engine/types.py +138 -0
- kontra/errors.py +533 -0
- kontra/logging.py +85 -0
- kontra/preplan/__init__.py +5 -0
- kontra/preplan/planner.py +253 -0
- kontra/preplan/postgres.py +179 -0
- kontra/preplan/sqlserver.py +191 -0
- kontra/preplan/types.py +24 -0
- kontra/probes/__init__.py +20 -0
- kontra/probes/compare.py +400 -0
- kontra/probes/relationship.py +283 -0
- kontra/reporters/__init__.py +0 -0
- kontra/reporters/json_reporter.py +190 -0
- kontra/reporters/rich_reporter.py +11 -0
- kontra/rules/__init__.py +35 -0
- kontra/rules/base.py +186 -0
- kontra/rules/builtin/__init__.py +40 -0
- kontra/rules/builtin/allowed_values.py +156 -0
- kontra/rules/builtin/compare.py +188 -0
- kontra/rules/builtin/conditional_not_null.py +213 -0
- kontra/rules/builtin/conditional_range.py +310 -0
- kontra/rules/builtin/contains.py +138 -0
- kontra/rules/builtin/custom_sql_check.py +182 -0
- kontra/rules/builtin/disallowed_values.py +140 -0
- kontra/rules/builtin/dtype.py +203 -0
- kontra/rules/builtin/ends_with.py +129 -0
- kontra/rules/builtin/freshness.py +240 -0
- kontra/rules/builtin/length.py +193 -0
- kontra/rules/builtin/max_rows.py +35 -0
- kontra/rules/builtin/min_rows.py +46 -0
- kontra/rules/builtin/not_null.py +121 -0
- kontra/rules/builtin/range.py +222 -0
- kontra/rules/builtin/regex.py +143 -0
- kontra/rules/builtin/starts_with.py +129 -0
- kontra/rules/builtin/unique.py +124 -0
- kontra/rules/condition_parser.py +203 -0
- kontra/rules/execution_plan.py +455 -0
- kontra/rules/factory.py +103 -0
- kontra/rules/predicates.py +25 -0
- kontra/rules/registry.py +24 -0
- kontra/rules/static_predicates.py +120 -0
- kontra/scout/__init__.py +9 -0
- kontra/scout/backends/__init__.py +17 -0
- kontra/scout/backends/base.py +111 -0
- kontra/scout/backends/duckdb_backend.py +359 -0
- kontra/scout/backends/postgres_backend.py +519 -0
- kontra/scout/backends/sqlserver_backend.py +577 -0
- kontra/scout/dtype_mapping.py +150 -0
- kontra/scout/patterns.py +69 -0
- kontra/scout/profiler.py +801 -0
- kontra/scout/reporters/__init__.py +39 -0
- kontra/scout/reporters/json_reporter.py +165 -0
- kontra/scout/reporters/markdown_reporter.py +152 -0
- kontra/scout/reporters/rich_reporter.py +144 -0
- kontra/scout/store.py +208 -0
- kontra/scout/suggest.py +200 -0
- kontra/scout/types.py +652 -0
- kontra/state/__init__.py +29 -0
- kontra/state/backends/__init__.py +79 -0
- kontra/state/backends/base.py +348 -0
- kontra/state/backends/local.py +480 -0
- kontra/state/backends/postgres.py +1010 -0
- kontra/state/backends/s3.py +543 -0
- kontra/state/backends/sqlserver.py +969 -0
- kontra/state/fingerprint.py +166 -0
- kontra/state/types.py +1061 -0
- kontra/version.py +1 -0
- kontra-0.5.2.dist-info/METADATA +122 -0
- kontra-0.5.2.dist-info/RECORD +124 -0
- kontra-0.5.2.dist-info/WHEEL +5 -0
- kontra-0.5.2.dist-info/entry_points.txt +2 -0
- kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
- kontra-0.5.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
# src/kontra/engine/sql_validator.py
|
|
2
|
+
"""
|
|
3
|
+
SQL validation using sqlglot for safe remote execution.
|
|
4
|
+
|
|
5
|
+
Ensures user-provided SQL is read-only before executing on production databases.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import List, Optional, Set, Tuple
|
|
12
|
+
|
|
13
|
+
import sqlglot
|
|
14
|
+
from sqlglot import exp
|
|
15
|
+
from sqlglot.errors import ParseError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Statement types that are NOT allowed (write operations and external access)
|
|
19
|
+
FORBIDDEN_STATEMENT_TYPES: Set[type] = {
|
|
20
|
+
exp.Insert,
|
|
21
|
+
exp.Update,
|
|
22
|
+
exp.Delete,
|
|
23
|
+
exp.Drop,
|
|
24
|
+
exp.Create,
|
|
25
|
+
exp.Alter,
|
|
26
|
+
exp.Merge,
|
|
27
|
+
exp.Grant,
|
|
28
|
+
exp.Revoke,
|
|
29
|
+
exp.Command, # Generic command execution
|
|
30
|
+
exp.Copy, # COPY command (file I/O)
|
|
31
|
+
exp.Set, # SET commands (configuration changes)
|
|
32
|
+
exp.Use, # USE database (context switching)
|
|
33
|
+
exp.Attach, # ATTACH external databases (SEC-002)
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# Table prefixes/schemas that are forbidden (system catalogs)
|
|
37
|
+
# These allow information disclosure attacks
|
|
38
|
+
FORBIDDEN_TABLE_PREFIXES: Set[str] = {
|
|
39
|
+
# PostgreSQL system catalogs
|
|
40
|
+
"pg_",
|
|
41
|
+
# SQL Server system views
|
|
42
|
+
"sys.",
|
|
43
|
+
# Standard information schema (both PostgreSQL and SQL Server)
|
|
44
|
+
"information_schema.",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Specific system tables to block (without prefix, for tables accessed directly)
|
|
48
|
+
FORBIDDEN_TABLES: Set[str] = {
|
|
49
|
+
# PostgreSQL sensitive tables
|
|
50
|
+
"pg_shadow",
|
|
51
|
+
"pg_authid",
|
|
52
|
+
"pg_roles",
|
|
53
|
+
"pg_user",
|
|
54
|
+
"pg_database",
|
|
55
|
+
"pg_tablespace",
|
|
56
|
+
"pg_settings",
|
|
57
|
+
"pg_stat_activity",
|
|
58
|
+
"pg_stat_user_tables",
|
|
59
|
+
# SQL Server sensitive tables
|
|
60
|
+
"syslogins",
|
|
61
|
+
"sysobjects",
|
|
62
|
+
"syscolumns",
|
|
63
|
+
"sysusers",
|
|
64
|
+
"sysdatabases",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
# Function names that could have side effects (case-insensitive)
|
|
68
|
+
FORBIDDEN_FUNCTIONS: Set[str] = {
|
|
69
|
+
# PostgreSQL
|
|
70
|
+
"pg_sleep",
|
|
71
|
+
"pg_terminate_backend",
|
|
72
|
+
"pg_cancel_backend",
|
|
73
|
+
"pg_reload_conf",
|
|
74
|
+
"set_config",
|
|
75
|
+
"dblink",
|
|
76
|
+
"dblink_exec",
|
|
77
|
+
"lo_import",
|
|
78
|
+
"lo_export",
|
|
79
|
+
"pg_file_write",
|
|
80
|
+
"pg_read_file",
|
|
81
|
+
"pg_ls_dir",
|
|
82
|
+
# SQL Server
|
|
83
|
+
"xp_cmdshell",
|
|
84
|
+
"xp_regread",
|
|
85
|
+
"xp_regwrite",
|
|
86
|
+
"sp_executesql",
|
|
87
|
+
"sp_oacreate",
|
|
88
|
+
"openrowset",
|
|
89
|
+
"opendatasource",
|
|
90
|
+
"bulk",
|
|
91
|
+
# Generic dangerous
|
|
92
|
+
"exec",
|
|
93
|
+
"execute",
|
|
94
|
+
"call",
|
|
95
|
+
"sleep",
|
|
96
|
+
# DuckDB file access (SEC-001: arbitrary file read)
|
|
97
|
+
"read_csv",
|
|
98
|
+
"read_csv_auto",
|
|
99
|
+
"read_parquet",
|
|
100
|
+
"read_json",
|
|
101
|
+
"read_json_auto",
|
|
102
|
+
"read_json_objects",
|
|
103
|
+
"read_blob",
|
|
104
|
+
"read_text",
|
|
105
|
+
"read_ndjson",
|
|
106
|
+
"read_ndjson_auto",
|
|
107
|
+
"read_ndjson_objects",
|
|
108
|
+
# DuckDB file listing/globbing
|
|
109
|
+
"glob",
|
|
110
|
+
"list_files",
|
|
111
|
+
# DuckDB external access
|
|
112
|
+
"httpfs_get",
|
|
113
|
+
"http_get",
|
|
114
|
+
"s3_get",
|
|
115
|
+
# DuckDB query functions that could bypass table reference
|
|
116
|
+
"query",
|
|
117
|
+
"query_table",
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class ValidationResult:
|
|
123
|
+
"""Result of SQL validation."""
|
|
124
|
+
|
|
125
|
+
is_safe: bool
|
|
126
|
+
reason: Optional[str] = None
|
|
127
|
+
parsed_sql: Optional[str] = None # Normalized SQL if parsing succeeded
|
|
128
|
+
dialect: Optional[str] = None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def validate_sql(
|
|
132
|
+
sql: str,
|
|
133
|
+
dialect: str = "postgres",
|
|
134
|
+
allow_cte: bool = True,
|
|
135
|
+
allow_subqueries: bool = True,
|
|
136
|
+
) -> ValidationResult:
|
|
137
|
+
"""
|
|
138
|
+
Validate that SQL is safe for remote execution.
|
|
139
|
+
|
|
140
|
+
A SQL statement is considered safe if:
|
|
141
|
+
1. It parses successfully
|
|
142
|
+
2. It's a SELECT statement (not INSERT, UPDATE, DELETE, etc.)
|
|
143
|
+
3. It doesn't contain forbidden functions
|
|
144
|
+
4. It doesn't contain multiple statements (no SQL injection via ;)
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
sql: The SQL statement to validate
|
|
148
|
+
dialect: SQL dialect for parsing ("postgres", "tsql", "duckdb")
|
|
149
|
+
allow_cte: Allow WITH clauses (CTEs)
|
|
150
|
+
allow_subqueries: Allow subqueries in WHERE/FROM
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
ValidationResult with is_safe=True if SQL is safe, False otherwise
|
|
154
|
+
"""
|
|
155
|
+
sql = sql.strip()
|
|
156
|
+
|
|
157
|
+
if not sql:
|
|
158
|
+
return ValidationResult(is_safe=False, reason="Empty SQL statement")
|
|
159
|
+
|
|
160
|
+
# Map dialect names
|
|
161
|
+
dialect_map = {
|
|
162
|
+
"postgres": "postgres",
|
|
163
|
+
"postgresql": "postgres",
|
|
164
|
+
"sqlserver": "tsql",
|
|
165
|
+
"mssql": "tsql",
|
|
166
|
+
"tsql": "tsql",
|
|
167
|
+
"duckdb": "duckdb",
|
|
168
|
+
}
|
|
169
|
+
sqlglot_dialect = dialect_map.get(dialect.lower(), "postgres")
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
# Parse SQL - this will catch syntax errors
|
|
173
|
+
statements = sqlglot.parse(sql, dialect=sqlglot_dialect)
|
|
174
|
+
except ParseError as e:
|
|
175
|
+
return ValidationResult(
|
|
176
|
+
is_safe=False,
|
|
177
|
+
reason=f"SQL parse error: {e}",
|
|
178
|
+
dialect=sqlglot_dialect,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Must be exactly one statement (no SQL injection via semicolons)
|
|
182
|
+
if len(statements) != 1:
|
|
183
|
+
return ValidationResult(
|
|
184
|
+
is_safe=False,
|
|
185
|
+
reason=f"Expected 1 statement, found {len(statements)}. Multiple statements not allowed.",
|
|
186
|
+
dialect=sqlglot_dialect,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
stmt = statements[0]
|
|
190
|
+
|
|
191
|
+
if stmt is None:
|
|
192
|
+
return ValidationResult(
|
|
193
|
+
is_safe=False,
|
|
194
|
+
reason="Failed to parse SQL statement",
|
|
195
|
+
dialect=sqlglot_dialect,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Check statement type - must be SELECT (or WITH for CTEs)
|
|
199
|
+
is_select = isinstance(stmt, exp.Select)
|
|
200
|
+
is_cte_select = isinstance(stmt, exp.With) and allow_cte
|
|
201
|
+
|
|
202
|
+
if not (is_select or is_cte_select):
|
|
203
|
+
stmt_type = type(stmt).__name__
|
|
204
|
+
return ValidationResult(
|
|
205
|
+
is_safe=False,
|
|
206
|
+
reason=f"Only SELECT statements allowed, found: {stmt_type}",
|
|
207
|
+
dialect=sqlglot_dialect,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Check for forbidden statement types anywhere in the AST
|
|
211
|
+
for node in stmt.walk():
|
|
212
|
+
node_type = type(node)
|
|
213
|
+
if node_type in FORBIDDEN_STATEMENT_TYPES:
|
|
214
|
+
return ValidationResult(
|
|
215
|
+
is_safe=False,
|
|
216
|
+
reason=f"Forbidden operation: {node_type.__name__}",
|
|
217
|
+
dialect=sqlglot_dialect,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Check for forbidden functions
|
|
221
|
+
forbidden_found = _check_forbidden_functions(stmt)
|
|
222
|
+
if forbidden_found:
|
|
223
|
+
return ValidationResult(
|
|
224
|
+
is_safe=False,
|
|
225
|
+
reason=f"Forbidden function: {forbidden_found}",
|
|
226
|
+
dialect=sqlglot_dialect,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Check for system catalog access (information disclosure)
|
|
230
|
+
forbidden_table = _check_forbidden_tables(stmt)
|
|
231
|
+
if forbidden_table:
|
|
232
|
+
return ValidationResult(
|
|
233
|
+
is_safe=False,
|
|
234
|
+
reason=f"Access to system catalog not allowed: {forbidden_table}",
|
|
235
|
+
dialect=sqlglot_dialect,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Check for subqueries if not allowed
|
|
239
|
+
if not allow_subqueries:
|
|
240
|
+
for node in stmt.walk():
|
|
241
|
+
if isinstance(node, exp.Subquery):
|
|
242
|
+
return ValidationResult(
|
|
243
|
+
is_safe=False,
|
|
244
|
+
reason="Subqueries not allowed",
|
|
245
|
+
dialect=sqlglot_dialect,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# SQL is safe - return normalized version
|
|
249
|
+
try:
|
|
250
|
+
normalized = stmt.sql(dialect=sqlglot_dialect)
|
|
251
|
+
except Exception:
|
|
252
|
+
normalized = sql # Fallback to original if normalization fails
|
|
253
|
+
|
|
254
|
+
return ValidationResult(
|
|
255
|
+
is_safe=True,
|
|
256
|
+
parsed_sql=normalized,
|
|
257
|
+
dialect=sqlglot_dialect,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _check_forbidden_functions(stmt: exp.Expression) -> Optional[str]:
|
|
262
|
+
"""
|
|
263
|
+
Check for forbidden function calls in the AST.
|
|
264
|
+
|
|
265
|
+
Returns the name of the forbidden function if found, None otherwise.
|
|
266
|
+
"""
|
|
267
|
+
for node in stmt.walk():
|
|
268
|
+
if isinstance(node, exp.Func):
|
|
269
|
+
# Check function name via multiple methods
|
|
270
|
+
func_name = node.name.lower() if node.name else ""
|
|
271
|
+
if func_name in FORBIDDEN_FUNCTIONS:
|
|
272
|
+
return func_name
|
|
273
|
+
|
|
274
|
+
# Check sql_name() for functions like ReadCSV, ReadParquet
|
|
275
|
+
# that have specific class types
|
|
276
|
+
try:
|
|
277
|
+
sql_name = node.sql_name().lower() if hasattr(node, "sql_name") else ""
|
|
278
|
+
if sql_name in FORBIDDEN_FUNCTIONS:
|
|
279
|
+
return sql_name
|
|
280
|
+
except Exception:
|
|
281
|
+
pass
|
|
282
|
+
|
|
283
|
+
# Check class name directly for specific types
|
|
284
|
+
class_name = type(node).__name__.lower()
|
|
285
|
+
# Map class names to function names
|
|
286
|
+
class_to_func = {
|
|
287
|
+
"readcsv": "read_csv",
|
|
288
|
+
"readparquet": "read_parquet",
|
|
289
|
+
"readjson": "read_json",
|
|
290
|
+
}
|
|
291
|
+
if class_name in class_to_func:
|
|
292
|
+
mapped_name = class_to_func[class_name]
|
|
293
|
+
if mapped_name in FORBIDDEN_FUNCTIONS:
|
|
294
|
+
return mapped_name
|
|
295
|
+
|
|
296
|
+
# Also check for CALL statements disguised as functions
|
|
297
|
+
if isinstance(node, exp.Anonymous):
|
|
298
|
+
name = node.name.lower() if hasattr(node, "name") and node.name else ""
|
|
299
|
+
if name in FORBIDDEN_FUNCTIONS:
|
|
300
|
+
return name
|
|
301
|
+
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _check_forbidden_tables(stmt: exp.Expression) -> Optional[str]:
|
|
306
|
+
"""
|
|
307
|
+
Check for access to forbidden system catalog tables.
|
|
308
|
+
|
|
309
|
+
Walks the AST looking for table references that match system catalog
|
|
310
|
+
patterns (pg_*, sys.*, information_schema.*).
|
|
311
|
+
|
|
312
|
+
Returns the forbidden table reference if found, None otherwise.
|
|
313
|
+
"""
|
|
314
|
+
for node in stmt.walk():
|
|
315
|
+
# Check Table nodes (direct table references)
|
|
316
|
+
if isinstance(node, exp.Table):
|
|
317
|
+
table_name = _get_full_table_name(node)
|
|
318
|
+
if table_name:
|
|
319
|
+
forbidden = _is_forbidden_table(table_name)
|
|
320
|
+
if forbidden:
|
|
321
|
+
return forbidden
|
|
322
|
+
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _get_full_table_name(table_node: exp.Table) -> Optional[str]:
|
|
327
|
+
"""
|
|
328
|
+
Extract the full table name from a Table node, including schema if present.
|
|
329
|
+
|
|
330
|
+
Returns schema.table or just table name.
|
|
331
|
+
"""
|
|
332
|
+
parts = []
|
|
333
|
+
|
|
334
|
+
# Get catalog (database) if present
|
|
335
|
+
if table_node.catalog:
|
|
336
|
+
parts.append(str(table_node.catalog))
|
|
337
|
+
|
|
338
|
+
# Get schema if present
|
|
339
|
+
if table_node.db:
|
|
340
|
+
parts.append(str(table_node.db))
|
|
341
|
+
|
|
342
|
+
# Get table name
|
|
343
|
+
if table_node.name:
|
|
344
|
+
parts.append(str(table_node.name))
|
|
345
|
+
|
|
346
|
+
if parts:
|
|
347
|
+
return ".".join(parts)
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def _is_forbidden_table(table_ref: str) -> Optional[str]:
|
|
352
|
+
"""
|
|
353
|
+
Check if a table reference matches forbidden patterns.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
table_ref: Full table reference (e.g., "pg_user", "sys.tables", "information_schema.columns")
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
The forbidden table reference if matched, None otherwise.
|
|
360
|
+
"""
|
|
361
|
+
table_lower = table_ref.lower()
|
|
362
|
+
|
|
363
|
+
# Check exact matches first
|
|
364
|
+
# Handle both "pg_user" and "public.pg_user" etc.
|
|
365
|
+
table_parts = table_lower.split(".")
|
|
366
|
+
base_table = table_parts[-1] # Last part is the table name
|
|
367
|
+
|
|
368
|
+
if base_table in FORBIDDEN_TABLES:
|
|
369
|
+
return table_ref
|
|
370
|
+
|
|
371
|
+
# Check prefixes (handles schema.table patterns)
|
|
372
|
+
for prefix in FORBIDDEN_TABLE_PREFIXES:
|
|
373
|
+
# Check if the full reference starts with the prefix
|
|
374
|
+
if table_lower.startswith(prefix):
|
|
375
|
+
return table_ref
|
|
376
|
+
# Also check if any part starts with the prefix
|
|
377
|
+
for part in table_parts:
|
|
378
|
+
if part.startswith(prefix.rstrip(".")):
|
|
379
|
+
return table_ref
|
|
380
|
+
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def transpile_sql(
|
|
385
|
+
sql: str,
|
|
386
|
+
from_dialect: str,
|
|
387
|
+
to_dialect: str,
|
|
388
|
+
) -> Tuple[bool, str]:
|
|
389
|
+
"""
|
|
390
|
+
Transpile SQL from one dialect to another.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
sql: The SQL statement to transpile
|
|
394
|
+
from_dialect: Source dialect ("postgres", "tsql", "duckdb")
|
|
395
|
+
to_dialect: Target dialect
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Tuple of (success, result_sql_or_error)
|
|
399
|
+
"""
|
|
400
|
+
dialect_map = {
|
|
401
|
+
"postgres": "postgres",
|
|
402
|
+
"postgresql": "postgres",
|
|
403
|
+
"sqlserver": "tsql",
|
|
404
|
+
"mssql": "tsql",
|
|
405
|
+
"tsql": "tsql",
|
|
406
|
+
"duckdb": "duckdb",
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
src = dialect_map.get(from_dialect.lower(), from_dialect)
|
|
410
|
+
dst = dialect_map.get(to_dialect.lower(), to_dialect)
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
result = sqlglot.transpile(sql, read=src, write=dst)
|
|
414
|
+
if result:
|
|
415
|
+
return True, result[0]
|
|
416
|
+
return False, "Transpilation returned empty result"
|
|
417
|
+
except Exception as e:
|
|
418
|
+
return False, str(e)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def format_table_reference(
|
|
422
|
+
schema: str,
|
|
423
|
+
table: str,
|
|
424
|
+
dialect: str,
|
|
425
|
+
) -> str:
|
|
426
|
+
"""
|
|
427
|
+
Format a table reference for a specific SQL dialect.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
schema: Schema name (e.g., "public", "dbo")
|
|
431
|
+
table: Table name
|
|
432
|
+
dialect: SQL dialect ("postgres", "sqlserver", "duckdb")
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Properly quoted table reference
|
|
436
|
+
"""
|
|
437
|
+
dialect = dialect.lower()
|
|
438
|
+
|
|
439
|
+
if dialect in ("postgres", "postgresql", "duckdb"):
|
|
440
|
+
# PostgreSQL/DuckDB: "schema"."table"
|
|
441
|
+
return f'"{schema}"."{table}"'
|
|
442
|
+
elif dialect in ("sqlserver", "mssql", "tsql"):
|
|
443
|
+
# SQL Server: [schema].[table]
|
|
444
|
+
return f"[{schema}].[{table}]"
|
|
445
|
+
else:
|
|
446
|
+
# Default: schema.table
|
|
447
|
+
return f"{schema}.{table}"
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def replace_table_placeholder(
|
|
451
|
+
sql: str,
|
|
452
|
+
schema: str,
|
|
453
|
+
table: str,
|
|
454
|
+
dialect: str,
|
|
455
|
+
placeholder: str = "{table}",
|
|
456
|
+
) -> str:
|
|
457
|
+
"""
|
|
458
|
+
Replace {table} placeholder with properly formatted table reference.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
sql: SQL with placeholder
|
|
462
|
+
schema: Schema name
|
|
463
|
+
table: Table name
|
|
464
|
+
dialect: SQL dialect
|
|
465
|
+
placeholder: Placeholder string to replace (default: "{table}")
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
SQL with placeholder replaced
|
|
469
|
+
"""
|
|
470
|
+
table_ref = format_table_reference(schema, table, dialect)
|
|
471
|
+
return sql.replace(placeholder, table_ref)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def to_count_query(sql: str, dialect: str = "postgres") -> Tuple[bool, str]:
|
|
475
|
+
"""
|
|
476
|
+
Transform a SELECT query into a COUNT(*) query for violation counting.
|
|
477
|
+
|
|
478
|
+
Strategy:
|
|
479
|
+
- Simple SELECT (no DISTINCT, GROUP BY, LIMIT): Rewrite SELECT to COUNT(*)
|
|
480
|
+
- Complex SELECT (has DISTINCT, GROUP BY, or LIMIT): Wrap in COUNT(*)
|
|
481
|
+
|
|
482
|
+
Examples:
|
|
483
|
+
SELECT * FROM t WHERE x < 0
|
|
484
|
+
→ SELECT COUNT(*) FROM t WHERE x < 0
|
|
485
|
+
|
|
486
|
+
SELECT DISTINCT region FROM t
|
|
487
|
+
→ SELECT COUNT(*) FROM (SELECT DISTINCT region FROM t) AS _v
|
|
488
|
+
|
|
489
|
+
SELECT a FROM t GROUP BY a HAVING COUNT(*) > 1
|
|
490
|
+
→ SELECT COUNT(*) FROM (SELECT a FROM t GROUP BY a HAVING COUNT(*) > 1) AS _v
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
sql: The SELECT query to transform
|
|
494
|
+
dialect: SQL dialect ("postgres", "sqlserver", "duckdb")
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Tuple of (success, transformed_sql_or_error)
|
|
498
|
+
"""
|
|
499
|
+
# Map dialect names
|
|
500
|
+
dialect_map = {
|
|
501
|
+
"postgres": "postgres",
|
|
502
|
+
"postgresql": "postgres",
|
|
503
|
+
"sqlserver": "tsql",
|
|
504
|
+
"mssql": "tsql",
|
|
505
|
+
"tsql": "tsql",
|
|
506
|
+
"duckdb": "duckdb",
|
|
507
|
+
}
|
|
508
|
+
sqlglot_dialect = dialect_map.get(dialect.lower(), "postgres")
|
|
509
|
+
|
|
510
|
+
try:
|
|
511
|
+
parsed = sqlglot.parse_one(sql, dialect=sqlglot_dialect)
|
|
512
|
+
except ParseError as e:
|
|
513
|
+
return False, f"SQL parse error: {e}"
|
|
514
|
+
|
|
515
|
+
if parsed is None:
|
|
516
|
+
return False, "Failed to parse SQL"
|
|
517
|
+
|
|
518
|
+
# Verify it's a SELECT statement (or WITH/CTE)
|
|
519
|
+
if not isinstance(parsed, (exp.Select, exp.With)):
|
|
520
|
+
return False, f"Expected SELECT statement, got {type(parsed).__name__}"
|
|
521
|
+
|
|
522
|
+
# Check if we need to wrap (complex query) or can rewrite (simple query)
|
|
523
|
+
needs_wrap = _needs_wrapping(parsed)
|
|
524
|
+
|
|
525
|
+
if needs_wrap:
|
|
526
|
+
# Wrap: SELECT COUNT(*) FROM (...) AS _v
|
|
527
|
+
result = _wrap_in_count(parsed, sqlglot_dialect)
|
|
528
|
+
else:
|
|
529
|
+
# Rewrite: Replace SELECT expressions with COUNT(*)
|
|
530
|
+
result = _rewrite_to_count(parsed, sqlglot_dialect)
|
|
531
|
+
|
|
532
|
+
return True, result
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def _needs_wrapping(parsed: exp.Expression) -> bool:
|
|
536
|
+
"""
|
|
537
|
+
Check if a query needs wrapping vs simple rewriting.
|
|
538
|
+
|
|
539
|
+
Needs wrapping if:
|
|
540
|
+
- Has DISTINCT (changing SELECT would change result set)
|
|
541
|
+
- Has GROUP BY (rewriting would return multiple rows)
|
|
542
|
+
- Has LIMIT/OFFSET (rewriting would ignore the limit)
|
|
543
|
+
- Has UNION/INTERSECT/EXCEPT (compound queries)
|
|
544
|
+
"""
|
|
545
|
+
# Check for DISTINCT in the main SELECT
|
|
546
|
+
if isinstance(parsed, exp.Select):
|
|
547
|
+
if parsed.args.get("distinct"):
|
|
548
|
+
return True
|
|
549
|
+
|
|
550
|
+
# Check for GROUP BY
|
|
551
|
+
if parsed.find(exp.Group):
|
|
552
|
+
return True
|
|
553
|
+
|
|
554
|
+
# Check for LIMIT or OFFSET
|
|
555
|
+
if parsed.find(exp.Limit) or parsed.find(exp.Offset):
|
|
556
|
+
return True
|
|
557
|
+
|
|
558
|
+
# Check for set operations (UNION, INTERSECT, EXCEPT)
|
|
559
|
+
if parsed.find(exp.Union) or parsed.find(exp.Intersect) or parsed.find(exp.Except):
|
|
560
|
+
return True
|
|
561
|
+
|
|
562
|
+
# Check for WITH (CTE) - wrap to be safe
|
|
563
|
+
if isinstance(parsed, exp.With):
|
|
564
|
+
return True
|
|
565
|
+
|
|
566
|
+
return False
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _wrap_in_count(parsed: exp.Expression, dialect: str) -> str:
|
|
570
|
+
"""
|
|
571
|
+
Wrap a query in SELECT COUNT(*) FROM (...) AS _v.
|
|
572
|
+
"""
|
|
573
|
+
# Create: SELECT COUNT(*) FROM (original_query) AS _v
|
|
574
|
+
count_star = exp.Count(this=exp.Star())
|
|
575
|
+
|
|
576
|
+
# Handle different expression types
|
|
577
|
+
if hasattr(parsed, "subquery"):
|
|
578
|
+
subquery = parsed.subquery(alias="_v")
|
|
579
|
+
else:
|
|
580
|
+
# Fallback: wrap in parentheses manually
|
|
581
|
+
subquery = exp.Subquery(this=parsed, alias=exp.TableAlias(this=exp.Identifier(this="_v")))
|
|
582
|
+
|
|
583
|
+
wrapped = exp.Select(expressions=[count_star]).from_(subquery)
|
|
584
|
+
|
|
585
|
+
return wrapped.sql(dialect=dialect)
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _rewrite_to_count(parsed: exp.Expression, dialect: str) -> str:
|
|
589
|
+
"""
|
|
590
|
+
Rewrite a simple SELECT to use COUNT(*) instead of column expressions.
|
|
591
|
+
|
|
592
|
+
SELECT a, b, c FROM t WHERE x < 0
|
|
593
|
+
→ SELECT COUNT(*) FROM t WHERE x < 0
|
|
594
|
+
"""
|
|
595
|
+
if not isinstance(parsed, exp.Select):
|
|
596
|
+
# Fallback to wrapping for non-SELECT
|
|
597
|
+
return _wrap_in_count(parsed, dialect)
|
|
598
|
+
|
|
599
|
+
# Create COUNT(*) expression
|
|
600
|
+
count_star = exp.Count(this=exp.Star())
|
|
601
|
+
|
|
602
|
+
# Replace the SELECT expressions with COUNT(*)
|
|
603
|
+
parsed.set("expressions", [count_star])
|
|
604
|
+
|
|
605
|
+
# Remove any DISTINCT (shouldn't be here, but just in case)
|
|
606
|
+
if parsed.args.get("distinct"):
|
|
607
|
+
parsed.set("distinct", None)
|
|
608
|
+
|
|
609
|
+
return parsed.sql(dialect=dialect)
|