sqlproof 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sqlproof/__init__.py +32 -0
  2. sqlproof/_version.py +1 -0
  3. sqlproof/cli.py +151 -0
  4. sqlproof/client.py +159 -0
  5. sqlproof/config.py +42 -0
  6. sqlproof/contrib/__init__.py +3 -0
  7. sqlproof/contrib/supabase.py +136 -0
  8. sqlproof/core.py +344 -0
  9. sqlproof/coverage/__init__.py +6 -0
  10. sqlproof/coverage/diversity.py +11 -0
  11. sqlproof/coverage/plpgsql.py +5 -0
  12. sqlproof/coverage/schema_shape.py +7 -0
  13. sqlproof/exceptions.py +47 -0
  14. sqlproof/generators/__init__.py +21 -0
  15. sqlproof/generators/columns.py +93 -0
  16. sqlproof/generators/constraints.py +181 -0
  17. sqlproof/generators/functions.py +9 -0
  18. sqlproof/generators/graph.py +51 -0
  19. sqlproof/generators/rows.py +153 -0
  20. sqlproof/generators/sampling.py +15 -0
  21. sqlproof/generators/well_known.py +59 -0
  22. sqlproof/pytest_plugin.py +24 -0
  23. sqlproof/reporter/__init__.py +5 -0
  24. sqlproof/reporter/console.py +20 -0
  25. sqlproof/reporter/json_io.py +26 -0
  26. sqlproof/runners/__init__.py +14 -0
  27. sqlproof/runners/db.py +48 -0
  28. sqlproof/runners/migration.py +51 -0
  29. sqlproof/runners/overload.py +41 -0
  30. sqlproof/runners/property.py +119 -0
  31. sqlproof/runners/rls.py +40 -0
  32. sqlproof/runners/stateful.py +36 -0
  33. sqlproof/schema/__init__.py +27 -0
  34. sqlproof/schema/dependency_graph.py +38 -0
  35. sqlproof/schema/fingerprint.py +34 -0
  36. sqlproof/schema/introspect.py +229 -0
  37. sqlproof/schema/model.py +98 -0
  38. sqlproof/schema/parse_sql.py +206 -0
  39. sqlproof/testing.py +101 -0
  40. sqlproof/types.py +34 -0
  41. sqlproof-0.1.0a1.dist-info/METADATA +248 -0
  42. sqlproof-0.1.0a1.dist-info/RECORD +44 -0
  43. sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
  44. sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ from sqlproof.generators.functions import FunctionCall
