plain.models 0.49.2__py3-none-any.whl → 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. plain/models/CHANGELOG.md +13 -0
  2. plain/models/aggregates.py +42 -19
  3. plain/models/backends/base/base.py +125 -105
  4. plain/models/backends/base/client.py +11 -3
  5. plain/models/backends/base/creation.py +22 -12
  6. plain/models/backends/base/features.py +10 -4
  7. plain/models/backends/base/introspection.py +29 -16
  8. plain/models/backends/base/operations.py +187 -91
  9. plain/models/backends/base/schema.py +267 -165
  10. plain/models/backends/base/validation.py +12 -3
  11. plain/models/backends/ddl_references.py +85 -43
  12. plain/models/backends/mysql/base.py +29 -26
  13. plain/models/backends/mysql/client.py +7 -2
  14. plain/models/backends/mysql/compiler.py +12 -3
  15. plain/models/backends/mysql/creation.py +5 -2
  16. plain/models/backends/mysql/features.py +24 -22
  17. plain/models/backends/mysql/introspection.py +22 -13
  18. plain/models/backends/mysql/operations.py +106 -39
  19. plain/models/backends/mysql/schema.py +48 -24
  20. plain/models/backends/mysql/validation.py +13 -6
  21. plain/models/backends/postgresql/base.py +41 -34
  22. plain/models/backends/postgresql/client.py +7 -2
  23. plain/models/backends/postgresql/creation.py +10 -5
  24. plain/models/backends/postgresql/introspection.py +15 -8
  25. plain/models/backends/postgresql/operations.py +109 -42
  26. plain/models/backends/postgresql/schema.py +85 -46
  27. plain/models/backends/sqlite3/_functions.py +151 -115
  28. plain/models/backends/sqlite3/base.py +37 -23
  29. plain/models/backends/sqlite3/client.py +7 -1
  30. plain/models/backends/sqlite3/creation.py +9 -5
  31. plain/models/backends/sqlite3/features.py +5 -3
  32. plain/models/backends/sqlite3/introspection.py +32 -16
  33. plain/models/backends/sqlite3/operations.py +125 -42
  34. plain/models/backends/sqlite3/schema.py +82 -58
  35. plain/models/backends/utils.py +52 -29
  36. plain/models/backups/cli.py +8 -6
  37. plain/models/backups/clients.py +16 -7
  38. plain/models/backups/core.py +24 -13
  39. plain/models/base.py +113 -74
  40. plain/models/cli.py +94 -63
  41. plain/models/config.py +1 -1
  42. plain/models/connections.py +23 -7
  43. plain/models/constraints.py +65 -47
  44. plain/models/database_url.py +1 -1
  45. plain/models/db.py +6 -2
  46. plain/models/deletion.py +66 -43
  47. plain/models/entrypoints.py +1 -1
  48. plain/models/enums.py +22 -11
  49. plain/models/exceptions.py +23 -8
  50. plain/models/expressions.py +440 -257
  51. plain/models/fields/__init__.py +253 -202
  52. plain/models/fields/json.py +120 -54
  53. plain/models/fields/mixins.py +12 -8
  54. plain/models/fields/related.py +284 -252
  55. plain/models/fields/related_descriptors.py +31 -22
  56. plain/models/fields/related_lookups.py +23 -11
  57. plain/models/fields/related_managers.py +81 -47
  58. plain/models/fields/reverse_related.py +58 -55
  59. plain/models/forms.py +89 -63
  60. plain/models/functions/comparison.py +71 -18
  61. plain/models/functions/datetime.py +79 -29
  62. plain/models/functions/math.py +43 -10
  63. plain/models/functions/mixins.py +24 -7
  64. plain/models/functions/text.py +104 -25
  65. plain/models/functions/window.py +12 -6
  66. plain/models/indexes.py +52 -28
  67. plain/models/lookups.py +228 -153
  68. plain/models/migrations/autodetector.py +86 -43
  69. plain/models/migrations/exceptions.py +7 -3
  70. plain/models/migrations/executor.py +33 -7
  71. plain/models/migrations/graph.py +79 -50
  72. plain/models/migrations/loader.py +45 -22
  73. plain/models/migrations/migration.py +23 -18
  74. plain/models/migrations/operations/base.py +37 -19
  75. plain/models/migrations/operations/fields.py +89 -42
  76. plain/models/migrations/operations/models.py +245 -143
  77. plain/models/migrations/operations/special.py +82 -25
  78. plain/models/migrations/optimizer.py +7 -2
  79. plain/models/migrations/questioner.py +58 -31
  80. plain/models/migrations/recorder.py +18 -11
  81. plain/models/migrations/serializer.py +50 -39
  82. plain/models/migrations/state.py +220 -133
  83. plain/models/migrations/utils.py +29 -13
  84. plain/models/migrations/writer.py +17 -14
  85. plain/models/options.py +63 -56
  86. plain/models/otel.py +16 -6
  87. plain/models/preflight.py +35 -12
  88. plain/models/query.py +323 -228
  89. plain/models/query_utils.py +93 -58
  90. plain/models/registry.py +34 -16
  91. plain/models/sql/compiler.py +146 -97
  92. plain/models/sql/datastructures.py +38 -25
  93. plain/models/sql/query.py +255 -169
  94. plain/models/sql/subqueries.py +32 -21
  95. plain/models/sql/where.py +54 -29
  96. plain/models/test/pytest.py +15 -11
  97. plain/models/test/utils.py +4 -2
  98. plain/models/transaction.py +20 -7
  99. plain/models/utils.py +13 -5
  100. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/METADATA +1 -1
  101. plain_models-0.50.0.dist-info/RECORD +122 -0
  102. plain_models-0.49.2.dist-info/RECORD +0 -122
  103. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/WHEEL +0 -0
  104. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/entry_points.txt +0 -0
  105. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from decimal import Decimal
