plain.models 0.49.2__py3-none-any.whl → 0.51.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. plain/models/CHANGELOG.md +27 -0
  2. plain/models/README.md +26 -42
  3. plain/models/__init__.py +2 -0
  4. plain/models/aggregates.py +42 -19
  5. plain/models/backends/base/base.py +125 -105
  6. plain/models/backends/base/client.py +11 -3
  7. plain/models/backends/base/creation.py +24 -14
  8. plain/models/backends/base/features.py +10 -4
  9. plain/models/backends/base/introspection.py +37 -20
  10. plain/models/backends/base/operations.py +187 -91
  11. plain/models/backends/base/schema.py +338 -218
  12. plain/models/backends/base/validation.py +13 -4
  13. plain/models/backends/ddl_references.py +85 -43
  14. plain/models/backends/mysql/base.py +29 -26
  15. plain/models/backends/mysql/client.py +7 -2
  16. plain/models/backends/mysql/compiler.py +13 -4
  17. plain/models/backends/mysql/creation.py +5 -2
  18. plain/models/backends/mysql/features.py +24 -22
  19. plain/models/backends/mysql/introspection.py +22 -13
  20. plain/models/backends/mysql/operations.py +107 -40
  21. plain/models/backends/mysql/schema.py +52 -28
  22. plain/models/backends/mysql/validation.py +13 -6
  23. plain/models/backends/postgresql/base.py +41 -34
  24. plain/models/backends/postgresql/client.py +7 -2
  25. plain/models/backends/postgresql/creation.py +10 -5
  26. plain/models/backends/postgresql/introspection.py +15 -8
  27. plain/models/backends/postgresql/operations.py +110 -43
  28. plain/models/backends/postgresql/schema.py +88 -49
  29. plain/models/backends/sqlite3/_functions.py +151 -115
  30. plain/models/backends/sqlite3/base.py +37 -23
  31. plain/models/backends/sqlite3/client.py +7 -1
  32. plain/models/backends/sqlite3/creation.py +9 -5
  33. plain/models/backends/sqlite3/features.py +5 -3
  34. plain/models/backends/sqlite3/introspection.py +32 -16
  35. plain/models/backends/sqlite3/operations.py +126 -43
  36. plain/models/backends/sqlite3/schema.py +127 -92
  37. plain/models/backends/utils.py +52 -29
  38. plain/models/backups/cli.py +8 -6
  39. plain/models/backups/clients.py +16 -7
  40. plain/models/backups/core.py +24 -13
  41. plain/models/base.py +221 -229
  42. plain/models/cli.py +98 -67
  43. plain/models/config.py +1 -1
  44. plain/models/connections.py +23 -7
  45. plain/models/constraints.py +79 -56
  46. plain/models/database_url.py +1 -1
  47. plain/models/db.py +6 -2
  48. plain/models/deletion.py +80 -56
  49. plain/models/entrypoints.py +1 -1
  50. plain/models/enums.py +22 -11
  51. plain/models/exceptions.py +23 -8
  52. plain/models/expressions.py +441 -258
  53. plain/models/fields/__init__.py +272 -217
  54. plain/models/fields/json.py +123 -57
  55. plain/models/fields/mixins.py +12 -8
  56. plain/models/fields/related.py +324 -290
  57. plain/models/fields/related_descriptors.py +33 -24
  58. plain/models/fields/related_lookups.py +24 -12
  59. plain/models/fields/related_managers.py +102 -79
  60. plain/models/fields/reverse_related.py +66 -63
  61. plain/models/forms.py +101 -75
  62. plain/models/functions/comparison.py +71 -18
  63. plain/models/functions/datetime.py +79 -29
  64. plain/models/functions/math.py +43 -10
  65. plain/models/functions/mixins.py +24 -7
  66. plain/models/functions/text.py +104 -25
  67. plain/models/functions/window.py +12 -6
  68. plain/models/indexes.py +57 -32
  69. plain/models/lookups.py +228 -153
  70. plain/models/meta.py +505 -0
  71. plain/models/migrations/autodetector.py +86 -43
  72. plain/models/migrations/exceptions.py +7 -3
  73. plain/models/migrations/executor.py +33 -7
  74. plain/models/migrations/graph.py +79 -50
  75. plain/models/migrations/loader.py +45 -22
  76. plain/models/migrations/migration.py +23 -18
  77. plain/models/migrations/operations/base.py +38 -20
  78. plain/models/migrations/operations/fields.py +95 -48
  79. plain/models/migrations/operations/models.py +246 -142
  80. plain/models/migrations/operations/special.py +82 -25
  81. plain/models/migrations/optimizer.py +7 -2
  82. plain/models/migrations/questioner.py +58 -31
  83. plain/models/migrations/recorder.py +27 -16
  84. plain/models/migrations/serializer.py +50 -39
  85. plain/models/migrations/state.py +232 -156
  86. plain/models/migrations/utils.py +30 -14
  87. plain/models/migrations/writer.py +17 -14
  88. plain/models/options.py +189 -518
  89. plain/models/otel.py +16 -6
  90. plain/models/preflight.py +42 -17
  91. plain/models/query.py +400 -251
  92. plain/models/query_utils.py +109 -69
  93. plain/models/registry.py +40 -21
  94. plain/models/sql/compiler.py +190 -127
  95. plain/models/sql/datastructures.py +38 -25
  96. plain/models/sql/query.py +320 -225
  97. plain/models/sql/subqueries.py +36 -25
  98. plain/models/sql/where.py +54 -29
  99. plain/models/test/pytest.py +15 -11
  100. plain/models/test/utils.py +4 -2
  101. plain/models/transaction.py +20 -7
  102. plain/models/utils.py +17 -6
  103. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/METADATA +27 -43
  104. plain_models-0.51.0.dist-info/RECORD +123 -0
  105. plain_models-0.49.2.dist-info/RECORD +0 -122
  106. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/WHEEL +0 -0
  107. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/entry_points.txt +0 -0
  108. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
