plain.postgres 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/postgres/CHANGELOG.md +1028 -0
- plain/postgres/README.md +925 -0
- plain/postgres/__init__.py +120 -0
- plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
- plain/postgres/aggregates.py +236 -0
- plain/postgres/backups/__init__.py +0 -0
- plain/postgres/backups/cli.py +148 -0
- plain/postgres/backups/clients.py +94 -0
- plain/postgres/backups/core.py +172 -0
- plain/postgres/base.py +1415 -0
- plain/postgres/cli/__init__.py +3 -0
- plain/postgres/cli/db.py +142 -0
- plain/postgres/cli/migrations.py +1085 -0
- plain/postgres/config.py +18 -0
- plain/postgres/connection.py +1331 -0
- plain/postgres/connections.py +77 -0
- plain/postgres/constants.py +13 -0
- plain/postgres/constraints.py +495 -0
- plain/postgres/database_url.py +94 -0
- plain/postgres/db.py +59 -0
- plain/postgres/default_settings.py +38 -0
- plain/postgres/deletion.py +475 -0
- plain/postgres/dialect.py +640 -0
- plain/postgres/entrypoints.py +4 -0
- plain/postgres/enums.py +103 -0
- plain/postgres/exceptions.py +217 -0
- plain/postgres/expressions.py +1912 -0
- plain/postgres/fields/__init__.py +2118 -0
- plain/postgres/fields/encrypted.py +354 -0
- plain/postgres/fields/json.py +413 -0
- plain/postgres/fields/mixins.py +30 -0
- plain/postgres/fields/related.py +1192 -0
- plain/postgres/fields/related_descriptors.py +290 -0
- plain/postgres/fields/related_lookups.py +223 -0
- plain/postgres/fields/related_managers.py +661 -0
- plain/postgres/fields/reverse_descriptors.py +229 -0
- plain/postgres/fields/reverse_related.py +328 -0
- plain/postgres/fields/timezones.py +143 -0
- plain/postgres/forms.py +773 -0
- plain/postgres/functions/__init__.py +189 -0
- plain/postgres/functions/comparison.py +127 -0
- plain/postgres/functions/datetime.py +454 -0
- plain/postgres/functions/math.py +140 -0
- plain/postgres/functions/mixins.py +59 -0
- plain/postgres/functions/text.py +282 -0
- plain/postgres/functions/window.py +125 -0
- plain/postgres/indexes.py +286 -0
- plain/postgres/lookups.py +758 -0
- plain/postgres/meta.py +584 -0
- plain/postgres/migrations/__init__.py +53 -0
- plain/postgres/migrations/autodetector.py +1379 -0
- plain/postgres/migrations/exceptions.py +54 -0
- plain/postgres/migrations/executor.py +188 -0
- plain/postgres/migrations/graph.py +364 -0
- plain/postgres/migrations/loader.py +377 -0
- plain/postgres/migrations/migration.py +180 -0
- plain/postgres/migrations/operations/__init__.py +34 -0
- plain/postgres/migrations/operations/base.py +139 -0
- plain/postgres/migrations/operations/fields.py +373 -0
- plain/postgres/migrations/operations/models.py +798 -0
- plain/postgres/migrations/operations/special.py +184 -0
- plain/postgres/migrations/optimizer.py +74 -0
- plain/postgres/migrations/questioner.py +340 -0
- plain/postgres/migrations/recorder.py +119 -0
- plain/postgres/migrations/serializer.py +378 -0
- plain/postgres/migrations/state.py +882 -0
- plain/postgres/migrations/utils.py +147 -0
- plain/postgres/migrations/writer.py +302 -0
- plain/postgres/options.py +207 -0
- plain/postgres/otel.py +231 -0
- plain/postgres/preflight.py +336 -0
- plain/postgres/query.py +2242 -0
- plain/postgres/query_utils.py +456 -0
- plain/postgres/registry.py +217 -0
- plain/postgres/schema.py +1885 -0
- plain/postgres/sql/__init__.py +40 -0
- plain/postgres/sql/compiler.py +1869 -0
- plain/postgres/sql/constants.py +22 -0
- plain/postgres/sql/datastructures.py +222 -0
- plain/postgres/sql/query.py +2947 -0
- plain/postgres/sql/where.py +374 -0
- plain/postgres/test/__init__.py +0 -0
- plain/postgres/test/pytest.py +117 -0
- plain/postgres/test/utils.py +18 -0
- plain/postgres/transaction.py +222 -0
- plain/postgres/types.py +92 -0
- plain/postgres/types.pyi +751 -0
- plain/postgres/utils.py +345 -0
- plain_postgres-0.84.0.dist-info/METADATA +937 -0
- plain_postgres-0.84.0.dist-info/RECORD +93 -0
- plain_postgres-0.84.0.dist-info/WHEEL +4 -0
- plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
- plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from plain.postgres.dialect import (
|
|
7
|
+
date_extract_sql,
|
|
8
|
+
date_trunc_sql,
|
|
9
|
+
datetime_cast_date_sql,
|
|
10
|
+
datetime_cast_time_sql,
|
|
11
|
+
datetime_extract_sql,
|
|
12
|
+
datetime_trunc_sql,
|
|
13
|
+
time_extract_sql,
|
|
14
|
+
time_trunc_sql,
|
|
15
|
+
)
|
|
16
|
+
from plain.postgres.expressions import Func
|
|
17
|
+
from plain.postgres.fields import (
|
|
18
|
+
DateField,
|
|
19
|
+
DateTimeField,
|
|
20
|
+
DurationField,
|
|
21
|
+
Field,
|
|
22
|
+
IntegerField,
|
|
23
|
+
TimeField,
|
|
24
|
+
)
|
|
25
|
+
from plain.postgres.lookups import (
|
|
26
|
+
Transform,
|
|
27
|
+
YearExact,
|
|
28
|
+
YearGt,
|
|
29
|
+
YearGte,
|
|
30
|
+
YearLt,
|
|
31
|
+
YearLte,
|
|
32
|
+
)
|
|
33
|
+
from plain.utils import timezone
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from plain.postgres.connection import DatabaseConnection
|
|
37
|
+
from plain.postgres.sql.compiler import SQLCompiler
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TimezoneMixin(Transform):
|
|
41
|
+
tzinfo = None
|
|
42
|
+
|
|
43
|
+
def get_tzname(self) -> str | None:
|
|
44
|
+
# Timezone conversions must happen to the input datetime *before*
|
|
45
|
+
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
|
46
|
+
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
|
47
|
+
# based on the input datetime not the stored datetime.
|
|
48
|
+
if self.tzinfo is None:
|
|
49
|
+
return timezone.get_current_timezone_name()
|
|
50
|
+
else:
|
|
51
|
+
return timezone._get_timezone_name(self.tzinfo)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Extract(TimezoneMixin, Transform):
|
|
55
|
+
lookup_name: str | None = None
|
|
56
|
+
output_field = IntegerField()
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
expression: Any,
|
|
61
|
+
lookup_name: str | None = None,
|
|
62
|
+
tzinfo: Any = None,
|
|
63
|
+
**extra: Any,
|
|
64
|
+
) -> None:
|
|
65
|
+
if self.lookup_name is None:
|
|
66
|
+
self.lookup_name = lookup_name
|
|
67
|
+
if self.lookup_name is None:
|
|
68
|
+
raise ValueError("lookup_name must be provided")
|
|
69
|
+
self.tzinfo = tzinfo
|
|
70
|
+
super().__init__(expression, **extra)
|
|
71
|
+
|
|
72
|
+
def as_sql(
|
|
73
|
+
self,
|
|
74
|
+
compiler: SQLCompiler,
|
|
75
|
+
connection: DatabaseConnection,
|
|
76
|
+
function: str | None = None,
|
|
77
|
+
template: str | None = None,
|
|
78
|
+
arg_joiner: str | None = None,
|
|
79
|
+
**extra_context: Any,
|
|
80
|
+
) -> tuple[str, list[Any]]:
|
|
81
|
+
# lookup_name is guaranteed to be str after __init__ validation
|
|
82
|
+
assert self.lookup_name is not None
|
|
83
|
+
sql, params = compiler.compile(self.lhs)
|
|
84
|
+
lhs_output_field = self.lhs.output_field
|
|
85
|
+
if isinstance(lhs_output_field, DateTimeField):
|
|
86
|
+
tzname = self.get_tzname()
|
|
87
|
+
sql, params = datetime_extract_sql(
|
|
88
|
+
self.lookup_name, sql, tuple(params), tzname
|
|
89
|
+
)
|
|
90
|
+
elif self.tzinfo is not None:
|
|
91
|
+
raise ValueError("tzinfo can only be used with DateTimeField.")
|
|
92
|
+
elif isinstance(lhs_output_field, DateField):
|
|
93
|
+
sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
|
|
94
|
+
elif isinstance(lhs_output_field, TimeField):
|
|
95
|
+
sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
|
|
96
|
+
elif isinstance(lhs_output_field, DurationField):
|
|
97
|
+
# PostgreSQL has native duration (interval) type
|
|
98
|
+
sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
|
|
99
|
+
else:
|
|
100
|
+
# resolve_expression has already validated the output_field so this
|
|
101
|
+
# assert should never be hit.
|
|
102
|
+
raise ValueError("Tried to Extract from an invalid type.")
|
|
103
|
+
return sql, list(params)
|
|
104
|
+
|
|
105
|
+
def resolve_expression(
|
|
106
|
+
self,
|
|
107
|
+
query: Any = None,
|
|
108
|
+
allow_joins: bool = True,
|
|
109
|
+
reuse: Any = None,
|
|
110
|
+
summarize: bool = False,
|
|
111
|
+
for_save: bool = False,
|
|
112
|
+
) -> Extract:
|
|
113
|
+
copy = super().resolve_expression(
|
|
114
|
+
query, allow_joins, reuse, summarize, for_save
|
|
115
|
+
)
|
|
116
|
+
field = getattr(copy.lhs, "output_field", None)
|
|
117
|
+
if field is None:
|
|
118
|
+
return copy
|
|
119
|
+
if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Extract input expression must be DateField, DateTimeField, "
|
|
122
|
+
"TimeField, or DurationField."
|
|
123
|
+
)
|
|
124
|
+
# Passing dates to functions expecting datetimes is most likely a mistake.
|
|
125
|
+
if type(field) == DateField and copy.lookup_name in ( # noqa: E721
|
|
126
|
+
"hour",
|
|
127
|
+
"minute",
|
|
128
|
+
"second",
|
|
129
|
+
):
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
|
|
132
|
+
)
|
|
133
|
+
if isinstance(field, DurationField) and copy.lookup_name in (
|
|
134
|
+
"year",
|
|
135
|
+
"iso_year",
|
|
136
|
+
"month",
|
|
137
|
+
"week",
|
|
138
|
+
"week_day",
|
|
139
|
+
"iso_week_day",
|
|
140
|
+
"quarter",
|
|
141
|
+
):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
|
|
144
|
+
)
|
|
145
|
+
return copy
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class ExtractYear(Extract):
|
|
149
|
+
lookup_name = "year"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ExtractIsoYear(Extract):
|
|
153
|
+
"""Return the ISO-8601 week-numbering year."""
|
|
154
|
+
|
|
155
|
+
lookup_name = "iso_year"
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ExtractMonth(Extract):
|
|
159
|
+
lookup_name = "month"
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ExtractDay(Extract):
|
|
163
|
+
lookup_name = "day"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ExtractWeek(Extract):
|
|
167
|
+
"""
|
|
168
|
+
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
|
169
|
+
week.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
lookup_name = "week"
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ExtractWeekDay(Extract):
|
|
176
|
+
"""
|
|
177
|
+
Return Sunday=1 through Saturday=7.
|
|
178
|
+
|
|
179
|
+
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
lookup_name = "week_day"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ExtractIsoWeekDay(Extract):
|
|
186
|
+
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
|
187
|
+
|
|
188
|
+
lookup_name = "iso_week_day"
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class ExtractQuarter(Extract):
|
|
192
|
+
lookup_name = "quarter"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class ExtractHour(Extract):
|
|
196
|
+
lookup_name = "hour"
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ExtractMinute(Extract):
|
|
200
|
+
lookup_name = "minute"
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class ExtractSecond(Extract):
|
|
204
|
+
lookup_name = "second"
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
DateField.register_lookup(ExtractYear)
|
|
208
|
+
DateField.register_lookup(ExtractMonth)
|
|
209
|
+
DateField.register_lookup(ExtractDay)
|
|
210
|
+
DateField.register_lookup(ExtractWeekDay)
|
|
211
|
+
DateField.register_lookup(ExtractIsoWeekDay)
|
|
212
|
+
DateField.register_lookup(ExtractWeek)
|
|
213
|
+
DateField.register_lookup(ExtractIsoYear)
|
|
214
|
+
DateField.register_lookup(ExtractQuarter)
|
|
215
|
+
|
|
216
|
+
TimeField.register_lookup(ExtractHour)
|
|
217
|
+
TimeField.register_lookup(ExtractMinute)
|
|
218
|
+
TimeField.register_lookup(ExtractSecond)
|
|
219
|
+
|
|
220
|
+
DateTimeField.register_lookup(ExtractHour)
|
|
221
|
+
DateTimeField.register_lookup(ExtractMinute)
|
|
222
|
+
DateTimeField.register_lookup(ExtractSecond)
|
|
223
|
+
|
|
224
|
+
ExtractYear.register_lookup(YearExact)
|
|
225
|
+
ExtractYear.register_lookup(YearGt)
|
|
226
|
+
ExtractYear.register_lookup(YearGte)
|
|
227
|
+
ExtractYear.register_lookup(YearLt)
|
|
228
|
+
ExtractYear.register_lookup(YearLte)
|
|
229
|
+
|
|
230
|
+
ExtractIsoYear.register_lookup(YearExact)
|
|
231
|
+
ExtractIsoYear.register_lookup(YearGt)
|
|
232
|
+
ExtractIsoYear.register_lookup(YearGte)
|
|
233
|
+
ExtractIsoYear.register_lookup(YearLt)
|
|
234
|
+
ExtractIsoYear.register_lookup(YearLte)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class Now(Func):
|
|
238
|
+
# STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
|
|
239
|
+
# as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
|
|
240
|
+
# transaction.
|
|
241
|
+
template = "STATEMENT_TIMESTAMP()"
|
|
242
|
+
output_field = DateTimeField()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class TruncBase(TimezoneMixin, Transform):
|
|
246
|
+
kind: str | None = None
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
expression: Any,
|
|
251
|
+
output_field: Field | None = None,
|
|
252
|
+
tzinfo: Any = None,
|
|
253
|
+
**extra: Any,
|
|
254
|
+
) -> None:
|
|
255
|
+
self.tzinfo = tzinfo
|
|
256
|
+
super().__init__(expression, output_field=output_field, **extra)
|
|
257
|
+
|
|
258
|
+
def as_sql(
|
|
259
|
+
self,
|
|
260
|
+
compiler: SQLCompiler,
|
|
261
|
+
connection: DatabaseConnection,
|
|
262
|
+
function: str | None = None,
|
|
263
|
+
template: str | None = None,
|
|
264
|
+
arg_joiner: str | None = None,
|
|
265
|
+
**extra_context: Any,
|
|
266
|
+
) -> tuple[str, list[Any]]:
|
|
267
|
+
# kind is guaranteed to be str in subclasses
|
|
268
|
+
assert self.kind is not None
|
|
269
|
+
sql, params = compiler.compile(self.lhs)
|
|
270
|
+
tzname = None
|
|
271
|
+
if isinstance(self.lhs.output_field, DateTimeField):
|
|
272
|
+
tzname = self.get_tzname()
|
|
273
|
+
elif self.tzinfo is not None:
|
|
274
|
+
raise ValueError("tzinfo can only be used with DateTimeField.")
|
|
275
|
+
if isinstance(self.output_field, DateTimeField):
|
|
276
|
+
sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
|
|
277
|
+
elif isinstance(self.output_field, DateField):
|
|
278
|
+
sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
|
|
279
|
+
elif isinstance(self.output_field, TimeField):
|
|
280
|
+
sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
|
|
281
|
+
else:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
|
284
|
+
)
|
|
285
|
+
return sql, list(params)
|
|
286
|
+
|
|
287
|
+
def resolve_expression(
|
|
288
|
+
self,
|
|
289
|
+
query: Any = None,
|
|
290
|
+
allow_joins: bool = True,
|
|
291
|
+
reuse: Any = None,
|
|
292
|
+
summarize: bool = False,
|
|
293
|
+
for_save: bool = False,
|
|
294
|
+
) -> TruncBase:
|
|
295
|
+
copy = super().resolve_expression(
|
|
296
|
+
query, allow_joins, reuse, summarize, for_save
|
|
297
|
+
)
|
|
298
|
+
field = copy.lhs.output_field
|
|
299
|
+
# DateTimeField is a subclass of DateField so this works for both.
|
|
300
|
+
if not isinstance(field, DateField | TimeField):
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
|
|
303
|
+
)
|
|
304
|
+
# If self.output_field was None, then accessing the field will trigger
|
|
305
|
+
# the resolver to assign it to self.lhs.output_field.
|
|
306
|
+
if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
|
|
307
|
+
raise ValueError(
|
|
308
|
+
"output_field must be either DateField, TimeField, or DateTimeField"
|
|
309
|
+
)
|
|
310
|
+
# Passing dates or times to functions expecting datetimes is most
|
|
311
|
+
# likely a mistake.
|
|
312
|
+
class_output_field = (
|
|
313
|
+
self.__class__.output_field
|
|
314
|
+
if isinstance(self.__class__.output_field, Field)
|
|
315
|
+
else None
|
|
316
|
+
)
|
|
317
|
+
output_field = class_output_field or copy.output_field
|
|
318
|
+
has_explicit_output_field = (
|
|
319
|
+
class_output_field or field.__class__ is not copy.output_field.__class__
|
|
320
|
+
)
|
|
321
|
+
if type(field) == DateField and ( # noqa: E721
|
|
322
|
+
isinstance(output_field, DateTimeField)
|
|
323
|
+
or copy.kind in ("hour", "minute", "second", "time")
|
|
324
|
+
):
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"Cannot truncate DateField '{}' to {}.".format(
|
|
327
|
+
field.name,
|
|
328
|
+
output_field.__class__.__name__
|
|
329
|
+
if has_explicit_output_field
|
|
330
|
+
else "DateTimeField",
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
elif isinstance(field, TimeField) and (
|
|
334
|
+
isinstance(output_field, DateTimeField)
|
|
335
|
+
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
|
336
|
+
):
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"Cannot truncate TimeField '{}' to {}.".format(
|
|
339
|
+
field.name,
|
|
340
|
+
output_field.__class__.__name__
|
|
341
|
+
if has_explicit_output_field
|
|
342
|
+
else "DateTimeField",
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
return copy
|
|
346
|
+
|
|
347
|
+
def convert_value(
|
|
348
|
+
self, value: Any, expression: Any, connection: DatabaseConnection
|
|
349
|
+
) -> Any:
|
|
350
|
+
if isinstance(self.output_field, DateTimeField):
|
|
351
|
+
if value is not None:
|
|
352
|
+
value = value.replace(tzinfo=None)
|
|
353
|
+
value = timezone.make_aware(value, self.tzinfo)
|
|
354
|
+
elif isinstance(value, datetime):
|
|
355
|
+
if value is None:
|
|
356
|
+
pass
|
|
357
|
+
elif isinstance(self.output_field, DateField):
|
|
358
|
+
value = value.date()
|
|
359
|
+
elif isinstance(self.output_field, TimeField):
|
|
360
|
+
value = value.time()
|
|
361
|
+
return value
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class Trunc(TruncBase):
|
|
365
|
+
def __init__(
|
|
366
|
+
self,
|
|
367
|
+
expression: Any,
|
|
368
|
+
kind: str,
|
|
369
|
+
output_field: Field | None = None,
|
|
370
|
+
tzinfo: Any = None,
|
|
371
|
+
**extra: Any,
|
|
372
|
+
) -> None:
|
|
373
|
+
self.kind = kind
|
|
374
|
+
super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class TruncYear(TruncBase):
|
|
378
|
+
kind = "year"
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class TruncQuarter(TruncBase):
|
|
382
|
+
kind = "quarter"
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class TruncMonth(TruncBase):
|
|
386
|
+
kind = "month"
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class TruncWeek(TruncBase):
|
|
390
|
+
"""Truncate to midnight on the Monday of the week."""
|
|
391
|
+
|
|
392
|
+
kind = "week"
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class TruncDay(TruncBase):
|
|
396
|
+
kind = "day"
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class TruncDate(TruncBase):
|
|
400
|
+
kind = "date"
|
|
401
|
+
lookup_name = "date"
|
|
402
|
+
output_field = DateField()
|
|
403
|
+
|
|
404
|
+
def as_sql(
|
|
405
|
+
self,
|
|
406
|
+
compiler: SQLCompiler,
|
|
407
|
+
connection: DatabaseConnection,
|
|
408
|
+
function: str | None = None,
|
|
409
|
+
template: str | None = None,
|
|
410
|
+
arg_joiner: str | None = None,
|
|
411
|
+
**extra_context: Any,
|
|
412
|
+
) -> tuple[str, list[Any]]:
|
|
413
|
+
# Cast to date rather than truncate to date.
|
|
414
|
+
sql, params = compiler.compile(self.lhs)
|
|
415
|
+
tzname = self.get_tzname()
|
|
416
|
+
sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
|
|
417
|
+
return sql, list(params)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class TruncTime(TruncBase):
|
|
421
|
+
kind = "time"
|
|
422
|
+
lookup_name = "time"
|
|
423
|
+
output_field = TimeField()
|
|
424
|
+
|
|
425
|
+
def as_sql(
|
|
426
|
+
self,
|
|
427
|
+
compiler: SQLCompiler,
|
|
428
|
+
connection: DatabaseConnection,
|
|
429
|
+
function: str | None = None,
|
|
430
|
+
template: str | None = None,
|
|
431
|
+
arg_joiner: str | None = None,
|
|
432
|
+
**extra_context: Any,
|
|
433
|
+
) -> tuple[str, list[Any]]:
|
|
434
|
+
# Cast to time rather than truncate to time.
|
|
435
|
+
sql, params = compiler.compile(self.lhs)
|
|
436
|
+
tzname = self.get_tzname()
|
|
437
|
+
sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
|
|
438
|
+
return sql, list(params)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class TruncHour(TruncBase):
|
|
442
|
+
kind = "hour"
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class TruncMinute(TruncBase):
|
|
446
|
+
kind = "minute"
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class TruncSecond(TruncBase):
|
|
450
|
+
kind = "second"
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
DateTimeField.register_lookup(TruncDate)
|
|
454
|
+
DateTimeField.register_lookup(TruncTime)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from plain.postgres.expressions import Func
|
|
6
|
+
from plain.postgres.fields import Field
|
|
7
|
+
from plain.postgres.functions.mixins import (
|
|
8
|
+
FixDecimalInputMixin,
|
|
9
|
+
NumericOutputFieldMixin,
|
|
10
|
+
)
|
|
11
|
+
from plain.postgres.lookups import Transform
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Abs(Transform):
|
|
18
|
+
function = "ABS"
|
|
19
|
+
lookup_name = "abs"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ACos(NumericOutputFieldMixin, Transform):
|
|
23
|
+
function = "ACOS"
|
|
24
|
+
lookup_name = "acos"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ASin(NumericOutputFieldMixin, Transform):
|
|
28
|
+
function = "ASIN"
|
|
29
|
+
lookup_name = "asin"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ATan(NumericOutputFieldMixin, Transform):
|
|
33
|
+
function = "ATAN"
|
|
34
|
+
lookup_name = "atan"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ATan2(NumericOutputFieldMixin, Func):
|
|
38
|
+
function = "ATAN2"
|
|
39
|
+
arity = 2
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Ceil(Transform):
|
|
43
|
+
function = "CEILING"
|
|
44
|
+
lookup_name = "ceil"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Cos(NumericOutputFieldMixin, Transform):
|
|
48
|
+
function = "COS"
|
|
49
|
+
lookup_name = "cos"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Cot(NumericOutputFieldMixin, Transform):
|
|
53
|
+
function = "COT"
|
|
54
|
+
lookup_name = "cot"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Degrees(NumericOutputFieldMixin, Transform):
|
|
58
|
+
function = "DEGREES"
|
|
59
|
+
lookup_name = "degrees"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Exp(NumericOutputFieldMixin, Transform):
|
|
63
|
+
function = "EXP"
|
|
64
|
+
lookup_name = "exp"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Floor(Transform):
|
|
68
|
+
function = "FLOOR"
|
|
69
|
+
lookup_name = "floor"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Ln(NumericOutputFieldMixin, Transform):
|
|
73
|
+
function = "LN"
|
|
74
|
+
lookup_name = "ln"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
|
78
|
+
function = "LOG"
|
|
79
|
+
arity = 2
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
|
83
|
+
function = "MOD"
|
|
84
|
+
arity = 2
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Pi(NumericOutputFieldMixin, Func):
|
|
88
|
+
function = "PI"
|
|
89
|
+
arity = 0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Power(NumericOutputFieldMixin, Func):
|
|
93
|
+
function = "POWER"
|
|
94
|
+
arity = 2
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Radians(NumericOutputFieldMixin, Transform):
|
|
98
|
+
function = "RADIANS"
|
|
99
|
+
lookup_name = "radians"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Random(NumericOutputFieldMixin, Func):
|
|
103
|
+
function = "RANDOM"
|
|
104
|
+
arity = 0
|
|
105
|
+
|
|
106
|
+
def get_group_by_cols(self) -> list[Any]:
|
|
107
|
+
return []
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Round(FixDecimalInputMixin, Transform):
|
|
111
|
+
function = "ROUND"
|
|
112
|
+
lookup_name = "round"
|
|
113
|
+
arity = None # Override Transform's arity=1 to enable passing precision.
|
|
114
|
+
|
|
115
|
+
def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
|
|
116
|
+
super().__init__(expression, precision, **extra)
|
|
117
|
+
|
|
118
|
+
def _resolve_output_field(self) -> Field:
|
|
119
|
+
source = self.get_source_expressions()[0]
|
|
120
|
+
return source.output_field
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class Sign(Transform):
|
|
124
|
+
function = "SIGN"
|
|
125
|
+
lookup_name = "sign"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Sin(NumericOutputFieldMixin, Transform):
|
|
129
|
+
function = "SIN"
|
|
130
|
+
lookup_name = "sin"
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Sqrt(NumericOutputFieldMixin, Transform):
|
|
134
|
+
function = "SQRT"
|
|
135
|
+
lookup_name = "sqrt"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Tan(NumericOutputFieldMixin, Transform):
|
|
139
|
+
function = "TAN"
|
|
140
|
+
lookup_name = "tan"
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from plain.postgres.expressions import Func
|
|
7
|
+
from plain.postgres.fields import DecimalField, Field, FloatField, IntegerField
|
|
8
|
+
from plain.postgres.functions import Cast
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from plain.postgres.connection import DatabaseConnection
|
|
12
|
+
from plain.postgres.sql.compiler import SQLCompiler
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FixDecimalInputMixin(Func):
|
|
16
|
+
"""
|
|
17
|
+
Mixin for Func subclasses that need to convert FloatField to DecimalField.
|
|
18
|
+
|
|
19
|
+
PostgreSQL doesn't support the following function signatures:
|
|
20
|
+
- LOG(double, double)
|
|
21
|
+
- MOD(double, double)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def as_sql(
|
|
25
|
+
self,
|
|
26
|
+
compiler: SQLCompiler,
|
|
27
|
+
connection: DatabaseConnection,
|
|
28
|
+
function: str | None = None,
|
|
29
|
+
template: str | None = None,
|
|
30
|
+
arg_joiner: str | None = None,
|
|
31
|
+
**extra_context: Any,
|
|
32
|
+
) -> tuple[str, list[Any]]:
|
|
33
|
+
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
|
34
|
+
|
|
35
|
+
clone = self.copy()
|
|
36
|
+
clone.set_source_expressions(
|
|
37
|
+
[
|
|
38
|
+
Cast(expression, output_field)
|
|
39
|
+
if isinstance(expression.output_field, FloatField)
|
|
40
|
+
else expression
|
|
41
|
+
for expression in self.get_source_expressions()
|
|
42
|
+
]
|
|
43
|
+
)
|
|
44
|
+
return super(FixDecimalInputMixin, clone).as_sql(
|
|
45
|
+
compiler, connection, **extra_context
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class NumericOutputFieldMixin(Func):
|
|
50
|
+
def _resolve_output_field(self) -> DecimalField | FloatField | Field:
|
|
51
|
+
source_fields = self.get_source_fields()
|
|
52
|
+
if any(isinstance(s, DecimalField) for s in source_fields):
|
|
53
|
+
return DecimalField()
|
|
54
|
+
if any(isinstance(s, IntegerField) for s in source_fields):
|
|
55
|
+
return FloatField()
|
|
56
|
+
if source_fields:
|
|
57
|
+
if result := super()._resolve_output_field():
|
|
58
|
+
return result
|
|
59
|
+
return FloatField()
|