5
+ from typing import TYPE_CHECKING, Any
3
6
 
4
7
  from plain.models import register_model
5
8
  from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
@@ -10,6 +13,11 @@ from plain.models.db import NotSupportedError
10
13
  from plain.models.registry import ModelsRegistry
11
14
  from plain.models.transaction import atomic
12
15
 
16
+ if TYPE_CHECKING:
17
+ from plain.models.base import Model
18
+ from plain.models.constraints import BaseConstraint
19
+ from plain.models.fields import Field
20
+
13
21
 
14
22
  class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
15
23
  sql_delete_table = "DROP TABLE %(table)s"
@@ -22,7 +30,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
22
30
  sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
23
31
  sql_delete_unique = "DROP INDEX %(name)s"
24
32
 
25
- def __enter__(self):
33
+ def __enter__(self) -> DatabaseSchemaEditor:
26
34
  # Some SQLite schema alterations need foreign key constraints to be
27
35
  # disabled. Enforce it here for the duration of the schema edition.
28
36
  if not self.connection.disable_constraint_checking():
@@ -35,19 +43,19 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
35
43
  )
36
44
  return super().__enter__()
37
45
 
38
- def __exit__(self, exc_type, exc_value, traceback):
46
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
39
47
  self.connection.check_constraints()
40
48
  super().__exit__(exc_type, exc_value, traceback)
41
49
  self.connection.enable_constraint_checking()
42
50
 
43
- def quote_value(self, value):
51
+ def quote_value(self, value: Any) -> str:
44
52
  # The backend "mostly works" without this function and there are use
45
53
  # cases for compiling Python without the sqlite3 libraries (e.g.
46
54
  # security hardening).
47
55
  try:
48
56
  import sqlite3
49
57
 
50
- value = sqlite3.adapt(value)
58
+ value = sqlite3.adapt(value) # type: ignore[call-overload]
51
59
  except ImportError:
52
60
  pass
53
61
  except sqlite3.ProgrammingError:
@@ -71,12 +79,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
71
79
  f"Cannot quote parameter value {value!r} of type {type(value)}"
72
80
  )
73
81
 
74
- def prepare_default(self, value):
82
+ def prepare_default(self, value: Any) -> str:
75
83
  return self.quote_value(value)
76
84
 