1
5
  from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
2
6
  from plain.models.constants import LOOKUP_SEP
3
7
  from plain.models.constraints import UniqueConstraint
4
8
  from plain.models.expressions import F
5
9
  from plain.models.fields import NOT_PROVIDED
6
10
 
11
+ if TYPE_CHECKING:
12
+ from collections.abc import Sequence
13
+
14
+ from plain.models.base import Model
15
+ from plain.models.constraints import BaseConstraint
16
+ from plain.models.fields import Field
17
+ from plain.models.indexes import Index
18
+
7
19
 
8
20
  class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
9
21
  sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
@@ -37,7 +49,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
37
49
  sql_alter_column_comment = None
38
50
 
39
51
  @property
40
- def sql_delete_check(self):
52
+ def sql_delete_check(self) -> str:
41
53
  if self.connection.mysql_is_mariadb:
42
54
  # The name of the column check constraint is the same as the field
43
55
  # name on MariaDB. Adding IF EXISTS clause prevents migrations
@@ -46,7 +58,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
46
58
  return "ALTER TABLE %(table)s DROP CHECK %(name)s"
47
59
 
48
60
  @property
49
- def sql_rename_column(self):
61
+ def sql_rename_column(self) -> str:
50
62
  # MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an
51
63
  # "ALTER TABLE ... RENAME COLUMN" statement.
52
64
  if self.connection.mysql_is_mariadb:
@@ -56,7 +68,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
56
68
  return super().sql_rename_column
57
69
  return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
58
70
 
59
- def quote_value(self, value):
71
+ def quote_value(self, value: Any) -> str:
60
72
  self.connection.ensure_connection()
61
73
  if isinstance(value, str):
62
74
  value = value.replace("%", "%%")
@@ -68,19 +80,19 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
68
80
  quoted = quoted.decode()
69
81
  return quoted
70
82
 
71
- def _is_limited_data_type(self, field):
83
+ def _is_limited_data_type(self, field: Field) -> bool:
72
84
  db_type = field.db_type(self.connection)
73
85
  return (
74
86
  db_type is not None
75
87
  and db_type.lower() in self.connection._limited_data_types
76
88
  )
77
89
 
78
- def skip_default(self, field):
90
+ def skip_default(self, field: Field) -> bool:
79
91
  if not self._supports_limited_data_type_defaults:
