plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,454 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from plain.postgres.dialect import (
7
+ date_extract_sql,
8
+ date_trunc_sql,
9
+ datetime_cast_date_sql,
10
+ datetime_cast_time_sql,
11
+ datetime_extract_sql,
12
+ datetime_trunc_sql,
13
+ time_extract_sql,
14
+ time_trunc_sql,
15
+ )
16
+ from plain.postgres.expressions import Func
17
+ from plain.postgres.fields import (
18
+ DateField,
19
+ DateTimeField,
20
+ DurationField,
21
+ Field,
22
+ IntegerField,
23
+ TimeField,
24
+ )
25
+ from plain.postgres.lookups import (
26
+ Transform,
27
+ YearExact,
28
+ YearGt,
29
+ YearGte,
30
+ YearLt,
31
+ YearLte,
32
+ )
33
+ from plain.utils import timezone
34
+
35
+ if TYPE_CHECKING:
36
+ from plain.postgres.connection import DatabaseConnection
37
+ from plain.postgres.sql.compiler import SQLCompiler
38
+
39
+
40
+ class TimezoneMixin(Transform):
41
+ tzinfo = None
42
+
43
+ def get_tzname(self) -> str | None:
44
+ # Timezone conversions must happen to the input datetime *before*
45
+ # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
46
+ # database as 2016-01-01 01:00:00 +00:00. Any results should be
47
+ # based on the input datetime not the stored datetime.
48
+ if self.tzinfo is None:
49
+ return timezone.get_current_timezone_name()
50
+ else:
51
+ return timezone._get_timezone_name(self.tzinfo)
52
+
53
+
54
+ class Extract(TimezoneMixin, Transform):
55
+ lookup_name: str | None = None
56
+ output_field = IntegerField()
57
+
58
+ def __init__(
59
+ self,
60
+ expression: Any,
61
+ lookup_name: str | None = None,
62
+ tzinfo: Any = None,
63
+ **extra: Any,
64
+ ) -> None:
65
+ if self.lookup_name is None:
66
+ self.lookup_name = lookup_name
67
+ if self.lookup_name is None:
68
+ raise ValueError("lookup_name must be provided")
69
+ self.tzinfo = tzinfo
70
+ super().__init__(expression, **extra)
71
+
72
+ def as_sql(
73
+ self,
74
+ compiler: SQLCompiler,
75
+ connection: DatabaseConnection,
76
+ function: str | None = None,
77
+ template: str | None = None,
78
+ arg_joiner: str | None = None,
79
+ **extra_context: Any,
80
+ ) -> tuple[str, list[Any]]:
81
+ # lookup_name is guaranteed to be str after __init__ validation
82
+ assert self.lookup_name is not None
83
+ sql, params = compiler.compile(self.lhs)
84
+ lhs_output_field = self.lhs.output_field
85
+ if isinstance(lhs_output_field, DateTimeField):
86
+ tzname = self.get_tzname()
87
+ sql, params = datetime_extract_sql(
88
+ self.lookup_name, sql, tuple(params), tzname
89
+ )
90
+ elif self.tzinfo is not None:
91
+ raise ValueError("tzinfo can only be used with DateTimeField.")
92
+ elif isinstance(lhs_output_field, DateField):
93
+ sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
94
+ elif isinstance(lhs_output_field, TimeField):
95
+ sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
96
+ elif isinstance(lhs_output_field, DurationField):
97
+ # PostgreSQL has native duration (interval) type
98
+ sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
99
+ else:
100
+ # resolve_expression has already validated the output_field so this
101
+ # assert should never be hit.
102
+ raise ValueError("Tried to Extract from an invalid type.")
103
+ return sql, list(params)
104
+
105
+ def resolve_expression(
106
+ self,
107
+ query: Any = None,
108
+ allow_joins: bool = True,
109
+ reuse: Any = None,
110
+ summarize: bool = False,
111
+ for_save: bool = False,
112
+ ) -> Extract:
113
+ copy = super().resolve_expression(
114
+ query, allow_joins, reuse, summarize, for_save
115
+ )
116
+ field = getattr(copy.lhs, "output_field", None)
117
+ if field is None:
118
+ return copy
119
+ if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
120
+ raise ValueError(
121
+ "Extract input expression must be DateField, DateTimeField, "
122
+ "TimeField, or DurationField."
123
+ )
124
+ # Passing dates to functions expecting datetimes is most likely a mistake.
125
+ if type(field) == DateField and copy.lookup_name in ( # noqa: E721
126
+ "hour",
127
+ "minute",
128
+ "second",
129
+ ):
130
+ raise ValueError(
131
+ f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
132
+ )
133
+ if isinstance(field, DurationField) and copy.lookup_name in (
134
+ "year",
135
+ "iso_year",
136
+ "month",
137
+ "week",
138
+ "week_day",
139
+ "iso_week_day",
140
+ "quarter",
141
+ ):
142
+ raise ValueError(
143
+ f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
144
+ )
145
+ return copy
146
+
147
+
148
+ class ExtractYear(Extract):
149
+ lookup_name = "year"
150
+
151
+
152
+ class ExtractIsoYear(Extract):
153
+ """Return the ISO-8601 week-numbering year."""
154
+
155
+ lookup_name = "iso_year"
156
+
157
+
158
+ class ExtractMonth(Extract):
159
+ lookup_name = "month"
160
+
161
+
162
+ class ExtractDay(Extract):
163
+ lookup_name = "day"
164
+
165
+
166
+ class ExtractWeek(Extract):
167
+ """
168
+ Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
169
+ week.
170
+ """
171
+
172
+ lookup_name = "week"
173
+
174
+
175
+ class ExtractWeekDay(Extract):
176
+ """
177
+ Return Sunday=1 through Saturday=7.
178
+
179
+ To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
180
+ """
181
+
182
+ lookup_name = "week_day"
183
+
184
+
185
+ class ExtractIsoWeekDay(Extract):
186
+ """Return Monday=1 through Sunday=7, based on ISO-8601."""
187
+
188
+ lookup_name = "iso_week_day"
189
+
190
+
191
+ class ExtractQuarter(Extract):
192
+ lookup_name = "quarter"
193
+
194
+
195
+ class ExtractHour(Extract):
196
+ lookup_name = "hour"
197
+
198
+
199
+ class ExtractMinute(Extract):
200
+ lookup_name = "minute"
201
+
202
+
203
+ class ExtractSecond(Extract):
204
+ lookup_name = "second"
205
+
206
+
207
+ DateField.register_lookup(ExtractYear)
208
+ DateField.register_lookup(ExtractMonth)
209
+ DateField.register_lookup(ExtractDay)
210
+ DateField.register_lookup(ExtractWeekDay)
211
+ DateField.register_lookup(ExtractIsoWeekDay)
212
+ DateField.register_lookup(ExtractWeek)
213
+ DateField.register_lookup(ExtractIsoYear)
214
+ DateField.register_lookup(ExtractQuarter)
215
+
216
+ TimeField.register_lookup(ExtractHour)
217
+ TimeField.register_lookup(ExtractMinute)
218
+ TimeField.register_lookup(ExtractSecond)
219
+
220
+ DateTimeField.register_lookup(ExtractHour)
221
+ DateTimeField.register_lookup(ExtractMinute)
222
+ DateTimeField.register_lookup(ExtractSecond)
223
+
224
+ ExtractYear.register_lookup(YearExact)
225
+ ExtractYear.register_lookup(YearGt)
226
+ ExtractYear.register_lookup(YearGte)
227
+ ExtractYear.register_lookup(YearLt)
228
+ ExtractYear.register_lookup(YearLte)
229
+
230
+ ExtractIsoYear.register_lookup(YearExact)
231
+ ExtractIsoYear.register_lookup(YearGt)
232
+ ExtractIsoYear.register_lookup(YearGte)
233
+ ExtractIsoYear.register_lookup(YearLt)
234
+ ExtractIsoYear.register_lookup(YearLte)
235
+
236
+
237
+ class Now(Func):
238
+ # STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
239
+ # as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
240
+ # transaction.
241
+ template = "STATEMENT_TIMESTAMP()"
242
+ output_field = DateTimeField()
243
+
244
+
245
+ class TruncBase(TimezoneMixin, Transform):
246
+ kind: str | None = None
247
+
248
+ def __init__(
249
+ self,
250
+ expression: Any,
251
+ output_field: Field | None = None,
252
+ tzinfo: Any = None,
253
+ **extra: Any,
254
+ ) -> None:
255
+ self.tzinfo = tzinfo
256
+ super().__init__(expression, output_field=output_field, **extra)
257
+
258
+ def as_sql(
259
+ self,
260
+ compiler: SQLCompiler,
261
+ connection: DatabaseConnection,
262
+ function: str | None = None,
263
+ template: str | None = None,
264
+ arg_joiner: str | None = None,
265
+ **extra_context: Any,
266
+ ) -> tuple[str, list[Any]]:
267
+ # kind is guaranteed to be str in subclasses
268
+ assert self.kind is not None
269
+ sql, params = compiler.compile(self.lhs)
270
+ tzname = None
271
+ if isinstance(self.lhs.output_field, DateTimeField):
272
+ tzname = self.get_tzname()
273
+ elif self.tzinfo is not None:
274
+ raise ValueError("tzinfo can only be used with DateTimeField.")
275
+ if isinstance(self.output_field, DateTimeField):
276
+ sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
277
+ elif isinstance(self.output_field, DateField):
278
+ sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
279
+ elif isinstance(self.output_field, TimeField):
280
+ sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
281
+ else:
282
+ raise ValueError(
283
+ "Trunc only valid on DateField, TimeField, or DateTimeField."
284
+ )
285
+ return sql, list(params)
286
+
287
+ def resolve_expression(
288
+ self,
289
+ query: Any = None,
290
+ allow_joins: bool = True,
291
+ reuse: Any = None,
292
+ summarize: bool = False,
293
+ for_save: bool = False,
294
+ ) -> TruncBase:
295
+ copy = super().resolve_expression(
296
+ query, allow_joins, reuse, summarize, for_save
297
+ )
298
+ field = copy.lhs.output_field
299
+ # DateTimeField is a subclass of DateField so this works for both.
300
+ if not isinstance(field, DateField | TimeField):
301
+ raise TypeError(
302
+ f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
303
+ )
304
+ # If self.output_field was None, then accessing the field will trigger
305
+ # the resolver to assign it to self.lhs.output_field.
306
+ if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
307
+ raise ValueError(
308
+ "output_field must be either DateField, TimeField, or DateTimeField"
309
+ )
310
+ # Passing dates or times to functions expecting datetimes is most
311
+ # likely a mistake.
312
+ class_output_field = (
313
+ self.__class__.output_field
314
+ if isinstance(self.__class__.output_field, Field)
315
+ else None
316
+ )
317
+ output_field = class_output_field or copy.output_field
318
+ has_explicit_output_field = (
319
+ class_output_field or field.__class__ is not copy.output_field.__class__
320
+ )
321
+ if type(field) == DateField and ( # noqa: E721
322
+ isinstance(output_field, DateTimeField)
323
+ or copy.kind in ("hour", "minute", "second", "time")
324
+ ):
325
+ raise ValueError(
326
+ "Cannot truncate DateField '{}' to {}.".format(
327
+ field.name,
328
+ output_field.__class__.__name__
329
+ if has_explicit_output_field
330
+ else "DateTimeField",
331
+ )
332
+ )
333
+ elif isinstance(field, TimeField) and (
334
+ isinstance(output_field, DateTimeField)
335
+ or copy.kind in ("year", "quarter", "month", "week", "day", "date")
336
+ ):
337
+ raise ValueError(
338
+ "Cannot truncate TimeField '{}' to {}.".format(
339
+ field.name,
340
+ output_field.__class__.__name__
341
+ if has_explicit_output_field
342
+ else "DateTimeField",
343
+ )
344
+ )
345
+ return copy
346
+
347
+ def convert_value(
348
+ self, value: Any, expression: Any, connection: DatabaseConnection
349
+ ) -> Any:
350
+ if isinstance(self.output_field, DateTimeField):
351
+ if value is not None:
352
+ value = value.replace(tzinfo=None)
353
+ value = timezone.make_aware(value, self.tzinfo)
354
+ elif isinstance(value, datetime):
355
+ if value is None:
356
+ pass
357
+ elif isinstance(self.output_field, DateField):
358
+ value = value.date()
359
+ elif isinstance(self.output_field, TimeField):
360
+ value = value.time()
361
+ return value
362
+
363
+
364
+ class Trunc(TruncBase):
365
+ def __init__(
366
+ self,
367
+ expression: Any,
368
+ kind: str,
369
+ output_field: Field | None = None,
370
+ tzinfo: Any = None,
371
+ **extra: Any,
372
+ ) -> None:
373
+ self.kind = kind
374
+ super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
375
+
376
+
377
+ class TruncYear(TruncBase):
378
+ kind = "year"
379
+
380
+
381
+ class TruncQuarter(TruncBase):
382
+ kind = "quarter"
383
+
384
+
385
+ class TruncMonth(TruncBase):
386
+ kind = "month"
387
+
388
+
389
+ class TruncWeek(TruncBase):
390
+ """Truncate to midnight on the Monday of the week."""
391
+
392
+ kind = "week"
393
+
394
+
395
+ class TruncDay(TruncBase):
396
+ kind = "day"
397
+
398
+
399
+ class TruncDate(TruncBase):
400
+ kind = "date"
401
+ lookup_name = "date"
402
+ output_field = DateField()
403
+
404
+ def as_sql(
405
+ self,
406
+ compiler: SQLCompiler,
407
+ connection: DatabaseConnection,
408
+ function: str | None = None,
409
+ template: str | None = None,
410
+ arg_joiner: str | None = None,
411
+ **extra_context: Any,
412
+ ) -> tuple[str, list[Any]]:
413
+ # Cast to date rather than truncate to date.
414
+ sql, params = compiler.compile(self.lhs)
415
+ tzname = self.get_tzname()
416
+ sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
417
+ return sql, list(params)
418
+
419
+
420
+ class TruncTime(TruncBase):
421
+ kind = "time"
422
+ lookup_name = "time"
423
+ output_field = TimeField()
424
+
425
+ def as_sql(
426
+ self,
427
+ compiler: SQLCompiler,
428
+ connection: DatabaseConnection,
429
+ function: str | None = None,
430
+ template: str | None = None,
431
+ arg_joiner: str | None = None,
432
+ **extra_context: Any,
433
+ ) -> tuple[str, list[Any]]:
434
+ # Cast to time rather than truncate to time.
435
+ sql, params = compiler.compile(self.lhs)
436
+ tzname = self.get_tzname()
437
+ sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
438
+ return sql, list(params)
439
+
440
+
441
+ class TruncHour(TruncBase):
442
+ kind = "hour"
443
+
444
+
445
+ class TruncMinute(TruncBase):
446
+ kind = "minute"
447
+
448
+
449
+ class TruncSecond(TruncBase):
450
+ kind = "second"
451
+
452
+
453
+ DateTimeField.register_lookup(TruncDate)
454
+ DateTimeField.register_lookup(TruncTime)
@@ -0,0 +1,140 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from plain.postgres.expressions import Func
6
+ from plain.postgres.fields import Field
7
+ from plain.postgres.functions.mixins import (
8
+ FixDecimalInputMixin,
9
+ NumericOutputFieldMixin,
10
+ )
11
+ from plain.postgres.lookups import Transform
12
+
13
+ if TYPE_CHECKING:
14
+ pass
15
+
16
+
17
+ class Abs(Transform):
18
+ function = "ABS"
19
+ lookup_name = "abs"
20
+
21
+
22
+ class ACos(NumericOutputFieldMixin, Transform):
23
+ function = "ACOS"
24
+ lookup_name = "acos"
25
+
26
+
27
+ class ASin(NumericOutputFieldMixin, Transform):
28
+ function = "ASIN"
29
+ lookup_name = "asin"
30
+
31
+
32
+ class ATan(NumericOutputFieldMixin, Transform):
33
+ function = "ATAN"
34
+ lookup_name = "atan"
35
+
36
+
37
+ class ATan2(NumericOutputFieldMixin, Func):
38
+ function = "ATAN2"
39
+ arity = 2
40
+
41
+
42
+ class Ceil(Transform):
43
+ function = "CEILING"
44
+ lookup_name = "ceil"
45
+
46
+
47
+ class Cos(NumericOutputFieldMixin, Transform):
48
+ function = "COS"
49
+ lookup_name = "cos"
50
+
51
+
52
+ class Cot(NumericOutputFieldMixin, Transform):
53
+ function = "COT"
54
+ lookup_name = "cot"
55
+
56
+
57
+ class Degrees(NumericOutputFieldMixin, Transform):
58
+ function = "DEGREES"
59
+ lookup_name = "degrees"
60
+
61
+
62
+ class Exp(NumericOutputFieldMixin, Transform):
63
+ function = "EXP"
64
+ lookup_name = "exp"
65
+
66
+
67
+ class Floor(Transform):
68
+ function = "FLOOR"
69
+ lookup_name = "floor"
70
+
71
+
72
+ class Ln(NumericOutputFieldMixin, Transform):
73
+ function = "LN"
74
+ lookup_name = "ln"
75
+
76
+
77
+ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
78
+ function = "LOG"
79
+ arity = 2
80
+
81
+
82
+ class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
83
+ function = "MOD"
84
+ arity = 2
85
+
86
+
87
+ class Pi(NumericOutputFieldMixin, Func):
88
+ function = "PI"
89
+ arity = 0
90
+
91
+
92
+ class Power(NumericOutputFieldMixin, Func):
93
+ function = "POWER"
94
+ arity = 2
95
+
96
+
97
+ class Radians(NumericOutputFieldMixin, Transform):
98
+ function = "RADIANS"
99
+ lookup_name = "radians"
100
+
101
+
102
+ class Random(NumericOutputFieldMixin, Func):
103
+ function = "RANDOM"
104
+ arity = 0
105
+
106
+ def get_group_by_cols(self) -> list[Any]:
107
+ return []
108
+
109
+
110
+ class Round(FixDecimalInputMixin, Transform):
111
+ function = "ROUND"
112
+ lookup_name = "round"
113
+ arity = None # Override Transform's arity=1 to enable passing precision.
114
+
115
+ def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
116
+ super().__init__(expression, precision, **extra)
117
+
118
+ def _resolve_output_field(self) -> Field:
119
+ source = self.get_source_expressions()[0]
120
+ return source.output_field
121
+
122
+
123
+ class Sign(Transform):
124
+ function = "SIGN"
125
+ lookup_name = "sign"
126
+
127
+
128
+ class Sin(NumericOutputFieldMixin, Transform):
129
+ function = "SIN"
130
+ lookup_name = "sin"
131
+
132
+
133
+ class Sqrt(NumericOutputFieldMixin, Transform):
134
+ function = "SQRT"
135
+ lookup_name = "sqrt"
136
+
137
+
138
+ class Tan(NumericOutputFieldMixin, Transform):
139
+ function = "TAN"
140
+ lookup_name = "tan"
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from plain.postgres.expressions import Func
7
+ from plain.postgres.fields import DecimalField, Field, FloatField, IntegerField
8
+ from plain.postgres.functions import Cast
9
+
10
+ if TYPE_CHECKING:
11
+ from plain.postgres.connection import DatabaseConnection
12
+ from plain.postgres.sql.compiler import SQLCompiler
13
+
14
+
15
+ class FixDecimalInputMixin(Func):
16
+ """
17
+ Mixin for Func subclasses that need to convert FloatField to DecimalField.
18
+
19
+ PostgreSQL doesn't support the following function signatures:
20
+ - LOG(double, double)
21
+ - MOD(double, double)
22
+ """
23
+
24
+ def as_sql(
25
+ self,
26
+ compiler: SQLCompiler,
27
+ connection: DatabaseConnection,
28
+ function: str | None = None,
29
+ template: str | None = None,
30
+ arg_joiner: str | None = None,
31
+ **extra_context: Any,
32
+ ) -> tuple[str, list[Any]]:
33
+ output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
34
+
35
+ clone = self.copy()
36
+ clone.set_source_expressions(
37
+ [
38
+ Cast(expression, output_field)
39
+ if isinstance(expression.output_field, FloatField)
40
+ else expression
41
+ for expression in self.get_source_expressions()
42
+ ]
43
+ )
44
+ return super(FixDecimalInputMixin, clone).as_sql(
45
+ compiler, connection, **extra_context
46
+ )
47
+
48
+
49
+ class NumericOutputFieldMixin(Func):
50
+ def _resolve_output_field(self) -> DecimalField | FloatField | Field:
51
+ source_fields = self.get_source_fields()
52
+ if any(isinstance(s, DecimalField) for s in source_fields):
53
+ return DecimalField()
54
+ if any(isinstance(s, IntegerField) for s in source_fields):
55
+ return FloatField()
56
+ if source_fields:
57
+ if result := super()._resolve_output_field():
58
+ return result
59
+ return FloatField()