plain.models 0.49.2__py3-none-any.whl → 0.51.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/models/CHANGELOG.md +27 -0
- plain/models/README.md +26 -42
- plain/models/__init__.py +2 -0
- plain/models/aggregates.py +42 -19
- plain/models/backends/base/base.py +125 -105
- plain/models/backends/base/client.py +11 -3
- plain/models/backends/base/creation.py +24 -14
- plain/models/backends/base/features.py +10 -4
- plain/models/backends/base/introspection.py +37 -20
- plain/models/backends/base/operations.py +187 -91
- plain/models/backends/base/schema.py +338 -218
- plain/models/backends/base/validation.py +13 -4
- plain/models/backends/ddl_references.py +85 -43
- plain/models/backends/mysql/base.py +29 -26
- plain/models/backends/mysql/client.py +7 -2
- plain/models/backends/mysql/compiler.py +13 -4
- plain/models/backends/mysql/creation.py +5 -2
- plain/models/backends/mysql/features.py +24 -22
- plain/models/backends/mysql/introspection.py +22 -13
- plain/models/backends/mysql/operations.py +107 -40
- plain/models/backends/mysql/schema.py +52 -28
- plain/models/backends/mysql/validation.py +13 -6
- plain/models/backends/postgresql/base.py +41 -34
- plain/models/backends/postgresql/client.py +7 -2
- plain/models/backends/postgresql/creation.py +10 -5
- plain/models/backends/postgresql/introspection.py +15 -8
- plain/models/backends/postgresql/operations.py +110 -43
- plain/models/backends/postgresql/schema.py +88 -49
- plain/models/backends/sqlite3/_functions.py +151 -115
- plain/models/backends/sqlite3/base.py +37 -23
- plain/models/backends/sqlite3/client.py +7 -1
- plain/models/backends/sqlite3/creation.py +9 -5
- plain/models/backends/sqlite3/features.py +5 -3
- plain/models/backends/sqlite3/introspection.py +32 -16
- plain/models/backends/sqlite3/operations.py +126 -43
- plain/models/backends/sqlite3/schema.py +127 -92
- plain/models/backends/utils.py +52 -29
- plain/models/backups/cli.py +8 -6
- plain/models/backups/clients.py +16 -7
- plain/models/backups/core.py +24 -13
- plain/models/base.py +221 -229
- plain/models/cli.py +98 -67
- plain/models/config.py +1 -1
- plain/models/connections.py +23 -7
- plain/models/constraints.py +79 -56
- plain/models/database_url.py +1 -1
- plain/models/db.py +6 -2
- plain/models/deletion.py +80 -56
- plain/models/entrypoints.py +1 -1
- plain/models/enums.py +22 -11
- plain/models/exceptions.py +23 -8
- plain/models/expressions.py +441 -258
- plain/models/fields/__init__.py +272 -217
- plain/models/fields/json.py +123 -57
- plain/models/fields/mixins.py +12 -8
- plain/models/fields/related.py +324 -290
- plain/models/fields/related_descriptors.py +33 -24
- plain/models/fields/related_lookups.py +24 -12
- plain/models/fields/related_managers.py +102 -79
- plain/models/fields/reverse_related.py +66 -63
- plain/models/forms.py +101 -75
- plain/models/functions/comparison.py +71 -18
- plain/models/functions/datetime.py +79 -29
- plain/models/functions/math.py +43 -10
- plain/models/functions/mixins.py +24 -7
- plain/models/functions/text.py +104 -25
- plain/models/functions/window.py +12 -6
- plain/models/indexes.py +57 -32
- plain/models/lookups.py +228 -153
- plain/models/meta.py +505 -0
- plain/models/migrations/autodetector.py +86 -43
- plain/models/migrations/exceptions.py +7 -3
- plain/models/migrations/executor.py +33 -7
- plain/models/migrations/graph.py +79 -50
- plain/models/migrations/loader.py +45 -22
- plain/models/migrations/migration.py +23 -18
- plain/models/migrations/operations/base.py +38 -20
- plain/models/migrations/operations/fields.py +95 -48
- plain/models/migrations/operations/models.py +246 -142
- plain/models/migrations/operations/special.py +82 -25
- plain/models/migrations/optimizer.py +7 -2
- plain/models/migrations/questioner.py +58 -31
- plain/models/migrations/recorder.py +27 -16
- plain/models/migrations/serializer.py +50 -39
- plain/models/migrations/state.py +232 -156
- plain/models/migrations/utils.py +30 -14
- plain/models/migrations/writer.py +17 -14
- plain/models/options.py +189 -518
- plain/models/otel.py +16 -6
- plain/models/preflight.py +42 -17
- plain/models/query.py +400 -251
- plain/models/query_utils.py +109 -69
- plain/models/registry.py +40 -21
- plain/models/sql/compiler.py +190 -127
- plain/models/sql/datastructures.py +38 -25
- plain/models/sql/query.py +320 -225
- plain/models/sql/subqueries.py +36 -25
- plain/models/sql/where.py +54 -29
- plain/models/test/pytest.py +15 -11
- plain/models/test/utils.py +4 -2
- plain/models/transaction.py +20 -7
- plain/models/utils.py +17 -6
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/METADATA +27 -43
- plain_models-0.51.0.dist-info/RECORD +123 -0
- plain_models-0.49.2.dist-info/RECORD +0 -122
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/WHEEL +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/entry_points.txt +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.51.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from datetime import datetime
|
4
|
+
from typing import TYPE_CHECKING, Any
|
2
5
|
|
3
6
|
from plain.models.expressions import Func
|
4
7
|
from plain.models.fields import (
|
@@ -19,11 +22,15 @@ from plain.models.lookups import (
|
|
19
22
|
)
|
20
23
|
from plain.utils import timezone
|
21
24
|
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
27
|
+
from plain.models.sql.compiler import SQLCompiler
|
28
|
+
|
22
29
|
|
23
30
|
class TimezoneMixin:
|
24
31
|
tzinfo = None
|
25
32
|
|
26
|
-
def get_tzname(self):
|
33
|
+
def get_tzname(self) -> str | None:
|
27
34
|
# Timezone conversions must happen to the input datetime *before*
|
28
35
|
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
29
36
|
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
@@ -38,7 +45,13 @@ class Extract(TimezoneMixin, Transform):
|
|
38
45
|
lookup_name = None
|
39
46
|
output_field = IntegerField()
|
40
47
|
|
41
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
expression: Any,
|
51
|
+
lookup_name: str | None = None,
|
52
|
+
tzinfo: Any = None,
|
53
|
+
**extra: Any,
|
54
|
+
) -> None:
|
42
55
|
if self.lookup_name is None:
|
43
56
|
self.lookup_name = lookup_name
|
44
57
|
if self.lookup_name is None:
|
@@ -46,7 +59,9 @@ class Extract(TimezoneMixin, Transform):
|
|
46
59
|
self.tzinfo = tzinfo
|
47
60
|
super().__init__(expression, **extra)
|
48
61
|
|
49
|
-
def as_sql(
|
62
|
+
def as_sql(
|
63
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
64
|
+
) -> tuple[str, tuple[Any, ...]]:
|
50
65
|
sql, params = compiler.compile(self.lhs)
|
51
66
|
lhs_output_field = self.lhs.output_field
|
52
67
|
if isinstance(lhs_output_field, DateTimeField):
|
@@ -76,11 +91,16 @@ class Extract(TimezoneMixin, Transform):
|
|
76
91
|
# resolve_expression has already validated the output_field so this
|
77
92
|
# assert should never be hit.
|
78
93
|
raise ValueError("Tried to Extract from an invalid type.")
|
79
|
-
return sql, params
|
94
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
80
95
|
|
81
96
|
def resolve_expression(
|
82
|
-
self,
|
83
|
-
|
97
|
+
self,
|
98
|
+
query: Any = None,
|
99
|
+
allow_joins: bool = True,
|
100
|
+
reuse: Any = None,
|
101
|
+
summarize: bool = False,
|
102
|
+
for_save: bool = False,
|
103
|
+
) -> Extract:
|
84
104
|
copy = super().resolve_expression(
|
85
105
|
query, allow_joins, reuse, summarize, for_save
|
86
106
|
)
|
@@ -209,7 +229,12 @@ class Now(Func):
|
|
209
229
|
template = "CURRENT_TIMESTAMP"
|
210
230
|
output_field = DateTimeField()
|
211
231
|
|
212
|
-
def as_postgresql(
|
232
|
+
def as_postgresql(
|
233
|
+
self,
|
234
|
+
compiler: SQLCompiler,
|
235
|
+
connection: BaseDatabaseWrapper,
|
236
|
+
**extra_context: Any,
|
237
|
+
) -> tuple[str, tuple[Any, ...]]:
|
213
238
|
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
214
239
|
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
215
240
|
# other databases.
|
@@ -217,12 +242,22 @@ class Now(Func):
|
|
217
242
|
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
218
243
|
)
|
219
244
|
|
220
|
-
def as_mysql(
|
245
|
+
def as_mysql(
|
246
|
+
self,
|
247
|
+
compiler: SQLCompiler,
|
248
|
+
connection: BaseDatabaseWrapper,
|
249
|
+
**extra_context: Any,
|
250
|
+
) -> tuple[str, tuple[Any, ...]]:
|
221
251
|
return self.as_sql(
|
222
252
|
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
223
253
|
)
|
224
254
|
|
225
|
-
def as_sqlite(
|
255
|
+
def as_sqlite(
|
256
|
+
self,
|
257
|
+
compiler: SQLCompiler,
|
258
|
+
connection: BaseDatabaseWrapper,
|
259
|
+
**extra_context: Any,
|
260
|
+
) -> tuple[str, tuple[Any, ...]]:
|
226
261
|
return self.as_sql(
|
227
262
|
compiler,
|
228
263
|
connection,
|
@@ -237,15 +272,17 @@ class TruncBase(TimezoneMixin, Transform):
|
|
237
272
|
|
238
273
|
def __init__(
|
239
274
|
self,
|
240
|
-
expression,
|
241
|
-
output_field=None,
|
242
|
-
tzinfo=None,
|
243
|
-
**extra,
|
244
|
-
):
|
275
|
+
expression: Any,
|
276
|
+
output_field: Field | None = None,
|
277
|
+
tzinfo: Any = None,
|
278
|
+
**extra: Any,
|
279
|
+
) -> None:
|
245
280
|
self.tzinfo = tzinfo
|
246
281
|
super().__init__(expression, output_field=output_field, **extra)
|
247
282
|
|
248
|
-
def as_sql(
|
283
|
+
def as_sql(
|
284
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
285
|
+
) -> tuple[str, tuple[Any, ...]]:
|
249
286
|
sql, params = compiler.compile(self.lhs)
|
250
287
|
tzname = None
|
251
288
|
if isinstance(self.lhs.output_field, DateTimeField):
|
@@ -268,11 +305,16 @@ class TruncBase(TimezoneMixin, Transform):
|
|
268
305
|
raise ValueError(
|
269
306
|
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
270
307
|
)
|
271
|
-
return sql, params
|
308
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
272
309
|
|
273
310
|
def resolve_expression(
|
274
|
-
self,
|
275
|
-
|
311
|
+
self,
|
312
|
+
query: Any = None,
|
313
|
+
allow_joins: bool = True,
|
314
|
+
reuse: Any = None,
|
315
|
+
summarize: bool = False,
|
316
|
+
for_save: bool = False,
|
317
|
+
) -> TruncBase:
|
276
318
|
copy = super().resolve_expression(
|
277
319
|
query, allow_joins, reuse, summarize, for_save
|
278
320
|
)
|
@@ -325,7 +367,9 @@ class TruncBase(TimezoneMixin, Transform):
|
|
325
367
|
)
|
326
368
|
return copy
|
327
369
|
|
328
|
-
def convert_value(
|
370
|
+
def convert_value(
|
371
|
+
self, value: Any, expression: Any, connection: BaseDatabaseWrapper
|
372
|
+
) -> Any:
|
329
373
|
if isinstance(self.output_field, DateTimeField):
|
330
374
|
if value is not None:
|
331
375
|
value = value.replace(tzinfo=None)
|
@@ -348,12 +392,12 @@ class TruncBase(TimezoneMixin, Transform):
|
|
348
392
|
class Trunc(TruncBase):
|
349
393
|
def __init__(
|
350
394
|
self,
|
351
|
-
expression,
|
352
|
-
kind,
|
353
|
-
output_field=None,
|
354
|
-
tzinfo=None,
|
355
|
-
**extra,
|
356
|
-
):
|
395
|
+
expression: Any,
|
396
|
+
kind: str,
|
397
|
+
output_field: Field | None = None,
|
398
|
+
tzinfo: Any = None,
|
399
|
+
**extra: Any,
|
400
|
+
) -> None:
|
357
401
|
self.kind = kind
|
358
402
|
super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
|
359
403
|
|
@@ -385,11 +429,14 @@ class TruncDate(TruncBase):
|
|
385
429
|
lookup_name = "date"
|
386
430
|
output_field = DateField()
|
387
431
|
|
388
|
-
def as_sql(
|
432
|
+
def as_sql(
|
433
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
434
|
+
) -> tuple[str, tuple[Any, ...]]:
|
389
435
|
# Cast to date rather than truncate to date.
|
390
436
|
sql, params = compiler.compile(self.lhs)
|
391
437
|
tzname = self.get_tzname()
|
392
|
-
|
438
|
+
sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
439
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
393
440
|
|
394
441
|
|
395
442
|
class TruncTime(TruncBase):
|
@@ -397,11 +444,14 @@ class TruncTime(TruncBase):
|
|
397
444
|
lookup_name = "time"
|
398
445
|
output_field = TimeField()
|
399
446
|
|
400
|
-
def as_sql(
|
447
|
+
def as_sql(
|
448
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
449
|
+
) -> tuple[str, tuple[Any, ...]]:
|
401
450
|
# Cast to time rather than truncate to time.
|
402
451
|
sql, params = compiler.compile(self.lhs)
|
403
452
|
tzname = self.get_tzname()
|
404
|
-
|
453
|
+
sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
454
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
405
455
|
|
406
456
|
|
407
457
|
class TruncHour(TruncBase):
|
plain/models/functions/math.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
1
5
|
from plain.models.expressions import Func, Value
|
2
|
-
from plain.models.fields import FloatField, IntegerField
|
6
|
+
from plain.models.fields import Field, FloatField, IntegerField
|
3
7
|
from plain.models.functions import Cast
|
4
8
|
from plain.models.functions.mixins import (
|
5
9
|
FixDecimalInputMixin,
|
@@ -7,6 +11,10 @@ from plain.models.functions.mixins import (
|
|
7
11
|
)
|
8
12
|
from plain.models.lookups import Transform
|
9
13
|
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
16
|
+
from plain.models.sql.compiler import SQLCompiler
|
17
|
+
|
10
18
|
|
11
19
|
class Abs(Transform):
|
12
20
|
function = "ABS"
|
@@ -32,10 +40,15 @@ class ATan2(NumericOutputFieldMixin, Func):
|
|
32
40
|
function = "ATAN2"
|
33
41
|
arity = 2
|
34
42
|
|
35
|
-
def as_sqlite(
|
43
|
+
def as_sqlite(
|
44
|
+
self,
|
45
|
+
compiler: SQLCompiler,
|
46
|
+
connection: BaseDatabaseWrapper,
|
47
|
+
**extra_context: Any,
|
48
|
+
) -> tuple[str, tuple[Any, ...]]:
|
36
49
|
if not getattr(
|
37
50
|
connection.ops, "spatialite", False
|
38
|
-
) or connection.ops.spatial_version >= (5, 0, 0):
|
51
|
+
) or connection.ops.spatial_version >= (5, 0, 0): # type: ignore[attr-defined]
|
39
52
|
return self.as_sql(compiler, connection)
|
40
53
|
# This function is usually ATan2(y, x), returning the inverse tangent
|
41
54
|
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
@@ -93,7 +106,12 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
|
93
106
|
function = "LOG"
|
94
107
|
arity = 2
|
95
108
|
|
96
|
-
def as_sqlite(
|
109
|
+
def as_sqlite(
|
110
|
+
self,
|
111
|
+
compiler: SQLCompiler,
|
112
|
+
connection: BaseDatabaseWrapper,
|
113
|
+
**extra_context: Any,
|
114
|
+
) -> tuple[str, tuple[Any, ...]]:
|
97
115
|
if not getattr(connection.ops, "spatialite", False):
|
98
116
|
return self.as_sql(compiler, connection)
|
99
117
|
# This function is usually Log(b, x) returning the logarithm of x to
|
@@ -127,13 +145,23 @@ class Random(NumericOutputFieldMixin, Func):
|
|
127
145
|
function = "RANDOM"
|
128
146
|
arity = 0
|
129
147
|
|
130
|
-
def as_mysql(
|
148
|
+
def as_mysql(
|
149
|
+
self,
|
150
|
+
compiler: SQLCompiler,
|
151
|
+
connection: BaseDatabaseWrapper,
|
152
|
+
**extra_context: Any,
|
153
|
+
) -> tuple[str, tuple[Any, ...]]:
|
131
154
|
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
132
155
|
|
133
|
-
def as_sqlite(
|
156
|
+
def as_sqlite(
|
157
|
+
self,
|
158
|
+
compiler: SQLCompiler,
|
159
|
+
connection: BaseDatabaseWrapper,
|
160
|
+
**extra_context: Any,
|
161
|
+
) -> tuple[str, tuple[Any, ...]]:
|
134
162
|
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
135
163
|
|
136
|
-
def get_group_by_cols(self):
|
164
|
+
def get_group_by_cols(self) -> list[Any]:
|
137
165
|
return []
|
138
166
|
|
139
167
|
|
@@ -142,16 +170,21 @@ class Round(FixDecimalInputMixin, Transform):
|
|
142
170
|
lookup_name = "round"
|
143
171
|
arity = None # Override Transform's arity=1 to enable passing precision.
|
144
172
|
|
145
|
-
def __init__(self, expression, precision=0, **extra):
|
173
|
+
def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
|
146
174
|
super().__init__(expression, precision, **extra)
|
147
175
|
|
148
|
-
def as_sqlite(
|
176
|
+
def as_sqlite(
|
177
|
+
self,
|
178
|
+
compiler: SQLCompiler,
|
179
|
+
connection: BaseDatabaseWrapper,
|
180
|
+
**extra_context: Any,
|
181
|
+
) -> tuple[str, tuple[Any, ...]]:
|
149
182
|
precision = self.get_source_expressions()[1]
|
150
183
|
if isinstance(precision, Value) and precision.value < 0:
|
151
184
|
raise ValueError("SQLite does not support negative precision.")
|
152
185
|
return super().as_sqlite(compiler, connection, **extra_context)
|
153
186
|
|
154
|
-
def _resolve_output_field(self):
|
187
|
+
def _resolve_output_field(self) -> Field:
|
155
188
|
source = self.get_source_expressions()[0]
|
156
189
|
return source.output_field
|
157
190
|
|
plain/models/functions/mixins.py
CHANGED
@@ -1,11 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import sys
|
4
|
+
from typing import TYPE_CHECKING, Any
|
2
5
|
|
3
6
|
from plain.models.fields import DecimalField, FloatField, IntegerField
|
4
7
|
from plain.models.functions import Cast
|
5
8
|
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
11
|
+
from plain.models.sql.compiler import SQLCompiler
|
12
|
+
|
6
13
|
|
7
14
|
class FixDecimalInputMixin:
|
8
|
-
def as_postgresql(
|
15
|
+
def as_postgresql(
|
16
|
+
self,
|
17
|
+
compiler: SQLCompiler,
|
18
|
+
connection: BaseDatabaseWrapper,
|
19
|
+
**extra_context: Any,
|
20
|
+
) -> tuple[str, tuple[Any, ...]]:
|
9
21
|
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
10
22
|
# following function signatures:
|
11
23
|
# - LOG(double, double)
|
@@ -24,18 +36,23 @@ class FixDecimalInputMixin:
|
|
24
36
|
|
25
37
|
|
26
38
|
class FixDurationInputMixin:
|
27
|
-
def as_mysql(
|
28
|
-
|
29
|
-
|
39
|
+
def as_mysql(
|
40
|
+
self,
|
41
|
+
compiler: SQLCompiler,
|
42
|
+
connection: BaseDatabaseWrapper,
|
43
|
+
**extra_context: Any,
|
44
|
+
) -> tuple[str, tuple[Any, ...]]:
|
45
|
+
sql, params = super().as_sql(compiler, connection, **extra_context) # type: ignore[misc]
|
46
|
+
if self.output_field.get_internal_type() == "DurationField": # type: ignore[attr-defined]
|
30
47
|
sql = f"CAST({sql} AS SIGNED)"
|
31
48
|
return sql, params
|
32
49
|
|
33
50
|
|
34
51
|
class NumericOutputFieldMixin:
|
35
|
-
def _resolve_output_field(self):
|
36
|
-
source_fields = self.get_source_fields()
|
52
|
+
def _resolve_output_field(self) -> DecimalField | FloatField:
|
53
|
+
source_fields = self.get_source_fields() # type: ignore[attr-defined]
|
37
54
|
if any(isinstance(s, DecimalField) for s in source_fields):
|
38
55
|
return DecimalField()
|
39
56
|
if any(isinstance(s, IntegerField) for s in source_fields):
|
40
57
|
return FloatField()
|
41
|
-
return super()._resolve_output_field() if source_fields else FloatField()
|
58
|
+
return super()._resolve_output_field() if source_fields else FloatField() # type: ignore[misc]
|
plain/models/functions/text.py
CHANGED
@@ -1,12 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
1
5
|
from plain.models.expressions import Func, Value
|
2
6
|
from plain.models.fields import CharField, IntegerField, TextField
|
3
7
|
from plain.models.functions import Cast, Coalesce
|
4
8
|
from plain.models.lookups import Transform
|
5
9
|
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
12
|
+
from plain.models.sql.compiler import SQLCompiler
|
13
|
+
|
6
14
|
|
7
15
|
class MySQLSHA2Mixin:
|
8
|
-
def as_mysql(
|
9
|
-
|
16
|
+
def as_mysql(
|
17
|
+
self,
|
18
|
+
compiler: SQLCompiler,
|
19
|
+
connection: BaseDatabaseWrapper,
|
20
|
+
**extra_context: Any,
|
21
|
+
) -> tuple[str, tuple[Any, ...]]:
|
22
|
+
return super().as_sql( # type: ignore[misc]
|
10
23
|
compiler,
|
11
24
|
connection,
|
12
25
|
template=f"SHA2(%(expressions)s, {self.function[3:]})",
|
@@ -15,8 +28,13 @@ class MySQLSHA2Mixin:
|
|
15
28
|
|
16
29
|
|
17
30
|
class PostgreSQLSHAMixin:
|
18
|
-
def as_postgresql(
|
19
|
-
|
31
|
+
def as_postgresql(
|
32
|
+
self,
|
33
|
+
compiler: SQLCompiler,
|
34
|
+
connection: BaseDatabaseWrapper,
|
35
|
+
**extra_context: Any,
|
36
|
+
) -> tuple[str, tuple[Any, ...]]:
|
37
|
+
return super().as_sql( # type: ignore[misc]
|
20
38
|
compiler,
|
21
39
|
connection,
|
22
40
|
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
@@ -29,7 +47,12 @@ class Chr(Transform):
|
|
29
47
|
function = "CHR"
|
30
48
|
lookup_name = "chr"
|
31
49
|
|
32
|
-
def as_mysql(
|
50
|
+
def as_mysql(
|
51
|
+
self,
|
52
|
+
compiler: SQLCompiler,
|
53
|
+
connection: BaseDatabaseWrapper,
|
54
|
+
**extra_context: Any,
|
55
|
+
) -> tuple[str, tuple[Any, ...]]:
|
33
56
|
return super().as_sql(
|
34
57
|
compiler,
|
35
58
|
connection,
|
@@ -38,7 +61,12 @@ class Chr(Transform):
|
|
38
61
|
**extra_context,
|
39
62
|
)
|
40
63
|
|
41
|
-
def as_sqlite(
|
64
|
+
def as_sqlite(
|
65
|
+
self,
|
66
|
+
compiler: SQLCompiler,
|
67
|
+
connection: BaseDatabaseWrapper,
|
68
|
+
**extra_context: Any,
|
69
|
+
) -> tuple[str, tuple[Any, ...]]:
|
42
70
|
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
|
43
71
|
|
44
72
|
|
@@ -50,7 +78,12 @@ class ConcatPair(Func):
|
|
50
78
|
|
51
79
|
function = "CONCAT"
|
52
80
|
|
53
|
-
def as_sqlite(
|
81
|
+
def as_sqlite(
|
82
|
+
self,
|
83
|
+
compiler: SQLCompiler,
|
84
|
+
connection: BaseDatabaseWrapper,
|
85
|
+
**extra_context: Any,
|
86
|
+
) -> tuple[str, tuple[Any, ...]]:
|
54
87
|
coalesced = self.coalesce()
|
55
88
|
return super(ConcatPair, coalesced).as_sql(
|
56
89
|
compiler,
|
@@ -60,7 +93,12 @@ class ConcatPair(Func):
|
|
60
93
|
**extra_context,
|
61
94
|
)
|
62
95
|
|
63
|
-
def as_postgresql(
|
96
|
+
def as_postgresql(
|
97
|
+
self,
|
98
|
+
compiler: SQLCompiler,
|
99
|
+
connection: BaseDatabaseWrapper,
|
100
|
+
**extra_context: Any,
|
101
|
+
) -> tuple[str, tuple[Any, ...]]:
|
64
102
|
copy = self.copy()
|
65
103
|
copy.set_source_expressions(
|
66
104
|
[
|
@@ -74,7 +112,12 @@ class ConcatPair(Func):
|
|
74
112
|
**extra_context,
|
75
113
|
)
|
76
114
|
|
77
|
-
def as_mysql(
|
115
|
+
def as_mysql(
|
116
|
+
self,
|
117
|
+
compiler: SQLCompiler,
|
118
|
+
connection: BaseDatabaseWrapper,
|
119
|
+
**extra_context: Any,
|
120
|
+
) -> tuple[str, tuple[Any, ...]]:
|
78
121
|
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
79
122
|
return super().as_sql(
|
80
123
|
compiler,
|
@@ -84,7 +127,7 @@ class ConcatPair(Func):
|
|
84
127
|
**extra_context,
|
85
128
|
)
|
86
129
|
|
87
|
-
def coalesce(self):
|
130
|
+
def coalesce(self) -> ConcatPair:
|
88
131
|
# null on either side results in null for expression, wrap with coalesce
|
89
132
|
c = self.copy()
|
90
133
|
c.set_source_expressions(
|
@@ -106,13 +149,13 @@ class Concat(Func):
|
|
106
149
|
function = None
|
107
150
|
template = "%(expressions)s"
|
108
151
|
|
109
|
-
def __init__(self, *expressions, **extra):
|
152
|
+
def __init__(self, *expressions: Any, **extra: Any) -> None:
|
110
153
|
if len(expressions) < 2:
|
111
154
|
raise ValueError("Concat must take at least two expressions")
|
112
155
|
paired = self._paired(expressions)
|
113
156
|
super().__init__(paired, **extra)
|
114
157
|
|
115
|
-
def _paired(self, expressions):
|
158
|
+
def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
|
116
159
|
# wrap pairs of expressions in successive concat functions
|
117
160
|
# exp = [a, b, c, d]
|
118
161
|
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
@@ -126,7 +169,7 @@ class Left(Func):
|
|
126
169
|
arity = 2
|
127
170
|
output_field = CharField()
|
128
171
|
|
129
|
-
def __init__(self, expression, length, **extra):
|
172
|
+
def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
|
130
173
|
"""
|
131
174
|
expression: the name of a field, or an expression returning a string
|
132
175
|
length: the number of characters to return from the start of the string
|
@@ -136,10 +179,15 @@ class Left(Func):
|
|
136
179
|
raise ValueError("'length' must be greater than 0.")
|
137
180
|
super().__init__(expression, length, **extra)
|
138
181
|
|
139
|
-
def get_substr(self):
|
182
|
+
def get_substr(self) -> Substr:
|
140
183
|
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
141
184
|
|
142
|
-
def as_sqlite(
|
185
|
+
def as_sqlite(
|
186
|
+
self,
|
187
|
+
compiler: SQLCompiler,
|
188
|
+
connection: BaseDatabaseWrapper,
|
189
|
+
**extra_context: Any,
|
190
|
+
) -> tuple[str, tuple[Any, ...]]:
|
143
191
|
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
|
144
192
|
|
145
193
|
|
@@ -150,7 +198,12 @@ class Length(Transform):
|
|
150
198
|
lookup_name = "length"
|
151
199
|
output_field = IntegerField()
|
152
200
|
|
153
|
-
def as_mysql(
|
201
|
+
def as_mysql(
|
202
|
+
self,
|
203
|
+
compiler: SQLCompiler,
|
204
|
+
connection: BaseDatabaseWrapper,
|
205
|
+
**extra_context: Any,
|
206
|
+
) -> tuple[str, tuple[Any, ...]]:
|
154
207
|
return super().as_sql(
|
155
208
|
compiler, connection, function="CHAR_LENGTH", **extra_context
|
156
209
|
)
|
@@ -165,7 +218,9 @@ class LPad(Func):
|
|
165
218
|
function = "LPAD"
|
166
219
|
output_field = CharField()
|
167
220
|
|
168
|
-
def __init__(
|
221
|
+
def __init__(
|
222
|
+
self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
|
223
|
+
) -> None:
|
169
224
|
if (
|
170
225
|
not hasattr(length, "resolve_expression")
|
171
226
|
and length is not None
|
@@ -190,10 +245,20 @@ class Ord(Transform):
|
|
190
245
|
lookup_name = "ord"
|
191
246
|
output_field = IntegerField()
|
192
247
|
|
193
|
-
def as_mysql(
|
248
|
+
def as_mysql(
|
249
|
+
self,
|
250
|
+
compiler: SQLCompiler,
|
251
|
+
connection: BaseDatabaseWrapper,
|
252
|
+
**extra_context: Any,
|
253
|
+
) -> tuple[str, tuple[Any, ...]]:
|
194
254
|
return super().as_sql(compiler, connection, function="ORD", **extra_context)
|
195
255
|
|
196
|
-
def as_sqlite(
|
256
|
+
def as_sqlite(
|
257
|
+
self,
|
258
|
+
compiler: SQLCompiler,
|
259
|
+
connection: BaseDatabaseWrapper,
|
260
|
+
**extra_context: Any,
|
261
|
+
) -> tuple[str, tuple[Any, ...]]:
|
197
262
|
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
|
198
263
|
|
199
264
|
|
@@ -201,7 +266,7 @@ class Repeat(Func):
|
|
201
266
|
function = "REPEAT"
|
202
267
|
output_field = CharField()
|
203
268
|
|
204
|
-
def __init__(self, expression, number, **extra):
|
269
|
+
def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
|
205
270
|
if (
|
206
271
|
not hasattr(number, "resolve_expression")
|
207
272
|
and number is not None
|
@@ -214,7 +279,9 @@ class Repeat(Func):
|
|
214
279
|
class Replace(Func):
|
215
280
|
function = "REPLACE"
|
216
281
|
|
217
|
-
def __init__(
|
282
|
+
def __init__(
|
283
|
+
self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
|
284
|
+
) -> None:
|
218
285
|
super().__init__(expression, text, replacement, **extra)
|
219
286
|
|
220
287
|
|
@@ -226,7 +293,7 @@ class Reverse(Transform):
|
|
226
293
|
class Right(Left):
|
227
294
|
function = "RIGHT"
|
228
295
|
|
229
|
-
def get_substr(self):
|
296
|
+
def get_substr(self) -> Substr:
|
230
297
|
return Substr(
|
231
298
|
self.source_expressions[0], self.source_expressions[1] * Value(-1)
|
232
299
|
)
|
@@ -277,7 +344,12 @@ class StrIndex(Func):
|
|
277
344
|
arity = 2
|
278
345
|
output_field = IntegerField()
|
279
346
|
|
280
|
-
def as_postgresql(
|
347
|
+
def as_postgresql(
|
348
|
+
self,
|
349
|
+
compiler: SQLCompiler,
|
350
|
+
connection: BaseDatabaseWrapper,
|
351
|
+
**extra_context: Any,
|
352
|
+
) -> tuple[str, tuple[Any, ...]]:
|
281
353
|
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
|
282
354
|
|
283
355
|
|
@@ -285,7 +357,9 @@ class Substr(Func):
|
|
285
357
|
function = "SUBSTRING"
|
286
358
|
output_field = CharField()
|
287
359
|
|
288
|
-
def __init__(
|
360
|
+
def __init__(
|
361
|
+
self, expression: Any, pos: Any, length: Any = None, **extra: Any
|
362
|
+
) -> None:
|
289
363
|
"""
|
290
364
|
expression: the name of a field, or an expression returning a string
|
291
365
|
pos: an integer > 0, or an expression returning an integer
|
@@ -299,7 +373,12 @@ class Substr(Func):
|
|
299
373
|
expressions.append(length)
|
300
374
|
super().__init__(*expressions, **extra)
|
301
375
|
|
302
|
-
def as_sqlite(
|
376
|
+
def as_sqlite(
|
377
|
+
self,
|
378
|
+
compiler: SQLCompiler,
|
379
|
+
connection: BaseDatabaseWrapper,
|
380
|
+
**extra_context: Any,
|
381
|
+
) -> tuple[str, tuple[Any, ...]]:
|
303
382
|
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
304
383
|
|
305
384
|
|