80
92
  return self._is_limited_data_type(field)
81
93
  return False
82
94
 
83
- def skip_default_on_alter(self, field):
95
+ def skip_default_on_alter(self, field: Field) -> bool:
84
96
  if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
85
97
  # MySQL doesn't support defaults for BLOB and TEXT in the
86
98
  # ALTER COLUMN statement.
@@ -88,13 +100,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
88
100
  return False
89
101
 
90
102
  @property
91
- def _supports_limited_data_type_defaults(self):
103
+ def _supports_limited_data_type_defaults(self) -> bool:
92
104
  # MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT.
93
105
  if self.connection.mysql_is_mariadb:
94
106
  return True
95
107
  return self.connection.mysql_version >= (8, 0, 13)
96
108
 
97
- def _column_default_sql(self, field):
109
+ def _column_default_sql(self, field: Field) -> str:
98
110
  if (
99
111
  not self.connection.mysql_is_mariadb
100
112
  and self._supports_limited_data_type_defaults
@@ -105,7 +117,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
105
117
  return "(%s)"
106
118
  return super()._column_default_sql(field)
107
119
 
108
- def add_field(self, model, field):
120
+ def add_field(self, model: type[Model], field: Field) -> None:
109
121
  super().add_field(model, field)
110
122
 
111
123
  # Simulate the effect of a one-off default.
@@ -113,11 +125,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
113
125
  if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
114
126
  effective_default = self.effective_default(field)
115
127
  self.execute(
116
- f"UPDATE {self.quote_name(model._meta.db_table)} SET {self.quote_name(field.column)} = %s",
128
+ f"UPDATE {self.quote_name(model.model_options.db_table)} SET {self.quote_name(field.column)} = %s",
117
129
  [effective_default],
118
130
  )
119
131
 
120
- def remove_constraint(self, model, constraint):
132
+ def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
121
133
  if (
122
134
  isinstance(constraint, UniqueConstraint)
123
135
  and constraint.create_sql(model, self) is not None
@@ -129,7 +141,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
129
141
  )
130
142
  super().remove_constraint(model, constraint)
131
143
 
132
- def remove_index(self, model, index):
144
+ def remove_index(self, model: type[Model], index: Index) -> None:
133
145
  self._create_missing_fk_index(
134
146
  model,
135
147
  fields=[field_name for field_name, _ in index.fields_orders],
@@ -137,12 +149,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
137
149
  )
138
150
  super().remove_index(model, index)
139
151
 
140
- def _field_should_be_indexed(self, model, field):
152
+ def _field_should_be_indexed(self, model: type[Model], field: Field) -> bool:
141
153
  if not super()._field_should_be_indexed(model, field):
142
154
  return False
143
155
 
144
156
  storage = self.connection.introspection.get_storage_engine(
145
- self.connection.cursor(), model._meta.db_table
157
+ self.connection.cursor(), model.model_options.db_table
146
158
  )
147
159
  # No need to create an index for ForeignKey fields except if
148
160
  # db_constraint=False because the index from that constraint won't be
@@ -150,18 +162,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
150
162
  if (
151
163
  storage == "InnoDB"
152
164
  and field.get_internal_type() == "ForeignKey"
153
- and field.db_constraint
165
+ and field.db_constraint # type: ignore[attr-defined]
154
166
  ):
155
167
  return False
156
168
  return not self._is_limited_data_type(field)
157
169
 
158
170
  def _create_missing_fk_index(
159
171
  self,
160
- model,
172
+ model: type[Model],
161
173
  *,
162
- fields,
163
- expressions=None,
164
- ):
174
+ fields: Sequence[str],
175
+ expressions: Sequence[Any] | None = None,
176
+ ) -> None:
165
177
  """
166
178
  MySQL can remove an implicit FK index on a field when that field is
167
179
  covered by another index. "covered" here means
@@ -185,7 +197,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
185
197
  if not first_field_name:
186
198
  return
187
199
 
188
- first_field = model._meta.get_field(first_field_name)
200
+ first_field = model._model_meta.get_field(first_field_name)
189
201
  if first_field.get_internal_type() == "ForeignKey":
190
202
  column = self.connection.introspection.identifier_converter(
191
203
  first_field.column
@@ -194,7 +206,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
194
206
  constraint_names = [
195
207
  name
196
208
  for name, infodict in self.connection.introspection.get_constraints(
197
- cursor, model._meta.db_table
209
+ cursor, model.model_options.db_table
198
210
  ).items()
199
211
  if infodict["index"] and infodict["columns"][0] == column
200
212
  ]
@@ -205,7 +217,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
205
217
  self._create_index_sql(model, fields=[first_field], suffix="")
206
218
  )
207
219
 
208
- def _set_field_new_type_null_status(self, field, new_type):
220
+ def _set_field_new_type_null_status(self, field: Field, new_type: str) -> str:
209
221
  """
