plain.models 0.49.2__py3-none-any.whl → 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. plain/models/CHANGELOG.md +13 -0
  2. plain/models/aggregates.py +42 -19
  3. plain/models/backends/base/base.py +125 -105
  4. plain/models/backends/base/client.py +11 -3
  5. plain/models/backends/base/creation.py +22 -12
  6. plain/models/backends/base/features.py +10 -4
  7. plain/models/backends/base/introspection.py +29 -16
  8. plain/models/backends/base/operations.py +187 -91
  9. plain/models/backends/base/schema.py +267 -165
  10. plain/models/backends/base/validation.py +12 -3
  11. plain/models/backends/ddl_references.py +85 -43
  12. plain/models/backends/mysql/base.py +29 -26
  13. plain/models/backends/mysql/client.py +7 -2
  14. plain/models/backends/mysql/compiler.py +12 -3
  15. plain/models/backends/mysql/creation.py +5 -2
  16. plain/models/backends/mysql/features.py +24 -22
  17. plain/models/backends/mysql/introspection.py +22 -13
  18. plain/models/backends/mysql/operations.py +106 -39
  19. plain/models/backends/mysql/schema.py +48 -24
  20. plain/models/backends/mysql/validation.py +13 -6
  21. plain/models/backends/postgresql/base.py +41 -34
  22. plain/models/backends/postgresql/client.py +7 -2
  23. plain/models/backends/postgresql/creation.py +10 -5
  24. plain/models/backends/postgresql/introspection.py +15 -8
  25. plain/models/backends/postgresql/operations.py +109 -42
  26. plain/models/backends/postgresql/schema.py +85 -46
  27. plain/models/backends/sqlite3/_functions.py +151 -115
  28. plain/models/backends/sqlite3/base.py +37 -23
  29. plain/models/backends/sqlite3/client.py +7 -1
  30. plain/models/backends/sqlite3/creation.py +9 -5
  31. plain/models/backends/sqlite3/features.py +5 -3
  32. plain/models/backends/sqlite3/introspection.py +32 -16
  33. plain/models/backends/sqlite3/operations.py +125 -42
  34. plain/models/backends/sqlite3/schema.py +82 -58
  35. plain/models/backends/utils.py +52 -29
  36. plain/models/backups/cli.py +8 -6
  37. plain/models/backups/clients.py +16 -7
  38. plain/models/backups/core.py +24 -13
  39. plain/models/base.py +113 -74
  40. plain/models/cli.py +94 -63
  41. plain/models/config.py +1 -1
  42. plain/models/connections.py +23 -7
  43. plain/models/constraints.py +65 -47
  44. plain/models/database_url.py +1 -1
  45. plain/models/db.py +6 -2
  46. plain/models/deletion.py +66 -43
  47. plain/models/entrypoints.py +1 -1
  48. plain/models/enums.py +22 -11
  49. plain/models/exceptions.py +23 -8
  50. plain/models/expressions.py +440 -257
  51. plain/models/fields/__init__.py +253 -202
  52. plain/models/fields/json.py +120 -54
  53. plain/models/fields/mixins.py +12 -8
  54. plain/models/fields/related.py +284 -252
  55. plain/models/fields/related_descriptors.py +31 -22
  56. plain/models/fields/related_lookups.py +23 -11
  57. plain/models/fields/related_managers.py +81 -47
  58. plain/models/fields/reverse_related.py +58 -55
  59. plain/models/forms.py +89 -63
  60. plain/models/functions/comparison.py +71 -18
  61. plain/models/functions/datetime.py +79 -29
  62. plain/models/functions/math.py +43 -10
  63. plain/models/functions/mixins.py +24 -7
  64. plain/models/functions/text.py +104 -25
  65. plain/models/functions/window.py +12 -6
  66. plain/models/indexes.py +52 -28
  67. plain/models/lookups.py +228 -153
  68. plain/models/migrations/autodetector.py +86 -43
  69. plain/models/migrations/exceptions.py +7 -3
  70. plain/models/migrations/executor.py +33 -7
  71. plain/models/migrations/graph.py +79 -50
  72. plain/models/migrations/loader.py +45 -22
  73. plain/models/migrations/migration.py +23 -18
  74. plain/models/migrations/operations/base.py +37 -19
  75. plain/models/migrations/operations/fields.py +89 -42
  76. plain/models/migrations/operations/models.py +245 -143
  77. plain/models/migrations/operations/special.py +82 -25
  78. plain/models/migrations/optimizer.py +7 -2
  79. plain/models/migrations/questioner.py +58 -31
  80. plain/models/migrations/recorder.py +18 -11
  81. plain/models/migrations/serializer.py +50 -39
  82. plain/models/migrations/state.py +220 -133
  83. plain/models/migrations/utils.py +29 -13
  84. plain/models/migrations/writer.py +17 -14
  85. plain/models/options.py +63 -56
  86. plain/models/otel.py +16 -6
  87. plain/models/preflight.py +35 -12
  88. plain/models/query.py +323 -228
  89. plain/models/query_utils.py +93 -58
  90. plain/models/registry.py +34 -16
  91. plain/models/sql/compiler.py +146 -97
  92. plain/models/sql/datastructures.py +38 -25
  93. plain/models/sql/query.py +255 -169
  94. plain/models/sql/subqueries.py +32 -21
  95. plain/models/sql/where.py +54 -29
  96. plain/models/test/pytest.py +15 -11
  97. plain/models/test/utils.py +4 -2
  98. plain/models/transaction.py +20 -7
  99. plain/models/utils.py +13 -5
  100. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/METADATA +1 -1
  101. plain_models-0.50.0.dist-info/RECORD +122 -0
  102. plain_models-0.49.2.dist-info/RECORD +0 -122
  103. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/WHEEL +0 -0
  104. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/entry_points.txt +0 -0
  105. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,10 @@
