sqlproof 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sqlproof/__init__.py +32 -0
  2. sqlproof/_version.py +1 -0
  3. sqlproof/cli.py +151 -0
  4. sqlproof/client.py +159 -0
  5. sqlproof/config.py +42 -0
  6. sqlproof/contrib/__init__.py +3 -0
  7. sqlproof/contrib/supabase.py +136 -0
  8. sqlproof/core.py +344 -0
  9. sqlproof/coverage/__init__.py +6 -0
  10. sqlproof/coverage/diversity.py +11 -0
  11. sqlproof/coverage/plpgsql.py +5 -0
  12. sqlproof/coverage/schema_shape.py +7 -0
  13. sqlproof/exceptions.py +47 -0
  14. sqlproof/generators/__init__.py +21 -0
  15. sqlproof/generators/columns.py +93 -0
  16. sqlproof/generators/constraints.py +181 -0
  17. sqlproof/generators/functions.py +9 -0
  18. sqlproof/generators/graph.py +51 -0
  19. sqlproof/generators/rows.py +153 -0
  20. sqlproof/generators/sampling.py +15 -0
  21. sqlproof/generators/well_known.py +59 -0
  22. sqlproof/pytest_plugin.py +24 -0
  23. sqlproof/reporter/__init__.py +5 -0
  24. sqlproof/reporter/console.py +20 -0
  25. sqlproof/reporter/json_io.py +26 -0
  26. sqlproof/runners/__init__.py +14 -0
  27. sqlproof/runners/db.py +48 -0
  28. sqlproof/runners/migration.py +51 -0
  29. sqlproof/runners/overload.py +41 -0
  30. sqlproof/runners/property.py +119 -0
  31. sqlproof/runners/rls.py +40 -0
  32. sqlproof/runners/stateful.py +36 -0
  33. sqlproof/schema/__init__.py +27 -0
  34. sqlproof/schema/dependency_graph.py +38 -0
  35. sqlproof/schema/fingerprint.py +34 -0
  36. sqlproof/schema/introspect.py +229 -0
  37. sqlproof/schema/model.py +98 -0
  38. sqlproof/schema/parse_sql.py +206 -0
  39. sqlproof/testing.py +101 -0
  40. sqlproof/types.py +34 -0
  41. sqlproof-0.1.0a1.dist-info/METADATA +248 -0
  42. sqlproof-0.1.0a1.dist-info/RECORD +44 -0
  43. sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
  44. sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from decimal import Decimal