210
222
  Keep the null property of the old field. If it has changed, it will be
211
223
  handled separately.
@@ -217,14 +229,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
217
229
  return new_type
218
230
 
219
231
  def _alter_column_type_sql(
220
- self, model, old_field, new_field, new_type, old_collation, new_collation
221
- ):
232
+ self,
233
+ model: type[Model],
234
+ old_field: Field,
235
+ new_field: Field,
236
+ new_type: str,
237
+ old_collation: str,
238
+ new_collation: str,
239
+ ) -> tuple[str, list[Any]]:
222
240
  new_type = self._set_field_new_type_null_status(old_field, new_type)
223
241
  return super()._alter_column_type_sql(
224
242
  model, old_field, new_field, new_type, old_collation, new_collation
225
243
  )
226
244
 
227
- def _field_db_check(self, field, field_db_params):
245
+ def _field_db_check(
246
+ self, field: Field, field_db_params: dict[str, Any]
247
+ ) -> str | None:
228
248
  if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
229
249
  10,
230
250
  5,
@@ -237,14 +257,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
237
257
  # renamed.
238
258
  return field_db_params["check"]
239
259
 
240
- def _rename_field_sql(self, table, old_field, new_field, new_type):
260
+ def _rename_field_sql(
261
+ self, table: str, old_field: Field, new_field: Field, new_type: str
262
+ ) -> str:
241
263
  new_type = self._set_field_new_type_null_status(old_field, new_type)
242
264
  return super()._rename_field_sql(table, old_field, new_field, new_type)
243
265
 
244
- def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
266
+ def _alter_column_comment_sql(
267
+ self, model: type[Model], new_field: Field, new_type: str, new_db_comment: str
268
+ ) -> tuple[str, list[Any]]:
245
269
  # Comment is alter when altering the column type.
246
270
  return "", []
247
271
 
248
- def _comment_sql(self, comment):
272
+ def _comment_sql(self, comment: str | None) -> str:
249
273
  comment_sql = super()._comment_sql(comment)
250
274
  return f" COMMENT {comment_sql}"
@@ -1,19 +1,26 @@
1
- from plain import preflight
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
2
5
  from plain.models.backends.base.validation import BaseDatabaseValidation
6
+ from plain.preflight import PreflightResult
7
+
8
+ if TYPE_CHECKING:
9
+ from plain.models.fields import Field
3
10
 
4
11
 
5
12
  class DatabaseValidation(BaseDatabaseValidation):
6
- def preflight(self):
13
+ def preflight(self) -> list[PreflightResult]:
7
14
  issues = super().preflight()
8
15
  issues.extend(self._check_sql_mode())
9
16
  return issues
10
17
 
11
- def _check_sql_mode(self):
18
+ def _check_sql_mode(self) -> list[PreflightResult]:
12
19
  if not (
13
20
  self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
14
21
  ):
