sqlproof 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sqlproof/__init__.py +32 -0
  2. sqlproof/_version.py +1 -0
  3. sqlproof/cli.py +151 -0
  4. sqlproof/client.py +159 -0
  5. sqlproof/config.py +42 -0
  6. sqlproof/contrib/__init__.py +3 -0
  7. sqlproof/contrib/supabase.py +136 -0
  8. sqlproof/core.py +344 -0
  9. sqlproof/coverage/__init__.py +6 -0
  10. sqlproof/coverage/diversity.py +11 -0
  11. sqlproof/coverage/plpgsql.py +5 -0
  12. sqlproof/coverage/schema_shape.py +7 -0
  13. sqlproof/exceptions.py +47 -0
  14. sqlproof/generators/__init__.py +21 -0
  15. sqlproof/generators/columns.py +93 -0
  16. sqlproof/generators/constraints.py +181 -0
  17. sqlproof/generators/functions.py +9 -0
  18. sqlproof/generators/graph.py +51 -0
  19. sqlproof/generators/rows.py +153 -0
  20. sqlproof/generators/sampling.py +15 -0
  21. sqlproof/generators/well_known.py +59 -0
  22. sqlproof/pytest_plugin.py +24 -0
  23. sqlproof/reporter/__init__.py +5 -0
  24. sqlproof/reporter/console.py +20 -0
  25. sqlproof/reporter/json_io.py +26 -0
  26. sqlproof/runners/__init__.py +14 -0
  27. sqlproof/runners/db.py +48 -0
  28. sqlproof/runners/migration.py +51 -0
  29. sqlproof/runners/overload.py +41 -0
  30. sqlproof/runners/property.py +119 -0
  31. sqlproof/runners/rls.py +40 -0
  32. sqlproof/runners/stateful.py +36 -0
  33. sqlproof/schema/__init__.py +27 -0
  34. sqlproof/schema/dependency_graph.py +38 -0
  35. sqlproof/schema/fingerprint.py +34 -0
  36. sqlproof/schema/introspect.py +229 -0
  37. sqlproof/schema/model.py +98 -0
  38. sqlproof/schema/parse_sql.py +206 -0
  39. sqlproof/testing.py +101 -0
  40. sqlproof/types.py +34 -0
  41. sqlproof-0.1.0a1.dist-info/METADATA +248 -0
  42. sqlproof-0.1.0a1.dist-info/RECORD +44 -0
  43. sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
  44. sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
