plain.models 0.49.2__py3-none-any.whl → 0.51.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. plain/models/CHANGELOG.md +27 -0
  2. plain/models/README.md +26 -42
  3. plain/models/__init__.py +2 -0
  4. plain/models/aggregates.py +42 -19
  5. plain/models/backends/base/base.py +125 -105
  6. plain/models/backends/base/client.py +11 -3
  7. plain/models/backends/base/creation.py +24 -14
  8. plain/models/backends/base/features.py +10 -4
  9. plain/models/backends/base/introspection.py +37 -20
  10. plain/models/backends/base/operations.py +187 -91
  11. plain/models/backends/base/schema.py +338 -218
  12. plain/models/backends/base/validation.py +13 -4
  13. plain/models/backends/ddl_references.py +85 -43
  14. plain/models/backends/mysql/base.py +29 -26
  15. plain/models/backends/mysql/client.py +7 -2
  16. plain/models/backends/mysql/compiler.py +13 -4
  17. plain/models/backends/mysql/creation.py +5 -2
  18. plain/models/backends/mysql/features.py +24 -22
  19. plain/models/backends/mysql/introspection.py +22 -13
  20. plain/models/backends/mysql/operations.py +107 -40
  21. plain/models/backends/mysql/schema.py +52 -28
  22. plain/models/backends/mysql/validation.py +13 -6
  23. plain/models/backends/postgresql/base.py +41 -34
  24. plain/models/backends/postgresql/client.py +7 -2
  25. plain/models/backends/postgresql/creation.py +10 -5
  26. plain/models/backends/postgresql/introspection.py +15 -8
  27. plain/models/backends/postgresql/operations.py +110 -43
  28. plain/models/backends/postgresql/schema.py +88 -49
  29. plain/models/backends/sqlite3/_functions.py +151 -115
  30. plain/models/backends/sqlite3/base.py +37 -23
  31. plain/models/backends/sqlite3/client.py +7 -1
  32. plain/models/backends/sqlite3/creation.py +9 -5
  33. plain/models/backends/sqlite3/features.py +5 -3
  34. plain/models/backends/sqlite3/introspection.py +32 -16
  35. plain/models/backends/sqlite3/operations.py +126 -43
  36. plain/models/backends/sqlite3/schema.py +127 -92
  37. plain/models/backends/utils.py +52 -29
  38. plain/models/backups/cli.py +8 -6
  39. plain/models/backups/clients.py +16 -7
  40. plain/models/backups/core.py +24 -13
  41. plain/models/base.py +221 -229
  42. plain/models/cli.py +98 -67
  43. plain/models/config.py +1 -1
  44. plain/models/connections.py +23 -7
  45. plain/models/constraints.py +79 -56
  46. plain/models/database_url.py +1 -1
  47. plain/models/db.py +6 -2
  48. plain/models/deletion.py +80 -56
  49. plain/models/entrypoints.py +1 -1
  50. plain/models/enums.py +22 -11
  51. plain/models/exceptions.py +23 -8
  52. plain/models/expressions.py +441 -258
  53. plain/models/fields/__init__.py +272 -217
  54. plain/models/fields/json.py +123 -57
  55. plain/models/fields/mixins.py +12 -8
  56. plain/models/fields/related.py +324 -290
  57. plain/models/fields/related_descriptors.py +33 -24
  58. plain/models/fields/related_lookups.py +24 -12
  59. plain/models/fields/related_managers.py +102 -79
  60. plain/models/fields/reverse_related.py +66 -63
  61. plain/models/forms.py +101 -75
  62. plain/models/functions/comparison.py +71 -18
  63. plain/models/functions/datetime.py +79 -29
  64. plain/models/functions/math.py +43 -10
  65. plain/models/functions/mixins.py +24 -7
  66. plain/models/functions/text.py +104 -25
  67. plain/models/functions/window.py +12 -6
  68. plain/models/indexes.py +57 -32
  69. plain/models/lookups.py +228 -153
  70. plain/models/meta.py +505 -0
  71. plain/models/migrations/autodetector.py +86 -43
  72. plain/models/migrations/exceptions.py +7 -3
  73. plain/models/migrations/executor.py +33 -7
  74. plain/models/migrations/graph.py +79 -50
  75. plain/models/migrations/loader.py +45 -22
  76. plain/models/migrations/migration.py +23 -18
  77. plain/models/migrations/operations/base.py +38 -20
  78. plain/models/migrations/operations/fields.py +95 -48
  79. plain/models/migrations/operations/models.py +246 -142
  80. plain/models/migrations/operations/special.py +82 -25
  81. plain/models/migrations/optimizer.py +7 -2
  82. plain/models/migrations/questioner.py +58 -31
  83. plain/models/migrations/recorder.py +27 -16
  84. plain/models/migrations/serializer.py +50 -39
  85. plain/models/migrations/state.py +232 -156
  86. plain/models/migrations/utils.py +30 -14
  87. plain/models/migrations/writer.py +17 -14
  88. plain/models/options.py +189 -518
  89. plain/models/otel.py +16 -6
  90. plain/models/preflight.py +42 -17
  91. plain/models/query.py +400 -251
  92. plain/models/query_utils.py +109 -69
  93. plain/models/registry.py +40 -21
  94. plain/models/sql/compiler.py +190 -127
  95. plain/models/sql/datastructures.py +38 -25
  96. plain/models/sql/query.py +320 -225
  97. plain/models/sql/subqueries.py +36 -25
  98. plain/models/sql/where.py +54 -29
  99. plain/models/test/pytest.py +15 -11
  100. plain/models/test/utils.py +4 -2
  101. plain/models/transaction.py +20 -7
  102. plain/models/utils.py +17 -6
  103. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/METADATA +27 -43
  104. plain_models-0.51.0.dist-info/RECORD +123 -0
  105. plain_models-0.49.2.dist-info/RECORD +0 -122
  106. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/WHEEL +0 -0
  107. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/entry_points.txt +0 -0
  108. {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  from plain.models.expressions import Func
4
7
  from plain.models.fields import (
@@ -19,11 +22,15 @@ from plain.models.lookups import (
19
22
  )
20
23
  from plain.utils import timezone
21
24
 
25
+ if TYPE_CHECKING:
26
+ from plain.models.backends.base.base import BaseDatabaseWrapper
27
+ from plain.models.sql.compiler import SQLCompiler
28
+
22
29
 
23
30
  class TimezoneMixin:
24
31
  tzinfo = None
25
32
 
26
- def get_tzname(self):
33
+ def get_tzname(self) -> str | None:
27
34
  # Timezone conversions must happen to the input datetime *before*
28
35
  # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
29
36
  # database as 2016-01-01 01:00:00 +00:00. Any results should be
@@ -38,7 +45,13 @@ class Extract(TimezoneMixin, Transform):
38
45
  lookup_name = None
39
46
  output_field = IntegerField()
40
47
 
41
- def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
48
+ def __init__(
49
+ self,
50
+ expression: Any,
51
+ lookup_name: str | None = None,
52
+ tzinfo: Any = None,
53
+ **extra: Any,
54
+ ) -> None:
42
55
  if self.lookup_name is None:
43
56
  self.lookup_name = lookup_name
44
57
  if self.lookup_name is None:
@@ -46,7 +59,9 @@ class Extract(TimezoneMixin, Transform):
46
59
  self.tzinfo = tzinfo
47
60
  super().__init__(expression, **extra)
48
61
 
49
- def as_sql(self, compiler, connection):
62
+ def as_sql(
63
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
64
+ ) -> tuple[str, tuple[Any, ...]]:
50
65
  sql, params = compiler.compile(self.lhs)
51
66
  lhs_output_field = self.lhs.output_field
52
67
  if isinstance(lhs_output_field, DateTimeField):
@@ -76,11 +91,16 @@ class Extract(TimezoneMixin, Transform):
76
91
  # resolve_expression has already validated the output_field so this
77
92
  # assert should never be hit.
78
93
  raise ValueError("Tried to Extract from an invalid type.")
79
- return sql, params
94
+ return sql, tuple(params) if isinstance(params, list) else params
80
95
 
81
96
  def resolve_expression(
82
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
83
- ):
97
+ self,
98
+ query: Any = None,
99
+ allow_joins: bool = True,
100
+ reuse: Any = None,
101
+ summarize: bool = False,
102
+ for_save: bool = False,
103
+ ) -> Extract:
84
104
  copy = super().resolve_expression(
85
105
  query, allow_joins, reuse, summarize, for_save
86
106
  )
@@ -209,7 +229,12 @@ class Now(Func):
209
229
  template = "CURRENT_TIMESTAMP"
210
230
  output_field = DateTimeField()
211
231
 
212
- def as_postgresql(self, compiler, connection, **extra_context):
232
+ def as_postgresql(
233
+ self,
234
+ compiler: SQLCompiler,
235
+ connection: BaseDatabaseWrapper,
236
+ **extra_context: Any,
237
+ ) -> tuple[str, tuple[Any, ...]]:
213
238
  # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
214
239
  # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
215
240
  # other databases.
@@ -217,12 +242,22 @@ class Now(Func):
217
242
  compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
218
243
  )
