plain.models 0.49.2__py3-none-any.whl → 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. plain/models/CHANGELOG.md +13 -0
  2. plain/models/aggregates.py +42 -19
  3. plain/models/backends/base/base.py +125 -105
  4. plain/models/backends/base/client.py +11 -3
  5. plain/models/backends/base/creation.py +22 -12
  6. plain/models/backends/base/features.py +10 -4
  7. plain/models/backends/base/introspection.py +29 -16
  8. plain/models/backends/base/operations.py +187 -91
  9. plain/models/backends/base/schema.py +267 -165
  10. plain/models/backends/base/validation.py +12 -3
  11. plain/models/backends/ddl_references.py +85 -43
  12. plain/models/backends/mysql/base.py +29 -26
  13. plain/models/backends/mysql/client.py +7 -2
  14. plain/models/backends/mysql/compiler.py +12 -3
  15. plain/models/backends/mysql/creation.py +5 -2
  16. plain/models/backends/mysql/features.py +24 -22
  17. plain/models/backends/mysql/introspection.py +22 -13
  18. plain/models/backends/mysql/operations.py +106 -39
  19. plain/models/backends/mysql/schema.py +48 -24
  20. plain/models/backends/mysql/validation.py +13 -6
  21. plain/models/backends/postgresql/base.py +41 -34
  22. plain/models/backends/postgresql/client.py +7 -2
  23. plain/models/backends/postgresql/creation.py +10 -5
  24. plain/models/backends/postgresql/introspection.py +15 -8
  25. plain/models/backends/postgresql/operations.py +109 -42
  26. plain/models/backends/postgresql/schema.py +85 -46
  27. plain/models/backends/sqlite3/_functions.py +151 -115
  28. plain/models/backends/sqlite3/base.py +37 -23
  29. plain/models/backends/sqlite3/client.py +7 -1
  30. plain/models/backends/sqlite3/creation.py +9 -5
  31. plain/models/backends/sqlite3/features.py +5 -3
  32. plain/models/backends/sqlite3/introspection.py +32 -16
  33. plain/models/backends/sqlite3/operations.py +125 -42
  34. plain/models/backends/sqlite3/schema.py +82 -58
  35. plain/models/backends/utils.py +52 -29
  36. plain/models/backups/cli.py +8 -6
  37. plain/models/backups/clients.py +16 -7
  38. plain/models/backups/core.py +24 -13
  39. plain/models/base.py +113 -74
  40. plain/models/cli.py +94 -63
  41. plain/models/config.py +1 -1
  42. plain/models/connections.py +23 -7
  43. plain/models/constraints.py +65 -47
  44. plain/models/database_url.py +1 -1
  45. plain/models/db.py +6 -2
  46. plain/models/deletion.py +66 -43
  47. plain/models/entrypoints.py +1 -1
  48. plain/models/enums.py +22 -11
  49. plain/models/exceptions.py +23 -8
  50. plain/models/expressions.py +440 -257
  51. plain/models/fields/__init__.py +253 -202
  52. plain/models/fields/json.py +120 -54
  53. plain/models/fields/mixins.py +12 -8
  54. plain/models/fields/related.py +284 -252
  55. plain/models/fields/related_descriptors.py +31 -22
  56. plain/models/fields/related_lookups.py +23 -11
  57. plain/models/fields/related_managers.py +81 -47
  58. plain/models/fields/reverse_related.py +58 -55
  59. plain/models/forms.py +89 -63
  60. plain/models/functions/comparison.py +71 -18
  61. plain/models/functions/datetime.py +79 -29
  62. plain/models/functions/math.py +43 -10
  63. plain/models/functions/mixins.py +24 -7
  64. plain/models/functions/text.py +104 -25
  65. plain/models/functions/window.py +12 -6
  66. plain/models/indexes.py +52 -28
  67. plain/models/lookups.py +228 -153
  68. plain/models/migrations/autodetector.py +86 -43
  69. plain/models/migrations/exceptions.py +7 -3
  70. plain/models/migrations/executor.py +33 -7
  71. plain/models/migrations/graph.py +79 -50
  72. plain/models/migrations/loader.py +45 -22
  73. plain/models/migrations/migration.py +23 -18
  74. plain/models/migrations/operations/base.py +37 -19
  75. plain/models/migrations/operations/fields.py +89 -42
  76. plain/models/migrations/operations/models.py +245 -143
  77. plain/models/migrations/operations/special.py +82 -25
  78. plain/models/migrations/optimizer.py +7 -2
  79. plain/models/migrations/questioner.py +58 -31
  80. plain/models/migrations/recorder.py +18 -11
  81. plain/models/migrations/serializer.py +50 -39
  82. plain/models/migrations/state.py +220 -133
  83. plain/models/migrations/utils.py +29 -13
  84. plain/models/migrations/writer.py +17 -14
  85. plain/models/options.py +63 -56
  86. plain/models/otel.py +16 -6
  87. plain/models/preflight.py +35 -12
  88. plain/models/query.py +323 -228
  89. plain/models/query_utils.py +93 -58
  90. plain/models/registry.py +34 -16
  91. plain/models/sql/compiler.py +146 -97
  92. plain/models/sql/datastructures.py +38 -25
  93. plain/models/sql/query.py +255 -169
  94. plain/models/sql/subqueries.py +32 -21
  95. plain/models/sql/where.py +54 -29
  96. plain/models/test/pytest.py +15 -11
  97. plain/models/test/utils.py +4 -2
  98. plain/models/transaction.py +20 -7
  99. plain/models/utils.py +13 -5
  100. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/METADATA +1 -1
  101. plain_models-0.50.0.dist-info/RECORD +122 -0
  102. plain_models-0.49.2.dist-info/RECORD +0 -122
  103. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/WHEEL +0 -0
  104. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/entry_points.txt +0 -0
  105. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import _thread
