sqlproof 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sqlproof/__init__.py +32 -0
  2. sqlproof/_version.py +1 -0
  3. sqlproof/cli.py +151 -0
  4. sqlproof/client.py +159 -0
  5. sqlproof/config.py +42 -0
  6. sqlproof/contrib/__init__.py +3 -0
  7. sqlproof/contrib/supabase.py +136 -0
  8. sqlproof/core.py +344 -0
  9. sqlproof/coverage/__init__.py +6 -0
  10. sqlproof/coverage/diversity.py +11 -0
  11. sqlproof/coverage/plpgsql.py +5 -0
  12. sqlproof/coverage/schema_shape.py +7 -0
  13. sqlproof/exceptions.py +47 -0
  14. sqlproof/generators/__init__.py +21 -0
  15. sqlproof/generators/columns.py +93 -0
  16. sqlproof/generators/constraints.py +181 -0
  17. sqlproof/generators/functions.py +9 -0
  18. sqlproof/generators/graph.py +51 -0
  19. sqlproof/generators/rows.py +153 -0
  20. sqlproof/generators/sampling.py +15 -0
  21. sqlproof/generators/well_known.py +59 -0
  22. sqlproof/pytest_plugin.py +24 -0
  23. sqlproof/reporter/__init__.py +5 -0
  24. sqlproof/reporter/console.py +20 -0
  25. sqlproof/reporter/json_io.py +26 -0
  26. sqlproof/runners/__init__.py +14 -0
  27. sqlproof/runners/db.py +48 -0
  28. sqlproof/runners/migration.py +51 -0
  29. sqlproof/runners/overload.py +41 -0
  30. sqlproof/runners/property.py +119 -0
  31. sqlproof/runners/rls.py +40 -0
  32. sqlproof/runners/stateful.py +36 -0
  33. sqlproof/schema/__init__.py +27 -0
  34. sqlproof/schema/dependency_graph.py +38 -0
  35. sqlproof/schema/fingerprint.py +34 -0
  36. sqlproof/schema/introspect.py +229 -0
  37. sqlproof/schema/model.py +98 -0
  38. sqlproof/schema/parse_sql.py +206 -0
  39. sqlproof/testing.py +101 -0
  40. sqlproof/types.py +34 -0
  41. sqlproof-0.1.0a1.dist-info/METADATA +248 -0
  42. sqlproof-0.1.0a1.dist-info/RECORD +44 -0
  43. sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
  44. sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
