sqlassert 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlassert/__init__.py +13 -0
- sqlassert/unique.py +605 -0
- sqlassert-0.1.0.dist-info/METADATA +173 -0
- sqlassert-0.1.0.dist-info/RECORD +6 -0
- sqlassert-0.1.0.dist-info/WHEEL +5 -0
- sqlassert-0.1.0.dist-info/top_level.txt +1 -0
sqlassert/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from sqlassert.unique import (
|
|
2
|
+
UniqueJoinCheckResult,
|
|
3
|
+
UniqueJoinValidationResult,
|
|
4
|
+
unique_assertions,
|
|
5
|
+
validate_unique_joins,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"UniqueJoinCheckResult",
|
|
10
|
+
"UniqueJoinValidationResult",
|
|
11
|
+
"unique_assertions",
|
|
12
|
+
"validate_unique_joins",
|
|
13
|
+
]
|
sqlassert/unique.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import sqlglot
|
|
9
|
+
from sqlglot import exp
|
|
10
|
+
from sqlglot.errors import SqlglotError
|
|
11
|
+
from sqlglot.tokens import TokenType, Tokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
FALSE_ASSERTION = "select false"
|
|
15
|
+
MARKER = re.compile(re.escape(r"/**unique**/"), flags=re.IGNORECASE)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class UniqueJoinCheckResult:
|
|
20
|
+
marker_index: int
|
|
21
|
+
valid: bool
|
|
22
|
+
reason: str
|
|
23
|
+
inferred_key_columns: tuple[str, ...] = ()
|
|
24
|
+
constrained_key_columns: tuple[str, ...] = ()
|
|
25
|
+
unique_constraints: tuple[tuple[str, ...], ...] = ()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class UniqueJoinValidationResult:
|
|
30
|
+
valid: bool
|
|
31
|
+
reason: str
|
|
32
|
+
checks: tuple[UniqueJoinCheckResult, ...] = ()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class _UniqueJoinPlan:
|
|
37
|
+
marker_index: int
|
|
38
|
+
reason: str
|
|
39
|
+
join_sql: str = ""
|
|
40
|
+
rhs: exp.Expression | None = None
|
|
41
|
+
rhs_names: frozenset[str] = frozenset()
|
|
42
|
+
keys: tuple[exp.Column, ...] = ()
|
|
43
|
+
rhs_filter_columns: tuple[exp.Column, ...] = ()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def validate_unique_joins(
|
|
47
|
+
connection: Any,
|
|
48
|
+
sql: str,
|
|
49
|
+
dialect: str = "duckdb",
|
|
50
|
+
) -> UniqueJoinValidationResult:
|
|
51
|
+
plans = _unique_join_plans(sql, dialect)
|
|
52
|
+
if not plans:
|
|
53
|
+
return UniqueJoinValidationResult(True, "no unique join markers found")
|
|
54
|
+
|
|
55
|
+
checks = tuple(_validate_plan(connection, plan) for plan in plans)
|
|
56
|
+
if all(check.valid for check in checks):
|
|
57
|
+
return UniqueJoinValidationResult(True, "all unique join assertions passed", checks)
|
|
58
|
+
|
|
59
|
+
reason = "; ".join(check.reason for check in checks if not check.valid)
|
|
60
|
+
return UniqueJoinValidationResult(False, reason, checks)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def unique_assertions(sql: str, dialect: str = "duckdb") -> list[str]:
|
|
64
|
+
return [FALSE_ASSERTION for _ in _unique_join_plans(sql, dialect)]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _unique_join_plans(sql: str, dialect: str) -> list[_UniqueJoinPlan]:
|
|
68
|
+
marker_join_indexes = _marked_join_indexes(sql, dialect)
|
|
69
|
+
if not marker_join_indexes:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
expressions = [expression for expression in sqlglot.parse(sql, read=dialect) if expression]
|
|
74
|
+
except SqlglotError:
|
|
75
|
+
return [
|
|
76
|
+
_UniqueJoinPlan(index, "SQL parse failed")
|
|
77
|
+
for index, _ in enumerate(marker_join_indexes)
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
joins: list[exp.Join] = []
|
|
81
|
+
for expression in expressions:
|
|
82
|
+
joins.extend(expression.find_all(exp.Join))
|
|
83
|
+
|
|
84
|
+
plans: list[_UniqueJoinPlan] = []
|
|
85
|
+
for marker_index, join_index in enumerate(marker_join_indexes):
|
|
86
|
+
if join_index is None or join_index >= len(joins):
|
|
87
|
+
plans.append(_UniqueJoinPlan(marker_index, "marker is not followed by a join"))
|
|
88
|
+
continue
|
|
89
|
+
plans.append(_plan_for_join(marker_index, joins[join_index], dialect))
|
|
90
|
+
return plans
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _marked_join_indexes(sql: str, dialect: str) -> list[int | None]:
|
|
94
|
+
marker_offsets = [match.end() for match in MARKER.finditer(sql)]
|
|
95
|
+
if not marker_offsets:
|
|
96
|
+
return []
|
|
97
|
+
|
|
98
|
+
tokens = Tokenizer(dialect=dialect).tokenize(sql)
|
|
99
|
+
join_offsets = [
|
|
100
|
+
token.start
|
|
101
|
+
for token in tokens
|
|
102
|
+
if token.token_type is TokenType.JOIN
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
indexes: list[int | None] = []
|
|
106
|
+
for marker_offset in marker_offsets:
|
|
107
|
+
next_join_index = next(
|
|
108
|
+
(index for index, join_offset in enumerate(join_offsets) if join_offset >= marker_offset),
|
|
109
|
+
None,
|
|
110
|
+
)
|
|
111
|
+
indexes.append(next_join_index)
|
|
112
|
+
return indexes
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _plan_for_join(marker_index: int, join: exp.Join, dialect: str) -> _UniqueJoinPlan:
|
|
116
|
+
join_sql = _join_sql(join, dialect)
|
|
117
|
+
kind = (join.args.get("kind") or "").upper()
|
|
118
|
+
if kind in {"ANTI", "SEMI"}:
|
|
119
|
+
return _UniqueJoinPlan(
|
|
120
|
+
marker_index,
|
|
121
|
+
f'{_join_reason_prefix(join_sql)}, {kind.lower()} joins are not supported',
|
|
122
|
+
join_sql,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
rhs = join.this
|
|
126
|
+
on = join.args.get("on")
|
|
127
|
+
using = join.args.get("using") or []
|
|
128
|
+
if rhs is None:
|
|
129
|
+
return _UniqueJoinPlan(marker_index, f"{_join_reason_prefix(join_sql)}, marked join has no RHS relation", join_sql)
|
|
130
|
+
if on is None and not using:
|
|
131
|
+
return _UniqueJoinPlan(marker_index, f"{_join_reason_prefix(join_sql)}, marked join has no ON or USING predicate", join_sql)
|
|
132
|
+
|
|
133
|
+
rhs_names = _rhs_names(rhs)
|
|
134
|
+
if not rhs_names:
|
|
135
|
+
return _UniqueJoinPlan(
|
|
136
|
+
marker_index,
|
|
137
|
+
f"{_join_reason_prefix(join_sql)}, could not identify RHS relation name or alias",
|
|
138
|
+
join_sql,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
keys = _rhs_using_columns(using) if using else _rhs_key_columns(on, rhs_names)
|
|
142
|
+
if not keys:
|
|
143
|
+
return _UniqueJoinPlan(
|
|
144
|
+
marker_index,
|
|
145
|
+
f"{_join_reason_prefix(join_sql)}, could not infer RHS key columns from join predicate",
|
|
146
|
+
join_sql,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return _UniqueJoinPlan(
|
|
150
|
+
marker_index=marker_index,
|
|
151
|
+
reason="join predicate inferred",
|
|
152
|
+
join_sql=join_sql,
|
|
153
|
+
rhs=rhs,
|
|
154
|
+
rhs_names=frozenset(rhs_names),
|
|
155
|
+
keys=tuple(keys),
|
|
156
|
+
rhs_filter_columns=tuple(_rhs_filter_columns(on, rhs_names) if on is not None else ()),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _validate_plan(connection: Any, plan: _UniqueJoinPlan) -> UniqueJoinCheckResult:
|
|
161
|
+
key_names = tuple(key.name for key in plan.keys)
|
|
162
|
+
if plan.rhs is None or not plan.keys:
|
|
163
|
+
return UniqueJoinCheckResult(
|
|
164
|
+
marker_index=plan.marker_index,
|
|
165
|
+
valid=False,
|
|
166
|
+
reason=plan.reason,
|
|
167
|
+
inferred_key_columns=key_names,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
unique_constraints = _relation_unique_constraints(connection, plan.rhs)
|
|
171
|
+
if not unique_constraints:
|
|
172
|
+
return UniqueJoinCheckResult(
|
|
173
|
+
marker_index=plan.marker_index,
|
|
174
|
+
valid=False,
|
|
175
|
+
reason=_cannot_prove_reason(plan),
|
|
176
|
+
inferred_key_columns=key_names,
|
|
177
|
+
unique_constraints=unique_constraints,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
constraint_check = _validate_constraints(plan, key_names, unique_constraints)
|
|
181
|
+
if constraint_check is not None:
|
|
182
|
+
return constraint_check
|
|
183
|
+
|
|
184
|
+
return UniqueJoinCheckResult(
|
|
185
|
+
marker_index=plan.marker_index,
|
|
186
|
+
valid=False,
|
|
187
|
+
reason=_cannot_prove_reason(plan),
|
|
188
|
+
inferred_key_columns=key_names,
|
|
189
|
+
unique_constraints=unique_constraints,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _validate_constraints(
|
|
194
|
+
plan: _UniqueJoinPlan,
|
|
195
|
+
key_names: tuple[str, ...],
|
|
196
|
+
unique_constraints: tuple[tuple[str, ...], ...],
|
|
197
|
+
) -> UniqueJoinCheckResult | None:
|
|
198
|
+
covered_columns = _covered_rhs_column_names(plan)
|
|
199
|
+
for constraint in unique_constraints:
|
|
200
|
+
if all(column.lower() in covered_columns for column in constraint):
|
|
201
|
+
constrained_key_columns = tuple(constraint)
|
|
202
|
+
return UniqueJoinCheckResult(
|
|
203
|
+
marker_index=plan.marker_index,
|
|
204
|
+
valid=True,
|
|
205
|
+
reason=(
|
|
206
|
+
f"RHS uniqueness proof ({', '.join(constrained_key_columns)}) "
|
|
207
|
+
"is covered by inferred keys/filters "
|
|
208
|
+
f"({', '.join(sorted(covered_columns))})"
|
|
209
|
+
),
|
|
210
|
+
inferred_key_columns=key_names,
|
|
211
|
+
constrained_key_columns=constrained_key_columns,
|
|
212
|
+
unique_constraints=unique_constraints,
|
|
213
|
+
)
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _cannot_prove_reason(plan: _UniqueJoinPlan) -> str:
|
|
218
|
+
rhs_columns = _format_columns(tuple(key.name for key in plan.keys))
|
|
219
|
+
verb = "is" if len(plan.keys) == 1 else "are"
|
|
220
|
+
return f"{_join_reason_prefix(plan.join_sql)}, we can't prove that RHS {rhs_columns} {verb} unique"
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _join_reason_prefix(join_sql: str) -> str:
|
|
224
|
+
return f'in join: "{join_sql}"'
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _join_sql(join: exp.Join, dialect: str) -> str:
|
|
228
|
+
join_sql = join.sql(dialect=dialect)
|
|
229
|
+
join_sql = re.sub(r"/\*\s*\*?\s*unique\s*\*?\s*\*/", "", join_sql, flags=re.IGNORECASE)
|
|
230
|
+
return " ".join(join_sql.split())
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _rhs_names(rhs: exp.Expression) -> set[str]:
|
|
234
|
+
if rhs.alias:
|
|
235
|
+
return {rhs.alias.lower()}
|
|
236
|
+
if isinstance(rhs, exp.Table):
|
|
237
|
+
return {rhs.name.lower()}
|
|
238
|
+
if rhs.alias_or_name:
|
|
239
|
+
return {rhs.alias_or_name.lower()}
|
|
240
|
+
return set()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _rhs_key_columns(on: exp.Expression, rhs_names: set[str]) -> list[exp.Column]:
|
|
244
|
+
keys: list[exp.Column] = []
|
|
245
|
+
seen: set[tuple[str, str]] = set()
|
|
246
|
+
|
|
247
|
+
for equality in on.find_all(exp.EQ):
|
|
248
|
+
rhs_column = _simple_rhs_column(equality.this, equality.expression, rhs_names)
|
|
249
|
+
if rhs_column is None:
|
|
250
|
+
rhs_column = _simple_rhs_column(equality.expression, equality.this, rhs_names)
|
|
251
|
+
if rhs_column is None:
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
key = (_column_table(rhs_column).lower(), rhs_column.name.lower())
|
|
255
|
+
if key not in seen:
|
|
256
|
+
keys.append(rhs_column.copy())
|
|
257
|
+
seen.add(key)
|
|
258
|
+
|
|
259
|
+
return keys
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _rhs_using_columns(using: list[exp.Identifier]) -> list[exp.Column]:
|
|
263
|
+
keys = []
|
|
264
|
+
seen = set()
|
|
265
|
+
for identifier in using:
|
|
266
|
+
name = identifier.name
|
|
267
|
+
if name.lower() in seen:
|
|
268
|
+
continue
|
|
269
|
+
keys.append(exp.column(name))
|
|
270
|
+
seen.add(name.lower())
|
|
271
|
+
return keys
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _simple_rhs_column(
|
|
275
|
+
maybe_column: exp.Expression,
|
|
276
|
+
other_side: exp.Expression,
|
|
277
|
+
rhs_names: set[str],
|
|
278
|
+
) -> exp.Column | None:
|
|
279
|
+
if not isinstance(maybe_column, exp.Column):
|
|
280
|
+
return None
|
|
281
|
+
if not _is_rhs_column(maybe_column, rhs_names):
|
|
282
|
+
return None
|
|
283
|
+
if any(_is_rhs_column(column, rhs_names) for column in other_side.find_all(exp.Column)):
|
|
284
|
+
return None
|
|
285
|
+
return maybe_column
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _rhs_filter_columns(on: exp.Expression, rhs_names: set[str]) -> list[exp.Column]:
|
|
289
|
+
filters: list[exp.Column] = []
|
|
290
|
+
seen: set[tuple[str, str]] = set()
|
|
291
|
+
|
|
292
|
+
for predicate in _and_terms(on):
|
|
293
|
+
columns = list(predicate.find_all(exp.Column))
|
|
294
|
+
if not columns or not all(_is_rhs_column(column, rhs_names) for column in columns):
|
|
295
|
+
continue
|
|
296
|
+
for column in columns:
|
|
297
|
+
key = (_column_table(column).lower(), column.name.lower())
|
|
298
|
+
if key not in seen:
|
|
299
|
+
filters.append(column.copy())
|
|
300
|
+
seen.add(key)
|
|
301
|
+
|
|
302
|
+
return filters
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _and_terms(expression: exp.Expression) -> Iterable[exp.Expression]:
|
|
306
|
+
if isinstance(expression, exp.And):
|
|
307
|
+
yield from _and_terms(expression.this)
|
|
308
|
+
yield from _and_terms(expression.expression)
|
|
309
|
+
else:
|
|
310
|
+
yield expression
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _is_rhs_column(column: exp.Column, rhs_names: set[str]) -> bool:
|
|
314
|
+
table = _column_table(column)
|
|
315
|
+
return bool(table and table.lower() in rhs_names)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _column_table(column: exp.Column) -> str:
|
|
319
|
+
table = column.table
|
|
320
|
+
return table if isinstance(table, str) else ""
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _covered_rhs_column_names(plan: _UniqueJoinPlan) -> set[str]:
|
|
324
|
+
return {
|
|
325
|
+
column.name.lower()
|
|
326
|
+
for column in (*plan.keys, *plan.rhs_filter_columns)
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _relation_unique_constraints(
|
|
331
|
+
connection: Any,
|
|
332
|
+
relation: exp.Expression,
|
|
333
|
+
seen_views: frozenset[tuple[str, str]] = frozenset(),
|
|
334
|
+
) -> tuple[tuple[str, ...], ...]:
|
|
335
|
+
if isinstance(relation, exp.Table):
|
|
336
|
+
constraints = _unique_constraints(connection, relation)
|
|
337
|
+
if constraints:
|
|
338
|
+
return constraints
|
|
339
|
+
return _view_unique_constraints(connection, relation, seen_views)
|
|
340
|
+
|
|
341
|
+
if isinstance(relation, exp.Subquery) and isinstance(relation.this, exp.Select):
|
|
342
|
+
return _select_unique_constraints(connection, relation.this, seen_views)
|
|
343
|
+
|
|
344
|
+
return ()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _view_unique_constraints(
|
|
348
|
+
connection: Any,
|
|
349
|
+
view: exp.Table,
|
|
350
|
+
seen_views: frozenset[tuple[str, str]],
|
|
351
|
+
) -> tuple[tuple[str, ...], ...]:
|
|
352
|
+
schema_name = _table_schema(view) or "main"
|
|
353
|
+
view_key = (schema_name.lower(), view.name.lower())
|
|
354
|
+
if view_key in seen_views:
|
|
355
|
+
return ()
|
|
356
|
+
|
|
357
|
+
view_sql = _view_sql(connection, view)
|
|
358
|
+
if not view_sql:
|
|
359
|
+
return ()
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
expression = sqlglot.parse_one(view_sql, read="duckdb")
|
|
363
|
+
except SqlglotError:
|
|
364
|
+
return ()
|
|
365
|
+
|
|
366
|
+
if not isinstance(expression, exp.Create) or not isinstance(expression.expression, exp.Select):
|
|
367
|
+
return ()
|
|
368
|
+
|
|
369
|
+
return _select_unique_constraints(connection, expression.expression, seen_views | {view_key})
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _view_sql(connection: Any, view: exp.Table) -> str:
|
|
373
|
+
schema_name = _table_schema(view)
|
|
374
|
+
if schema_name:
|
|
375
|
+
query = (
|
|
376
|
+
"select sql "
|
|
377
|
+
"from duckdb_views() "
|
|
378
|
+
"where view_name = ? "
|
|
379
|
+
"and schema_name = ? "
|
|
380
|
+
"and not internal "
|
|
381
|
+
"order by database_name, schema_name "
|
|
382
|
+
"limit 1"
|
|
383
|
+
)
|
|
384
|
+
params = (view.name, schema_name)
|
|
385
|
+
else:
|
|
386
|
+
query = (
|
|
387
|
+
"select sql "
|
|
388
|
+
"from duckdb_views() "
|
|
389
|
+
"where view_name = ? "
|
|
390
|
+
"and not internal "
|
|
391
|
+
"order by database_name, schema_name "
|
|
392
|
+
"limit 1"
|
|
393
|
+
)
|
|
394
|
+
params = (view.name,)
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
row = connection.execute(query, params).fetchone()
|
|
398
|
+
except Exception:
|
|
399
|
+
return ""
|
|
400
|
+
|
|
401
|
+
return row[0] if row else ""
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _select_unique_constraints(
|
|
405
|
+
connection: Any,
|
|
406
|
+
select: exp.Select,
|
|
407
|
+
seen_views: frozenset[tuple[str, str]],
|
|
408
|
+
) -> tuple[tuple[str, ...], ...]:
|
|
409
|
+
constraints = []
|
|
410
|
+
group_constraint = _group_by_constraint(select)
|
|
411
|
+
if group_constraint:
|
|
412
|
+
constraints.append(group_constraint)
|
|
413
|
+
|
|
414
|
+
distinct_constraint = _distinct_constraint(select)
|
|
415
|
+
if distinct_constraint:
|
|
416
|
+
constraints.append(distinct_constraint)
|
|
417
|
+
|
|
418
|
+
qualify_constraint = _qualify_row_number_constraint(select)
|
|
419
|
+
if qualify_constraint:
|
|
420
|
+
constraints.append(qualify_constraint)
|
|
421
|
+
|
|
422
|
+
constraints.extend(_projected_source_constraints(connection, select, seen_views))
|
|
423
|
+
|
|
424
|
+
return _dedupe_constraints(tuple(constraints))
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _projected_source_constraints(
|
|
428
|
+
connection: Any,
|
|
429
|
+
select: exp.Select,
|
|
430
|
+
seen_views: frozenset[tuple[str, str]],
|
|
431
|
+
) -> tuple[tuple[str, ...], ...]:
|
|
432
|
+
source = _single_select_source(select)
|
|
433
|
+
if source is None:
|
|
434
|
+
return ()
|
|
435
|
+
|
|
436
|
+
source_constraints = _relation_unique_constraints(connection, source, seen_views)
|
|
437
|
+
if not source_constraints:
|
|
438
|
+
return ()
|
|
439
|
+
|
|
440
|
+
projection = _projection_map(select.expressions)
|
|
441
|
+
if not projection:
|
|
442
|
+
return ()
|
|
443
|
+
|
|
444
|
+
constraints = []
|
|
445
|
+
for constraint in source_constraints:
|
|
446
|
+
mapped = []
|
|
447
|
+
for column in constraint:
|
|
448
|
+
output_name = projection.get(column.lower())
|
|
449
|
+
if output_name is None:
|
|
450
|
+
break
|
|
451
|
+
mapped.append(output_name)
|
|
452
|
+
else:
|
|
453
|
+
constraints.append(tuple(mapped))
|
|
454
|
+
|
|
455
|
+
return tuple(constraints)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _single_select_source(select: exp.Select) -> exp.Expression | None:
|
|
459
|
+
if select.args.get("joins"):
|
|
460
|
+
return None
|
|
461
|
+
|
|
462
|
+
from_ = select.args.get("from_")
|
|
463
|
+
if not isinstance(from_, exp.From):
|
|
464
|
+
return None
|
|
465
|
+
|
|
466
|
+
source = from_.this
|
|
467
|
+
if isinstance(source, exp.Table | exp.Subquery):
|
|
468
|
+
return source
|
|
469
|
+
return None
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _projection_map(expressions: list[exp.Expression]) -> dict[str, str]:
|
|
473
|
+
projection: dict[str, str] = {}
|
|
474
|
+
for expression in expressions:
|
|
475
|
+
source_name, output_name = _projection_column_names(expression)
|
|
476
|
+
if source_name and output_name and source_name.lower() not in projection:
|
|
477
|
+
projection[source_name.lower()] = output_name
|
|
478
|
+
return projection
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def _projection_column_names(expression: exp.Expression) -> tuple[str, str]:
|
|
482
|
+
if isinstance(expression, exp.Alias) and isinstance(expression.this, exp.Column):
|
|
483
|
+
return expression.this.name, expression.alias
|
|
484
|
+
if isinstance(expression, exp.Column):
|
|
485
|
+
return expression.name, expression.name
|
|
486
|
+
return "", ""
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _dedupe_constraints(constraints: tuple[tuple[str, ...], ...]) -> tuple[tuple[str, ...], ...]:
|
|
490
|
+
deduped = []
|
|
491
|
+
seen = set()
|
|
492
|
+
for constraint in constraints:
|
|
493
|
+
key = tuple(column.lower() for column in constraint)
|
|
494
|
+
if key in seen:
|
|
495
|
+
continue
|
|
496
|
+
deduped.append(constraint)
|
|
497
|
+
seen.add(key)
|
|
498
|
+
return tuple(deduped)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _group_by_constraint(select: exp.Select) -> tuple[str, ...]:
|
|
502
|
+
group = select.args.get("group")
|
|
503
|
+
if not isinstance(group, exp.Group):
|
|
504
|
+
return ()
|
|
505
|
+
return _simple_output_columns(group.expressions)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def _distinct_constraint(select: exp.Select) -> tuple[str, ...]:
|
|
509
|
+
if select.args.get("distinct") is None:
|
|
510
|
+
return ()
|
|
511
|
+
return _simple_output_columns(select.expressions)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _qualify_row_number_constraint(select: exp.Select) -> tuple[str, ...]:
|
|
515
|
+
qualify = select.args.get("qualify")
|
|
516
|
+
if not isinstance(qualify, exp.Qualify):
|
|
517
|
+
return ()
|
|
518
|
+
|
|
519
|
+
window = _row_number_window_filtered_to_one(qualify.this)
|
|
520
|
+
if window is None:
|
|
521
|
+
return ()
|
|
522
|
+
return _simple_output_columns(window.args.get("partition_by") or [])
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _row_number_window_filtered_to_one(expression: exp.Expression) -> exp.Window | None:
|
|
526
|
+
if not isinstance(expression, exp.EQ):
|
|
527
|
+
return None
|
|
528
|
+
|
|
529
|
+
if _is_row_number_window(expression.this) and _is_one(expression.expression):
|
|
530
|
+
return expression.this
|
|
531
|
+
if _is_row_number_window(expression.expression) and _is_one(expression.this):
|
|
532
|
+
return expression.expression
|
|
533
|
+
return None
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _is_row_number_window(expression: exp.Expression) -> bool:
|
|
537
|
+
return isinstance(expression, exp.Window) and isinstance(expression.this, exp.RowNumber)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def _is_one(expression: exp.Expression) -> bool:
|
|
541
|
+
return isinstance(expression, exp.Literal) and not expression.is_string and expression.this == "1"
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _simple_output_columns(expressions: Iterable[exp.Expression]) -> tuple[str, ...]:
|
|
545
|
+
columns = []
|
|
546
|
+
seen = set()
|
|
547
|
+
for expression in expressions:
|
|
548
|
+
name = _simple_output_column_name(expression)
|
|
549
|
+
if name and name.lower() not in seen:
|
|
550
|
+
columns.append(name)
|
|
551
|
+
seen.add(name.lower())
|
|
552
|
+
return tuple(columns)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def _simple_output_column_name(expression: exp.Expression) -> str:
|
|
556
|
+
if isinstance(expression, exp.Alias):
|
|
557
|
+
return expression.alias
|
|
558
|
+
if isinstance(expression, exp.Column):
|
|
559
|
+
return expression.name
|
|
560
|
+
return ""
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def _unique_constraints(connection: Any, rhs: exp.Table) -> tuple[tuple[str, ...], ...]:
|
|
564
|
+
schema_name = _table_schema(rhs)
|
|
565
|
+
if schema_name:
|
|
566
|
+
query = (
|
|
567
|
+
"select constraint_column_names "
|
|
568
|
+
"from duckdb_constraints() "
|
|
569
|
+
"where table_name = ? "
|
|
570
|
+
"and schema_name = ? "
|
|
571
|
+
"and constraint_type in ('PRIMARY KEY', 'UNIQUE') "
|
|
572
|
+
"order by constraint_index"
|
|
573
|
+
)
|
|
574
|
+
params = (rhs.name, schema_name)
|
|
575
|
+
else:
|
|
576
|
+
query = (
|
|
577
|
+
"select constraint_column_names "
|
|
578
|
+
"from duckdb_constraints() "
|
|
579
|
+
"where table_name = ? "
|
|
580
|
+
"and constraint_type in ('PRIMARY KEY', 'UNIQUE') "
|
|
581
|
+
"order by constraint_index"
|
|
582
|
+
)
|
|
583
|
+
params = (rhs.name,)
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
rows = connection.execute(query, params).fetchall()
|
|
587
|
+
except Exception:
|
|
588
|
+
return ()
|
|
589
|
+
|
|
590
|
+
return tuple(tuple(row[0]) for row in rows)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _table_schema(table: exp.Table) -> str:
|
|
594
|
+
db = table.args.get("db")
|
|
595
|
+
if isinstance(db, exp.Identifier):
|
|
596
|
+
return db.name
|
|
597
|
+
if isinstance(db, str):
|
|
598
|
+
return db
|
|
599
|
+
return ""
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _format_columns(columns: tuple[str, ...]) -> str:
|
|
603
|
+
if len(columns) == 1:
|
|
604
|
+
return f"column {columns[0]}"
|
|
605
|
+
return f"columns {', '.join(columns)}"
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sqlassert
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Generate SQL assertions from unique join markers.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: sqlglot>=28
|
|
8
|
+
Provides-Extra: test
|
|
9
|
+
Requires-Dist: duckdb; extra == "test"
|
|
10
|
+
Requires-Dist: pytest; extra == "test"
|
|
11
|
+
|
|
12
|
+
# sqlassert
|
|
13
|
+
|
|
14
|
+
`sqlassert` is a Python library for adding safety checks to SQL before you run it.
|
|
15
|
+
|
|
16
|
+
The goal is to catch common query mistakes at test time or build time, using fast static and metadata-backed proofs instead of scanning production data. You can add `sqlassert` to your test suite and validate important queries offline, making them more resilient independent of the current contents of your database.
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install sqlassert
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
_Alpha warning: Today `sqlassert` supports only one check: `/**UNIQUE**/` joins. It is also only tested on duckdb._
|
|
24
|
+
|
|
25
|
+
## Features
|
|
26
|
+
|
|
27
|
+
### Unique Join
|
|
28
|
+
|
|
29
|
+
Joins often accidentally multiply rows. A query may look correct against today’s data but silently break when the RHS relation later contains multiple matching rows.
|
|
30
|
+
|
|
31
|
+
`sqlassert` lets you mark joins that are expected to be unique. That is, the result of the join must never 'grow' the number of rows with respect to the LHS.
|
|
32
|
+
|
|
33
|
+
```sql
|
|
34
|
+
select *
|
|
35
|
+
from sessions
|
|
36
|
+
/**UNIQUE**/ join users
|
|
37
|
+
on sessions.user_id = users.id;
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
The marker is just a SQL comment. Your SQL remains valid SQL and can still run normally. `sqlassert` reads the query separately and validates that the RHS is provably unique for the join keys.
|
|
41
|
+
|
|
42
|
+
## Usage
|
|
43
|
+
|
|
44
|
+
Run validation offline, before your application or analytics job executes the query:
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import duckdb
|
|
48
|
+
from sqlassert import validate_unique_joins
|
|
49
|
+
|
|
50
|
+
con = duckdb.connect("warehouse.duckdb")
|
|
51
|
+
|
|
52
|
+
query = """
|
|
53
|
+
select *
|
|
54
|
+
from sessions
|
|
55
|
+
/**UNIQUE**/ join users
|
|
56
|
+
on sessions.user_id = users.id
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
result = validate_unique_joins(con, query)
|
|
60
|
+
|
|
61
|
+
assert result.valid, result.reason
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
For a test suite, keep your model/query SQL as strings or load them from files, then validate them against a db connection that has the relevant schema:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
def test_query_join_contract(con):
|
|
68
|
+
query = load_query("models/session_enrichment.sql")
|
|
69
|
+
result = validate_unique_joins(con, query)
|
|
70
|
+
|
|
71
|
+
assert result.valid, result.reason
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
`result.checks` contains one check per marker:
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
for check in result.checks:
|
|
78
|
+
print(check.valid)
|
|
79
|
+
print(check.reason)
|
|
80
|
+
print(check.inferred_key_columns)
|
|
81
|
+
print(check.constrained_key_columns)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## Details
|
|
85
|
+
|
|
86
|
+
### Unique Join Syntax
|
|
87
|
+
|
|
88
|
+
Place `/**UNIQUE**/` immediately before the join that should be uniqueness-checked:
|
|
89
|
+
|
|
90
|
+
```sql
|
|
91
|
+
select *
|
|
92
|
+
from lhs
|
|
93
|
+
/**UNIQUE**/ left join rhs
|
|
94
|
+
on lhs.rhs_id = rhs.id;
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
`ON` and `USING` are both supported:
|
|
98
|
+
|
|
99
|
+
```sql
|
|
100
|
+
select *
|
|
101
|
+
from users
|
|
102
|
+
/**UNIQUE**/ join user_profiles
|
|
103
|
+
using (id);
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
The marker applies to the next join after the comment.
|
|
107
|
+
|
|
108
|
+
## Proofs, Not Data Checks
|
|
109
|
+
|
|
110
|
+
`sqlassert` does **not** validate by querying actual table data. It will not run `count(*)`, search for duplicates, or sample rows.
|
|
111
|
+
|
|
112
|
+
Instead, it proves uniqueness using fast information available from the SQL and database metadata. If uniqueness cannot be proven, validation fails with a reason that names the join and RHS column:
|
|
113
|
+
|
|
114
|
+
```text
|
|
115
|
+
in join "INNER JOIN events ON sessions.event_id = events.id", we can't prove that RHS column id is unique
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Supported uniqueness proofs today:
|
|
119
|
+
|
|
120
|
+
- RHS `PRIMARY KEY` and `UNIQUE` constraints from db metadata.
|
|
121
|
+
- RHS `GROUP BY` subqueries, when the join covers the grouping keys.
|
|
122
|
+
- RHS `SELECT DISTINCT` subqueries, when the join covers the selected distinct columns.
|
|
123
|
+
- RHS `QUALIFY row_number() over (partition by ...) = 1` subqueries, when the join covers the partition keys.
|
|
124
|
+
- Simple projection views and subqueries that preserve one of the proofs above.
|
|
125
|
+
|
|
126
|
+
Views can inherit uniqueness when they are simple projections over a source relation with a supported proof. Filters preserve uniqueness; computed expressions, joins inside views, unions, and arbitrary subquery semantics are not guessed.
|
|
127
|
+
|
|
128
|
+
Examples:
|
|
129
|
+
|
|
130
|
+
```sql
|
|
131
|
+
-- Proved by primary key.
|
|
132
|
+
select *
|
|
133
|
+
from sessions
|
|
134
|
+
/**UNIQUE**/ join users
|
|
135
|
+
on sessions.user_id = users.id;
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
```sql
|
|
139
|
+
-- Proved by composite primary key plus RHS-only filter.
|
|
140
|
+
select *
|
|
141
|
+
from sessions
|
|
142
|
+
/**UNIQUE**/ join orders
|
|
143
|
+
on sessions.user_id = orders.user_id
|
|
144
|
+
and orders.order_id = 1;
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
```sql
|
|
148
|
+
-- Proved by GROUP BY.
|
|
149
|
+
with latest_session as (
|
|
150
|
+
select user_id, max(ts) as max_ts
|
|
151
|
+
from sessions
|
|
152
|
+
group by user_id
|
|
153
|
+
)
|
|
154
|
+
select *
|
|
155
|
+
from users
|
|
156
|
+
/**UNIQUE**/ join latest_session
|
|
157
|
+
on users.id = latest_session.user_id;
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
```sql
|
|
161
|
+
-- Proved by QUALIFY row_number() = 1.
|
|
162
|
+
with sessions_ranked as (
|
|
163
|
+
select user_id, *
|
|
164
|
+
from sessions
|
|
165
|
+
qualify row_number() over (partition by user_id order by ts) = 1
|
|
166
|
+
)
|
|
167
|
+
select *
|
|
168
|
+
from users
|
|
169
|
+
/**UNIQUE**/ join sessions_ranked
|
|
170
|
+
on users.id = sessions_ranked.user_id;
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
More compile-time SQL checks can be added under the same model: explicit syntax, fast validation, and clear reasons when a proof is missing.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
sqlassert/__init__.py,sha256=Sroyw8c2hVVASyOuoRvMFnyBrgDxfMomwJs8KoXEcc8,274
|
|
2
|
+
sqlassert/unique.py,sha256=Uzxti-t4Hl8iIT3bPX6NZQbWbCmzAQ-SfukhLcnc76E,19160
|
|
3
|
+
sqlassert-0.1.0.dist-info/METADATA,sha256=lcdR-fUctWAlwsW7h9W_ThsWslvtfzGlwiLje8Arfl4,4971
|
|
4
|
+
sqlassert-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
5
|
+
sqlassert-0.1.0.dist-info/top_level.txt,sha256=BINwf_eEiAW975zV29VyXrSnTO9G136zvJVJ2pupnLI,10
|
|
6
|
+
sqlassert-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sqlassert
|