plain.models 0.49.2__py3-none-any.whl → 0.51.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/models/CHANGELOG.md +27 -0
- plain/models/README.md +26 -42
- plain/models/__init__.py +2 -0
- plain/models/aggregates.py +42 -19
- plain/models/backends/base/base.py +125 -105
- plain/models/backends/base/client.py +11 -3
- plain/models/backends/base/creation.py +24 -14
- plain/models/backends/base/features.py +10 -4
- plain/models/backends/base/introspection.py +37 -20
- plain/models/backends/base/operations.py +187 -91
- plain/models/backends/base/schema.py +338 -218
- plain/models/backends/base/validation.py +13 -4
- plain/models/backends/ddl_references.py +85 -43
- plain/models/backends/mysql/base.py +29 -26
- plain/models/backends/mysql/client.py +7 -2
- plain/models/backends/mysql/compiler.py +13 -4
- plain/models/backends/mysql/creation.py +5 -2
- plain/models/backends/mysql/features.py +24 -22
- plain/models/backends/mysql/introspection.py +22 -13
- plain/models/backends/mysql/operations.py +107 -40
- plain/models/backends/mysql/schema.py +52 -28
- plain/models/backends/mysql/validation.py +13 -6
- plain/models/backends/postgresql/base.py +41 -34
- plain/models/backends/postgresql/client.py +7 -2
- plain/models/backends/postgresql/creation.py +10 -5
- plain/models/backends/postgresql/introspection.py +15 -8
- plain/models/backends/postgresql/operations.py +110 -43
- plain/models/backends/postgresql/schema.py +88 -49
- plain/models/backends/sqlite3/_functions.py +151 -115
- plain/models/backends/sqlite3/base.py +37 -23
- plain/models/backends/sqlite3/client.py +7 -1
- plain/models/backends/sqlite3/creation.py +9 -5
- plain/models/backends/sqlite3/features.py +5 -3
- plain/models/backends/sqlite3/introspection.py +32 -16
- plain/models/backends/sqlite3/operations.py +126 -43
- plain/models/backends/sqlite3/schema.py +127 -92
- plain/models/backends/utils.py +52 -29
- plain/models/backups/cli.py +8 -6
- plain/models/backups/clients.py +16 -7
- plain/models/backups/core.py +24 -13
- plain/models/base.py +221 -229
- plain/models/cli.py +98 -67
- plain/models/config.py +1 -1
- plain/models/connections.py +23 -7
- plain/models/constraints.py +79 -56
- plain/models/database_url.py +1 -1
- plain/models/db.py +6 -2
- plain/models/deletion.py +80 -56
- plain/models/entrypoints.py +1 -1
- plain/models/enums.py +22 -11
- plain/models/exceptions.py +23 -8
- plain/models/expressions.py +441 -258
- plain/models/fields/__init__.py +272 -217
- plain/models/fields/json.py +123 -57
- plain/models/fields/mixins.py +12 -8
- plain/models/fields/related.py +324 -290
- plain/models/fields/related_descriptors.py +33 -24
- plain/models/fields/related_lookups.py +24 -12
- plain/models/fields/related_managers.py +102 -79
- plain/models/fields/reverse_related.py +66 -63
- plain/models/forms.py +101 -75
- plain/models/functions/comparison.py +71 -18
- plain/models/functions/datetime.py +79 -29
- plain/models/functions/math.py +43 -10
- plain/models/functions/mixins.py +24 -7
- plain/models/functions/text.py +104 -25
- plain/models/functions/window.py +12 -6
- plain/models/indexes.py +57 -32
- plain/models/lookups.py +228 -153
- plain/models/meta.py +505 -0
- plain/models/migrations/autodetector.py +86 -43
- plain/models/migrations/exceptions.py +7 -3
- plain/models/migrations/executor.py +33 -7
- plain/models/migrations/graph.py +79 -50
- plain/models/migrations/loader.py +45 -22
- plain/models/migrations/migration.py +23 -18
- plain/models/migrations/operations/base.py +38 -20
- plain/models/migrations/operations/fields.py +95 -48
- plain/models/migrations/operations/models.py +246 -142
- plain/models/migrations/operations/special.py +82 -25
- plain/models/migrations/optimizer.py +7 -2
- plain/models/migrations/questioner.py +58 -31
- plain/models/migrations/recorder.py +27 -16
- plain/models/migrations/serializer.py +50 -39
- plain/models/migrations/state.py +232 -156
- plain/models/migrations/utils.py +30 -14
- plain/models/migrations/writer.py +17 -14
- plain/models/options.py +189 -518
- plain/models/otel.py +16 -6
- plain/models/preflight.py +42 -17
- plain/models/query.py +400 -251
- plain/models/query_utils.py +109 -69
- plain/models/registry.py +40 -21
- plain/models/sql/compiler.py +190 -127
- plain/models/sql/datastructures.py +38 -25
- plain/models/sql/query.py +320 -225
- plain/models/sql/subqueries.py +36 -25
- plain/models/sql/where.py +54 -29
- plain/models/test/pytest.py +15 -11
- plain/models/test/utils.py +4 -2
- plain/models/transaction.py +20 -7
- plain/models/utils.py +17 -6
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/METADATA +27 -43
- plain_models-0.51.0.dist-info/RECORD +123 -0
- plain_models-0.49.2.dist-info/RECORD +0 -122
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/WHEEL +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/entry_points.txt +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/licenses/LICENSE +0 -0
plain/models/sql/subqueries.py
CHANGED
@@ -2,6 +2,10 @@
|
|
2
2
|
Query subclasses which provide extra functionality beyond simple data retrieval.
|
3
3
|
"""
|
4
4
|
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
5
9
|
from plain.models.exceptions import FieldError
|
6
10
|
from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
7
11
|
from plain.models.sql.query import Query
|
@@ -14,7 +18,7 @@ class DeleteQuery(Query):
|
|
14
18
|
|
15
19
|
compiler = "SQLDeleteCompiler"
|
16
20
|
|
17
|
-
def do_query(self, table, where):
|
21
|
+
def do_query(self, table: str, where: Any) -> int:
|
18
22
|
self.alias_map = {table: self.alias_map[table]}
|
19
23
|
self.where = where
|
20
24
|
cursor = self.get_compiler().execute_sql(CURSOR)
|
@@ -23,7 +27,7 @@ class DeleteQuery(Query):
|
|
23
27
|
return cursor.rowcount
|
24
28
|
return 0
|
25
29
|
|
26
|
-
def delete_batch(self, id_list):
|
30
|
+
def delete_batch(self, id_list: list[Any]) -> int:
|
27
31
|
"""
|
28
32
|
Set up and execute delete queries for all the objects in id_list.
|
29
33
|
|
@@ -32,14 +36,14 @@ class DeleteQuery(Query):
|
|
32
36
|
"""
|
33
37
|
# number of objects deleted
|
34
38
|
num_deleted = 0
|
35
|
-
field = self.
|
39
|
+
field = self.get_model_meta().get_field("id")
|
36
40
|
for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
|
37
41
|
self.clear_where()
|
38
42
|
self.add_filter(
|
39
43
|
f"{field.attname}__in",
|
40
44
|
id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
|
41
45
|
)
|
42
|
-
num_deleted += self.do_query(self.
|
46
|
+
num_deleted += self.do_query(self.model.model_options.db_table, self.where)
|
43
47
|
return num_deleted
|
44
48
|
|
45
49
|
|
@@ -48,25 +52,25 @@ class UpdateQuery(Query):
|
|
48
52
|
|
49
53
|
compiler = "SQLUpdateCompiler"
|
50
54
|
|
51
|
-
def __init__(self, *args, **kwargs):
|
55
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
52
56
|
super().__init__(*args, **kwargs)
|
53
57
|
self._setup_query()
|
54
58
|
|
55
|
-
def _setup_query(self):
|
59
|
+
def _setup_query(self) -> None:
|
56
60
|
"""
|
57
61
|
Run on initialization and at the end of chaining. Any attributes that
|
58
62
|
would normally be set in __init__() should go here instead.
|
59
63
|
"""
|
60
|
-
self.values = []
|
61
|
-
self.related_ids = None
|
62
|
-
self.related_updates = {}
|
64
|
+
self.values: list[tuple[Any, Any, Any]] = []
|
65
|
+
self.related_ids: dict[Any, list[Any]] | None = None
|
66
|
+
self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
|
63
67
|
|
64
|
-
def clone(self):
|
68
|
+
def clone(self) -> UpdateQuery:
|
65
69
|
obj = super().clone()
|
66
70
|
obj.related_updates = self.related_updates.copy()
|
67
|
-
return obj
|
71
|
+
return obj # type: ignore[return-value]
|
68
72
|
|
69
|
-
def update_batch(self, id_list, values):
|
73
|
+
def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
|
70
74
|
self.add_update_values(values)
|
71
75
|
for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
|
72
76
|
self.clear_where()
|
@@ -75,7 +79,7 @@ class UpdateQuery(Query):
|
|
75
79
|
)
|
76
80
|
self.get_compiler().execute_sql(NO_RESULTS)
|
77
81
|
|
78
|
-
def add_update_values(self, values):
|
82
|
+
def add_update_values(self, values: dict[str, Any]) -> list[tuple[Any, Any, Any]]:
|
79
83
|
"""
|
80
84
|
Convert a dictionary of field name to value mappings into an update
|
81
85
|
query. This is the entry point for the public update() method on
|
@@ -83,7 +87,7 @@ class UpdateQuery(Query):
|
|
83
87
|
"""
|
84
88
|
values_seq = []
|
85
89
|
for name, val in values.items():
|
86
|
-
field = self.
|
90
|
+
field = self.get_model_meta().get_field(name)
|
87
91
|
direct = (
|
88
92
|
not (field.auto_created and not field.concrete) or not field.concrete
|
89
93
|
)
|
@@ -93,13 +97,13 @@ class UpdateQuery(Query):
|
|
93
97
|
f"Cannot update model field {field!r} (only non-relations and "
|
94
98
|
"foreign keys permitted)."
|
95
99
|
)
|
96
|
-
if model is not self.
|
100
|
+
if model is not self.get_model_meta().model:
|
97
101
|
self.add_related_update(model, field, val)
|
98
102
|
continue
|
99
103
|
values_seq.append((field, model, val))
|
100
104
|
return self.add_update_fields(values_seq)
|
101
105
|
|
102
|
-
def add_update_fields(self, values_seq):
|
106
|
+
def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
|
103
107
|
"""
|
104
108
|
Append a sequence of (field, model, value) triples to the internal list
|
105
109
|
that will be used to generate the UPDATE query. Might be more usefully
|
@@ -111,7 +115,7 @@ class UpdateQuery(Query):
|
|
111
115
|
val = val.resolve_expression(self, allow_joins=False, for_save=True)
|
112
116
|
self.values.append((field, model, val))
|
113
117
|
|
114
|
-
def add_related_update(self, model, field, value):
|
118
|
+
def add_related_update(self, model: Any, field: Any, value: Any) -> None:
|
115
119
|
"""
|
116
120
|
Add (name, value) to an update query for an ancestor model.
|
117
121
|
|
@@ -119,7 +123,7 @@ class UpdateQuery(Query):
|
|
119
123
|
"""
|
120
124
|
self.related_updates.setdefault(model, []).append((field, None, value))
|
121
125
|
|
122
|
-
def get_related_updates(self):
|
126
|
+
def get_related_updates(self) -> list[UpdateQuery]:
|
123
127
|
"""
|
124
128
|
Return a list of query objects: one for each update required to an
|
125
129
|
ancestor model. Each query will have the same filtering conditions as
|
@@ -141,19 +145,26 @@ class InsertQuery(Query):
|
|
141
145
|
compiler = "SQLInsertCompiler"
|
142
146
|
|
143
147
|
def __init__(
|
144
|
-
self,
|
145
|
-
|
148
|
+
self,
|
149
|
+
*args: Any,
|
150
|
+
on_conflict: str | None = None,
|
151
|
+
update_fields: list[Any] | None = None,
|
152
|
+
unique_fields: list[Any] | None = None,
|
153
|
+
**kwargs: Any,
|
154
|
+
) -> None:
|
146
155
|
super().__init__(*args, **kwargs)
|
147
|
-
self.fields = []
|
148
|
-
self.objs = []
|
156
|
+
self.fields: list[Any] = []
|
157
|
+
self.objs: list[Any] = []
|
149
158
|
self.on_conflict = on_conflict
|
150
159
|
self.update_fields = update_fields or []
|
151
160
|
self.unique_fields = unique_fields or []
|
152
161
|
|
153
|
-
def insert_values(
|
162
|
+
def insert_values(
|
163
|
+
self, fields: list[Any], objs: list[Any], raw: bool = False
|
164
|
+
) -> None:
|
154
165
|
self.fields = fields
|
155
166
|
self.objs = objs
|
156
|
-
self.raw = raw
|
167
|
+
self.raw = raw # type: ignore[attr-defined]
|
157
168
|
|
158
169
|
|
159
170
|
class AggregateQuery(Query):
|
@@ -164,6 +175,6 @@ class AggregateQuery(Query):
|
|
164
175
|
|
165
176
|
compiler = "SQLAggregateCompiler"
|
166
177
|
|
167
|
-
def __init__(self, model, inner_query):
|
178
|
+
def __init__(self, model: Any, inner_query: Any) -> None:
|
168
179
|
self.inner_query = inner_query
|
169
180
|
super().__init__(model)
|
plain/models/sql/where.py
CHANGED
@@ -2,14 +2,21 @@
|
|
2
2
|
Code to manage the creation and SQL rendering of 'where' constraints.
|
3
3
|
"""
|
4
4
|
|
5
|
+
from __future__ import annotations
|
6
|
+
|
5
7
|
import operator
|
6
8
|
from functools import cached_property, reduce
|
9
|
+
from typing import TYPE_CHECKING, Any
|
7
10
|
|
8
11
|
from plain.models.exceptions import EmptyResultSet, FullResultSet
|
9
12
|
from plain.models.expressions import Case, When
|
10
13
|
from plain.models.lookups import Exact
|
11
14
|
from plain.utils import tree
|
12
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
18
|
+
from plain.models.sql.compiler import SQLCompiler
|
19
|
+
|
13
20
|
# Connection types
|
14
21
|
AND = "AND"
|
15
22
|
OR = "OR"
|
@@ -35,7 +42,9 @@ class WhereNode(tree.Node):
|
|
35
42
|
resolved = False
|
36
43
|
conditional = True
|
37
44
|
|
38
|
-
def split_having_qualify(
|
45
|
+
def split_having_qualify(
|
46
|
+
self, negated: bool = False, must_group_by: bool = False
|
47
|
+
) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
|
39
48
|
"""
|
40
49
|
Return three possibly None nodes: one for those parts of self that
|
41
50
|
should be included in the WHERE clause, one for those parts of self
|
@@ -111,7 +120,9 @@ class WhereNode(tree.Node):
|
|
111
120
|
)
|
112
121
|
return where_node, having_node, qualify_node
|
113
122
|
|
114
|
-
def as_sql(
|
123
|
+
def as_sql(
|
124
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
125
|
+
) -> tuple[str, list[Any]]:
|
115
126
|
"""
|
116
127
|
Return the SQL version of the where clause and the value to be
|
117
128
|
substituted in. Return '', [] if this node matches everything,
|
@@ -181,20 +192,20 @@ class WhereNode(tree.Node):
|
|
181
192
|
sql_string = f"({sql_string})"
|
182
193
|
return sql_string, result_params
|
183
194
|
|
184
|
-
def get_group_by_cols(self):
|
195
|
+
def get_group_by_cols(self) -> list[Any]:
|
185
196
|
cols = []
|
186
197
|
for child in self.children:
|
187
198
|
cols.extend(child.get_group_by_cols())
|
188
199
|
return cols
|
189
200
|
|
190
|
-
def get_source_expressions(self):
|
201
|
+
def get_source_expressions(self) -> list[Any]:
|
191
202
|
return self.children[:]
|
192
203
|
|
193
|
-
def set_source_expressions(self, children):
|
204
|
+
def set_source_expressions(self, children: list[Any]) -> None:
|
194
205
|
assert len(children) == len(self.children)
|
195
206
|
self.children = children
|
196
207
|
|
197
|
-
def relabel_aliases(self, change_map):
|
208
|
+
def relabel_aliases(self, change_map: dict[str, str]) -> None:
|
198
209
|
"""
|
199
210
|
Relabel the alias values of any children. 'change_map' is a dictionary
|
200
211
|
mapping old (current) alias values to the new values.
|
@@ -206,7 +217,7 @@ class WhereNode(tree.Node):
|
|
206
217
|
elif hasattr(child, "relabeled_clone"):
|
207
218
|
self.children[pos] = child.relabeled_clone(change_map)
|
208
219
|
|
209
|
-
def clone(self):
|
220
|
+
def clone(self) -> WhereNode:
|
210
221
|
clone = self.create(connector=self.connector, negated=self.negated)
|
211
222
|
for child in self.children:
|
212
223
|
if hasattr(child, "clone"):
|
@@ -214,12 +225,12 @@ class WhereNode(tree.Node):
|
|
214
225
|
clone.children.append(child)
|
215
226
|
return clone
|
216
227
|
|
217
|
-
def relabeled_clone(self, change_map):
|
228
|
+
def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
|
218
229
|
clone = self.clone()
|
219
230
|
clone.relabel_aliases(change_map)
|
220
231
|
return clone
|
221
232
|
|
222
|
-
def replace_expressions(self, replacements):
|
233
|
+
def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
|
223
234
|
if replacement := replacements.get(self):
|
224
235
|
return replacement
|
225
236
|
clone = self.create(connector=self.connector, negated=self.negated)
|
@@ -227,44 +238,44 @@ class WhereNode(tree.Node):
|
|
227
238
|
clone.children.append(child.replace_expressions(replacements))
|
228
239
|
return clone
|
229
240
|
|
230
|
-
def get_refs(self):
|
241
|
+
def get_refs(self) -> set[Any]:
|
231
242
|
refs = set()
|
232
243
|
for child in self.children:
|
233
244
|
refs |= child.get_refs()
|
234
245
|
return refs
|
235
246
|
|
236
247
|
@classmethod
|
237
|
-
def _contains_aggregate(cls, obj):
|
248
|
+
def _contains_aggregate(cls, obj: Any) -> bool:
|
238
249
|
if isinstance(obj, tree.Node):
|
239
250
|
return any(cls._contains_aggregate(c) for c in obj.children)
|
240
251
|
return obj.contains_aggregate
|
241
252
|
|
242
253
|
@cached_property
|
243
|
-
def contains_aggregate(self):
|
254
|
+
def contains_aggregate(self) -> bool:
|
244
255
|
return self._contains_aggregate(self)
|
245
256
|
|
246
257
|
@classmethod
|
247
|
-
def _contains_over_clause(cls, obj):
|
258
|
+
def _contains_over_clause(cls, obj: Any) -> bool:
|
248
259
|
if isinstance(obj, tree.Node):
|
249
260
|
return any(cls._contains_over_clause(c) for c in obj.children)
|
250
261
|
return obj.contains_over_clause
|
251
262
|
|
252
263
|
@cached_property
|
253
|
-
def contains_over_clause(self):
|
264
|
+
def contains_over_clause(self) -> bool:
|
254
265
|
return self._contains_over_clause(self)
|
255
266
|
|
256
267
|
@property
|
257
|
-
def is_summary(self):
|
268
|
+
def is_summary(self) -> bool:
|
258
269
|
return any(child.is_summary for child in self.children)
|
259
270
|
|
260
271
|
@staticmethod
|
261
|
-
def _resolve_leaf(expr, query, *args, **kwargs):
|
272
|
+
def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
|
262
273
|
if hasattr(expr, "resolve_expression"):
|
263
274
|
expr = expr.resolve_expression(query, *args, **kwargs)
|
264
275
|
return expr
|
265
276
|
|
266
277
|
@classmethod
|
267
|
-
def _resolve_node(cls, node, query, *args, **kwargs):
|
278
|
+
def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
|
268
279
|
if hasattr(node, "children"):
|
269
280
|
for child in node.children:
|
270
281
|
cls._resolve_node(child, query, *args, **kwargs)
|
@@ -273,23 +284,25 @@ class WhereNode(tree.Node):
|
|
273
284
|
if hasattr(node, "rhs"):
|
274
285
|
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
|
275
286
|
|
276
|
-
def resolve_expression(self, *args, **kwargs):
|
287
|
+
def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
|
277
288
|
clone = self.clone()
|
278
289
|
clone._resolve_node(clone, *args, **kwargs)
|
279
290
|
clone.resolved = True
|
280
291
|
return clone
|
281
292
|
|
282
293
|
@cached_property
|
283
|
-
def output_field(self):
|
294
|
+
def output_field(self) -> Any:
|
284
295
|
from plain.models.fields import BooleanField
|
285
296
|
|
286
297
|
return BooleanField()
|
287
298
|
|
288
299
|
@property
|
289
|
-
def _output_field_or_none(self):
|
300
|
+
def _output_field_or_none(self) -> Any:
|
290
301
|
return self.output_field
|
291
302
|
|
292
|
-
def select_format(
|
303
|
+
def select_format(
|
304
|
+
self, compiler: SQLCompiler, sql: str, params: list[Any]
|
305
|
+
) -> tuple[str, list[Any]]:
|
293
306
|
# Wrap filters with a CASE WHEN expression if a database backend
|
294
307
|
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
295
308
|
# BY list.
|
@@ -297,13 +310,13 @@ class WhereNode(tree.Node):
|
|
297
310
|
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
298
311
|
return sql, params
|
299
312
|
|
300
|
-
def get_db_converters(self, connection):
|
313
|
+
def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
|
301
314
|
return self.output_field.get_db_converters(connection)
|
302
315
|
|
303
|
-
def get_lookup(self, lookup):
|
316
|
+
def get_lookup(self, lookup: str) -> Any:
|
304
317
|
return self.output_field.get_lookup(lookup)
|
305
318
|
|
306
|
-
def leaves(self):
|
319
|
+
def leaves(self) -> Any:
|
307
320
|
for child in self.children:
|
308
321
|
if isinstance(child, WhereNode):
|
309
322
|
yield from child.leaves()
|
@@ -317,7 +330,11 @@ class NothingNode:
|
|
317
330
|
contains_aggregate = False
|
318
331
|
contains_over_clause = False
|
319
332
|
|
320
|
-
def as_sql(
|
333
|
+
def as_sql(
|
334
|
+
self,
|
335
|
+
compiler: SQLCompiler | None = None,
|
336
|
+
connection: BaseDatabaseWrapper | None = None,
|
337
|
+
) -> tuple[str, list[Any]]:
|
321
338
|
raise EmptyResultSet
|
322
339
|
|
323
340
|
|
@@ -326,11 +343,15 @@ class ExtraWhere:
|
|
326
343
|
contains_aggregate = False
|
327
344
|
contains_over_clause = False
|
328
345
|
|
329
|
-
def __init__(self, sqls, params):
|
346
|
+
def __init__(self, sqls: list[str], params: list[Any] | None):
|
330
347
|
self.sqls = sqls
|
331
348
|
self.params = params
|
332
349
|
|
333
|
-
def as_sql(
|
350
|
+
def as_sql(
|
351
|
+
self,
|
352
|
+
compiler: SQLCompiler | None = None,
|
353
|
+
connection: BaseDatabaseWrapper | None = None,
|
354
|
+
) -> tuple[str, list[Any]]:
|
334
355
|
sqls = [f"({sql})" for sql in self.sqls]
|
335
356
|
return " AND ".join(sqls), list(self.params or ())
|
336
357
|
|
@@ -341,14 +362,18 @@ class SubqueryConstraint:
|
|
341
362
|
contains_aggregate = False
|
342
363
|
contains_over_clause = False
|
343
364
|
|
344
|
-
def __init__(
|
365
|
+
def __init__(
|
366
|
+
self, alias: str, columns: list[str], targets: list[Any], query_object: Any
|
367
|
+
):
|
345
368
|
self.alias = alias
|
346
369
|
self.columns = columns
|
347
370
|
self.targets = targets
|
348
371
|
query_object.clear_ordering(clear_default=True)
|
349
372
|
self.query_object = query_object
|
350
373
|
|
351
|
-
def as_sql(
|
374
|
+
def as_sql(
|
375
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
376
|
+
) -> tuple[str, list[Any]]:
|
352
377
|
query = self.query_object
|
353
378
|
query.set_values(self.targets)
|
354
379
|
query_compiler = query.get_compiler()
|
plain/models/test/pytest.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import re
|
4
|
+
from collections.abc import Generator
|
5
|
+
from typing import Any
|
2
6
|
|
3
7
|
import pytest
|
4
8
|
|
@@ -15,25 +19,25 @@ from .utils import (
|
|
15
19
|
|
16
20
|
|
17
21
|
@pytest.fixture(autouse=True)
|
18
|
-
def _db_disabled():
|
22
|
+
def _db_disabled() -> Generator[None, None, None]:
|
19
23
|
"""
|
20
24
|
Every test should use this fixture by default to prevent
|
21
25
|
access to the normal database.
|
22
26
|
"""
|
23
27
|
|
24
|
-
def cursor_disabled(self):
|
28
|
+
def cursor_disabled(self: Any) -> None:
|
25
29
|
pytest.fail("Database access not allowed without the `db` fixture")
|
26
30
|
|
27
|
-
BaseDatabaseWrapper._enabled_cursor = BaseDatabaseWrapper.cursor
|
28
|
-
BaseDatabaseWrapper.cursor = cursor_disabled
|
31
|
+
BaseDatabaseWrapper._enabled_cursor = BaseDatabaseWrapper.cursor # type: ignore[attr-defined]
|
32
|
+
BaseDatabaseWrapper.cursor = cursor_disabled # type: ignore[method-assign]
|
29
33
|
|
30
34
|
yield
|
31
35
|
|
32
|
-
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
|
36
|
+
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
|
33
37
|
|
34
38
|
|
35
39
|
@pytest.fixture(scope="session")
|
36
|
-
def setup_db(request):
|
40
|
+
def setup_db(request: Any) -> Generator[None, None, None]:
|
37
41
|
"""
|
38
42
|
This fixture is called automatically by `db`,
|
39
43
|
so a test database will only be setup if the `db` fixture is used.
|
@@ -58,19 +62,19 @@ def setup_db(request):
|
|
58
62
|
|
59
63
|
|
60
64
|
@pytest.fixture
|
61
|
-
def db(setup_db, request):
|
65
|
+
def db(setup_db: Any, request: Any) -> Generator[None, None, None]:
|
62
66
|
if "isolated_db" in request.fixturenames:
|
63
67
|
pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
|
64
68
|
|
65
69
|
# Set .cursor() back to the original implementation to unblock it
|
66
|
-
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
|
70
|
+
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
|
67
71
|
|
68
72
|
if not db_connection.features.supports_transactions:
|
69
73
|
pytest.fail("Database does not support transactions")
|
70
74
|
|
71
75
|
with suppress_db_tracing():
|
72
76
|
atomic = transaction.atomic()
|
73
|
-
atomic._from_testcase = True # TODO remove this somehow?
|
77
|
+
atomic._from_testcase = True # type: ignore[attr-defined] # TODO remove this somehow?
|
74
78
|
atomic.__enter__()
|
75
79
|
|
76
80
|
yield
|
@@ -90,7 +94,7 @@ def db(setup_db, request):
|
|
90
94
|
|
91
95
|
|
92
96
|
@pytest.fixture
|
93
|
-
def isolated_db(request):
|
97
|
+
def isolated_db(request: Any) -> Generator[None, None, None]:
|
94
98
|
"""
|
95
99
|
Create and destroy a unique test database for each test, using a prefix
|
96
100
|
derived from the test function name to ensure isolation from the default
|
@@ -99,7 +103,7 @@ def isolated_db(request):
|
|
99
103
|
if "db" in request.fixturenames:
|
100
104
|
pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
|
101
105
|
# Set .cursor() back to the original implementation to unblock it
|
102
|
-
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
|
106
|
+
BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
|
103
107
|
|
104
108
|
verbosity = 1
|
105
109
|
|
plain/models/test/utils.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from plain.models import db_connection
|
2
4
|
from plain.models.otel import suppress_db_tracing
|
3
5
|
|
4
6
|
|
5
|
-
def setup_database(*, verbosity, prefix=""):
|
7
|
+
def setup_database(*, verbosity: int, prefix: str = "") -> str:
|
6
8
|
old_name = db_connection.settings_dict["NAME"]
|
7
9
|
with suppress_db_tracing():
|
8
10
|
db_connection.creation.create_test_db(verbosity=verbosity, prefix=prefix)
|
9
11
|
return old_name
|
10
12
|
|
11
13
|
|
12
|
-
def teardown_database(old_name, verbosity):
|
14
|
+
def teardown_database(old_name: str, verbosity: int) -> None:
|
13
15
|
with suppress_db_tracing():
|
14
16
|
db_connection.creation.destroy_test_db(old_name, verbosity)
|
plain/models/transaction.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections.abc import Callable, Generator
|
1
4
|
from contextlib import ContextDecorator, contextmanager
|
5
|
+
from typing import Any, TypeVar
|
2
6
|
|
3
7
|
from plain.models.db import DatabaseError, Error, ProgrammingError, db_connection
|
4
8
|
|
9
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
10
|
+
|
5
11
|
|
6
12
|
class TransactionManagementError(ProgrammingError):
|
7
13
|
"""Transaction management is used improperly."""
|
@@ -10,7 +16,7 @@ class TransactionManagementError(ProgrammingError):
|
|
10
16
|
|
11
17
|
|
12
18
|
@contextmanager
|
13
|
-
def mark_for_rollback_on_error():
|
19
|
+
def mark_for_rollback_on_error() -> Generator[None, None, None]:
|
14
20
|
"""
|
15
21
|
Internal low-level utility to mark a transaction as "needs rollback" when
|
16
22
|
an exception is raised while not enforcing the enclosed block to be in a
|
@@ -36,7 +42,7 @@ def mark_for_rollback_on_error():
|
|
36
42
|
raise
|
37
43
|
|
38
44
|
|
39
|
-
def on_commit(func, robust=False):
|
45
|
+
def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
|
40
46
|
"""
|
41
47
|
Register `func` to be called when the current transaction is committed.
|
42
48
|
If the current transaction is rolled back, `func` will not be called.
|
@@ -83,12 +89,12 @@ class Atomic(ContextDecorator):
|
|
83
89
|
This is a private API.
|
84
90
|
"""
|
85
91
|
|
86
|
-
def __init__(self, savepoint, durable):
|
92
|
+
def __init__(self, savepoint: bool, durable: bool) -> None:
|
87
93
|
self.savepoint = savepoint
|
88
94
|
self.durable = durable
|
89
95
|
self._from_testcase = False
|
90
96
|
|
91
|
-
def __enter__(self):
|
97
|
+
def __enter__(self) -> None:
|
92
98
|
if (
|
93
99
|
self.durable
|
94
100
|
and db_connection.atomic_blocks
|
@@ -127,7 +133,12 @@ class Atomic(ContextDecorator):
|
|
127
133
|
if db_connection.in_atomic_block:
|
128
134
|
db_connection.atomic_blocks.append(self)
|
129
135
|
|
130
|
-
def __exit__(
|
136
|
+
def __exit__(
|
137
|
+
self,
|
138
|
+
exc_type: type[BaseException] | None,
|
139
|
+
exc_value: BaseException | None,
|
140
|
+
traceback: Any,
|
141
|
+
) -> None:
|
131
142
|
if db_connection.in_atomic_block:
|
132
143
|
db_connection.atomic_blocks.pop()
|
133
144
|
|
@@ -217,8 +228,10 @@ class Atomic(ContextDecorator):
|
|
217
228
|
db_connection.in_atomic_block = False
|
218
229
|
|
219
230
|
|
220
|
-
def atomic(
|
231
|
+
def atomic(
|
232
|
+
func: F | None = None, *, savepoint: bool = True, durable: bool = False
|
233
|
+
) -> F | Atomic:
|
221
234
|
"""Create an atomic transaction context or decorator."""
|
222
235
|
if callable(func):
|
223
|
-
return Atomic(savepoint, durable)(func)
|
236
|
+
return Atomic(savepoint, durable)(func) # type: ignore[return-value]
|
224
237
|
return Atomic(savepoint, durable)
|
plain/models/utils.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import functools
|
2
4
|
from collections import namedtuple
|
5
|
+
from collections.abc import Generator
|
6
|
+
from typing import Any
|
3
7
|
|
4
8
|
|
5
|
-
def make_model_tuple(model):
|
9
|
+
def make_model_tuple(model: Any) -> tuple[str, str]:
|
6
10
|
"""
|
7
11
|
Take a model or a string of the form "package_label.ModelName" and return a
|
8
12
|
corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
|
@@ -15,7 +19,10 @@ def make_model_tuple(model):
|
|
15
19
|
package_label, model_name = model.split(".")
|
16
20
|
model_tuple = package_label, model_name.lower()
|
17
21
|
else:
|
18
|
-
model_tuple =
|
22
|
+
model_tuple = (
|
23
|
+
model.model_options.package_label,
|
24
|
+
model.model_options.model_name,
|
25
|
+
)
|
19
26
|
assert len(model_tuple) == 2
|
20
27
|
return model_tuple
|
21
28
|
except (ValueError, AssertionError):
|
@@ -25,7 +32,9 @@ def make_model_tuple(model):
|
|
25
32
|
)
|
26
33
|
|
27
34
|
|
28
|
-
def resolve_callables(
|
35
|
+
def resolve_callables(
|
36
|
+
mapping: dict[str, Any],
|
37
|
+
) -> Generator[tuple[str, Any], None, None]:
|
29
38
|
"""
|
30
39
|
Generate key/value pairs for the given mapping where the values are
|
31
40
|
evaluated if they're callable.
|
@@ -34,15 +43,17 @@ def resolve_callables(mapping):
|
|
34
43
|
yield k, v() if callable(v) else v
|
35
44
|
|
36
45
|
|
37
|
-
def unpickle_named_row(
|
46
|
+
def unpickle_named_row(
|
47
|
+
names: tuple[str, ...], values: tuple[Any, ...]
|
48
|
+
) -> tuple[Any, ...]:
|
38
49
|
return create_namedtuple_class(*names)(*values)
|
39
50
|
|
40
51
|
|
41
52
|
@functools.lru_cache
|
42
|
-
def create_namedtuple_class(*names):
|
53
|
+
def create_namedtuple_class(*names: str) -> type[tuple[Any, ...]]:
|
43
54
|
# Cache type() with @lru_cache since it's too slow to be called for every
|
44
55
|
# QuerySet evaluation.
|
45
|
-
def __reduce__(self):
|
56
|
+
def __reduce__(self: Any) -> tuple[Any, tuple[tuple[str, ...], tuple[Any, ...]]]:
|
46
57
|
return unpickle_named_row, (names, tuple(self))
|
47
58
|
|
48
59
|
return type(
|