sqlproof 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlproof/__init__.py +32 -0
- sqlproof/_version.py +1 -0
- sqlproof/cli.py +151 -0
- sqlproof/client.py +159 -0
- sqlproof/config.py +42 -0
- sqlproof/contrib/__init__.py +3 -0
- sqlproof/contrib/supabase.py +136 -0
- sqlproof/core.py +344 -0
- sqlproof/coverage/__init__.py +6 -0
- sqlproof/coverage/diversity.py +11 -0
- sqlproof/coverage/plpgsql.py +5 -0
- sqlproof/coverage/schema_shape.py +7 -0
- sqlproof/exceptions.py +47 -0
- sqlproof/generators/__init__.py +21 -0
- sqlproof/generators/columns.py +93 -0
- sqlproof/generators/constraints.py +181 -0
- sqlproof/generators/functions.py +9 -0
- sqlproof/generators/graph.py +51 -0
- sqlproof/generators/rows.py +153 -0
- sqlproof/generators/sampling.py +15 -0
- sqlproof/generators/well_known.py +59 -0
- sqlproof/pytest_plugin.py +24 -0
- sqlproof/reporter/__init__.py +5 -0
- sqlproof/reporter/console.py +20 -0
- sqlproof/reporter/json_io.py +26 -0
- sqlproof/runners/__init__.py +14 -0
- sqlproof/runners/db.py +48 -0
- sqlproof/runners/migration.py +51 -0
- sqlproof/runners/overload.py +41 -0
- sqlproof/runners/property.py +119 -0
- sqlproof/runners/rls.py +40 -0
- sqlproof/runners/stateful.py +36 -0
- sqlproof/schema/__init__.py +27 -0
- sqlproof/schema/dependency_graph.py +38 -0
- sqlproof/schema/fingerprint.py +34 -0
- sqlproof/schema/introspect.py +229 -0
- sqlproof/schema/model.py +98 -0
- sqlproof/schema/parse_sql.py +206 -0
- sqlproof/testing.py +101 -0
- sqlproof/types.py +34 -0
- sqlproof-0.1.0a1.dist-info/METADATA +248 -0
- sqlproof-0.1.0a1.dist-info/RECORD +44 -0
- sqlproof-0.1.0a1.dist-info/WHEEL +4 -0
- sqlproof-0.1.0a1.dist-info/entry_points.txt +5 -0
sqlproof/core.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Generator, Mapping
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from types import TracebackType
|
|
7
|
+
from typing import Any, Self, cast
|
|
8
|
+
|
|
9
|
+
import psycopg
|
|
10
|
+
from hypothesis import strategies as st
|
|
11
|
+
from hypothesis.strategies import SearchStrategy
|
|
12
|
+
from psycopg.rows import dict_row
|
|
13
|
+
from psycopg.types.json import Json, Jsonb
|
|
14
|
+
|
|
15
|
+
from sqlproof.client import InMemorySqlProofClient, PsycopgSqlProofClient, SqlProofClient
|
|
16
|
+
from sqlproof.config import ExternalSeed, ExternalTableSpec, SqlProofConfig
|
|
17
|
+
from sqlproof.exceptions import SqlProofPropertyFailure, SqlProofUsageError
|
|
18
|
+
from sqlproof.generators.graph import ColumnOverrides, Dataset, SizeSpec, dataset_strategy
|
|
19
|
+
from sqlproof.generators.sampling import draw_example
|
|
20
|
+
from sqlproof.schema.dependency_graph import insertion_order
|
|
21
|
+
from sqlproof.schema.fingerprint import compute
|
|
22
|
+
from sqlproof.schema.introspect import introspect_schema
|
|
23
|
+
from sqlproof.schema.model import Column, SchemaInfo, Table
|
|
24
|
+
from sqlproof.schema.parse_sql import parse_schema_sql
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SqlProof:
|
|
28
|
+
def __init__(self, config: SqlProofConfig) -> None:
|
|
29
|
+
from sqlproof.runners.db import DBManager
|
|
30
|
+
|
|
31
|
+
self.config = config
|
|
32
|
+
self.schema_info = self._load_schema(config)
|
|
33
|
+
self.schema_fingerprint = compute(self.schema_info)
|
|
34
|
+
self._db_manager = DBManager(config) if config.connection_string is not None else None
|
|
35
|
+
self._external_sample_cache: dict[str, list[object]] = {}
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_schema_file(cls, path: str | Path, **kwargs: Any) -> Self:
|
|
39
|
+
return cls(SqlProofConfig(schema_file=path, **kwargs))
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_connection_string(cls, dsn: str, **kwargs: Any) -> Self:
|
|
43
|
+
return cls(SqlProofConfig(connection_string=dsn, **kwargs))
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_config(cls, config: SqlProofConfig) -> Self:
|
|
47
|
+
return cls(config)
|
|
48
|
+
|
|
49
|
+
def customize(self, table: str, **overrides: object) -> Self:
|
|
50
|
+
del table, overrides
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def dataset_strategy(
|
|
54
|
+
self,
|
|
55
|
+
*,
|
|
56
|
+
sizes: Mapping[str, SizeSpec],
|
|
57
|
+
columns: ColumnOverrides | None = None,
|
|
58
|
+
) -> SearchStrategy[Dataset]:
|
|
59
|
+
if self.config.external_tables:
|
|
60
|
+
return self._dataset_strategy_with_external_tables(sizes=sizes, columns=columns)
|
|
61
|
+
return dataset_strategy(
|
|
62
|
+
self.schema_info,
|
|
63
|
+
sizes=sizes,
|
|
64
|
+
columns=columns,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def _dataset_strategy_with_external_tables(
|
|
68
|
+
self,
|
|
69
|
+
*,
|
|
70
|
+
sizes: Mapping[str, SizeSpec],
|
|
71
|
+
columns: ColumnOverrides | None,
|
|
72
|
+
) -> SearchStrategy[Dataset]:
|
|
73
|
+
@st.composite
|
|
74
|
+
def dataset(draw: st.DrawFn) -> Dataset:
|
|
75
|
+
external_parent_rows = self._external_parent_rows(draw=draw)
|
|
76
|
+
return draw(
|
|
77
|
+
dataset_strategy(
|
|
78
|
+
self.schema_info,
|
|
79
|
+
sizes=sizes,
|
|
80
|
+
external_parent_rows=external_parent_rows,
|
|
81
|
+
columns=columns,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return dataset()
|
|
86
|
+
|
|
87
|
+
def run_state_machine(
|
|
88
|
+
self,
|
|
89
|
+
machine_class: type,
|
|
90
|
+
*,
|
|
91
|
+
settings: Any = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Run a `SqlProofStateMachine` subclass against this proof.
|
|
94
|
+
|
|
95
|
+
Binds `self` as the proof for the machine, then dispatches to
|
|
96
|
+
`hypothesis.stateful.run_state_machine_as_test`. Each example gets
|
|
97
|
+
an isolated dataset client; writes from one example are rolled back
|
|
98
|
+
before the next begins.
|
|
99
|
+
"""
|
|
100
|
+
from hypothesis.stateful import run_state_machine_as_test
|
|
101
|
+
|
|
102
|
+
from sqlproof.testing import SqlProofStateMachine
|
|
103
|
+
|
|
104
|
+
if not isinstance(machine_class, type) or not issubclass(
|
|
105
|
+
machine_class, SqlProofStateMachine
|
|
106
|
+
):
|
|
107
|
+
msg = "machine_class must be a subclass of SqlProofStateMachine."
|
|
108
|
+
raise SqlProofUsageError(msg)
|
|
109
|
+
|
|
110
|
+
bound_class = type(
|
|
111
|
+
machine_class.__name__,
|
|
112
|
+
(machine_class,),
|
|
113
|
+
{"_sqlproof_proof": self},
|
|
114
|
+
)
|
|
115
|
+
run_state_machine_as_test(bound_class, settings=settings)
|
|
116
|
+
|
|
117
|
+
@contextmanager
|
|
118
|
+
def client_for_dataset(
|
|
119
|
+
self, dataset: dict[str, list[dict[str, Any]]]
|
|
120
|
+
) -> Generator[SqlProofClient]:
|
|
121
|
+
if self._db_manager is None:
|
|
122
|
+
yield InMemorySqlProofClient(dataset)
|
|
123
|
+
return
|
|
124
|
+
with self._db_manager.acquire() as client:
|
|
125
|
+
client.execute("SAVEPOINT sqlproof_run")
|
|
126
|
+
try:
|
|
127
|
+
_insert_dataset(client, self.schema_info, dataset)
|
|
128
|
+
yield client
|
|
129
|
+
finally:
|
|
130
|
+
client.execute("ROLLBACK TO SAVEPOINT sqlproof_run")
|
|
131
|
+
client.execute("RELEASE SAVEPOINT sqlproof_run")
|
|
132
|
+
|
|
133
|
+
def check(
|
|
134
|
+
self,
|
|
135
|
+
name: str,
|
|
136
|
+
*,
|
|
137
|
+
sizes: Mapping[str, SizeSpec],
|
|
138
|
+
property: Callable[..., None],
|
|
139
|
+
setup: object | None = None,
|
|
140
|
+
runs: int = 100,
|
|
141
|
+
seed: int | None = None,
|
|
142
|
+
timeout_ms: int = 5000,
|
|
143
|
+
commit: bool = False,
|
|
144
|
+
) -> None:
|
|
145
|
+
from sqlproof.runners.property import run_property
|
|
146
|
+
|
|
147
|
+
del name, setup, seed, timeout_ms, commit
|
|
148
|
+
if not callable(property):
|
|
149
|
+
msg = "property must be callable"
|
|
150
|
+
raise TypeError(msg)
|
|
151
|
+
run_property(self, property, sizes=sizes, runs=runs, failure_dir=Path(".sqlproof/failures"))
|
|
152
|
+
|
|
153
|
+
def invariant(
|
|
154
|
+
self,
|
|
155
|
+
name: str,
|
|
156
|
+
*,
|
|
157
|
+
sizes: Mapping[str, SizeSpec],
|
|
158
|
+
query: str,
|
|
159
|
+
expect_empty: bool = True,
|
|
160
|
+
runs: int = 100,
|
|
161
|
+
seed: int | None = None,
|
|
162
|
+
timeout_ms: int = 5000,
|
|
163
|
+
) -> None:
|
|
164
|
+
del seed, timeout_ms
|
|
165
|
+
strategy = self.dataset_strategy(sizes=sizes)
|
|
166
|
+
for run_index in range(runs):
|
|
167
|
+
client = InMemorySqlProofClient(draw_example(strategy))
|
|
168
|
+
rows = client.query(query)
|
|
169
|
+
failed = bool(rows) if expect_empty else not rows
|
|
170
|
+
if failed:
|
|
171
|
+
payload = {
|
|
172
|
+
"property_name": name,
|
|
173
|
+
"runs": run_index + 1,
|
|
174
|
+
"row_context": {},
|
|
175
|
+
"dataset": client.get_generated_data(),
|
|
176
|
+
"schema_fingerprint": self.schema_fingerprint,
|
|
177
|
+
}
|
|
178
|
+
raise SqlProofPropertyFailure(
|
|
179
|
+
f"Invariant {name!r} failed: query returned {len(rows)} rows.",
|
|
180
|
+
counterexample=payload,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def disconnect(self) -> None:
|
|
184
|
+
if self._db_manager is not None:
|
|
185
|
+
self._db_manager.stop()
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
def __enter__(self) -> Self:
|
|
189
|
+
return self
|
|
190
|
+
|
|
191
|
+
def __exit__(
|
|
192
|
+
self,
|
|
193
|
+
exc_type: type[BaseException] | None,
|
|
194
|
+
exc: BaseException | None,
|
|
195
|
+
traceback: TracebackType | None,
|
|
196
|
+
) -> None:
|
|
197
|
+
self.disconnect()
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def _load_schema(config: SqlProofConfig) -> SchemaInfo:
|
|
201
|
+
if config.schema_file is not None:
|
|
202
|
+
path = Path(config.schema_file)
|
|
203
|
+
return parse_schema_sql(path.read_text(encoding="utf-8"), schema=config.schema)
|
|
204
|
+
if config.connection_string is not None:
|
|
205
|
+
connection = psycopg.connect(
|
|
206
|
+
conninfo=config.connection_string,
|
|
207
|
+
autocommit=True,
|
|
208
|
+
row_factory=cast(Any, dict_row),
|
|
209
|
+
)
|
|
210
|
+
try:
|
|
211
|
+
return introspect_schema(connection, schema=config.schema)
|
|
212
|
+
finally:
|
|
213
|
+
connection.close()
|
|
214
|
+
return SchemaInfo()
|
|
215
|
+
|
|
216
|
+
def _external_parent_rows(
|
|
217
|
+
self,
|
|
218
|
+
*,
|
|
219
|
+
draw: st.DrawFn | None = None,
|
|
220
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
221
|
+
if not self.config.external_tables:
|
|
222
|
+
return {}
|
|
223
|
+
if self.config.connection_string is None:
|
|
224
|
+
msg = "external_tables requires a connection_string-backed SqlProof instance."
|
|
225
|
+
raise SqlProofUsageError(msg)
|
|
226
|
+
|
|
227
|
+
connection = psycopg.connect(
|
|
228
|
+
conninfo=self.config.connection_string,
|
|
229
|
+
autocommit=True,
|
|
230
|
+
row_factory=cast(Any, dict_row),
|
|
231
|
+
)
|
|
232
|
+
try:
|
|
233
|
+
client = PsycopgSqlProofClient(connection)
|
|
234
|
+
return _external_parent_rows(
|
|
235
|
+
self.config.external_tables,
|
|
236
|
+
client,
|
|
237
|
+
draw=draw,
|
|
238
|
+
sample_cache=self._external_sample_cache,
|
|
239
|
+
)
|
|
240
|
+
finally:
|
|
241
|
+
connection.close()
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _insert_dataset(
|
|
245
|
+
client: SqlProofClient,
|
|
246
|
+
schema_info: SchemaInfo,
|
|
247
|
+
dataset: dict[str, list[dict[str, Any]]],
|
|
248
|
+
) -> None:
|
|
249
|
+
for table in insertion_order(schema_info.tables):
|
|
250
|
+
rows = dataset.get(table.name, [])
|
|
251
|
+
for row in rows:
|
|
252
|
+
if not row:
|
|
253
|
+
continue
|
|
254
|
+
columns = list(row)
|
|
255
|
+
placeholders = ", ".join(["%s"] * len(columns))
|
|
256
|
+
column_sql = ", ".join(_quote_identifier(column) for column in columns)
|
|
257
|
+
table_sql = f"{_quote_identifier(table.schema)}.{_quote_identifier(table.name)}"
|
|
258
|
+
sql = f"INSERT INTO {table_sql} ({column_sql}) VALUES ({placeholders})"
|
|
259
|
+
values = [_adapt_insert_value(table, column, row[column]) for column in columns]
|
|
260
|
+
client.execute(sql, *values)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _quote_identifier(identifier: str) -> str:
|
|
264
|
+
return '"' + identifier.replace('"', '""') + '"'
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _adapt_insert_value(table: Table, column_name: str, value: Any) -> object:
|
|
268
|
+
column = table.column(column_name)
|
|
269
|
+
type_name = _base_type_name(column)
|
|
270
|
+
if type_name == "jsonb":
|
|
271
|
+
return Jsonb(value)
|
|
272
|
+
if type_name == "json":
|
|
273
|
+
return Json(value)
|
|
274
|
+
return value
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _base_type_name(column: Column) -> str:
|
|
278
|
+
pg_type = column.type
|
|
279
|
+
while pg_type.base is not None:
|
|
280
|
+
pg_type = pg_type.base
|
|
281
|
+
return pg_type.name.lower()
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _external_parent_rows(
|
|
285
|
+
specs: Mapping[str, ExternalTableSpec],
|
|
286
|
+
client: SqlProofClient,
|
|
287
|
+
*,
|
|
288
|
+
draw: st.DrawFn | None = None,
|
|
289
|
+
sample_cache: dict[str, list[object]] | None = None,
|
|
290
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
291
|
+
rows_by_table: dict[str, list[dict[str, Any]]] = {}
|
|
292
|
+
for table_name, spec in specs.items():
|
|
293
|
+
seed_count = _draw_seed_count(spec.seed_count, draw=draw)
|
|
294
|
+
if spec.seed is not None:
|
|
295
|
+
_call_external_seed(spec.seed, client, seed_count)
|
|
296
|
+
sampled_values = _sample_external_values(
|
|
297
|
+
table_name,
|
|
298
|
+
spec,
|
|
299
|
+
client,
|
|
300
|
+
sample_cache=sample_cache,
|
|
301
|
+
)
|
|
302
|
+
if seed_count is not None:
|
|
303
|
+
sampled_values = sampled_values[:seed_count]
|
|
304
|
+
rows = [{spec.primary_key: value} for value in sampled_values]
|
|
305
|
+
rows_by_table[table_name] = rows
|
|
306
|
+
if "." in table_name:
|
|
307
|
+
rows_by_table.setdefault(table_name.rsplit(".", 1)[1], rows)
|
|
308
|
+
return rows_by_table
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _sample_external_values(
|
|
312
|
+
table_name: str,
|
|
313
|
+
spec: ExternalTableSpec,
|
|
314
|
+
client: SqlProofClient,
|
|
315
|
+
*,
|
|
316
|
+
sample_cache: dict[str, list[object]] | None,
|
|
317
|
+
) -> list[object]:
|
|
318
|
+
if spec.seed is not None or sample_cache is None:
|
|
319
|
+
return list(spec.sample(client))
|
|
320
|
+
if table_name not in sample_cache:
|
|
321
|
+
sample_cache[table_name] = list(spec.sample(client))
|
|
322
|
+
return sample_cache[table_name]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _draw_seed_count(size: SizeSpec | None, *, draw: st.DrawFn | None) -> int | None:
|
|
326
|
+
if size is None:
|
|
327
|
+
return None
|
|
328
|
+
if isinstance(size, int):
|
|
329
|
+
return size
|
|
330
|
+
if draw is None:
|
|
331
|
+
msg = "ExternalTableSpec.seed_count strategies require dataset_strategy() generation."
|
|
332
|
+
raise SqlProofUsageError(msg)
|
|
333
|
+
return draw(size)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _call_external_seed(
|
|
337
|
+
seed: ExternalSeed,
|
|
338
|
+
client: SqlProofClient,
|
|
339
|
+
count: int | None,
|
|
340
|
+
) -> None:
|
|
341
|
+
if count is None:
|
|
342
|
+
cast(Callable[[SqlProofClient], None], seed)(client)
|
|
343
|
+
return
|
|
344
|
+
cast(Callable[[SqlProofClient, int], None], seed)(client, count)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def diversity_ratio(datasets: list[dict[str, list[dict[str, Any]]]]) -> float:
|
|
8
|
+
if not datasets:
|
|
9
|
+
return 0.0
|
|
10
|
+
fingerprints = {json.dumps(dataset, sort_keys=True, default=str) for dataset in datasets}
|
|
11
|
+
return len(fingerprints) / len(datasets)
|
sqlproof/exceptions.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SqlProofError(Exception):
|
|
8
|
+
"""Base for all SqlProof errors."""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SqlProofUsageError(SqlProofError):
|
|
12
|
+
"""Caller misuse: invalid sizes, conflicting decorators, ambiguous types, etc."""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SqlProofSchemaError(SqlProofError):
|
|
16
|
+
"""Schema parsing or introspection failure."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CircularDependencyError(SqlProofSchemaError):
|
|
20
|
+
"""FK cycle between distinct tables."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SqlProofGenerationError(SqlProofError):
|
|
24
|
+
"""Data generation exhausted retry budget for assume-and-retry constraints."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SqlProofMappingError(SqlProofError):
|
|
28
|
+
"""query_typed could not map a row to the requested model."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SqlProofTimeoutError(SqlProofError):
|
|
32
|
+
"""A property run exceeded its timeout."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(slots=True)
|
|
36
|
+
class SqlProofPropertyFailure(SqlProofError):
|
|
37
|
+
"""The property was falsified."""
|
|
38
|
+
|
|
39
|
+
message: str
|
|
40
|
+
counterexample: dict[str, Any] | None = None
|
|
41
|
+
|
|
42
|
+
def __str__(self) -> str:
|
|
43
|
+
return self.message
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SqlProofContainerError(SqlProofError):
|
|
47
|
+
"""testcontainers startup, container died mid-run, etc."""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from sqlproof.generators.columns import strategy_for_column, strategy_for_type
|
|
4
|
+
from sqlproof.generators.graph import Dataset, SizeSpec, dataset_strategy
|
|
5
|
+
from sqlproof.generators.rows import ColumnContext, ColumnOverrides
|
|
6
|
+
from sqlproof.generators.well_known import emails, phone_numbers, postal_codes, slugs, urls
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ColumnContext",
|
|
10
|
+
"ColumnOverrides",
|
|
11
|
+
"Dataset",
|
|
12
|
+
"SizeSpec",
|
|
13
|
+
"dataset_strategy",
|
|
14
|
+
"emails",
|
|
15
|
+
"phone_numbers",
|
|
16
|
+
"postal_codes",
|
|
17
|
+
"slugs",
|
|
18
|
+
"strategy_for_column",
|
|
19
|
+
"strategy_for_type",
|
|
20
|
+
"urls",
|
|
21
|
+
]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from hypothesis import strategies as st
|
|
7
|
+
from hypothesis.strategies import SearchStrategy
|
|
8
|
+
|
|
9
|
+
from sqlproof.schema.model import Column, PgType
|
|
10
|
+
|
|
11
|
+
POSTGRES_TEXT_ALPHABET = st.characters(
|
|
12
|
+
blacklist_characters="\x00",
|
|
13
|
+
blacklist_categories=("Cs",),
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def strategy_for_column(column: Column) -> SearchStrategy[Any]:
|
|
18
|
+
strategy = strategy_for_type(column.type)
|
|
19
|
+
if column.nullable:
|
|
20
|
+
strategy = st.one_of(st.none(), strategy)
|
|
21
|
+
return strategy
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def strategy_for_type(pg_type: PgType) -> SearchStrategy[Any]:
|
|
25
|
+
name = pg_type.name.lower()
|
|
26
|
+
if pg_type.kind == "enum":
|
|
27
|
+
return st.sampled_from(pg_type.enum_values)
|
|
28
|
+
if name in {"smallint", "int2"}:
|
|
29
|
+
return st.integers(-32_768, 32_767)
|
|
30
|
+
if name in {"integer", "int", "int4", "serial"}:
|
|
31
|
+
return st.integers(-2_147_483_648, 2_147_483_647)
|
|
32
|
+
if name in {"bigint", "int8", "bigserial"}:
|
|
33
|
+
return st.integers(-(2**63), 2**63 - 1)
|
|
34
|
+
if name in {"numeric", "decimal"}:
|
|
35
|
+
places = pg_type.modifiers[1] if len(pg_type.modifiers) > 1 else 2
|
|
36
|
+
return st.decimals(
|
|
37
|
+
min_value=Decimal("-1000000"),
|
|
38
|
+
max_value=Decimal("1000000"),
|
|
39
|
+
places=places,
|
|
40
|
+
allow_nan=False,
|
|
41
|
+
allow_infinity=False,
|
|
42
|
+
)
|
|
43
|
+
if name in {"real", "float4"}:
|
|
44
|
+
return st.floats(width=32, allow_nan=False, allow_infinity=False)
|
|
45
|
+
if name in {"double precision", "float8"}:
|
|
46
|
+
return st.floats(allow_nan=False, allow_infinity=False)
|
|
47
|
+
if name in {"boolean", "bool"}:
|
|
48
|
+
return st.booleans()
|
|
49
|
+
if name in {"text", "citext"}:
|
|
50
|
+
return _postgres_text(max_size=255)
|
|
51
|
+
if name in {"varchar", "character varying"}:
|
|
52
|
+
max_size = pg_type.modifiers[0] if pg_type.modifiers else 255
|
|
53
|
+
return _postgres_text(max_size=max_size)
|
|
54
|
+
if name in {"char", "character"}:
|
|
55
|
+
size = pg_type.modifiers[0] if pg_type.modifiers else 1
|
|
56
|
+
return _postgres_text(min_size=size, max_size=size)
|
|
57
|
+
if name == "uuid":
|
|
58
|
+
return st.uuids().map(str)
|
|
59
|
+
if name in {
|
|
60
|
+
"timestamp",
|
|
61
|
+
"timestamp without time zone",
|
|
62
|
+
"timestamptz",
|
|
63
|
+
"timestamp with time zone",
|
|
64
|
+
}:
|
|
65
|
+
return st.datetimes()
|
|
66
|
+
if name == "date":
|
|
67
|
+
return st.dates()
|
|
68
|
+
if name in {"time", "timetz"}:
|
|
69
|
+
return st.times()
|
|
70
|
+
if name == "interval":
|
|
71
|
+
return st.timedeltas()
|
|
72
|
+
if name in {"json", "jsonb"}:
|
|
73
|
+
json_scalar = (
|
|
74
|
+
st.none()
|
|
75
|
+
| st.booleans()
|
|
76
|
+
| st.floats(allow_nan=False, allow_infinity=False)
|
|
77
|
+
| _postgres_text()
|
|
78
|
+
)
|
|
79
|
+
return st.recursive(
|
|
80
|
+
json_scalar,
|
|
81
|
+
lambda children: (
|
|
82
|
+
st.lists(children, max_size=5)
|
|
83
|
+
| st.dictionaries(_postgres_text(max_size=20), children, max_size=5)
|
|
84
|
+
),
|
|
85
|
+
max_leaves=10,
|
|
86
|
+
)
|
|
87
|
+
if name == "bytea":
|
|
88
|
+
return st.binary()
|
|
89
|
+
return _postgres_text(max_size=255)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _postgres_text(*, min_size: int = 0, max_size: int | None = None) -> SearchStrategy[str]:
|
|
93
|
+
return st.text(alphabet=POSTGRES_TEXT_ALPHABET, min_size=min_size, max_size=max_size)
|