219
244
 
220
- def as_mysql(self, compiler, connection, **extra_context):
245
+ def as_mysql(
246
+ self,
247
+ compiler: SQLCompiler,
248
+ connection: BaseDatabaseWrapper,
249
+ **extra_context: Any,
250
+ ) -> tuple[str, tuple[Any, ...]]:
221
251
  return self.as_sql(
222
252
  compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
223
253
  )
224
254
 
225
- def as_sqlite(self, compiler, connection, **extra_context):
255
+ def as_sqlite(
256
+ self,
257
+ compiler: SQLCompiler,
258
+ connection: BaseDatabaseWrapper,
259
+ **extra_context: Any,
260
+ ) -> tuple[str, tuple[Any, ...]]:
226
261
  return self.as_sql(
227
262
  compiler,
228
263
  connection,
@@ -237,15 +272,17 @@ class TruncBase(TimezoneMixin, Transform):
237
272
 
238
273
  def __init__(
239
274
  self,
240
- expression,
241
- output_field=None,
242
- tzinfo=None,
243
- **extra,
244
- ):
275
+ expression: Any,
276
+ output_field: Field | None = None,
277
+ tzinfo: Any = None,
278
+ **extra: Any,
279
+ ) -> None:
245
280
  self.tzinfo = tzinfo