2
2
  Query subclasses which provide extra functionality beyond simple data retrieval.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
7
+ from typing import Any
8
+
5
9
  from plain.models.exceptions import FieldError
6
10
  from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
7
11
  from plain.models.sql.query import Query
@@ -14,7 +18,7 @@ class DeleteQuery(Query):
14
18
 
15
19
  compiler = "SQLDeleteCompiler"
16
20
 
17
- def do_query(self, table, where):
21
+ def do_query(self, table: str, where: Any) -> int:
18
22
  self.alias_map = {table: self.alias_map[table]}
19
23
  self.where = where
20
24
  cursor = self.get_compiler().execute_sql(CURSOR)
@@ -23,7 +27,7 @@ class DeleteQuery(Query):
23
27
  return cursor.rowcount
24
28
  return 0
25
29
 
26
- def delete_batch(self, id_list):
30
+ def delete_batch(self, id_list: list[Any]) -> int:
27
31
  """
28
32
  Set up and execute delete queries for all the objects in id_list.
29
33
 
@@ -48,25 +52,25 @@ class UpdateQuery(Query):
48
52
 
49
53
  compiler = "SQLUpdateCompiler"
50
54
 
51
- def __init__(self, *args, **kwargs):
55
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
52
56
  super().__init__(*args, **kwargs)
53
57
  self._setup_query()
54
58
 
55
- def _setup_query(self):
59
+ def _setup_query(self) -> None:
56
60
  """
57
61
  Run on initialization and at the end of chaining. Any attributes that
58
62
  would normally be set in __init__() should go here instead.
59
63
  """
60
- self.values = []
61
- self.related_ids = None
62
- self.related_updates = {}
64
+ self.values: list[tuple[Any, Any, Any]] = []
65
+ self.related_ids: dict[Any, list[Any]] | None = None
66
+ self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
63
67
 
64
- def clone(self):
68
+ def clone(self) -> UpdateQuery:
65
69
  obj = super().clone()
66
70
  obj.related_updates = self.related_updates.copy()
67
- return obj
71
+ return obj # type: ignore[return-value]
68
72
 
69
- def update_batch(self, id_list, values):
73
+ def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
70
74
  self.add_update_values(values)
71
75
  for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
72
76
  self.clear_where()
@@ -75,7 +79,7 @@ class UpdateQuery(Query):
75
79
  )
76
80
  self.get_compiler().execute_sql(NO_RESULTS)
77
81
 
78
- def add_update_values(self, values):
82
+ def add_update_values(self, values: dict[str, Any]) -> list[tuple[Any, Any, Any]]:
79
83
  """
80
84
  Convert a dictionary of field name to value mappings into an update
81
85
  query. This is the entry point for the public update() method on
@@ -99,7 +103,7 @@ class UpdateQuery(Query):
99
103
  values_seq.append((field, model, val))