15
22
  return [
16
- preflight.PreflightResult(
23
+ PreflightResult(
17
24
  fix=f"{self.connection.display_name} Strict Mode is not set for the database connection. "
18
25
  f"{self.connection.display_name}'s Strict Mode fixes many data integrity problems in "
19
26
  f"{self.connection.display_name}, such as data truncation upon insertion, by "
@@ -25,7 +32,7 @@ class DatabaseValidation(BaseDatabaseValidation):
25
32
  ]
26
33
  return []
27
34
 
28
- def check_field_type(self, field, field_type):
35
+ def check_field_type(self, field: Field, field_type: str) -> list[PreflightResult]:
29
36
  """
30
37
  MySQL has the following field length restriction:
31
38
  No character (varchar) fields can have a length exceeding 255
@@ -39,7 +46,7 @@ class DatabaseValidation(BaseDatabaseValidation):
39
46
  and (field.max_length is None or int(field.max_length) > 255)
40
47
  ):
41
48
  errors.append(
42
- preflight.PreflightResult(
49
+ PreflightResult(
43
50
  fix=f"{self.connection.display_name} may not allow unique CharFields to have a max_length "
44
51
  "> 255.",
45
52
  obj=field,
@@ -1,15 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
1
4
  import threading
2
5
  import warnings
6
+ from collections.abc import Generator
3
7
  from contextlib import contextmanager
4
8
  from functools import cached_property, lru_cache
9
+ from typing import Any
5
10
 
6
- import psycopg as Database
7
- from psycopg import IsolationLevel, adapt, adapters, sql
8
- from psycopg.postgres import types as pg_types
9
- from psycopg.pq import Format
10
- from psycopg.types.datetime import TimestamptzLoader
11
- from psycopg.types.range import Range, RangeDumper
12
- from psycopg.types.string import TextLoader
11
+ import psycopg as Database # type: ignore[import-untyped]
12
+ from psycopg import IsolationLevel, adapt, adapters, sql # type: ignore[import-untyped]
13
+ from psycopg.postgres import types as pg_types # type: ignore[import-untyped]
14
+ from psycopg.pq import Format # type: ignore[import-untyped]
15
+ from psycopg.types.datetime import TimestamptzLoader # type: ignore[import-untyped]
16
+ from psycopg.types.range import Range, RangeDumper # type: ignore[import-untyped]
17
+ from psycopg.types.string import TextLoader # type: ignore[import-untyped]
13
18
 
14
19
  from plain.exceptions import ImproperlyConfigured
15
20
  from plain.models.backends.base.base import BaseDatabaseWrapper
@@ -38,14 +43,14 @@ class BaseTzLoader(TimestamptzLoader):
38
43
  The timezone can be None too, in which case it will be chopped.
39
44
  """
40
45
 
41
- timezone = None
46
+ timezone: datetime.tzinfo | None = None
42
47
 
43
- def load(self, data):
48
+ def load(self, data: bytes) -> datetime.datetime:
44
49
  res = super().load(data)
45
50
  return res.replace(tzinfo=self.timezone)
46
51
 
47
52
 
48
- def register_tzloader(tz, context):
53
+ def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
49
54
  class SpecificTzLoader(BaseTzLoader):
50
55
  timezone = tz
51
56
 
@@ -55,7 +60,7 @@ def register_tzloader(tz, context):
55
60
  class PlainRangeDumper(RangeDumper):
56
61
  """A Range dumper customized for Plain."""
57
62
 
58
- def upgrade(self, obj, format):
63
+ def upgrade(self, obj: Any, format: Format) -> RangeDumper:
59
64
  dumper = super().upgrade(obj, format)
60
65
  if dumper is not self and dumper.oid == TSRANGE_OID:
61
66
  dumper.oid = TSTZRANGE_OID
@@ -63,7 +68,7 @@ class PlainRangeDumper(RangeDumper):
63
68
 
64
69
 
65
70
  @lru_cache
66
- def get_adapters_template(timezone):
71
+ def get_adapters_template(timezone: datetime.tzinfo | None) -> adapters.AdaptersMap:
67
72
  ctx = adapt.AdaptersMap(adapters)
68
73
  # No-op JSON loader to avoid psycopg3 round trips
69
74
  ctx.register_loader("jsonb", TextLoader)
@@ -75,13 +80,13 @@ def get_adapters_template(timezone):
75
80
  return ctx
76
81
 
77
82
 
78
- def _get_varchar_column(data):
83
+ def _get_varchar_column(data: dict[str, Any]) -> str:
79
84
  if data["max_length"] is None:
80
85
  return "varchar"
81
86
  return "varchar({max_length})".format(**data)
82
87
 
83
88
 
84
- class DatabaseWrapper(BaseDatabaseWrapper):
89
+ class PostgreSQLDatabaseWrapper(BaseDatabaseWrapper):
85
90
  vendor = "postgresql"
86
91
  display_name = "PostgreSQL"
87
92
  # This dictionary maps Field objects to their associated PostgreSQL column
@@ -166,14 +171,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
166
171
  # PostgreSQL backend-specific attributes.
167
172
  _named_cursor_idx = 0
168
173
 
169
- def get_database_version(self):
174
+ def get_database_version(self) -> tuple[int, ...]:
170
175
  """
171
176
  Return a tuple of the database's version.
172
177
  E.g. for pg_version 120004, return (12, 4).
173
178
  """
174
179
  return divmod(self.pg_version, 10000)
175
180
 
176
- def get_connection_params(self):
181
+ def get_connection_params(self) -> dict[str, Any]:
177
182
  settings_dict = self.settings_dict
178
183
  # None may be used to connect to the default 'postgres' db
179
184
  if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
@@ -194,7 +199,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
194
199
  self.ops.max_name_length(),
195
200
  )
196
201
  )