sqlproof/core.py ADDED
@@ -0,0 +1,344 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Generator, Mapping
4
+ from contextlib import contextmanager
5
+ from pathlib import Path
6
+ from types import TracebackType
7
+ from typing import Any, Self, cast
8
+
9
+ import psycopg
10
+ from hypothesis import strategies as st
11
+ from hypothesis.strategies import SearchStrategy
12
+ from psycopg.rows import dict_row
13
+ from psycopg.types.json import Json, Jsonb
14
+
15
+ from sqlproof.client import InMemorySqlProofClient, PsycopgSqlProofClient, SqlProofClient
16
+ from sqlproof.config import ExternalSeed, ExternalTableSpec, SqlProofConfig
17
+ from sqlproof.exceptions import SqlProofPropertyFailure, SqlProofUsageError
18
+ from sqlproof.generators.graph import ColumnOverrides, Dataset, SizeSpec, dataset_strategy
19
+ from sqlproof.generators.sampling import draw_example
20
+ from sqlproof.schema.dependency_graph import insertion_order
21
+ from sqlproof.schema.fingerprint import compute
22
+ from sqlproof.schema.introspect import introspect_schema
23
+ from sqlproof.schema.model import Column, SchemaInfo, Table
24
+ from sqlproof.schema.parse_sql import parse_schema_sql
25
+
26
+
27
+ class SqlProof:
28
+ def __init__(self, config: SqlProofConfig) -> None:
29
+ from sqlproof.runners.db import DBManager
30
+
31
+ self.config = config
32
+ self.schema_info = self._load_schema(config)
33
+ self.schema_fingerprint = compute(self.schema_info)
34
+ self._db_manager = DBManager(config) if config.connection_string is not None else None
35
+ self._external_sample_cache: dict[str, list[object]] = {}
36
+
37
+ @classmethod
38
+ def from_schema_file(cls, path: str | Path, **kwargs: Any) -> Self:
39
+ return cls(SqlProofConfig(schema_file=path, **kwargs))
40
+
41
+ @classmethod
42
+ def from_connection_string(cls, dsn: str, **kwargs: Any) -> Self:
43
+ return cls(SqlProofConfig(connection_string=dsn, **kwargs))
44
+
45
+ @classmethod
46
+ def from_config(cls, config: SqlProofConfig) -> Self:
47
+ return cls(config)
48
+
49
+ def customize(self, table: str, **overrides: object) -> Self:
50
+ del table, overrides
51
+ return self
52
+
53
+ def dataset_strategy(
54
+ self,
55
+ *,
56
+ sizes: Mapping[str, SizeSpec],
57
+ columns: ColumnOverrides | None = None,
58
+ ) -> SearchStrategy[Dataset]:
59
+ if self.config.external_tables:
60
+ return self._dataset_strategy_with_external_tables(sizes=sizes, columns=columns)
61
+ return dataset_strategy(
62
+ self.schema_info,
63
+ sizes=sizes,
64
+ columns=columns,
65
+ )
66
+
67
+ def _dataset_strategy_with_external_tables(
68
+ self,
69
+ *,
70
+ sizes: Mapping[str, SizeSpec],
71
+ columns: ColumnOverrides | None,
72
+ ) -> SearchStrategy[Dataset]:
73
+ @st.composite
74
+ def dataset(draw: st.DrawFn) -> Dataset:
75
+ external_parent_rows = self._external_parent_rows(draw=draw)
76
+ return draw(
77
+ dataset_strategy(
78
+ self.schema_info,
79
+ sizes=sizes,
80
+ external_parent_rows=external_parent_rows,
81
+ columns=columns,
82
+ )
83
+ )
84
+
85
+ return dataset()
86
+
87
+ def run_state_machine(
88
+ self,
89
+ machine_class: type,
90
+ *,
91
+ settings: Any = None,
92
+ ) -> None:
93
+ """Run a `SqlProofStateMachine` subclass against this proof.
94
+
95
+ Binds `self` as the proof for the machine, then dispatches to
96
+ `hypothesis.stateful.run_state_machine_as_test`. Each example gets
97
+ an isolated dataset client; writes from one example are rolled back
98
+ before the next begins.
99
+ """
100
+ from hypothesis.stateful import run_state_machine_as_test
101
+
102
+ from sqlproof.testing import SqlProofStateMachine
103
+
104
+ if not isinstance(machine_class, type) or not issubclass(
105
+ machine_class, SqlProofStateMachine
106
+ ):
107
+ msg = "machine_class must be a subclass of SqlProofStateMachine."
108
+ raise SqlProofUsageError(msg)
109
+
110
+ bound_class = type(
111
+ machine_class.__name__,
112
+ (machine_class,),
113
+ {"_sqlproof_proof": self},
114
+ )
115
+ run_state_machine_as_test(bound_class, settings=settings)
116
+
117
+ @contextmanager
118
+ def client_for_dataset(
119
+ self, dataset: dict[str, list[dict[str, Any]]]
120
+ ) -> Generator[SqlProofClient]:
121
+ if self._db_manager is None:
122
+ yield InMemorySqlProofClient(dataset)
123
+ return
124
+ with self._db_manager.acquire() as client:
125
+ client.execute("SAVEPOINT sqlproof_run")
126
+ try:
127
+ _insert_dataset(client, self.schema_info, dataset)
128
+ yield client
129
+ finally:
130
+ client.execute("ROLLBACK TO SAVEPOINT sqlproof_run")
131
+ client.execute("RELEASE SAVEPOINT sqlproof_run")
132
+
133
+ def check(
134
+ self,
135
+ name: str,
136
+ *,
137
+ sizes: Mapping[str, SizeSpec],
138
+ property: Callable[..., None],
139
+ setup: object | None = None,
140
+ runs: int = 100,
141
+ seed: int | None = None,
142
+ timeout_ms: int = 5000,
143
+ commit: bool = False,
144
+ ) -> None:
145
+ from sqlproof.runners.property import run_property
146
+
147
+ del name, setup, seed, timeout_ms, commit
148
+ if not callable(property):
149
+ msg = "property must be callable"
150
+ raise TypeError(msg)
151
+ run_property(self, property, sizes=sizes, runs=runs, failure_dir=Path(".sqlproof/failures"))
152
+
153
+ def invariant(
154
+ self,
155
+ name: str,
156
+ *,
157
+ sizes: Mapping[str, SizeSpec],
158
+ query: str,
159
+ expect_empty: bool = True,
160
+ runs: int = 100,
161
+ seed: int | None = None,
162
+ timeout_ms: int = 5000,
163
+ ) -> None:
164
+ del seed, timeout_ms
165
+ strategy = self.dataset_strategy(sizes=sizes)
166
+ for run_index in range(runs):
167
+ client = InMemorySqlProofClient(draw_example(strategy))
168
+ rows = client.query(query)
169
+ failed = bool(rows) if expect_empty else not rows
170
+ if failed:
171
+ payload = {
172
+ "property_name": name,
173
+ "runs": run_index + 1,
174
+ "row_context": {},
175
+ "dataset": client.get_generated_data(),
176
+ "schema_fingerprint": self.schema_fingerprint,
177
+ }
178
+ raise SqlProofPropertyFailure(
179
+ f"Invariant {name!r} failed: query returned {len(rows)} rows.",
180
+ counterexample=payload,
181
+ )
182
+
183
+ def disconnect(self) -> None:
184
+ if self._db_manager is not None:
185
+ self._db_manager.stop()
186
+ return None
187
+
188
+ def __enter__(self) -> Self:
189
+ return self
190
+
191
+ def __exit__(
192
+ self,
193
+ exc_type: type[BaseException] | None,
194
+ exc: BaseException | None,
195
+ traceback: TracebackType | None,
196
+ ) -> None:
197
+ self.disconnect()
198
+
199
+ @staticmethod
200
+ def _load_schema(config: SqlProofConfig) -> SchemaInfo:
201
+ if config.schema_file is not None:
202
+ path = Path(config.schema_file)
203
+ return parse_schema_sql(path.read_text(encoding="utf-8"), schema=config.schema)
204
+ if config.connection_string is not None:
205
+ connection = psycopg.connect(
206
+ conninfo=config.connection_string,
207
+ autocommit=True,
208
+ row_factory=cast(Any, dict_row),
209
+ )
210
+ try:
211
+ return introspect_schema(connection, schema=config.schema)
212
+ finally:
213
+ connection.close()
214
+ return SchemaInfo()
215
+
216
+ def _external_parent_rows(
217
+ self,
218
+ *,
219
+ draw: st.DrawFn | None = None,
220
+ ) -> dict[str, list[dict[str, Any]]]:
221
+ if not self.config.external_tables:
222
+ return {}
223
+ if self.config.connection_string is None:
224
+ msg = "external_tables requires a connection_string-backed SqlProof instance."
225
+ raise SqlProofUsageError(msg)
226
+
227
+ connection = psycopg.connect(
228
+ conninfo=self.config.connection_string,
229
+ autocommit=True,
230
+ row_factory=cast(Any, dict_row),
231
+ )
232
+ try:
233
+ client = PsycopgSqlProofClient(connection)
234
+ return _external_parent_rows(
235
+ self.config.external_tables,
236
+ client,
237
+ draw=draw,
238
+ sample_cache=self._external_sample_cache,
239
+ )
240
+ finally:
241
+ connection.close()
242
+
243
+
244
+ def _insert_dataset(
245
+ client: SqlProofClient,
246
+ schema_info: SchemaInfo,
247
+ dataset: dict[str, list[dict[str, Any]]],
248
+ ) -> None:
249
+ for table in insertion_order(schema_info.tables):
250
+ rows = dataset.get(table.name, [])
251
+ for row in rows:
252
+ if not row:
253
+ continue
254
+ columns = list(row)
255
+ placeholders = ", ".join(["%s"] * len(columns))
256
+ column_sql = ", ".join(_quote_identifier(column) for column in columns)
257
+ table_sql = f"{_quote_identifier(table.schema)}.{_quote_identifier(table.name)}"
258
+ sql = f"INSERT INTO {table_sql} ({column_sql}) VALUES ({placeholders})"
259
+ values = [_adapt_insert_value(table, column, row[column]) for column in columns]
260
+ client.execute(sql, *values)
261
+
262
+
263
+ def _quote_identifier(identifier: str) -> str:
264
+ return '"' + identifier.replace('"', '""') + '"'
265
+
266
+
267
+ def _adapt_insert_value(table: Table, column_name: str, value: Any) -> object:
268
+ column = table.column(column_name)
269
+ type_name = _base_type_name(column)
270
+ if type_name == "jsonb":
271
+ return Jsonb(value)
272
+ if type_name == "json":
273
+ return Json(value)
274
+ return value
275
+
276
+
277
+ def _base_type_name(column: Column) -> str:
278
+ pg_type = column.type
279
+ while pg_type.base is not None:
280
+ pg_type = pg_type.base
281
+ return pg_type.name.lower()
282
+
283
+
284
+ def _external_parent_rows(
285
+ specs: Mapping[str, ExternalTableSpec],
286
+ client: SqlProofClient,
287
+ *,
288
+ draw: st.DrawFn | None = None,
289
+ sample_cache: dict[str, list[object]] | None = None,
290
+ ) -> dict[str, list[dict[str, Any]]]:
291
+ rows_by_table: dict[str, list[dict[str, Any]]] = {}
292
+ for table_name, spec in specs.items():
293
+ seed_count = _draw_seed_count(spec.seed_count, draw=draw)
294
+ if spec.seed is not None:
295
+ _call_external_seed(spec.seed, client, seed_count)
296
+ sampled_values = _sample_external_values(
297
+ table_name,
298
+ spec,
299
+ client,
300
+ sample_cache=sample_cache,
301
+ )
302
+ if seed_count is not None:
303
+ sampled_values = sampled_values[:seed_count]
304
+ rows = [{spec.primary_key: value} for value in sampled_values]
305
+ rows_by_table[table_name] = rows
306
+ if "." in table_name:
307
+ rows_by_table.setdefault(table_name.rsplit(".", 1)[1], rows)
308
+ return rows_by_table
309
+
310
+
311
+ def _sample_external_values(
312
+ table_name: str,
313
+ spec: ExternalTableSpec,
314
+ client: SqlProofClient,
315
+ *,
316
+ sample_cache: dict[str, list[object]] | None,
317
+ ) -> list[object]:
318
+ if spec.seed is not None or sample_cache is None:
319
+ return list(spec.sample(client))
320
+ if table_name not in sample_cache:
321
+ sample_cache[table_name] = list(spec.sample(client))
322
+ return sample_cache[table_name]
323
+
324
+
325
+ def _draw_seed_count(size: SizeSpec | None, *, draw: st.DrawFn | None) -> int | None:
326
+ if size is None:
327
+ return None
328
+ if isinstance(size, int):
329
+ return size
330
+ if draw is None:
331
+ msg = "ExternalTableSpec.seed_count strategies require dataset_strategy() generation."
332
+ raise SqlProofUsageError(msg)
333
+ return draw(size)
334
+
335
+
336
+ def _call_external_seed(
337
+ seed: ExternalSeed,
338
+ client: SqlProofClient,
339
+ count: int | None,
340
+ ) -> None:
341
+ if count is None:
342
+ cast(Callable[[SqlProofClient], None], seed)(client)
343
+ return
344
+ cast(Callable[[SqlProofClient, int], None], seed)(client, count)
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlproof.coverage.diversity import diversity_ratio
4
+ from sqlproof.coverage.schema_shape import summarize_dataset_shape
5
+
6
+ __all__ = ["diversity_ratio", "summarize_dataset_shape"]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+
7
+ def diversity_ratio(datasets: list[dict[str, list[dict[str, Any]]]]) -> float:
8
+ if not datasets:
9
+ return 0.0
10
+ fingerprints = {json.dumps(dataset, sort_keys=True, default=str) for dataset in datasets}
11
+ return len(fingerprints) / len(datasets)
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ def coverage_available() -> bool:
5
+ return False
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def summarize_dataset_shape(dataset: dict[str, list[dict[str, Any]]]) -> dict[str, dict[str, int]]:
7
+ return {table: {"rows": len(rows)} for table, rows in dataset.items()}
sqlproof/exceptions.py ADDED
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+
7
+ class SqlProofError(Exception):
8
+ """Base for all SqlProof errors."""
9
+
10
+
11
+ class SqlProofUsageError(SqlProofError):
12
+ """Caller misuse: invalid sizes, conflicting decorators, ambiguous types, etc."""
13
+
14
+
15
+ class SqlProofSchemaError(SqlProofError):
16
+ """Schema parsing or introspection failure."""
17
+
18
+
19
+ class CircularDependencyError(SqlProofSchemaError):
20
+ """FK cycle between distinct tables."""
21
+
22
+
23
+ class SqlProofGenerationError(SqlProofError):
24
+ """Data generation exhausted retry budget for assume-and-retry constraints."""
25
+
26
+
27
+ class SqlProofMappingError(SqlProofError):
28
+ """query_typed could not map a row to the requested model."""
29
+
30
+
31
+ class SqlProofTimeoutError(SqlProofError):
32
+ """A property run exceeded its timeout."""
33
+
34
+
35
+ @dataclass(slots=True)
36
+ class SqlProofPropertyFailure(SqlProofError):
37
+ """The property was falsified."""
38
+
39
+ message: str
40
+ counterexample: dict[str, Any] | None = None
41
+
42
+ def __str__(self) -> str:
43
+ return self.message
44
+
45
+
46
+ class SqlProofContainerError(SqlProofError):
47
+ """testcontainers startup, container died mid-run, etc."""
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlproof.generators.columns import strategy_for_column, strategy_for_type
4
+ from sqlproof.generators.graph import Dataset, SizeSpec, dataset_strategy
5
+ from sqlproof.generators.rows import ColumnContext, ColumnOverrides
6
+ from sqlproof.generators.well_known import emails, phone_numbers, postal_codes, slugs, urls
7
+
8
+ __all__ = [
9
+ "ColumnContext",
10
+ "ColumnOverrides",
11
+ "Dataset",
12
+ "SizeSpec",
13
+ "dataset_strategy",
14
+ "emails",
15
+ "phone_numbers",
16
+ "postal_codes",
17
+ "slugs",
18
+ "strategy_for_column",
19
+ "strategy_for_type",
20
+ "urls",
21
+ ]
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ from decimal import Decimal
4
+ from typing import Any
5
+
6
+ from hypothesis import strategies as st
7
+ from hypothesis.strategies import SearchStrategy
8
+
9
+ from sqlproof.schema.model import Column, PgType
10
+
11
+ POSTGRES_TEXT_ALPHABET = st.characters(
12
+ blacklist_characters="\x00",
13
+ blacklist_categories=("Cs",),
14
+ )
15
+
16
+
17
+ def strategy_for_column(column: Column) -> SearchStrategy[Any]:
18
+ strategy = strategy_for_type(column.type)
19
+ if column.nullable:
20
+ strategy = st.one_of(st.none(), strategy)
21
+ return strategy
22
+
23
+
24
+ def strategy_for_type(pg_type: PgType) -> SearchStrategy[Any]:
25
+ name = pg_type.name.lower()
26
+ if pg_type.kind == "enum":
27
+ return st.sampled_from(pg_type.enum_values)
28
+ if name in {"smallint", "int2"}:
29
+ return st.integers(-32_768, 32_767)
30
+ if name in {"integer", "int", "int4", "serial"}:
31
+ return st.integers(-2_147_483_648, 2_147_483_647)
32
+ if name in {"bigint", "int8", "bigserial"}:
33
+ return st.integers(-(2**63), 2**63 - 1)
34
+ if name in {"numeric", "decimal"}:
35
+ places = pg_type.modifiers[1] if len(pg_type.modifiers) > 1 else 2
36
+ return st.decimals(
37
+ min_value=Decimal("-1000000"),
38
+ max_value=Decimal("1000000"),
39
+ places=places,
40
+ allow_nan=False,
41
+ allow_infinity=False,
42
+ )
43
+ if name in {"real", "float4"}:
44
+ return st.floats(width=32, allow_nan=False, allow_infinity=False)
45
+ if name in {"double precision", "float8"}:
46
+ return st.floats(allow_nan=False, allow_infinity=False)
47
+ if name in {"boolean", "bool"}:
48
+ return st.booleans()
49
+ if name in {"text", "citext"}:
50
+ return _postgres_text(max_size=255)
51
+ if name in {"varchar", "character varying"}:
52
+ max_size = pg_type.modifiers[0] if pg_type.modifiers else 255
53
+ return _postgres_text(max_size=max_size)
54
+ if name in {"char", "character"}:
55
+ size = pg_type.modifiers[0] if pg_type.modifiers else 1
56
+ return _postgres_text(min_size=size, max_size=size)
57
+ if name == "uuid":
58
+ return st.uuids().map(str)
59
+ if name in {
60
+ "timestamp",
61
+ "timestamp without time zone",
62
+ "timestamptz",
63
+ "timestamp with time zone",
64
+ }:
65
+ return st.datetimes()
66
+ if name == "date":
67
+ return st.dates()
68
+ if name in {"time", "timetz"}:
69
+ return st.times()
70
+ if name == "interval":
71
+ return st.timedeltas()
72
+ if name in {"json", "jsonb"}:
73
+ json_scalar = (
74
+ st.none()
75
+ | st.booleans()
76
+ | st.floats(allow_nan=False, allow_infinity=False)
77
+ | _postgres_text()
78
+ )
79
+ return st.recursive(
80
+ json_scalar,
81
+ lambda children: (
82
+ st.lists(children, max_size=5)
83
+ | st.dictionaries(_postgres_text(max_size=20), children, max_size=5)
84
+ ),
85
+ max_leaves=10,
86
+ )
87
+ if name == "bytea":
88
+ return st.binary()
89
+ return _postgres_text(max_size=255)
90
+
91
+
92
+ def _postgres_text(*, min_size: int = 0, max_size: int | None = None) -> SearchStrategy[str]:
93
+ return st.text(alphabet=POSTGRES_TEXT_ALPHABET, min_size=min_size, max_size=max_size)