5
+ from typing import Any
6
+
7
+ from hypothesis import strategies as st
8
+ from hypothesis.strategies import SearchStrategy
9
+
10
+ from sqlproof.schema.model import CheckConstraint, Column
11
+
12
+
13
+ def refine_for_checks(
14
+ column: Column,
15
+ strategy: SearchStrategy[Any],
16
+ checks: tuple[CheckConstraint, ...],
17
+ ) -> SearchStrategy[Any]:
18
+ for check in checks:
19
+ strategy = _refine_for_check(column, strategy, check.expression)
20
+ return strategy
21
+
22
+
23
+ def _refine_for_check(
24
+ column: Column,
25
+ strategy: SearchStrategy[Any],
26
+ expression: str,
27
+ ) -> SearchStrategy[Any]:
28
+ expression = _normalize_check_expression(expression)
29
+ in_set_match = re.fullmatch(
30
+ rf"{re.escape(column.name)}\s+IN\s*\((?P<values>.+)\)",
31
+ expression,
32
+ flags=re.IGNORECASE,
33
+ )
34
+ if in_set_match is not None:
35
+ values = tuple(
36
+ _parse_sql_literal(value) for value in in_set_match.group("values").split(",")
37
+ )
38
+ return _sampled_values_for_column(column, values)
39
+
40
+ any_array_match = re.fullmatch(
41
+ rf"\(?\s*{re.escape(column.name)}\s*=\s*ANY\s*"
42
+ r"\(\s*ARRAY\[(?P<values>.+)\]\s*\)\s*\)?",
43
+ expression,
44
+ flags=re.IGNORECASE,
45
+ )
46
+ if any_array_match is not None:
47
+ values = tuple(
48
+ _parse_sql_literal(value) for value in any_array_match.group("values").split(",")
49
+ )
50
+ return _sampled_values_for_column(column, values)
51
+
52
+ length_match = re.fullmatch(
53
+ rf"(?:char_length|length)\s*\(\s*{re.escape(column.name)}\s*\)\s*"
54
+ rf"(?P<op>>=|>|<=|<|=)\s*(?P<value>\d+)",
55
+ expression,
56
+ flags=re.IGNORECASE,
57
+ )
58
+ if length_match is not None:
59
+ direct = _direct_length_strategy(
60
+ column, length_match.group("op"), int(length_match.group("value"))
61
+ )
62
+ if direct is not None:
63
+ return direct
64
+
65
+ range_match = re.fullmatch(
66
+ rf"{re.escape(column.name)}\s*(?P<op>>=|>|<=|<)\s*(?P<value>-?\d+(?:\.\d+)?)",
67
+ expression,
68
+ flags=re.IGNORECASE,
69
+ )
70
+ if range_match is None:
71
+ return strategy
72
+ op = range_match.group("op")
73
+ raw_value = Decimal(range_match.group("value"))
74
+ direct = _direct_range_strategy(column, op, raw_value)
75
+ if direct is not None:
76
+ return direct
77
+
78
+ def predicate(value: Any) -> bool:
79
+ if value is None:
80
+ return True
81
+ comparable = Decimal(str(value))
82
+ if op == ">=":
83
+ return comparable >= raw_value
84
+ if op == ">":
85
+ return comparable > raw_value
86
+ if op == "<=":
87
+ return comparable <= raw_value
88
+ return comparable < raw_value
89
+
90
+ return strategy.filter(predicate)
91
+
92
+
93
+ def _normalize_check_expression(expression: str) -> str:
94
+ value = expression.strip()
95
+ match = re.fullmatch(r"CHECK\s*\((?P<inner>.*)\)", value, flags=re.IGNORECASE | re.DOTALL)
96
+ if match is not None:
97
+ return match.group("inner").strip()
98
+ return value
99
+
100
+
101
+ def _parse_sql_literal(value: str) -> Any:
102
+ stripped = value.strip()
103
+ cast_match = re.fullmatch(
104
+ r"(?P<literal>'(?:''|[^'])*'|-?\d+(?:\.\d+)?)(?:\s*::[\w. ]+)?",
105
+ stripped,
106
+ )
107
+ if cast_match is not None:
108
+ stripped = cast_match.group("literal")
109
+ if stripped.startswith("'") and stripped.endswith("'"):
110
+ return stripped[1:-1].replace("''", "'")
111
+ try:
112
+ return int(stripped)
113
+ except ValueError:
114
+ try:
115
+ return Decimal(stripped)
116
+ except Exception:
117
+ return stripped
118
+
119
+
120
+ def _sampled_values_for_column(column: Column, values: tuple[Any, ...]) -> SearchStrategy[Any]:
121
+ strategy = st.sampled_from(values)
122
+ if column.nullable:
123
+ return st.none() | strategy
124
+ return strategy
125
+
126
+
127
+ def _direct_length_strategy(column: Column, op: str, value: int) -> SearchStrategy[str] | None:
128
+ name = column.type.name
129
+ if name not in {"text", "citext", "varchar", "character varying", "char", "character"}:
130
+ return None
131
+ min_size = 0
132
+ max_size = column.type.modifiers[0] if column.type.modifiers else 255
133
+ if op == "=":
134
+ min_size = value
135
+ max_size = value
136
+ elif op == "<=":
137
+ max_size = min(max_size, value)
138
+ elif op == "<":
139
+ max_size = min(max_size, max(0, value - 1))
140
+ elif op == ">=":
141
+ min_size = value
142
+ elif op == ">":
143
+ min_size = value + 1
144
+ if min_size > max_size:
145
+ return st.nothing()
146
+ return st.text(min_size=min_size, max_size=max_size)
147
+
148
+
149
+ def _direct_range_strategy(
150
+ column: Column,
151
+ op: str,
152
+ raw_value: Decimal,
153
+ ) -> SearchStrategy[Any] | None:
154
+ name = column.type.name
155
+ if op not in {">=", ">"}:
156
+ return None
157
+ minimum = raw_value if op == ">=" else raw_value + Decimal("1")
158
+ if name in {"integer", "int", "int4"}:
159
+ return st.integers(max(int(minimum), -2_147_483_648), 2_147_483_647)
160
+ if name in {"bigint", "int8"}:
161
+ return st.integers(max(int(minimum), -(2**63)), 2**63 - 1)
162
+ if name in {"numeric", "decimal"}:
163
+ places = column.type.modifiers[1] if len(column.type.modifiers) > 1 else 2
164
+ return st.decimals(
165
+ min_value=minimum,
166
+ max_value=Decimal("1000000"),
167
+ places=places,
168
+ allow_nan=False,
169
+ allow_infinity=False,
170
+ )
171
+ return None
172
+
173
+
174
+ def unique_rows(rows: list[dict[str, Any]], columns: tuple[str, ...]) -> bool:
175
+ seen: set[tuple[Any, ...]] = set()
176
+ for row in rows:
177
+ key = tuple(row[column] for column in columns)
178
+ if key in seen:
179
+ return False
180
+ seen.add(key)
181
+ return True
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True, slots=True)
7
+ class FunctionCall:
8
+ sql: str
9
+ overload_name: str
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from typing import Any
5
+
6
+ from hypothesis import strategies as st
7
+ from hypothesis.strategies import SearchStrategy
8
+
9
+ from sqlproof.schema.dependency_graph import insertion_order
10
+ from sqlproof.schema.model import SchemaInfo
11
+
12
+ from .rows import ColumnOverrides, table_rows_strategy
13
+
14
+ Dataset = dict[str, list[dict[str, Any]]]
15
+ SizeSpec = int | SearchStrategy[int]
16
+
17
+
18
+ def dataset_strategy(
19
+ schema: SchemaInfo,
20
+ *,
21
+ sizes: Mapping[str, SizeSpec],
22
+ external_parent_rows: Mapping[str, list[dict[str, Any]]] | None = None,
23
+ columns: ColumnOverrides | None = None,
24
+ ) -> SearchStrategy[Dataset]:
25
+ ordered_tables = insertion_order(schema.tables)
26
+ external_parent_rows = external_parent_rows or {}
27
+
28
+ @st.composite
29
+ def dataset(draw: st.DrawFn) -> Dataset:
30
+ rows_by_table: Dataset = {}
31
+ for table in ordered_tables:
32
+ count = _draw_size(draw, sizes.get(table.name, 0))
33
+ available_parent_rows = {**external_parent_rows, **rows_by_table}
34
+ rows_by_table[table.name] = draw(
35
+ table_rows_strategy(
36
+ table,
37
+ count=count,
38
+ parent_rows=available_parent_rows,
39
+ rows_by_table=rows_by_table,
40
+ columns=columns,
41
+ )
42
+ )
43
+ return rows_by_table
44
+
45
+ return dataset()
46
+
47
+
48
+ def _draw_size(draw: st.DrawFn, size: SizeSpec) -> int:
49
+ if isinstance(size, int):
50
+ return size
51
+ return draw(size)
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from dataclasses import dataclass
5
+ from decimal import Decimal
6
+ from typing import Any, cast
7
+ from uuid import UUID
8
+
9
+ from hypothesis import strategies as st
10
+ from hypothesis.strategies import SearchStrategy
11
+
12
+ from sqlproof.exceptions import SqlProofGenerationError
13
+ from sqlproof.generators.columns import strategy_for_column
14
+ from sqlproof.generators.constraints import refine_for_checks
15
+ from sqlproof.schema.model import ForeignKey, Table
16
+
17
+ DatasetRows = dict[str, list[dict[str, Any]]]
18
+ ColumnOverrides = Mapping[str, Any]
19
+
20
+
21
+ @dataclass(frozen=True, slots=True)
22
+ class ColumnContext:
23
+ table: Table
24
+ column_name: str
25
+ row_index: int
26
+ row: dict[str, Any]
27
+ table_rows: list[dict[str, Any]]
28
+ rows_by_table: DatasetRows
29
+
30
+
31
+ def table_rows_strategy(
32
+ table: Table,
33
+ *,
34
+ count: int,
35
+ parent_rows: DatasetRows | None = None,
36
+ rows_by_table: DatasetRows | None = None,
37
+ columns: ColumnOverrides | None = None,
38
+ ) -> SearchStrategy[list[dict[str, Any]]]:
39
+ parent_rows = parent_rows or {}
40
+ rows_by_table = rows_by_table or {}
41
+ columns = columns or {}
42
+
43
+ @st.composite
44
+ def rows(draw: st.DrawFn) -> list[dict[str, Any]]:
45
+ generated: list[dict[str, Any]] = []
46
+ for index in range(count):
47
+ row: dict[str, Any] = {}
48
+ for column in table.columns:
49
+ if column.name in table.primary_key and len(table.primary_key) == 1:
50
+ row[column.name] = _unique_value(column.name, column.type.name, index)
51
+ continue
52
+ if column.is_generated:
53
+ continue
54
+ override = _column_override(columns, table, column.name)
55
+ if override is not None:
56
+ context = ColumnContext(
57
+ table=table,
58
+ column_name=column.name,
59
+ row_index=index,
60
+ row=row,
61
+ table_rows=generated,
62
+ rows_by_table=rows_by_table,
63
+ )
64
+ row[column.name] = _draw_override(draw, override, context)
65
+ continue
66
+ if column.default is not None:
67
+ continue
68
+ fk = _foreign_key_for_column(table, column.name)
69
+ if fk is not None:
70
+ parent_key = _parent_rows_key(fk, parent_rows)
71
+ if parent_key is not None:
72
+ parents = parent_rows[parent_key]
73
+ if parents:
74
+ parent = draw(st.sampled_from(parents))
75
+ row[column.name] = parent[fk.referenced_columns[0]]
76
+ continue
77
+ if column.nullable:
78
+ row[column.name] = None
79
+ continue
80
+ msg = (
81
+ f"Cannot generate {table.name}.{column.name}: "
82
+ f"required foreign key has no available parent rows for "
83
+ f"{fk.referenced_table}.{fk.referenced_columns[0]}."
84
+ )
85
+ raise SqlProofGenerationError(msg)
86
+ if _is_single_column_unique(table, column.name):
87
+ row[column.name] = _unique_value(column.name, column.type.name, index)
88
+ continue
89
+ strategy = refine_for_checks(
90
+ column, strategy_for_column(column), table.check_constraints
91
+ )
92
+ row[column.name] = draw(strategy)
93
+ generated.append(row)
94
+ return generated
95
+
96
+ return rows()
97
+
98
+
99
+ def _column_override(
100
+ overrides: ColumnOverrides,
101
+ table: Table,
102
+ column_name: str,
103
+ ) -> Any | None:
104
+ for key in (f"{table.qualified_name}.{column_name}", f"{table.name}.{column_name}"):
105
+ if key in overrides:
106
+ return overrides[key]
107
+ return None
108
+
109
+
110
+ def _draw_override(draw: st.DrawFn, override: Any, context: ColumnContext) -> Any:
111
+ if isinstance(override, SearchStrategy):
112
+ return draw(cast(SearchStrategy[Any], override))
113
+ if callable(override):
114
+ return override(context)
115
+ return override
116
+
117
+
118
+ def _foreign_key_for_column(table: Table, column_name: str) -> ForeignKey | None:
119
+ for foreign_key in table.foreign_keys:
120
+ if foreign_key.columns == (column_name,):
121
+ return foreign_key
122
+ return None
123
+
124
+
125
+ def _parent_rows_key(
126
+ foreign_key: ForeignKey,
127
+ parent_rows: dict[str, list[dict[str, Any]]],
128
+ ) -> str | None:
129
+ if foreign_key.referenced_schema is not None:
130
+ qualified = f"{foreign_key.referenced_schema}.{foreign_key.referenced_table}"
131
+ if qualified in parent_rows:
132
+ return qualified
133
+ if foreign_key.referenced_table in parent_rows:
134
+ return foreign_key.referenced_table
135
+ return None
136
+
137
+
138
+ def _is_single_column_unique(table: Table, column_name: str) -> bool:
139
+ return any(columns == (column_name,) for columns in table.unique_constraints)
140
+
141
+
142
+ def _unique_value(column_name: str, type_name: str, index: int) -> Any:
143
+ normalized = type_name.lower()
144
+ value = index + 1
145
+ if normalized in {"smallint", "int2", "integer", "int", "int4", "serial"}:
146
+ return value
147
+ if normalized in {"bigint", "int8", "bigserial"}:
148
+ return value
149
+ if normalized in {"numeric", "decimal"}:
150
+ return Decimal(value)
151
+ if normalized == "uuid":
152
+ return str(UUID(int=value))
153
+ return f"{column_name}_{value}"
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TypeVar
5
+
6
+ from hypothesis.errors import NonInteractiveExampleWarning
7
+ from hypothesis.strategies import SearchStrategy
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ def draw_example(strategy: SearchStrategy[T]) -> T:
13
+ with warnings.catch_warnings():
14
+ warnings.simplefilter("ignore", NonInteractiveExampleWarning)
15
+ return strategy.example()
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from hypothesis import strategies as st
4
+ from hypothesis.strategies import SearchStrategy
5
+
6
+
7
+ def slugs(min_length: int = 1, max_length: int = 64) -> SearchStrategy[str]:
8
+ alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-"
9
+ return st.text(alphabet=alphabet, min_size=min_length, max_size=max_length).filter(
10
+ lambda value: not value.startswith("-") and not value.endswith("-")
11
+ )
12
+
13
+
14
+ def emails(domains: list[str] | None = None) -> SearchStrategy[str]:
15
+ domain_strategy = (
16
+ st.sampled_from(domains) if domains else st.sampled_from(["example.com", "test.dev"])
17
+ )
18
+ local = st.text(
19
+ alphabet="abcdefghijklmnopqrstuvwxyz0123456789._-",
20
+ min_size=1,
21
+ max_size=32,
22
+ ).filter(lambda value: value.strip("."))
23
+ return st.builds(_join_email, local, domain_strategy)
24
+
25
+
26
+ def urls(
27
+ schemes: tuple[str, ...] = ("https", "http"),
28
+ *,
29
+ include_path: bool = True,
30
+ include_query: bool = True,
31
+ include_fragment: bool = False,
32
+ ) -> SearchStrategy[str]:
33
+ scheme = st.sampled_from(schemes)
34
+ host = st.sampled_from(["example.com", "sqlproof.dev", "localhost"])
35
+ path = slugs(max_length=20).map(lambda value: f"/{value}") if include_path else st.just("")
36
+ query = st.text(max_size=10).map(lambda value: f"?q={value}") if include_query else st.just("")
37
+ fragment = (
38
+ slugs(max_length=10).map(lambda value: f"#{value}") if include_fragment else st.just("")
39
+ )
40
+ return st.builds(_join_url, scheme, host, path, query, fragment)
41
+
42
+
43
+ def phone_numbers(country: str | None = None) -> SearchStrategy[str]:
44
+ prefix = "+1" if country in {None, "US", "CA"} else "+44"
45
+ return st.integers(2_000_000_000, 9_999_999_999).map(lambda number: f"{prefix}{number}")
46
+
47
+
48
+ def postal_codes(country: str) -> SearchStrategy[str]:
49
+ if country.upper() in {"US", "USA"}:
50
+ return st.integers(0, 99_999).map(lambda number: f"{number:05d}")
51
+ return st.text(alphabet="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ", min_size=3, max_size=10)
52
+
53
+
54
+ def _join_email(lhs: str, rhs: str) -> str:
55
+ return f"{lhs}@{rhs}"
56
+
57
+
58
+ def _join_url(scheme: str, host: str, path: str, query: str, fragment: str) -> str:
59
+ return f"{scheme}://{host}{path}{query}{fragment}"
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+
5
+
6
+ def pytest_addoption(parser: pytest.Parser) -> None:
7
+ group = parser.getgroup("sqlproof")
8
+ group.addoption("--sqlproof-seed", action="store", type=int, help="Fix the SqlProof seed.")
9
+ group.addoption(
10
+ "--sqlproof-runs", action="store", type=int, help="Override SqlProof run count."
11
+ )
12
+ group.addoption(
13
+ "--sqlproof-show-counterexample",
14
+ action="store_true",
15
+ help="Print full SqlProof counterexamples.",
16
+ )
17
+ group.addoption("--sqlproof-coverage", action="store_true", help="Enable PL/pgSQL coverage.")
18
+ group.addoption(
19
+ "--sqlproof-diversity-report",
20
+ action="store_true",
21
+ help="Print generator diversity report.",
22
+ )
23
+ group.addoption("--sqlproof-postgres-image", action="store", help="Override Postgres image.")
24
+ group.addoption("--sqlproof-verbose", action="store_true", help="Enable DEBUG logging.")
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlproof.reporter.json_io import write_counterexample
4
+
5
+ __all__ = ["write_counterexample"]
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, cast
4
+
5
+ from sqlproof.coverage.schema_shape import summarize_dataset_shape
6
+
7
+
8
+ def format_failure(payload: dict[str, Any]) -> str:
9
+ lines = [
10
+ f"Property failed: {payload.get('property_name', 'counterexample')}",
11
+ f"Failure: {payload.get('failure', {}).get('kind', 'unknown')}: "
12
+ f"{payload.get('failure', {}).get('message', '')}",
13
+ ]
14
+ if payload.get("row_context"):
15
+ lines.append(f"Row context: {payload['row_context']}")
16
+ dataset = payload.get("dataset")
17
+ if isinstance(dataset, dict):
18
+ shape = summarize_dataset_shape(cast(dict[str, list[dict[str, Any]]], dataset))
19
+ lines.append(f"Dataset shape: {shape}")
20
+ return "\n".join(lines)
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import is_dataclass
5
+ from datetime import date, datetime, time, timedelta
6
+ from decimal import Decimal
7
+ from pathlib import Path
8
+ from typing import Any
9
+ from uuid import UUID
10
+
11
+
12
+ def write_counterexample(path: Path, payload: dict[str, Any]) -> None:
13
+ path.parent.mkdir(parents=True, exist_ok=True)
14
+ path.write_text(
15
+ json.dumps(payload, default=_json_default, indent=2, sort_keys=True), encoding="utf-8"
16
+ )
17
+
18
+
19
+ def _json_default(value: Any) -> Any:
20
+ if isinstance(value, Decimal | datetime | date | time | UUID):
21
+ return str(value)
22
+ if isinstance(value, timedelta):
23
+ return value.total_seconds()
24
+ if is_dataclass(value):
25
+ return value.__dict__
26
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlproof.runners.migration import migration
4
+ from sqlproof.runners.overload import function_overloads
5
+ from sqlproof.runners.property import Check, sqlproof
6
+ from sqlproof.runners.rls import rls
7
+ from sqlproof.runners.stateful import stateful
8
+
9
+ sqlproof.stateful = stateful # type: ignore[attr-defined]
10
+ sqlproof.migration = migration # type: ignore[attr-defined]
11
+ sqlproof.rls = rls # type: ignore[attr-defined]
12
+ sqlproof.function_overloads = function_overloads # type: ignore[attr-defined]
13
+
14
+ __all__ = ["Check", "function_overloads", "migration", "rls", "sqlproof", "stateful"]
sqlproof/runners/db.py ADDED
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+ from contextlib import contextmanager
5
+ from typing import Any, cast
6
+
7
+ import psycopg
8
+ from psycopg.rows import dict_row
9
+
10
+ from sqlproof.client import PsycopgSqlProofClient, SqlProofClient
11
+ from sqlproof.config import SqlProofConfig
12
+ from sqlproof.exceptions import SqlProofUsageError
13
+
14
+
15
+ class DBManager:
16
+ def __init__(self, config: SqlProofConfig) -> None:
17
+ self.config = config
18
+ self.started = False
19
+ self._connection: Any | None = None
20
+
21
+ def start(self) -> None:
22
+ if self.config.connection_string is None:
23
+ msg = "DBManager requires a connection_string for real Postgres execution."
24
+ raise SqlProofUsageError(msg)
25
+ if self.started:
26
+ return
27
+ self._connection = psycopg.connect(
28
+ conninfo=self.config.connection_string,
29
+ autocommit=False,
30
+ row_factory=cast(Any, dict_row),
31
+ )
32
+ self.started = True
33
+
34
+ @contextmanager
35
+ def acquire(self, *, persistent: bool = False) -> Generator[SqlProofClient]:
36
+ del persistent
37
+ if self._connection is None:
38
+ self.start()
39
+ if self._connection is None:
40
+ msg = "DBManager failed to establish a database connection."
41
+ raise SqlProofUsageError(msg)
42
+ yield PsycopgSqlProofClient(self._connection)
43
+
44
+ def stop(self) -> None:
45
+ if self._connection is not None:
46
+ self._connection.close()
47
+ self._connection = None
48
+ self.started = False
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ from pathlib import Path
5
+ from typing import Any, cast
6
+
7
+ from hypothesis import HealthCheck, given, settings
8
+
9
+ from sqlproof.generators.graph import SizeSpec, dataset_strategy
10
+
11
+
12
+ def migration(
13
+ proof: Any,
14
+ *,
15
+ before_schema: str,
16
+ migration: str,
17
+ sizes: Mapping[str, SizeSpec],
18
+ **kwargs: object,
19
+ ) -> Callable[[Callable[..., None]], Callable[..., None]]:
20
+ del before_schema
21
+ runs = int(cast(Any, kwargs.pop("runs", 1)))
22
+
23
+ def decorate(function: Callable[..., None]) -> Callable[..., None]:
24
+ def wrapped() -> None:
25
+ @given(dataset_strategy(proof.schema_info, sizes=sizes))
26
+ @settings(
27
+ max_examples=runs,
28
+ deadline=None,
29
+ suppress_health_check=[HealthCheck.function_scoped_fixture],
30
+ )
31
+ def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
32
+ with (
33
+ proof.client_for_dataset(dataset) as before,
34
+ proof.client_for_dataset(dataset) as after,
35
+ ):
36
+ _execute_migration(after, migration)
37
+ function(before, after)
38
+
39
+ execute()
40
+
41
+ return wrapped
42
+
43
+ return decorate
44
+
45
+
46
+ def _execute_migration(db: Any, migration: str) -> None:
47
+ path = Path(migration)
48
+ if path.exists():
49
+ db.execute_file(path)
50
+ else:
51
+ db.execute(migration)