2
4
  import copy
3
5
  import datetime
@@ -7,8 +9,10 @@ import time
7
9
  import warnings
8
10
  import zoneinfo
9
11
  from collections import deque
12
+ from collections.abc import Generator
10
13
  from contextlib import contextmanager
11
14
  from functools import cached_property
15
+ from typing import TYPE_CHECKING, Any
12
16
 
13
17
  from plain.models.backends import utils
14
18
  from plain.models.backends.base.validation import BaseDatabaseValidation
@@ -21,6 +25,14 @@ from plain.models.db import (
21
25
  from plain.models.transaction import TransactionManagementError
22
26
  from plain.runtime import settings
23
27
 
28
+ if TYPE_CHECKING:
29
+ from plain.models.backends.base.client import BaseDatabaseClient
30
+ from plain.models.backends.base.creation import BaseDatabaseCreation
31
+ from plain.models.backends.base.features import BaseDatabaseFeatures
32
+ from plain.models.backends.base.introspection import BaseDatabaseIntrospection
33
+ from plain.models.backends.base.operations import BaseDatabaseOperations
34
+ from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
35
+
24
36
  RAN_DB_VERSION_CHECK = False
25
37
 
26
38
  logger = logging.getLogger("plain.models.backends.base")
@@ -30,96 +42,102 @@ class BaseDatabaseWrapper:
30
42
  """Represent a database connection."""
31
43
 
32
44
  # Mapping of Field objects to their column types.
33
- data_types = {}
45
+ data_types: dict[str, str] = {}
34
46
  # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
35
- data_types_suffix = {}
47
+ data_types_suffix: dict[str, str] = {}
36
48
  # Mapping of Field objects to their SQL for CHECK constraints.
37
- data_type_check_constraints = {}
38
- ops = None
39
- vendor = "unknown"
40
- display_name = "unknown"
41
- SchemaEditorClass = None
49
+ data_type_check_constraints: dict[str, str] = {}
50
+ # Instance attributes - always assigned in __init__
51
+ ops: BaseDatabaseOperations
52
+ client: BaseDatabaseClient
53
+ creation: BaseDatabaseCreation
54
+ features: BaseDatabaseFeatures
55
+ introspection: BaseDatabaseIntrospection
56
+ validation: BaseDatabaseValidation
57
+ vendor: str = "unknown"
58
+ display_name: str = "unknown"
59
+ SchemaEditorClass: type[BaseDatabaseSchemaEditor] | None = None
42
60
  # Classes instantiated in __init__().
43
- client_class = None
44
- creation_class = None
45
- features_class = None
46
- introspection_class = None
47
- ops_class = None
48
- validation_class = BaseDatabaseValidation
61
+ client_class: type[BaseDatabaseClient] | None = None
62
+ creation_class: type[BaseDatabaseCreation] | None = None
63
+ features_class: type[BaseDatabaseFeatures] | None = None
64
+ introspection_class: type[BaseDatabaseIntrospection] | None = None
65
+ ops_class: type[BaseDatabaseOperations] | None = None
66
+ validation_class: type[BaseDatabaseValidation] = BaseDatabaseValidation
49
67
 
50
- queries_limit = 9000
68
+ queries_limit: int = 9000
51
69
 
52
- def __init__(self, settings_dict):
70
+ def __init__(self, settings_dict: dict[str, Any]):
53
71
  # Connection related attributes.
54
72
  # The underlying database connection.
55
- self.connection = None
73
+ self.connection: BaseDatabaseWrapper | None = None
56
74
  # `settings_dict` should be a dictionary containing keys such as
57
75
  # NAME, USER, etc. It's called `settings_dict` instead of `settings`
58
76
  # to disambiguate it from Plain settings modules.
59
- self.settings_dict = settings_dict
77
+ self.settings_dict: dict[str, Any] = settings_dict
60
78
  # Query logging in debug mode or when explicitly enabled.
61
- self.queries_log = deque(maxlen=self.queries_limit)
62
- self.force_debug_cursor = False
79
+ self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
80
+ self.force_debug_cursor: bool = False
63
81
 
64
82
  # Transaction related attributes.
65
83
  # Tracks if the connection is in autocommit mode. Per PEP 249, by
66
84
  # default, it isn't.
67
- self.autocommit = False
85
+ self.autocommit: bool = False
68
86
  # Tracks if the connection is in a transaction managed by 'atomic'.
69
- self.in_atomic_block = False
87
+ self.in_atomic_block: bool = False
70
88
  # Increment to generate unique savepoint ids.
71
- self.savepoint_state = 0
89
+ self.savepoint_state: int = 0
72
90
  # List of savepoints created by 'atomic'.
73
- self.savepoint_ids = []
91
+ self.savepoint_ids: list[str] = []
74
92
  # Stack of active 'atomic' blocks.
75
- self.atomic_blocks = []
93
+ self.atomic_blocks: list[Any] = []
76
94
  # Tracks if the outermost 'atomic' block should commit on exit,
77
95
  # ie. if autocommit was active on entry.
78
- self.commit_on_exit = True
96
+ self.commit_on_exit: bool = True
79
97
  # Tracks if the transaction should be rolled back to the next
80
98
  # available savepoint because of an exception in an inner block.
81
- self.needs_rollback = False
82
- self.rollback_exc = None
99
+ self.needs_rollback: bool = False
100
+ self.rollback_exc: Exception | None = None
83
101
 
84
102
  # Connection termination related attributes.
85
- self.close_at = None
86
- self.closed_in_transaction = False
87
- self.errors_occurred = False
88
- self.health_check_enabled = False
89
- self.health_check_done = False
103
+ self.close_at: float | None = None
104
+ self.closed_in_transaction: bool = False
105
+ self.errors_occurred: bool = False
106
+ self.health_check_enabled: bool = False
107
+ self.health_check_done: bool = False
90
108
 
91
109
  # Thread-safety related attributes.
92
- self._thread_sharing_lock = threading.Lock()
93
- self._thread_sharing_count = 0
94
- self._thread_ident = _thread.get_ident()
110
+ self._thread_sharing_lock: threading.Lock = threading.Lock()
111
+ self._thread_sharing_count: int = 0
112
+ self._thread_ident: int = _thread.get_ident()
95
113
 
96
114
  # A list of no-argument functions to run when the transaction commits.
97
115
  # Each entry is an (sids, func, robust) tuple, where sids is a set of
98
116
  # the active savepoint IDs when this function was registered and robust
99
117
  # specifies whether it's allowed for the function to fail.
100
- self.run_on_commit = []
118
+ self.run_on_commit: list[tuple[set[str], Any, bool]] = []
101
119
 
102
120
  # Should we run the on-commit hooks the next time set_autocommit(True)
103
121
  # is called?
104
- self.run_commit_hooks_on_set_autocommit_on = False
122
+ self.run_commit_hooks_on_set_autocommit_on: bool = False
105
123
 
106
124
  # A stack of wrappers to be invoked around execute()/executemany()
107
125
  # calls. Each entry is a function taking five arguments: execute, sql,
108
126
  # params, many, and context. It's the function's responsibility to
109
127
  # call execute(sql, params, many, context).
110
- self.execute_wrappers = []
128
+ self.execute_wrappers: list[Any] = []
111
129
 
112
- self.client = self.client_class(self)
113
- self.creation = self.creation_class(self)
114
- self.features = self.features_class(self)
115
- self.introspection = self.introspection_class(self)
116
- self.ops = self.ops_class(self)
117
- self.validation = self.validation_class(self)
130
+ self.client: BaseDatabaseClient = self.client_class(self) # type: ignore[misc]
131
+ self.creation: BaseDatabaseCreation = self.creation_class(self) # type: ignore[misc]
132
+ self.features: BaseDatabaseFeatures = self.features_class(self) # type: ignore[misc]
133
+ self.introspection: BaseDatabaseIntrospection = self.introspection_class(self) # type: ignore[misc]
134
+ self.ops: BaseDatabaseOperations = self.ops_class(self) # type: ignore[misc]
135
+ self.validation: BaseDatabaseValidation = self.validation_class(self)
118
136
 
119
- def __repr__(self):
137
+ def __repr__(self) -> str:
120
138
  return f"<{self.__class__.__qualname__} vendor={self.vendor!r}>"
121
139
 
122
- def ensure_timezone(self):
140
+ def ensure_timezone(self) -> bool:
123
141
  """
124
142
  Ensure the connection's timezone is set to `self.timezone_name` and
125
143
  return whether it changed or not.
@@ -127,7 +145,7 @@ class BaseDatabaseWrapper:
127
145
  return False
128
146
 
129
147
  @cached_property
130
- def timezone(self):
148
+ def timezone(self) -> datetime.tzinfo:
131
149
  """
132
150
  Return a tzinfo of the database connection time zone.
133
151
 
@@ -148,7 +166,7 @@ class BaseDatabaseWrapper:
148
166
  return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
149
167
 
150
168
  @cached_property
151
- def timezone_name(self):
169
+ def timezone_name(self) -> str:
152
170
  """
153
171
  Name of the time zone of the database connection.
154
172
  """
@@ -158,11 +176,11 @@ class BaseDatabaseWrapper:
158
176
  return self.settings_dict["TIME_ZONE"]
159
177
 
160
178
  @property
161
- def queries_logged(self):
179
+ def queries_logged(self) -> bool:
162
180
  return self.force_debug_cursor or settings.DEBUG
163
181
 
164
182
  @property
165
- def queries(self):
183
+ def queries(self) -> list[dict[str, Any]]:
166
184
  if len(self.queries_log) == self.queries_log.maxlen:
167
185
  warnings.warn(
168
186
  f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
@@ -170,14 +188,14 @@ class BaseDatabaseWrapper:
170
188
  )
171
189
  return list(self.queries_log)
172
190
 
173
- def get_database_version(self):
191
+ def get_database_version(self) -> tuple[int, ...]:
174
192
  """Return a tuple of the database's version."""
175
193
  raise NotImplementedError(
176
194
  "subclasses of BaseDatabaseWrapper may require a get_database_version() "
177
195
  "method."
178
196
  )
179
197
 
180
- def check_database_version_supported(self):
198
+ def check_database_version_supported(self) -> None:
181
199
  """
182
200
  Raise an error if the database version isn't supported by this
183
201
  version of Plain.
@@ -195,28 +213,28 @@ class BaseDatabaseWrapper:
195
213
 
196
214
  # ##### Backend-specific methods for creating connections and cursors #####
197
215
 
198
- def get_connection_params(self):
216
+ def get_connection_params(self) -> dict[str, Any]:
199
217
  """Return a dict of parameters suitable for get_new_connection."""
200
218
  raise NotImplementedError(
201
219
  "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
202
220
  "method"
203
221
  )
204
222
 
205
- def get_new_connection(self, conn_params):
223
+ def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
206
224
  """Open a connection to the database."""
207
225
  raise NotImplementedError(
208
226
  "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
209
227
  "method"
210
228
  )
211
229
 
212
- def init_connection_state(self):
230
+ def init_connection_state(self) -> None:
213
231
  """Initialize the database connection settings."""
214
232
  global RAN_DB_VERSION_CHECK
215
233
  if not RAN_DB_VERSION_CHECK:
216
234
  self.check_database_version_supported()
217
235
  RAN_DB_VERSION_CHECK = True
218
236
 
219
- def create_cursor(self, name=None):
237
+ def create_cursor(self, name: str | None = None) -> Any:
220
238
  """Create a cursor. Assume that a connection is established."""
221
239
  raise NotImplementedError(
222
240
  "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
@@ -224,7 +242,7 @@ class BaseDatabaseWrapper:
224
242
 
225
243
  # ##### Backend-specific methods for creating connections #####
226
244
 
227
- def connect(self):
245
+ def connect(self) -> None:
228
246
  """Connect to the database. Assume that the connection is closed."""
229
247
  # In case the previous connection was closed while in an atomic block
230
248
  self.in_atomic_block = False
@@ -247,7 +265,7 @@ class BaseDatabaseWrapper:
247
265
 
248
266
  self.run_on_commit = []
249
267
 
250
- def ensure_connection(self):
268
+ def ensure_connection(self) -> None:
251
269
  """Guarantee that a connection to the database is established."""
252
270
  if self.connection is None:
253
271
  with self.wrap_database_errors:
@@ -255,7 +273,7 @@ class BaseDatabaseWrapper:
255
273
 
256
274
  # ##### Backend-specific wrappers for PEP-249 connection methods #####
257
275
 
258
- def _prepare_cursor(self, cursor):
276
+ def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
259
277
  """
260
278
  Validate the connection is usable and perform database cursor wrapping.
261
279
  """
@@ -266,34 +284,34 @@ class BaseDatabaseWrapper:
266
284
  wrapped_cursor = self.make_cursor(cursor)
267
285
  return wrapped_cursor
268
286
 
269
- def _cursor(self, name=None):
287
+ def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
270
288
  self.close_if_health_check_failed()
271
289
  self.ensure_connection()
272
290
  with self.wrap_database_errors:
273
291
  return self._prepare_cursor(self.create_cursor(name))
274
292
 
275
- def _commit(self):
293
+ def _commit(self) -> None:
276
294
  if self.connection is not None:
277
295
  with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
278
296
  return self.connection.commit()
279
297
 
280
- def _rollback(self):
298
+ def _rollback(self) -> None:
281
299
  if self.connection is not None:
282
300
  with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
283
301
  return self.connection.rollback()
284
302
 
285
- def _close(self):
303
+ def _close(self) -> None:
286
304
  if self.connection is not None:
287
305
  with self.wrap_database_errors:
288
306
  return self.connection.close()
289
307
 
290
308
  # ##### Generic wrappers for PEP-249 connection methods #####
291
309
 
292
- def cursor(self):
310
+ def cursor(self) -> utils.CursorWrapper:
293
311
  """Create a cursor, opening a connection if necessary."""
294
312
  return self._cursor()
295
313
 
296
- def commit(self):
314
+ def commit(self) -> None:
297
315
  """Commit a transaction and reset the dirty flag."""
298
316
  self.validate_thread_sharing()
299
317
  self.validate_no_atomic_block()
@@ -302,7 +320,7 @@ class BaseDatabaseWrapper:
302
320
  self.errors_occurred = False
303
321
  self.run_commit_hooks_on_set_autocommit_on = True
304
322
 
305
- def rollback(self):
323
+ def rollback(self) -> None:
306
324
  """Roll back a transaction and reset the dirty flag."""
307
325
  self.validate_thread_sharing()
308
326
  self.validate_no_atomic_block()
@@ -312,7 +330,7 @@ class BaseDatabaseWrapper:
312
330
  self.needs_rollback = False
313
331
  self.run_on_commit = []
314
332
 
315
- def close(self):
333
+ def close(self) -> None:
316
334
  """Close the connection to the database."""
317
335
  self.validate_thread_sharing()
318
336
  self.run_on_commit = []
@@ -333,32 +351,32 @@ class BaseDatabaseWrapper:
333
351
 
334
352
  # ##### Backend-specific savepoint management methods #####
335
353
 
336
- def _savepoint(self, sid):
354
+ def _savepoint(self, sid: str) -> None:
337
355
  with self.cursor() as cursor:
338
356
  cursor.execute(self.ops.savepoint_create_sql(sid))
339
357
 
340
- def _savepoint_rollback(self, sid):
358
+ def _savepoint_rollback(self, sid: str) -> None:
341
359
  with self.cursor() as cursor:
342
360
  cursor.execute(self.ops.savepoint_rollback_sql(sid))
343
361
 
344
- def _savepoint_commit(self, sid):
362
+ def _savepoint_commit(self, sid: str) -> None:
345
363
  with self.cursor() as cursor:
346
364
  cursor.execute(self.ops.savepoint_commit_sql(sid))
347
365
 
348
- def _savepoint_allowed(self):
366
+ def _savepoint_allowed(self) -> bool:
349
367
  # Savepoints cannot be created outside a transaction
350
368
  return self.features.uses_savepoints and not self.get_autocommit()
351
369
 
352
370
  # ##### Generic savepoint management methods #####
353
371
 
354
- def savepoint(self):
372
+ def savepoint(self) -> str | None:
355
373
  """
356
374
  Create a savepoint inside the current transaction. Return an
357
375
  identifier for the savepoint that will be used for the subsequent
358
376
  rollback or commit. Do nothing if savepoints are not supported.
359
377
  """
360
378
  if not self._savepoint_allowed():
361
- return
379
+ return None
362
380
 
363
381
  thread_ident = _thread.get_ident()
364
382
  tid = str(thread_ident).replace("-", "")
@@ -371,7 +389,7 @@ class BaseDatabaseWrapper:
371
389
 
372
390
  return sid
373
391
 
374
- def savepoint_rollback(self, sid):
392
+ def savepoint_rollback(self, sid: str) -> None:
375
393
  """
376
394
  Roll back to a savepoint. Do nothing if savepoints are not supported.
377
395
  """
@@ -388,7 +406,7 @@ class BaseDatabaseWrapper:
388
406
  if sid not in sids
389
407
  ]
390
408
 
391
- def savepoint_commit(self, sid):
409
+ def savepoint_commit(self, sid: str) -> None:
392
410
  """
393
411
  Release a savepoint. Do nothing if savepoints are not supported.
394
412
  """
@@ -398,7 +416,7 @@ class BaseDatabaseWrapper:
398
416
  self.validate_thread_sharing()
399
417
  self._savepoint_commit(sid)
400
418
 
401
- def clean_savepoints(self):
419
+ def clean_savepoints(self) -> None:
402
420
  """
403
421
  Reset the counter used to generate unique savepoint ids in this thread.
404
422
  """
@@ -406,7 +424,7 @@ class BaseDatabaseWrapper:
406
424
 
407
425
  # ##### Backend-specific transaction management methods #####
408
426
 
409
- def _set_autocommit(self, autocommit):
427
+ def _set_autocommit(self, autocommit: bool) -> None:
410
428
  """
411
429
  Backend-specific implementation to enable or disable autocommit.
412
430
  """
@@ -416,14 +434,16 @@ class BaseDatabaseWrapper:
416
434
 
417
435
  # ##### Generic transaction management methods #####
418
436
 
419
- def get_autocommit(self):
437
+ def get_autocommit(self) -> bool:
420
438
  """Get the autocommit state."""
421
439
  self.ensure_connection()
422
440
  return self.autocommit
423
441
 
424
442
  def set_autocommit(
425
- self, autocommit, force_begin_transaction_with_broken_autocommit=False
426
- ):
443
+ self,
444
+ autocommit: bool,
445
+ force_begin_transaction_with_broken_autocommit: bool = False,
446
+ ) -> None:
427
447
  """
428
448
  Enable or disable autocommit.
429
449
 
@@ -446,7 +466,7 @@ class BaseDatabaseWrapper:
446
466
  )