7
+
8
+
9
+ def function_overloads(
10
+ proof: Any,
11
+ *,
12
+ function: str,
13
+ **kwargs: object,
14
+ ) -> Callable[[Callable[..., None]], Callable[..., None]]:
15
+ del kwargs
16
+
17
+ def decorate(callback: Callable[..., None]) -> Callable[..., None]:
18
+ def wrapped() -> None:
19
+ calls = _function_calls(proof, function)
20
+ if len(calls) < 2:
21
+ calls = (
22
+ FunctionCall(sql=f"{function}()", overload_name=f"{function}/0"),
23
+ FunctionCall(sql=f"{function}()", overload_name=f"{function}/0"),
24
+ )
25
+ with proof.client_for_dataset({}) as db:
26
+ callback(db, calls[0], calls[1])
27
+
28
+ return wrapped
29
+
30
+ return decorate
31
+
32
+
33
+ def _function_calls(proof: Any, function_name: str) -> tuple[FunctionCall, ...]:
34
+ calls: list[FunctionCall] = []
35
+ for function in proof.schema_info.functions:
36
+ if function.name != function_name:
37
+ continue
38
+ arguments = ", ".join("%s" for _arg in function.arg_types)
39
+ overload = f"{function.name}/{len(function.arg_types)}"
40
+ calls.append(FunctionCall(sql=f"{function.name}({arguments})", overload_name=overload))
41
+ return tuple(calls)
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Callable, Generator, Mapping
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+ from typing import Any, ParamSpec, TypeVar
8
+
9
+ from hypothesis import HealthCheck, given, settings
10
+
11
+ from sqlproof.exceptions import SqlProofPropertyFailure
12
+ from sqlproof.generators.graph import SizeSpec, dataset_strategy
13
+ from sqlproof.reporter.json_io import write_counterexample
14
+
15
+ P = ParamSpec("P")
16
+ R = TypeVar("R")
17
+
18
+
19
+ class Check:
20
+ def __init__(self) -> None:
21
+ self.row_context: dict[str, Any] | None = None
22
+
23
+ @contextmanager
24
+ def row(self, **context: Any) -> Generator[None]:
25
+ previous = self.row_context
26
+ self.row_context = context
27
+ try:
28
+ yield
29
+ finally:
30
+ if previous is not None:
31
+ self.row_context = previous
32
+
33
+ def label(self, name: str) -> None:
34
+ del name
35
+
36
+
37
+ def sqlproof(
38
+ proof: Any,
39
+ *,
40
+ sizes: Mapping[str, SizeSpec],
41
+ runs: int = 100,
42
+ seed: int | None = None,
43
+ timeout_ms: int = 5000,
44
+ commit: bool = False,
45
+ failure_dir: str | Path = ".sqlproof/failures",
46
+ ) -> Callable[[Callable[..., None]], Callable[..., None]]:
47
+ del seed, timeout_ms, commit
48
+
49
+ def decorate(function: Callable[..., None]) -> Callable[..., None]:
50
+ def wrapped() -> None:
51
+ run_property(
52
+ proof,
53
+ function,
54
+ sizes=sizes,
55
+ runs=runs,
56
+ failure_dir=Path(failure_dir),
57
+ )
58
+
59
+ wrapped.__name__ = function.__name__
60
+ wrapped.__doc__ = function.__doc__
61
+ return wrapped
62
+
63
+ return decorate
64
+
65
+
66
+ def run_property(
67
+ proof: Any,
68
+ function: Callable[..., None],
69
+ *,
70
+ sizes: Mapping[str, SizeSpec],
71
+ runs: int,
72
+ failure_dir: Path,
73
+ ) -> None:
74
+ if hasattr(proof, "dataset_strategy"):
75
+ strategy = proof.dataset_strategy(sizes=sizes)
76
+ else:
77
+ strategy = dataset_strategy(proof.schema_info, sizes=sizes)
78
+ signature = inspect.signature(function)
79
+ wants_check = "check" in signature.parameters
80
+ run_count = 0
81
+
82
+ @given(strategy)
83
+ @settings(
84
+ max_examples=runs,
85
+ deadline=None,
86
+ suppress_health_check=[HealthCheck.function_scoped_fixture],
87
+ )
88
+ def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
89
+ nonlocal run_count
90
+ run_count += 1
91
+ check = Check()
92
+ try:
93
+ with proof.client_for_dataset(dataset) as client:
94
+ if wants_check:
95
+ function(client, check)
96
+ else:
97
+ function(client)
98
+ except Exception as exc:
99
+ payload: dict[str, Any] = {
100
+ "$schema": "https://sqlproof.dev/schemas/counterexample-v1.json",
101
+ "version": 1,
102
+ "property_name": function.__name__,
103
+ "seed": None,
104
+ "runs": run_count,
105
+ "shrink_steps": 0,
106
+ "schema_fingerprint": proof.schema_fingerprint,
107
+ "row_context": check.row_context or {},
108
+ "dataset": dataset,
109
+ "failure": {
110
+ "kind": type(exc).__name__,
111
+ "message": str(exc),
112
+ "locals": {},
113
+ "traceback": [],
114
+ },
115
+ }
116
+ write_counterexample(failure_dir / f"{function.__name__}.json", payload)
117
+ raise SqlProofPropertyFailure(str(exc), counterexample=payload) from exc
118
+
119
+ execute()
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ from typing import Any, cast
5
+
6
+ from hypothesis import HealthCheck, given, settings
7
+
8
+ from sqlproof.generators.graph import SizeSpec, dataset_strategy
9
+
10
+
11
+ def rls(
12
+ proof: Any,
13
+ *,
14
+ sizes: Mapping[str, SizeSpec],
15
+ roles: list[str],
16
+ mode: str = "postgrest",
17
+ **kwargs: object,
18
+ ) -> Callable[[Callable[..., None]], Callable[..., None]]:
19
+ del mode
20
+ runs = int(cast(Any, kwargs.pop("runs", 1)))
21
+
22
+ def decorate(function: Callable[..., None]) -> Callable[..., None]:
23
+ def wrapped() -> None:
24
+ @given(dataset_strategy(proof.schema_info, sizes=sizes))
25
+ @settings(
26
+ max_examples=runs,
27
+ deadline=None,
28
+ suppress_health_check=[HealthCheck.function_scoped_fixture],
29
+ )
30
+ def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
31
+ for role in roles:
32
+ with proof.client_for_dataset(dataset) as db:
33
+ db.execute(f"SET LOCAL ROLE {role}")
34
+ function(db, {"role": role}, dataset)
35
+
36
+ execute()
37
+
38
+ return wrapped
39
+
40
+ return decorate
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ from typing import Any, cast
5
+
6
+ from hypothesis import HealthCheck, given, settings
7
+
8
+ from sqlproof.generators.graph import SizeSpec, dataset_strategy
9
+
10
+
11
+ def stateful(
12
+ proof: Any, *, sizes: Mapping[str, SizeSpec], **kwargs: object
13
+ ) -> Callable[[type[Any]], type[Any]]:
14
+ runs = int(cast(Any, kwargs.pop("runs", 1)))
15
+
16
+ def decorate(cls: type[Any]) -> type[Any]:
17
+ original_run = getattr(cls, "run", None)
18
+
19
+ def run(self: Any) -> None:
20
+ @given(dataset_strategy(proof.schema_info, sizes=sizes))
21
+ @settings(
22
+ max_examples=runs,
23
+ deadline=None,
24
+ suppress_health_check=[HealthCheck.function_scoped_fixture],
25
+ )
26
+ def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
27
+ with proof.client_for_dataset(dataset) as db:
28
+ if original_run is not None:
29
+ original_run(self, db)
30
+
31
+ execute()
32
+
33
+ cls.run = run
34
+ return cls
35
+
36
+ return decorate
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlproof.schema.dependency_graph import insertion_order
4
+ from sqlproof.schema.fingerprint import compute
5
+ from sqlproof.schema.model import (
6
+ CheckConstraint,
7
+ Column,
8
+ ForeignKey,
9
+ Function,
10
+ PgType,
11
+ SchemaInfo,
12
+ Table,
13
+ )
14
+ from sqlproof.schema.parse_sql import parse_schema_sql
15
+
16
+ __all__ = [
17
+ "CheckConstraint",
18
+ "Column",
19
+ "ForeignKey",
20
+ "Function",
21
+ "PgType",
22
+ "SchemaInfo",
23
+ "Table",
24
+ "compute",
25
+ "insertion_order",
26
+ "parse_schema_sql",
27
+ ]
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict, deque
4
+
5
+ from sqlproof.exceptions import CircularDependencyError
6
+ from sqlproof.schema.model import Table
7
+
8
+
9
+ def insertion_order(tables: tuple[Table, ...]) -> tuple[Table, ...]:
10
+ by_name = {table.name: table for table in tables}
11
+ dependents: dict[str, set[str]] = defaultdict(set)
12
+ indegree = {table.name: 0 for table in tables}
13
+
14
+ for table in tables:
15
+ dependencies = {
16
+ fk.referenced_table
17
+ for fk in table.foreign_keys
18
+ if fk.referenced_table != table.name and fk.referenced_table in by_name
19
+ }
20
+ indegree[table.name] = len(dependencies)
21
+ for dependency in dependencies:
22
+ dependents[dependency].add(table.name)
23
+
24
+ ready = deque(table.name for table in tables if indegree[table.name] == 0)
25
+ ordered: list[Table] = []
26
+ while ready:
27
+ name = ready.popleft()
28
+ ordered.append(by_name[name])
29
+ for dependent in sorted(dependents[name]):
30
+ indegree[dependent] -= 1
31
+ if indegree[dependent] == 0:
32
+ ready.append(dependent)
33
+
34
+ if len(ordered) != len(tables):
35
+ cycle = ", ".join(sorted(name for name, degree in indegree.items() if degree > 0))
36
+ raise CircularDependencyError(f"Circular foreign-key dependency detected: {cycle}")
37
+
38
+ return tuple(ordered)
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ from dataclasses import asdict, is_dataclass
6
+ from typing import Any, cast
7
+
8
+ from sqlproof.schema.model import SchemaInfo
9
+
10
+
11
+ def _canonical(value: Any) -> Any:
12
+ if is_dataclass(value):
13
+ return _canonical(asdict(cast(Any, value)))
14
+ if isinstance(value, dict):
15
+ items = cast("dict[str, Any]", value)
16
+ return {key: _canonical(items[key]) for key in sorted(items)}
17
+ if isinstance(value, tuple):
18
+ tuple_value = cast("tuple[Any, ...]", value) # type: ignore[redundant-cast]
19
+ return [_canonical(item) for item in tuple_value]
20
+ if isinstance(value, list):
21
+ list_value = cast("list[Any]", value) # type: ignore[redundant-cast]
22
+ return [_canonical(item) for item in list_value]
23
+ return value
24
+
25
+
26
+ def compute(schema_info: SchemaInfo) -> str:
27
+ canonical_json = json.dumps(
28
+ _canonical(schema_info),
29
+ sort_keys=True,
30
+ separators=(",", ":"),
31
+ ensure_ascii=True,
32
+ )
33
+ digest = hashlib.sha256(canonical_json.encode("utf-8")).hexdigest()
34
+ return f"sha256:{digest}"
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal, cast
4
+
5
+ from sqlproof.schema.model import CheckConstraint, Column, ForeignKey, PgType, SchemaInfo, Table
6
+
7
+
8
+ def introspect_schema(connection: Any, *, schema: str = "public") -> SchemaInfo:
9
+ enums = _load_enums(connection, schema=schema)
10
+ enum_by_name = {enum.name: enum for enum in enums}
11
+ columns_by_table: dict[tuple[str, str], list[Column]] = {}
12
+ for row in _fetch_all(connection, _COLUMNS_SQL, schema):
13
+ key = (str(row["schema_name"]), str(row["table_name"]))
14
+ type_name = str(row["type_name"])
15
+ pg_type = enum_by_name.get(type_name, PgType(kind="scalar", name=type_name))
16
+ columns_by_table.setdefault(key, []).append(
17
+ Column(
18
+ name=str(row["column_name"]),
19
+ type=pg_type,
20
+ nullable=bool(row["nullable"]),
21
+ default=_optional_str(row["default"]),
22
+ is_generated=bool(row["is_generated"]) or _is_serial_default(row["default"]),
23
+ identity=_identity(row["identity"]),
24
+ )
25
+ )
26
+
27
+ primary_keys = _constraint_columns(connection, _PRIMARY_KEYS_SQL, schema)
28
+ unique_constraints = _grouped_constraints(connection, _UNIQUES_SQL, schema)
29
+ foreign_keys = _foreign_keys(connection, schema=schema)
30
+ check_constraints = _check_constraints(connection, schema=schema)
31
+
32
+ tables = [
33
+ Table(
34
+ schema=table_schema,
35
+ name=table_name,
36
+ columns=tuple(columns),
37
+ primary_key=primary_keys.get((table_schema, table_name), ()),
38
+ foreign_keys=tuple(foreign_keys.get((table_schema, table_name), ())),
39
+ unique_constraints=tuple(unique_constraints.get((table_schema, table_name), ())),
40
+ check_constraints=tuple(check_constraints.get((table_schema, table_name), ())),
41
+ )
42
+ for (table_schema, table_name), columns in sorted(columns_by_table.items())
43
+ ]
44
+ return SchemaInfo(tables=tuple(tables), enums=enums)
45
+
46
+
47
+ def _fetch_all(connection: Any, sql: str, *params: object) -> list[dict[str, Any]]:
48
+ cursor = connection.execute(sql, params)
49
+ return [dict(row) for row in cursor.fetchall()]
50
+
51
+
52
+ def _load_enums(connection: Any, *, schema: str) -> tuple[PgType, ...]:
53
+ rows = _fetch_all(connection, _ENUMS_SQL, schema)
54
+ return tuple(
55
+ PgType(
56
+ kind="enum",
57
+ name=str(row["enum_name"]),
58
+ enum_values=tuple(str(value) for value in row["enum_values"]),
59
+ )
60
+ for row in rows
61
+ )
62
+
63
+
64
+ def _constraint_columns(
65
+ connection: Any, sql: str, schema: str
66
+ ) -> dict[tuple[str, str], tuple[str, ...]]:
67
+ return {
68
+ (str(row["schema_name"]), str(row["table_name"])): tuple(
69
+ str(value) for value in row["columns"]
70
+ )
71
+ for row in _fetch_all(connection, sql, schema)
72
+ }
73
+
74
+
75
+ def _grouped_constraints(
76
+ connection: Any, sql: str, schema: str
77
+ ) -> dict[tuple[str, str], list[tuple[str, ...]]]:
78
+ grouped: dict[tuple[str, str], list[tuple[str, ...]]] = {}
79
+ for row in _fetch_all(connection, sql, schema):
80
+ key = (str(row["schema_name"]), str(row["table_name"]))
81
+ grouped.setdefault(key, []).append(tuple(str(value) for value in row["columns"]))
82
+ return grouped
83
+
84
+
85
+ def _foreign_keys(connection: Any, *, schema: str) -> dict[tuple[str, str], list[ForeignKey]]:
86
+ grouped: dict[tuple[str, str], list[ForeignKey]] = {}
87
+ for row in _fetch_all(connection, _FOREIGN_KEYS_SQL, schema):
88
+ key = (str(row["schema_name"]), str(row["table_name"]))
89
+ grouped.setdefault(key, []).append(
90
+ ForeignKey(
91
+ columns=tuple(str(value) for value in row["columns"]),
92
+ referenced_table=str(row["referenced_table"]),
93
+ referenced_columns=tuple(str(value) for value in row["referenced_columns"]),
94
+ on_delete=cast(
95
+ Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"],
96
+ str(row["on_delete"]),
97
+ ),
98
+ on_update=cast(
99
+ Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"],
100
+ str(row["on_update"]),
101
+ ),
102
+ referenced_schema=str(row["referenced_schema"]),
103
+ )
104
+ )
105
+ return grouped
106
+
107
+
108
+ def _check_constraints(
109
+ connection: Any, *, schema: str
110
+ ) -> dict[tuple[str, str], list[CheckConstraint]]:
111
+ grouped: dict[tuple[str, str], list[CheckConstraint]] = {}
112
+ for row in _fetch_all(connection, _CHECKS_SQL, schema):
113
+ key = (str(row["schema_name"]), str(row["table_name"]))
114
+ grouped.setdefault(key, []).append(CheckConstraint(str(row["expression"])))
115
+ return grouped
116
+
117
+
118
+ def _optional_str(value: object) -> str | None:
119
+ return None if value is None else str(value)
120
+
121
+
122
+ def _identity(value: object) -> Literal["always", "by_default"] | None:
123
+ if value is None:
124
+ return None
125
+ normalized = str(value).lower().replace(" ", "_")
126
+ if normalized in {"always", "by_default"}:
127
+ return cast(Literal["always", "by_default"], normalized)
128
+ return None
129
+
130
+
131
+ def _is_serial_default(value: object) -> bool:
132
+ return isinstance(value, str) and value.startswith("nextval(")
133
+
134
+
135
+ _ENUMS_SQL = """
136
+ SELECT
137
+ n.nspname AS schema_name,
138
+ t.typname AS enum_name,
139
+ array_agg(e.enumlabel ORDER BY e.enumsortorder) AS enum_values
140
+ FROM pg_catalog.pg_type t
141
+ JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
142
+ JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
143
+ WHERE n.nspname = %s
144
+ GROUP BY n.nspname, t.typname
145
+ ORDER BY n.nspname, t.typname
146
+ """
147
+
148
+ _COLUMNS_SQL = """
149
+ SELECT
150
+ ns.nspname AS schema_name,
151
+ cls.relname AS table_name,
152
+ att.attname AS column_name,
153
+ typ.typname AS type_name,
154
+ NOT att.attnotnull AS nullable,
155
+ pg_get_expr(def.adbin, def.adrelid) AS default,
156
+ att.attgenerated <> '' AS is_generated,
157
+ CASE att.attidentity WHEN 'a' THEN 'always' WHEN 'd' THEN 'by default' ELSE NULL END AS identity,
158
+ ARRAY[]::integer[] AS modifiers
159
+ FROM pg_catalog.pg_attribute att
160
+ JOIN pg_catalog.pg_class cls ON cls.oid = att.attrelid
161
+ JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
162
+ JOIN pg_catalog.pg_type typ ON typ.oid = att.atttypid
163
+ LEFT JOIN pg_catalog.pg_attrdef def ON def.adrelid = att.attrelid AND def.adnum = att.attnum
164
+ WHERE ns.nspname = %s
165
+ AND cls.relkind IN ('r', 'p')
166
+ AND att.attnum > 0
167
+ AND NOT att.attisdropped
168
+ ORDER BY ns.nspname, cls.relname, att.attnum
169
+ """
170
+
171
+ _PRIMARY_KEYS_SQL = """
172
+ SELECT
173
+ ns.nspname AS schema_name,
174
+ cls.relname AS table_name,
175
+ array_agg(att.attname ORDER BY key_column.ordinality) AS columns
176
+ FROM pg_catalog.pg_constraint con
177
+ JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
178
+ JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
179
+ JOIN unnest(con.conkey) WITH ORDINALITY AS key_column(attnum, ordinality) ON true
180
+ JOIN pg_catalog.pg_attribute att ON att.attrelid = cls.oid AND att.attnum = key_column.attnum
181
+ WHERE ns.nspname = %s AND con.contype = 'p'
182
+ GROUP BY ns.nspname, cls.relname, con.oid
183
+ """
184
+
185
+ _UNIQUES_SQL = _PRIMARY_KEYS_SQL.replace("con.contype = 'p'", "con.contype = 'u'")
186
+
187
+ _FOREIGN_KEYS_SQL = """
188
+ SELECT
189
+ ns.nspname AS schema_name,
190
+ cls.relname AS table_name,
191
+ array_agg(src_att.attname ORDER BY src_key.ordinality) AS columns,
192
+ ref_ns.nspname AS referenced_schema,
193
+ ref_cls.relname AS referenced_table,
194
+ array_agg(ref_att.attname ORDER BY ref_key.ordinality) AS referenced_columns,
195
+ CASE con.confdeltype
196
+ WHEN 'r' THEN 'RESTRICT' WHEN 'c' THEN 'CASCADE' WHEN 'n' THEN 'SET NULL'
197
+ WHEN 'd' THEN 'SET DEFAULT' ELSE 'NO ACTION'
198
+ END AS on_delete,
199
+ CASE con.confupdtype
200
+ WHEN 'r' THEN 'RESTRICT' WHEN 'c' THEN 'CASCADE' WHEN 'n' THEN 'SET NULL'
201
+ WHEN 'd' THEN 'SET DEFAULT' ELSE 'NO ACTION'
202
+ END AS on_update
203
+ FROM pg_catalog.pg_constraint con
204
+ JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
205
+ JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
206
+ JOIN pg_catalog.pg_class ref_cls ON ref_cls.oid = con.confrelid
207
+ JOIN pg_catalog.pg_namespace ref_ns ON ref_ns.oid = ref_cls.relnamespace
208
+ JOIN unnest(con.conkey) WITH ORDINALITY AS src_key(attnum, ordinality) ON true
209
+ JOIN unnest(con.confkey) WITH ORDINALITY AS ref_key(attnum, ordinality)
210
+ ON ref_key.ordinality = src_key.ordinality
211
+ JOIN pg_catalog.pg_attribute src_att
212
+ ON src_att.attrelid = cls.oid AND src_att.attnum = src_key.attnum
213
+ JOIN pg_catalog.pg_attribute ref_att
214
+ ON ref_att.attrelid = ref_cls.oid AND ref_att.attnum = ref_key.attnum
215
+ WHERE ns.nspname = %s AND con.contype = 'f'
216
+ GROUP BY ns.nspname, cls.relname, ref_ns.nspname, ref_cls.relname, con.oid
217
+ """
218
+
219
+ _CHECKS_SQL = """
220
+ SELECT
221
+ ns.nspname AS schema_name,
222
+ cls.relname AS table_name,
223
+ pg_get_constraintdef(con.oid, true) AS expression
224
+ FROM pg_catalog.pg_constraint con
225
+ JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
226
+ JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
227
+ WHERE ns.nspname = %s AND con.contype = 'c'
228
+ ORDER BY ns.nspname, cls.relname, con.conname
229
+ """
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Literal
5
+
6
+ from sqlproof.exceptions import SqlProofSchemaError
7
+
8
+
9
+ @dataclass(frozen=True, slots=True)
10
+ class PgType:
11
+ kind: Literal["scalar", "array", "enum", "domain", "composite", "range"]
12
+ name: str
13
+ base: PgType | None = None
14
+ enum_values: tuple[str, ...] = ()
15
+ array_dim: int = 0
16
+ modifiers: tuple[int, ...] = ()
17
+
18
+
19
+ @dataclass(frozen=True, slots=True)
20
+ class ParsedCheck:
21
+ kind: Literal["range", "in_set", "regex", "length", "compound"]
22
+ column: str
23
+ payload: Any
24
+
25
+
26
+ @dataclass(frozen=True, slots=True)
27
+ class CheckConstraint:
28
+ expression: str
29
+ parsed: ParsedCheck | None = None
30
+
31
+
32
+ @dataclass(frozen=True, slots=True)
33
+ class Column:
34
+ name: str
35
+ type: PgType
36
+ nullable: bool
37
+ default: str | None
38
+ is_generated: bool
39
+ identity: Literal["always", "by_default"] | None = None
40
+
41
+
42
+ @dataclass(frozen=True, slots=True)
43
+ class ForeignKey:
44
+ columns: tuple[str, ...]
45
+ referenced_table: str
46
+ referenced_columns: tuple[str, ...]
47
+ on_delete: Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"]
48
+ on_update: Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"]
49
+ referenced_schema: str | None = None
50
+
51
+
52
+ @dataclass(frozen=True, slots=True)
53
+ class Table:
54
+ schema: str
55
+ name: str
56
+ columns: tuple[Column, ...]
57
+ primary_key: tuple[str, ...]
58
+ foreign_keys: tuple[ForeignKey, ...]
59
+ unique_constraints: tuple[tuple[str, ...], ...]
60
+ check_constraints: tuple[CheckConstraint, ...]
61
+ opaque_constraints: tuple[str, ...] = ()
62
+
63
+ @property
64
+ def qualified_name(self) -> str:
65
+ return f"{self.schema}.{self.name}"
66
+
67
+ def column(self, name: str) -> Column:
68
+ for column in self.columns:
69
+ if column.name == name:
70
+ return column
71
+ msg = f"Unknown column {name!r} on table {self.qualified_name!r}."
72
+ raise SqlProofSchemaError(msg)
73
+
74
+
75
+ @dataclass(frozen=True, slots=True)
76
+ class Function:
77
+ schema: str
78
+ name: str
79
+ arg_types: tuple[PgType, ...]
80
+ return_type: PgType
81
+ volatility: Literal["immutable", "stable", "volatile"]
82
+ language: str
83
+
84
+
85
+ @dataclass(frozen=True, slots=True)
86
+ class SchemaInfo:
87
+ tables: tuple[Table, ...] = ()
88
+ enums: tuple[PgType, ...] = ()
89
+ functions: tuple[Function, ...] = ()
90
+ domains: tuple[PgType, ...] = ()
91
+ opaque_sql: tuple[str, ...] = field(default_factory=tuple)
92
+
93
+ def table(self, name: str, schema: str = "public") -> Table:
94
+ for table in self.tables:
95
+ if table.name == name and table.schema == schema:
96
+ return table
97
+ msg = f"Unknown table {schema}.{name}."
98
+ raise SqlProofSchemaError(msg)