sqlproof 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlproof/__init__.py +32 -0
- sqlproof/_version.py +1 -0
- sqlproof/cli.py +151 -0
- sqlproof/client.py +159 -0
- sqlproof/config.py +42 -0
- sqlproof/contrib/__init__.py +3 -0
- sqlproof/contrib/supabase.py +136 -0
- sqlproof/core.py +344 -0
- sqlproof/coverage/__init__.py +6 -0
- sqlproof/coverage/diversity.py +11 -0
- sqlproof/coverage/plpgsql.py +5 -0
- sqlproof/coverage/schema_shape.py +7 -0
- sqlproof/exceptions.py +47 -0
- sqlproof/generators/__init__.py +21 -0
- sqlproof/generators/columns.py +93 -0
- sqlproof/generators/constraints.py +181 -0
- sqlproof/generators/functions.py +9 -0
- sqlproof/generators/graph.py +51 -0
- sqlproof/generators/rows.py +153 -0
- sqlproof/generators/sampling.py +15 -0
- sqlproof/generators/well_known.py +59 -0
- sqlproof/pytest_plugin.py +24 -0
- sqlproof/reporter/__init__.py +5 -0
- sqlproof/reporter/console.py +20 -0
- sqlproof/reporter/json_io.py +26 -0
- sqlproof/runners/__init__.py +14 -0
- sqlproof/runners/db.py +48 -0
- sqlproof/runners/migration.py +51 -0
- sqlproof/runners/overload.py +41 -0
- sqlproof/runners/property.py +119 -0
- sqlproof/runners/rls.py +40 -0
- sqlproof/runners/stateful.py +36 -0
- sqlproof/schema/__init__.py +27 -0
- sqlproof/schema/dependency_graph.py +38 -0
- sqlproof/schema/fingerprint.py +34 -0
- sqlproof/schema/introspect.py +229 -0
- sqlproof/schema/model.py +98 -0
- sqlproof/schema/parse_sql.py +206 -0
- sqlproof/testing.py +101 -0
- sqlproof/types.py +34 -0
- sqlproof-0.1.0a1.dist-info/METADATA +248 -0
- sqlproof-0.1.0a1.dist-info/RECORD +44 -0
- sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
- sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlproof.generators.functions import FunctionCall
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def function_overloads(
|
|
10
|
+
proof: Any,
|
|
11
|
+
*,
|
|
12
|
+
function: str,
|
|
13
|
+
**kwargs: object,
|
|
14
|
+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
|
|
15
|
+
del kwargs
|
|
16
|
+
|
|
17
|
+
def decorate(callback: Callable[..., None]) -> Callable[..., None]:
|
|
18
|
+
def wrapped() -> None:
|
|
19
|
+
calls = _function_calls(proof, function)
|
|
20
|
+
if len(calls) < 2:
|
|
21
|
+
calls = (
|
|
22
|
+
FunctionCall(sql=f"{function}()", overload_name=f"{function}/0"),
|
|
23
|
+
FunctionCall(sql=f"{function}()", overload_name=f"{function}/0"),
|
|
24
|
+
)
|
|
25
|
+
with proof.client_for_dataset({}) as db:
|
|
26
|
+
callback(db, calls[0], calls[1])
|
|
27
|
+
|
|
28
|
+
return wrapped
|
|
29
|
+
|
|
30
|
+
return decorate
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _function_calls(proof: Any, function_name: str) -> tuple[FunctionCall, ...]:
|
|
34
|
+
calls: list[FunctionCall] = []
|
|
35
|
+
for function in proof.schema_info.functions:
|
|
36
|
+
if function.name != function_name:
|
|
37
|
+
continue
|
|
38
|
+
arguments = ", ".join("%s" for _arg in function.arg_types)
|
|
39
|
+
overload = f"{function.name}/{len(function.arg_types)}"
|
|
40
|
+
calls.append(FunctionCall(sql=f"{function.name}({arguments})", overload_name=overload))
|
|
41
|
+
return tuple(calls)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Callable, Generator, Mapping
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, ParamSpec, TypeVar
|
|
8
|
+
|
|
9
|
+
from hypothesis import HealthCheck, given, settings
|
|
10
|
+
|
|
11
|
+
from sqlproof.exceptions import SqlProofPropertyFailure
|
|
12
|
+
from sqlproof.generators.graph import SizeSpec, dataset_strategy
|
|
13
|
+
from sqlproof.reporter.json_io import write_counterexample
|
|
14
|
+
|
|
15
|
+
P = ParamSpec("P")
|
|
16
|
+
R = TypeVar("R")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Check:
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self.row_context: dict[str, Any] | None = None
|
|
22
|
+
|
|
23
|
+
@contextmanager
|
|
24
|
+
def row(self, **context: Any) -> Generator[None]:
|
|
25
|
+
previous = self.row_context
|
|
26
|
+
self.row_context = context
|
|
27
|
+
try:
|
|
28
|
+
yield
|
|
29
|
+
finally:
|
|
30
|
+
if previous is not None:
|
|
31
|
+
self.row_context = previous
|
|
32
|
+
|
|
33
|
+
def label(self, name: str) -> None:
|
|
34
|
+
del name
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def sqlproof(
|
|
38
|
+
proof: Any,
|
|
39
|
+
*,
|
|
40
|
+
sizes: Mapping[str, SizeSpec],
|
|
41
|
+
runs: int = 100,
|
|
42
|
+
seed: int | None = None,
|
|
43
|
+
timeout_ms: int = 5000,
|
|
44
|
+
commit: bool = False,
|
|
45
|
+
failure_dir: str | Path = ".sqlproof/failures",
|
|
46
|
+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
|
|
47
|
+
del seed, timeout_ms, commit
|
|
48
|
+
|
|
49
|
+
def decorate(function: Callable[..., None]) -> Callable[..., None]:
|
|
50
|
+
def wrapped() -> None:
|
|
51
|
+
run_property(
|
|
52
|
+
proof,
|
|
53
|
+
function,
|
|
54
|
+
sizes=sizes,
|
|
55
|
+
runs=runs,
|
|
56
|
+
failure_dir=Path(failure_dir),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
wrapped.__name__ = function.__name__
|
|
60
|
+
wrapped.__doc__ = function.__doc__
|
|
61
|
+
return wrapped
|
|
62
|
+
|
|
63
|
+
return decorate
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def run_property(
|
|
67
|
+
proof: Any,
|
|
68
|
+
function: Callable[..., None],
|
|
69
|
+
*,
|
|
70
|
+
sizes: Mapping[str, SizeSpec],
|
|
71
|
+
runs: int,
|
|
72
|
+
failure_dir: Path,
|
|
73
|
+
) -> None:
|
|
74
|
+
if hasattr(proof, "dataset_strategy"):
|
|
75
|
+
strategy = proof.dataset_strategy(sizes=sizes)
|
|
76
|
+
else:
|
|
77
|
+
strategy = dataset_strategy(proof.schema_info, sizes=sizes)
|
|
78
|
+
signature = inspect.signature(function)
|
|
79
|
+
wants_check = "check" in signature.parameters
|
|
80
|
+
run_count = 0
|
|
81
|
+
|
|
82
|
+
@given(strategy)
|
|
83
|
+
@settings(
|
|
84
|
+
max_examples=runs,
|
|
85
|
+
deadline=None,
|
|
86
|
+
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
87
|
+
)
|
|
88
|
+
def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
|
|
89
|
+
nonlocal run_count
|
|
90
|
+
run_count += 1
|
|
91
|
+
check = Check()
|
|
92
|
+
try:
|
|
93
|
+
with proof.client_for_dataset(dataset) as client:
|
|
94
|
+
if wants_check:
|
|
95
|
+
function(client, check)
|
|
96
|
+
else:
|
|
97
|
+
function(client)
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
payload: dict[str, Any] = {
|
|
100
|
+
"$schema": "https://sqlproof.dev/schemas/counterexample-v1.json",
|
|
101
|
+
"version": 1,
|
|
102
|
+
"property_name": function.__name__,
|
|
103
|
+
"seed": None,
|
|
104
|
+
"runs": run_count,
|
|
105
|
+
"shrink_steps": 0,
|
|
106
|
+
"schema_fingerprint": proof.schema_fingerprint,
|
|
107
|
+
"row_context": check.row_context or {},
|
|
108
|
+
"dataset": dataset,
|
|
109
|
+
"failure": {
|
|
110
|
+
"kind": type(exc).__name__,
|
|
111
|
+
"message": str(exc),
|
|
112
|
+
"locals": {},
|
|
113
|
+
"traceback": [],
|
|
114
|
+
},
|
|
115
|
+
}
|
|
116
|
+
write_counterexample(failure_dir / f"{function.__name__}.json", payload)
|
|
117
|
+
raise SqlProofPropertyFailure(str(exc), counterexample=payload) from exc
|
|
118
|
+
|
|
119
|
+
execute()
|
sqlproof/runners/rls.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
from hypothesis import HealthCheck, given, settings
|
|
7
|
+
|
|
8
|
+
from sqlproof.generators.graph import SizeSpec, dataset_strategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def rls(
|
|
12
|
+
proof: Any,
|
|
13
|
+
*,
|
|
14
|
+
sizes: Mapping[str, SizeSpec],
|
|
15
|
+
roles: list[str],
|
|
16
|
+
mode: str = "postgrest",
|
|
17
|
+
**kwargs: object,
|
|
18
|
+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
|
|
19
|
+
del mode
|
|
20
|
+
runs = int(cast(Any, kwargs.pop("runs", 1)))
|
|
21
|
+
|
|
22
|
+
def decorate(function: Callable[..., None]) -> Callable[..., None]:
|
|
23
|
+
def wrapped() -> None:
|
|
24
|
+
@given(dataset_strategy(proof.schema_info, sizes=sizes))
|
|
25
|
+
@settings(
|
|
26
|
+
max_examples=runs,
|
|
27
|
+
deadline=None,
|
|
28
|
+
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
29
|
+
)
|
|
30
|
+
def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
|
|
31
|
+
for role in roles:
|
|
32
|
+
with proof.client_for_dataset(dataset) as db:
|
|
33
|
+
db.execute(f"SET LOCAL ROLE {role}")
|
|
34
|
+
function(db, {"role": role}, dataset)
|
|
35
|
+
|
|
36
|
+
execute()
|
|
37
|
+
|
|
38
|
+
return wrapped
|
|
39
|
+
|
|
40
|
+
return decorate
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
from hypothesis import HealthCheck, given, settings
|
|
7
|
+
|
|
8
|
+
from sqlproof.generators.graph import SizeSpec, dataset_strategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def stateful(
|
|
12
|
+
proof: Any, *, sizes: Mapping[str, SizeSpec], **kwargs: object
|
|
13
|
+
) -> Callable[[type[Any]], type[Any]]:
|
|
14
|
+
runs = int(cast(Any, kwargs.pop("runs", 1)))
|
|
15
|
+
|
|
16
|
+
def decorate(cls: type[Any]) -> type[Any]:
|
|
17
|
+
original_run = getattr(cls, "run", None)
|
|
18
|
+
|
|
19
|
+
def run(self: Any) -> None:
|
|
20
|
+
@given(dataset_strategy(proof.schema_info, sizes=sizes))
|
|
21
|
+
@settings(
|
|
22
|
+
max_examples=runs,
|
|
23
|
+
deadline=None,
|
|
24
|
+
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
25
|
+
)
|
|
26
|
+
def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
|
|
27
|
+
with proof.client_for_dataset(dataset) as db:
|
|
28
|
+
if original_run is not None:
|
|
29
|
+
original_run(self, db)
|
|
30
|
+
|
|
31
|
+
execute()
|
|
32
|
+
|
|
33
|
+
cls.run = run
|
|
34
|
+
return cls
|
|
35
|
+
|
|
36
|
+
return decorate
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from sqlproof.schema.dependency_graph import insertion_order
|
|
4
|
+
from sqlproof.schema.fingerprint import compute
|
|
5
|
+
from sqlproof.schema.model import (
|
|
6
|
+
CheckConstraint,
|
|
7
|
+
Column,
|
|
8
|
+
ForeignKey,
|
|
9
|
+
Function,
|
|
10
|
+
PgType,
|
|
11
|
+
SchemaInfo,
|
|
12
|
+
Table,
|
|
13
|
+
)
|
|
14
|
+
from sqlproof.schema.parse_sql import parse_schema_sql
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"CheckConstraint",
|
|
18
|
+
"Column",
|
|
19
|
+
"ForeignKey",
|
|
20
|
+
"Function",
|
|
21
|
+
"PgType",
|
|
22
|
+
"SchemaInfo",
|
|
23
|
+
"Table",
|
|
24
|
+
"compute",
|
|
25
|
+
"insertion_order",
|
|
26
|
+
"parse_schema_sql",
|
|
27
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict, deque
|
|
4
|
+
|
|
5
|
+
from sqlproof.exceptions import CircularDependencyError
|
|
6
|
+
from sqlproof.schema.model import Table
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def insertion_order(tables: tuple[Table, ...]) -> tuple[Table, ...]:
|
|
10
|
+
by_name = {table.name: table for table in tables}
|
|
11
|
+
dependents: dict[str, set[str]] = defaultdict(set)
|
|
12
|
+
indegree = {table.name: 0 for table in tables}
|
|
13
|
+
|
|
14
|
+
for table in tables:
|
|
15
|
+
dependencies = {
|
|
16
|
+
fk.referenced_table
|
|
17
|
+
for fk in table.foreign_keys
|
|
18
|
+
if fk.referenced_table != table.name and fk.referenced_table in by_name
|
|
19
|
+
}
|
|
20
|
+
indegree[table.name] = len(dependencies)
|
|
21
|
+
for dependency in dependencies:
|
|
22
|
+
dependents[dependency].add(table.name)
|
|
23
|
+
|
|
24
|
+
ready = deque(table.name for table in tables if indegree[table.name] == 0)
|
|
25
|
+
ordered: list[Table] = []
|
|
26
|
+
while ready:
|
|
27
|
+
name = ready.popleft()
|
|
28
|
+
ordered.append(by_name[name])
|
|
29
|
+
for dependent in sorted(dependents[name]):
|
|
30
|
+
indegree[dependent] -= 1
|
|
31
|
+
if indegree[dependent] == 0:
|
|
32
|
+
ready.append(dependent)
|
|
33
|
+
|
|
34
|
+
if len(ordered) != len(tables):
|
|
35
|
+
cycle = ", ".join(sorted(name for name, degree in indegree.items() if degree > 0))
|
|
36
|
+
raise CircularDependencyError(f"Circular foreign-key dependency detected: {cycle}")
|
|
37
|
+
|
|
38
|
+
return tuple(ordered)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from dataclasses import asdict, is_dataclass
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from sqlproof.schema.model import SchemaInfo
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _canonical(value: Any) -> Any:
|
|
12
|
+
if is_dataclass(value):
|
|
13
|
+
return _canonical(asdict(cast(Any, value)))
|
|
14
|
+
if isinstance(value, dict):
|
|
15
|
+
items = cast("dict[str, Any]", value)
|
|
16
|
+
return {key: _canonical(items[key]) for key in sorted(items)}
|
|
17
|
+
if isinstance(value, tuple):
|
|
18
|
+
tuple_value = cast("tuple[Any, ...]", value) # type: ignore[redundant-cast]
|
|
19
|
+
return [_canonical(item) for item in tuple_value]
|
|
20
|
+
if isinstance(value, list):
|
|
21
|
+
list_value = cast("list[Any]", value) # type: ignore[redundant-cast]
|
|
22
|
+
return [_canonical(item) for item in list_value]
|
|
23
|
+
return value
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute(schema_info: SchemaInfo) -> str:
|
|
27
|
+
canonical_json = json.dumps(
|
|
28
|
+
_canonical(schema_info),
|
|
29
|
+
sort_keys=True,
|
|
30
|
+
separators=(",", ":"),
|
|
31
|
+
ensure_ascii=True,
|
|
32
|
+
)
|
|
33
|
+
digest = hashlib.sha256(canonical_json.encode("utf-8")).hexdigest()
|
|
34
|
+
return f"sha256:{digest}"
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, cast
|
|
4
|
+
|
|
5
|
+
from sqlproof.schema.model import CheckConstraint, Column, ForeignKey, PgType, SchemaInfo, Table
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def introspect_schema(connection: Any, *, schema: str = "public") -> SchemaInfo:
|
|
9
|
+
enums = _load_enums(connection, schema=schema)
|
|
10
|
+
enum_by_name = {enum.name: enum for enum in enums}
|
|
11
|
+
columns_by_table: dict[tuple[str, str], list[Column]] = {}
|
|
12
|
+
for row in _fetch_all(connection, _COLUMNS_SQL, schema):
|
|
13
|
+
key = (str(row["schema_name"]), str(row["table_name"]))
|
|
14
|
+
type_name = str(row["type_name"])
|
|
15
|
+
pg_type = enum_by_name.get(type_name, PgType(kind="scalar", name=type_name))
|
|
16
|
+
columns_by_table.setdefault(key, []).append(
|
|
17
|
+
Column(
|
|
18
|
+
name=str(row["column_name"]),
|
|
19
|
+
type=pg_type,
|
|
20
|
+
nullable=bool(row["nullable"]),
|
|
21
|
+
default=_optional_str(row["default"]),
|
|
22
|
+
is_generated=bool(row["is_generated"]) or _is_serial_default(row["default"]),
|
|
23
|
+
identity=_identity(row["identity"]),
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
primary_keys = _constraint_columns(connection, _PRIMARY_KEYS_SQL, schema)
|
|
28
|
+
unique_constraints = _grouped_constraints(connection, _UNIQUES_SQL, schema)
|
|
29
|
+
foreign_keys = _foreign_keys(connection, schema=schema)
|
|
30
|
+
check_constraints = _check_constraints(connection, schema=schema)
|
|
31
|
+
|
|
32
|
+
tables = [
|
|
33
|
+
Table(
|
|
34
|
+
schema=table_schema,
|
|
35
|
+
name=table_name,
|
|
36
|
+
columns=tuple(columns),
|
|
37
|
+
primary_key=primary_keys.get((table_schema, table_name), ()),
|
|
38
|
+
foreign_keys=tuple(foreign_keys.get((table_schema, table_name), ())),
|
|
39
|
+
unique_constraints=tuple(unique_constraints.get((table_schema, table_name), ())),
|
|
40
|
+
check_constraints=tuple(check_constraints.get((table_schema, table_name), ())),
|
|
41
|
+
)
|
|
42
|
+
for (table_schema, table_name), columns in sorted(columns_by_table.items())
|
|
43
|
+
]
|
|
44
|
+
return SchemaInfo(tables=tuple(tables), enums=enums)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _fetch_all(connection: Any, sql: str, *params: object) -> list[dict[str, Any]]:
|
|
48
|
+
cursor = connection.execute(sql, params)
|
|
49
|
+
return [dict(row) for row in cursor.fetchall()]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _load_enums(connection: Any, *, schema: str) -> tuple[PgType, ...]:
|
|
53
|
+
rows = _fetch_all(connection, _ENUMS_SQL, schema)
|
|
54
|
+
return tuple(
|
|
55
|
+
PgType(
|
|
56
|
+
kind="enum",
|
|
57
|
+
name=str(row["enum_name"]),
|
|
58
|
+
enum_values=tuple(str(value) for value in row["enum_values"]),
|
|
59
|
+
)
|
|
60
|
+
for row in rows
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _constraint_columns(
|
|
65
|
+
connection: Any, sql: str, schema: str
|
|
66
|
+
) -> dict[tuple[str, str], tuple[str, ...]]:
|
|
67
|
+
return {
|
|
68
|
+
(str(row["schema_name"]), str(row["table_name"])): tuple(
|
|
69
|
+
str(value) for value in row["columns"]
|
|
70
|
+
)
|
|
71
|
+
for row in _fetch_all(connection, sql, schema)
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _grouped_constraints(
|
|
76
|
+
connection: Any, sql: str, schema: str
|
|
77
|
+
) -> dict[tuple[str, str], list[tuple[str, ...]]]:
|
|
78
|
+
grouped: dict[tuple[str, str], list[tuple[str, ...]]] = {}
|
|
79
|
+
for row in _fetch_all(connection, sql, schema):
|
|
80
|
+
key = (str(row["schema_name"]), str(row["table_name"]))
|
|
81
|
+
grouped.setdefault(key, []).append(tuple(str(value) for value in row["columns"]))
|
|
82
|
+
return grouped
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _foreign_keys(connection: Any, *, schema: str) -> dict[tuple[str, str], list[ForeignKey]]:
|
|
86
|
+
grouped: dict[tuple[str, str], list[ForeignKey]] = {}
|
|
87
|
+
for row in _fetch_all(connection, _FOREIGN_KEYS_SQL, schema):
|
|
88
|
+
key = (str(row["schema_name"]), str(row["table_name"]))
|
|
89
|
+
grouped.setdefault(key, []).append(
|
|
90
|
+
ForeignKey(
|
|
91
|
+
columns=tuple(str(value) for value in row["columns"]),
|
|
92
|
+
referenced_table=str(row["referenced_table"]),
|
|
93
|
+
referenced_columns=tuple(str(value) for value in row["referenced_columns"]),
|
|
94
|
+
on_delete=cast(
|
|
95
|
+
Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"],
|
|
96
|
+
str(row["on_delete"]),
|
|
97
|
+
),
|
|
98
|
+
on_update=cast(
|
|
99
|
+
Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"],
|
|
100
|
+
str(row["on_update"]),
|
|
101
|
+
),
|
|
102
|
+
referenced_schema=str(row["referenced_schema"]),
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
return grouped
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _check_constraints(
|
|
109
|
+
connection: Any, *, schema: str
|
|
110
|
+
) -> dict[tuple[str, str], list[CheckConstraint]]:
|
|
111
|
+
grouped: dict[tuple[str, str], list[CheckConstraint]] = {}
|
|
112
|
+
for row in _fetch_all(connection, _CHECKS_SQL, schema):
|
|
113
|
+
key = (str(row["schema_name"]), str(row["table_name"]))
|
|
114
|
+
grouped.setdefault(key, []).append(CheckConstraint(str(row["expression"])))
|
|
115
|
+
return grouped
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _optional_str(value: object) -> str | None:
|
|
119
|
+
return None if value is None else str(value)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _identity(value: object) -> Literal["always", "by_default"] | None:
|
|
123
|
+
if value is None:
|
|
124
|
+
return None
|
|
125
|
+
normalized = str(value).lower().replace(" ", "_")
|
|
126
|
+
if normalized in {"always", "by_default"}:
|
|
127
|
+
return cast(Literal["always", "by_default"], normalized)
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _is_serial_default(value: object) -> bool:
|
|
132
|
+
return isinstance(value, str) and value.startswith("nextval(")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
_ENUMS_SQL = """
|
|
136
|
+
SELECT
|
|
137
|
+
n.nspname AS schema_name,
|
|
138
|
+
t.typname AS enum_name,
|
|
139
|
+
array_agg(e.enumlabel ORDER BY e.enumsortorder) AS enum_values
|
|
140
|
+
FROM pg_catalog.pg_type t
|
|
141
|
+
JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
|
|
142
|
+
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|
|
143
|
+
WHERE n.nspname = %s
|
|
144
|
+
GROUP BY n.nspname, t.typname
|
|
145
|
+
ORDER BY n.nspname, t.typname
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
_COLUMNS_SQL = """
|
|
149
|
+
SELECT
|
|
150
|
+
ns.nspname AS schema_name,
|
|
151
|
+
cls.relname AS table_name,
|
|
152
|
+
att.attname AS column_name,
|
|
153
|
+
typ.typname AS type_name,
|
|
154
|
+
NOT att.attnotnull AS nullable,
|
|
155
|
+
pg_get_expr(def.adbin, def.adrelid) AS default,
|
|
156
|
+
att.attgenerated <> '' AS is_generated,
|
|
157
|
+
CASE att.attidentity WHEN 'a' THEN 'always' WHEN 'd' THEN 'by default' ELSE NULL END AS identity,
|
|
158
|
+
ARRAY[]::integer[] AS modifiers
|
|
159
|
+
FROM pg_catalog.pg_attribute att
|
|
160
|
+
JOIN pg_catalog.pg_class cls ON cls.oid = att.attrelid
|
|
161
|
+
JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
|
|
162
|
+
JOIN pg_catalog.pg_type typ ON typ.oid = att.atttypid
|
|
163
|
+
LEFT JOIN pg_catalog.pg_attrdef def ON def.adrelid = att.attrelid AND def.adnum = att.attnum
|
|
164
|
+
WHERE ns.nspname = %s
|
|
165
|
+
AND cls.relkind IN ('r', 'p')
|
|
166
|
+
AND att.attnum > 0
|
|
167
|
+
AND NOT att.attisdropped
|
|
168
|
+
ORDER BY ns.nspname, cls.relname, att.attnum
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
_PRIMARY_KEYS_SQL = """
|
|
172
|
+
SELECT
|
|
173
|
+
ns.nspname AS schema_name,
|
|
174
|
+
cls.relname AS table_name,
|
|
175
|
+
array_agg(att.attname ORDER BY key_column.ordinality) AS columns
|
|
176
|
+
FROM pg_catalog.pg_constraint con
|
|
177
|
+
JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
|
|
178
|
+
JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
|
|
179
|
+
JOIN unnest(con.conkey) WITH ORDINALITY AS key_column(attnum, ordinality) ON true
|
|
180
|
+
JOIN pg_catalog.pg_attribute att ON att.attrelid = cls.oid AND att.attnum = key_column.attnum
|
|
181
|
+
WHERE ns.nspname = %s AND con.contype = 'p'
|
|
182
|
+
GROUP BY ns.nspname, cls.relname, con.oid
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
_UNIQUES_SQL = _PRIMARY_KEYS_SQL.replace("con.contype = 'p'", "con.contype = 'u'")
|
|
186
|
+
|
|
187
|
+
_FOREIGN_KEYS_SQL = """
|
|
188
|
+
SELECT
|
|
189
|
+
ns.nspname AS schema_name,
|
|
190
|
+
cls.relname AS table_name,
|
|
191
|
+
array_agg(src_att.attname ORDER BY src_key.ordinality) AS columns,
|
|
192
|
+
ref_ns.nspname AS referenced_schema,
|
|
193
|
+
ref_cls.relname AS referenced_table,
|
|
194
|
+
array_agg(ref_att.attname ORDER BY ref_key.ordinality) AS referenced_columns,
|
|
195
|
+
CASE con.confdeltype
|
|
196
|
+
WHEN 'r' THEN 'RESTRICT' WHEN 'c' THEN 'CASCADE' WHEN 'n' THEN 'SET NULL'
|
|
197
|
+
WHEN 'd' THEN 'SET DEFAULT' ELSE 'NO ACTION'
|
|
198
|
+
END AS on_delete,
|
|
199
|
+
CASE con.confupdtype
|
|
200
|
+
WHEN 'r' THEN 'RESTRICT' WHEN 'c' THEN 'CASCADE' WHEN 'n' THEN 'SET NULL'
|
|
201
|
+
WHEN 'd' THEN 'SET DEFAULT' ELSE 'NO ACTION'
|
|
202
|
+
END AS on_update
|
|
203
|
+
FROM pg_catalog.pg_constraint con
|
|
204
|
+
JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
|
|
205
|
+
JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
|
|
206
|
+
JOIN pg_catalog.pg_class ref_cls ON ref_cls.oid = con.confrelid
|
|
207
|
+
JOIN pg_catalog.pg_namespace ref_ns ON ref_ns.oid = ref_cls.relnamespace
|
|
208
|
+
JOIN unnest(con.conkey) WITH ORDINALITY AS src_key(attnum, ordinality) ON true
|
|
209
|
+
JOIN unnest(con.confkey) WITH ORDINALITY AS ref_key(attnum, ordinality)
|
|
210
|
+
ON ref_key.ordinality = src_key.ordinality
|
|
211
|
+
JOIN pg_catalog.pg_attribute src_att
|
|
212
|
+
ON src_att.attrelid = cls.oid AND src_att.attnum = src_key.attnum
|
|
213
|
+
JOIN pg_catalog.pg_attribute ref_att
|
|
214
|
+
ON ref_att.attrelid = ref_cls.oid AND ref_att.attnum = ref_key.attnum
|
|
215
|
+
WHERE ns.nspname = %s AND con.contype = 'f'
|
|
216
|
+
GROUP BY ns.nspname, cls.relname, ref_ns.nspname, ref_cls.relname, con.oid
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
_CHECKS_SQL = """
|
|
220
|
+
SELECT
|
|
221
|
+
ns.nspname AS schema_name,
|
|
222
|
+
cls.relname AS table_name,
|
|
223
|
+
pg_get_constraintdef(con.oid, true) AS expression
|
|
224
|
+
FROM pg_catalog.pg_constraint con
|
|
225
|
+
JOIN pg_catalog.pg_class cls ON cls.oid = con.conrelid
|
|
226
|
+
JOIN pg_catalog.pg_namespace ns ON ns.oid = cls.relnamespace
|
|
227
|
+
WHERE ns.nspname = %s AND con.contype = 'c'
|
|
228
|
+
ORDER BY ns.nspname, cls.relname, con.conname
|
|
229
|
+
"""
|
sqlproof/schema/model.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
from sqlproof.exceptions import SqlProofSchemaError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True, slots=True)
|
|
10
|
+
class PgType:
|
|
11
|
+
kind: Literal["scalar", "array", "enum", "domain", "composite", "range"]
|
|
12
|
+
name: str
|
|
13
|
+
base: PgType | None = None
|
|
14
|
+
enum_values: tuple[str, ...] = ()
|
|
15
|
+
array_dim: int = 0
|
|
16
|
+
modifiers: tuple[int, ...] = ()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class ParsedCheck:
|
|
21
|
+
kind: Literal["range", "in_set", "regex", "length", "compound"]
|
|
22
|
+
column: str
|
|
23
|
+
payload: Any
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True, slots=True)
|
|
27
|
+
class CheckConstraint:
|
|
28
|
+
expression: str
|
|
29
|
+
parsed: ParsedCheck | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True, slots=True)
|
|
33
|
+
class Column:
|
|
34
|
+
name: str
|
|
35
|
+
type: PgType
|
|
36
|
+
nullable: bool
|
|
37
|
+
default: str | None
|
|
38
|
+
is_generated: bool
|
|
39
|
+
identity: Literal["always", "by_default"] | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True, slots=True)
|
|
43
|
+
class ForeignKey:
|
|
44
|
+
columns: tuple[str, ...]
|
|
45
|
+
referenced_table: str
|
|
46
|
+
referenced_columns: tuple[str, ...]
|
|
47
|
+
on_delete: Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"]
|
|
48
|
+
on_update: Literal["NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"]
|
|
49
|
+
referenced_schema: str | None = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True, slots=True)
|
|
53
|
+
class Table:
|
|
54
|
+
schema: str
|
|
55
|
+
name: str
|
|
56
|
+
columns: tuple[Column, ...]
|
|
57
|
+
primary_key: tuple[str, ...]
|
|
58
|
+
foreign_keys: tuple[ForeignKey, ...]
|
|
59
|
+
unique_constraints: tuple[tuple[str, ...], ...]
|
|
60
|
+
check_constraints: tuple[CheckConstraint, ...]
|
|
61
|
+
opaque_constraints: tuple[str, ...] = ()
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def qualified_name(self) -> str:
|
|
65
|
+
return f"{self.schema}.{self.name}"
|
|
66
|
+
|
|
67
|
+
def column(self, name: str) -> Column:
|
|
68
|
+
for column in self.columns:
|
|
69
|
+
if column.name == name:
|
|
70
|
+
return column
|
|
71
|
+
msg = f"Unknown column {name!r} on table {self.qualified_name!r}."
|
|
72
|
+
raise SqlProofSchemaError(msg)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass(frozen=True, slots=True)
|
|
76
|
+
class Function:
|
|
77
|
+
schema: str
|
|
78
|
+
name: str
|
|
79
|
+
arg_types: tuple[PgType, ...]
|
|
80
|
+
return_type: PgType
|
|
81
|
+
volatility: Literal["immutable", "stable", "volatile"]
|
|
82
|
+
language: str
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(frozen=True, slots=True)
|
|
86
|
+
class SchemaInfo:
|
|
87
|
+
tables: tuple[Table, ...] = ()
|
|
88
|
+
enums: tuple[PgType, ...] = ()
|
|
89
|
+
functions: tuple[Function, ...] = ()
|
|
90
|
+
domains: tuple[PgType, ...] = ()
|
|
91
|
+
opaque_sql: tuple[str, ...] = field(default_factory=tuple)
|
|
92
|
+
|
|
93
|
+
def table(self, name: str, schema: str = "public") -> Table:
|
|
94
|
+
for table in self.tables:
|
|
95
|
+
if table.name == name and table.schema == schema:
|
|
96
|
+
return table
|
|
97
|
+
msg = f"Unknown table {schema}.{name}."
|
|
98
|
+
raise SqlProofSchemaError(msg)
|