447
467
 
448
468
  if start_transaction_under_autocommit:
449
- self._start_transaction_under_autocommit()
469
+ self._start_transaction_under_autocommit() # type: ignore[attr-defined]
450
470
  elif autocommit:
451
471
  self._set_autocommit(autocommit)
452
472
  else:
@@ -458,7 +478,7 @@ class BaseDatabaseWrapper:
458
478
  self.run_and_clear_commit_hooks()
459
479
  self.run_commit_hooks_on_set_autocommit_on = False
460
480
 
461
- def get_rollback(self):
481
+ def get_rollback(self) -> bool:
462
482
  """Get the "needs rollback" flag -- for *advanced use* only."""
463
483
  if not self.in_atomic_block:
464
484
  raise TransactionManagementError(
@@ -466,7 +486,7 @@ class BaseDatabaseWrapper:
466
486
  )
467
487
  return self.needs_rollback
468
488
 
469
- def set_rollback(self, rollback):
489
+ def set_rollback(self, rollback: bool) -> None:
470
490
  """
471
491
  Set or unset the "needs rollback" flag -- for *advanced use* only.
472
492
  """
@@ -476,14 +496,14 @@ class BaseDatabaseWrapper:
476
496
  )
477
497
  self.needs_rollback = rollback
478
498
 
