plain.models 0.49.2__py3-none-any.whl → 0.51.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. plain/models/CHANGELOG.md +27 -0
  2. plain/models/README.md +26 -42
  3. plain/models/__init__.py +2 -0
  4. plain/models/aggregates.py +42 -19
  5. plain/models/backends/base/base.py +125 -105
  6. plain/models/backends/base/client.py +11 -3
  7. plain/models/backends/base/creation.py +24 -14
  8. plain/models/backends/base/features.py +10 -4
  9. plain/models/backends/base/introspection.py +37 -20
  10. plain/models/backends/base/operations.py +187 -91
  11. plain/models/backends/base/schema.py +338 -218
  12. plain/models/backends/base/validation.py +13 -4
  13. plain/models/backends/ddl_references.py +85 -43
  14. plain/models/backends/mysql/base.py +29 -26
  15. plain/models/backends/mysql/client.py +7 -2
  16. plain/models/backends/mysql/compiler.py +13 -4
  17. plain/models/backends/mysql/creation.py +5 -2
  18. plain/models/backends/mysql/features.py +24 -22
  19. plain/models/backends/mysql/introspection.py +22 -13
  20. plain/models/backends/mysql/operations.py +107 -40
  21. plain/models/backends/mysql/schema.py +52 -28
  22. plain/models/backends/mysql/validation.py +13 -6
  23. plain/models/backends/postgresql/base.py +41 -34
  24. plain/models/backends/postgresql/client.py +7 -2
  25. plain/models/backends/postgresql/creation.py +10 -5
  26. plain/models/backends/postgresql/introspection.py +15 -8
  27. plain/models/backends/postgresql/operations.py +110 -43
  28. plain/models/backends/postgresql/schema.py +88 -49
  29. plain/models/backends/sqlite3/_functions.py +151 -115
  30. plain/models/backends/sqlite3/base.py +37 -23
  31. plain/models/backends/sqlite3/client.py +7 -1
  32. plain/models/backends/sqlite3/creation.py +9 -5
  33. plain/models/backends/sqlite3/features.py +5 -3
  34. plain/models/backends/sqlite3/introspection.py +32 -16
  35. plain/models/backends/sqlite3/operations.py +126 -43
  36. plain/models/backends/sqlite3/schema.py +127 -92
  37. plain/models/backends/utils.py +52 -29
  38. plain/models/backups/cli.py +8 -6
  39. plain/models/backups/clients.py +16 -7
  40. plain/models/backups/core.py +24 -13
  41. plain/models/base.py +221 -229
  42. plain/models/cli.py +98 -67
  43. plain/models/config.py +1 -1
  44. plain/models/connections.py +23 -7
  45. plain/models/constraints.py +79 -56
  46. plain/models/database_url.py +1 -1
  47. plain/models/db.py +6 -2
  48. plain/models/deletion.py +80 -56
  49. plain/models/entrypoints.py +1 -1
  50. plain/models/enums.py +22 -11
  51. plain/models/exceptions.py +23 -8
  52. plain/models/expressions.py +441 -258
  53. plain/models/fields/__init__.py +272 -217
  54. plain/models/fields/json.py +123 -57
  55. plain/models/fields/mixins.py +12 -8
  56. plain/models/fields/related.py +324 -290
  57. plain/models/fields/related_descriptors.py +33 -24
  58. plain/models/fields/related_lookups.py +24 -12
  59. plain/models/fields/related_managers.py +102 -79
  60. plain/models/fields/reverse_related.py +66 -63
  61. plain/models/forms.py +101 -75
  62. plain/models/functions/comparison.py +71 -18
  63. plain/models/functions/datetime.py +79 -29
  64. plain/models/functions/math.py +43 -10
  65. plain/models/functions/mixins.py +24 -7
  66. plain/models/functions/text.py +104 -25
  67. plain/models/functions/window.py +12 -6
  68. plain/models/indexes.py +57 -32
  69. plain/models/lookups.py +228 -153
  70. plain/models/meta.py +505 -0
  71. plain/models/migrations/autodetector.py +86 -43
  72. plain/models/migrations/exceptions.py +7 -3
  73. plain/models/migrations/executor.py +33 -7
  74. plain/models/migrations/graph.py +79 -50
  75. plain/models/migrations/loader.py +45 -22
  76. plain/models/migrations/migration.py +23 -18
  77. plain/models/migrations/operations/base.py +38 -20
  78. plain/models/migrations/operations/fields.py +95 -48
  79. plain/models/migrations/operations/models.py +246 -142
  80. plain/models/migrations/operations/special.py +82 -25
  81. plain/models/migrations/optimizer.py +7 -2
  82. plain/models/migrations/questioner.py +58 -31
  83. plain/models/migrations/recorder.py +27 -16
  84. plain/models/migrations/serializer.py +50 -39
  85. plain/models/migrations/state.py +232 -156
  86. plain/models/migrations/utils.py +30 -14
  87. plain/models/migrations/writer.py +17 -14
  88. plain/models/options.py +189 -518
  89. plain/models/otel.py +16 -6
  90. plain/models/preflight.py +42 -17
  91. plain/models/query.py +400 -251
  92. plain/models/query_utils.py +109 -69
  93. plain/models/registry.py +40 -21
  94. plain/models/sql/compiler.py +190 -127
  95. plain/models/sql/datastructures.py +38 -25
  96. plain/models/sql/query.py +320 -225
  97. plain/models/sql/subqueries.py +36 -25
  98. plain/models/sql/where.py +54 -29
  99. plain/models/test/pytest.py +15 -11
  100. plain/models/test/utils.py +4 -2
  101. plain/models/transaction.py +20 -7
  102. plain/models/utils.py +17 -6
  103. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/METADATA +27 -43
  104. plain_models-0.51.0.dist-info/RECORD +123 -0
  105. plain_models-0.49.2.dist-info/RECORD +0 -122
  106. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/WHEEL +0 -0
  107. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/entry_points.txt +0 -0
  108. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import json
