sqlproof 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlproof/__init__.py +32 -0
- sqlproof/_version.py +1 -0
- sqlproof/cli.py +151 -0
- sqlproof/client.py +159 -0
- sqlproof/config.py +42 -0
- sqlproof/contrib/__init__.py +3 -0
- sqlproof/contrib/supabase.py +136 -0
- sqlproof/core.py +344 -0
- sqlproof/coverage/__init__.py +6 -0
- sqlproof/coverage/diversity.py +11 -0
- sqlproof/coverage/plpgsql.py +5 -0
- sqlproof/coverage/schema_shape.py +7 -0
- sqlproof/exceptions.py +47 -0
- sqlproof/generators/__init__.py +21 -0
- sqlproof/generators/columns.py +93 -0
- sqlproof/generators/constraints.py +181 -0
- sqlproof/generators/functions.py +9 -0
- sqlproof/generators/graph.py +51 -0
- sqlproof/generators/rows.py +153 -0
- sqlproof/generators/sampling.py +15 -0
- sqlproof/generators/well_known.py +59 -0
- sqlproof/pytest_plugin.py +24 -0
- sqlproof/reporter/__init__.py +5 -0
- sqlproof/reporter/console.py +20 -0
- sqlproof/reporter/json_io.py +26 -0
- sqlproof/runners/__init__.py +14 -0
- sqlproof/runners/db.py +48 -0
- sqlproof/runners/migration.py +51 -0
- sqlproof/runners/overload.py +41 -0
- sqlproof/runners/property.py +119 -0
- sqlproof/runners/rls.py +40 -0
- sqlproof/runners/stateful.py +36 -0
- sqlproof/schema/__init__.py +27 -0
- sqlproof/schema/dependency_graph.py +38 -0
- sqlproof/schema/fingerprint.py +34 -0
- sqlproof/schema/introspect.py +229 -0
- sqlproof/schema/model.py +98 -0
- sqlproof/schema/parse_sql.py +206 -0
- sqlproof/testing.py +101 -0
- sqlproof/types.py +34 -0
- sqlproof-0.1.0a1.dist-info/METADATA +248 -0
- sqlproof-0.1.0a1.dist-info/RECORD +44 -0
- sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
- sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
sqlproof/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from sqlproof._version import __version__
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from sqlproof.config import ExternalTableSpec, SqlProofConfig
|
|
9
|
+
from sqlproof.core import SqlProof
|
|
10
|
+
from sqlproof.runners import sqlproof
|
|
11
|
+
|
|
12
|
+
__all__ = ["ExternalTableSpec", "SqlProof", "SqlProofConfig", "__version__", "sqlproof"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name: str) -> object:
|
|
16
|
+
if name == "SqlProof":
|
|
17
|
+
from sqlproof.core import SqlProof
|
|
18
|
+
|
|
19
|
+
return SqlProof
|
|
20
|
+
if name == "SqlProofConfig":
|
|
21
|
+
from sqlproof.config import SqlProofConfig
|
|
22
|
+
|
|
23
|
+
return SqlProofConfig
|
|
24
|
+
if name == "ExternalTableSpec":
|
|
25
|
+
from sqlproof.config import ExternalTableSpec
|
|
26
|
+
|
|
27
|
+
return ExternalTableSpec
|
|
28
|
+
if name == "sqlproof":
|
|
29
|
+
from sqlproof.runners import sqlproof
|
|
30
|
+
|
|
31
|
+
return sqlproof
|
|
32
|
+
raise AttributeError(name)
|
sqlproof/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0a1"
|
sqlproof/cli.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, cast
|
|
8
|
+
|
|
9
|
+
from sqlproof._version import __version__
|
|
10
|
+
from sqlproof.coverage.schema_shape import summarize_dataset_shape
|
|
11
|
+
from sqlproof.reporter.console import format_failure
|
|
12
|
+
from sqlproof.schema.model import Column, SchemaInfo, Table
|
|
13
|
+
from sqlproof.schema.parse_sql import parse_schema_sql
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main(argv: list[str] | None = None) -> int:
|
|
17
|
+
parser = argparse.ArgumentParser(prog="sqlproof")
|
|
18
|
+
subcommands = parser.add_subparsers(dest="command", required=True)
|
|
19
|
+
|
|
20
|
+
subcommands.add_parser("version")
|
|
21
|
+
|
|
22
|
+
introspect = subcommands.add_parser("introspect")
|
|
23
|
+
introspect.add_argument("--schema-file", type=Path)
|
|
24
|
+
introspect.add_argument("--dsn")
|
|
25
|
+
introspect.add_argument("--format", choices=["json", "text"], default="text")
|
|
26
|
+
|
|
27
|
+
generate_types_parser = subcommands.add_parser("generate-types")
|
|
28
|
+
generate_types_parser.add_argument("--schema-file", type=Path, required=True)
|
|
29
|
+
generate_types_parser.add_argument("--output", type=Path)
|
|
30
|
+
generate_types_parser.add_argument(
|
|
31
|
+
"--style", choices=["typeddict", "dataclass", "pydantic"], default="typeddict"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
replay = subcommands.add_parser("replay")
|
|
35
|
+
replay.add_argument("counterexample", type=Path)
|
|
36
|
+
|
|
37
|
+
report = subcommands.add_parser("report")
|
|
38
|
+
report.add_argument("counterexample", type=Path)
|
|
39
|
+
report.add_argument("--format", choices=["json", "text"], default="text")
|
|
40
|
+
|
|
41
|
+
run = subcommands.add_parser("run")
|
|
42
|
+
run.add_argument("test_path")
|
|
43
|
+
run.add_argument("pytest_options", nargs=argparse.REMAINDER)
|
|
44
|
+
|
|
45
|
+
subcommands.add_parser("clean-orphans")
|
|
46
|
+
|
|
47
|
+
args = parser.parse_args(argv)
|
|
48
|
+
if args.command == "version":
|
|
49
|
+
print(f"sqlproof {__version__}")
|
|
50
|
+
return 0
|
|
51
|
+
if args.command == "introspect":
|
|
52
|
+
if args.schema_file is None and args.dsn is None:
|
|
53
|
+
parser.error("one of --schema-file or --dsn is required")
|
|
54
|
+
if args.schema_file is not None:
|
|
55
|
+
schema = parse_schema_sql(args.schema_file.read_text(encoding="utf-8"))
|
|
56
|
+
else:
|
|
57
|
+
from sqlproof import SqlProof
|
|
58
|
+
|
|
59
|
+
schema = SqlProof.from_connection_string(str(args.dsn)).schema_info
|
|
60
|
+
if args.format == "json":
|
|
61
|
+
print(json.dumps(_schema_payload(schema), default=str))
|
|
62
|
+
else:
|
|
63
|
+
for table in schema.tables:
|
|
64
|
+
print(f"{table.qualified_name} ({len(table.columns)} columns)")
|
|
65
|
+
return 0
|
|
66
|
+
if args.command == "generate-types":
|
|
67
|
+
from sqlproof.types import generate_types
|
|
68
|
+
|
|
69
|
+
output = generate_types(args.schema_file, style=args.style)
|
|
70
|
+
if args.output:
|
|
71
|
+
args.output.write_text(output, encoding="utf-8")
|
|
72
|
+
else:
|
|
73
|
+
print(output)
|
|
74
|
+
return 0
|
|
75
|
+
if args.command == "replay":
|
|
76
|
+
payload = json.loads(args.counterexample.read_text(encoding="utf-8"))
|
|
77
|
+
print(format_failure(payload))
|
|
78
|
+
print(f"replay loaded {payload.get('property_name', 'counterexample')}")
|
|
79
|
+
return 0
|
|
80
|
+
if args.command == "report":
|
|
81
|
+
payload = json.loads(args.counterexample.read_text(encoding="utf-8"))
|
|
82
|
+
report_payload = _counterexample_report(payload)
|
|
83
|
+
if args.format == "json":
|
|
84
|
+
print(json.dumps(report_payload, default=str))
|
|
85
|
+
else:
|
|
86
|
+
print(format_failure(payload))
|
|
87
|
+
return 0
|
|
88
|
+
if args.command == "run":
|
|
89
|
+
import pytest
|
|
90
|
+
|
|
91
|
+
return int(pytest.main([args.test_path, *args.pytest_options]))
|
|
92
|
+
if args.command == "clean-orphans":
|
|
93
|
+
print("No orphaned SqlProof containers found.")
|
|
94
|
+
return 0
|
|
95
|
+
return 1
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _schema_payload(schema: SchemaInfo) -> dict[str, object]:
|
|
99
|
+
return {"tables": [_table_payload(table) for table in schema.tables]}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _table_payload(table: Table) -> dict[str, object]:
|
|
103
|
+
return {
|
|
104
|
+
"schema": table.schema,
|
|
105
|
+
"name": table.name,
|
|
106
|
+
"columns": [_column_payload(column) for column in table.columns],
|
|
107
|
+
"primary_key": list(table.primary_key),
|
|
108
|
+
"unique_constraints": [list(columns) for columns in table.unique_constraints],
|
|
109
|
+
"checks": [check.expression for check in table.check_constraints],
|
|
110
|
+
"foreign_keys": [
|
|
111
|
+
{
|
|
112
|
+
"columns": list(foreign_key.columns),
|
|
113
|
+
"referenced_table": foreign_key.referenced_table,
|
|
114
|
+
"referenced_columns": list(foreign_key.referenced_columns),
|
|
115
|
+
"on_delete": foreign_key.on_delete,
|
|
116
|
+
"on_update": foreign_key.on_update,
|
|
117
|
+
}
|
|
118
|
+
for foreign_key in table.foreign_keys
|
|
119
|
+
],
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _column_payload(column: Column) -> dict[str, object]:
|
|
124
|
+
return {
|
|
125
|
+
"name": column.name,
|
|
126
|
+
"type": column.type.name,
|
|
127
|
+
"nullable": column.nullable,
|
|
128
|
+
"default": column.default,
|
|
129
|
+
"is_generated": column.is_generated,
|
|
130
|
+
"identity": column.identity,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _counterexample_report(payload: dict[str, object]) -> dict[str, object]:
|
|
135
|
+
dataset = payload.get("dataset")
|
|
136
|
+
shape = (
|
|
137
|
+
summarize_dataset_shape(cast(dict[str, list[dict[str, Any]]], dataset))
|
|
138
|
+
if isinstance(dataset, dict)
|
|
139
|
+
else {}
|
|
140
|
+
)
|
|
141
|
+
return {
|
|
142
|
+
"property_name": payload.get("property_name", "counterexample"),
|
|
143
|
+
"schema_fingerprint": payload.get("schema_fingerprint"),
|
|
144
|
+
"row_context": payload.get("row_context", {}),
|
|
145
|
+
"failure": payload.get("failure", {}),
|
|
146
|
+
"shape": shape,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
if __name__ == "__main__": # pragma: no cover
|
|
151
|
+
raise SystemExit(main(sys.argv[1:]))
|
sqlproof/client.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from contextlib import AbstractContextManager, contextmanager
|
|
6
|
+
from dataclasses import fields, is_dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Protocol, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
from sqlproof.exceptions import SqlProofMappingError, SqlProofUsageError
|
|
11
|
+
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SqlProofClient(Protocol):
|
|
16
|
+
def query(self, sql: str, *params: Any) -> list[dict[str, Any]]: ...
|
|
17
|
+
|
|
18
|
+
def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]: ...
|
|
19
|
+
|
|
20
|
+
def scalar(self, sql: str, *params: Any) -> Any: ...
|
|
21
|
+
|
|
22
|
+
def execute(self, sql: str, *params: Any) -> int: ...
|
|
23
|
+
|
|
24
|
+
def execute_file(self, path: str) -> None: ...
|
|
25
|
+
|
|
26
|
+
def savepoint(self) -> AbstractContextManager[None]: ...
|
|
27
|
+
|
|
28
|
+
def get_generated_data(self) -> dict[str, list[dict[str, Any]]]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InMemorySqlProofClient:
|
|
32
|
+
def __init__(self, dataset: dict[str, list[dict[str, Any]]]) -> None:
|
|
33
|
+
self._dataset = dataset
|
|
34
|
+
|
|
35
|
+
def query(self, sql: str, *params: Any) -> list[dict[str, Any]]:
|
|
36
|
+
del params
|
|
37
|
+
match = re.search(r"SELECT\s+(?P<columns>.*?)\s+FROM\s+(?P<table>\w+)", sql, re.I | re.S)
|
|
38
|
+
if match is None:
|
|
39
|
+
return []
|
|
40
|
+
table = match.group("table")
|
|
41
|
+
rows = self._dataset.get(table, [])
|
|
42
|
+
columns_sql = match.group("columns").strip()
|
|
43
|
+
if columns_sql == "*":
|
|
44
|
+
return [dict(row) for row in rows]
|
|
45
|
+
columns = [_clean_selected_column(part) for part in columns_sql.split(",")]
|
|
46
|
+
return [{column: row.get(column) for column in columns} for row in rows]
|
|
47
|
+
|
|
48
|
+
def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]:
|
|
49
|
+
rows = self.query(sql, *params)
|
|
50
|
+
return [_map_row(row, model) for row in rows]
|
|
51
|
+
|
|
52
|
+
def scalar(self, sql: str, *params: Any) -> Any:
|
|
53
|
+
rows = self.query(sql, *params)
|
|
54
|
+
if not rows:
|
|
55
|
+
return None
|
|
56
|
+
return next(iter(rows[0].values()))
|
|
57
|
+
|
|
58
|
+
def execute(self, sql: str, *params: Any) -> int:
|
|
59
|
+
del sql, params
|
|
60
|
+
return 0
|
|
61
|
+
|
|
62
|
+
def execute_file(self, path: str) -> None:
|
|
63
|
+
del path
|
|
64
|
+
|
|
65
|
+
@contextmanager
|
|
66
|
+
def savepoint(self) -> Generator[None]:
|
|
67
|
+
yield
|
|
68
|
+
|
|
69
|
+
def get_generated_data(self) -> dict[str, list[dict[str, Any]]]:
|
|
70
|
+
return self._dataset
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PsycopgSqlProofClient:
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
connection: Any,
|
|
77
|
+
*,
|
|
78
|
+
dataset: dict[str, list[dict[str, Any]]] | None = None,
|
|
79
|
+
) -> None:
|
|
80
|
+
self._connection = connection
|
|
81
|
+
self._dataset = dataset or {}
|
|
82
|
+
|
|
83
|
+
def query(self, sql: str, *params: Any) -> list[dict[str, Any]]:
|
|
84
|
+
cursor = self._connection.execute(sql, params)
|
|
85
|
+
return [dict(row) for row in cursor.fetchall()]
|
|
86
|
+
|
|
87
|
+
def query_typed(self, sql: str, model: type[T], *params: Any) -> list[T]:
|
|
88
|
+
rows = self.query(sql, *params)
|
|
89
|
+
return [_map_row(row, model) for row in rows]
|
|
90
|
+
|
|
91
|
+
def scalar(self, sql: str, *params: Any) -> Any:
|
|
92
|
+
rows = self.query(sql, *params)
|
|
93
|
+
if not rows:
|
|
94
|
+
return None
|
|
95
|
+
return next(iter(rows[0].values()))
|
|
96
|
+
|
|
97
|
+
def execute(self, sql: str, *params: Any) -> int:
|
|
98
|
+
cursor = self._connection.execute(sql, params)
|
|
99
|
+
return int(cursor.rowcount)
|
|
100
|
+
|
|
101
|
+
def execute_file(self, path: str | Path) -> None:
|
|
102
|
+
sql = Path(path).read_text(encoding="utf-8")
|
|
103
|
+
for statement in _split_sql_statements(sql):
|
|
104
|
+
self.execute(statement)
|
|
105
|
+
|
|
106
|
+
@contextmanager
|
|
107
|
+
def savepoint(self) -> Generator[None]:
|
|
108
|
+
name = "sqlproof_client"
|
|
109
|
+
self.execute(f"SAVEPOINT {name}")
|
|
110
|
+
try:
|
|
111
|
+
yield
|
|
112
|
+
except BaseException:
|
|
113
|
+
self.execute(f"ROLLBACK TO SAVEPOINT {name}")
|
|
114
|
+
raise
|
|
115
|
+
finally:
|
|
116
|
+
self.execute(f"RELEASE SAVEPOINT {name}")
|
|
117
|
+
|
|
118
|
+
def get_generated_data(self) -> dict[str, list[dict[str, Any]]]:
|
|
119
|
+
return self._dataset
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _split_sql_statements(sql: str) -> list[str]:
|
|
123
|
+
statements = [statement.strip() for statement in sql.split(";")]
|
|
124
|
+
return [statement for statement in statements if statement]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _clean_selected_column(sql: str) -> str:
|
|
128
|
+
value = sql.strip()
|
|
129
|
+
if "." in value:
|
|
130
|
+
value = value.rsplit(".", 1)[1]
|
|
131
|
+
if " AS " in value.upper():
|
|
132
|
+
value = re.split(r"\s+AS\s+", value, flags=re.I)[-1]
|
|
133
|
+
return value.strip().strip('"')
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _map_row(row: dict[str, Any], model: type[T]) -> T:
|
|
137
|
+
detectors = [
|
|
138
|
+
hasattr(model, "__pydantic_fields__"),
|
|
139
|
+
is_dataclass(model),
|
|
140
|
+
hasattr(model, "__required_keys__") or hasattr(model, "__optional_keys__"),
|
|
141
|
+
]
|
|
142
|
+
if sum(detectors) != 1:
|
|
143
|
+
raise SqlProofUsageError("Ambiguous or unsupported row model type.")
|
|
144
|
+
if hasattr(model, "__pydantic_fields__"):
|
|
145
|
+
try:
|
|
146
|
+
return model.model_validate(row) # type: ignore[attr-defined,no-any-return]
|
|
147
|
+
except Exception as exc: # pragma: no cover - exercised when pydantic is installed
|
|
148
|
+
raise SqlProofMappingError(str(exc)) from exc
|
|
149
|
+
if is_dataclass(model):
|
|
150
|
+
names = {field.name for field in fields(model)}
|
|
151
|
+
missing = {field.name for field in fields(model) if field.name not in row}
|
|
152
|
+
if missing:
|
|
153
|
+
raise SqlProofMappingError(f"Missing required fields: {', '.join(sorted(missing))}")
|
|
154
|
+
return model(**{name: row[name] for name in names if name in row})
|
|
155
|
+
required = cast("set[str]", getattr(model, "__required_keys__", set[str]()))
|
|
156
|
+
missing = required - row.keys()
|
|
157
|
+
if missing:
|
|
158
|
+
raise SqlProofMappingError(f"Missing required fields: {', '.join(sorted(missing))}")
|
|
159
|
+
return dict(row) # type: ignore[return-value]
|
sqlproof/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from hypothesis.strategies import SearchStrategy
|
|
9
|
+
|
|
10
|
+
from sqlproof.exceptions import SqlProofUsageError
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from sqlproof.client import SqlProofClient
|
|
14
|
+
|
|
15
|
+
ExternalSeed = Callable[["SqlProofClient"], None] | Callable[["SqlProofClient", int], None]
|
|
16
|
+
SizeSpec = int | SearchStrategy[int]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class ExternalTableSpec:
|
|
21
|
+
primary_key: str
|
|
22
|
+
sample: Callable[[SqlProofClient], Sequence[object]]
|
|
23
|
+
seed: ExternalSeed | None = None
|
|
24
|
+
seed_count: SizeSpec | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, slots=True)
|
|
28
|
+
class SqlProofConfig:
|
|
29
|
+
connection_string: str | None = None
|
|
30
|
+
schema: str = "public"
|
|
31
|
+
schema_file: str | Path | None = None
|
|
32
|
+
image: str = "postgres:16"
|
|
33
|
+
reuse_container: bool = True
|
|
34
|
+
transaction_per_run: bool = True
|
|
35
|
+
seed: Callable[[SqlProofClient], None] | None = None
|
|
36
|
+
external_tables: Mapping[str, ExternalTableSpec] | None = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
sources = [self.connection_string is not None, self.schema_file is not None]
|
|
40
|
+
if sum(sources) != 1:
|
|
41
|
+
msg = "Exactly one of connection_string or schema_file must be provided."
|
|
42
|
+
raise SqlProofUsageError(msg)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from collections.abc import Generator, Mapping
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from importlib import import_module
|
|
8
|
+
from typing import Any, cast
|
|
9
|
+
|
|
10
|
+
from sqlproof.client import SqlProofClient
|
|
11
|
+
|
|
12
|
+
CLAIMS_GUC = "request.jwt.claims"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@contextmanager
|
|
16
|
+
def as_supabase_user(
|
|
17
|
+
db: SqlProofClient,
|
|
18
|
+
user_id: str,
|
|
19
|
+
*,
|
|
20
|
+
role: str = "authenticated",
|
|
21
|
+
extra_claims: Mapping[str, Any] | None = None,
|
|
22
|
+
) -> Generator[None]:
|
|
23
|
+
"""Run a block as a Supabase auth user by setting `request.jwt.claims`.
|
|
24
|
+
|
|
25
|
+
Sets the transaction-local `request.jwt.claims` GUC so that PostgREST/
|
|
26
|
+
Supabase helpers (`auth.uid()`, `auth.jwt()`, `auth.role()`) resolve to
|
|
27
|
+
the given user. The previous value of the GUC, if any, is restored on
|
|
28
|
+
exit, so nested invocations stack correctly.
|
|
29
|
+
|
|
30
|
+
Composable with `db.savepoint()`. Safe under exceptions.
|
|
31
|
+
"""
|
|
32
|
+
prior = db.scalar("SELECT current_setting(%s, true)", CLAIMS_GUC)
|
|
33
|
+
claims: dict[str, Any] = {"sub": user_id, "role": role}
|
|
34
|
+
if extra_claims:
|
|
35
|
+
claims.update(extra_claims)
|
|
36
|
+
db.execute("SELECT set_config(%s, %s, true)", CLAIMS_GUC, json.dumps(claims))
|
|
37
|
+
try:
|
|
38
|
+
yield
|
|
39
|
+
finally:
|
|
40
|
+
restore_value = "" if prior in (None, "") else prior
|
|
41
|
+
db.execute("SELECT set_config(%s, %s, true)", CLAIMS_GUC, restore_value)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def seed_supabase_test_users(
|
|
45
|
+
db: SqlProofClient | object,
|
|
46
|
+
count: int = 20,
|
|
47
|
+
*,
|
|
48
|
+
email_prefix: str = "sqlproof_",
|
|
49
|
+
email_domain: str = "test.invalid",
|
|
50
|
+
password: str = "test_password",
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Create replaceable Supabase auth users for external table FK sampling."""
|
|
53
|
+
del db
|
|
54
|
+
if count < 0:
|
|
55
|
+
msg = "count must be non-negative."
|
|
56
|
+
raise ValueError(msg)
|
|
57
|
+
|
|
58
|
+
httpx = import_module("httpx")
|
|
59
|
+
service_role_key = os.environ["SUPABASE_SERVICE_ROLE_KEY"]
|
|
60
|
+
with httpx.Client(
|
|
61
|
+
base_url=os.environ["SUPABASE_URL"],
|
|
62
|
+
headers={"Authorization": f"Bearer {service_role_key}", "apikey": service_role_key},
|
|
63
|
+
timeout=5.0,
|
|
64
|
+
) as admin:
|
|
65
|
+
response = admin.get("/auth/v1/admin/users")
|
|
66
|
+
response.raise_for_status()
|
|
67
|
+
existing = response.json()
|
|
68
|
+
existing_emails = {_email(user) for user in existing.get("users", [])}
|
|
69
|
+
|
|
70
|
+
for index in range(count):
|
|
71
|
+
email = f"{email_prefix}{index}@{email_domain}"
|
|
72
|
+
if email in existing_emails:
|
|
73
|
+
continue
|
|
74
|
+
create_response = admin.post(
|
|
75
|
+
"/auth/v1/admin/users",
|
|
76
|
+
json={
|
|
77
|
+
"email": email,
|
|
78
|
+
"password": password,
|
|
79
|
+
"email_confirm": True,
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
create_response.raise_for_status()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def seed_test_users_directly(
|
|
86
|
+
db: SqlProofClient,
|
|
87
|
+
count: int = 20,
|
|
88
|
+
*,
|
|
89
|
+
email_prefix: str = "sqlproof_",
|
|
90
|
+
email_domain: str = "test.invalid",
|
|
91
|
+
) -> list[str]:
|
|
92
|
+
"""Insert skeleton `auth.users` rows directly via SQL.
|
|
93
|
+
|
|
94
|
+
Returns the user_ids of all sqlproof test users (newly inserted plus any
|
|
95
|
+
pre-existing ones matching the email pattern). Idempotent: existing
|
|
96
|
+
emails are preserved via `ON CONFLICT (email) DO NOTHING`.
|
|
97
|
+
|
|
98
|
+
Use when the Supabase admin API is unavailable but the connection has
|
|
99
|
+
write access to `auth.users` (e.g. local Supabase).
|
|
100
|
+
"""
|
|
101
|
+
if count < 0:
|
|
102
|
+
msg = "count must be non-negative."
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
|
|
105
|
+
for index in range(count):
|
|
106
|
+
email = f"{email_prefix}{index}@{email_domain}"
|
|
107
|
+
db.execute(
|
|
108
|
+
"""
|
|
109
|
+
INSERT INTO auth.users (id, aud, role, email)
|
|
110
|
+
SELECT gen_random_uuid(), 'authenticated', 'authenticated', %s
|
|
111
|
+
WHERE NOT EXISTS (SELECT 1 FROM auth.users WHERE email = %s)
|
|
112
|
+
""",
|
|
113
|
+
email,
|
|
114
|
+
email,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
escaped_prefix = email_prefix.replace("_", r"\_")
|
|
118
|
+
rows = db.query(
|
|
119
|
+
r"""
|
|
120
|
+
SELECT id::text AS id
|
|
121
|
+
FROM auth.users
|
|
122
|
+
WHERE email LIKE %s ESCAPE '\'
|
|
123
|
+
ORDER BY email
|
|
124
|
+
""",
|
|
125
|
+
f"{escaped_prefix}%@{email_domain}",
|
|
126
|
+
)
|
|
127
|
+
return [row["id"] for row in rows]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _email(user: object) -> str:
|
|
131
|
+
if isinstance(user, Mapping):
|
|
132
|
+
user_mapping = cast(Mapping[str, object], user)
|
|
133
|
+
value = user_mapping.get("email")
|
|
134
|
+
if isinstance(value, str):
|
|
135
|
+
return value
|
|
136
|
+
return ""
|