479
- def validate_no_atomic_block(self):
499
+ def validate_no_atomic_block(self) -> None:
480
500
  """Raise an error if an atomic block is active."""
481
501
  if self.in_atomic_block:
482
502
  raise TransactionManagementError(
483
503
  "This is forbidden when an 'atomic' block is active."
484
504
  )
485
505
 
486
- def validate_no_broken_transaction(self):
506
+ def validate_no_broken_transaction(self) -> None:
487
507
  if self.needs_rollback:
488
508
  raise TransactionManagementError(
489
509
  "An error occurred in the current transaction. You can't "
@@ -492,7 +512,7 @@ class BaseDatabaseWrapper:
492
512
 
493
513
  # ##### Foreign key constraints checks handling #####
494
514
 
495
- def disable_constraint_checking(self):
515
+ def disable_constraint_checking(self) -> bool:
496
516
  """
497
517
  Backends can implement as needed to temporarily disable foreign key
498
518
  constraint checking. Should return True if the constraints were
@@ -500,14 +520,14 @@ class BaseDatabaseWrapper:
500
520
  """
501
521
  return False
502
522
 
503
- def enable_constraint_checking(self):
523
+ def enable_constraint_checking(self) -> None:
504
524
  """
505
525
  Backends can implement as needed to re-enable foreign key constraint
506
526
  checking.
507
527
  """
508
528
  pass
509
529
 
510
- def check_constraints(self, table_names=None):
530
+ def check_constraints(self, table_names: list[str] | None = None) -> None:
511
531
  """
512
532
  Backends can override this method if they can apply constraint
513
533
  checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
@@ -517,7 +537,7 @@ class BaseDatabaseWrapper:
517
537
 
518
538
  # ##### Connection termination handling #####
519
539
 
520
- def is_usable(self):
540
+ def is_usable(self) -> bool:
521
541
  """
522
542
  Test if the database connection is usable.
523
543
 
@@ -530,7 +550,7 @@ class BaseDatabaseWrapper:
530
550
  "subclasses of BaseDatabaseWrapper may require an is_usable() method"
531
551
  )