4
+ from typing import TYPE_CHECKING, Any, cast
2
5
 
3
6
  from plain import exceptions, preflight
4
7
  from plain.models import expressions, lookups
@@ -14,6 +17,13 @@ from plain.models.lookups import (
14
17
  from . import Field
15
18
  from .mixins import CheckFieldDefaultMixin
16
19
 
20
+ if TYPE_CHECKING:
21
+ from plain.models.backends.base.base import BaseDatabaseWrapper
22
+ from plain.models.backends.mysql.base import MySQLDatabaseWrapper
23
+ from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
24
+ from plain.models.sql.compiler import SQLCompiler
25
+ from plain.preflight.results import PreflightResult
26
+
17
27
  __all__ = ["JSONField"]
18
28
 
19
29
 
@@ -28,9 +38,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
28
38
  def __init__(
29
39
  self,
30
40
  *,
31
- encoder=None,
32
- decoder=None,
33
- **kwargs,
41
+ encoder: type[json.JSONEncoder] | None = None,
42
+ decoder: type[json.JSONDecoder] | None = None,
43
+ **kwargs: Any,
34
44
  ):
35
45
  if encoder and not callable(encoder):
36
46
  raise ValueError("The encoder parameter must be a callable object.")
@@ -40,22 +50,22 @@ class JSONField(CheckFieldDefaultMixin, Field):
40
50
  self.decoder = decoder
41
51
  super().__init__(**kwargs)
42
52
 
43
- def preflight(self, **kwargs):
53
+ def preflight(self, **kwargs: Any) -> list[PreflightResult]:
44
54
  errors = super().preflight(**kwargs)
45
55
  errors.extend(self._check_supported())
46
56
  return errors
47
57
 
48
- def _check_supported(self):
58
+ def _check_supported(self) -> list[PreflightResult]:
49
59
  errors = []
50
60
 
51
61
  if (
52
- self.model._meta.required_db_vendor
53
- and self.model._meta.required_db_vendor != db_connection.vendor
62
+ self.model.model_options.required_db_vendor
63
+ and self.model.model_options.required_db_vendor != db_connection.vendor
54
64
  ):
55
65
  return errors
56
66
 
57
67
  if not (
58
- "supports_json_field" in self.model._meta.required_db_features
68
+ "supports_json_field" in self.model.model_options.required_db_features
59
69
  or db_connection.features.supports_json_field
60
70
  ):
61
71
  errors.append(
@@ -67,7 +77,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
67
77
  )
68
78
  return errors
69
79
 
70
- def deconstruct(self):
80
+ def deconstruct(self) -> tuple[str, str, list[Any], dict[str, Any]]:
71
81
  name, path, args, kwargs = super().deconstruct()
72
82
  if self.encoder is not None:
73
83
  kwargs["encoder"] = self.encoder
@@ -75,7 +85,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
75
85
  kwargs["decoder"] = self.decoder
76
86
  return name, path, args, kwargs
77
87
 
78
- def from_db_value(self, value, expression, connection):
88
+ def from_db_value(
89
+ self, value: Any, expression: Any, connection: BaseDatabaseWrapper
90
+ ) -> Any:
79
91
  if value is None:
80
92
  return value
81
93
  # Some backends (SQLite at least) extract non-string values in their
@@ -87,10 +99,12 @@ class JSONField(CheckFieldDefaultMixin, Field):
87
99
  except json.JSONDecodeError:
88
100
  return value
89
101
 
90
- def get_internal_type(self):
102
+ def get_internal_type(self) -> str:
91
103
  return "JSONField"
92
104
 
93
- def get_db_prep_value(self, value, connection, prepared=False):
105
+ def get_db_prep_value(
106
+ self, value: Any, connection: BaseDatabaseWrapper, prepared: bool = False
107
+ ) -> Any:
94
108
  if isinstance(value, expressions.Value) and isinstance(
95
109
  value.output_field, JSONField
96
110
  ):
@@ -99,18 +113,18 @@ class JSONField(CheckFieldDefaultMixin, Field):
99
113
  return value
100
114
  return connection.ops.adapt_json_value(value, self.encoder)
101
115
 
102
- def get_db_prep_save(self, value, connection):
116
+ def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any:
103
117
  if value is None:
104
118
  return value
105
119
  return self.get_db_prep_value(value, connection)
106
120
 
107
- def get_transform(self, name):
121
+ def get_transform(self, name: str) -> KeyTransformFactory | type[Transform]:
108
122
  transform = super().get_transform(name)
109
123
  if transform:
110
124
  return transform
111
125
  return KeyTransformFactory(name)
112
126
 
113
- def validate(self, value, model_instance):
127
+ def validate(self, value: Any, model_instance: Any) -> None:
114
128
  super().validate(value, model_instance)
115
129
  try:
116
130
  json.dumps(value, cls=self.encoder)
@@ -121,11 +135,11 @@ class JSONField(CheckFieldDefaultMixin, Field):
121
135
  params={"value": value},
122
136
  )
