plain.models 0.49.2__py3-none-any.whl → 0.50.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/models/CHANGELOG.md +13 -0
- plain/models/aggregates.py +42 -19
- plain/models/backends/base/base.py +125 -105
- plain/models/backends/base/client.py +11 -3
- plain/models/backends/base/creation.py +22 -12
- plain/models/backends/base/features.py +10 -4
- plain/models/backends/base/introspection.py +29 -16
- plain/models/backends/base/operations.py +187 -91
- plain/models/backends/base/schema.py +267 -165
- plain/models/backends/base/validation.py +12 -3
- plain/models/backends/ddl_references.py +85 -43
- plain/models/backends/mysql/base.py +29 -26
- plain/models/backends/mysql/client.py +7 -2
- plain/models/backends/mysql/compiler.py +12 -3
- plain/models/backends/mysql/creation.py +5 -2
- plain/models/backends/mysql/features.py +24 -22
- plain/models/backends/mysql/introspection.py +22 -13
- plain/models/backends/mysql/operations.py +106 -39
- plain/models/backends/mysql/schema.py +48 -24
- plain/models/backends/mysql/validation.py +13 -6
- plain/models/backends/postgresql/base.py +41 -34
- plain/models/backends/postgresql/client.py +7 -2
- plain/models/backends/postgresql/creation.py +10 -5
- plain/models/backends/postgresql/introspection.py +15 -8
- plain/models/backends/postgresql/operations.py +109 -42
- plain/models/backends/postgresql/schema.py +85 -46
- plain/models/backends/sqlite3/_functions.py +151 -115
- plain/models/backends/sqlite3/base.py +37 -23
- plain/models/backends/sqlite3/client.py +7 -1
- plain/models/backends/sqlite3/creation.py +9 -5
- plain/models/backends/sqlite3/features.py +5 -3
- plain/models/backends/sqlite3/introspection.py +32 -16
- plain/models/backends/sqlite3/operations.py +125 -42
- plain/models/backends/sqlite3/schema.py +82 -58
- plain/models/backends/utils.py +52 -29
- plain/models/backups/cli.py +8 -6
- plain/models/backups/clients.py +16 -7
- plain/models/backups/core.py +24 -13
- plain/models/base.py +113 -74
- plain/models/cli.py +94 -63
- plain/models/config.py +1 -1
- plain/models/connections.py +23 -7
- plain/models/constraints.py +65 -47
- plain/models/database_url.py +1 -1
- plain/models/db.py +6 -2
- plain/models/deletion.py +66 -43
- plain/models/entrypoints.py +1 -1
- plain/models/enums.py +22 -11
- plain/models/exceptions.py +23 -8
- plain/models/expressions.py +440 -257
- plain/models/fields/__init__.py +253 -202
- plain/models/fields/json.py +120 -54
- plain/models/fields/mixins.py +12 -8
- plain/models/fields/related.py +284 -252
- plain/models/fields/related_descriptors.py +31 -22
- plain/models/fields/related_lookups.py +23 -11
- plain/models/fields/related_managers.py +81 -47
- plain/models/fields/reverse_related.py +58 -55
- plain/models/forms.py +89 -63
- plain/models/functions/comparison.py +71 -18
- plain/models/functions/datetime.py +79 -29
- plain/models/functions/math.py +43 -10
- plain/models/functions/mixins.py +24 -7
- plain/models/functions/text.py +104 -25
- plain/models/functions/window.py +12 -6
- plain/models/indexes.py +52 -28
- plain/models/lookups.py +228 -153
- plain/models/migrations/autodetector.py +86 -43
- plain/models/migrations/exceptions.py +7 -3
- plain/models/migrations/executor.py +33 -7
- plain/models/migrations/graph.py +79 -50
- plain/models/migrations/loader.py +45 -22
- plain/models/migrations/migration.py +23 -18
- plain/models/migrations/operations/base.py +37 -19
- plain/models/migrations/operations/fields.py +89 -42
- plain/models/migrations/operations/models.py +245 -143
- plain/models/migrations/operations/special.py +82 -25
- plain/models/migrations/optimizer.py +7 -2
- plain/models/migrations/questioner.py +58 -31
- plain/models/migrations/recorder.py +18 -11
- plain/models/migrations/serializer.py +50 -39
- plain/models/migrations/state.py +220 -133
- plain/models/migrations/utils.py +29 -13
- plain/models/migrations/writer.py +17 -14
- plain/models/options.py +63 -56
- plain/models/otel.py +16 -6
- plain/models/preflight.py +35 -12
- plain/models/query.py +323 -228
- plain/models/query_utils.py +93 -58
- plain/models/registry.py +34 -16
- plain/models/sql/compiler.py +146 -97
- plain/models/sql/datastructures.py +38 -25
- plain/models/sql/query.py +255 -169
- plain/models/sql/subqueries.py +32 -21
- plain/models/sql/where.py +54 -29
- plain/models/test/pytest.py +15 -11
- plain/models/test/utils.py +4 -2
- plain/models/transaction.py +20 -7
- plain/models/utils.py +13 -5
- {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/METADATA +1 -1
- plain_models-0.50.0.dist-info/RECORD +122 -0
- plain_models-0.49.2.dist-info/RECORD +0 -122
- {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/WHEEL +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/entry_points.txt +0 -0
- {plain_models-0.49.2.dist-info → plain_models-0.50.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,19 @@
|
|
1
1
|
"""Database functions that do comparisons or type conversions."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Any
|
6
|
+
|
3
7
|
from plain.models.db import NotSupportedError
|
4
8
|
from plain.models.expressions import Func, Value
|
5
|
-
from plain.models.fields import TextField
|
9
|
+
from plain.models.fields import Field, TextField
|
6
10
|
from plain.models.fields.json import JSONField
|
7
11
|
from plain.utils.regex_helper import _lazy_re_compile
|
8
12
|
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
15
|
+
from plain.models.sql.compiler import SQLCompiler
|
16
|
+
|
9
17
|
|
10
18
|
class Cast(Func):
|
11
19
|
"""Coerce an expression to a new field type."""
|
@@ -13,14 +21,24 @@ class Cast(Func):
|
|
13
21
|
function = "CAST"
|
14
22
|
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
15
23
|
|
16
|
-
def __init__(self, expression, output_field):
|
24
|
+
def __init__(self, expression: Any, output_field: Field) -> None:
|
17
25
|
super().__init__(expression, output_field=output_field)
|
18
26
|
|
19
|
-
def as_sql(
|
27
|
+
def as_sql(
|
28
|
+
self,
|
29
|
+
compiler: SQLCompiler,
|
30
|
+
connection: BaseDatabaseWrapper,
|
31
|
+
**extra_context: Any,
|
32
|
+
) -> tuple[str, tuple[Any, ...]]:
|
20
33
|
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
21
34
|
return super().as_sql(compiler, connection, **extra_context)
|
22
35
|
|
23
|
-
def as_sqlite(
|
36
|
+
def as_sqlite(
|
37
|
+
self,
|
38
|
+
compiler: SQLCompiler,
|
39
|
+
connection: BaseDatabaseWrapper,
|
40
|
+
**extra_context: Any,
|
41
|
+
) -> tuple[str, tuple[Any, ...]]:
|
24
42
|
db_type = self.output_field.db_type(connection)
|
25
43
|
if db_type in {"datetime", "time"}:
|
26
44
|
# Use strftime as datetime/time don't keep fractional seconds.
|
@@ -38,18 +56,28 @@ class Cast(Func):
|
|
38
56
|
)
|
39
57
|
return self.as_sql(compiler, connection, **extra_context)
|
40
58
|
|
41
|
-
def as_mysql(
|
59
|
+
def as_mysql(
|
60
|
+
self,
|
61
|
+
compiler: SQLCompiler,
|
62
|
+
connection: BaseDatabaseWrapper,
|
63
|
+
**extra_context: Any,
|
64
|
+
) -> tuple[str, tuple[Any, ...]]:
|
42
65
|
template = None
|
43
66
|
output_type = self.output_field.get_internal_type()
|
44
67
|
# MySQL doesn't support explicit cast to float.
|
45
68
|
if output_type == "FloatField":
|
46
69
|
template = "(%(expressions)s + 0.0)"
|
47
70
|
# MariaDB doesn't support explicit cast to JSON.
|
48
|
-
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
71
|
+
elif output_type == "JSONField" and connection.mysql_is_mariadb: # type: ignore[attr-defined]
|
49
72
|
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
50
73
|
return self.as_sql(compiler, connection, template=template, **extra_context)
|
51
74
|
|
52
|
-
def as_postgresql(
|
75
|
+
def as_postgresql(
|
76
|
+
self,
|
77
|
+
compiler: SQLCompiler,
|
78
|
+
connection: BaseDatabaseWrapper,
|
79
|
+
**extra_context: Any,
|
80
|
+
) -> tuple[str, tuple[Any, ...]]:
|
53
81
|
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
54
82
|
# 'expressions' is wrapped in parentheses in case it's a complex
|
55
83
|
# expression.
|
@@ -66,13 +94,13 @@ class Coalesce(Func):
|
|
66
94
|
|
67
95
|
function = "COALESCE"
|
68
96
|
|
69
|
-
def __init__(self, *expressions, **extra):
|
97
|
+
def __init__(self, *expressions: Any, **extra: Any) -> None:
|
70
98
|
if len(expressions) < 2:
|
71
99
|
raise ValueError("Coalesce must take at least two expressions")
|
72
100
|
super().__init__(*expressions, **extra)
|
73
101
|
|
74
102
|
@property
|
75
|
-
def empty_result_set_value(self):
|
103
|
+
def empty_result_set_value(self) -> Any:
|
76
104
|
for expression in self.get_source_expressions():
|
77
105
|
result = expression.empty_result_set_value
|
78
106
|
if result is NotImplemented or result is not None:
|
@@ -87,13 +115,18 @@ class Collate(Func):
|
|
87
115
|
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
88
116
|
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
89
117
|
|
90
|
-
def __init__(self, expression, collation):
|
118
|
+
def __init__(self, expression: Any, collation: str) -> None:
|
91
119
|
if not (collation and self.collation_re.match(collation)):
|
92
120
|
raise ValueError(f"Invalid collation name: {collation!r}.")
|
93
121
|
self.collation = collation
|
94
122
|
super().__init__(expression)
|
95
123
|
|
96
|
-
def as_sql(
|
124
|
+
def as_sql(
|
125
|
+
self,
|
126
|
+
compiler: SQLCompiler,
|
127
|
+
connection: BaseDatabaseWrapper,
|
128
|
+
**extra_context: Any,
|
129
|
+
) -> tuple[str, tuple[Any, ...]]:
|
97
130
|
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
98
131
|
return super().as_sql(compiler, connection, **extra_context)
|
99
132
|
|
@@ -109,12 +142,17 @@ class Greatest(Func):
|
|
109
142
|
|
110
143
|
function = "GREATEST"
|
111
144
|
|
112
|
-
def __init__(self, *expressions, **extra):
|
145
|
+
def __init__(self, *expressions: Any, **extra: Any) -> None:
|
113
146
|
if len(expressions) < 2:
|
114
147
|
raise ValueError("Greatest must take at least two expressions")
|
115
148
|
super().__init__(*expressions, **extra)
|
116
149
|
|
117
|
-
def as_sqlite(
|
150
|
+
def as_sqlite(
|
151
|
+
self,
|
152
|
+
compiler: SQLCompiler,
|
153
|
+
connection: BaseDatabaseWrapper,
|
154
|
+
**extra_context: Any,
|
155
|
+
) -> tuple[str, tuple[Any, ...]]:
|
118
156
|
"""Use the MAX function on SQLite."""
|
119
157
|
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
120
158
|
|
@@ -123,20 +161,30 @@ class JSONObject(Func):
|
|
123
161
|
function = "JSON_OBJECT"
|
124
162
|
output_field = JSONField()
|
125
163
|
|
126
|
-
def __init__(self, **fields):
|
164
|
+
def __init__(self, **fields: Any) -> None:
|
127
165
|
expressions = []
|
128
166
|
for key, value in fields.items():
|
129
167
|
expressions.extend((Value(key), value))
|
130
168
|
super().__init__(*expressions)
|
131
169
|
|
132
|
-
def as_sql(
|
170
|
+
def as_sql(
|
171
|
+
self,
|
172
|
+
compiler: SQLCompiler,
|
173
|
+
connection: BaseDatabaseWrapper,
|
174
|
+
**extra_context: Any,
|
175
|
+
) -> tuple[str, tuple[Any, ...]]:
|
133
176
|
if not connection.features.has_json_object_function:
|
134
177
|
raise NotSupportedError(
|
135
178
|
"JSONObject() is not supported on this database backend."
|
136
179
|
)
|
137
180
|
return super().as_sql(compiler, connection, **extra_context)
|
138
181
|
|
139
|
-
def as_postgresql(
|
182
|
+
def as_postgresql(
|
183
|
+
self,
|
184
|
+
compiler: SQLCompiler,
|
185
|
+
connection: BaseDatabaseWrapper,
|
186
|
+
**extra_context: Any,
|
187
|
+
) -> tuple[str, tuple[Any, ...]]:
|
140
188
|
copy = self.copy()
|
141
189
|
copy.set_source_expressions(
|
142
190
|
[
|
@@ -163,12 +211,17 @@ class Least(Func):
|
|
163
211
|
|
164
212
|
function = "LEAST"
|
165
213
|
|
166
|
-
def __init__(self, *expressions, **extra):
|
214
|
+
def __init__(self, *expressions: Any, **extra: Any) -> None:
|
167
215
|
if len(expressions) < 2:
|
168
216
|
raise ValueError("Least must take at least two expressions")
|
169
217
|
super().__init__(*expressions, **extra)
|
170
218
|
|
171
|
-
def as_sqlite(
|
219
|
+
def as_sqlite(
|
220
|
+
self,
|
221
|
+
compiler: SQLCompiler,
|
222
|
+
connection: BaseDatabaseWrapper,
|
223
|
+
**extra_context: Any,
|
224
|
+
) -> tuple[str, tuple[Any, ...]]:
|
172
225
|
"""Use the MIN function on SQLite."""
|
173
226
|
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
174
227
|
|
@@ -1,4 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from datetime import datetime
|
4
|
+
from typing import TYPE_CHECKING, Any
|
2
5
|
|
3
6
|
from plain.models.expressions import Func
|
4
7
|
from plain.models.fields import (
|
@@ -19,11 +22,15 @@ from plain.models.lookups import (
|
|
19
22
|
)
|
20
23
|
from plain.utils import timezone
|
21
24
|
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
27
|
+
from plain.models.sql.compiler import SQLCompiler
|
28
|
+
|
22
29
|
|
23
30
|
class TimezoneMixin:
|
24
31
|
tzinfo = None
|
25
32
|
|
26
|
-
def get_tzname(self):
|
33
|
+
def get_tzname(self) -> str | None:
|
27
34
|
# Timezone conversions must happen to the input datetime *before*
|
28
35
|
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
29
36
|
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
@@ -38,7 +45,13 @@ class Extract(TimezoneMixin, Transform):
|
|
38
45
|
lookup_name = None
|
39
46
|
output_field = IntegerField()
|
40
47
|
|
41
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
expression: Any,
|
51
|
+
lookup_name: str | None = None,
|
52
|
+
tzinfo: Any = None,
|
53
|
+
**extra: Any,
|
54
|
+
) -> None:
|
42
55
|
if self.lookup_name is None:
|
43
56
|
self.lookup_name = lookup_name
|
44
57
|
if self.lookup_name is None:
|
@@ -46,7 +59,9 @@ class Extract(TimezoneMixin, Transform):
|
|
46
59
|
self.tzinfo = tzinfo
|
47
60
|
super().__init__(expression, **extra)
|
48
61
|
|
49
|
-
def as_sql(
|
62
|
+
def as_sql(
|
63
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
64
|
+
) -> tuple[str, tuple[Any, ...]]:
|
50
65
|
sql, params = compiler.compile(self.lhs)
|
51
66
|
lhs_output_field = self.lhs.output_field
|
52
67
|
if isinstance(lhs_output_field, DateTimeField):
|
@@ -76,11 +91,16 @@ class Extract(TimezoneMixin, Transform):
|
|
76
91
|
# resolve_expression has already validated the output_field so this
|
77
92
|
# assert should never be hit.
|
78
93
|
raise ValueError("Tried to Extract from an invalid type.")
|
79
|
-
return sql, params
|
94
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
80
95
|
|
81
96
|
def resolve_expression(
|
82
|
-
self,
|
83
|
-
|
97
|
+
self,
|
98
|
+
query: Any = None,
|
99
|
+
allow_joins: bool = True,
|
100
|
+
reuse: Any = None,
|
101
|
+
summarize: bool = False,
|
102
|
+
for_save: bool = False,
|
103
|
+
) -> Extract:
|
84
104
|
copy = super().resolve_expression(
|
85
105
|
query, allow_joins, reuse, summarize, for_save
|
86
106
|
)
|
@@ -209,7 +229,12 @@ class Now(Func):
|
|
209
229
|
template = "CURRENT_TIMESTAMP"
|
210
230
|
output_field = DateTimeField()
|
211
231
|
|
212
|
-
def as_postgresql(
|
232
|
+
def as_postgresql(
|
233
|
+
self,
|
234
|
+
compiler: SQLCompiler,
|
235
|
+
connection: BaseDatabaseWrapper,
|
236
|
+
**extra_context: Any,
|
237
|
+
) -> tuple[str, tuple[Any, ...]]:
|
213
238
|
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
214
239
|
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
215
240
|
# other databases.
|
@@ -217,12 +242,22 @@ class Now(Func):
|
|
217
242
|
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
218
243
|
)
|
219
244
|
|
220
|
-
def as_mysql(
|
245
|
+
def as_mysql(
|
246
|
+
self,
|
247
|
+
compiler: SQLCompiler,
|
248
|
+
connection: BaseDatabaseWrapper,
|
249
|
+
**extra_context: Any,
|
250
|
+
) -> tuple[str, tuple[Any, ...]]:
|
221
251
|
return self.as_sql(
|
222
252
|
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
223
253
|
)
|
224
254
|
|
225
|
-
def as_sqlite(
|
255
|
+
def as_sqlite(
|
256
|
+
self,
|
257
|
+
compiler: SQLCompiler,
|
258
|
+
connection: BaseDatabaseWrapper,
|
259
|
+
**extra_context: Any,
|
260
|
+
) -> tuple[str, tuple[Any, ...]]:
|
226
261
|
return self.as_sql(
|
227
262
|
compiler,
|
228
263
|
connection,
|
@@ -237,15 +272,17 @@ class TruncBase(TimezoneMixin, Transform):
|
|
237
272
|
|
238
273
|
def __init__(
|
239
274
|
self,
|
240
|
-
expression,
|
241
|
-
output_field=None,
|
242
|
-
tzinfo=None,
|
243
|
-
**extra,
|
244
|
-
):
|
275
|
+
expression: Any,
|
276
|
+
output_field: Field | None = None,
|
277
|
+
tzinfo: Any = None,
|
278
|
+
**extra: Any,
|
279
|
+
) -> None:
|
245
280
|
self.tzinfo = tzinfo
|
246
281
|
super().__init__(expression, output_field=output_field, **extra)
|
247
282
|
|
248
|
-
def as_sql(
|
283
|
+
def as_sql(
|
284
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
285
|
+
) -> tuple[str, tuple[Any, ...]]:
|
249
286
|
sql, params = compiler.compile(self.lhs)
|
250
287
|
tzname = None
|
251
288
|
if isinstance(self.lhs.output_field, DateTimeField):
|
@@ -268,11 +305,16 @@ class TruncBase(TimezoneMixin, Transform):
|
|
268
305
|
raise ValueError(
|
269
306
|
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
270
307
|
)
|
271
|
-
return sql, params
|
308
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
272
309
|
|
273
310
|
def resolve_expression(
|
274
|
-
self,
|
275
|
-
|
311
|
+
self,
|
312
|
+
query: Any = None,
|
313
|
+
allow_joins: bool = True,
|
314
|
+
reuse: Any = None,
|
315
|
+
summarize: bool = False,
|
316
|
+
for_save: bool = False,
|
317
|
+
) -> TruncBase:
|
276
318
|
copy = super().resolve_expression(
|
277
319
|
query, allow_joins, reuse, summarize, for_save
|
278
320
|
)
|
@@ -325,7 +367,9 @@ class TruncBase(TimezoneMixin, Transform):
|
|
325
367
|
)
|
326
368
|
return copy
|
327
369
|
|
328
|
-
def convert_value(
|
370
|
+
def convert_value(
|
371
|
+
self, value: Any, expression: Any, connection: BaseDatabaseWrapper
|
372
|
+
) -> Any:
|
329
373
|
if isinstance(self.output_field, DateTimeField):
|
330
374
|
if value is not None:
|
331
375
|
value = value.replace(tzinfo=None)
|
@@ -348,12 +392,12 @@ class TruncBase(TimezoneMixin, Transform):
|
|
348
392
|
class Trunc(TruncBase):
|
349
393
|
def __init__(
|
350
394
|
self,
|
351
|
-
expression,
|
352
|
-
kind,
|
353
|
-
output_field=None,
|
354
|
-
tzinfo=None,
|
355
|
-
**extra,
|
356
|
-
):
|
395
|
+
expression: Any,
|
396
|
+
kind: str,
|
397
|
+
output_field: Field | None = None,
|
398
|
+
tzinfo: Any = None,
|
399
|
+
**extra: Any,
|
400
|
+
) -> None:
|
357
401
|
self.kind = kind
|
358
402
|
super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
|
359
403
|
|
@@ -385,11 +429,14 @@ class TruncDate(TruncBase):
|
|
385
429
|
lookup_name = "date"
|
386
430
|
output_field = DateField()
|
387
431
|
|
388
|
-
def as_sql(
|
432
|
+
def as_sql(
|
433
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
434
|
+
) -> tuple[str, tuple[Any, ...]]:
|
389
435
|
# Cast to date rather than truncate to date.
|
390
436
|
sql, params = compiler.compile(self.lhs)
|
391
437
|
tzname = self.get_tzname()
|
392
|
-
|
438
|
+
sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
439
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
393
440
|
|
394
441
|
|
395
442
|
class TruncTime(TruncBase):
|
@@ -397,11 +444,14 @@ class TruncTime(TruncBase):
|
|
397
444
|
lookup_name = "time"
|
398
445
|
output_field = TimeField()
|
399
446
|
|
400
|
-
def as_sql(
|
447
|
+
def as_sql(
|
448
|
+
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
449
|
+
) -> tuple[str, tuple[Any, ...]]:
|
401
450
|
# Cast to time rather than truncate to time.
|
402
451
|
sql, params = compiler.compile(self.lhs)
|
403
452
|
tzname = self.get_tzname()
|
404
|
-
|
453
|
+
sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
454
|
+
return sql, tuple(params) if isinstance(params, list) else params
|
405
455
|
|
406
456
|
|
407
457
|
class TruncHour(TruncBase):
|
plain/models/functions/math.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
1
5
|
from plain.models.expressions import Func, Value
|
2
|
-
from plain.models.fields import FloatField, IntegerField
|
6
|
+
from plain.models.fields import Field, FloatField, IntegerField
|
3
7
|
from plain.models.functions import Cast
|
4
8
|
from plain.models.functions.mixins import (
|
5
9
|
FixDecimalInputMixin,
|
@@ -7,6 +11,10 @@ from plain.models.functions.mixins import (
|
|
7
11
|
)
|
8
12
|
from plain.models.lookups import Transform
|
9
13
|
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
16
|
+
from plain.models.sql.compiler import SQLCompiler
|
17
|
+
|
10
18
|
|
11
19
|
class Abs(Transform):
|
12
20
|
function = "ABS"
|
@@ -32,10 +40,15 @@ class ATan2(NumericOutputFieldMixin, Func):
|
|
32
40
|
function = "ATAN2"
|
33
41
|
arity = 2
|
34
42
|
|
35
|
-
def as_sqlite(
|
43
|
+
def as_sqlite(
|
44
|
+
self,
|
45
|
+
compiler: SQLCompiler,
|
46
|
+
connection: BaseDatabaseWrapper,
|
47
|
+
**extra_context: Any,
|
48
|
+
) -> tuple[str, tuple[Any, ...]]:
|
36
49
|
if not getattr(
|
37
50
|
connection.ops, "spatialite", False
|
38
|
-
) or connection.ops.spatial_version >= (5, 0, 0):
|
51
|
+
) or connection.ops.spatial_version >= (5, 0, 0): # type: ignore[attr-defined]
|
39
52
|
return self.as_sql(compiler, connection)
|
40
53
|
# This function is usually ATan2(y, x), returning the inverse tangent
|
41
54
|
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
@@ -93,7 +106,12 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
|
93
106
|
function = "LOG"
|
94
107
|
arity = 2
|
95
108
|
|
96
|
-
def as_sqlite(
|
109
|
+
def as_sqlite(
|
110
|
+
self,
|
111
|
+
compiler: SQLCompiler,
|
112
|
+
connection: BaseDatabaseWrapper,
|
113
|
+
**extra_context: Any,
|
114
|
+
) -> tuple[str, tuple[Any, ...]]:
|
97
115
|
if not getattr(connection.ops, "spatialite", False):
|
98
116
|
return self.as_sql(compiler, connection)
|
99
117
|
# This function is usually Log(b, x) returning the logarithm of x to
|
@@ -127,13 +145,23 @@ class Random(NumericOutputFieldMixin, Func):
|
|
127
145
|
function = "RANDOM"
|
128
146
|
arity = 0
|
129
147
|
|
130
|
-
def as_mysql(
|
148
|
+
def as_mysql(
|
149
|
+
self,
|
150
|
+
compiler: SQLCompiler,
|
151
|
+
connection: BaseDatabaseWrapper,
|
152
|
+
**extra_context: Any,
|
153
|
+
) -> tuple[str, tuple[Any, ...]]:
|
131
154
|
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
132
155
|
|
133
|
-
def as_sqlite(
|
156
|
+
def as_sqlite(
|
157
|
+
self,
|
158
|
+
compiler: SQLCompiler,
|
159
|
+
connection: BaseDatabaseWrapper,
|
160
|
+
**extra_context: Any,
|
161
|
+
) -> tuple[str, tuple[Any, ...]]:
|
134
162
|
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
135
163
|
|
136
|
-
def get_group_by_cols(self):
|
164
|
+
def get_group_by_cols(self) -> list[Any]:
|
137
165
|
return []
|
138
166
|
|
139
167
|
|
@@ -142,16 +170,21 @@ class Round(FixDecimalInputMixin, Transform):
|
|
142
170
|
lookup_name = "round"
|
143
171
|
arity = None # Override Transform's arity=1 to enable passing precision.
|
144
172
|
|
145
|
-
def __init__(self, expression, precision=0, **extra):
|
173
|
+
def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
|
146
174
|
super().__init__(expression, precision, **extra)
|
147
175
|
|
148
|
-
def as_sqlite(
|
176
|
+
def as_sqlite(
|
177
|
+
self,
|
178
|
+
compiler: SQLCompiler,
|
179
|
+
connection: BaseDatabaseWrapper,
|
180
|
+
**extra_context: Any,
|
181
|
+
) -> tuple[str, tuple[Any, ...]]:
|
149
182
|
precision = self.get_source_expressions()[1]
|
150
183
|
if isinstance(precision, Value) and precision.value < 0:
|
151
184
|
raise ValueError("SQLite does not support negative precision.")
|
152
185
|
return super().as_sqlite(compiler, connection, **extra_context)
|
153
186
|
|
154
|
-
def _resolve_output_field(self):
|
187
|
+
def _resolve_output_field(self) -> Field:
|
155
188
|
source = self.get_source_expressions()[0]
|
156
189
|
return source.output_field
|
157
190
|
|
plain/models/functions/mixins.py
CHANGED
@@ -1,11 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import sys
|
4
|
+
from typing import TYPE_CHECKING, Any
|
2
5
|
|
3
6
|
from plain.models.fields import DecimalField, FloatField, IntegerField
|
4
7
|
from plain.models.functions import Cast
|
5
8
|
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from plain.models.backends.base.base import BaseDatabaseWrapper
|
11
|
+
from plain.models.sql.compiler import SQLCompiler
|
12
|
+
|
6
13
|
|
7
14
|
class FixDecimalInputMixin:
|
8
|
-
def as_postgresql(
|
15
|
+
def as_postgresql(
|
16
|
+
self,
|
17
|
+
compiler: SQLCompiler,
|
18
|
+
connection: BaseDatabaseWrapper,
|
19
|
+
**extra_context: Any,
|
20
|
+
) -> tuple[str, tuple[Any, ...]]:
|
9
21
|
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
10
22
|
# following function signatures:
|
11
23
|
# - LOG(double, double)
|
@@ -24,18 +36,23 @@ class FixDecimalInputMixin:
|
|
24
36
|
|
25
37
|
|
26
38
|
class FixDurationInputMixin:
|
27
|
-
def as_mysql(
|
28
|
-
|
29
|
-
|
39
|
+
def as_mysql(
|
40
|
+
self,
|
41
|
+
compiler: SQLCompiler,
|
42
|
+
connection: BaseDatabaseWrapper,
|
43
|
+
**extra_context: Any,
|
44
|
+
) -> tuple[str, tuple[Any, ...]]:
|
45
|
+
sql, params = super().as_sql(compiler, connection, **extra_context) # type: ignore[misc]
|
46
|
+
if self.output_field.get_internal_type() == "DurationField": # type: ignore[attr-defined]
|
30
47
|
sql = f"CAST({sql} AS SIGNED)"
|
31
48
|
return sql, params
|
32
49
|
|
33
50
|
|
34
51
|
class NumericOutputFieldMixin:
|
35
|
-
def _resolve_output_field(self):
|
36
|
-
source_fields = self.get_source_fields()
|
52
|
+
def _resolve_output_field(self) -> DecimalField | FloatField:
|
53
|
+
source_fields = self.get_source_fields() # type: ignore[attr-defined]
|
37
54
|
if any(isinstance(s, DecimalField) for s in source_fields):
|
38
55
|
return DecimalField()
|
39
56
|
if any(isinstance(s, IntegerField) for s in source_fields):
|
40
57
|
return FloatField()
|
41
|
-
return super()._resolve_output_field() if source_fields else FloatField()
|
58
|
+
return super()._resolve_output_field() if source_fields else FloatField() # type: ignore[misc]
|