197
- conn_params = {"client_encoding": "UTF8"}
202
+ conn_params: dict[str, Any] = {"client_encoding": "UTF8"}
198
203
  if settings_dict["NAME"]:
199
204
  conn_params = {
200
205
  "dbname": settings_dict["NAME"],
@@ -224,7 +229,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
224
229
  conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
225
230
  return conn_params
226
231
 
227
- def get_new_connection(self, conn_params):
232
+ def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
228
233
  # self.isolation_level must be set:
229
234
  # - after connecting to the database in order to obtain the database's
230
235
  # default when no value is explicitly specified in options.
@@ -257,7 +262,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
257
262
  )
258
263
  return connection
259
264
 
260
- def ensure_timezone(self):
265
+ def ensure_timezone(self) -> bool:
261
266
  if self.connection is None:
262
267
  return False
263
268
  conn_timezone_name = self.connection.info.parameter_status("TimeZone")
@@ -268,7 +273,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
268
273
  return True
269
274
  return False
270
275
 
271
- def ensure_role(self):
276
+ def ensure_role(self) -> bool:
272
277
  if self.connection is None:
273
278
  return False
274
279
  if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
@@ -278,7 +283,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
278
283
  return True
279
284
  return False
280
285
 
281
- def init_connection_state(self):
286
+ def init_connection_state(self) -> None:
282
287
  super().init_connection_state()
283
288
 
284
289
  # Commit after setting the time zone.
@@ -291,7 +296,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
291
296
  if (commit_role or commit_tz) and not self.get_autocommit():
292
297
  self.connection.commit()
293
298
 
294
- def create_cursor(self, name=None):
299
+ def create_cursor(self, name: str | None = None) -> Any:
295
300
  if name:
296
301
  # In autocommit mode, the cursor will be used outside of a
297
302
  # transaction, hence use a holdable cursor.
@@ -307,10 +312,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
307
312
  register_tzloader(self.timezone, cursor)
308
313
  return cursor
309
314
 
310
- def tzinfo_factory(self, offset):
315
+ def tzinfo_factory(self, offset: int) -> datetime.tzinfo | None:
311
316
  return self.timezone
312
317
 
313
- def chunked_cursor(self):
318
+ def chunked_cursor(self) -> Any:
314
319
  self._named_cursor_idx += 1
315
320
  # Get the current async task
316
321
  # Note that right now this is behind @async_unsafe, so this is
@@ -329,11 +334,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
329
334
  )
330
335
  )
331
336
 
332
- def _set_autocommit(self, autocommit):
337
+ def _set_autocommit(self, autocommit: bool) -> None:
333
338
  with self.wrap_database_errors:
334
339
  self.connection.autocommit = autocommit
335
340
 