246
281
  super().__init__(expression, output_field=output_field, **extra)
247
282
 
248
- def as_sql(self, compiler, connection):
283
+ def as_sql(
284
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
285
+ ) -> tuple[str, tuple[Any, ...]]:
249
286
  sql, params = compiler.compile(self.lhs)
250
287
  tzname = None
251
288
  if isinstance(self.lhs.output_field, DateTimeField):
@@ -268,11 +305,16 @@ class TruncBase(TimezoneMixin, Transform):
268
305
  raise ValueError(
269
306
  "Trunc only valid on DateField, TimeField, or DateTimeField."
270
307
  )
271
- return sql, params
308
+ return sql, tuple(params) if isinstance(params, list) else params
272
309
 
273
310
  def resolve_expression(
274
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
275
- ):
311
+ self,
312
+ query: Any = None,
313
+ allow_joins: bool = True,
314
+ reuse: Any = None,
315
+ summarize: bool = False,
316
+ for_save: bool = False,
317
+ ) -> TruncBase:
276
318
  copy = super().resolve_expression(
277
319
  query, allow_joins, reuse, summarize, for_save
278
320
  )
@@ -325,7 +367,9 @@ class TruncBase(TimezoneMixin, Transform):
325
367
  )
326
368
  return copy
327
369
 
328
- def convert_value(self, value, expression, connection):
370
+ def convert_value(
371
+ self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372
+ ) -> Any:
329
373
  if isinstance(self.output_field, DateTimeField):
330
374
  if value is not None:
331
375
  value = value.replace(tzinfo=None)