123
137
 
124
- def value_to_string(self, obj):
138
+ def value_to_string(self, obj: Any) -> Any:
125
139
  return self.value_from_object(obj)
126
140
 
127
141
 
128
- def compile_json_path(key_transforms, include_root=True):
142
+ def compile_json_path(key_transforms: list[Any], include_root: bool = True) -> str:
129
143
  path = ["$"] if include_root else []
130
144
  for key_transform in key_transforms:
131
145
  try:
@@ -142,7 +156,9 @@ class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
142
156
  lookup_name = "contains"
143
157
  postgres_operator = "@>"
144
158
 
145
- def as_sql(self, compiler, connection):
159
+ def as_sql(
160
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
161
+ ) -> tuple[str, tuple[Any, ...]]:
146
162
  if not connection.features.supports_json_field_contains:
147
163
  raise NotSupportedError(
148
164
  "contains lookup is not supported on this database backend."
@@ -157,7 +173,9 @@ class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
157
173
  lookup_name = "contained_by"
158
174
  postgres_operator = "<@"
159
175
 
160
- def as_sql(self, compiler, connection):
176
+ def as_sql(
177
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
178
+ ) -> tuple[str, tuple[Any, ...]]:
161
179
  if not connection.features.supports_json_field_contains:
162
180
  raise NotSupportedError(
163
181
  "contained_by lookup is not supported on this database backend."
@@ -169,13 +187,18 @@ class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
169
187
 
170
188
 
171
189
  class HasKeyLookup(PostgresOperatorLookup):
172
- logical_operator = None
190
+ logical_operator: str | None = None
173
191
 
174
- def compile_json_path_final_key(self, key_transform):
192
+ def compile_json_path_final_key(self, key_transform: Any) -> str:
175
193
  # Compile the final key without interpreting ints as array elements.
176
194
  return f".{json.dumps(key_transform)}"
177
195
 
178
- def as_sql(self, compiler, connection, template=None):
196
+ def as_sql(
197
+ self,
198
+ compiler: SQLCompiler,
199
+ connection: BaseDatabaseWrapper,
200
+ template: str | None = None,
201
+ ) -> tuple[str, tuple[Any, ...]]:
179
202
  # Process JSON path from the left-hand side.
180
203
  if isinstance(self.lhs, KeyTransform):
181
204
  lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
@@ -205,12 +228,16 @@ class HasKeyLookup(PostgresOperatorLookup):
205
228
  sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
206
229
  return sql, tuple(lhs_params) + tuple(rhs_params)
207
230
 
208
- def as_mysql(self, compiler, connection):
231
+ def as_mysql(
232
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
233
+ ) -> tuple[str, tuple[Any, ...]]:
209
234
  return self.as_sql(
210
235
  compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
211
236
  )
212
237
 
213
- def as_postgresql(self, compiler, connection):
238
+ def as_postgresql(
239
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
240
+ ) -> tuple[str, tuple[Any, ...]]:
214
241
  if isinstance(self.rhs, KeyTransform):
215
242
  *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
216
243
  for key in rhs_key_transforms[:-1]:
@@ -218,7 +245,9 @@ class HasKeyLookup(PostgresOperatorLookup):
218
245
  self.rhs = rhs_key_transforms[-1]
219
246
  return super().as_postgresql(compiler, connection)
220
247
 
221
- def as_sqlite(self, compiler, connection):
248
+ def as_sqlite(
249
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
250
+ ) -> tuple[str, tuple[Any, ...]]:
222
251
  return self.as_sql(
223
252
  compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
224
253
  )
@@ -235,7 +264,7 @@ class HasKeys(HasKeyLookup):
235
264
  postgres_operator = "?&"
236
265
  logical_operator = " AND "
237
266
 
238
- def get_prep_lookup(self):
267
+ def get_prep_lookup(self) -> list[str]:
239
268
  return [str(item) for item in self.rhs]
240
269
 
241
270
 
@@ -246,7 +275,7 @@ class HasAnyKeys(HasKeys):
246
275
 
247
276
 
248
277
  class HasKeyOrArrayIndex(HasKey):
249
- def compile_json_path_final_key(self, key_transform):
278
+ def compile_json_path_final_key(self, key_transform: Any) -> str:
250
279
  return compile_json_path([key_transform], include_root=False)
251
280
 
252
281
 
@@ -258,14 +287,18 @@ class CaseInsensitiveMixin:
258
287
  case-sensitive.
259
288
  """
260
289
 
261
- def process_lhs(self, compiler, connection):
262
- lhs, lhs_params = super().process_lhs(compiler, connection)
290
+ def process_lhs(
291
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
292
+ ) -> tuple[str, list[Any]]:
293
+ lhs, lhs_params = super().process_lhs(compiler, connection) # type: ignore[misc]
263
294
  if connection.vendor == "mysql":
264
295
  return f"LOWER({lhs})", lhs_params
265
296
  return lhs, lhs_params
266
297
 
267
- def process_rhs(self, compiler, connection):
268
- rhs, rhs_params = super().process_rhs(compiler, connection)
298
+ def process_rhs(
299
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
300
+ ) -> tuple[str, list[Any]]:
301
+ rhs, rhs_params = super().process_rhs(compiler, connection) # type: ignore[misc]
269
302
  if connection.vendor == "mysql":
270
303
  return f"LOWER({rhs})", rhs_params
271
304
  return rhs, rhs_params
@@ -274,7 +307,9 @@ class CaseInsensitiveMixin:
274
307
  class JSONExact(lookups.Exact):
275
308
  can_use_none_as_rhs = True
276
309
 
277
- def process_rhs(self, compiler, connection):
310
+ def process_rhs(
311
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
312
+ ) -> tuple[str, list[Any]]:
278
313
  rhs, rhs_params = super().process_rhs(compiler, connection)
279
314
  # Treat None lookup values as null.
280
315
  if rhs == "%s" and rhs_params == [None]:
@@ -302,11 +337,13 @@ class KeyTransform(Transform):
302
337
  postgres_operator = "->"
303
338
  postgres_nested_operator = "#>"
304
339
 
305
- def __init__(self, key_name, *args, **kwargs):
340
+ def __init__(self, key_name: str, *args: Any, **kwargs: Any):
306
341
  super().__init__(*args, **kwargs)
307
342
  self.key_name = str(key_name)
308
343
 
309
- def preprocess_lhs(self, compiler, connection):
344
+ def preprocess_lhs(
345
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
346
+ ) -> tuple[str, tuple[Any, ...], list[str]]:
310
347
  key_transforms = [self.key_name]
311
348
  previous = self.lhs
312
349
  while isinstance(previous, KeyTransform):
@@ -315,12 +352,16 @@ class KeyTransform(Transform):
315
352
  lhs, params = compiler.compile(previous)
316
353
  return lhs, params, key_transforms
317
354
 
318
- def as_mysql(self, compiler, connection):
355
+ def as_mysql(
356
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
357
+ ) -> tuple[str, tuple[Any, ...]]:
319
358
  lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
320
359
  json_path = compile_json_path(key_transforms)
321
360
  return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
322
361
 
323
- def as_postgresql(self, compiler, connection):
362
+ def as_postgresql(
363
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
364
+ ) -> tuple[str, tuple[Any, ...]]:
324
365
  lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
325
366
  if len(key_transforms) > 1:
326
367
  sql = f"({lhs} {self.postgres_nested_operator} %s)"
@@ -331,11 +372,17 @@ class KeyTransform(Transform):
331
372
  lookup = self.key_name
332
373
  return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
333
374
 
334
- def as_sqlite(self, compiler, connection):
375
+ def as_sqlite(
376
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
377
+ ) -> tuple[str, tuple[Any, ...]]:
378
+ sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
335
379
  lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
336
380
  json_path = compile_json_path(key_transforms)
337
381
  datatype_values = ",".join(
338
- [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
382
+ [
383
+ repr(datatype)
384
+ for datatype in sqlite_connection.ops.jsonfield_datatype_values # type: ignore[attr-defined]
385
+ ]
339
386
  )
340
387
  return (
341
388
  f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
@@ -348,8 +395,11 @@ class KeyTextTransform(KeyTransform):
348
395
  postgres_nested_operator = "#>>"
349
396
  output_field = TextField()
350
397
 
351
- def as_mysql(self, compiler, connection):
352
- if connection.mysql_is_mariadb:
398
+ def as_mysql(
399
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
400
+ ) -> tuple[str, tuple[Any, ...]]:
401
+ mysql_connection = cast(MySQLDatabaseWrapper, connection)
402
+ if mysql_connection.mysql_is_mariadb:
353
403
  # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
354
404
  sql, params = super().as_mysql(compiler, connection)
355
405
  return f"JSON_UNQUOTE({sql})", params
@@ -359,7 +409,7 @@ class KeyTextTransform(KeyTransform):
359
409
  return f"({lhs} ->> %s)", tuple(params) + (json_path,)
360
410
 
361
411
  @classmethod
362
- def from_lookup(cls, lookup):
412
+ def from_lookup(cls, lookup: str) -> Any:
363
413
  transform, *keys = lookup.split(LOOKUP_SEP)
364
414
  if not keys:
365
415
  raise ValueError("Lookup must contain key or index transforms.")
@@ -379,7 +429,7 @@ class KeyTransformTextLookupMixin:
379
429
  representation.
380
430
  """
381
431
 
382
- def __init__(self, key_transform, *args, **kwargs):
432
+ def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
383
433
  if not isinstance(key_transform, KeyTransform):
384
434
  raise TypeError(
385
435
  "Transform should be an instance of KeyTransform in order to "
@@ -390,12 +440,14 @@ class KeyTransformTextLookupMixin:
390
440
  *key_transform.source_expressions,
391
441
  **key_transform.extra,
392
442
  )
393
- super().__init__(key_text_transform, *args, **kwargs)
443
+ super().__init__(key_text_transform, *args, **kwargs) # type: ignore[misc]
394
444
 
395
445
 
396
446
  class KeyTransformIsNull(lookups.IsNull):
397
447
  # key__isnull=False is the same as has_key='key'
398
- def as_sqlite(self, compiler, connection):
448
+ def as_sqlite(
449
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
450
+ ) -> tuple[str, tuple[Any, ...]]:
399
451
  template = "JSON_TYPE(%s, %%s) IS NULL"
400
452
  if not self.rhs:
401
453
  template = "JSON_TYPE(%s, %%s) IS NOT NULL"
@@ -407,7 +459,13 @@ class KeyTransformIsNull(lookups.IsNull):
407
459
 
408
460
 
409
461
  class KeyTransformIn(lookups.In):
410
- def resolve_expression_parameter(self, compiler, connection, sql, param):
462
+ def resolve_expression_parameter(
463
+ self,
464
+ compiler: SQLCompiler,
465
+ connection: BaseDatabaseWrapper,
466
+ sql: str,
467
+ param: Any,
468
+ ) -> tuple[str, tuple[Any, ...]]:
411
469
  sql, params = super().resolve_expression_parameter(
412
470
  compiler,
413
471
  connection,
@@ -418,25 +476,31 @@ class KeyTransformIn(lookups.In):
418
476
  not hasattr(param, "as_sql")
419
477
  and not connection.features.has_native_json_field
420
478
  ):
421
- if connection.vendor == "mysql" or (
422
- connection.vendor == "sqlite"
423
- and params[0] not in connection.ops.jsonfield_datatype_values
424
- ):
479
+ if connection.vendor == "mysql":
425
480
  sql = "JSON_EXTRACT(%s, '$')"
426
- if connection.vendor == "mysql" and connection.mysql_is_mariadb:
427
- sql = f"JSON_UNQUOTE({sql})"
481
+ elif connection.vendor == "sqlite":
482
+ sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
483
+ if params[0] not in sqlite_connection.ops.jsonfield_datatype_values: # type: ignore[attr-defined]
484
+ sql = "JSON_EXTRACT(%s, '$')"
485
+ if connection.vendor == "mysql":
486
+ mysql_connection = cast(MySQLDatabaseWrapper, connection)
487
+ if mysql_connection.mysql_is_mariadb:
488
+ sql = f"JSON_UNQUOTE({sql})"
428
489
  return sql, params
429
490
 
430
491
 
431
492
  class KeyTransformExact(JSONExact):
432
- def process_rhs(self, compiler, connection):
493
+ def process_rhs(
494
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
495
+ ) -> tuple[str, list[Any]]:
433
496
  if isinstance(self.rhs, KeyTransform):
434
497
  return super(lookups.Exact, self).process_rhs(compiler, connection)
435
498
  rhs, rhs_params = super().process_rhs(compiler, connection)
436
499
  if connection.vendor == "sqlite":
500
+ sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
437
501
  func = []
438
502
  for value in rhs_params:
439
- if value in connection.ops.jsonfield_datatype_values:
503
+ if value in sqlite_connection.ops.jsonfield_datatype_values: # type: ignore[attr-defined]
440
504
  func.append("%s")
441
505
  else:
442
506
  func.append("JSON_EXTRACT(%s, '$')")
@@ -487,8 +551,10 @@ class KeyTransformIRegex(
487
551
 
488
552
 
489
553
  class KeyTransformNumericLookupMixin:
490
- def process_rhs(self, compiler, connection):
491
- rhs, rhs_params = super().process_rhs(compiler, connection)
554
+ def process_rhs(
555
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
556
+ ) -> tuple[str, list[Any]]:
557
+ rhs, rhs_params = super().process_rhs(compiler, connection) # type: ignore[misc]
492
558
  if not connection.features.has_native_json_field:
493
559
  rhs_params = [json.loads(value) for value in rhs_params]
494
560
  return rhs, rhs_params
@@ -529,8 +595,8 @@ KeyTransform.register_lookup(KeyTransformGte)
529
595
 
530
596
 
531
597
  class KeyTransformFactory:
532
- def __init__(self, key_name):
598
+ def __init__(self, key_name: str):
533
599
  self.key_name = key_name
534
600
 
535
- def __call__(self, *args, **kwargs):
601
+ def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
536
602
  return KeyTransform(self.key_name, *args, **kwargs)
@@ -1,3 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
1
5
  from plain.preflight import PreflightResult
2
6
 
3
7
  NOT_PROVIDED = object()
@@ -6,10 +10,10 @@ NOT_PROVIDED = object()
6
10
  class FieldCacheMixin:
7
11
  """Provide an API for working with the model's fields value cache."""
8
12
 
9
- def get_cache_name(self):
13
+ def get_cache_name(self) -> str:
10
14
  raise NotImplementedError
11
15
 
12
- def get_cached_value(self, instance, default=NOT_PROVIDED):
16
+ def get_cached_value(self, instance: Any, default: Any = NOT_PROVIDED) -> Any:
13
17
  cache_name = self.get_cache_name()
14
18
  try:
15
19
  return instance._state.fields_cache[cache_name]
@@ -18,20 +22,20 @@ class FieldCacheMixin:
18
22
  raise
19
23
  return default
20
24
 
21
- def is_cached(self, instance):
25
+ def is_cached(self, instance: Any) -> bool:
22
26
  return self.get_cache_name() in instance._state.fields_cache
23
27
 
24
- def set_cached_value(self, instance, value):
28
+ def set_cached_value(self, instance: Any, value: Any) -> None:
25
29
  instance._state.fields_cache[self.get_cache_name()] = value
26
30
 
27
- def delete_cached_value(self, instance):
31
+ def delete_cached_value(self, instance: Any) -> None:
28
32
  del instance._state.fields_cache[self.get_cache_name()]
29
33
 
30
34
 
31
35
  class CheckFieldDefaultMixin:
32
36
  _default_fix = ("<valid default>", "<invalid default>")
33
37
 
34
- def _check_default(self):
38
+ def _check_default(self) -> list[PreflightResult]: # type: ignore[misc]
35
39
  if (
36
40
  self.has_default()
37
41
  and self.default is not None
@@ -53,7 +57,7 @@ class CheckFieldDefaultMixin:
53
57
  else:
54
58
  return []
55
59
 
56
- def preflight(self, **kwargs):
57
- errors = super().preflight(**kwargs)
60
+ def preflight(self, **kwargs: Any) -> list[PreflightResult]: # type: ignore[misc]
61
+ errors = super().preflight(**kwargs) # type: ignore[misc]
58
62
  errors.extend(self._check_default())
59
63
  return errors