plain.models 0.49.2__py3-none-any.whl → 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. plain/models/CHANGELOG.md +13 -0
  2. plain/models/aggregates.py +42 -19
  3. plain/models/backends/base/base.py +125 -105
  4. plain/models/backends/base/client.py +11 -3
  5. plain/models/backends/base/creation.py +22 -12
  6. plain/models/backends/base/features.py +10 -4
  7. plain/models/backends/base/introspection.py +29 -16
  8. plain/models/backends/base/operations.py +187 -91
  9. plain/models/backends/base/schema.py +267 -165
  10. plain/models/backends/base/validation.py +12 -3
  11. plain/models/backends/ddl_references.py +85 -43
  12. plain/models/backends/mysql/base.py +29 -26
  13. plain/models/backends/mysql/client.py +7 -2
  14. plain/models/backends/mysql/compiler.py +12 -3
  15. plain/models/backends/mysql/creation.py +5 -2
  16. plain/models/backends/mysql/features.py +24 -22
  17. plain/models/backends/mysql/introspection.py +22 -13
  18. plain/models/backends/mysql/operations.py +106 -39
  19. plain/models/backends/mysql/schema.py +48 -24
  20. plain/models/backends/mysql/validation.py +13 -6
  21. plain/models/backends/postgresql/base.py +41 -34
  22. plain/models/backends/postgresql/client.py +7 -2
  23. plain/models/backends/postgresql/creation.py +10 -5
  24. plain/models/backends/postgresql/introspection.py +15 -8
  25. plain/models/backends/postgresql/operations.py +109 -42
  26. plain/models/backends/postgresql/schema.py +85 -46
  27. plain/models/backends/sqlite3/_functions.py +151 -115
  28. plain/models/backends/sqlite3/base.py +37 -23
  29. plain/models/backends/sqlite3/client.py +7 -1
  30. plain/models/backends/sqlite3/creation.py +9 -5
  31. plain/models/backends/sqlite3/features.py +5 -3
  32. plain/models/backends/sqlite3/introspection.py +32 -16
  33. plain/models/backends/sqlite3/operations.py +125 -42
  34. plain/models/backends/sqlite3/schema.py +82 -58
  35. plain/models/backends/utils.py +52 -29
  36. plain/models/backups/cli.py +8 -6
  37. plain/models/backups/clients.py +16 -7
  38. plain/models/backups/core.py +24 -13
  39. plain/models/base.py +113 -74
  40. plain/models/cli.py +94 -63
  41. plain/models/config.py +1 -1
  42. plain/models/connections.py +23 -7
  43. plain/models/constraints.py +65 -47
  44. plain/models/database_url.py +1 -1
  45. plain/models/db.py +6 -2
  46. plain/models/deletion.py +66 -43
  47. plain/models/entrypoints.py +1 -1
  48. plain/models/enums.py +22 -11
  49. plain/models/exceptions.py +23 -8
  50. plain/models/expressions.py +440 -257
  51. plain/models/fields/__init__.py +253 -202
  52. plain/models/fields/json.py +120 -54
  53. plain/models/fields/mixins.py +12 -8
  54. plain/models/fields/related.py +284 -252
  55. plain/models/fields/related_descriptors.py +31 -22
  56. plain/models/fields/related_lookups.py +23 -11
  57. plain/models/fields/related_managers.py +81 -47
  58. plain/models/fields/reverse_related.py +58 -55
  59. plain/models/forms.py +89 -63
  60. plain/models/functions/comparison.py +71 -18
  61. plain/models/functions/datetime.py +79 -29
  62. plain/models/functions/math.py +43 -10
  63. plain/models/functions/mixins.py +24 -7
  64. plain/models/functions/text.py +104 -25
  65. plain/models/functions/window.py +12 -6
  66. plain/models/indexes.py +52 -28
  67. plain/models/lookups.py +228 -153
  68. plain/models/migrations/autodetector.py +86 -43
  69. plain/models/migrations/exceptions.py +7 -3
  70. plain/models/migrations/executor.py +33 -7
  71. plain/models/migrations/graph.py +79 -50
  72. plain/models/migrations/loader.py +45 -22
  73. plain/models/migrations/migration.py +23 -18
  74. plain/models/migrations/operations/base.py +37 -19
  75. plain/models/migrations/operations/fields.py +89 -42
  76. plain/models/migrations/operations/models.py +245 -143
  77. plain/models/migrations/operations/special.py +82 -25
  78. plain/models/migrations/optimizer.py +7 -2
  79. plain/models/migrations/questioner.py +58 -31
  80. plain/models/migrations/recorder.py +18 -11
  81. plain/models/migrations/serializer.py +50 -39
  82. plain/models/migrations/state.py +220 -133
  83. plain/models/migrations/utils.py +29 -13
  84. plain/models/migrations/writer.py +17 -14
  85. plain/models/options.py +63 -56
  86. plain/models/otel.py +16 -6
  87. plain/models/preflight.py +35 -12
  88. plain/models/query.py +323 -228
  89. plain/models/query_utils.py +93 -58
  90. plain/models/registry.py +34 -16
  91. plain/models/sql/compiler.py +146 -97
  92. plain/models/sql/datastructures.py +38 -25
  93. plain/models/sql/query.py +255 -169
  94. plain/models/sql/subqueries.py +32 -21
  95. plain/models/sql/where.py +54 -29
  96. plain/models/test/pytest.py +15 -11
  97. plain/models/test/utils.py +4 -2
  98. plain/models/transaction.py +20 -7
  99. plain/models/utils.py +13 -5
  100. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/METADATA +1 -1
  101. plain_models-0.50.0.dist-info/RECORD +122 -0
  102. plain_models-0.49.2.dist-info/RECORD +0 -122
  103. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/WHEEL +0 -0
  104. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/entry_points.txt +0 -0
  105. {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,19 @@
1
1
  """Database functions that do comparisons or type conversions."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
3
7
  from plain.models.db import NotSupportedError
4
8
  from plain.models.expressions import Func, Value
5
- from plain.models.fields import TextField
9
+ from plain.models.fields import Field, TextField
6
10
  from plain.models.fields.json import JSONField
7
11
  from plain.utils.regex_helper import _lazy_re_compile
8
12
 
13
+ if TYPE_CHECKING:
14
+ from plain.models.backends.base.base import BaseDatabaseWrapper
15
+ from plain.models.sql.compiler import SQLCompiler
16
+
9
17
 
10
18
  class Cast(Func):
11
19
  """Coerce an expression to a new field type."""
@@ -13,14 +21,24 @@ class Cast(Func):
13
21
  function = "CAST"
14
22
  template = "%(function)s(%(expressions)s AS %(db_type)s)"
15
23
 
16
- def __init__(self, expression, output_field):
24
+ def __init__(self, expression: Any, output_field: Field) -> None:
17
25
  super().__init__(expression, output_field=output_field)
18
26
 
19
- def as_sql(self, compiler, connection, **extra_context):
27
+ def as_sql(
28
+ self,
29
+ compiler: SQLCompiler,
30
+ connection: BaseDatabaseWrapper,
31
+ **extra_context: Any,
32
+ ) -> tuple[str, tuple[Any, ...]]:
20
33
  extra_context["db_type"] = self.output_field.cast_db_type(connection)
21
34
  return super().as_sql(compiler, connection, **extra_context)
22
35
 
23
- def as_sqlite(self, compiler, connection, **extra_context):
36
+ def as_sqlite(
37
+ self,
38
+ compiler: SQLCompiler,
39
+ connection: BaseDatabaseWrapper,
40
+ **extra_context: Any,
41
+ ) -> tuple[str, tuple[Any, ...]]:
24
42
  db_type = self.output_field.db_type(connection)
25
43
  if db_type in {"datetime", "time"}:
26
44
  # Use strftime as datetime/time don't keep fractional seconds.
@@ -38,18 +56,28 @@ class Cast(Func):
38
56
  )
39
57
  return self.as_sql(compiler, connection, **extra_context)
40
58
 
41
- def as_mysql(self, compiler, connection, **extra_context):
59
+ def as_mysql(
60
+ self,
61
+ compiler: SQLCompiler,
62
+ connection: BaseDatabaseWrapper,
63
+ **extra_context: Any,
64
+ ) -> tuple[str, tuple[Any, ...]]:
42
65
  template = None
43
66
  output_type = self.output_field.get_internal_type()
44
67
  # MySQL doesn't support explicit cast to float.
45
68
  if output_type == "FloatField":
46
69
  template = "(%(expressions)s + 0.0)"
47
70
  # MariaDB doesn't support explicit cast to JSON.
48
- elif output_type == "JSONField" and connection.mysql_is_mariadb:
71
+ elif output_type == "JSONField" and connection.mysql_is_mariadb: # type: ignore[attr-defined]
49
72
  template = "JSON_EXTRACT(%(expressions)s, '$')"
50
73
  return self.as_sql(compiler, connection, template=template, **extra_context)
51
74
 
52
- def as_postgresql(self, compiler, connection, **extra_context):
75
+ def as_postgresql(
76
+ self,
77
+ compiler: SQLCompiler,
78
+ connection: BaseDatabaseWrapper,
79
+ **extra_context: Any,
80
+ ) -> tuple[str, tuple[Any, ...]]:
53
81
  # CAST would be valid too, but the :: shortcut syntax is more readable.
54
82
  # 'expressions' is wrapped in parentheses in case it's a complex
55
83
  # expression.
@@ -66,13 +94,13 @@ class Coalesce(Func):
66
94
 
67
95
  function = "COALESCE"
68
96
 
69
- def __init__(self, *expressions, **extra):
97
+ def __init__(self, *expressions: Any, **extra: Any) -> None:
70
98
  if len(expressions) < 2:
71
99
  raise ValueError("Coalesce must take at least two expressions")
72
100
  super().__init__(*expressions, **extra)
73
101
 
74
102
  @property
75
- def empty_result_set_value(self):
103
+ def empty_result_set_value(self) -> Any:
76
104
  for expression in self.get_source_expressions():
77
105
  result = expression.empty_result_set_value
78
106
  if result is NotImplemented or result is not None:
@@ -87,13 +115,18 @@ class Collate(Func):
87
115
  # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
88
116
  collation_re = _lazy_re_compile(r"^[\w\-]+$")
89
117
 
90
- def __init__(self, expression, collation):
118
+ def __init__(self, expression: Any, collation: str) -> None:
91
119
  if not (collation and self.collation_re.match(collation)):
92
120
  raise ValueError(f"Invalid collation name: {collation!r}.")
93
121
  self.collation = collation
94
122
  super().__init__(expression)
95
123
 
96
- def as_sql(self, compiler, connection, **extra_context):
124
+ def as_sql(
125
+ self,
126
+ compiler: SQLCompiler,
127
+ connection: BaseDatabaseWrapper,
128
+ **extra_context: Any,
129
+ ) -> tuple[str, tuple[Any, ...]]:
97
130
  extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
98
131
  return super().as_sql(compiler, connection, **extra_context)
99
132
 
@@ -109,12 +142,17 @@ class Greatest(Func):
109
142
 
110
143
  function = "GREATEST"
111
144
 
112
- def __init__(self, *expressions, **extra):
145
+ def __init__(self, *expressions: Any, **extra: Any) -> None:
113
146
  if len(expressions) < 2:
114
147
  raise ValueError("Greatest must take at least two expressions")
115
148
  super().__init__(*expressions, **extra)
116
149
 
117
- def as_sqlite(self, compiler, connection, **extra_context):
150
+ def as_sqlite(
151
+ self,
152
+ compiler: SQLCompiler,
153
+ connection: BaseDatabaseWrapper,
154
+ **extra_context: Any,
155
+ ) -> tuple[str, tuple[Any, ...]]:
118
156
  """Use the MAX function on SQLite."""
119
157
  return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
120
158
 
@@ -123,20 +161,30 @@ class JSONObject(Func):
123
161
  function = "JSON_OBJECT"
124
162
  output_field = JSONField()
125
163
 
126
- def __init__(self, **fields):
164
+ def __init__(self, **fields: Any) -> None:
127
165
  expressions = []
128
166
  for key, value in fields.items():
129
167
  expressions.extend((Value(key), value))
130
168
  super().__init__(*expressions)
131
169
 
132
- def as_sql(self, compiler, connection, **extra_context):
170
+ def as_sql(
171
+ self,
172
+ compiler: SQLCompiler,
173
+ connection: BaseDatabaseWrapper,
174
+ **extra_context: Any,
175
+ ) -> tuple[str, tuple[Any, ...]]:
133
176
  if not connection.features.has_json_object_function:
134
177
  raise NotSupportedError(
135
178
  "JSONObject() is not supported on this database backend."
136
179
  )
137
180
  return super().as_sql(compiler, connection, **extra_context)
138
181
 
139
- def as_postgresql(self, compiler, connection, **extra_context):
182
+ def as_postgresql(
183
+ self,
184
+ compiler: SQLCompiler,
185
+ connection: BaseDatabaseWrapper,
186
+ **extra_context: Any,
187
+ ) -> tuple[str, tuple[Any, ...]]:
140
188
  copy = self.copy()
141
189
  copy.set_source_expressions(
142
190
  [
@@ -163,12 +211,17 @@ class Least(Func):
163
211
 
164
212
  function = "LEAST"
165
213
 
166
- def __init__(self, *expressions, **extra):
214
+ def __init__(self, *expressions: Any, **extra: Any) -> None:
167
215
  if len(expressions) < 2:
168
216
  raise ValueError("Least must take at least two expressions")
169
217
  super().__init__(*expressions, **extra)
170
218
 
171
- def as_sqlite(self, compiler, connection, **extra_context):
219
+ def as_sqlite(
220
+ self,
221
+ compiler: SQLCompiler,
222
+ connection: BaseDatabaseWrapper,
223
+ **extra_context: Any,
224
+ ) -> tuple[str, tuple[Any, ...]]:
172
225
  """Use the MIN function on SQLite."""
173
226
  return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
174
227
 
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  from plain.models.expressions import Func
4
7
  from plain.models.fields import (
@@ -19,11 +22,15 @@ from plain.models.lookups import (
19
22
  )
20
23
  from plain.utils import timezone
21
24
 
25
+ if TYPE_CHECKING:
26
+ from plain.models.backends.base.base import BaseDatabaseWrapper
27
+ from plain.models.sql.compiler import SQLCompiler
28
+
22
29
 
23
30
  class TimezoneMixin:
24
31
  tzinfo = None
25
32
 
26
- def get_tzname(self):
33
+ def get_tzname(self) -> str | None:
27
34
  # Timezone conversions must happen to the input datetime *before*
28
35
  # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
29
36
  # database as 2016-01-01 01:00:00 +00:00. Any results should be
@@ -38,7 +45,13 @@ class Extract(TimezoneMixin, Transform):
38
45
  lookup_name = None
39
46
  output_field = IntegerField()
40
47
 
41
- def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
48
+ def __init__(
49
+ self,
50
+ expression: Any,
51
+ lookup_name: str | None = None,
52
+ tzinfo: Any = None,
53
+ **extra: Any,
54
+ ) -> None:
42
55
  if self.lookup_name is None:
43
56
  self.lookup_name = lookup_name
44
57
  if self.lookup_name is None:
@@ -46,7 +59,9 @@ class Extract(TimezoneMixin, Transform):
46
59
  self.tzinfo = tzinfo
47
60
  super().__init__(expression, **extra)
48
61
 
49
- def as_sql(self, compiler, connection):
62
+ def as_sql(
63
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
64
+ ) -> tuple[str, tuple[Any, ...]]:
50
65
  sql, params = compiler.compile(self.lhs)
51
66
  lhs_output_field = self.lhs.output_field
52
67
  if isinstance(lhs_output_field, DateTimeField):
@@ -76,11 +91,16 @@ class Extract(TimezoneMixin, Transform):
76
91
  # resolve_expression has already validated the output_field so this
77
92
  # assert should never be hit.
78
93
  raise ValueError("Tried to Extract from an invalid type.")
79
- return sql, params
94
+ return sql, tuple(params) if isinstance(params, list) else params
80
95
 
81
96
  def resolve_expression(
82
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
83
- ):
97
+ self,
98
+ query: Any = None,
99
+ allow_joins: bool = True,
100
+ reuse: Any = None,
101
+ summarize: bool = False,
102
+ for_save: bool = False,
103
+ ) -> Extract:
84
104
  copy = super().resolve_expression(
85
105
  query, allow_joins, reuse, summarize, for_save
86
106
  )
@@ -209,7 +229,12 @@ class Now(Func):
209
229
  template = "CURRENT_TIMESTAMP"
210
230
  output_field = DateTimeField()
211
231
 
212
- def as_postgresql(self, compiler, connection, **extra_context):
232
+ def as_postgresql(
233
+ self,
234
+ compiler: SQLCompiler,
235
+ connection: BaseDatabaseWrapper,
236
+ **extra_context: Any,
237
+ ) -> tuple[str, tuple[Any, ...]]:
213
238
  # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
214
239
  # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
215
240
  # other databases.
@@ -217,12 +242,22 @@ class Now(Func):
217
242
  compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
218
243
  )
219
244
 
220
- def as_mysql(self, compiler, connection, **extra_context):
245
+ def as_mysql(
246
+ self,
247
+ compiler: SQLCompiler,
248
+ connection: BaseDatabaseWrapper,
249
+ **extra_context: Any,
250
+ ) -> tuple[str, tuple[Any, ...]]:
221
251
  return self.as_sql(
222
252
  compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
223
253
  )
224
254
 
225
- def as_sqlite(self, compiler, connection, **extra_context):
255
+ def as_sqlite(
256
+ self,
257
+ compiler: SQLCompiler,
258
+ connection: BaseDatabaseWrapper,
259
+ **extra_context: Any,
260
+ ) -> tuple[str, tuple[Any, ...]]:
226
261
  return self.as_sql(
227
262
  compiler,
228
263
  connection,
@@ -237,15 +272,17 @@ class TruncBase(TimezoneMixin, Transform):
237
272
 
238
273
  def __init__(
239
274
  self,
240
- expression,
241
- output_field=None,
242
- tzinfo=None,
243
- **extra,
244
- ):
275
+ expression: Any,
276
+ output_field: Field | None = None,
277
+ tzinfo: Any = None,
278
+ **extra: Any,
279
+ ) -> None:
245
280
  self.tzinfo = tzinfo
246
281
  super().__init__(expression, output_field=output_field, **extra)
247
282
 
248
- def as_sql(self, compiler, connection):
283
+ def as_sql(
284
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
285
+ ) -> tuple[str, tuple[Any, ...]]:
249
286
  sql, params = compiler.compile(self.lhs)
250
287
  tzname = None
251
288
  if isinstance(self.lhs.output_field, DateTimeField):
@@ -268,11 +305,16 @@ class TruncBase(TimezoneMixin, Transform):
268
305
  raise ValueError(
269
306
  "Trunc only valid on DateField, TimeField, or DateTimeField."
270
307
  )
271
- return sql, params
308
+ return sql, tuple(params) if isinstance(params, list) else params
272
309
 
273
310
  def resolve_expression(
274
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
275
- ):
311
+ self,
312
+ query: Any = None,
313
+ allow_joins: bool = True,
314
+ reuse: Any = None,
315
+ summarize: bool = False,
316
+ for_save: bool = False,
317
+ ) -> TruncBase:
276
318
  copy = super().resolve_expression(
277
319
  query, allow_joins, reuse, summarize, for_save
278
320
  )
@@ -325,7 +367,9 @@ class TruncBase(TimezoneMixin, Transform):
325
367
  )
326
368
  return copy
327
369
 
328
- def convert_value(self, value, expression, connection):
370
+ def convert_value(
371
+ self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372
+ ) -> Any:
329
373
  if isinstance(self.output_field, DateTimeField):
330
374
  if value is not None:
331
375
  value = value.replace(tzinfo=None)
@@ -348,12 +392,12 @@ class TruncBase(TimezoneMixin, Transform):
348
392
  class Trunc(TruncBase):
349
393
  def __init__(
350
394
  self,
351
- expression,
352
- kind,
353
- output_field=None,
354
- tzinfo=None,
355
- **extra,
356
- ):
395
+ expression: Any,
396
+ kind: str,
397
+ output_field: Field | None = None,
398
+ tzinfo: Any = None,
399
+ **extra: Any,
400
+ ) -> None:
357
401
  self.kind = kind
358
402
  super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
359
403
 
@@ -385,11 +429,14 @@ class TruncDate(TruncBase):
385
429
  lookup_name = "date"
386
430
  output_field = DateField()
387
431
 
388
- def as_sql(self, compiler, connection):
432
+ def as_sql(
433
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434
+ ) -> tuple[str, tuple[Any, ...]]:
389
435
  # Cast to date rather than truncate to date.
390
436
  sql, params = compiler.compile(self.lhs)
391
437
  tzname = self.get_tzname()
392
- return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
438
+ sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
439
+ return sql, tuple(params) if isinstance(params, list) else params
393
440
 
394
441
 
395
442
  class TruncTime(TruncBase):
@@ -397,11 +444,14 @@ class TruncTime(TruncBase):
397
444
  lookup_name = "time"
398
445
  output_field = TimeField()
399
446
 
400
- def as_sql(self, compiler, connection):
447
+ def as_sql(
448
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
449
+ ) -> tuple[str, tuple[Any, ...]]:
401
450
  # Cast to time rather than truncate to time.
402
451
  sql, params = compiler.compile(self.lhs)
403
452
  tzname = self.get_tzname()
404
- return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
453
+ sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
454
+ return sql, tuple(params) if isinstance(params, list) else params
405
455
 
406
456
 
407
457
  class TruncHour(TruncBase):
@@ -1,5 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
1
5
  from plain.models.expressions import Func, Value
2
- from plain.models.fields import FloatField, IntegerField
6
+ from plain.models.fields import Field, FloatField, IntegerField
3
7
  from plain.models.functions import Cast
4
8
  from plain.models.functions.mixins import (
5
9
  FixDecimalInputMixin,
@@ -7,6 +11,10 @@ from plain.models.functions.mixins import (
7
11
  )
8
12
  from plain.models.lookups import Transform
9
13
 
14
+ if TYPE_CHECKING:
15
+ from plain.models.backends.base.base import BaseDatabaseWrapper
16
+ from plain.models.sql.compiler import SQLCompiler
17
+
10
18
 
11
19
  class Abs(Transform):
12
20
  function = "ABS"
@@ -32,10 +40,15 @@ class ATan2(NumericOutputFieldMixin, Func):
32
40
  function = "ATAN2"
33
41
  arity = 2
34
42
 
35
- def as_sqlite(self, compiler, connection, **extra_context):
43
+ def as_sqlite(
44
+ self,
45
+ compiler: SQLCompiler,
46
+ connection: BaseDatabaseWrapper,
47
+ **extra_context: Any,
48
+ ) -> tuple[str, tuple[Any, ...]]:
36
49
  if not getattr(
37
50
  connection.ops, "spatialite", False
38
- ) or connection.ops.spatial_version >= (5, 0, 0):
51
+ ) or connection.ops.spatial_version >= (5, 0, 0): # type: ignore[attr-defined]
39
52
  return self.as_sql(compiler, connection)
40
53
  # This function is usually ATan2(y, x), returning the inverse tangent
41
54
  # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
@@ -93,7 +106,12 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
93
106
  function = "LOG"
94
107
  arity = 2
95
108
 
96
- def as_sqlite(self, compiler, connection, **extra_context):
109
+ def as_sqlite(
110
+ self,
111
+ compiler: SQLCompiler,
112
+ connection: BaseDatabaseWrapper,
113
+ **extra_context: Any,
114
+ ) -> tuple[str, tuple[Any, ...]]:
97
115
  if not getattr(connection.ops, "spatialite", False):
98
116
  return self.as_sql(compiler, connection)
99
117
  # This function is usually Log(b, x) returning the logarithm of x to
@@ -127,13 +145,23 @@ class Random(NumericOutputFieldMixin, Func):
127
145
  function = "RANDOM"
128
146
  arity = 0
129
147
 
130
- def as_mysql(self, compiler, connection, **extra_context):
148
+ def as_mysql(
149
+ self,
150
+ compiler: SQLCompiler,
151
+ connection: BaseDatabaseWrapper,
152
+ **extra_context: Any,
153
+ ) -> tuple[str, tuple[Any, ...]]:
131
154
  return super().as_sql(compiler, connection, function="RAND", **extra_context)
132
155
 
133
- def as_sqlite(self, compiler, connection, **extra_context):
156
+ def as_sqlite(
157
+ self,
158
+ compiler: SQLCompiler,
159
+ connection: BaseDatabaseWrapper,
160
+ **extra_context: Any,
161
+ ) -> tuple[str, tuple[Any, ...]]:
134
162
  return super().as_sql(compiler, connection, function="RAND", **extra_context)
135
163
 
136
- def get_group_by_cols(self):
164
+ def get_group_by_cols(self) -> list[Any]:
137
165
  return []
138
166
 
139
167
 
@@ -142,16 +170,21 @@ class Round(FixDecimalInputMixin, Transform):
142
170
  lookup_name = "round"
143
171
  arity = None # Override Transform's arity=1 to enable passing precision.
144
172
 
145
- def __init__(self, expression, precision=0, **extra):
173
+ def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
146
174
  super().__init__(expression, precision, **extra)
147
175
 
148
- def as_sqlite(self, compiler, connection, **extra_context):
176
+ def as_sqlite(
177
+ self,
178
+ compiler: SQLCompiler,
179
+ connection: BaseDatabaseWrapper,
180
+ **extra_context: Any,
181
+ ) -> tuple[str, tuple[Any, ...]]:
149
182
  precision = self.get_source_expressions()[1]
150
183
  if isinstance(precision, Value) and precision.value < 0:
151
184
  raise ValueError("SQLite does not support negative precision.")
152
185
  return super().as_sqlite(compiler, connection, **extra_context)
153
186
 
154
- def _resolve_output_field(self):
187
+ def _resolve_output_field(self) -> Field:
155
188
  source = self.get_source_expressions()[0]
156
189
  return source.output_field
157
190
 
@@ -1,11 +1,23 @@
1
+ from __future__ import annotations
2
+
1
3
  import sys
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  from plain.models.fields import DecimalField, FloatField, IntegerField
4
7
  from plain.models.functions import Cast
5
8
 
9
+ if TYPE_CHECKING:
10
+ from plain.models.backends.base.base import BaseDatabaseWrapper
11
+ from plain.models.sql.compiler import SQLCompiler
12
+
6
13
 
7
14
  class FixDecimalInputMixin:
8
- def as_postgresql(self, compiler, connection, **extra_context):
15
+ def as_postgresql(
16
+ self,
17
+ compiler: SQLCompiler,
18
+ connection: BaseDatabaseWrapper,
19
+ **extra_context: Any,
20
+ ) -> tuple[str, tuple[Any, ...]]:
9
21
  # Cast FloatField to DecimalField as PostgreSQL doesn't support the
10
22
  # following function signatures:
11
23
  # - LOG(double, double)
@@ -24,18 +36,23 @@ class FixDecimalInputMixin:
24
36
 
25
37
 
26
38
  class FixDurationInputMixin:
27
- def as_mysql(self, compiler, connection, **extra_context):
28
- sql, params = super().as_sql(compiler, connection, **extra_context)
29
- if self.output_field.get_internal_type() == "DurationField":
39
+ def as_mysql(
40
+ self,
41
+ compiler: SQLCompiler,
42
+ connection: BaseDatabaseWrapper,
43
+ **extra_context: Any,
44
+ ) -> tuple[str, tuple[Any, ...]]:
45
+ sql, params = super().as_sql(compiler, connection, **extra_context) # type: ignore[misc]
46
+ if self.output_field.get_internal_type() == "DurationField": # type: ignore[attr-defined]
30
47
  sql = f"CAST({sql} AS SIGNED)"
31
48
  return sql, params
32
49
 
33
50
 
34
51
  class NumericOutputFieldMixin:
35
- def _resolve_output_field(self):
36
- source_fields = self.get_source_fields()
52
+ def _resolve_output_field(self) -> DecimalField | FloatField:
53
+ source_fields = self.get_source_fields() # type: ignore[attr-defined]
37
54
  if any(isinstance(s, DecimalField) for s in source_fields):
38
55
  return DecimalField()
39
56
  if any(isinstance(s, IntegerField) for s in source_fields):
40
57
  return FloatField()
41
- return super()._resolve_output_field() if source_fields else FloatField()
58
+ return super()._resolve_output_field() if source_fields else FloatField() # type: ignore[misc]