@@ -348,12 +392,12 @@ class TruncBase(TimezoneMixin, Transform):
348
392
  class Trunc(TruncBase):
349
393
  def __init__(
350
394
  self,
351
- expression,
352
- kind,
353
- output_field=None,
354
- tzinfo=None,
355
- **extra,
356
- ):
395
+ expression: Any,
396
+ kind: str,
397
+ output_field: Field | None = None,
398
+ tzinfo: Any = None,
399
+ **extra: Any,
400
+ ) -> None:
357
401
  self.kind = kind
358
402
  super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
359
403
 
@@ -385,11 +429,14 @@ class TruncDate(TruncBase):
385
429
  lookup_name = "date"
386
430
  output_field = DateField()
387
431
 
388
- def as_sql(self, compiler, connection):
432
+ def as_sql(
433
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434
+ ) -> tuple[str, tuple[Any, ...]]:
389
435
  # Cast to date rather than truncate to date.
390
436
  sql, params = compiler.compile(self.lhs)
391
437
  tzname = self.get_tzname()
392
- return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
438
+ sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
439
+ return sql, tuple(params) if isinstance(params, list) else params
393
440
 
394
441
 
395
442
  class TruncTime(TruncBase):
@@ -397,11 +444,14 @@ class TruncTime(TruncBase):
397
444
  lookup_name = "time"
398
445
  output_field = TimeField()
399
446
 
400
- def as_sql(self, compiler, connection):
447
+ def as_sql(
448
+ self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
449
+ ) -> tuple[str, tuple[Any, ...]]:
401
450
  # Cast to time rather than truncate to time.
402
451
  sql, params = compiler.compile(self.lhs)
403
452
  tzname = self.get_tzname()
404
- return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
453
+ sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
454
+ return sql, tuple(params) if isinstance(params, list) else params
405
455
 
406
456
 
407
457
  class TruncHour(TruncBase):
@@ -1,5 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
1
5
  from plain.models.expressions import Func, Value
2
- from plain.models.fields import FloatField, IntegerField
6
+ from plain.models.fields import Field, FloatField, IntegerField
3
7
  from plain.models.functions import Cast
4
8
  from plain.models.functions.mixins import (
5
9
  FixDecimalInputMixin,
@@ -7,6 +11,10 @@ from plain.models.functions.mixins import (
7
11
  )
8
12
  from plain.models.lookups import Transform
9
13
 
14
+ if TYPE_CHECKING:
15
+ from plain.models.backends.base.base import BaseDatabaseWrapper
16
+ from plain.models.sql.compiler import SQLCompiler
17
+
10
18
 
11
19
  class Abs(Transform):
12
20
  function = "ABS"
@@ -32,10 +40,15 @@ class ATan2(NumericOutputFieldMixin, Func):
32
40
  function = "ATAN2"
33
41
  arity = 2
34
42
 
35
- def as_sqlite(self, compiler, connection, **extra_context):
43
+ def as_sqlite(
44
+ self,
45
+ compiler: SQLCompiler,
46
+ connection: BaseDatabaseWrapper,
47
+ **extra_context: Any,
48
+ ) -> tuple[str, tuple[Any, ...]]:
36
49
  if not getattr(
37
50
  connection.ops, "spatialite", False
38
- ) or connection.ops.spatial_version >= (5, 0, 0):
51
+ ) or connection.ops.spatial_version >= (5, 0, 0): # type: ignore[attr-defined]
39
52
  return self.as_sql(compiler, connection)
40
53
  # This function is usually ATan2(y, x), returning the inverse tangent
41
54
  # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
@@ -93,7 +106,12 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
93
106
  function = "LOG"
94
107
  arity = 2
95
108
 
96
- def as_sqlite(self, compiler, connection, **extra_context):
109
+ def as_sqlite(
110
+ self,
111
+ compiler: SQLCompiler,
112
+ connection: BaseDatabaseWrapper,
113
+ **extra_context: Any,
114
+ ) -> tuple[str, tuple[Any, ...]]:
97
115
  if not getattr(connection.ops, "spatialite", False):
