plain.postgres 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/postgres/CHANGELOG.md +1028 -0
- plain/postgres/README.md +925 -0
- plain/postgres/__init__.py +120 -0
- plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
- plain/postgres/aggregates.py +236 -0
- plain/postgres/backups/__init__.py +0 -0
- plain/postgres/backups/cli.py +148 -0
- plain/postgres/backups/clients.py +94 -0
- plain/postgres/backups/core.py +172 -0
- plain/postgres/base.py +1415 -0
- plain/postgres/cli/__init__.py +3 -0
- plain/postgres/cli/db.py +142 -0
- plain/postgres/cli/migrations.py +1085 -0
- plain/postgres/config.py +18 -0
- plain/postgres/connection.py +1331 -0
- plain/postgres/connections.py +77 -0
- plain/postgres/constants.py +13 -0
- plain/postgres/constraints.py +495 -0
- plain/postgres/database_url.py +94 -0
- plain/postgres/db.py +59 -0
- plain/postgres/default_settings.py +38 -0
- plain/postgres/deletion.py +475 -0
- plain/postgres/dialect.py +640 -0
- plain/postgres/entrypoints.py +4 -0
- plain/postgres/enums.py +103 -0
- plain/postgres/exceptions.py +217 -0
- plain/postgres/expressions.py +1912 -0
- plain/postgres/fields/__init__.py +2118 -0
- plain/postgres/fields/encrypted.py +354 -0
- plain/postgres/fields/json.py +413 -0
- plain/postgres/fields/mixins.py +30 -0
- plain/postgres/fields/related.py +1192 -0
- plain/postgres/fields/related_descriptors.py +290 -0
- plain/postgres/fields/related_lookups.py +223 -0
- plain/postgres/fields/related_managers.py +661 -0
- plain/postgres/fields/reverse_descriptors.py +229 -0
- plain/postgres/fields/reverse_related.py +328 -0
- plain/postgres/fields/timezones.py +143 -0
- plain/postgres/forms.py +773 -0
- plain/postgres/functions/__init__.py +189 -0
- plain/postgres/functions/comparison.py +127 -0
- plain/postgres/functions/datetime.py +454 -0
- plain/postgres/functions/math.py +140 -0
- plain/postgres/functions/mixins.py +59 -0
- plain/postgres/functions/text.py +282 -0
- plain/postgres/functions/window.py +125 -0
- plain/postgres/indexes.py +286 -0
- plain/postgres/lookups.py +758 -0
- plain/postgres/meta.py +584 -0
- plain/postgres/migrations/__init__.py +53 -0
- plain/postgres/migrations/autodetector.py +1379 -0
- plain/postgres/migrations/exceptions.py +54 -0
- plain/postgres/migrations/executor.py +188 -0
- plain/postgres/migrations/graph.py +364 -0
- plain/postgres/migrations/loader.py +377 -0
- plain/postgres/migrations/migration.py +180 -0
- plain/postgres/migrations/operations/__init__.py +34 -0
- plain/postgres/migrations/operations/base.py +139 -0
- plain/postgres/migrations/operations/fields.py +373 -0
- plain/postgres/migrations/operations/models.py +798 -0
- plain/postgres/migrations/operations/special.py +184 -0
- plain/postgres/migrations/optimizer.py +74 -0
- plain/postgres/migrations/questioner.py +340 -0
- plain/postgres/migrations/recorder.py +119 -0
- plain/postgres/migrations/serializer.py +378 -0
- plain/postgres/migrations/state.py +882 -0
- plain/postgres/migrations/utils.py +147 -0
- plain/postgres/migrations/writer.py +302 -0
- plain/postgres/options.py +207 -0
- plain/postgres/otel.py +231 -0
- plain/postgres/preflight.py +336 -0
- plain/postgres/query.py +2242 -0
- plain/postgres/query_utils.py +456 -0
- plain/postgres/registry.py +217 -0
- plain/postgres/schema.py +1885 -0
- plain/postgres/sql/__init__.py +40 -0
- plain/postgres/sql/compiler.py +1869 -0
- plain/postgres/sql/constants.py +22 -0
- plain/postgres/sql/datastructures.py +222 -0
- plain/postgres/sql/query.py +2947 -0
- plain/postgres/sql/where.py +374 -0
- plain/postgres/test/__init__.py +0 -0
- plain/postgres/test/pytest.py +117 -0
- plain/postgres/test/utils.py +18 -0
- plain/postgres/transaction.py +222 -0
- plain/postgres/types.py +92 -0
- plain/postgres/types.pyi +751 -0
- plain/postgres/utils.py +345 -0
- plain_postgres-0.84.0.dist-info/METADATA +937 -0
- plain_postgres-0.84.0.dist-info/RECORD +93 -0
- plain_postgres-0.84.0.dist-info/WHEEL +4 -0
- plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
- plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Code to manage the creation and SQL rendering of 'where' constraints.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import operator
|
|
8
|
+
from functools import cached_property, reduce
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from plain.postgres.exceptions import EmptyResultSet, FullResultSet
|
|
12
|
+
from plain.postgres.expressions import Case, ResolvableExpression, When
|
|
13
|
+
from plain.postgres.lookups import Exact
|
|
14
|
+
from plain.utils import tree
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from plain.postgres.connection import DatabaseConnection
|
|
18
|
+
from plain.postgres.lookups import Lookup
|
|
19
|
+
from plain.postgres.sql.compiler import SQLCompiler
|
|
20
|
+
|
|
21
|
+
# Connection types
|
|
22
|
+
AND = "AND"
|
|
23
|
+
OR = "OR"
|
|
24
|
+
XOR = "XOR"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class WhereNode(tree.Node):
|
|
28
|
+
"""
|
|
29
|
+
An SQL WHERE clause.
|
|
30
|
+
|
|
31
|
+
The class is tied to the Query class that created it (in order to create
|
|
32
|
+
the correct SQL).
|
|
33
|
+
|
|
34
|
+
A child is usually an expression producing boolean values. Most likely the
|
|
35
|
+
expression is a Lookup instance.
|
|
36
|
+
|
|
37
|
+
However, a child could also be any class with as_sql() and either
|
|
38
|
+
relabeled_clone() method or relabel_aliases() and clone() methods and
|
|
39
|
+
contains_aggregate attribute.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
default = AND
|
|
43
|
+
resolved = False
|
|
44
|
+
conditional = True
|
|
45
|
+
|
|
46
|
+
def split_having_qualify(
|
|
47
|
+
self, negated: bool = False, must_group_by: bool = False
|
|
48
|
+
) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
|
|
49
|
+
"""
|
|
50
|
+
Return three possibly None nodes: one for those parts of self that
|
|
51
|
+
should be included in the WHERE clause, one for those parts of self
|
|
52
|
+
that must be included in the HAVING clause, and one for those parts
|
|
53
|
+
that refer to window functions.
|
|
54
|
+
"""
|
|
55
|
+
if not self.contains_aggregate and not self.contains_over_clause:
|
|
56
|
+
return self, None, None
|
|
57
|
+
in_negated = negated ^ self.negated
|
|
58
|
+
# Whether or not children must be connected in the same filtering
|
|
59
|
+
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
|
|
60
|
+
must_remain_connected = (
|
|
61
|
+
(in_negated and self.connector == AND)
|
|
62
|
+
or (not in_negated and self.connector == OR)
|
|
63
|
+
or self.connector == XOR
|
|
64
|
+
)
|
|
65
|
+
if (
|
|
66
|
+
must_remain_connected
|
|
67
|
+
and self.contains_aggregate
|
|
68
|
+
and not self.contains_over_clause
|
|
69
|
+
):
|
|
70
|
+
# It's must cheaper to short-circuit and stash everything in the
|
|
71
|
+
# HAVING clause than split children if possible.
|
|
72
|
+
return None, self, None
|
|
73
|
+
where_parts = []
|
|
74
|
+
having_parts = []
|
|
75
|
+
qualify_parts = []
|
|
76
|
+
for c in self.children:
|
|
77
|
+
if hasattr(c, "split_having_qualify"):
|
|
78
|
+
where_part, having_part, qualify_part = c.split_having_qualify(
|
|
79
|
+
in_negated, must_group_by
|
|
80
|
+
)
|
|
81
|
+
if where_part is not None:
|
|
82
|
+
where_parts.append(where_part)
|
|
83
|
+
if having_part is not None:
|
|
84
|
+
having_parts.append(having_part)
|
|
85
|
+
if qualify_part is not None:
|
|
86
|
+
qualify_parts.append(qualify_part)
|
|
87
|
+
elif c.contains_over_clause:
|
|
88
|
+
qualify_parts.append(c)
|
|
89
|
+
elif c.contains_aggregate:
|
|
90
|
+
having_parts.append(c)
|
|
91
|
+
else:
|
|
92
|
+
where_parts.append(c)
|
|
93
|
+
if must_remain_connected and qualify_parts:
|
|
94
|
+
# Disjunctive heterogeneous predicates can be pushed down to
|
|
95
|
+
# qualify as long as no conditional aggregation is involved.
|
|
96
|
+
if not where_parts or (where_parts and not must_group_by):
|
|
97
|
+
return None, None, self
|
|
98
|
+
elif where_parts:
|
|
99
|
+
# In theory this should only be enforced when dealing with
|
|
100
|
+
# where_parts containing predicates against multi-valued
|
|
101
|
+
# relationships that could affect aggregation results but this
|
|
102
|
+
# is complex to infer properly.
|
|
103
|
+
raise NotImplementedError(
|
|
104
|
+
"Heterogeneous disjunctive predicates against window functions are "
|
|
105
|
+
"not implemented when performing conditional aggregation."
|
|
106
|
+
)
|
|
107
|
+
where_node = (
|
|
108
|
+
self.create(where_parts, self.connector, self.negated)
|
|
109
|
+
if where_parts
|
|
110
|
+
else None
|
|
111
|
+
)
|
|
112
|
+
having_node = (
|
|
113
|
+
self.create(having_parts, self.connector, self.negated)
|
|
114
|
+
if having_parts
|
|
115
|
+
else None
|
|
116
|
+
)
|
|
117
|
+
qualify_node = (
|
|
118
|
+
self.create(qualify_parts, self.connector, self.negated)
|
|
119
|
+
if qualify_parts
|
|
120
|
+
else None
|
|
121
|
+
)
|
|
122
|
+
return where_node, having_node, qualify_node
|
|
123
|
+
|
|
124
|
+
def as_sql(
|
|
125
|
+
self, compiler: SQLCompiler, connection: DatabaseConnection
|
|
126
|
+
) -> tuple[str, list[Any]]:
|
|
127
|
+
"""
|
|
128
|
+
Return the SQL version of the where clause and the value to be
|
|
129
|
+
substituted in. Return '', [] if this node matches everything,
|
|
130
|
+
None, [] if this node is empty, and raise EmptyResultSet if this
|
|
131
|
+
node can't match anything.
|
|
132
|
+
"""
|
|
133
|
+
result = []
|
|
134
|
+
result_params = []
|
|
135
|
+
if self.connector == AND:
|
|
136
|
+
full_needed, empty_needed = len(self.children), 1
|
|
137
|
+
else:
|
|
138
|
+
full_needed, empty_needed = 1, len(self.children)
|
|
139
|
+
|
|
140
|
+
if self.connector == XOR:
|
|
141
|
+
# PostgreSQL doesn't have a native XOR operator, so convert:
|
|
142
|
+
# a XOR b XOR c XOR ...
|
|
143
|
+
# to:
|
|
144
|
+
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1
|
|
145
|
+
lhs = self.__class__(self.children, OR)
|
|
146
|
+
rhs_sum = reduce(
|
|
147
|
+
operator.add,
|
|
148
|
+
(Case(When(c, then=1), default=0) for c in self.children),
|
|
149
|
+
)
|
|
150
|
+
rhs = Exact(1, rhs_sum)
|
|
151
|
+
return self.__class__([lhs, rhs], AND, self.negated).as_sql(
|
|
152
|
+
compiler, connection
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
for child in self.children:
|
|
156
|
+
try:
|
|
157
|
+
sql, params = compiler.compile(child)
|
|
158
|
+
except EmptyResultSet:
|
|
159
|
+
empty_needed -= 1
|
|
160
|
+
except FullResultSet:
|
|
161
|
+
full_needed -= 1
|
|
162
|
+
else:
|
|
163
|
+
if sql:
|
|
164
|
+
result.append(sql)
|
|
165
|
+
result_params.extend(params)
|
|
166
|
+
else:
|
|
167
|
+
full_needed -= 1
|
|
168
|
+
# Check if this node matches nothing or everything.
|
|
169
|
+
# First check the amount of full nodes and empty nodes
|
|
170
|
+
# to make this node empty/full.
|
|
171
|
+
# Now, check if this node is full/empty using the
|
|
172
|
+
# counts.
|
|
173
|
+
if empty_needed == 0:
|
|
174
|
+
if self.negated:
|
|
175
|
+
raise FullResultSet
|
|
176
|
+
else:
|
|
177
|
+
raise EmptyResultSet
|
|
178
|
+
if full_needed == 0:
|
|
179
|
+
if self.negated:
|
|
180
|
+
raise EmptyResultSet
|
|
181
|
+
else:
|
|
182
|
+
raise FullResultSet
|
|
183
|
+
conn = f" {self.connector} "
|
|
184
|
+
sql_string = conn.join(result)
|
|
185
|
+
if not sql_string:
|
|
186
|
+
raise FullResultSet
|
|
187
|
+
if self.negated:
|
|
188
|
+
sql_string = f"NOT ({sql_string})"
|
|
189
|
+
elif len(result) > 1 or self.resolved:
|
|
190
|
+
sql_string = f"({sql_string})"
|
|
191
|
+
return sql_string, result_params
|
|
192
|
+
|
|
193
|
+
def get_group_by_cols(self) -> list[Any]:
|
|
194
|
+
cols = []
|
|
195
|
+
for child in self.children:
|
|
196
|
+
cols.extend(child.get_group_by_cols())
|
|
197
|
+
return cols
|
|
198
|
+
|
|
199
|
+
def get_source_expressions(self) -> list[Any]:
|
|
200
|
+
return self.children[:]
|
|
201
|
+
|
|
202
|
+
def set_source_expressions(self, children: list[Any]) -> None:
|
|
203
|
+
assert len(children) == len(self.children)
|
|
204
|
+
self.children = children
|
|
205
|
+
|
|
206
|
+
def relabel_aliases(self, change_map: dict[str, str]) -> None:
|
|
207
|
+
"""
|
|
208
|
+
Relabel the alias values of any children. 'change_map' is a dictionary
|
|
209
|
+
mapping old (current) alias values to the new values.
|
|
210
|
+
"""
|
|
211
|
+
for pos, child in enumerate(self.children):
|
|
212
|
+
if hasattr(child, "relabel_aliases"):
|
|
213
|
+
# For example another WhereNode
|
|
214
|
+
child.relabel_aliases(change_map)
|
|
215
|
+
elif hasattr(child, "relabeled_clone"):
|
|
216
|
+
self.children[pos] = child.relabeled_clone(change_map)
|
|
217
|
+
|
|
218
|
+
def clone(self) -> WhereNode:
|
|
219
|
+
clone = self.create(connector=self.connector, negated=self.negated)
|
|
220
|
+
for child in self.children:
|
|
221
|
+
if hasattr(child, "clone"):
|
|
222
|
+
child = child.clone()
|
|
223
|
+
clone.children.append(child)
|
|
224
|
+
return clone
|
|
225
|
+
|
|
226
|
+
def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
|
|
227
|
+
clone = self.clone()
|
|
228
|
+
clone.relabel_aliases(change_map)
|
|
229
|
+
return clone
|
|
230
|
+
|
|
231
|
+
def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
|
|
232
|
+
if replacement := replacements.get(self):
|
|
233
|
+
return replacement
|
|
234
|
+
clone = self.create(connector=self.connector, negated=self.negated)
|
|
235
|
+
for child in self.children:
|
|
236
|
+
clone.children.append(child.replace_expressions(replacements))
|
|
237
|
+
return clone
|
|
238
|
+
|
|
239
|
+
def get_refs(self) -> set[Any]:
|
|
240
|
+
refs = set()
|
|
241
|
+
for child in self.children:
|
|
242
|
+
refs |= child.get_refs()
|
|
243
|
+
return refs
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def _contains_aggregate(cls, obj: Any) -> bool:
|
|
247
|
+
if isinstance(obj, tree.Node):
|
|
248
|
+
return any(cls._contains_aggregate(c) for c in obj.children)
|
|
249
|
+
return obj.contains_aggregate
|
|
250
|
+
|
|
251
|
+
@cached_property
|
|
252
|
+
def contains_aggregate(self) -> bool:
|
|
253
|
+
return self._contains_aggregate(self)
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def _contains_over_clause(cls, obj: Any) -> bool:
|
|
257
|
+
if isinstance(obj, tree.Node):
|
|
258
|
+
return any(cls._contains_over_clause(c) for c in obj.children)
|
|
259
|
+
return obj.contains_over_clause
|
|
260
|
+
|
|
261
|
+
@cached_property
|
|
262
|
+
def contains_over_clause(self) -> bool:
|
|
263
|
+
return self._contains_over_clause(self)
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def is_summary(self) -> bool:
|
|
267
|
+
return any(child.is_summary for child in self.children)
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
|
|
271
|
+
if isinstance(expr, ResolvableExpression):
|
|
272
|
+
expr = expr.resolve_expression(query, *args, **kwargs)
|
|
273
|
+
return expr
|
|
274
|
+
|
|
275
|
+
@classmethod
|
|
276
|
+
def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
|
|
277
|
+
if hasattr(node, "children"):
|
|
278
|
+
for child in node.children:
|
|
279
|
+
cls._resolve_node(child, query, *args, **kwargs)
|
|
280
|
+
if hasattr(node, "lhs"):
|
|
281
|
+
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
|
|
282
|
+
if hasattr(node, "rhs"):
|
|
283
|
+
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
|
|
284
|
+
|
|
285
|
+
def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
|
|
286
|
+
clone = self.clone()
|
|
287
|
+
clone._resolve_node(clone, *args, **kwargs)
|
|
288
|
+
clone.resolved = True
|
|
289
|
+
return clone
|
|
290
|
+
|
|
291
|
+
@cached_property
|
|
292
|
+
def output_field(self) -> Any:
|
|
293
|
+
from plain.postgres.fields import BooleanField
|
|
294
|
+
|
|
295
|
+
return BooleanField()
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def _output_field_or_none(self) -> Any:
|
|
299
|
+
return self.output_field
|
|
300
|
+
|
|
301
|
+
def select_format(
|
|
302
|
+
self, compiler: SQLCompiler, sql: str, params: list[Any]
|
|
303
|
+
) -> tuple[str, list[Any]]:
|
|
304
|
+
# Boolean expressions work directly in SELECT
|
|
305
|
+
return sql, params
|
|
306
|
+
|
|
307
|
+
def get_db_converters(self, connection: DatabaseConnection) -> list[Any]:
|
|
308
|
+
return self.output_field.get_db_converters(connection)
|
|
309
|
+
|
|
310
|
+
def get_lookup(self, lookup: str) -> type[Lookup] | None:
|
|
311
|
+
return self.output_field.get_lookup(lookup)
|
|
312
|
+
|
|
313
|
+
def leaves(self) -> Any:
|
|
314
|
+
for child in self.children:
|
|
315
|
+
if isinstance(child, WhereNode):
|
|
316
|
+
yield from child.leaves()
|
|
317
|
+
else:
|
|
318
|
+
yield child
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class NothingNode:
|
|
322
|
+
"""A node that matches nothing."""
|
|
323
|
+
|
|
324
|
+
contains_aggregate = False
|
|
325
|
+
contains_over_clause = False
|
|
326
|
+
|
|
327
|
+
def as_sql(
|
|
328
|
+
self,
|
|
329
|
+
compiler: SQLCompiler | None = None,
|
|
330
|
+
connection: DatabaseConnection | None = None,
|
|
331
|
+
) -> tuple[str, list[Any]]:
|
|
332
|
+
raise EmptyResultSet
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class ExtraWhere:
|
|
336
|
+
# The contents are a black box - assume no aggregates or windows are used.
|
|
337
|
+
contains_aggregate = False
|
|
338
|
+
contains_over_clause = False
|
|
339
|
+
|
|
340
|
+
def __init__(self, sqls: list[str], params: list[Any] | None):
|
|
341
|
+
self.sqls = sqls
|
|
342
|
+
self.params = params
|
|
343
|
+
|
|
344
|
+
def as_sql(
|
|
345
|
+
self,
|
|
346
|
+
compiler: SQLCompiler | None = None,
|
|
347
|
+
connection: DatabaseConnection | None = None,
|
|
348
|
+
) -> tuple[str, list[Any]]:
|
|
349
|
+
sqls = [f"({sql})" for sql in self.sqls]
|
|
350
|
+
return " AND ".join(sqls), list(self.params or ())
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class SubqueryConstraint:
|
|
354
|
+
# Even if aggregates or windows would be used in a subquery,
|
|
355
|
+
# the outer query isn't interested about those.
|
|
356
|
+
contains_aggregate = False
|
|
357
|
+
contains_over_clause = False
|
|
358
|
+
|
|
359
|
+
def __init__(
|
|
360
|
+
self, alias: str, columns: list[str], targets: list[Any], query_object: Any
|
|
361
|
+
):
|
|
362
|
+
self.alias = alias
|
|
363
|
+
self.columns = columns
|
|
364
|
+
self.targets = targets
|
|
365
|
+
query_object.clear_ordering(clear_default=True)
|
|
366
|
+
self.query_object = query_object
|
|
367
|
+
|
|
368
|
+
def as_sql(
|
|
369
|
+
self, compiler: SQLCompiler, connection: DatabaseConnection
|
|
370
|
+
) -> tuple[str, list[Any]]:
|
|
371
|
+
query = self.query_object
|
|
372
|
+
query.set_values(self.targets)
|
|
373
|
+
query_compiler = query.get_compiler()
|
|
374
|
+
return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)
|
|
File without changes
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from plain.postgres.otel import suppress_db_tracing
|
|
10
|
+
from plain.signals import request_finished, request_started
|
|
11
|
+
|
|
12
|
+
from .. import transaction
|
|
13
|
+
from ..connection import DatabaseConnection
|
|
14
|
+
from ..db import close_old_connections, get_connection
|
|
15
|
+
from .utils import (
|
|
16
|
+
setup_database,
|
|
17
|
+
teardown_database,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture(autouse=True)
|
|
22
|
+
def _db_disabled() -> Generator[None]:
|
|
23
|
+
"""
|
|
24
|
+
Every test should use this fixture by default to prevent
|
|
25
|
+
access to the normal database.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def cursor_disabled(self: Any) -> None:
|
|
29
|
+
pytest.fail("Database access not allowed without the `db` fixture") # type: ignore[invalid-argument-type]
|
|
30
|
+
|
|
31
|
+
# Save original cursor method and replace with disabled version
|
|
32
|
+
setattr(DatabaseConnection, "_enabled_cursor", DatabaseConnection.cursor)
|
|
33
|
+
DatabaseConnection.cursor = cursor_disabled # type: ignore[assignment]
|
|
34
|
+
|
|
35
|
+
yield
|
|
36
|
+
|
|
37
|
+
# Restore original cursor method
|
|
38
|
+
DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture(scope="session")
|
|
42
|
+
def setup_db(request: Any) -> Generator[None]:
|
|
43
|
+
"""
|
|
44
|
+
This fixture is called automatically by `db`,
|
|
45
|
+
so a test database will only be setup if the `db` fixture is used.
|
|
46
|
+
"""
|
|
47
|
+
verbosity = request.config.option.verbose
|
|
48
|
+
|
|
49
|
+
# Set up the test db across the entire session
|
|
50
|
+
_old_db_name = setup_database(verbosity=verbosity)
|
|
51
|
+
|
|
52
|
+
# Keep connections open during request client / testing
|
|
53
|
+
request_started.disconnect(close_old_connections)
|
|
54
|
+
request_finished.disconnect(close_old_connections)
|
|
55
|
+
|
|
56
|
+
yield
|
|
57
|
+
|
|
58
|
+
# Put the signals back...
|
|
59
|
+
request_started.connect(close_old_connections)
|
|
60
|
+
request_finished.connect(close_old_connections)
|
|
61
|
+
|
|
62
|
+
# When the test session is done, tear down the test db
|
|
63
|
+
teardown_database(_old_db_name, verbosity=verbosity)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.fixture
|
|
67
|
+
def db(setup_db: Any, request: Any) -> Generator[None]:
|
|
68
|
+
if "isolated_db" in request.fixturenames:
|
|
69
|
+
pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together") # type: ignore[invalid-argument-type]
|
|
70
|
+
|
|
71
|
+
# Set .cursor() back to the original implementation to unblock it
|
|
72
|
+
DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
|
|
73
|
+
|
|
74
|
+
with suppress_db_tracing():
|
|
75
|
+
atomic = transaction.atomic()
|
|
76
|
+
atomic._from_testcase = True
|
|
77
|
+
atomic.__enter__()
|
|
78
|
+
|
|
79
|
+
yield
|
|
80
|
+
|
|
81
|
+
with suppress_db_tracing():
|
|
82
|
+
conn = get_connection()
|
|
83
|
+
# PostgreSQL can defer constraint checks
|
|
84
|
+
if not conn.needs_rollback and conn.is_usable():
|
|
85
|
+
conn.check_constraints()
|
|
86
|
+
|
|
87
|
+
conn.set_rollback(True)
|
|
88
|
+
atomic.__exit__(None, None, None)
|
|
89
|
+
|
|
90
|
+
conn.close()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@pytest.fixture
|
|
94
|
+
def isolated_db(request: Any) -> Generator[None]:
|
|
95
|
+
"""
|
|
96
|
+
Create and destroy a unique test database for each test, using a prefix
|
|
97
|
+
derived from the test function name to ensure isolation from the default
|
|
98
|
+
test database.
|
|
99
|
+
"""
|
|
100
|
+
if "db" in request.fixturenames:
|
|
101
|
+
pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together") # type: ignore[invalid-argument-type]
|
|
102
|
+
# Set .cursor() back to the original implementation to unblock it
|
|
103
|
+
DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
|
|
104
|
+
|
|
105
|
+
verbosity = 1
|
|
106
|
+
|
|
107
|
+
# Derive a safe prefix from the test function name
|
|
108
|
+
raw_name = request.node.name
|
|
109
|
+
prefix = re.sub(r"[^0-9A-Za-z_]+", "_", raw_name)
|
|
110
|
+
|
|
111
|
+
# Set up a fresh test database for this test, using the prefix
|
|
112
|
+
_old_db_name = setup_database(verbosity=verbosity, prefix=prefix)
|
|
113
|
+
|
|
114
|
+
yield
|
|
115
|
+
|
|
116
|
+
# Tear down the test database created for this test
|
|
117
|
+
teardown_database(_old_db_name, verbosity=verbosity)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from plain.postgres.db import get_connection
|
|
4
|
+
from plain.postgres.otel import suppress_db_tracing
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def setup_database(*, verbosity: int, prefix: str = "") -> str:
|
|
8
|
+
conn = get_connection()
|
|
9
|
+
old_name = conn.settings_dict["DATABASE"]
|
|
10
|
+
assert old_name is not None, "DATABASE setting must be set before creating test db"
|
|
11
|
+
with suppress_db_tracing():
|
|
12
|
+
conn.create_test_db(verbosity=verbosity, prefix=prefix)
|
|
13
|
+
return old_name
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def teardown_database(old_name: str, verbosity: int) -> None:
|
|
17
|
+
with suppress_db_tracing():
|
|
18
|
+
get_connection().destroy_test_db(old_name, verbosity)
|