sqlassert 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlassert/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ from sqlassert.unique import (
2
+ UniqueJoinCheckResult,
3
+ UniqueJoinValidationResult,
4
+ unique_assertions,
5
+ validate_unique_joins,
6
+ )
7
+
8
+ __all__ = [
9
+ "UniqueJoinCheckResult",
10
+ "UniqueJoinValidationResult",
11
+ "unique_assertions",
12
+ "validate_unique_joins",
13
+ ]
sqlassert/unique.py ADDED
@@ -0,0 +1,605 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections.abc import Iterable
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import sqlglot
9
+ from sqlglot import exp
10
+ from sqlglot.errors import SqlglotError
11
+ from sqlglot.tokens import TokenType, Tokenizer
12
+
13
+
14
+ FALSE_ASSERTION = "select false"
15
+ MARKER = re.compile(re.escape(r"/**unique**/"), flags=re.IGNORECASE)
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class UniqueJoinCheckResult:
20
+ marker_index: int
21
+ valid: bool
22
+ reason: str
23
+ inferred_key_columns: tuple[str, ...] = ()
24
+ constrained_key_columns: tuple[str, ...] = ()
25
+ unique_constraints: tuple[tuple[str, ...], ...] = ()
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class UniqueJoinValidationResult:
30
+ valid: bool
31
+ reason: str
32
+ checks: tuple[UniqueJoinCheckResult, ...] = ()
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class _UniqueJoinPlan:
37
+ marker_index: int
38
+ reason: str
39
+ join_sql: str = ""
40
+ rhs: exp.Expression | None = None
41
+ rhs_names: frozenset[str] = frozenset()
42
+ keys: tuple[exp.Column, ...] = ()
43
+ rhs_filter_columns: tuple[exp.Column, ...] = ()
44
+
45
+
46
+ def validate_unique_joins(
47
+ connection: Any,
48
+ sql: str,
49
+ dialect: str = "duckdb",
50
+ ) -> UniqueJoinValidationResult:
51
+ plans = _unique_join_plans(sql, dialect)
52
+ if not plans:
53
+ return UniqueJoinValidationResult(True, "no unique join markers found")
54
+
55
+ checks = tuple(_validate_plan(connection, plan) for plan in plans)
56
+ if all(check.valid for check in checks):
57
+ return UniqueJoinValidationResult(True, "all unique join assertions passed", checks)
58
+
59
+ reason = "; ".join(check.reason for check in checks if not check.valid)
60
+ return UniqueJoinValidationResult(False, reason, checks)
61
+
62
+
63
+ def unique_assertions(sql: str, dialect: str = "duckdb") -> list[str]:
64
+ return [FALSE_ASSERTION for _ in _unique_join_plans(sql, dialect)]
65
+
66
+
67
+ def _unique_join_plans(sql: str, dialect: str) -> list[_UniqueJoinPlan]:
68
+ marker_join_indexes = _marked_join_indexes(sql, dialect)
69
+ if not marker_join_indexes:
70
+ return []
71
+
72
+ try:
73
+ expressions = [expression for expression in sqlglot.parse(sql, read=dialect) if expression]
74
+ except SqlglotError:
75
+ return [
76
+ _UniqueJoinPlan(index, "SQL parse failed")
77
+ for index, _ in enumerate(marker_join_indexes)
78
+ ]
79
+
80
+ joins: list[exp.Join] = []
81
+ for expression in expressions:
82
+ joins.extend(expression.find_all(exp.Join))
83
+
84
+ plans: list[_UniqueJoinPlan] = []
85
+ for marker_index, join_index in enumerate(marker_join_indexes):
86
+ if join_index is None or join_index >= len(joins):
87
+ plans.append(_UniqueJoinPlan(marker_index, "marker is not followed by a join"))
88
+ continue
89
+ plans.append(_plan_for_join(marker_index, joins[join_index], dialect))
90
+ return plans
91
+
92
+
93
+ def _marked_join_indexes(sql: str, dialect: str) -> list[int | None]:
94
+ marker_offsets = [match.end() for match in MARKER.finditer(sql)]
95
+ if not marker_offsets:
96
+ return []
97
+
98
+ tokens = Tokenizer(dialect=dialect).tokenize(sql)
99
+ join_offsets = [
100
+ token.start
101
+ for token in tokens
102
+ if token.token_type is TokenType.JOIN
103
+ ]
104
+
105
+ indexes: list[int | None] = []
106
+ for marker_offset in marker_offsets:
107
+ next_join_index = next(
108
+ (index for index, join_offset in enumerate(join_offsets) if join_offset >= marker_offset),
109
+ None,
110
+ )
111
+ indexes.append(next_join_index)
112
+ return indexes
113
+
114
+
115
+ def _plan_for_join(marker_index: int, join: exp.Join, dialect: str) -> _UniqueJoinPlan:
116
+ join_sql = _join_sql(join, dialect)
117
+ kind = (join.args.get("kind") or "").upper()
118
+ if kind in {"ANTI", "SEMI"}:
119
+ return _UniqueJoinPlan(
120
+ marker_index,
121
+ f'{_join_reason_prefix(join_sql)}, {kind.lower()} joins are not supported',
122
+ join_sql,
123
+ )
124
+
125
+ rhs = join.this
126
+ on = join.args.get("on")
127
+ using = join.args.get("using") or []
128
+ if rhs is None:
129
+ return _UniqueJoinPlan(marker_index, f"{_join_reason_prefix(join_sql)}, marked join has no RHS relation", join_sql)
130
+ if on is None and not using:
131
+ return _UniqueJoinPlan(marker_index, f"{_join_reason_prefix(join_sql)}, marked join has no ON or USING predicate", join_sql)
132
+
133
+ rhs_names = _rhs_names(rhs)
134
+ if not rhs_names:
135
+ return _UniqueJoinPlan(
136
+ marker_index,
137
+ f"{_join_reason_prefix(join_sql)}, could not identify RHS relation name or alias",
138
+ join_sql,
139
+ )
140
+
141
+ keys = _rhs_using_columns(using) if using else _rhs_key_columns(on, rhs_names)
142
+ if not keys:
143
+ return _UniqueJoinPlan(
144
+ marker_index,
145
+ f"{_join_reason_prefix(join_sql)}, could not infer RHS key columns from join predicate",
146
+ join_sql,
147
+ )
148
+
149
+ return _UniqueJoinPlan(
150
+ marker_index=marker_index,
151
+ reason="join predicate inferred",
152
+ join_sql=join_sql,
153
+ rhs=rhs,
154
+ rhs_names=frozenset(rhs_names),
155
+ keys=tuple(keys),
156
+ rhs_filter_columns=tuple(_rhs_filter_columns(on, rhs_names) if on is not None else ()),
157
+ )
158
+
159
+
160
+ def _validate_plan(connection: Any, plan: _UniqueJoinPlan) -> UniqueJoinCheckResult:
161
+ key_names = tuple(key.name for key in plan.keys)
162
+ if plan.rhs is None or not plan.keys:
163
+ return UniqueJoinCheckResult(
164
+ marker_index=plan.marker_index,
165
+ valid=False,
166
+ reason=plan.reason,
167
+ inferred_key_columns=key_names,
168
+ )
169
+
170
+ unique_constraints = _relation_unique_constraints(connection, plan.rhs)
171
+ if not unique_constraints:
172
+ return UniqueJoinCheckResult(
173
+ marker_index=plan.marker_index,
174
+ valid=False,
175
+ reason=_cannot_prove_reason(plan),
176
+ inferred_key_columns=key_names,
177
+ unique_constraints=unique_constraints,
178
+ )
179
+
180
+ constraint_check = _validate_constraints(plan, key_names, unique_constraints)
181
+ if constraint_check is not None:
182
+ return constraint_check
183
+
184
+ return UniqueJoinCheckResult(
185
+ marker_index=plan.marker_index,
186
+ valid=False,
187
+ reason=_cannot_prove_reason(plan),
188
+ inferred_key_columns=key_names,
189
+ unique_constraints=unique_constraints,
190
+ )
191
+
192
+
193
+ def _validate_constraints(
194
+ plan: _UniqueJoinPlan,
195
+ key_names: tuple[str, ...],
196
+ unique_constraints: tuple[tuple[str, ...], ...],
197
+ ) -> UniqueJoinCheckResult | None:
198
+ covered_columns = _covered_rhs_column_names(plan)
199
+ for constraint in unique_constraints:
200
+ if all(column.lower() in covered_columns for column in constraint):
201
+ constrained_key_columns = tuple(constraint)
202
+ return UniqueJoinCheckResult(
203
+ marker_index=plan.marker_index,
204
+ valid=True,
205
+ reason=(
206
+ f"RHS uniqueness proof ({', '.join(constrained_key_columns)}) "
207
+ "is covered by inferred keys/filters "
208
+ f"({', '.join(sorted(covered_columns))})"
209
+ ),
210
+ inferred_key_columns=key_names,
211
+ constrained_key_columns=constrained_key_columns,
212
+ unique_constraints=unique_constraints,
213
+ )
214
+ return None
215
+
216
+
217
+ def _cannot_prove_reason(plan: _UniqueJoinPlan) -> str:
218
+ rhs_columns = _format_columns(tuple(key.name for key in plan.keys))
219
+ verb = "is" if len(plan.keys) == 1 else "are"
220
+ return f"{_join_reason_prefix(plan.join_sql)}, we can't prove that RHS {rhs_columns} {verb} unique"
221
+
222
+
223
+ def _join_reason_prefix(join_sql: str) -> str:
224
+ return f'in join: "{join_sql}"'
225
+
226
+
227
+ def _join_sql(join: exp.Join, dialect: str) -> str:
228
+ join_sql = join.sql(dialect=dialect)
229
+ join_sql = re.sub(r"/\*\s*\*?\s*unique\s*\*?\s*\*/", "", join_sql, flags=re.IGNORECASE)
230
+ return " ".join(join_sql.split())
231
+
232
+
233
+ def _rhs_names(rhs: exp.Expression) -> set[str]:
234
+ if rhs.alias:
235
+ return {rhs.alias.lower()}
236
+ if isinstance(rhs, exp.Table):
237
+ return {rhs.name.lower()}
238
+ if rhs.alias_or_name:
239
+ return {rhs.alias_or_name.lower()}
240
+ return set()
241
+
242
+
243
+ def _rhs_key_columns(on: exp.Expression, rhs_names: set[str]) -> list[exp.Column]:
244
+ keys: list[exp.Column] = []
245
+ seen: set[tuple[str, str]] = set()
246
+
247
+ for equality in on.find_all(exp.EQ):
248
+ rhs_column = _simple_rhs_column(equality.this, equality.expression, rhs_names)
249
+ if rhs_column is None:
250
+ rhs_column = _simple_rhs_column(equality.expression, equality.this, rhs_names)
251
+ if rhs_column is None:
252
+ continue
253
+
254
+ key = (_column_table(rhs_column).lower(), rhs_column.name.lower())
255
+ if key not in seen:
256
+ keys.append(rhs_column.copy())
257
+ seen.add(key)
258
+
259
+ return keys
260
+
261
+
262
+ def _rhs_using_columns(using: list[exp.Identifier]) -> list[exp.Column]:
263
+ keys = []
264
+ seen = set()
265
+ for identifier in using:
266
+ name = identifier.name
267
+ if name.lower() in seen:
268
+ continue
269
+ keys.append(exp.column(name))
270
+ seen.add(name.lower())
271
+ return keys
272
+
273
+
274
+ def _simple_rhs_column(
275
+ maybe_column: exp.Expression,
276
+ other_side: exp.Expression,
277
+ rhs_names: set[str],
278
+ ) -> exp.Column | None:
279
+ if not isinstance(maybe_column, exp.Column):
280
+ return None
281
+ if not _is_rhs_column(maybe_column, rhs_names):
282
+ return None
283
+ if any(_is_rhs_column(column, rhs_names) for column in other_side.find_all(exp.Column)):
284
+ return None
285
+ return maybe_column
286
+
287
+
288
+ def _rhs_filter_columns(on: exp.Expression, rhs_names: set[str]) -> list[exp.Column]:
289
+ filters: list[exp.Column] = []
290
+ seen: set[tuple[str, str]] = set()
291
+
292
+ for predicate in _and_terms(on):
293
+ columns = list(predicate.find_all(exp.Column))
294
+ if not columns or not all(_is_rhs_column(column, rhs_names) for column in columns):
295
+ continue
296
+ for column in columns:
297
+ key = (_column_table(column).lower(), column.name.lower())
298
+ if key not in seen:
299
+ filters.append(column.copy())
300
+ seen.add(key)
301
+
302
+ return filters
303
+
304
+
305
+ def _and_terms(expression: exp.Expression) -> Iterable[exp.Expression]:
306
+ if isinstance(expression, exp.And):
307
+ yield from _and_terms(expression.this)
308
+ yield from _and_terms(expression.expression)
309
+ else:
310
+ yield expression
311
+
312
+
313
+ def _is_rhs_column(column: exp.Column, rhs_names: set[str]) -> bool:
314
+ table = _column_table(column)
315
+ return bool(table and table.lower() in rhs_names)
316
+
317
+
318
+ def _column_table(column: exp.Column) -> str:
319
+ table = column.table
320
+ return table if isinstance(table, str) else ""
321
+
322
+
323
+ def _covered_rhs_column_names(plan: _UniqueJoinPlan) -> set[str]:
324
+ return {
325
+ column.name.lower()
326
+ for column in (*plan.keys, *plan.rhs_filter_columns)
327
+ }
328
+
329
+
330
+ def _relation_unique_constraints(
331
+ connection: Any,
332
+ relation: exp.Expression,
333
+ seen_views: frozenset[tuple[str, str]] = frozenset(),
334
+ ) -> tuple[tuple[str, ...], ...]:
335
+ if isinstance(relation, exp.Table):
336
+ constraints = _unique_constraints(connection, relation)
337
+ if constraints:
338
+ return constraints
339
+ return _view_unique_constraints(connection, relation, seen_views)
340
+
341
+ if isinstance(relation, exp.Subquery) and isinstance(relation.this, exp.Select):
342
+ return _select_unique_constraints(connection, relation.this, seen_views)
343
+
344
+ return ()
345
+
346
+
347
+ def _view_unique_constraints(
348
+ connection: Any,
349
+ view: exp.Table,
350
+ seen_views: frozenset[tuple[str, str]],
351
+ ) -> tuple[tuple[str, ...], ...]:
352
+ schema_name = _table_schema(view) or "main"
353
+ view_key = (schema_name.lower(), view.name.lower())
354
+ if view_key in seen_views:
355
+ return ()
356
+
357
+ view_sql = _view_sql(connection, view)
358
+ if not view_sql:
359
+ return ()
360
+
361
+ try:
362
+ expression = sqlglot.parse_one(view_sql, read="duckdb")
363
+ except SqlglotError:
364
+ return ()
365
+
366
+ if not isinstance(expression, exp.Create) or not isinstance(expression.expression, exp.Select):
367
+ return ()
368
+
369
+ return _select_unique_constraints(connection, expression.expression, seen_views | {view_key})
370
+
371
+
372
+ def _view_sql(connection: Any, view: exp.Table) -> str:
373
+ schema_name = _table_schema(view)
374
+ if schema_name:
375
+ query = (
376
+ "select sql "
377
+ "from duckdb_views() "
378
+ "where view_name = ? "
379
+ "and schema_name = ? "
380
+ "and not internal "
381
+ "order by database_name, schema_name "
382
+ "limit 1"
383
+ )
384
+ params = (view.name, schema_name)
385
+ else:
386
+ query = (
387
+ "select sql "
388
+ "from duckdb_views() "
389
+ "where view_name = ? "
390
+ "and not internal "
391
+ "order by database_name, schema_name "
392
+ "limit 1"
393
+ )
394
+ params = (view.name,)
395
+
396
+ try:
397
+ row = connection.execute(query, params).fetchone()
398
+ except Exception:
399
+ return ""
400
+
401
+ return row[0] if row else ""
402
+
403
+
404
+ def _select_unique_constraints(
405
+ connection: Any,
406
+ select: exp.Select,
407
+ seen_views: frozenset[tuple[str, str]],
408
+ ) -> tuple[tuple[str, ...], ...]:
409
+ constraints = []
410
+ group_constraint = _group_by_constraint(select)
411
+ if group_constraint:
412
+ constraints.append(group_constraint)
413
+
414
+ distinct_constraint = _distinct_constraint(select)
415
+ if distinct_constraint:
416
+ constraints.append(distinct_constraint)
417
+
418
+ qualify_constraint = _qualify_row_number_constraint(select)
419
+ if qualify_constraint:
420
+ constraints.append(qualify_constraint)
421
+
422
+ constraints.extend(_projected_source_constraints(connection, select, seen_views))
423
+
424
+ return _dedupe_constraints(tuple(constraints))
425
+
426
+
427
+ def _projected_source_constraints(
428
+ connection: Any,
429
+ select: exp.Select,
430
+ seen_views: frozenset[tuple[str, str]],
431
+ ) -> tuple[tuple[str, ...], ...]:
432
+ source = _single_select_source(select)
433
+ if source is None:
434
+ return ()
435
+
436
+ source_constraints = _relation_unique_constraints(connection, source, seen_views)
437
+ if not source_constraints:
438
+ return ()
439
+
440
+ projection = _projection_map(select.expressions)
441
+ if not projection:
442
+ return ()
443
+
444
+ constraints = []
445
+ for constraint in source_constraints:
446
+ mapped = []
447
+ for column in constraint:
448
+ output_name = projection.get(column.lower())
449
+ if output_name is None:
450
+ break
451
+ mapped.append(output_name)
452
+ else:
453
+ constraints.append(tuple(mapped))
454
+
455
+ return tuple(constraints)
456
+
457
+
458
+ def _single_select_source(select: exp.Select) -> exp.Expression | None:
459
+ if select.args.get("joins"):
460
+ return None
461
+
462
+ from_ = select.args.get("from_")
463
+ if not isinstance(from_, exp.From):
464
+ return None
465
+
466
+ source = from_.this
467
+ if isinstance(source, exp.Table | exp.Subquery):
468
+ return source
469
+ return None
470
+
471
+
472
+ def _projection_map(expressions: list[exp.Expression]) -> dict[str, str]:
473
+ projection: dict[str, str] = {}
474
+ for expression in expressions:
475
+ source_name, output_name = _projection_column_names(expression)
476
+ if source_name and output_name and source_name.lower() not in projection:
477
+ projection[source_name.lower()] = output_name
478
+ return projection
479
+
480
+
481
+ def _projection_column_names(expression: exp.Expression) -> tuple[str, str]:
482
+ if isinstance(expression, exp.Alias) and isinstance(expression.this, exp.Column):
483
+ return expression.this.name, expression.alias
484
+ if isinstance(expression, exp.Column):
485
+ return expression.name, expression.name
486
+ return "", ""
487
+
488
+
489
+ def _dedupe_constraints(constraints: tuple[tuple[str, ...], ...]) -> tuple[tuple[str, ...], ...]:
490
+ deduped = []
491
+ seen = set()
492
+ for constraint in constraints:
493
+ key = tuple(column.lower() for column in constraint)
494
+ if key in seen:
495
+ continue
496
+ deduped.append(constraint)
497
+ seen.add(key)
498
+ return tuple(deduped)
499
+
500
+
501
+ def _group_by_constraint(select: exp.Select) -> tuple[str, ...]:
502
+ group = select.args.get("group")
503
+ if not isinstance(group, exp.Group):
504
+ return ()
505
+ return _simple_output_columns(group.expressions)
506
+
507
+
508
+ def _distinct_constraint(select: exp.Select) -> tuple[str, ...]:
509
+ if select.args.get("distinct") is None:
510
+ return ()
511
+ return _simple_output_columns(select.expressions)
512
+
513
+
514
+ def _qualify_row_number_constraint(select: exp.Select) -> tuple[str, ...]:
515
+ qualify = select.args.get("qualify")
516
+ if not isinstance(qualify, exp.Qualify):
517
+ return ()
518
+
519
+ window = _row_number_window_filtered_to_one(qualify.this)
520
+ if window is None:
521
+ return ()
522
+ return _simple_output_columns(window.args.get("partition_by") or [])
523
+
524
+
525
+ def _row_number_window_filtered_to_one(expression: exp.Expression) -> exp.Window | None:
526
+ if not isinstance(expression, exp.EQ):
527
+ return None
528
+
529
+ if _is_row_number_window(expression.this) and _is_one(expression.expression):
530
+ return expression.this
531
+ if _is_row_number_window(expression.expression) and _is_one(expression.this):
532
+ return expression.expression
533
+ return None
534
+
535
+
536
+ def _is_row_number_window(expression: exp.Expression) -> bool:
537
+ return isinstance(expression, exp.Window) and isinstance(expression.this, exp.RowNumber)
538
+
539
+
540
+ def _is_one(expression: exp.Expression) -> bool:
541
+ return isinstance(expression, exp.Literal) and not expression.is_string and expression.this == "1"
542
+
543
+
544
+ def _simple_output_columns(expressions: Iterable[exp.Expression]) -> tuple[str, ...]:
545
+ columns = []
546
+ seen = set()
547
+ for expression in expressions:
548
+ name = _simple_output_column_name(expression)
549
+ if name and name.lower() not in seen:
550
+ columns.append(name)
551
+ seen.add(name.lower())
552
+ return tuple(columns)
553
+
554
+
555
+ def _simple_output_column_name(expression: exp.Expression) -> str:
556
+ if isinstance(expression, exp.Alias):
557
+ return expression.alias
558
+ if isinstance(expression, exp.Column):
559
+ return expression.name
560
+ return ""
561
+
562
+
563
+ def _unique_constraints(connection: Any, rhs: exp.Table) -> tuple[tuple[str, ...], ...]:
564
+ schema_name = _table_schema(rhs)
565
+ if schema_name:
566
+ query = (
567
+ "select constraint_column_names "
568
+ "from duckdb_constraints() "
569
+ "where table_name = ? "
570
+ "and schema_name = ? "
571
+ "and constraint_type in ('PRIMARY KEY', 'UNIQUE') "
572
+ "order by constraint_index"
573
+ )
574
+ params = (rhs.name, schema_name)
575
+ else:
576
+ query = (
577
+ "select constraint_column_names "
578
+ "from duckdb_constraints() "
579
+ "where table_name = ? "
580
+ "and constraint_type in ('PRIMARY KEY', 'UNIQUE') "
581
+ "order by constraint_index"
582
+ )
583
+ params = (rhs.name,)
584
+
585
+ try:
586
+ rows = connection.execute(query, params).fetchall()
587
+ except Exception:
588
+ return ()
589
+
590
+ return tuple(tuple(row[0]) for row in rows)
591
+
592
+
593
+ def _table_schema(table: exp.Table) -> str:
594
+ db = table.args.get("db")
595
+ if isinstance(db, exp.Identifier):
596
+ return db.name
597
+ if isinstance(db, str):
598
+ return db
599
+ return ""
600
+
601
+
602
+ def _format_columns(columns: tuple[str, ...]) -> str:
603
+ if len(columns) == 1:
604
+ return f"column {columns[0]}"
605
+ return f"columns {', '.join(columns)}"
@@ -0,0 +1,173 @@
1
+ Metadata-Version: 2.4
2
+ Name: sqlassert
3
+ Version: 0.1.0
4
+ Summary: Generate SQL assertions from unique join markers.
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: sqlglot>=28
8
+ Provides-Extra: test
9
+ Requires-Dist: duckdb; extra == "test"
10
+ Requires-Dist: pytest; extra == "test"
11
+
12
+ # sqlassert
13
+
14
+ `sqlassert` is a Python library for adding safety checks to SQL before you run it.
15
+
16
+ The goal is to catch common query mistakes at test time or build time, using fast static and metadata-backed proofs instead of scanning production data. You can add `sqlassert` to your test suite and validate important queries offline, making them more resilient independent of the current contents of your database.
17
+
18
+
19
+ ```bash
20
+ pip install sqlassert
21
+ ```
22
+
23
+ _Alpha warning: Today `sqlassert` supports only one check: `/**UNIQUE**/` joins. It is also only tested on duckdb._
24
+
25
+ ## Features
26
+
27
+ ### Unique Join
28
+
29
+ Joins often accidentally multiply rows. A query may look correct against today’s data but silently break when the RHS relation later contains multiple matching rows.
30
+
31
+ `sqlassert` lets you mark joins that are expected to be unique. That is, the result of the join must never 'grow' the number of rows with respect to the LHS.
32
+
33
+ ```sql
34
+ select *
35
+ from sessions
36
+ /**UNIQUE**/ join users
37
+ on sessions.user_id = users.id;
38
+ ```
39
+
40
+ The marker is just a SQL comment. Your SQL remains valid SQL and can still run normally. `sqlassert` reads the query separately and validates that the RHS is provably unique for the join keys.
41
+
42
+ ## Usage
43
+
44
+ Run validation offline, before your application or analytics job executes the query:
45
+
46
+ ```python
47
+ import duckdb
48
+ from sqlassert import validate_unique_joins
49
+
50
+ con = duckdb.connect("warehouse.duckdb")
51
+
52
+ query = """
53
+ select *
54
+ from sessions
55
+ /**UNIQUE**/ join users
56
+ on sessions.user_id = users.id
57
+ """
58
+
59
+ result = validate_unique_joins(con, query)
60
+
61
+ assert result.valid, result.reason
62
+ ```
63
+
64
+ For a test suite, keep your model/query SQL as strings or load them from files, then validate them against a db connection that has the relevant schema:
65
+
66
+ ```python
67
+ def test_query_join_contract(con):
68
+ query = load_query("models/session_enrichment.sql")
69
+ result = validate_unique_joins(con, query)
70
+
71
+ assert result.valid, result.reason
72
+ ```
73
+
74
+ `result.checks` contains one check per marker:
75
+
76
+ ```python
77
+ for check in result.checks:
78
+ print(check.valid)
79
+ print(check.reason)
80
+ print(check.inferred_key_columns)
81
+ print(check.constrained_key_columns)
82
+ ```
83
+
84
+ ## Details
85
+
86
+ ### Unique Join Syntax
87
+
88
+ Place `/**UNIQUE**/` immediately before the join that should be uniqueness-checked:
89
+
90
+ ```sql
91
+ select *
92
+ from lhs
93
+ /**UNIQUE**/ left join rhs
94
+ on lhs.rhs_id = rhs.id;
95
+ ```
96
+
97
+ `ON` and `USING` are both supported:
98
+
99
+ ```sql
100
+ select *
101
+ from users
102
+ /**UNIQUE**/ join user_profiles
103
+ using (id);
104
+ ```
105
+
106
+ The marker applies to the next join after the comment.
107
+
108
+ ## Proofs, Not Data Checks
109
+
110
+ `sqlassert` does **not** validate by querying actual table data. It will not run `count(*)`, search for duplicates, or sample rows.
111
+
112
+ Instead, it proves uniqueness using fast information available from the SQL and database metadata. If uniqueness cannot be proven, validation fails with a reason that names the join and RHS column:
113
+
114
+ ```text
115
+ in join "INNER JOIN events ON sessions.event_id = events.id", we can't prove that RHS column id is unique
116
+ ```
117
+
118
+ Supported uniqueness proofs today:
119
+
120
+ - RHS `PRIMARY KEY` and `UNIQUE` constraints from db metadata.
121
+ - RHS `GROUP BY` subqueries, when the join covers the grouping keys.
122
+ - RHS `SELECT DISTINCT` subqueries, when the join covers the selected distinct columns.
123
+ - RHS `QUALIFY row_number() over (partition by ...) = 1` subqueries, when the join covers the partition keys.
124
+ - Simple projection views and subqueries that preserve one of the proofs above.
125
+
126
+ Views can inherit uniqueness when they are simple projections over a source relation with a supported proof. Filters preserve uniqueness; computed expressions, joins inside views, unions, and arbitrary subquery semantics are not guessed.
127
+
128
+ Examples:
129
+
130
+ ```sql
131
+ -- Proved by primary key.
132
+ select *
133
+ from sessions
134
+ /**UNIQUE**/ join users
135
+ on sessions.user_id = users.id;
136
+ ```
137
+
138
+ ```sql
139
+ -- Proved by composite primary key plus RHS-only filter.
140
+ select *
141
+ from sessions
142
+ /**UNIQUE**/ join orders
143
+ on sessions.user_id = orders.user_id
144
+ and orders.order_id = 1;
145
+ ```
146
+
147
+ ```sql
148
+ -- Proved by GROUP BY.
149
+ with latest_session as (
150
+ select user_id, max(ts) as max_ts
151
+ from sessions
152
+ group by user_id
153
+ )
154
+ select *
155
+ from users
156
+ /**UNIQUE**/ join latest_session
157
+ on users.id = latest_session.user_id;
158
+ ```
159
+
160
+ ```sql
161
+ -- Proved by QUALIFY row_number() = 1.
162
+ with sessions_ranked as (
163
+ select user_id, *
164
+ from sessions
165
+ qualify row_number() over (partition by user_id order by ts) = 1
166
+ )
167
+ select *
168
+ from users
169
+ /**UNIQUE**/ join sessions_ranked
170
+ on users.id = sessions_ranked.user_id;
171
+ ```
172
+
173
+ More compile-time SQL checks can be added under the same model: explicit syntax, fast validation, and clear reasons when a proof is missing.
@@ -0,0 +1,6 @@
1
+ sqlassert/__init__.py,sha256=Sroyw8c2hVVASyOuoRvMFnyBrgDxfMomwJs8KoXEcc8,274
2
+ sqlassert/unique.py,sha256=Uzxti-t4Hl8iIT3bPX6NZQbWbCmzAQ-SfukhLcnc76E,19160
3
+ sqlassert-0.1.0.dist-info/METADATA,sha256=lcdR-fUctWAlwsW7h9W_ThsWslvtfzGlwiLje8Arfl4,4971
4
+ sqlassert-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
5
+ sqlassert-0.1.0.dist-info/top_level.txt,sha256=BINwf_eEiAW975zV29VyXrSnTO9G136zvJVJ2pupnLI,10
6
+ sqlassert-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ sqlassert