98
116
  return self.as_sql(compiler, connection)
99
117
  # This function is usually Log(b, x) returning the logarithm of x to
@@ -127,13 +145,23 @@ class Random(NumericOutputFieldMixin, Func):
127
145
  function = "RANDOM"
128
146
  arity = 0
129
147
 
130
- def as_mysql(self, compiler, connection, **extra_context):
148
+ def as_mysql(
149
+ self,
150
+ compiler: SQLCompiler,
151
+ connection: BaseDatabaseWrapper,
152
+ **extra_context: Any,
153
+ ) -> tuple[str, tuple[Any, ...]]:
131
154
  return super().as_sql(compiler, connection, function="RAND", **extra_context)
132
155
 
133
- def as_sqlite(self, compiler, connection, **extra_context):
156
+ def as_sqlite(
157
+ self,
158
+ compiler: SQLCompiler,
159
+ connection: BaseDatabaseWrapper,
160
+ **extra_context: Any,
161
+ ) -> tuple[str, tuple[Any, ...]]:
134
162
  return super().as_sql(compiler, connection, function="RAND", **extra_context)
135
163
 
136
- def get_group_by_cols(self):
164
+ def get_group_by_cols(self) -> list[Any]:
137
165
  return []
138
166
 
139
167
 
@@ -142,16 +170,21 @@ class Round(FixDecimalInputMixin, Transform):
142
170
  lookup_name = "round"
143
171
  arity = None # Override Transform's arity=1 to enable passing precision.
144
172
 
145
- def __init__(self, expression, precision=0, **extra):
173
+ def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
146
174
  super().__init__(expression, precision, **extra)
147
175
 
148
- def as_sqlite(self, compiler, connection, **extra_context):
176
+ def as_sqlite(
177
+ self,
178
+ compiler: SQLCompiler,
179
+ connection: BaseDatabaseWrapper,
180
+ **extra_context: Any,
181
+ ) -> tuple[str, tuple[Any, ...]]:
149
182
  precision = self.get_source_expressions()[1]
150
183
  if isinstance(precision, Value) and precision.value < 0:
151
184
  raise ValueError("SQLite does not support negative precision.")
152
185
  return super().as_sqlite(compiler, connection, **extra_context)
153
186
 
154
- def _resolve_output_field(self):
187
+ def _resolve_output_field(self) -> Field:
155
188
  source = self.get_source_expressions()[0]
156
189
  return source.output_field
157
190
 
@@ -1,11 +1,23 @@
1
+ from __future__ import annotations
2
+
1
3
  import sys
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  from plain.models.fields import DecimalField, FloatField, IntegerField
4
7
  from plain.models.functions import Cast
5
8
 
9
+ if TYPE_CHECKING:
10
+ from plain.models.backends.base.base import BaseDatabaseWrapper
11
+ from plain.models.sql.compiler import SQLCompiler
12
+
6
13
 
7
14
  class FixDecimalInputMixin:
8
- def as_postgresql(self, compiler, connection, **extra_context):
15
+ def as_postgresql(
16
+ self,
17
+ compiler: SQLCompiler,
18
+ connection: BaseDatabaseWrapper,
19
+ **extra_context: Any,
20
+ ) -> tuple[str, tuple[Any, ...]]:
9
21
  # Cast FloatField to DecimalField as PostgreSQL doesn't support the
10
22
  # following function signatures:
11
23
  # - LOG(double, double)
@@ -24,18 +36,23 @@ class FixDecimalInputMixin:
24
36
 
25
37
 
26
38
  class FixDurationInputMixin:
27
- def as_mysql(self, compiler, connection, **extra_context):
28
- sql, params = super().as_sql(compiler, connection, **extra_context)
29
- if self.output_field.get_internal_type() == "DurationField":
39
+ def as_mysql(
40
+ self,
41
+ compiler: SQLCompiler,
42
+ connection: BaseDatabaseWrapper,
43
+ **extra_context: Any,
44
+ ) -> tuple[str, tuple[Any, ...]]:
45
+ sql, params = super().as_sql(compiler, connection, **extra_context) # type: ignore[misc]
46
+ if self.output_field.get_internal_type() == "DurationField": # type: ignore[attr-defined]
30
47
  sql = f"CAST({sql} AS SIGNED)"