336
- def check_constraints(self, table_names=None):
341
+ def check_constraints(self, table_names: list[str] | None = None) -> None:
337
342
  """
338
343
  Check constraints by setting them to immediate. Return them to deferred
339
344
  afterward.
@@ -342,7 +347,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
342
347
  cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
343
348
  cursor.execute("SET CONSTRAINTS ALL DEFERRED")
344
349
 
345
- def is_usable(self):
350
+ def is_usable(self) -> bool:
346
351
  try:
347
352
  # Use a psycopg cursor directly, bypassing Plain's utilities.
348
353
  with self.connection.cursor() as cursor:
@@ -353,7 +358,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
353
358
  return True
354
359
 
355
360
  @contextmanager
356
- def _nodb_cursor(self):
361
+ def _nodb_cursor(self) -> Generator[Any, None, None]:
357
362
  cursor = None
358
363
  try:
359
364
  with super()._nodb_cursor() as cursor:
@@ -382,11 +387,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
382
387
  conn.close()
383
388
 
384
389
  @cached_property
385
- def pg_version(self):
390
+ def pg_version(self) -> int:
386
391
  with self.temporary_connection():
387
392
  return self.connection.info.server_version
388
393
 
389
- def make_debug_cursor(self, cursor):
394
+ def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
390
395
  return CursorDebugWrapper(cursor, self)
391
396
 
392
397
 
@@ -395,11 +400,13 @@ class CursorMixin:
395
400
  A subclass of psycopg cursor implementing callproc.
396
401
  """
397
402
 
398
- def callproc(self, name, args=None):
403
+ def callproc(
404
+ self, name: str | sql.Identifier, args: list[Any] | None = None
405
+ ) -> list[Any] | None:
399
406
  if not isinstance(name, sql.Identifier):
400
407
  name = sql.Identifier(name)
401
408
 
402
- qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
409
+ qparts: list[sql.Composable] = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
403
410
  if args:
404
411
  for item in args:
405
412
  qparts.append(sql.Literal(item))
@@ -408,7 +415,7 @@ class CursorMixin:
408
415
 
409
416
  qparts.append(sql.SQL(")"))
410
417
  stmt = sql.Composed(qparts)
411
- self.execute(stmt)
418
+ self.execute(stmt) # type: ignore[attr-defined]
412
419
  return args
413
420
 
414
421
 
@@ -421,6 +428,6 @@ class Cursor(CursorMixin, Database.ClientCursor):
421
428
 
422
429
 
423
430
  class CursorDebugWrapper(BaseCursorDebugWrapper):
424
- def copy(self, statement):
431
+ def copy(self, statement: Any) -> Any:
425
432
  with self.debug_sql(statement):
426
433
  return self.cursor.copy(statement)
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import signal
4
+ from typing import Any
2
5
 
3
6
  from plain.models.backends.base.client import BaseDatabaseClient
4
7
 
@@ -7,7 +10,9 @@ class DatabaseClient(BaseDatabaseClient):
7
10
  executable_name = "psql"
8
11
 
9
12
  @classmethod
10
- def settings_to_cmd_args_env(cls, settings_dict, parameters):
13
+ def settings_to_cmd_args_env(
14
+ cls, settings_dict: dict[str, Any], parameters: list[str]
15
+ ) -> tuple[list[str], dict[str, str] | None]:
11
16
  args = [cls.executable_name]
12
17
  options = settings_dict.get("OPTIONS", {})
13
18
 
@@ -53,7 +58,7 @@ class DatabaseClient(BaseDatabaseClient):
53
58
  env["PGPASSFILE"] = str(passfile)
54
59
  return args, (env or None)
55
60
 
56
- def runshell(self, parameters):
61
+ def runshell(self, parameters: list[str]) -> None:
57
62
  sigint_handler = signal.getsignal(signal.SIGINT)
58
63
  try:
59
64
  # Allow SIGINT to pass to psql to abort queries.
@@ -1,16 +1,21 @@
1
+ from __future__ import annotations
2
+
1
3
  import sys
4
+ from typing import Any
2
5
 
3
- from psycopg import errors
6
+ from psycopg import errors # type: ignore[import-untyped]
4
7
 
5
8
  from plain.exceptions import ImproperlyConfigured
6
9
  from plain.models.backends.base.creation import BaseDatabaseCreation
7
10
 
8
11
 
9
12
  class DatabaseCreation(BaseDatabaseCreation):