sqlproof/__init__.py ADDED
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from sqlproof._version import __version__
6
+
7
+ if TYPE_CHECKING:
8
+ from sqlproof.config import ExternalTableSpec, SqlProofConfig
9
+ from sqlproof.core import SqlProof
10
+ from sqlproof.runners import sqlproof
11
+
12
+ __all__ = ["ExternalTableSpec", "SqlProof", "SqlProofConfig", "__version__", "sqlproof"]
13
+
14
+
15
+ def __getattr__(name: str) -> object:
16
+ if name == "SqlProof":
17
+ from sqlproof.core import SqlProof
18
+
19
+ return SqlProof
20
+ if name == "SqlProofConfig":
21
+ from sqlproof.config import SqlProofConfig
22
+
23
+ return SqlProofConfig
24
+ if name == "ExternalTableSpec":
25
+ from sqlproof.config import ExternalTableSpec
26
+
27
+ return ExternalTableSpec
28
+ if name == "sqlproof":
29
+ from sqlproof.runners import sqlproof
30
+
31
+ return sqlproof
32
+ raise AttributeError(name)
sqlproof/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0a1"
sqlproof/cli.py ADDED
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Any, cast
8
+
9
+ from sqlproof._version import __version__
10
+ from sqlproof.coverage.schema_shape import summarize_dataset_shape
11
+ from sqlproof.reporter.console import format_failure
12
+ from sqlproof.schema.model import Column, SchemaInfo, Table
13
+ from sqlproof.schema.parse_sql import parse_schema_sql
14
+
15
+
16
+ def main(argv: list[str] | None = None) -> int:
17
+ parser = argparse.ArgumentParser(prog="sqlproof")
18
+ subcommands = parser.add_subparsers(dest="command", required=True)
19
+
20
+ subcommands.add_parser("version")
21
+
22
+ introspect = subcommands.add_parser("introspect")
23
+ introspect.add_argument("--schema-file", type=Path)
24
+ introspect.add_argument("--dsn")
25
+ introspect.add_argument("--format", choices=["json", "text"], default="text")
26
+
27
+ generate_types_parser = subcommands.add_parser("generate-types")
28
+ generate_types_parser.add_argument("--schema-file", type=Path, required=True)
29
+ generate_types_parser.add_argument("--output", type=Path)
30
+ generate_types_parser.add_argument(
31
+ "--style", choices=["typeddict", "dataclass", "pydantic"], default="typeddict"
32
+ )
33
+
34
+ replay = subcommands.add_parser("replay")
35
+ replay.add_argument("counterexample", type=Path)
36
+
37
+ report = subcommands.add_parser("report")
38
+ report.add_argument("counterexample", type=Path)
39
+ report.add_argument("--format", choices=["json", "text"], default="text")
40
+
41
+ run = subcommands.add_parser("run")
42
+ run.add_argument("test_path")
43
+ run.add_argument("pytest_options", nargs=argparse.REMAINDER)
44
+
45
+ subcommands.add_parser("clean-orphans")
46
+
47
+ args = parser.parse_args(argv)
48
+ if args.command == "version":
49
+ print(f"sqlproof {__version__}")
50
+ return 0
51
+ if args.command == "introspect":
52
+ if args.schema_file is None and args.dsn is None:
53
+ parser.error("one of --schema-file or --dsn is required")
54
+ if args.schema_file is not None:
55
+ schema = parse_schema_sql(args.schema_file.read_text(encoding="utf-8"))
56
+ else:
57
+ from sqlproof import SqlProof
58
+
59
+ schema = SqlProof.from_connection_string(str(args.dsn)).schema_info
60
+ if args.format == "json":
61
+ print(json.dumps(_schema_payload(schema), default=str))
62
+ else:
63
+ for table in schema.tables:
64
+ print(f"{table.qualified_name} ({len(table.columns)} columns)")
65
+ return 0
66
+ if args.command == "generate-types":
67
+ from sqlproof.types import generate_types
68
+
69
+ output = generate_types(args.schema_file, style=args.style)
70
+ if args.output:
71
+ args.output.write_text(output, encoding="utf-8")
72
+ else:
73
+ print(output)
74
+ return 0
75
+ if args.command == "replay":
76
+ payload = json.loads(args.counterexample.read_text(encoding="utf-8"))
77
+ print(format_failure(payload))
78
+ print(f"replay loaded {payload.get('property_name', 'counterexample')}")
79
+ return 0
80
+ if args.command == "report":
81
+ payload = json.loads(args.counterexample.read_text(encoding="utf-8"))
82
+ report_payload = _counterexample_report(payload)
83
+ if args.format == "json":
84
+ print(json.dumps(report_payload, default=str))
85
+ else:
86
+ print(format_failure(payload))
87
+ return 0
88
+ if args.command == "run":
89
+ import pytest
90
+
91
+ return int(pytest.main([args.test_path, *args.pytest_options]))
92
+ if args.command == "clean-orphans":
93
+ print("No orphaned SqlProof containers found.")
94
+ return 0
95
+ return 1
96
+
97
+
98
+ def _schema_payload(schema: SchemaInfo) -> dict[str, object]:
99
+ return {"tables": [_table_payload(table) for table in schema.tables]}
100
+
101
+
102
+ def _table_payload(table: Table) -> dict[str, object]:
103
+ return {
104
+ "schema": table.schema,
105
+ "name": table.name,
106
+ "columns": [_column_payload(column) for column in table.columns],
107
+ "primary_key": list(table.primary_key),
108
+ "unique_constraints": [list(columns) for columns in table.unique_constraints],
109
+ "checks": [check.expression for check in table.check_constraints],
110
+ "foreign_keys": [
111
+ {
112
+ "columns": list(foreign_key.columns),
113
+ "referenced_table": foreign_key.referenced_table,
114
+ "referenced_columns": list(foreign_key.referenced_columns),
115
+ "on_delete": foreign_key.on_delete,
116
+ "on_update": foreign_key.on_update,
117
+ }
118
+ for foreign_key in table.foreign_keys
119
+ ],
120
+ }
121
+
122
+
123
+ def _column_payload(column: Column) -> dict[str, object]:
124
+ return {
125
+ "name": column.name,
126
+ "type": column.type.name,
127
+ "nullable": column.nullable,
128
+ "default": column.default,
129
+ "is_generated": column.is_generated,
130
+ "identity": column.identity,
131
+ }
132
+
133
+
134
+ def _counterexample_report(payload: dict[str, object]) -> dict[str, object]:
135
+ dataset = payload.get("dataset")
136
+ shape = (
137
+ summarize_dataset_shape(cast(dict[str, list[dict[str, Any]]], dataset))
138
+ if isinstance(dataset, dict)
139
+ else {}
140
+ )
141
+ return {
142
+ "property_name": payload.get("property_name", "counterexample"),
143
+ "schema_fingerprint": payload.get("schema_fingerprint"),
144
+ "row_context": payload.get("row_context", {}),
145
+ "failure": payload.get("failure", {}),
146
+ "shape": shape,
147
+ }
148
+
149
+
150
+ if __name__ == "__main__": # pragma: no cover
151
+ raise SystemExit(main(sys.argv[1:]))
sqlproof/client.py ADDED
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections.abc import Generator
5
+ from contextlib import AbstractContextManager, contextmanager
6
+ from dataclasses import fields, is_dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Protocol, TypeVar, cast
9
+
10
+ from sqlproof.exceptions import SqlProofMappingError, SqlProofUsageError
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ class SqlProofClient(Protocol):
16
+ def query(self, sql: str, *params: Any) -> list[dict[str, Any]]: ...
17
+
18
+ def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]: ...
19
+
20
+ def scalar(self, sql: str, *params: Any) -> Any: ...
21
+
22
+ def execute(self, sql: str, *params: Any) -> int: ...
23
+
24
+ def execute_file(self, path: str) -> None: ...
25
+
26
+ def savepoint(self) -> AbstractContextManager[None]: ...
27
+
28
+ def get_generated_data(self) -> dict[str, list[dict[str, Any]]]: ...
29
+
30
+
31
+ class InMemorySqlProofClient:
32
+ def __init__(self, dataset: dict[str, list[dict[str, Any]]]) -> None:
33
+ self._dataset = dataset
34
+
35
+ def query(self, sql: str, *params: Any) -> list[dict[str, Any]]:
36
+ del params
37
+ match = re.search(r"SELECT\s+(?P<columns>.*?)\s+FROM\s+(?P<table>\w+)", sql, re.I | re.S)
38
+ if match is None:
39
+ return []
40
+ table = match.group("table")
41
+ rows = self._dataset.get(table, [])
42
+ columns_sql = match.group("columns").strip()
43
+ if columns_sql == "*":
44
+ return [dict(row) for row in rows]
45
+ columns = [_clean_selected_column(part) for part in columns_sql.split(",")]
46
+ return [{column: row.get(column) for column in columns} for row in rows]
47
+
48
+ def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]:
49
+ rows = self.query(sql, *params)
50
+ return [_map_row(row, model) for row in rows]
51
+
52
+ def scalar(self, sql: str, *params: Any) -> Any:
53
+ rows = self.query(sql, *params)
54
+ if not rows:
55
+ return None
56
+ return next(iter(rows[0].values()))
57
+
58
+ def execute(self, sql: str, *params: Any) -> int:
59
+ del sql, params
60
+ return 0
61
+
62
+ def execute_file(self, path: str) -> None:
63
+ del path
64
+
65
+ @contextmanager
66
+ def savepoint(self) -> Generator[None]:
67
+ yield
68
+
69
+ def get_generated_data(self) -> dict[str, list[dict[str, Any]]]:
70
+ return self._dataset
71
+
72
+
73
+ class PsycopgSqlProofClient:
74
+ def __init__(
75
+ self,
76
+ connection: Any,
77
+ *,
78
+ dataset: dict[str, list[dict[str, Any]]] | None = None,
79
+ ) -> None:
80
+ self._connection = connection
81
+ self._dataset = dataset or {}
82
+
83
+ def query(self, sql: str, *params: Any) -> list[dict[str, Any]]:
84
+ cursor = self._connection.execute(sql, params)
85
+ return [dict(row) for row in cursor.fetchall()]
86
+
87
+ def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]:
88
+ rows = self.query(sql, *params)
89
+ return [_map_row(row, model) for row in rows]
90
+
91
+ def scalar(self, sql: str, *params: Any) -> Any:
92
+ rows = self.query(sql, *params)
93
+ if not rows:
94
+ return None
95
+ return next(iter(rows[0].values()))
96
+
97
+ def execute(self, sql: str, *params: Any) -> int:
98
+ cursor = self._connection.execute(sql, params)
99
+ return int(cursor.rowcount)
100
+
101
+ def execute_file(self, path: str | Path) -> None:
102
+ sql = Path(path).read_text(encoding="utf-8")
103
+ for statement in _split_sql_statements(sql):
104
+ self.execute(statement)
105
+
106
+ @contextmanager
107
+ def savepoint(self) -> Generator[None]:
108
+ name = "sqlproof_client"
109
+ self.execute(f"SAVEPOINT {name}")
110
+ try:
111
+ yield
112
+ except BaseException:
113
+ self.execute(f"ROLLBACK TO SAVEPOINT {name}")
114
+ raise
115
+ finally:
116
+ self.execute(f"RELEASE SAVEPOINT {name}")
117
+
118
+ def get_generated_data(self) -> dict[str, list[dict[str, Any]]]:
119
+ return self._dataset
120
+
121
+
122
+ def _split_sql_statements(sql: str) -> list[str]:
123
+ statements = [statement.strip() for statement in sql.split(";")]
124
+ return [statement for statement in statements if statement]
125
+
126
+
127
+ def _clean_selected_column(sql: str) -> str:
128
+ value = sql.strip()
129
+ if "." in value:
130
+ value = value.rsplit(".", 1)[1]
131
+ if " AS " in value.upper():
132
+ value = re.split(r"\s+AS\s+", value, flags=re.I)[-1]
133
+ return value.strip().strip('"')
134
+
135
+
136
+ def _map_row(row: dict[str, Any], model: type[T]) -> T:
137
+ detectors = [
138
+ hasattr(model, "__pydantic_fields__"),
139
+ is_dataclass(model),
140
+ hasattr(model, "__required_keys__") or hasattr(model, "__optional_keys__"),
141
+ ]
142
+ if sum(detectors) != 1:
143
+ raise SqlProofUsageError("Ambiguous or unsupported row model type.")
144
+ if hasattr(model, "__pydantic_fields__"):
145
+ try:
146
+ return model.model_validate(row) # type: ignore[attr-defined,no-any-return]
147
+ except Exception as exc: # pragma: no cover - exercised when pydantic is installed
148
+ raise SqlProofMappingError(str(exc)) from exc
149
+ if is_dataclass(model):
150
+ names = {field.name for field in fields(model)}
151
+ missing = {field.name for field in fields(model) if field.name not in row}
152
+ if missing:
153
+ raise SqlProofMappingError(f"Missing required fields: {', '.join(sorted(missing))}")
154
+ return model(**{name: row[name] for name in names if name in row})
155
+ required = cast("set[str]", getattr(model, "__required_keys__", set[str]()))
156
+ missing = required - row.keys()
157
+ if missing:
158
+ raise SqlProofMappingError(f"Missing required fields: {', '.join(sorted(missing))}")
159
+ return dict(row) # type: ignore[return-value]
sqlproof/config.py ADDED
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping, Sequence
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from hypothesis.strategies import SearchStrategy
9
+
10
+ from sqlproof.exceptions import SqlProofUsageError
11
+
12
+ if TYPE_CHECKING:
13
+ from sqlproof.client import SqlProofClient
14
+
15
+ ExternalSeed = Callable[["SqlProofClient"], None] | Callable[["SqlProofClient", int], None]
16
+ SizeSpec = int | SearchStrategy[int]
17
+
18
+
19
+ @dataclass(frozen=True, slots=True)
20
+ class ExternalTableSpec:
21
+ primary_key: str
22
+ sample: Callable[[SqlProofClient], Sequence[object]]
23
+ seed: ExternalSeed | None = None
24
+ seed_count: SizeSpec | None = None
25
+
26
+
27
+ @dataclass(frozen=True, slots=True)
28
+ class SqlProofConfig:
29
+ connection_string: str | None = None
30
+ schema: str = "public"
31
+ schema_file: str | Path | None = None
32
+ image: str = "postgres:16"
33
+ reuse_container: bool = True
34
+ transaction_per_run: bool = True
35
+ seed: Callable[[SqlProofClient], None] | None = None
36
+ external_tables: Mapping[str, ExternalTableSpec] | None = None
37
+
38
+ def __post_init__(self) -> None:
39
+ sources = [self.connection_string is not None, self.schema_file is not None]
40
+ if sum(sources) != 1:
41
+ msg = "Exactly one of connection_string or schema_file must be provided."
42
+ raise SqlProofUsageError(msg)
@@ -0,0 +1,3 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from collections.abc import Generator, Mapping
6
+ from contextlib import contextmanager
7
+ from importlib import import_module
8
+ from typing import Any, cast
9
+
10
+ from sqlproof.client import SqlProofClient
11
+
12
+ CLAIMS_GUC = "request.jwt.claims"
13
+
14
+
15
+ @contextmanager
16
+ def as_supabase_user(
17
+ db: SqlProofClient,
18
+ user_id: str,
19
+ *,
20
+ role: str = "authenticated",
21
+ extra_claims: Mapping[str, Any] | None = None,
22
+ ) -> Generator[None]:
23
+ """Run a block as a Supabase auth user by setting `request.jwt.claims`.
24
+
25
+ Sets the transaction-local `request.jwt.claims` GUC so that PostgREST/
26
+ Supabase helpers (`auth.uid()`, `auth.jwt()`, `auth.role()`) resolve to
27
+ the given user. The previous value of the GUC, if any, is restored on
28
+ exit, so nested invocations stack correctly.
29
+
30
+ Composable with `db.savepoint()`. Safe under exceptions.
31
+ """
32
+ prior = db.scalar("SELECT current_setting(%s, true)", CLAIMS_GUC)
33
+ claims: dict[str, Any] = {"sub": user_id, "role": role}
34
+ if extra_claims:
35
+ claims.update(extra_claims)
36
+ db.execute("SELECT set_config(%s, %s, true)", CLAIMS_GUC, json.dumps(claims))
37
+ try:
38
+ yield
39
+ finally:
40
+ restore_value = "" if prior in (None, "") else prior
41
+ db.execute("SELECT set_config(%s, %s, true)", CLAIMS_GUC, restore_value)
42
+
43
+
44
+ def seed_supabase_test_users(
45
+ db: SqlProofClient | object,
46
+ count: int = 20,
47
+ *,
48
+ email_prefix: str = "sqlproof_",
49
+ email_domain: str = "test.invalid",
50
+ password: str = "test_password",
51
+ ) -> None:
52
+ """Create replaceable Supabase auth users for external table FK sampling."""
53
+ del db
54
+ if count < 0:
55
+ msg = "count must be non-negative."
56
+ raise ValueError(msg)
57
+
58
+ httpx = import_module("httpx")
59
+ service_role_key = os.environ["SUPABASE_SERVICE_ROLE_KEY"]
60
+ with httpx.Client(
61
+ base_url=os.environ["SUPABASE_URL"],
62
+ headers={"Authorization": f"Bearer {service_role_key}", "apikey": service_role_key},
63
+ timeout=5.0,
64
+ ) as admin:
65
+ response = admin.get("/auth/v1/admin/users")
66
+ response.raise_for_status()
67
+ existing = response.json()
68
+ existing_emails = {_email(user) for user in existing.get("users", [])}
69
+
70
+ for index in range(count):
71
+ email = f"{email_prefix}{index}@{email_domain}"
72
+ if email in existing_emails:
73
+ continue
74
+ create_response = admin.post(
75
+ "/auth/v1/admin/users",
76
+ json={
77
+ "email": email,
78
+ "password": password,
79
+ "email_confirm": True,
80
+ },
81
+ )
82
+ create_response.raise_for_status()
83
+
84
+
85
+ def seed_test_users_directly(
86
+ db: SqlProofClient,
87
+ count: int = 20,
88
+ *,
89
+ email_prefix: str = "sqlproof_",
90
+ email_domain: str = "test.invalid",
91
+ ) -> list[str]:
92
+ """Insert skeleton `auth.users` rows directly via SQL.
93
+
94
+ Returns the user_ids of all sqlproof test users (newly inserted plus any
95
+ pre-existing ones matching the email pattern). Idempotent: existing
96
+ emails are preserved via `ON CONFLICT (email) DO NOTHING`.
97
+
98
+ Use when the Supabase admin API is unavailable but the connection has
99
+ write access to `auth.users` (e.g. local Supabase).
100
+ """
101
+ if count < 0:
102
+ msg = "count must be non-negative."
103
+ raise ValueError(msg)
104
+
105
+ for index in range(count):
106
+ email = f"{email_prefix}{index}@{email_domain}"
107
+ db.execute(
108
+ """
109
+ INSERT INTO auth.users (id, aud, role, email)
110
+ SELECT gen_random_uuid(), 'authenticated', 'authenticated', %s
111
+ WHERE NOT EXISTS (SELECT 1 FROM auth.users WHERE email = %s)
112
+ """,
113
+ email,
114
+ email,
115
+ )
116
+
117
+ escaped_prefix = email_prefix.replace("_", r"\_")
118
+ rows = db.query(
119
+ r"""
120
+ SELECT id::text AS id
121
+ FROM auth.users
122
+ WHERE email LIKE %s ESCAPE '\'
123
+ ORDER BY email
124
+ """,
125
+ f"{escaped_prefix}%@{email_domain}",
126
+ )
127
+ return [row["id"] for row in rows]
128
+
129
+
130
+ def _email(user: object) -> str:
131
+ if isinstance(user, Mapping):
132
+ user_mapping = cast(Mapping[str, object], user)
133
+ value = user_mapping.get("email")
134
+ if isinstance(value, str):
135
+ return value
136
+ return ""