sqlproof 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlproof/__init__.py +32 -0
- sqlproof/_version.py +1 -0
- sqlproof/cli.py +151 -0
- sqlproof/client.py +159 -0
- sqlproof/config.py +42 -0
- sqlproof/contrib/__init__.py +3 -0
- sqlproof/contrib/supabase.py +136 -0
- sqlproof/core.py +344 -0
- sqlproof/coverage/__init__.py +6 -0
- sqlproof/coverage/diversity.py +11 -0
- sqlproof/coverage/plpgsql.py +5 -0
- sqlproof/coverage/schema_shape.py +7 -0
- sqlproof/exceptions.py +47 -0
- sqlproof/generators/__init__.py +21 -0
- sqlproof/generators/columns.py +93 -0
- sqlproof/generators/constraints.py +181 -0
- sqlproof/generators/functions.py +9 -0
- sqlproof/generators/graph.py +51 -0
- sqlproof/generators/rows.py +153 -0
- sqlproof/generators/sampling.py +15 -0
- sqlproof/generators/well_known.py +59 -0
- sqlproof/pytest_plugin.py +24 -0
- sqlproof/reporter/__init__.py +5 -0
- sqlproof/reporter/console.py +20 -0
- sqlproof/reporter/json_io.py +26 -0
- sqlproof/runners/__init__.py +14 -0
- sqlproof/runners/db.py +48 -0
- sqlproof/runners/migration.py +51 -0
- sqlproof/runners/overload.py +41 -0
- sqlproof/runners/property.py +119 -0
- sqlproof/runners/rls.py +40 -0
- sqlproof/runners/stateful.py +36 -0
- sqlproof/schema/__init__.py +27 -0
- sqlproof/schema/dependency_graph.py +38 -0
- sqlproof/schema/fingerprint.py +34 -0
- sqlproof/schema/introspect.py +229 -0
- sqlproof/schema/model.py +98 -0
- sqlproof/schema/parse_sql.py +206 -0
- sqlproof/testing.py +101 -0
- sqlproof/types.py +34 -0
- sqlproof-0.1.0a1.dist-info/METADATA +248 -0
- sqlproof-0.1.0a1.dist-info/RECORD +44 -0
- sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
- sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from decimal import Decimal
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from hypothesis import strategies as st
|
|
8
|
+
from hypothesis.strategies import SearchStrategy
|
|
9
|
+
|
|
10
|
+
from sqlproof.schema.model import CheckConstraint, Column
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def refine_for_checks(
|
|
14
|
+
column: Column,
|
|
15
|
+
strategy: SearchStrategy[Any],
|
|
16
|
+
checks: tuple[CheckConstraint, ...],
|
|
17
|
+
) -> SearchStrategy[Any]:
|
|
18
|
+
for check in checks:
|
|
19
|
+
strategy = _refine_for_check(column, strategy, check.expression)
|
|
20
|
+
return strategy
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _refine_for_check(
|
|
24
|
+
column: Column,
|
|
25
|
+
strategy: SearchStrategy[Any],
|
|
26
|
+
expression: str,
|
|
27
|
+
) -> SearchStrategy[Any]:
|
|
28
|
+
expression = _normalize_check_expression(expression)
|
|
29
|
+
in_set_match = re.fullmatch(
|
|
30
|
+
rf"{re.escape(column.name)}\s+IN\s*\((?P<values>.+)\)",
|
|
31
|
+
expression,
|
|
32
|
+
flags=re.IGNORECASE,
|
|
33
|
+
)
|
|
34
|
+
if in_set_match is not None:
|
|
35
|
+
values = tuple(
|
|
36
|
+
_parse_sql_literal(value) for value in in_set_match.group("values").split(",")
|
|
37
|
+
)
|
|
38
|
+
return _sampled_values_for_column(column, values)
|
|
39
|
+
|
|
40
|
+
any_array_match = re.fullmatch(
|
|
41
|
+
rf"\(?\s*{re.escape(column.name)}\s*=\s*ANY\s*"
|
|
42
|
+
r"\(\s*ARRAY\[(?P<values>.+)\]\s*\)\s*\)?",
|
|
43
|
+
expression,
|
|
44
|
+
flags=re.IGNORECASE,
|
|
45
|
+
)
|
|
46
|
+
if any_array_match is not None:
|
|
47
|
+
values = tuple(
|
|
48
|
+
_parse_sql_literal(value) for value in any_array_match.group("values").split(",")
|
|
49
|
+
)
|
|
50
|
+
return _sampled_values_for_column(column, values)
|
|
51
|
+
|
|
52
|
+
length_match = re.fullmatch(
|
|
53
|
+
rf"(?:char_length|length)\s*\(\s*{re.escape(column.name)}\s*\)\s*"
|
|
54
|
+
rf"(?P<op>>=|>|<=|<|=)\s*(?P<value>\d+)",
|
|
55
|
+
expression,
|
|
56
|
+
flags=re.IGNORECASE,
|
|
57
|
+
)
|
|
58
|
+
if length_match is not None:
|
|
59
|
+
direct = _direct_length_strategy(
|
|
60
|
+
column, length_match.group("op"), int(length_match.group("value"))
|
|
61
|
+
)
|
|
62
|
+
if direct is not None:
|
|
63
|
+
return direct
|
|
64
|
+
|
|
65
|
+
range_match = re.fullmatch(
|
|
66
|
+
rf"{re.escape(column.name)}\s*(?P<op>>=|>|<=|<)\s*(?P<value>-?\d+(?:\.\d+)?)",
|
|
67
|
+
expression,
|
|
68
|
+
flags=re.IGNORECASE,
|
|
69
|
+
)
|
|
70
|
+
if range_match is None:
|
|
71
|
+
return strategy
|
|
72
|
+
op = range_match.group("op")
|
|
73
|
+
raw_value = Decimal(range_match.group("value"))
|
|
74
|
+
direct = _direct_range_strategy(column, op, raw_value)
|
|
75
|
+
if direct is not None:
|
|
76
|
+
return direct
|
|
77
|
+
|
|
78
|
+
def predicate(value: Any) -> bool:
|
|
79
|
+
if value is None:
|
|
80
|
+
return True
|
|
81
|
+
comparable = Decimal(str(value))
|
|
82
|
+
if op == ">=":
|
|
83
|
+
return comparable >= raw_value
|
|
84
|
+
if op == ">":
|
|
85
|
+
return comparable > raw_value
|
|
86
|
+
if op == "<=":
|
|
87
|
+
return comparable <= raw_value
|
|
88
|
+
return comparable < raw_value
|
|
89
|
+
|
|
90
|
+
return strategy.filter(predicate)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _normalize_check_expression(expression: str) -> str:
|
|
94
|
+
value = expression.strip()
|
|
95
|
+
match = re.fullmatch(r"CHECK\s*\((?P<inner>.*)\)", value, flags=re.IGNORECASE | re.DOTALL)
|
|
96
|
+
if match is not None:
|
|
97
|
+
return match.group("inner").strip()
|
|
98
|
+
return value
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _parse_sql_literal(value: str) -> Any:
|
|
102
|
+
stripped = value.strip()
|
|
103
|
+
cast_match = re.fullmatch(
|
|
104
|
+
r"(?P<literal>'(?:''|[^'])*'|-?\d+(?:\.\d+)?)(?:\s*::[\w. ]+)?",
|
|
105
|
+
stripped,
|
|
106
|
+
)
|
|
107
|
+
if cast_match is not None:
|
|
108
|
+
stripped = cast_match.group("literal")
|
|
109
|
+
if stripped.startswith("'") and stripped.endswith("'"):
|
|
110
|
+
return stripped[1:-1].replace("''", "'")
|
|
111
|
+
try:
|
|
112
|
+
return int(stripped)
|
|
113
|
+
except ValueError:
|
|
114
|
+
try:
|
|
115
|
+
return Decimal(stripped)
|
|
116
|
+
except Exception:
|
|
117
|
+
return stripped
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _sampled_values_for_column(column: Column, values: tuple[Any, ...]) -> SearchStrategy[Any]:
|
|
121
|
+
strategy = st.sampled_from(values)
|
|
122
|
+
if column.nullable:
|
|
123
|
+
return st.none() | strategy
|
|
124
|
+
return strategy
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _direct_length_strategy(column: Column, op: str, value: int) -> SearchStrategy[str] | None:
|
|
128
|
+
name = column.type.name
|
|
129
|
+
if name not in {"text", "citext", "varchar", "character varying", "char", "character"}:
|
|
130
|
+
return None
|
|
131
|
+
min_size = 0
|
|
132
|
+
max_size = column.type.modifiers[0] if column.type.modifiers else 255
|
|
133
|
+
if op == "=":
|
|
134
|
+
min_size = value
|
|
135
|
+
max_size = value
|
|
136
|
+
elif op == "<=":
|
|
137
|
+
max_size = min(max_size, value)
|
|
138
|
+
elif op == "<":
|
|
139
|
+
max_size = min(max_size, max(0, value - 1))
|
|
140
|
+
elif op == ">=":
|
|
141
|
+
min_size = value
|
|
142
|
+
elif op == ">":
|
|
143
|
+
min_size = value + 1
|
|
144
|
+
if min_size > max_size:
|
|
145
|
+
return st.nothing()
|
|
146
|
+
return st.text(min_size=min_size, max_size=max_size)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _direct_range_strategy(
|
|
150
|
+
column: Column,
|
|
151
|
+
op: str,
|
|
152
|
+
raw_value: Decimal,
|
|
153
|
+
) -> SearchStrategy[Any] | None:
|
|
154
|
+
name = column.type.name
|
|
155
|
+
if op not in {">=", ">"}:
|
|
156
|
+
return None
|
|
157
|
+
minimum = raw_value if op == ">=" else raw_value + Decimal("1")
|
|
158
|
+
if name in {"integer", "int", "int4"}:
|
|
159
|
+
return st.integers(max(int(minimum), -2_147_483_648), 2_147_483_647)
|
|
160
|
+
if name in {"bigint", "int8"}:
|
|
161
|
+
return st.integers(max(int(minimum), -(2**63)), 2**63 - 1)
|
|
162
|
+
if name in {"numeric", "decimal"}:
|
|
163
|
+
places = column.type.modifiers[1] if len(column.type.modifiers) > 1 else 2
|
|
164
|
+
return st.decimals(
|
|
165
|
+
min_value=minimum,
|
|
166
|
+
max_value=Decimal("1000000"),
|
|
167
|
+
places=places,
|
|
168
|
+
allow_nan=False,
|
|
169
|
+
allow_infinity=False,
|
|
170
|
+
)
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def unique_rows(rows: list[dict[str, Any]], columns: tuple[str, ...]) -> bool:
|
|
175
|
+
seen: set[tuple[Any, ...]] = set()
|
|
176
|
+
for row in rows:
|
|
177
|
+
key = tuple(row[column] for column in columns)
|
|
178
|
+
if key in seen:
|
|
179
|
+
return False
|
|
180
|
+
seen.add(key)
|
|
181
|
+
return True
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from hypothesis import strategies as st
|
|
7
|
+
from hypothesis.strategies import SearchStrategy
|
|
8
|
+
|
|
9
|
+
from sqlproof.schema.dependency_graph import insertion_order
|
|
10
|
+
from sqlproof.schema.model import SchemaInfo
|
|
11
|
+
|
|
12
|
+
from .rows import ColumnOverrides, table_rows_strategy
|
|
13
|
+
|
|
14
|
+
Dataset = dict[str, list[dict[str, Any]]]
|
|
15
|
+
SizeSpec = int | SearchStrategy[int]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def dataset_strategy(
|
|
19
|
+
schema: SchemaInfo,
|
|
20
|
+
*,
|
|
21
|
+
sizes: Mapping[str, SizeSpec],
|
|
22
|
+
external_parent_rows: Mapping[str, list[dict[str, Any]]] | None = None,
|
|
23
|
+
columns: ColumnOverrides | None = None,
|
|
24
|
+
) -> SearchStrategy[Dataset]:
|
|
25
|
+
ordered_tables = insertion_order(schema.tables)
|
|
26
|
+
external_parent_rows = external_parent_rows or {}
|
|
27
|
+
|
|
28
|
+
@st.composite
|
|
29
|
+
def dataset(draw: st.DrawFn) -> Dataset:
|
|
30
|
+
rows_by_table: Dataset = {}
|
|
31
|
+
for table in ordered_tables:
|
|
32
|
+
count = _draw_size(draw, sizes.get(table.name, 0))
|
|
33
|
+
available_parent_rows = {**external_parent_rows, **rows_by_table}
|
|
34
|
+
rows_by_table[table.name] = draw(
|
|
35
|
+
table_rows_strategy(
|
|
36
|
+
table,
|
|
37
|
+
count=count,
|
|
38
|
+
parent_rows=available_parent_rows,
|
|
39
|
+
rows_by_table=rows_by_table,
|
|
40
|
+
columns=columns,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
return rows_by_table
|
|
44
|
+
|
|
45
|
+
return dataset()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _draw_size(draw: st.DrawFn, size: SizeSpec) -> int:
|
|
49
|
+
if isinstance(size, int):
|
|
50
|
+
return size
|
|
51
|
+
return draw(size)
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from hypothesis import strategies as st
|
|
10
|
+
from hypothesis.strategies import SearchStrategy
|
|
11
|
+
|
|
12
|
+
from sqlproof.exceptions import SqlProofGenerationError
|
|
13
|
+
from sqlproof.generators.columns import strategy_for_column
|
|
14
|
+
from sqlproof.generators.constraints import refine_for_checks
|
|
15
|
+
from sqlproof.schema.model import ForeignKey, Table
|
|
16
|
+
|
|
17
|
+
DatasetRows = dict[str, list[dict[str, Any]]]
|
|
18
|
+
ColumnOverrides = Mapping[str, Any]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, slots=True)
|
|
22
|
+
class ColumnContext:
|
|
23
|
+
table: Table
|
|
24
|
+
column_name: str
|
|
25
|
+
row_index: int
|
|
26
|
+
row: dict[str, Any]
|
|
27
|
+
table_rows: list[dict[str, Any]]
|
|
28
|
+
rows_by_table: DatasetRows
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def table_rows_strategy(
|
|
32
|
+
table: Table,
|
|
33
|
+
*,
|
|
34
|
+
count: int,
|
|
35
|
+
parent_rows: DatasetRows | None = None,
|
|
36
|
+
rows_by_table: DatasetRows | None = None,
|
|
37
|
+
columns: ColumnOverrides | None = None,
|
|
38
|
+
) -> SearchStrategy[list[dict[str, Any]]]:
|
|
39
|
+
parent_rows = parent_rows or {}
|
|
40
|
+
rows_by_table = rows_by_table or {}
|
|
41
|
+
columns = columns or {}
|
|
42
|
+
|
|
43
|
+
@st.composite
|
|
44
|
+
def rows(draw: st.DrawFn) -> list[dict[str, Any]]:
|
|
45
|
+
generated: list[dict[str, Any]] = []
|
|
46
|
+
for index in range(count):
|
|
47
|
+
row: dict[str, Any] = {}
|
|
48
|
+
for column in table.columns:
|
|
49
|
+
if column.name in table.primary_key and len(table.primary_key) == 1:
|
|
50
|
+
row[column.name] = _unique_value(column.name, column.type.name, index)
|
|
51
|
+
continue
|
|
52
|
+
if column.is_generated:
|
|
53
|
+
continue
|
|
54
|
+
override = _column_override(columns, table, column.name)
|
|
55
|
+
if override is not None:
|
|
56
|
+
context = ColumnContext(
|
|
57
|
+
table=table,
|
|
58
|
+
column_name=column.name,
|
|
59
|
+
row_index=index,
|
|
60
|
+
row=row,
|
|
61
|
+
table_rows=generated,
|
|
62
|
+
rows_by_table=rows_by_table,
|
|
63
|
+
)
|
|
64
|
+
row[column.name] = _draw_override(draw, override, context)
|
|
65
|
+
continue
|
|
66
|
+
if column.default is not None:
|
|
67
|
+
continue
|
|
68
|
+
fk = _foreign_key_for_column(table, column.name)
|
|
69
|
+
if fk is not None:
|
|
70
|
+
parent_key = _parent_rows_key(fk, parent_rows)
|
|
71
|
+
if parent_key is not None:
|
|
72
|
+
parents = parent_rows[parent_key]
|
|
73
|
+
if parents:
|
|
74
|
+
parent = draw(st.sampled_from(parents))
|
|
75
|
+
row[column.name] = parent[fk.referenced_columns[0]]
|
|
76
|
+
continue
|
|
77
|
+
if column.nullable:
|
|
78
|
+
row[column.name] = None
|
|
79
|
+
continue
|
|
80
|
+
msg = (
|
|
81
|
+
f"Cannot generate {table.name}.{column.name}: "
|
|
82
|
+
f"required foreign key has no available parent rows for "
|
|
83
|
+
f"{fk.referenced_table}.{fk.referenced_columns[0]}."
|
|
84
|
+
)
|
|
85
|
+
raise SqlProofGenerationError(msg)
|
|
86
|
+
if _is_single_column_unique(table, column.name):
|
|
87
|
+
row[column.name] = _unique_value(column.name, column.type.name, index)
|
|
88
|
+
continue
|
|
89
|
+
strategy = refine_for_checks(
|
|
90
|
+
column, strategy_for_column(column), table.check_constraints
|
|
91
|
+
)
|
|
92
|
+
row[column.name] = draw(strategy)
|
|
93
|
+
generated.append(row)
|
|
94
|
+
return generated
|
|
95
|
+
|
|
96
|
+
return rows()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _column_override(
|
|
100
|
+
overrides: ColumnOverrides,
|
|
101
|
+
table: Table,
|
|
102
|
+
column_name: str,
|
|
103
|
+
) -> Any | None:
|
|
104
|
+
for key in (f"{table.qualified_name}.{column_name}", f"{table.name}.{column_name}"):
|
|
105
|
+
if key in overrides:
|
|
106
|
+
return overrides[key]
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _draw_override(draw: st.DrawFn, override: Any, context: ColumnContext) -> Any:
|
|
111
|
+
if isinstance(override, SearchStrategy):
|
|
112
|
+
return draw(cast(SearchStrategy[Any], override))
|
|
113
|
+
if callable(override):
|
|
114
|
+
return override(context)
|
|
115
|
+
return override
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _foreign_key_for_column(table: Table, column_name: str) -> ForeignKey | None:
|
|
119
|
+
for foreign_key in table.foreign_keys:
|
|
120
|
+
if foreign_key.columns == (column_name,):
|
|
121
|
+
return foreign_key
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _parent_rows_key(
|
|
126
|
+
foreign_key: ForeignKey,
|
|
127
|
+
parent_rows: dict[str, list[dict[str, Any]]],
|
|
128
|
+
) -> str | None:
|
|
129
|
+
if foreign_key.referenced_schema is not None:
|
|
130
|
+
qualified = f"{foreign_key.referenced_schema}.{foreign_key.referenced_table}"
|
|
131
|
+
if qualified in parent_rows:
|
|
132
|
+
return qualified
|
|
133
|
+
if foreign_key.referenced_table in parent_rows:
|
|
134
|
+
return foreign_key.referenced_table
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _is_single_column_unique(table: Table, column_name: str) -> bool:
|
|
139
|
+
return any(columns == (column_name,) for columns in table.unique_constraints)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _unique_value(column_name: str, type_name: str, index: int) -> Any:
|
|
143
|
+
normalized = type_name.lower()
|
|
144
|
+
value = index + 1
|
|
145
|
+
if normalized in {"smallint", "int2", "integer", "int", "int4", "serial"}:
|
|
146
|
+
return value
|
|
147
|
+
if normalized in {"bigint", "int8", "bigserial"}:
|
|
148
|
+
return value
|
|
149
|
+
if normalized in {"numeric", "decimal"}:
|
|
150
|
+
return Decimal(value)
|
|
151
|
+
if normalized == "uuid":
|
|
152
|
+
return str(UUID(int=value))
|
|
153
|
+
return f"{column_name}_{value}"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
from hypothesis.errors import NonInteractiveExampleWarning
|
|
7
|
+
from hypothesis.strategies import SearchStrategy
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def draw_example(strategy: SearchStrategy[T]) -> T:
|
|
13
|
+
with warnings.catch_warnings():
|
|
14
|
+
warnings.simplefilter("ignore", NonInteractiveExampleWarning)
|
|
15
|
+
return strategy.example()
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from hypothesis import strategies as st
|
|
4
|
+
from hypothesis.strategies import SearchStrategy
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def slugs(min_length: int = 1, max_length: int = 64) -> SearchStrategy[str]:
|
|
8
|
+
alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-"
|
|
9
|
+
return st.text(alphabet=alphabet, min_size=min_length, max_size=max_length).filter(
|
|
10
|
+
lambda value: not value.startswith("-") and not value.endswith("-")
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def emails(domains: list[str] | None = None) -> SearchStrategy[str]:
|
|
15
|
+
domain_strategy = (
|
|
16
|
+
st.sampled_from(domains) if domains else st.sampled_from(["example.com", "test.dev"])
|
|
17
|
+
)
|
|
18
|
+
local = st.text(
|
|
19
|
+
alphabet="abcdefghijklmnopqrstuvwxyz0123456789._-",
|
|
20
|
+
min_size=1,
|
|
21
|
+
max_size=32,
|
|
22
|
+
).filter(lambda value: value.strip("."))
|
|
23
|
+
return st.builds(_join_email, local, domain_strategy)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def urls(
|
|
27
|
+
schemes: tuple[str, ...] = ("https", "http"),
|
|
28
|
+
*,
|
|
29
|
+
include_path: bool = True,
|
|
30
|
+
include_query: bool = True,
|
|
31
|
+
include_fragment: bool = False,
|
|
32
|
+
) -> SearchStrategy[str]:
|
|
33
|
+
scheme = st.sampled_from(schemes)
|
|
34
|
+
host = st.sampled_from(["example.com", "sqlproof.dev", "localhost"])
|
|
35
|
+
path = slugs(max_length=20).map(lambda value: f"/{value}") if include_path else st.just("")
|
|
36
|
+
query = st.text(max_size=10).map(lambda value: f"?q={value}") if include_query else st.just("")
|
|
37
|
+
fragment = (
|
|
38
|
+
slugs(max_length=10).map(lambda value: f"#{value}") if include_fragment else st.just("")
|
|
39
|
+
)
|
|
40
|
+
return st.builds(_join_url, scheme, host, path, query, fragment)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def phone_numbers(country: str | None = None) -> SearchStrategy[str]:
|
|
44
|
+
prefix = "+1" if country in {None, "US", "CA"} else "+44"
|
|
45
|
+
return st.integers(2_000_000_000, 9_999_999_999).map(lambda number: f"{prefix}{number}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def postal_codes(country: str) -> SearchStrategy[str]:
|
|
49
|
+
if country.upper() in {"US", "USA"}:
|
|
50
|
+
return st.integers(0, 99_999).map(lambda number: f"{number:05d}")
|
|
51
|
+
return st.text(alphabet="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ", min_size=3, max_size=10)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _join_email(lhs: str, rhs: str) -> str:
|
|
55
|
+
return f"{lhs}@{rhs}"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _join_url(scheme: str, host: str, path: str, query: str, fragment: str) -> str:
|
|
59
|
+
return f"{scheme}://{host}{path}{query}{fragment}"
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
7
|
+
group = parser.getgroup("sqlproof")
|
|
8
|
+
group.addoption("--sqlproof-seed", action="store", type=int, help="Fix the SqlProof seed.")
|
|
9
|
+
group.addoption(
|
|
10
|
+
"--sqlproof-runs", action="store", type=int, help="Override SqlProof run count."
|
|
11
|
+
)
|
|
12
|
+
group.addoption(
|
|
13
|
+
"--sqlproof-show-counterexample",
|
|
14
|
+
action="store_true",
|
|
15
|
+
help="Print full SqlProof counterexamples.",
|
|
16
|
+
)
|
|
17
|
+
group.addoption("--sqlproof-coverage", action="store_true", help="Enable PL/pgSQL coverage.")
|
|
18
|
+
group.addoption(
|
|
19
|
+
"--sqlproof-diversity-report",
|
|
20
|
+
action="store_true",
|
|
21
|
+
help="Print generator diversity report.",
|
|
22
|
+
)
|
|
23
|
+
group.addoption("--sqlproof-postgres-image", action="store", help="Override Postgres image.")
|
|
24
|
+
group.addoption("--sqlproof-verbose", action="store_true", help="Enable DEBUG logging.")
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
from sqlproof.coverage.schema_shape import summarize_dataset_shape
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def format_failure(payload: dict[str, Any]) -> str:
|
|
9
|
+
lines = [
|
|
10
|
+
f"Property failed: {payload.get('property_name', 'counterexample')}",
|
|
11
|
+
f"Failure: {payload.get('failure', {}).get('kind', 'unknown')}: "
|
|
12
|
+
f"{payload.get('failure', {}).get('message', '')}",
|
|
13
|
+
]
|
|
14
|
+
if payload.get("row_context"):
|
|
15
|
+
lines.append(f"Row context: {payload['row_context']}")
|
|
16
|
+
dataset = payload.get("dataset")
|
|
17
|
+
if isinstance(dataset, dict):
|
|
18
|
+
shape = summarize_dataset_shape(cast(dict[str, list[dict[str, Any]]], dataset))
|
|
19
|
+
lines.append(f"Dataset shape: {shape}")
|
|
20
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import is_dataclass
|
|
5
|
+
from datetime import date, datetime, time, timedelta
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
from uuid import UUID
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def write_counterexample(path: Path, payload: dict[str, Any]) -> None:
|
|
13
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
14
|
+
path.write_text(
|
|
15
|
+
json.dumps(payload, default=_json_default, indent=2, sort_keys=True), encoding="utf-8"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _json_default(value: Any) -> Any:
|
|
20
|
+
if isinstance(value, Decimal | datetime | date | time | UUID):
|
|
21
|
+
return str(value)
|
|
22
|
+
if isinstance(value, timedelta):
|
|
23
|
+
return value.total_seconds()
|
|
24
|
+
if is_dataclass(value):
|
|
25
|
+
return value.__dict__
|
|
26
|
+
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from sqlproof.runners.migration import migration
|
|
4
|
+
from sqlproof.runners.overload import function_overloads
|
|
5
|
+
from sqlproof.runners.property import Check, sqlproof
|
|
6
|
+
from sqlproof.runners.rls import rls
|
|
7
|
+
from sqlproof.runners.stateful import stateful
|
|
8
|
+
|
|
9
|
+
sqlproof.stateful = stateful # type: ignore[attr-defined]
|
|
10
|
+
sqlproof.migration = migration # type: ignore[attr-defined]
|
|
11
|
+
sqlproof.rls = rls # type: ignore[attr-defined]
|
|
12
|
+
sqlproof.function_overloads = function_overloads # type: ignore[attr-defined]
|
|
13
|
+
|
|
14
|
+
__all__ = ["Check", "function_overloads", "migration", "rls", "sqlproof", "stateful"]
|
sqlproof/runners/db.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
import psycopg
|
|
8
|
+
from psycopg.rows import dict_row
|
|
9
|
+
|
|
10
|
+
from sqlproof.client import PsycopgSqlProofClient, SqlProofClient
|
|
11
|
+
from sqlproof.config import SqlProofConfig
|
|
12
|
+
from sqlproof.exceptions import SqlProofUsageError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DBManager:
|
|
16
|
+
def __init__(self, config: SqlProofConfig) -> None:
|
|
17
|
+
self.config = config
|
|
18
|
+
self.started = False
|
|
19
|
+
self._connection: Any | None = None
|
|
20
|
+
|
|
21
|
+
def start(self) -> None:
|
|
22
|
+
if self.config.connection_string is None:
|
|
23
|
+
msg = "DBManager requires a connection_string for real Postgres execution."
|
|
24
|
+
raise SqlProofUsageError(msg)
|
|
25
|
+
if self.started:
|
|
26
|
+
return
|
|
27
|
+
self._connection = psycopg.connect(
|
|
28
|
+
conninfo=self.config.connection_string,
|
|
29
|
+
autocommit=False,
|
|
30
|
+
row_factory=cast(Any, dict_row),
|
|
31
|
+
)
|
|
32
|
+
self.started = True
|
|
33
|
+
|
|
34
|
+
@contextmanager
|
|
35
|
+
def acquire(self, *, persistent: bool = False) -> Generator[SqlProofClient]:
|
|
36
|
+
del persistent
|
|
37
|
+
if self._connection is None:
|
|
38
|
+
self.start()
|
|
39
|
+
if self._connection is None:
|
|
40
|
+
msg = "DBManager failed to establish a database connection."
|
|
41
|
+
raise SqlProofUsageError(msg)
|
|
42
|
+
yield PsycopgSqlProofClient(self._connection)
|
|
43
|
+
|
|
44
|
+
def stop(self) -> None:
|
|
45
|
+
if self._connection is not None:
|
|
46
|
+
self._connection.close()
|
|
47
|
+
self._connection = None
|
|
48
|
+
self.started = False
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
from hypothesis import HealthCheck, given, settings
|
|
8
|
+
|
|
9
|
+
from sqlproof.generators.graph import SizeSpec, dataset_strategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def migration(
|
|
13
|
+
proof: Any,
|
|
14
|
+
*,
|
|
15
|
+
before_schema: str,
|
|
16
|
+
migration: str,
|
|
17
|
+
sizes: Mapping[str, SizeSpec],
|
|
18
|
+
**kwargs: object,
|
|
19
|
+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
|
|
20
|
+
del before_schema
|
|
21
|
+
runs = int(cast(Any, kwargs.pop("runs", 1)))
|
|
22
|
+
|
|
23
|
+
def decorate(function: Callable[..., None]) -> Callable[..., None]:
|
|
24
|
+
def wrapped() -> None:
|
|
25
|
+
@given(dataset_strategy(proof.schema_info, sizes=sizes))
|
|
26
|
+
@settings(
|
|
27
|
+
max_examples=runs,
|
|
28
|
+
deadline=None,
|
|
29
|
+
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
30
|
+
)
|
|
31
|
+
def execute(dataset: dict[str, list[dict[str, Any]]]) -> None:
|
|
32
|
+
with (
|
|
33
|
+
proof.client_for_dataset(dataset) as before,
|
|
34
|
+
proof.client_for_dataset(dataset) as after,
|
|
35
|
+
):
|
|
36
|
+
_execute_migration(after, migration)
|
|
37
|
+
function(before, after)
|
|
38
|
+
|
|
39
|
+
execute()
|
|
40
|
+
|
|
41
|
+
return wrapped
|
|
42
|
+
|
|
43
|
+
return decorate
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _execute_migration(db: Any, migration: str) -> None:
|
|
47
|
+
path = Path(migration)
|
|
48
|
+
if path.exists():
|
|
49
|
+
db.execute_file(path)
|
|
50
|
+
else:
|
|
51
|
+
db.execute(migration)
|