77
85
  def _is_referenced_by_fk_constraint(
78
- self, table_name, column_name=None, ignore_self=False
79
- ):
86
+ self, table_name: str, column_name: str | None = None, ignore_self: bool = False
87
+ ) -> bool:
80
88
  """
81
89
  Return whether or not the provided table name is referenced by another
82
90
  one. If `column_name` is specified, only references pointing to that
@@ -98,8 +106,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
98
106
  return False
99
107
 
100
108
  def alter_db_table(
101
- self, model, old_db_table, new_db_table, disable_constraints=True
102
- ):
109
+ self,
110
+ model: type[Model],
111
+ old_db_table: str,
112
+ new_db_table: str,
113
+ disable_constraints: bool = True,
114
+ ) -> None:
103
115
  if (
104
116
  not self.connection.features.supports_atomic_references_rename
105
117
  and disable_constraints
@@ -117,7 +129,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
117
129
  else:
118
130
  super().alter_db_table(model, old_db_table, new_db_table)
119
131
 
120
- def alter_field(self, model, old_field, new_field, strict=False):
132
+ def alter_field(
133
+ self,
134
+ model: type[Model],
135
+ old_field: Field,
136
+ new_field: Field,
137
+ strict: bool = False,
138
+ ) -> None:
121
139
  if not self._field_should_be_altered(old_field, new_field):
122
140
  return
123
141
  old_field_name = old_field.name
@@ -168,8 +186,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
168
186
  super().alter_field(model, old_field, new_field, strict=strict)
169
187
 
170
188
  def _remake_table(
171
- self, model, create_field=None, delete_field=None, alter_fields=None
172
- ):
189
+ self,
190
+ model: type[Model],
191
+ create_field: Field | None = None,
192
+ delete_field: Field | None = None,
193
+ alter_fields: list[tuple[Field, Field]] | None = None,
194
+ ) -> None:
173
195
  """
174
196
  Shortcut to transform a model from old_model into new_model
175
197
 
@@ -189,8 +211,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
189
211
  # Self-referential fields must be recreated rather than copied from
190
212
  # the old model to ensure their remote_field.field_name doesn't refer
191
213
  # to an altered field.
192
- def is_self_referential(f):
193
- return f.is_relation and f.remote_field.model is model
214
+ def is_self_referential(f: Field) -> bool:
215
+ return f.is_relation and f.remote_field.model is model # type: ignore[attr-defined]
194
216
 
195
217
  # Work out the new fields dict / mapping