100
104
  return self.add_update_fields(values_seq)
101
105
 
102
- def add_update_fields(self, values_seq):
106
+ def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
103
107
  """
104
108
  Append a sequence of (field, model, value) triples to the internal list
105
109
  that will be used to generate the UPDATE query. Might be more usefully
@@ -111,7 +115,7 @@ class UpdateQuery(Query):
111
115
  val = val.resolve_expression(self, allow_joins=False, for_save=True)
112
116
  self.values.append((field, model, val))
113
117
 
114
- def add_related_update(self, model, field, value):
118
+ def add_related_update(self, model: Any, field: Any, value: Any) -> None:
115
119
  """
116
120
  Add (name, value) to an update query for an ancestor model.
117
121
 
@@ -119,7 +123,7 @@ class UpdateQuery(Query):
119
123
  """
120
124
  self.related_updates.setdefault(model, []).append((field, None, value))
121
125
 
122
- def get_related_updates(self):
126
+ def get_related_updates(self) -> list[UpdateQuery]:
123
127
  """
124
128
  Return a list of query objects: one for each update required to an
125
129
  ancestor model. Each query will have the same filtering conditions as
@@ -141,19 +145,26 @@ class InsertQuery(Query):
141
145
  compiler = "SQLInsertCompiler"
142
146
 
143
147
  def __init__(
144
- self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
145
- ):
148
+ self,
149
+ *args: Any,
150
+ on_conflict: str | None = None,
151
+ update_fields: list[Any] | None = None,
152
+ unique_fields: list[Any] | None = None,
153
+ **kwargs: Any,
154
+ ) -> None:
146
155
  super().__init__(*args, **kwargs)
147
- self.fields = []
148
- self.objs = []
156
+ self.fields: list[Any] = []
157
+ self.objs: list[Any] = []
149
158
  self.on_conflict = on_conflict
150
159
  self.update_fields = update_fields or []
151
160
  self.unique_fields = unique_fields or []
152
161
 
153
- def insert_values(self, fields, objs, raw=False):
162
+ def insert_values(
163
+ self, fields: list[Any], objs: list[Any], raw: bool = False
164
+ ) -> None:
154
165
  self.fields = fields
155
166
  self.objs = objs
156
- self.raw = raw
167
+ self.raw = raw # type: ignore[attr-defined]
157
168
 
158
169
 
159
170
  class AggregateQuery(Query):
@@ -164,6 +175,6 @@ class AggregateQuery(Query):
164
175
 
165
176
  compiler = "SQLAggregateCompiler"
166
177
 
167
- def __init__(self, model, inner_query):
178
+ def __init__(self, model: Any, inner_query: Any) -> None:
168
179
  self.inner_query = inner_query
169
180
  super().__init__(model)
plain/models/sql/where.py CHANGED
@@ -2,14 +2,21 @@
2
2
  Code to manage the creation and SQL rendering of 'where' constraints.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import operator
6
8
  from functools import cached_property, reduce
9
+ from typing import TYPE_CHECKING, Any
7
10
 
8
11
  from plain.models.exceptions import EmptyResultSet, FullResultSet
9
12
  from plain.models.expressions import Case, When
10
13
  from plain.models.lookups import Exact
11
14
  from plain.utils import tree
12
15
 
16
+ if TYPE_CHECKING:
17
+ from plain.models.backends.base.base import BaseDatabaseWrapper
18
+ from plain.models.sql.compiler import SQLCompiler
19
+
13
20
  # Connection types
14
21
  AND = "AND"
15
22
  OR = "OR"
@@ -35,7 +42,9 @@ class WhereNode(tree.Node):
35
42
  resolved = False
36
43
  conditional = True
37
44
 