532
552
 
533
- def close_if_health_check_failed(self):
553
+ def close_if_health_check_failed(self) -> None:
534
554
  """Close existing connection if it fails a health check."""
535
555
  if (
536
556
  self.connection is None
@@ -543,7 +563,7 @@ class BaseDatabaseWrapper:
543
563
  self.close()
544
564
  self.health_check_done = True
545
565
 
546
- def close_if_unusable_or_obsolete(self):
566
+ def close_if_unusable_or_obsolete(self) -> None:
547
567
  """
548
568
  Close the current connection if unrecoverable errors have occurred
549
569
  or if it outlived its maximum age.
@@ -573,11 +593,11 @@ class BaseDatabaseWrapper:
573
593
  # ##### Thread safety handling #####
574
594
 
575
595
  @property
576
- def allow_thread_sharing(self):
596
+ def allow_thread_sharing(self) -> bool:
577
597
  with self._thread_sharing_lock:
578
598
  return self._thread_sharing_count > 0
579
599
 
580
- def validate_thread_sharing(self):
600
+ def validate_thread_sharing(self) -> None:
581
601
  """
582
602
  Validate that the connection isn't accessed by another thread than the
583
603
  one which originally created it, unless the connection was explicitly
@@ -594,7 +614,7 @@ class BaseDatabaseWrapper:
594
614
 
595
615
  # ##### Miscellaneous #####
596
616
 
597
- def prepare_database(self):
617
+ def prepare_database(self) -> None:
598
618
  """
599
619
  Hook to do any database check or preparation, generally called before
600
620
  migrating a project or an app.
@@ -602,30 +622,30 @@ class BaseDatabaseWrapper:
602
622
  pass
603
623
 
604
624
  @cached_property
605
- def wrap_database_errors(self):
625
+ def wrap_database_errors(self) -> DatabaseErrorWrapper:
606
626
  """
607
627
  Context manager and decorator that re-throws backend-specific database
608
628
  exceptions using Plain's common wrappers.
609
629
  """
610
630
  return DatabaseErrorWrapper(self)
611
631
 
612
- def chunked_cursor(self):
632
+ def chunked_cursor(self) -> utils.CursorWrapper:
613
633
  """
614
634
  Return a cursor that tries to avoid caching in the database (if
615
635
  supported by the database), otherwise return a regular cursor.
616
636
  """
617
637
  return self.cursor()
618
638
 
619
- def make_debug_cursor(self, cursor):
639
+ def make_debug_cursor(self, cursor: Any) -> utils.CursorDebugWrapper:
620
640
  """Create a cursor that logs all queries in self.queries_log."""
621
641
  return utils.CursorDebugWrapper(cursor, self)
622
642
 
623
- def make_cursor(self, cursor):
643
+ def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
624
644
  """Create a cursor without debug logging."""
625
645
  return utils.CursorWrapper(cursor, self)
626
646
 
627
647
  @contextmanager
628
- def temporary_connection(self):
648
+ def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
629
649
  """
630
650
  Context manager that ensures that a connection is established, and
631
651
  if it opened one, closes it to avoid leaving a dangling connection.
@@ -642,7 +662,7 @@ class BaseDatabaseWrapper:
642
662
  self.close()
643
663
 
644
664
  @contextmanager
645
- def _nodb_cursor(self):
665
+ def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
646
666
  """
647
667
  Return a cursor from an alternative connection to be used when there is
648
668
  no need to access the main database, specifically for test db
@@ -657,7 +677,7 @@ class BaseDatabaseWrapper:
657
677
  finally:
658
678
  conn.close()
659
679
 
660
- def schema_editor(self, *args, **kwargs):
680
+ def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor:
661
681
  """
662
682
  Return a new instance of this backend's SchemaEditor.
663
683
  """
@@ -667,7 +687,7 @@ class BaseDatabaseWrapper:
667
687
  )
668
688
  return self.SchemaEditorClass(self, *args, **kwargs)
669
689
 
670
- def on_commit(self, func, robust=False):
690
+ def on_commit(self, func: Any, robust: bool = False) -> None:
671
691
  if not callable(func):
672
692
  raise TypeError("on_commit()'s callback must be a callable.")
673
693
  if self.in_atomic_block:
@@ -692,7 +712,7 @@ class BaseDatabaseWrapper:
692
712
  else:
693
713
  func()
694
714
 
695
- def run_and_clear_commit_hooks(self):
715
+ def run_and_clear_commit_hooks(self) -> None:
696
716
  self.validate_no_atomic_block()
697
717
  current_run_on_commit = self.run_on_commit
698
718
  self.run_on_commit = []
@@ -712,7 +732,7 @@ class BaseDatabaseWrapper:
712
732
  func()
713
733
 
714
734
  @contextmanager
715
- def execute_wrapper(self, wrapper):
735
+ def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
716
736
  """
717
737
  Return a context manager under which the wrapper is applied to suitable
718
738
  database query executions.
@@ -723,7 +743,7 @@ class BaseDatabaseWrapper:
723
743
  finally:
724
744
  self.execute_wrappers.pop()
725
745
 
726
- def copy(self):
746
+ def copy(self) -> BaseDatabaseWrapper:
727
747
  """
728
748
  Return a copy of this connection.
729
749
 
@@ -1,5 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import subprocess
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ if TYPE_CHECKING:
8
+ from plain.models.backends.base.base import BaseDatabaseWrapper
3
9
 
4
10
 
5
11
  class BaseDatabaseClient:
@@ -9,18 +15,20 @@ class BaseDatabaseClient:
9
15
  # (e.g., "psql"). Subclasses must override this.
10
16
  executable_name = None
11
17
 
12
- def __init__(self, connection):
18
+ def __init__(self, connection: BaseDatabaseWrapper) -> None:
13
19
  # connection is an instance of BaseDatabaseWrapper.
14
20
  self.connection = connection
15
21
 
16
22
  @classmethod
17
- def settings_to_cmd_args_env(cls, settings_dict, parameters):
23
+ def settings_to_cmd_args_env(
24
+ cls, settings_dict: dict[str, Any], parameters: list[str]
25
+ ) -> tuple[list[str], dict[str, str] | None]:
18
26
  raise NotImplementedError(
19
27
  "subclasses of BaseDatabaseClient must provide a "
20
28
  "settings_to_cmd_args_env() method or override a runshell()."
21
29
  )
22
30
 
23
- def runshell(self, parameters):
31
+ def runshell(self, parameters: list[str]) -> None:
24
32
  args, env = self.settings_to_cmd_args_env(
25
33
  self.connection.settings_dict, parameters
26
34
  )