31
48
  return sql, params
32
49
 
33
50
 
34
51
  class NumericOutputFieldMixin:
35
- def _resolve_output_field(self):
36
- source_fields = self.get_source_fields()
52
+ def _resolve_output_field(self) -> DecimalField | FloatField:
53
+ source_fields = self.get_source_fields() # type: ignore[attr-defined]
37
54
  if any(isinstance(s, DecimalField) for s in source_fields):
38
55
  return DecimalField()
39
56
  if any(isinstance(s, IntegerField) for s in source_fields):
40
57
  return FloatField()
41
- return super()._resolve_output_field() if source_fields else FloatField()
58
+ return super()._resolve_output_field() if source_fields else FloatField() # type: ignore[misc]
@@ -1,12 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
1
5
  from plain.models.expressions import Func, Value
2
6
  from plain.models.fields import CharField, IntegerField, TextField
3
7
  from plain.models.functions import Cast, Coalesce
4
8
  from plain.models.lookups import Transform
5
9
 
10
+ if TYPE_CHECKING:
11
+ from plain.models.backends.base.base import BaseDatabaseWrapper
12
+ from plain.models.sql.compiler import SQLCompiler
13
+
6
14
 
7
15
  class MySQLSHA2Mixin:
8
- def as_mysql(self, compiler, connection, **extra_context):
9
- return super().as_sql(
16
+ def as_mysql(
17
+ self,
18
+ compiler: SQLCompiler,
19
+ connection: BaseDatabaseWrapper,
20
+ **extra_context: Any,
21
+ ) -> tuple[str, tuple[Any, ...]]:
22
+ return super().as_sql( # type: ignore[misc]
10
23
  compiler,
11
24
  connection,
12
25
  template=f"SHA2(%(expressions)s, {self.function[3:]})",
@@ -15,8 +28,13 @@ class MySQLSHA2Mixin:
15
28
 
16
29
 
17
30
  class PostgreSQLSHAMixin:
18
- def as_postgresql(self, compiler, connection, **extra_context):
19
- return super().as_sql(
31
+ def as_postgresql(
32
+ self,
33
+ compiler: SQLCompiler,
34
+ connection: BaseDatabaseWrapper,
35
+ **extra_context: Any,
36
+ ) -> tuple[str, tuple[Any, ...]]:
37
+ return super().as_sql( # type: ignore[misc]
20
38
  compiler,
21
39
  connection,
22
40
  template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
@@ -29,7 +47,12 @@ class Chr(Transform):
29
47
  function = "CHR"
30
48
  lookup_name = "chr"
31
49
 
32
- def as_mysql(self, compiler, connection, **extra_context):
50
+ def as_mysql(
51
+ self,
52
+ compiler: SQLCompiler,
53
+ connection: BaseDatabaseWrapper,
54
+ **extra_context: Any,
55
+ ) -> tuple[str, tuple[Any, ...]]:
33
56
  return super().as_sql(
34
57
  compiler,
35
58
  connection,
@@ -38,7 +61,12 @@ class Chr(Transform):
38
61
  **extra_context,
39
62
  )
40
63
 
41
- def as_sqlite(self, compiler, connection, **extra_context):
64
+ def as_sqlite(
65
+ self,
66
+ compiler: SQLCompiler,
67
+ connection: BaseDatabaseWrapper,
68
+ **extra_context: Any,
69
+ ) -> tuple[str, tuple[Any, ...]]:
42
70
  return super().as_sql(compiler, connection, function="CHAR", **extra_context)
43
71
 
44
72
 
@@ -50,7 +78,12 @@ class ConcatPair(Func):
50
78
 
51
79
  function = "CONCAT"
52
80
 
53
- def as_sqlite(self, compiler, connection, **extra_context):
81
+ def as_sqlite(
82
+ self,
83
+ compiler: SQLCompiler,
84
+ connection: BaseDatabaseWrapper,
85
+ **extra_context: Any,
86
+ ) -> tuple[str, tuple[Any, ...]]:
54
87
  coalesced = self.coalesce()
55
88
  return super(ConcatPair, coalesced).as_sql(
56
89
  compiler,
@@ -60,7 +93,12 @@ class ConcatPair(Func):
60
93
  **extra_context,
61
94
  )
62
95
 
63
- def as_postgresql(self, compiler, connection, **extra_context):
96
+ def as_postgresql(
97
+ self,
98
+ compiler: SQLCompiler,
99
+ connection: BaseDatabaseWrapper,
100
+ **extra_context: Any,
101
+ ) -> tuple[str, tuple[Any, ...]]:
64
102
  copy = self.copy()
65
103
  copy.set_source_expressions(
66
104
  [
@@ -74,7 +112,12 @@ class ConcatPair(Func):
74
112
  **extra_context,
75
113
  )
76
114
 
77
- def as_mysql(self, compiler, connection, **extra_context):
115
+ def as_mysql(
116
+ self,
117
+ compiler: SQLCompiler,
118
+ connection: BaseDatabaseWrapper,
119
+ **extra_context: Any,
120
+ ) -> tuple[str, tuple[Any, ...]]:
78
121
  # Use CONCAT_WS with an empty separator so that NULLs are ignored.
79
122
  return super().as_sql(
80
123
  compiler,
@@ -84,7 +127,7 @@ class ConcatPair(Func):
84
127
  **extra_context,
85
128
  )
86
129
 
87
- def coalesce(self):
130
+ def coalesce(self) -> ConcatPair:
88
131
  # null on either side results in null for expression, wrap with coalesce
89
132
  c = self.copy()
90
133
  c.set_source_expressions(
@@ -106,13 +149,13 @@ class Concat(Func):
106
149
  function = None
107
150
  template = "%(expressions)s"
108
151
 
109
- def __init__(self, *expressions, **extra):
152
+ def __init__(self, *expressions: Any, **extra: Any) -> None:
110
153
  if len(expressions) < 2:
111
154
  raise ValueError("Concat must take at least two expressions")
112
155
  paired = self._paired(expressions)
113
156
  super().__init__(paired, **extra)
114
157
 
115
- def _paired(self, expressions):
158
+ def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
116
159
  # wrap pairs of expressions in successive concat functions
117
160
  # exp = [a, b, c, d]
118
161
  # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
@@ -126,7 +169,7 @@ class Left(Func):
126
169
  arity = 2
127
170
  output_field = CharField()
128
171
 
129
- def __init__(self, expression, length, **extra):
172
+ def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
130
173
  """
131
174
  expression: the name of a field, or an expression returning a string
132
175
  length: the number of characters to return from the start of the string
@@ -136,10 +179,15 @@ class Left(Func):
136
179
  raise ValueError("'length' must be greater than 0.")
137
180
  super().__init__(expression, length, **extra)
138
181
 
139
- def get_substr(self):
182
+ def get_substr(self) -> Substr:
140
183
  return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
141
184
 
142
- def as_sqlite(self, compiler, connection, **extra_context):
185
+ def as_sqlite(
186
+ self,
187
+ compiler: SQLCompiler,
188
+ connection: BaseDatabaseWrapper,
189
+ **extra_context: Any,
190
+ ) -> tuple[str, tuple[Any, ...]]:
143
191
  return self.get_substr().as_sqlite(compiler, connection, **extra_context)
144
192
 
145
193
 
@@ -150,7 +198,12 @@ class Length(Transform):
150
198
  lookup_name = "length"
151
199
  output_field = IntegerField()
152
200
 
153
- def as_mysql(self, compiler, connection, **extra_context):
201
+ def as_mysql(
202
+ self,
203
+ compiler: SQLCompiler,
204
+ connection: BaseDatabaseWrapper,
205
+ **extra_context: Any,
206
+ ) -> tuple[str, tuple[Any, ...]]:
154
207
  return super().as_sql(
155
208
  compiler, connection, function="CHAR_LENGTH", **extra_context
156
209
  )
@@ -165,7 +218,9 @@ class LPad(Func):
165
218
  function = "LPAD"
166
219
  output_field = CharField()
167
220
 
168
- def __init__(self, expression, length, fill_text=Value(" "), **extra):
221
+ def __init__(
222
+ self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
223
+ ) -> None:
169
224
  if (
170
225
  not hasattr(length, "resolve_expression")
171
226
  and length is not None
@@ -190,10 +245,20 @@ class Ord(Transform):
190
245
  lookup_name = "ord"
191
246
  output_field = IntegerField()
192
247
 
193
- def as_mysql(self, compiler, connection, **extra_context):
248
+ def as_mysql(
249
+ self,
250
+ compiler: SQLCompiler,
251
+ connection: BaseDatabaseWrapper,
252
+ **extra_context: Any,
253
+ ) -> tuple[str, tuple[Any, ...]]:
194
254
  return super().as_sql(compiler, connection, function="ORD", **extra_context)
195
255
 
196
- def as_sqlite(self, compiler, connection, **extra_context):
256
+ def as_sqlite(
257
+ self,
258
+ compiler: SQLCompiler,
259
+ connection: BaseDatabaseWrapper,
260
+ **extra_context: Any,
261
+ ) -> tuple[str, tuple[Any, ...]]:
197
262
  return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
198
263
 
199
264
 
@@ -201,7 +266,7 @@ class Repeat(Func):
201
266
  function = "REPEAT"
202
267
  output_field = CharField()
203
268
 
204
- def __init__(self, expression, number, **extra):
269
+ def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
205
270
  if (
206
271
  not hasattr(number, "resolve_expression")
207
272
  and number is not None
@@ -214,7 +279,9 @@ class Repeat(Func):
214
279
  class Replace(Func):
215
280
  function = "REPLACE"
216
281
 
217
- def __init__(self, expression, text, replacement=Value(""), **extra):
282
+ def __init__(
283
+ self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
284
+ ) -> None:
218
285
  super().__init__(expression, text, replacement, **extra)
219
286
 
220
287
 
@@ -226,7 +293,7 @@ class Reverse(Transform):
226
293
  class Right(Left):
227
294
  function = "RIGHT"
228
295
 
229
- def get_substr(self):
296
+ def get_substr(self) -> Substr:
230
297
  return Substr(
231
298
  self.source_expressions[0], self.source_expressions[1] * Value(-1)
232
299
  )
@@ -277,7 +344,12 @@ class StrIndex(Func):
277
344
  arity = 2
278
345
  output_field = IntegerField()
279
346
 
280
- def as_postgresql(self, compiler, connection, **extra_context):
347
+ def as_postgresql(
348
+ self,
349
+ compiler: SQLCompiler,
350
+ connection: BaseDatabaseWrapper,
351
+ **extra_context: Any,
352
+ ) -> tuple[str, tuple[Any, ...]]:
281
353
  return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
282
354
 
283
355
 
@@ -285,7 +357,9 @@ class Substr(Func):
285
357
  function = "SUBSTRING"
286
358
  output_field = CharField()
287
359
 
288
- def __init__(self, expression, pos, length=None, **extra):
360
+ def __init__(
361
+ self, expression: Any, pos: Any, length: Any = None, **extra: Any
362
+ ) -> None:
289
363
  """
290
364
  expression: the name of a field, or an expression returning a string
291
365
  pos: an integer > 0, or an expression returning an integer
@@ -299,7 +373,12 @@ class Substr(Func):
299
373
  expressions.append(length)
300
374
  super().__init__(*expressions, **extra)
301
375
 
302
- def as_sqlite(self, compiler, connection, **extra_context):
376
+ def as_sqlite(
377
+ self,
378
+ compiler: SQLCompiler,
379
+ connection: BaseDatabaseWrapper,
380
+ **extra_context: Any,
381
+ ) -> tuple[str, tuple[Any, ...]]:
303
382
  return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
304
383
 
305
384