kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,609 @@
1
+ # src/kontra/engine/sql_validator.py
2
+ """
3
+ SQL validation using sqlglot for safe remote execution.
4
+
5
+ Ensures user-provided SQL is read-only before executing on production databases.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import List, Optional, Set, Tuple
12
+
13
+ import sqlglot
14
+ from sqlglot import exp
15
+ from sqlglot.errors import ParseError
16
+
17
+
18
+ # Statement types that are NOT allowed (write operations and external access)
19
+ FORBIDDEN_STATEMENT_TYPES: Set[type] = {
20
+ exp.Insert,
21
+ exp.Update,
22
+ exp.Delete,
23
+ exp.Drop,
24
+ exp.Create,
25
+ exp.Alter,
26
+ exp.Merge,
27
+ exp.Grant,
28
+ exp.Revoke,
29
+ exp.Command, # Generic command execution
30
+ exp.Copy, # COPY command (file I/O)
31
+ exp.Set, # SET commands (configuration changes)
32
+ exp.Use, # USE database (context switching)
33
+ exp.Attach, # ATTACH external databases (SEC-002)
34
+ }
35
+
36
+ # Table prefixes/schemas that are forbidden (system catalogs)
37
+ # These allow information disclosure attacks
38
+ FORBIDDEN_TABLE_PREFIXES: Set[str] = {
39
+ # PostgreSQL system catalogs
40
+ "pg_",
41
+ # SQL Server system views
42
+ "sys.",
43
+ # Standard information schema (both PostgreSQL and SQL Server)
44
+ "information_schema.",
45
+ }
46
+
47
+ # Specific system tables to block (without prefix, for tables accessed directly)
48
+ FORBIDDEN_TABLES: Set[str] = {
49
+ # PostgreSQL sensitive tables
50
+ "pg_shadow",
51
+ "pg_authid",
52
+ "pg_roles",
53
+ "pg_user",
54
+ "pg_database",
55
+ "pg_tablespace",
56
+ "pg_settings",
57
+ "pg_stat_activity",
58
+ "pg_stat_user_tables",
59
+ # SQL Server sensitive tables
60
+ "syslogins",
61
+ "sysobjects",
62
+ "syscolumns",
63
+ "sysusers",
64
+ "sysdatabases",
65
+ }
66
+
67
+ # Function names that could have side effects (case-insensitive)
68
+ FORBIDDEN_FUNCTIONS: Set[str] = {
69
+ # PostgreSQL
70
+ "pg_sleep",
71
+ "pg_terminate_backend",
72
+ "pg_cancel_backend",
73
+ "pg_reload_conf",
74
+ "set_config",
75
+ "dblink",
76
+ "dblink_exec",
77
+ "lo_import",
78
+ "lo_export",
79
+ "pg_file_write",
80
+ "pg_read_file",
81
+ "pg_ls_dir",
82
+ # SQL Server
83
+ "xp_cmdshell",
84
+ "xp_regread",
85
+ "xp_regwrite",
86
+ "sp_executesql",
87
+ "sp_oacreate",
88
+ "openrowset",
89
+ "opendatasource",
90
+ "bulk",
91
+ # Generic dangerous
92
+ "exec",
93
+ "execute",
94
+ "call",
95
+ "sleep",
96
+ # DuckDB file access (SEC-001: arbitrary file read)
97
+ "read_csv",
98
+ "read_csv_auto",
99
+ "read_parquet",
100
+ "read_json",
101
+ "read_json_auto",
102
+ "read_json_objects",
103
+ "read_blob",
104
+ "read_text",
105
+ "read_ndjson",
106
+ "read_ndjson_auto",
107
+ "read_ndjson_objects",
108
+ # DuckDB file listing/globbing
109
+ "glob",
110
+ "list_files",
111
+ # DuckDB external access
112
+ "httpfs_get",
113
+ "http_get",
114
+ "s3_get",
115
+ # DuckDB query functions that could bypass table reference
116
+ "query",
117
+ "query_table",
118
+ }
119
+
120
+
121
+ @dataclass
122
+ class ValidationResult:
123
+ """Result of SQL validation."""
124
+
125
+ is_safe: bool
126
+ reason: Optional[str] = None
127
+ parsed_sql: Optional[str] = None # Normalized SQL if parsing succeeded
128
+ dialect: Optional[str] = None
129
+
130
+
131
+ def validate_sql(
132
+ sql: str,
133
+ dialect: str = "postgres",
134
+ allow_cte: bool = True,
135
+ allow_subqueries: bool = True,
136
+ ) -> ValidationResult:
137
+ """
138
+ Validate that SQL is safe for remote execution.
139
+
140
+ A SQL statement is considered safe if:
141
+ 1. It parses successfully
142
+ 2. It's a SELECT statement (not INSERT, UPDATE, DELETE, etc.)
143
+ 3. It doesn't contain forbidden functions
144
+ 4. It doesn't contain multiple statements (no SQL injection via ;)
145
+
146
+ Args:
147
+ sql: The SQL statement to validate
148
+ dialect: SQL dialect for parsing ("postgres", "tsql", "duckdb")
149
+ allow_cte: Allow WITH clauses (CTEs)
150
+ allow_subqueries: Allow subqueries in WHERE/FROM
151
+
152
+ Returns:
153
+ ValidationResult with is_safe=True if SQL is safe, False otherwise
154
+ """
155
+ sql = sql.strip()
156
+
157
+ if not sql:
158
+ return ValidationResult(is_safe=False, reason="Empty SQL statement")
159
+
160
+ # Map dialect names
161
+ dialect_map = {
162
+ "postgres": "postgres",
163
+ "postgresql": "postgres",
164
+ "sqlserver": "tsql",
165
+ "mssql": "tsql",
166
+ "tsql": "tsql",
167
+ "duckdb": "duckdb",
168
+ }
169
+ sqlglot_dialect = dialect_map.get(dialect.lower(), "postgres")
170
+
171
+ try:
172
+ # Parse SQL - this will catch syntax errors
173
+ statements = sqlglot.parse(sql, dialect=sqlglot_dialect)
174
+ except ParseError as e:
175
+ return ValidationResult(
176
+ is_safe=False,
177
+ reason=f"SQL parse error: {e}",
178
+ dialect=sqlglot_dialect,
179
+ )
180
+
181
+ # Must be exactly one statement (no SQL injection via semicolons)
182
+ if len(statements) != 1:
183
+ return ValidationResult(
184
+ is_safe=False,
185
+ reason=f"Expected 1 statement, found {len(statements)}. Multiple statements not allowed.",
186
+ dialect=sqlglot_dialect,
187
+ )
188
+
189
+ stmt = statements[0]
190
+
191
+ if stmt is None:
192
+ return ValidationResult(
193
+ is_safe=False,
194
+ reason="Failed to parse SQL statement",
195
+ dialect=sqlglot_dialect,
196
+ )
197
+
198
+ # Check statement type - must be SELECT (or WITH for CTEs)
199
+ is_select = isinstance(stmt, exp.Select)
200
+ is_cte_select = isinstance(stmt, exp.With) and allow_cte
201
+
202
+ if not (is_select or is_cte_select):
203
+ stmt_type = type(stmt).__name__
204
+ return ValidationResult(
205
+ is_safe=False,
206
+ reason=f"Only SELECT statements allowed, found: {stmt_type}",
207
+ dialect=sqlglot_dialect,
208
+ )
209
+
210
+ # Check for forbidden statement types anywhere in the AST
211
+ for node in stmt.walk():
212
+ node_type = type(node)
213
+ if node_type in FORBIDDEN_STATEMENT_TYPES:
214
+ return ValidationResult(
215
+ is_safe=False,
216
+ reason=f"Forbidden operation: {node_type.__name__}",
217
+ dialect=sqlglot_dialect,
218
+ )
219
+
220
+ # Check for forbidden functions
221
+ forbidden_found = _check_forbidden_functions(stmt)
222
+ if forbidden_found:
223
+ return ValidationResult(
224
+ is_safe=False,
225
+ reason=f"Forbidden function: {forbidden_found}",
226
+ dialect=sqlglot_dialect,
227
+ )
228
+
229
+ # Check for system catalog access (information disclosure)
230
+ forbidden_table = _check_forbidden_tables(stmt)
231
+ if forbidden_table:
232
+ return ValidationResult(
233
+ is_safe=False,
234
+ reason=f"Access to system catalog not allowed: {forbidden_table}",
235
+ dialect=sqlglot_dialect,
236
+ )
237
+
238
+ # Check for subqueries if not allowed
239
+ if not allow_subqueries:
240
+ for node in stmt.walk():
241
+ if isinstance(node, exp.Subquery):
242
+ return ValidationResult(
243
+ is_safe=False,
244
+ reason="Subqueries not allowed",
245
+ dialect=sqlglot_dialect,
246
+ )
247
+
248
+ # SQL is safe - return normalized version
249
+ try:
250
+ normalized = stmt.sql(dialect=sqlglot_dialect)
251
+ except Exception:
252
+ normalized = sql # Fallback to original if normalization fails
253
+
254
+ return ValidationResult(
255
+ is_safe=True,
256
+ parsed_sql=normalized,
257
+ dialect=sqlglot_dialect,
258
+ )
259
+
260
+
261
+ def _check_forbidden_functions(stmt: exp.Expression) -> Optional[str]:
262
+ """
263
+ Check for forbidden function calls in the AST.
264
+
265
+ Returns the name of the forbidden function if found, None otherwise.
266
+ """
267
+ for node in stmt.walk():
268
+ if isinstance(node, exp.Func):
269
+ # Check function name via multiple methods
270
+ func_name = node.name.lower() if node.name else ""
271
+ if func_name in FORBIDDEN_FUNCTIONS:
272
+ return func_name
273
+
274
+ # Check sql_name() for functions like ReadCSV, ReadParquet
275
+ # that have specific class types
276
+ try:
277
+ sql_name = node.sql_name().lower() if hasattr(node, "sql_name") else ""
278
+ if sql_name in FORBIDDEN_FUNCTIONS:
279
+ return sql_name
280
+ except Exception:
281
+ pass
282
+
283
+ # Check class name directly for specific types
284
+ class_name = type(node).__name__.lower()
285
+ # Map class names to function names
286
+ class_to_func = {
287
+ "readcsv": "read_csv",
288
+ "readparquet": "read_parquet",
289
+ "readjson": "read_json",
290
+ }
291
+ if class_name in class_to_func:
292
+ mapped_name = class_to_func[class_name]
293
+ if mapped_name in FORBIDDEN_FUNCTIONS:
294
+ return mapped_name
295
+
296
+ # Also check for CALL statements disguised as functions
297
+ if isinstance(node, exp.Anonymous):
298
+ name = node.name.lower() if hasattr(node, "name") and node.name else ""
299
+ if name in FORBIDDEN_FUNCTIONS:
300
+ return name
301
+
302
+ return None
303
+
304
+
305
+ def _check_forbidden_tables(stmt: exp.Expression) -> Optional[str]:
306
+ """
307
+ Check for access to forbidden system catalog tables.
308
+
309
+ Walks the AST looking for table references that match system catalog
310
+ patterns (pg_*, sys.*, information_schema.*).
311
+
312
+ Returns the forbidden table reference if found, None otherwise.
313
+ """
314
+ for node in stmt.walk():
315
+ # Check Table nodes (direct table references)
316
+ if isinstance(node, exp.Table):
317
+ table_name = _get_full_table_name(node)
318
+ if table_name:
319
+ forbidden = _is_forbidden_table(table_name)
320
+ if forbidden:
321
+ return forbidden
322
+
323
+ return None
324
+
325
+
326
+ def _get_full_table_name(table_node: exp.Table) -> Optional[str]:
327
+ """
328
+ Extract the full table name from a Table node, including schema if present.
329
+
330
+ Returns schema.table or just table name.
331
+ """
332
+ parts = []
333
+
334
+ # Get catalog (database) if present
335
+ if table_node.catalog:
336
+ parts.append(str(table_node.catalog))
337
+
338
+ # Get schema if present
339
+ if table_node.db:
340
+ parts.append(str(table_node.db))
341
+
342
+ # Get table name
343
+ if table_node.name:
344
+ parts.append(str(table_node.name))
345
+
346
+ if parts:
347
+ return ".".join(parts)
348
+ return None
349
+
350
+
351
+ def _is_forbidden_table(table_ref: str) -> Optional[str]:
352
+ """
353
+ Check if a table reference matches forbidden patterns.
354
+
355
+ Args:
356
+ table_ref: Full table reference (e.g., "pg_user", "sys.tables", "information_schema.columns")
357
+
358
+ Returns:
359
+ The forbidden table reference if matched, None otherwise.
360
+ """
361
+ table_lower = table_ref.lower()
362
+
363
+ # Check exact matches first
364
+ # Handle both "pg_user" and "public.pg_user" etc.
365
+ table_parts = table_lower.split(".")
366
+ base_table = table_parts[-1] # Last part is the table name
367
+
368
+ if base_table in FORBIDDEN_TABLES:
369
+ return table_ref
370
+
371
+ # Check prefixes (handles schema.table patterns)
372
+ for prefix in FORBIDDEN_TABLE_PREFIXES:
373
+ # Check if the full reference starts with the prefix
374
+ if table_lower.startswith(prefix):
375
+ return table_ref
376
+ # Also check if any part starts with the prefix
377
+ for part in table_parts:
378
+ if part.startswith(prefix.rstrip(".")):
379
+ return table_ref
380
+
381
+ return None
382
+
383
+
384
+ def transpile_sql(
385
+ sql: str,
386
+ from_dialect: str,
387
+ to_dialect: str,
388
+ ) -> Tuple[bool, str]:
389
+ """
390
+ Transpile SQL from one dialect to another.
391
+
392
+ Args:
393
+ sql: The SQL statement to transpile
394
+ from_dialect: Source dialect ("postgres", "tsql", "duckdb")
395
+ to_dialect: Target dialect
396
+
397
+ Returns:
398
+ Tuple of (success, result_sql_or_error)
399
+ """
400
+ dialect_map = {
401
+ "postgres": "postgres",
402
+ "postgresql": "postgres",
403
+ "sqlserver": "tsql",
404
+ "mssql": "tsql",
405
+ "tsql": "tsql",
406
+ "duckdb": "duckdb",
407
+ }
408
+
409
+ src = dialect_map.get(from_dialect.lower(), from_dialect)
410
+ dst = dialect_map.get(to_dialect.lower(), to_dialect)
411
+
412
+ try:
413
+ result = sqlglot.transpile(sql, read=src, write=dst)
414
+ if result:
415
+ return True, result[0]
416
+ return False, "Transpilation returned empty result"
417
+ except Exception as e:
418
+ return False, str(e)
419
+
420
+
421
+ def format_table_reference(
422
+ schema: str,
423
+ table: str,
424
+ dialect: str,
425
+ ) -> str:
426
+ """
427
+ Format a table reference for a specific SQL dialect.
428
+
429
+ Args:
430
+ schema: Schema name (e.g., "public", "dbo")
431
+ table: Table name
432
+ dialect: SQL dialect ("postgres", "sqlserver", "duckdb")
433
+
434
+ Returns:
435
+ Properly quoted table reference
436
+ """
437
+ dialect = dialect.lower()
438
+
439
+ if dialect in ("postgres", "postgresql", "duckdb"):
440
+ # PostgreSQL/DuckDB: "schema"."table"
441
+ return f'"{schema}"."{table}"'
442
+ elif dialect in ("sqlserver", "mssql", "tsql"):
443
+ # SQL Server: [schema].[table]
444
+ return f"[{schema}].[{table}]"
445
+ else:
446
+ # Default: schema.table
447
+ return f"{schema}.{table}"
448
+
449
+
450
+ def replace_table_placeholder(
451
+ sql: str,
452
+ schema: str,
453
+ table: str,
454
+ dialect: str,
455
+ placeholder: str = "{table}",
456
+ ) -> str:
457
+ """
458
+ Replace {table} placeholder with properly formatted table reference.
459
+
460
+ Args:
461
+ sql: SQL with placeholder
462
+ schema: Schema name
463
+ table: Table name
464
+ dialect: SQL dialect
465
+ placeholder: Placeholder string to replace (default: "{table}")
466
+
467
+ Returns:
468
+ SQL with placeholder replaced
469
+ """
470
+ table_ref = format_table_reference(schema, table, dialect)
471
+ return sql.replace(placeholder, table_ref)
472
+
473
+
474
+ def to_count_query(sql: str, dialect: str = "postgres") -> Tuple[bool, str]:
475
+ """
476
+ Transform a SELECT query into a COUNT(*) query for violation counting.
477
+
478
+ Strategy:
479
+ - Simple SELECT (no DISTINCT, GROUP BY, LIMIT): Rewrite SELECT to COUNT(*)
480
+ - Complex SELECT (has DISTINCT, GROUP BY, or LIMIT): Wrap in COUNT(*)
481
+
482
+ Examples:
483
+ SELECT * FROM t WHERE x < 0
484
+ → SELECT COUNT(*) FROM t WHERE x < 0
485
+
486
+ SELECT DISTINCT region FROM t
487
+ → SELECT COUNT(*) FROM (SELECT DISTINCT region FROM t) AS _v
488
+
489
+ SELECT a FROM t GROUP BY a HAVING COUNT(*) > 1
490
+ → SELECT COUNT(*) FROM (SELECT a FROM t GROUP BY a HAVING COUNT(*) > 1) AS _v
491
+
492
+ Args:
493
+ sql: The SELECT query to transform
494
+ dialect: SQL dialect ("postgres", "sqlserver", "duckdb")
495
+
496
+ Returns:
497
+ Tuple of (success, transformed_sql_or_error)
498
+ """
499
+ # Map dialect names
500
+ dialect_map = {
501
+ "postgres": "postgres",
502
+ "postgresql": "postgres",
503
+ "sqlserver": "tsql",
504
+ "mssql": "tsql",
505
+ "tsql": "tsql",
506
+ "duckdb": "duckdb",
507
+ }
508
+ sqlglot_dialect = dialect_map.get(dialect.lower(), "postgres")
509
+
510
+ try:
511
+ parsed = sqlglot.parse_one(sql, dialect=sqlglot_dialect)
512
+ except ParseError as e:
513
+ return False, f"SQL parse error: {e}"
514
+
515
+ if parsed is None:
516
+ return False, "Failed to parse SQL"
517
+
518
+ # Verify it's a SELECT statement (or WITH/CTE)
519
+ if not isinstance(parsed, (exp.Select, exp.With)):
520
+ return False, f"Expected SELECT statement, got {type(parsed).__name__}"
521
+
522
+ # Check if we need to wrap (complex query) or can rewrite (simple query)
523
+ needs_wrap = _needs_wrapping(parsed)
524
+
525
+ if needs_wrap:
526
+ # Wrap: SELECT COUNT(*) FROM (...) AS _v
527
+ result = _wrap_in_count(parsed, sqlglot_dialect)
528
+ else:
529
+ # Rewrite: Replace SELECT expressions with COUNT(*)
530
+ result = _rewrite_to_count(parsed, sqlglot_dialect)
531
+
532
+ return True, result
533
+
534
+
535
+ def _needs_wrapping(parsed: exp.Expression) -> bool:
536
+ """
537
+ Check if a query needs wrapping vs simple rewriting.
538
+
539
+ Needs wrapping if:
540
+ - Has DISTINCT (changing SELECT would change result set)
541
+ - Has GROUP BY (rewriting would return multiple rows)
542
+ - Has LIMIT/OFFSET (rewriting would ignore the limit)
543
+ - Has UNION/INTERSECT/EXCEPT (compound queries)
544
+ """
545
+ # Check for DISTINCT in the main SELECT
546
+ if isinstance(parsed, exp.Select):
547
+ if parsed.args.get("distinct"):
548
+ return True
549
+
550
+ # Check for GROUP BY
551
+ if parsed.find(exp.Group):
552
+ return True
553
+
554
+ # Check for LIMIT or OFFSET
555
+ if parsed.find(exp.Limit) or parsed.find(exp.Offset):
556
+ return True
557
+
558
+ # Check for set operations (UNION, INTERSECT, EXCEPT)
559
+ if parsed.find(exp.Union) or parsed.find(exp.Intersect) or parsed.find(exp.Except):
560
+ return True
561
+
562
+ # Check for WITH (CTE) - wrap to be safe
563
+ if isinstance(parsed, exp.With):
564
+ return True
565
+
566
+ return False
567
+
568
+
569
+ def _wrap_in_count(parsed: exp.Expression, dialect: str) -> str:
570
+ """
571
+ Wrap a query in SELECT COUNT(*) FROM (...) AS _v.
572
+ """
573
+ # Create: SELECT COUNT(*) FROM (original_query) AS _v
574
+ count_star = exp.Count(this=exp.Star())
575
+
576
+ # Handle different expression types
577
+ if hasattr(parsed, "subquery"):
578
+ subquery = parsed.subquery(alias="_v")
579
+ else:
580
+ # Fallback: wrap in parentheses manually
581
+ subquery = exp.Subquery(this=parsed, alias=exp.TableAlias(this=exp.Identifier(this="_v")))
582
+
583
+ wrapped = exp.Select(expressions=[count_star]).from_(subquery)
584
+
585
+ return wrapped.sql(dialect=dialect)
586
+
587
+
588
+ def _rewrite_to_count(parsed: exp.Expression, dialect: str) -> str:
589
+ """
590
+ Rewrite a simple SELECT to use COUNT(*) instead of column expressions.
591
+
592
+ SELECT a, b, c FROM t WHERE x < 0
593
+ → SELECT COUNT(*) FROM t WHERE x < 0
594
+ """
595
+ if not isinstance(parsed, exp.Select):
596
+ # Fallback to wrapping for non-SELECT
597
+ return _wrap_in_count(parsed, dialect)
598
+
599
+ # Create COUNT(*) expression
600
+ count_star = exp.Count(this=exp.Star())
601
+
602
+ # Replace the SELECT expressions with COUNT(*)
603
+ parsed.set("expressions", [count_star])
604
+
605
+ # Remove any DISTINCT (shouldn't be here, but just in case)
606
+ if parsed.args.get("distinct"):
607
+ parsed.set("distinct", None)
608
+
609
+ return parsed.sql(dialect=dialect)