10
- def _quote_name(self, name):
13
+ def _quote_name(self, name: str) -> str:
11
14
  return self.connection.ops.quote_name(name)
12
15
 
13
- def _get_database_create_suffix(self, encoding=None, template=None):
16
+ def _get_database_create_suffix(
17
+ self, encoding: str | None = None, template: str | None = None
18
+ ) -> str:
14
19
  suffix = ""
15
20
  if encoding:
16
21
  suffix += f" ENCODING '{encoding}'"
@@ -18,7 +23,7 @@ class DatabaseCreation(BaseDatabaseCreation):
18
23
  suffix += f" TEMPLATE {self._quote_name(template)}"
19
24
  return suffix and "WITH" + suffix
20
25
 
21
- def sql_table_creation_suffix(self):
26
+ def sql_table_creation_suffix(self) -> str:
22
27
  test_settings = self.connection.settings_dict["TEST"]
23
28
  if test_settings.get("COLLATION") is not None:
24
29
  raise ImproperlyConfigured(
@@ -30,7 +35,7 @@ class DatabaseCreation(BaseDatabaseCreation):
30
35
  template=test_settings.get("TEMPLATE"),
31
36
  )
32
37
 
33
- def _execute_create_test_db(self, cursor, parameters):
38
+ def _execute_create_test_db(self, cursor: Any, parameters: dict[str, Any]) -> None:
34
39
  try:
35
40
  super()._execute_create_test_db(cursor, parameters)
36
41
  except Exception as e:
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import namedtuple
4
+ from typing import Any
2
5
 
3
6
  from plain.models.backends.base.introspection import BaseDatabaseIntrospection
4
7
  from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
@@ -36,9 +39,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
36
39
  # A hook for subclasses.
37
40
  index_default_access_method = "btree"
38
41
 
39
- ignored_tables = []
42
+ ignored_tables: list[str] = []
40
43
 
41
- def get_field_type(self, data_type, description):
44
+ def get_field_type(self, data_type: Any, description: Any) -> str:
42
45
  field_type = super().get_field_type(data_type, description)
43
46
  if description.is_autofield or (
44
47
  # Required for pre-Plain 4.1 serial columns.
@@ -48,7 +51,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
48
51
  return "PrimaryKeyField"
49
52
  return field_type
50
53
 
51
- def get_table_list(self, cursor):
54
+ def get_table_list(self, cursor: Any) -> list[TableInfo]:
52
55
  """Return a list of table and view names in the current database."""
53
56
  cursor.execute(
54
57
  """
@@ -73,7 +76,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
73
76
  if row[0] not in self.ignored_tables
74
77
  ]
75
78
 
76
- def get_table_description(self, cursor, table_name):
79
+ def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
77
80
  """
78
81
  Return a description of the table with the DB-API cursor.description
79
82
  interface.
@@ -120,7 +123,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
120
123
  for line in cursor.description
121
124
  ]
122
125
 
123
- def get_sequences(self, cursor, table_name, table_fields=()):
126
+ def get_sequences(
127
+ self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
128
+ ) -> list[dict[str, Any]]:
124
129
  cursor.execute(
125
130
  """
126
131
  SELECT
@@ -146,7 +151,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
146
151
  for row in cursor.fetchall()
147
152
  ]
148
153
 
149
- def get_relations(self, cursor, table_name):
154
+ def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
150
155
  """
151
156
  Return a dictionary of {field_name: (field_name_other_table, other_table)}
152
157
  representing all foreign keys in the given table.
@@ -171,13 +176,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
171
176
  )
172
177
  return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
173
178
 
174
- def get_constraints(self, cursor, table_name):
179
+ def get_constraints(
180
+ self, cursor: Any, table_name: str
181
+ ) -> dict[str, dict[str, Any]]:
175
182
  """
176
183
  Retrieve any constraints or keys (unique, pk, fk, check, index) across
177
184
  one or more columns. Also retrieve the definition of expression-based
178
185
  indexes.
179
186
  """
180
- constraints = {}
187
+ constraints: dict[str, dict[str, Any]] = {}
181
188
  # Loop over the key table, collecting things as constraints. The column
182
189
  # array must return column names in the same order in which they were
183
190
  # created.