38
- def split_having_qualify(self, negated=False, must_group_by=False):
45
+ def split_having_qualify(
46
+ self, negated: bool = False, must_group_by: bool = False
47
+ ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
39
48
  """
40
49
  Return three possibly None nodes: one for those parts of self that
41
50
  should be included in the WHERE clause, one for those parts of self
@@ -111,7 +120,9 @@ class WhereNode(tree.Node):
111
120
  )
112
121
  return where_node, having_node, qualify_node
113
122
 
114
- def as_sql(self, compiler, connection):
123
+ def as_sql(
124
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
125
+ ) -> tuple[str, list[Any]]:
115
126
  """
116
127
  Return the SQL version of the where clause and the value to be
117
128
  substituted in. Return '', [] if this node matches everything,
@@ -181,20 +192,20 @@ class WhereNode(tree.Node):
181
192
  sql_string = f"({sql_string})"
182
193
  return sql_string, result_params
183
194
 
184
- def get_group_by_cols(self):
195
+ def get_group_by_cols(self) -> list[Any]:
185
196
  cols = []
186
197
  for child in self.children:
187
198
  cols.extend(child.get_group_by_cols())
188
199
  return cols
189
200
 
190
- def get_source_expressions(self):
201
+ def get_source_expressions(self) -> list[Any]:
191
202
  return self.children[:]
192
203
 
193
- def set_source_expressions(self, children):
204
+ def set_source_expressions(self, children: list[Any]) -> None:
194
205
  assert len(children) == len(self.children)
195
206
  self.children = children
196
207
 
197
- def relabel_aliases(self, change_map):
208
+ def relabel_aliases(self, change_map: dict[str, str]) -> None:
198
209
  """
199
210
  Relabel the alias values of any children. 'change_map' is a dictionary
200
211
  mapping old (current) alias values to the new values.
@@ -206,7 +217,7 @@ class WhereNode(tree.Node):
206
217
  elif hasattr(child, "relabeled_clone"):
207
218
  self.children[pos] = child.relabeled_clone(change_map)
208
219
 
209
- def clone(self):
220
+ def clone(self) -> WhereNode:
210
221
  clone = self.create(connector=self.connector, negated=self.negated)
211
222
  for child in self.children:
212
223
  if hasattr(child, "clone"):
@@ -214,12 +225,12 @@ class WhereNode(tree.Node):
214
225
  clone.children.append(child)
215
226
  return clone
216
227
 
217
- def relabeled_clone(self, change_map):
228
+ def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
218
229
  clone = self.clone()
219
230
  clone.relabel_aliases(change_map)
220
231
  return clone
221
232
 
222
- def replace_expressions(self, replacements):
233
+ def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
223
234
  if replacement := replacements.get(self):
224
235
  return replacement
225
236
  clone = self.create(connector=self.connector, negated=self.negated)
@@ -227,44 +238,44 @@ class WhereNode(tree.Node):
227
238
  clone.children.append(child.replace_expressions(replacements))
228
239
  return clone
229
240
 
230
- def get_refs(self):
241
+ def get_refs(self) -> set[Any]:
231
242
  refs = set()
232
243
  for child in self.children:
233
244
  refs |= child.get_refs()
234
245
  return refs
235
246
 
236
247
  @classmethod
237
- def _contains_aggregate(cls, obj):
248
+ def _contains_aggregate(cls, obj: Any) -> bool:
238
249
  if isinstance(obj, tree.Node):
239
250
  return any(cls._contains_aggregate(c) for c in obj.children)
240
251
  return obj.contains_aggregate
241
252
 
242
253
  @cached_property
243
- def contains_aggregate(self):
254
+ def contains_aggregate(self) -> bool:
244
255
  return self._contains_aggregate(self)
245
256
 
246
257
  @classmethod
247
- def _contains_over_clause(cls, obj):
258
+ def _contains_over_clause(cls, obj: Any) -> bool:
248
259
  if isinstance(obj, tree.Node):
249
260
  return any(cls._contains_over_clause(c) for c in obj.children)
250
261
  return obj.contains_over_clause
251
262
 
252
263
  @cached_property
253
- def contains_over_clause(self):
264
+ def contains_over_clause(self) -> bool:
254
265
  return self._contains_over_clause(self)
255
266
 
256
267
  @property
257
- def is_summary(self):
268
+ def is_summary(self) -> bool:
258
269
  return any(child.is_summary for child in self.children)
259
270
 
260
271
  @staticmethod
261
- def _resolve_leaf(expr, query, *args, **kwargs):
272
+ def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
262
273
  if hasattr(expr, "resolve_expression"):
263
274
  expr = expr.resolve_expression(query, *args, **kwargs)
264
275
  return expr
265
276
 
266
277
  @classmethod
267
- def _resolve_node(cls, node, query, *args, **kwargs):
278
+ def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
268
279
  if hasattr(node, "children"):
269
280
  for child in node.children:
270
281
  cls._resolve_node(child, query, *args, **kwargs)
@@ -273,23 +284,25 @@ class WhereNode(tree.Node):
273
284
  if hasattr(node, "rhs"):
274
285
  node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
275
286
 
276
- def resolve_expression(self, *args, **kwargs):
287
+ def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
277
288
  clone = self.clone()
278
289
  clone._resolve_node(clone, *args, **kwargs)
279
290
  clone.resolved = True
280
291
  return clone
281
292
 
282
293
  @cached_property
283
- def output_field(self):
294
+ def output_field(self) -> Any:
284
295
  from plain.models.fields import BooleanField
285
296
 
286
297
  return BooleanField()
287
298
 
288
299
  @property
289
- def _output_field_or_none(self):
300
+ def _output_field_or_none(self) -> Any:
290
301
  return self.output_field
291
302
 
292
- def select_format(self, compiler, sql, params):
303
+ def select_format(
304
+ self, compiler: SQLCompiler, sql: str, params: list[Any]
305
+ ) -> tuple[str, list[Any]]:
293
306
  # Wrap filters with a CASE WHEN expression if a database backend
294
307
  # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
295
308
  # BY list.
@@ -297,13 +310,13 @@ class WhereNode(tree.Node):
297
310
  sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
298
311
  return sql, params
299
312
 
300
- def get_db_converters(self, connection):
313
+ def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
301
314
  return self.output_field.get_db_converters(connection)
302
315
 
303
- def get_lookup(self, lookup):
316
+ def get_lookup(self, lookup: str) -> Any:
304
317
  return self.output_field.get_lookup(lookup)
305
318
 
306
- def leaves(self):
319
+ def leaves(self) -> Any:
307
320
  for child in self.children:
308
321
  if isinstance(child, WhereNode):
309
322
  yield from child.leaves()
@@ -317,7 +330,11 @@ class NothingNode:
317
330
  contains_aggregate = False
318
331
  contains_over_clause = False
319
332
 
320
- def as_sql(self, compiler=None, connection=None):
333
+ def as_sql(
334
+ self,
335
+ compiler: SQLCompiler | None = None,
336
+ connection: BaseDatabaseWrapper | None = None,
337
+ ) -> tuple[str, list[Any]]:
321
338
  raise EmptyResultSet
322
339
 
323
340
 
@@ -326,11 +343,15 @@ class ExtraWhere:
326
343
  contains_aggregate = False
327
344
  contains_over_clause = False
328
345
 
329
- def __init__(self, sqls, params):
346
+ def __init__(self, sqls: list[str], params: list[Any] | None):
330
347
  self.sqls = sqls
331
348
  self.params = params
332
349
 
333
- def as_sql(self, compiler=None, connection=None):
350
+ def as_sql(
351
+ self,
352
+ compiler: SQLCompiler | None = None,
353
+ connection: BaseDatabaseWrapper | None = None,
354
+ ) -> tuple[str, list[Any]]:
334
355
  sqls = [f"({sql})" for sql in self.sqls]
335
356
  return " AND ".join(sqls), list(self.params or ())
336
357
 
@@ -341,14 +362,18 @@ class SubqueryConstraint:
341
362
  contains_aggregate = False
342
363
  contains_over_clause = False
343
364
 
344
- def __init__(self, alias, columns, targets, query_object):
365
+ def __init__(
366
+ self, alias: str, columns: list[str], targets: list[Any], query_object: Any
367
+ ):
345
368
  self.alias = alias
346
369
  self.columns = columns
347
370
  self.targets = targets
348
371
  query_object.clear_ordering(clear_default=True)
349
372
  self.query_object = query_object
350
373
 
351
- def as_sql(self, compiler, connection):
374
+ def as_sql(
375
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
376
+ ) -> tuple[str, list[Any]]:
352
377
  query = self.query_object
353
378
  query.set_values(self.targets)
354
379
  query_compiler = query.get_compiler()
@@ -1,4 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import re
4
+ from collections.abc import Generator
5
+ from typing import Any
2
6
 
3
7
  import pytest
4
8
 
@@ -15,25 +19,25 @@ from .utils import (
15
19
 
16
20
 
17
21
  @pytest.fixture(autouse=True)
18
- def _db_disabled():
22
+ def _db_disabled() -> Generator[None, None, None]:
19
23
  """
20
24
  Every test should use this fixture by default to prevent
21
25
  access to the normal database.
22
26
  """
23
27
 
24
- def cursor_disabled(self):
28
+ def cursor_disabled(self: Any) -> None:
25
29
  pytest.fail("Database access not allowed without the `db` fixture")
26
30
 
27
- BaseDatabaseWrapper._enabled_cursor = BaseDatabaseWrapper.cursor
28
- BaseDatabaseWrapper.cursor = cursor_disabled
31
+ BaseDatabaseWrapper._enabled_cursor = BaseDatabaseWrapper.cursor # type: ignore[attr-defined]
32
+ BaseDatabaseWrapper.cursor = cursor_disabled # type: ignore[method-assign]
29
33
 
30
34
  yield
31
35
 
32
- BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
36
+ BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
33
37
 
34
38
 
35
39
  @pytest.fixture(scope="session")
36
- def setup_db(request):
40
+ def setup_db(request: Any) -> Generator[None, None, None]:
37
41
  """
38
42
  This fixture is called automatically by `db`,
39
43
  so a test database will only be setup if the `db` fixture is used.
@@ -58,19 +62,19 @@ def setup_db(request):
58
62
 
59
63
 
60
64
  @pytest.fixture
61
- def db(setup_db, request):
65
+ def db(setup_db: Any, request: Any) -> Generator[None, None, None]:
62
66
  if "isolated_db" in request.fixturenames:
63
67
  pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
64
68
 
65
69
  # Set .cursor() back to the original implementation to unblock it
66
- BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
70
+ BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
67
71
 
68
72
  if not db_connection.features.supports_transactions:
69
73
  pytest.fail("Database does not support transactions")
70
74
 
71
75
  with suppress_db_tracing():
72
76
  atomic = transaction.atomic()
73
- atomic._from_testcase = True # TODO remove this somehow?
77
+ atomic._from_testcase = True # type: ignore[attr-defined] # TODO remove this somehow?
74
78
  atomic.__enter__()
75
79
 
76
80
  yield
@@ -90,7 +94,7 @@ def db(setup_db, request):
90
94
 
91
95
 
92
96
  @pytest.fixture
93
- def isolated_db(request):
97
+ def isolated_db(request: Any) -> Generator[None, None, None]:
94
98
  """
95
99
  Create and destroy a unique test database for each test, using a prefix
96
100
  derived from the test function name to ensure isolation from the default
@@ -99,7 +103,7 @@ def isolated_db(request):
99
103
  if "db" in request.fixturenames:
100
104
  pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
101
105
  # Set .cursor() back to the original implementation to unblock it
102
- BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor
106
+ BaseDatabaseWrapper.cursor = BaseDatabaseWrapper._enabled_cursor # type: ignore[method-assign]
103
107
 
104
108
  verbosity = 1
105
109
 
@@ -1,14 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  from plain.models import db_connection
2
4
  from plain.models.otel import suppress_db_tracing
3
5
 
4
6
 
5
- def setup_database(*, verbosity, prefix=""):
7
+ def setup_database(*, verbosity: int, prefix: str = "") -> str:
6
8
  old_name = db_connection.settings_dict["NAME"]
7
9
  with suppress_db_tracing():
8
10
  db_connection.creation.create_test_db(verbosity=verbosity, prefix=prefix)
9
11
  return old_name
10
12
 
11
13
 
12
- def teardown_database(old_name, verbosity):
14
+ def teardown_database(old_name: str, verbosity: int) -> None:
13
15
  with suppress_db_tracing():
14
16
  db_connection.creation.destroy_test_db(old_name, verbosity)
@@ -1,7 +1,13 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Generator
1
4
  from contextlib import ContextDecorator, contextmanager
5
+ from typing import Any, TypeVar
2
6
 
3
7
  from plain.models.db import DatabaseError, Error, ProgrammingError, db_connection
4
8
 
9
+ F = TypeVar("F", bound=Callable[..., Any])
10
+
5
11
 
6
12
  class TransactionManagementError(ProgrammingError):
7
13
  """Transaction management is used improperly."""
@@ -10,7 +16,7 @@ class TransactionManagementError(ProgrammingError):
10
16
 
11
17
 
12
18
  @contextmanager
13
- def mark_for_rollback_on_error():
19
+ def mark_for_rollback_on_error() -> Generator[None, None, None]:
14
20
  """
15
21
  Internal low-level utility to mark a transaction as "needs rollback" when
16
22
  an exception is raised while not enforcing the enclosed block to be in a
@@ -36,7 +42,7 @@ def mark_for_rollback_on_error():
36
42
  raise
37
43
 
38
44
 
39
- def on_commit(func, robust=False):
45
+ def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
40
46
  """
41
47
  Register `func` to be called when the current transaction is committed.
42
48
  If the current transaction is rolled back, `func` will not be called.
@@ -83,12 +89,12 @@ class Atomic(ContextDecorator):
83
89
  This is a private API.
84
90
  """
85
91
 
86
- def __init__(self, savepoint, durable):
92
+ def __init__(self, savepoint: bool, durable: bool) -> None:
87
93
  self.savepoint = savepoint
88
94
  self.durable = durable
89
95
  self._from_testcase = False
90
96
 
91
- def __enter__(self):
97
+ def __enter__(self) -> None:
92
98
  if (
93
99
  self.durable
94
100
  and db_connection.atomic_blocks
@@ -127,7 +133,12 @@ class Atomic(ContextDecorator):
127
133
  if db_connection.in_atomic_block:
128
134
  db_connection.atomic_blocks.append(self)
129
135
 
130
- def __exit__(self, exc_type, exc_value, traceback):
136
+ def __exit__(
137
+ self,
138
+ exc_type: type[BaseException] | None,
139
+ exc_value: BaseException | None,
140
+ traceback: Any,
141
+ ) -> None:
131
142
  if db_connection.in_atomic_block:
132
143
  db_connection.atomic_blocks.pop()
133
144
 
@@ -217,8 +228,10 @@ class Atomic(ContextDecorator):
217
228
  db_connection.in_atomic_block = False
218
229
 
219
230
 
220
- def atomic(func=None, *, savepoint=True, durable=False):
231
+ def atomic(
232
+ func: F | None = None, *, savepoint: bool = True, durable: bool = False
233
+ ) -> F | Atomic:
221
234
  """Create an atomic transaction context or decorator."""
222
235
  if callable(func):
223
- return Atomic(savepoint, durable)(func)
236
+ return Atomic(savepoint, durable)(func) # type: ignore[return-value]
224
237
  return Atomic(savepoint, durable)
plain/models/utils.py CHANGED
@@ -1,8 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
  from collections import namedtuple
5
+ from collections.abc import Generator
6
+ from typing import Any
3
7
 
4
8
 
5
- def make_model_tuple(model):
9
+ def make_model_tuple(model: Any) -> tuple[str, str]:
6
10
  """
7
11
  Take a model or a string of the form "package_label.ModelName" and return a
8
12
  corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
@@ -25,7 +29,9 @@ def make_model_tuple(model):
25
29
  )
26
30
 
27
31
 
28
- def resolve_callables(mapping):
32
+ def resolve_callables(
33
+ mapping: dict[str, Any],
34
+ ) -> Generator[tuple[str, Any], None, None]:
29
35
  """
30
36
  Generate key/value pairs for the given mapping where the values are
31
37
  evaluated if they're callable.
@@ -34,15 +40,17 @@ def resolve_callables(mapping):
34
40
  yield k, v() if callable(v) else v
35
41
 
36
42
 
37
- def unpickle_named_row(names, values):
43
+ def unpickle_named_row(
44
+ names: tuple[str, ...], values: tuple[Any, ...]
45
+ ) -> tuple[Any, ...]:
38
46
  return create_namedtuple_class(*names)(*values)
39
47
 
40
48
 
41
49
  @functools.lru_cache
42
- def create_namedtuple_class(*names):
50
+ def create_namedtuple_class(*names: str) -> type[tuple[Any, ...]]:
43
51
  # Cache type() with @lru_cache since it's too slow to be called for every
44
52
  # QuerySet evaluation.
45
- def __reduce__(self):
53
+ def __reduce__(self: Any) -> tuple[Any, tuple[tuple[str, ...], tuple[Any, ...]]]:
46
54
  return unpickle_named_row, (names, tuple(self))
47
55
 
48
56
  return type(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plain.models
3
- Version: 0.49.2
3
+ Version: 0.50.0
4
4
  Summary: Model your data and store it in a database.
5
5
  Author-email: Dave Gaeddert <dave.gaeddert@dropseed.dev>
6
6
  License-File: LICENSE