196
218
  body = {
@@ -276,7 +298,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
276
298
  meta = type("Meta", (), meta_contents)
277
299
  body_copy["Meta"] = meta
278
300
  body_copy["__module__"] = model.__module__
279
- register_model(type(model._meta.object_name, model.__bases__, body_copy))
301
+ register_model(type(model._meta.object_name, model.__bases__, body_copy)) # type: ignore[arg-type]
280
302
 
281
303
  # Construct a model with a renamed table name.
282
304
  body_copy = copy.deepcopy(body)
@@ -291,7 +313,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
291
313
  body_copy["Meta"] = meta
292
314
  body_copy["__module__"] = model.__module__
293
315
  new_model = type(f"New{model._meta.object_name}", model.__bases__, body_copy)
294
- register_model(new_model)
316
+ register_model(new_model) # type: ignore[arg-type]
295
317
 
296
318
  # Create a new table with the updated schema.
297
319
  self.create_model(new_model)
@@ -299,7 +321,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
299
321
  # Copy data from the old table into the new table
300
322
  self.execute(
301
323
  "INSERT INTO {} ({}) SELECT {} FROM {}".format(
302
- self.quote_name(new_model._meta.db_table),
324
+ self.quote_name(new_model._meta.db_table), # type: ignore[attr-defined]
303
325
  ", ".join(self.quote_name(x) for x in mapping),
304
326
  ", ".join(mapping.values()),
305
327
  self.quote_name(model._meta.db_table),
@@ -311,8 +333,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
311
333
 
312
334
  # Rename the new table to take way for the old
313
335
  self.alter_db_table(
314
- new_model,
315
- new_model._meta.db_table,
336
+ new_model, # type: ignore[arg-type]
337
+ new_model._meta.db_table, # type: ignore[attr-defined]
316
338
  model._meta.db_table,
317
339
  disable_constraints=False,
318
340
  )
@@ -325,7 +347,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
325
347
  if restore_pk_field:
326
348
  restore_pk_field.primary_key = True
327
349
 
328
- def delete_model(self, model, handle_autom2m=True):
350
+ def delete_model(self, model: type[Model], handle_autom2m: bool = True) -> None:
329
351
  if handle_autom2m:
330
352
  super().delete_model(model)
331
353
  else:
@@ -343,7 +365,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
343
365
  ):
344
366
  self.deferred_sql.remove(sql)
345
367
 
346
- def add_field(self, model, field):
368
+ def add_field(self, model: type[Model], field: Field) -> None:
347
369
  """Create a field on a model."""
348
370
  if (
349
371
  # Primary keys are not supported in ALTER TABLE
@@ -360,7 +382,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
360
382
  else:
361
383
  super().add_field(model, field)
362
384
 
363
- def remove_field(self, model, field):
385
+ def remove_field(self, model: type[Model], field: Field) -> None:
364
386
  """
365
387
  Remove a field from a model. Usually involves deleting a column,
366
388
  but for M2Ms may involve deleting a table.
@@ -374,8 +396,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
374
396
  # Primary keys, unique fields, indexed fields, and foreign keys are
375
397
  # not supported in ALTER TABLE DROP COLUMN.
376
398
  and not field.primary_key
377
- and not (field.remote_field and field.db_index)
378
- and not (field.remote_field and field.db_constraint)
399
+ and not (field.remote_field and field.db_index) # type: ignore[attr-defined]
400
+ and not (field.remote_field and field.db_constraint) # type: ignore[attr-defined]
379
401
  ):
380
402
  super().remove_field(model, field)
381
403
  # For everything else, remake.
@@ -387,15 +409,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
387
409
 
388
410
  def _alter_field(
389
411
  self,
390
- model,
391
- old_field,
392
- new_field,
393
- old_type,
394
- new_type,
395
- old_db_params,
396
- new_db_params,
397
- strict=False,
398
- ):
412
+ model: type[Model],
413
+ old_field: Field,
414
+ new_field: Field,
415
+ old_type: str,
416
+ new_type: str,
417
+ old_db_params: dict[str, Any],
418
+ new_db_params: dict[str, Any],
419
+ strict: bool = False,
420
+ ) -> None:
399
421
  """Perform a "physical" (non-ManyToMany) field update."""
400
422
  # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
401
423
  # changed and there aren't any constraints.
@@ -405,9 +427,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
405
427
  and self.column_sql(model, old_field) == self.column_sql(model, new_field)
406
428
  and not (
407
429
  old_field.remote_field
408
- and old_field.db_constraint
430
+ and old_field.db_constraint # type: ignore[attr-defined]
409
431
  or new_field.remote_field
410
- and new_field.db_constraint
432
+ and new_field.db_constraint # type: ignore[attr-defined]
411
433
  )
412
434
  ):
413
435
  return self.execute(
@@ -440,37 +462,39 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
440
462
  for related_model in related_models:
441
463
  self._remake_table(related_model)
442
464
 
443
- def _alter_many_to_many(self, model, old_field, new_field, strict):
465
+ def _alter_many_to_many(
466
+ self, model: type[Model], old_field: Field, new_field: Field, strict: bool
467
+ ) -> None:
444
468
  """Alter M2Ms to repoint their to= endpoints."""
445
469
  if (
446
- old_field.remote_field.through._meta.db_table
447
- == new_field.remote_field.through._meta.db_table
470
+ old_field.remote_field.through._meta.db_table # type: ignore[attr-defined]
471
+ == new_field.remote_field.through._meta.db_table # type: ignore[attr-defined]
448
472
  ):
449
473
  # The field name didn't change, but some options did, so we have to
450
474
  # propagate this altering.
451
475
  self._remake_table(
452
- old_field.remote_field.through,
476
+ old_field.remote_field.through, # type: ignore[attr-defined]
453
477
  alter_fields=[
454
478
  (
455
479
  # The field that points to the target model is needed,
456
480
  # so that table can be remade with the new m2m field -
457
481
  # this is m2m_reverse_field_name().
458
- old_field.remote_field.through._meta.get_field(
459
- old_field.m2m_reverse_field_name()
482
+ old_field.remote_field.through._meta.get_field( # type: ignore[attr-defined]
483
+ old_field.m2m_reverse_field_name() # type: ignore[attr-defined]
460
484
  ),
461
- new_field.remote_field.through._meta.get_field(
462
- new_field.m2m_reverse_field_name()
485
+ new_field.remote_field.through._meta.get_field( # type: ignore[attr-defined]
486
+ new_field.m2m_reverse_field_name() # type: ignore[attr-defined]
463
487
  ),
464
488
  ),
465
489
  (
466
490
  # The field that points to the model itself is needed,
467
491
  # so that table can be remade with the new self field -
468
492
  # this is m2m_field_name().
469
- old_field.remote_field.through._meta.get_field(
470
- old_field.m2m_field_name()
493
+ old_field.remote_field.through._meta.get_field( # type: ignore[attr-defined]
494
+ old_field.m2m_field_name() # type: ignore[attr-defined]
471
495
  ),
472
- new_field.remote_field.through._meta.get_field(
473
- new_field.m2m_field_name()
496
+ new_field.remote_field.through._meta.get_field( # type: ignore[attr-defined]
497
+ new_field.m2m_field_name() # type: ignore[attr-defined]
474
498
  ),
475
499
  ),
476
500
  ],
@@ -478,32 +502,32 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
478
502
  return
479
503
 
480
504
  # Make a new through table
481
- self.create_model(new_field.remote_field.through)
505
+ self.create_model(new_field.remote_field.through) # type: ignore[attr-defined]
482
506
  # Copy the data across
483
507
  self.execute(
484
508
  "INSERT INTO {} ({}) SELECT {} FROM {}".format(
485
- self.quote_name(new_field.remote_field.through._meta.db_table),
509
+ self.quote_name(new_field.remote_field.through._meta.db_table), # type: ignore[attr-defined]
486
510
  ", ".join(
487
511
  [
488
512
  "id",
489
- new_field.m2m_column_name(),
490
- new_field.m2m_reverse_name(),
513
+ new_field.m2m_column_name(), # type: ignore[attr-defined]
514
+ new_field.m2m_reverse_name(), # type: ignore[attr-defined]
491
515
  ]
492
516
  ),
493
517
  ", ".join(
494
518
  [
495
519
  "id",
496
- old_field.m2m_column_name(),
497
- old_field.m2m_reverse_name(),
520
+ old_field.m2m_column_name(), # type: ignore[attr-defined]
521
+ old_field.m2m_reverse_name(), # type: ignore[attr-defined]
498
522
  ]
499
523
  ),
500
- self.quote_name(old_field.remote_field.through._meta.db_table),
524
+ self.quote_name(old_field.remote_field.through._meta.db_table), # type: ignore[attr-defined]
501
525
  )
502
526
  )
503
527
  # Delete the old through table
504
- self.delete_model(old_field.remote_field.through)
528
+ self.delete_model(old_field.remote_field.through) # type: ignore[attr-defined]
505
529
 
506
- def add_constraint(self, model, constraint):
530
+ def add_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
507
531
  if isinstance(constraint, UniqueConstraint) and (
508
532
  constraint.condition
509
533
  or constraint.contains_expressions
@@ -514,7 +538,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
514
538
  else:
515
539
  self._remake_table(model)
516
540
 
517
- def remove_constraint(self, model, constraint):
541
+ def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
518
542
  if isinstance(constraint, UniqueConstraint) and (
519
543
  constraint.condition
520
544
  or constraint.contains_expressions
@@ -525,5 +549,5 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
525
549
  else:
526
550
  self._remake_table(model)
527
551
 
528
- def _collate_sql(self, collation):
552
+ def _collate_sql(self, collation: str) -> str:
529
553
  return "COLLATE " + collation
@@ -1,52 +1,59 @@
1
+ from __future__ import annotations
2
+
1
3
  import datetime
2
4
  import decimal
3
5
  import functools
4
6
  import logging
5
7
  import time
8
+ from collections.abc import Generator, Iterator
6
9
  from contextlib import contextmanager
7
10
  from hashlib import md5
11
+ from typing import TYPE_CHECKING, Any
8
12
 
9
13
  from plain.models.db import NotSupportedError
10
14
  from plain.models.otel import db_span
11
15
  from plain.utils.dateparse import parse_time
12
16
 
17
+ if TYPE_CHECKING:
18
+ from plain.models.backends.base.base import BaseDatabaseWrapper
19
+
13
20
  logger = logging.getLogger("plain.models.backends")
14
21
 
15
22
 
16
23
  class CursorWrapper:
17
- def __init__(self, cursor, db):
24
+ def __init__(self, cursor: Any, db: Any) -> None:
18
25
  self.cursor = cursor
19
26
  self.db = db
20
27
 
21
28
  WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
22
29
 
23
- def __getattr__(self, attr):
30
+ def __getattr__(self, attr: str) -> Any:
24
31
  cursor_attr = getattr(self.cursor, attr)
25
32
  if attr in CursorWrapper.WRAP_ERROR_ATTRS:
26
33
  return self.db.wrap_database_errors(cursor_attr)
27
34
  else:
28
35
  return cursor_attr
29
36
 
30
- def __iter__(self):
37
+ def __iter__(self) -> Iterator[Any]:
31
38
  with self.db.wrap_database_errors:
32
39
  yield from self.cursor
33
40
 
34
- def __enter__(self):
41
+ def __enter__(self) -> CursorWrapper:
35
42
  return self
36
43
 
37
- def __exit__(self, type, value, traceback):
44
+ def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
38
45
  # Close instead of passing through to avoid backend-specific behavior
39
46
  # (#17671). Catch errors liberally because errors in cleanup code
40
47
  # aren't useful.
41
48
  try:
42
- self.close()
49
+ self.close() # type: ignore[attr-defined]
43
50
  except self.db.Database.Error:
44
51
  pass
45
52
 
46
53
  # The following methods cannot be implemented in __getattr__, because the
47
54
  # code must run when the method is invoked, not just when it is accessed.
48
55
 
49
- def callproc(self, procname, params=None, kparams=None):
56
+ def callproc(self, procname: str, params: Any = None, kparams: Any = None) -> Any:
50
57
  # Keyword parameters for callproc aren't supported in PEP 249, but the
51
58
  # database driver may support them (e.g. cx_Oracle).
52
59
  if kparams is not None and not self.db.features.supports_callproc_kwargs:
@@ -64,23 +71,25 @@ class CursorWrapper:
64
71
  params = params or ()
65
72
  return self.cursor.callproc(procname, params, kparams)
66
73
 
67
- def execute(self, sql, params=None):
74
+ def execute(self, sql: str, params: Any = None) -> Any:
68
75
  return self._execute_with_wrappers(
69
76
  sql, params, many=False, executor=self._execute
70
77
  )
71
78
 
72
- def executemany(self, sql, param_list):
79
+ def executemany(self, sql: str, param_list: Any) -> Any:
73
80
  return self._execute_with_wrappers(
74
81
  sql, param_list, many=True, executor=self._executemany
75
82
  )
76
83
 
77
- def _execute_with_wrappers(self, sql, params, many, executor):
78
- context = {"connection": self.db, "cursor": self}
84
+ def _execute_with_wrappers(
85
+ self, sql: str, params: Any, many: bool, executor: Any
86
+ ) -> Any:
87
+ context: dict[str, Any] = {"connection": self.db, "cursor": self}
79
88
  for wrapper in reversed(self.db.execute_wrappers):
80
89
  executor = functools.partial(wrapper, executor)
81
90
  return executor(sql, params, many, context)
82
91
 
83
- def _execute(self, sql, params, *ignored_wrapper_args):
92
+ def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> Any:
84
93
  # Wrap in an OpenTelemetry span with standard attributes.
85
94
  with db_span(self.db, sql, params=params):
86
95
  self.db.validate_no_broken_transaction()
@@ -90,7 +99,9 @@ class CursorWrapper:
90
99
  else:
91
100
  return self.cursor.execute(sql, params)
92
101
 
93
- def _executemany(self, sql, param_list, *ignored_wrapper_args):
102
+ def _executemany(
103
+ self, sql: str, param_list: Any, *ignored_wrapper_args: Any
104
+ ) -> Any:
94
105
  with db_span(self.db, sql, many=True, params=param_list):
95
106
  self.db.validate_no_broken_transaction()
96
107
  with self.db.wrap_database_errors:
@@ -100,18 +111,22 @@ class CursorWrapper:
100
111
  class CursorDebugWrapper(CursorWrapper):
101
112
  # XXX callproc isn't instrumented at this time.
102
113
 
103
- def execute(self, sql, params=None):
114
+ def execute(self, sql: str, params: Any = None) -> Any:
104
115
  with self.debug_sql(sql, params, use_last_executed_query=True):
105
116
  return super().execute(sql, params)
106
117
 
107
- def executemany(self, sql, param_list):
118
+ def executemany(self, sql: str, param_list: Any) -> Any:
108
119
  with self.debug_sql(sql, param_list, many=True):
109
120
  return super().executemany(sql, param_list)
110
121
 
111
122
  @contextmanager
112
123
  def debug_sql(
113
- self, sql=None, params=None, use_last_executed_query=False, many=False
114
- ):
124
+ self,
125
+ sql: str | None = None,
126
+ params: Any = None,
127
+ use_last_executed_query: bool = False,
128
+ many: bool = False,
129
+ ) -> Generator[None, None, None]:
115
130
  start = time.monotonic()
116
131
  try:
117
132
  yield
@@ -121,7 +136,7 @@ class CursorDebugWrapper(CursorWrapper):
121
136
  if use_last_executed_query:
122
137
  sql = self.db.ops.last_executed_query(self.cursor, sql, params)
123
138
  try:
124
- times = len(params) if many else ""
139
+ times = len(params) if many else "" # type: ignore[arg-type]
125
140
  except TypeError:
126
141
  # params could be an iterator.
127
142
  times = "?"
@@ -145,7 +160,9 @@ class CursorDebugWrapper(CursorWrapper):
145
160
 
146
161
 
147
162
  @contextmanager
148
- def debug_transaction(connection, sql):
163
+ def debug_transaction(
164
+ connection: BaseDatabaseWrapper, sql: str
165
+ ) -> Generator[None, None, None]:
149
166
  start = time.monotonic()
150
167
  try:
151
168
  yield
@@ -171,7 +188,7 @@ def debug_transaction(connection, sql):
171
188
  )
172
189
 
173
190
 
174
- def split_tzname_delta(tzname):
191
+ def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
175
192
  """
176
193
  Split a time zone name into a 3-tuple of (name, sign, offset).
177
194
  """
@@ -188,13 +205,15 @@ def split_tzname_delta(tzname):
188
205
  ###############################################
189
206
 
190
207
 
191
- def typecast_date(s):
208
+ def typecast_date(s: str | None) -> datetime.date | None:
192
209
  return (
193
210
  datetime.date(*map(int, s.split("-"))) if s else None
194
211
  ) # return None if s is null
195
212
 
196
213
 
197
- def typecast_time(s): # does NOT store time zone information
214
+ def typecast_time(
215
+ s: str | None,
216
+ ) -> datetime.time | None: # does NOT store time zone information
198
217
  if not s:
199
218
  return None
200
219
  hour, minutes, seconds = s.split(":")
@@ -207,7 +226,9 @@ def typecast_time(s): # does NOT store time zone information
207
226
  )
208
227
 
209
228
 
210
- def typecast_timestamp(s): # does NOT store time zone information
229
+ def typecast_timestamp(
230
+ s: str | None,
231
+ ) -> datetime.date | datetime.datetime | None: # does NOT store time zone information
211
232
  # "2005-07-29 15:48:00.590358-05"
212
233
  # "2005-07-29 09:56:00-05"
213
234
  if not s:
@@ -243,7 +264,7 @@ def typecast_timestamp(s): # does NOT store time zone information
243
264
  ###############################################
244
265
 
245
266
 
246
- def split_identifier(identifier):
267
+ def split_identifier(identifier: str) -> tuple[str, str]:
247
268
  """
248
269
  Split an SQL identifier into a two element tuple of (namespace, name).
249
270
 
@@ -257,7 +278,7 @@ def split_identifier(identifier):
257
278
  return namespace.strip('"'), name.strip('"')
258
279
 
259
280
 
260
- def truncate_name(identifier, length=None, hash_len=4):
281
+ def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
261
282
  """
262
283
  Shorten an SQL identifier to a repeatable mangled version with the given
263
284
  length.
@@ -278,7 +299,7 @@ def truncate_name(identifier, length=None, hash_len=4):
278
299
  )
279
300
 
280
301
 
281
- def names_digest(*args, length):
302
+ def names_digest(*args: str, length: int) -> str:
282
303
  """
283
304
  Generate a 32-bit digest of a set of arguments that can be used to shorten
284
305
  identifying names.
@@ -289,7 +310,9 @@ def names_digest(*args, length):
289
310
  return h.hexdigest()[:length]
290
311
 
291
312
 
292
- def format_number(value, max_digits, decimal_places):
313
+ def format_number(
314
+ value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
315
+ ) -> str | None:
293
316
  """
294
317
  Format a number into a string with the requisite number of digits and
295
318
  decimal places.
@@ -304,12 +327,12 @@ def format_number(value, max_digits, decimal_places):
304
327
  decimal.Decimal(1).scaleb(-decimal_places), context=context
305
328
  )
306
329
  else:
307
- context.traps[decimal.Rounded] = 1
330
+ context.traps[decimal.Rounded] = 1 # type: ignore[assignment]
308
331
  value = context.create_decimal(value)
309
332
  return f"{value:f}"
310
333
 
311
334
 
312
- def strip_quotes(table_name):
335
+ def strip_quotes(table_name: str) -> str:
313
336
  """
314
337
  Strip quotes off of quoted table names to make them safe for use in index
315
338
  names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import time
3
5
  from pathlib import Path
@@ -11,13 +13,13 @@ from .core import DatabaseBackups
11
13
 
12
14
  @register_cli("backups")
13
15
  @click.group("backups")
14
- def cli():
16
+ def cli() -> None:
15
17
  """Local database backups"""
16
18
  pass
17
19
 
18
20
 
19
21
  @cli.command("list")
20
- def list_backups():
22
+ def list_backups() -> None:
21
23
  backups_handler = DatabaseBackups()
22
24
  backups = backups_handler.find_backups()
23
25
  if not backups:
@@ -40,7 +42,7 @@ def list_backups():
40
42
  @cli.command("create")
41
43
  @click.option("--pg-dump", default="pg_dump", envvar="PG_DUMP")
42
44
  @click.argument("backup_name", default="")
43
- def create_backup(backup_name, pg_dump):
45
+ def create_backup(backup_name: str, pg_dump: str) -> None:
44
46
  backups_handler = DatabaseBackups()
45
47
 
46
48
  if not backup_name:
@@ -62,7 +64,7 @@ def create_backup(backup_name, pg_dump):
62
64
  @click.option("--latest", is_flag=True)
63
65
  @click.option("--pg-restore", default="pg_restore", envvar="PG_RESTORE")
64
66
  @click.argument("backup_name", default="")
65
- def restore_backup(backup_name, latest, pg_restore):
67
+ def restore_backup(backup_name: str, latest: bool, pg_restore: str) -> None:
66
68
  backups_handler = DatabaseBackups()
67
69
 
68
70
  if backup_name and latest:
@@ -89,7 +91,7 @@ def restore_backup(backup_name, latest, pg_restore):
89
91
 
90
92
  @cli.command("delete")
91
93
  @click.argument("backup_name")
92
- def delete_backup(backup_name):
94
+ def delete_backup(backup_name: str) -> None:
93
95
  backups_handler = DatabaseBackups()
94
96
  try:
95
97
  backups_handler.delete(backup_name)
@@ -101,7 +103,7 @@ def delete_backup(backup_name):
101
103
 
102
104
  @cli.command("clear")
103
105
  @click.confirmation_option(prompt="Are you sure you want to delete all backups?")
104
- def clear_backups():
106
+ def clear_backups() -> None:
105
107
  backups_handler = DatabaseBackups()
106
108
  backups = backups_handler.find_backups()
107
109
  for backup in backups: