plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,1912 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import datetime
5
+ import functools
6
+ import inspect
7
+ from collections import defaultdict
8
+ from decimal import Decimal
9
+ from functools import cached_property
10
+ from types import NoneType
11
+ from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
12
+ from uuid import UUID
13
+
14
+ from plain.postgres import fields
15
+ from plain.postgres.constants import LOOKUP_SEP
16
+ from plain.postgres.db import NotSupportedError
17
+ from plain.postgres.dialect import (
18
+ CURRENT_ROW,
19
+ FOLLOWING,
20
+ PRECEDING,
21
+ UNBOUNDED_FOLLOWING,
22
+ UNBOUNDED_PRECEDING,
23
+ combine_expression,
24
+ quote_name,
25
+ subtract_temporals,
26
+ window_frame_range_start_end,
27
+ window_frame_rows_start_end,
28
+ )
29
+ from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
30
+ from plain.postgres.query_utils import Q
31
+ from plain.utils.deconstruct import deconstructible
32
+ from plain.utils.hashable import make_hashable
33
+
34
+ if TYPE_CHECKING:
35
+ from collections.abc import Callable, Iterable, Sequence
36
+
37
+ from plain.postgres.connection import DatabaseConnection
38
+ from plain.postgres.fields import Field
39
+ from plain.postgres.lookups import Lookup, Transform
40
+ from plain.postgres.query import QuerySet
41
+ from plain.postgres.sql.compiler import SQLCompilable, SQLCompiler
42
+ from plain.postgres.sql.query import Query
43
+
44
+ __all__ = [
45
+ # Core expression classes
46
+ "F",
47
+ "Value",
48
+ "Case",
49
+ "When",
50
+ "Subquery",
51
+ "Exists",
52
+ "OuterRef",
53
+ "Window",
54
+ "ExpressionWrapper",
55
+ "RawSQL",
56
+ "OrderBy",
57
+ # Base classes (for extension)
58
+ "Func",
59
+ "Expression",
60
+ "Combinable",
61
+ # Window frame specs
62
+ "RowRange",
63
+ "ValueRange",
64
+ ]
65
+
66
+
67
+ @runtime_checkable
68
+ class ResolvableExpression(Protocol):
69
+ """Protocol for expressions that can be resolved in query context."""
70
+
71
+ def resolve_expression(
72
+ self,
73
+ query: Any = None,
74
+ allow_joins: bool = True,
75
+ reuse: Any = None,
76
+ summarize: bool = False,
77
+ for_save: bool = False,
78
+ ) -> Any: ...
79
+
80
+
81
+ @runtime_checkable
82
+ class ReplaceableExpression(Protocol):
83
+ """Protocol for expressions that support expression replacement."""
84
+
85
+ def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
86
+
87
+
88
+ class Combinable:
89
+ """
90
+ Provide the ability to combine one or two objects with
91
+ some connector. For example F('foo') + F('bar').
92
+ """
93
+
94
+ # Arithmetic connectors
95
+ ADD = "+"
96
+ SUB = "-"
97
+ MUL = "*"
98
+ DIV = "/"
99
+ POW = "^"
100
+ # The following is a quoted % operator - it is quoted because it can be
101
+ # used in strings that also have parameter substitution.
102
+ MOD = "%%"
103
+
104
+ # Bitwise operators - note that these are generated by .bitand()
105
+ # and .bitor(), the '&' and '|' are reserved for boolean operator
106
+ # usage.
107
+ BITAND = "&"
108
+ BITOR = "|"
109
+ BITLEFTSHIFT = "<<"
110
+ BITRIGHTSHIFT = ">>"
111
+ BITXOR = "#"
112
+
113
+ def _combine(
114
+ self, other: Any, connector: str, reversed: bool
115
+ ) -> CombinedExpression:
116
+ if not isinstance(other, ResolvableExpression):
117
+ # everything must be resolvable to an expression
118
+ other = Value(other)
119
+
120
+ if reversed:
121
+ return CombinedExpression(other, connector, self)
122
+ return CombinedExpression(self, connector, other)
123
+
124
+ #############
125
+ # OPERATORS #
126
+ #############
127
+
128
+ def __neg__(self) -> CombinedExpression:
129
+ return self._combine(-1, self.MUL, False)
130
+
131
+ def __add__(self, other: Any) -> CombinedExpression:
132
+ return self._combine(other, self.ADD, False)
133
+
134
+ def __sub__(self, other: Any) -> CombinedExpression:
135
+ return self._combine(other, self.SUB, False)
136
+
137
+ def __mul__(self, other: Any) -> CombinedExpression:
138
+ return self._combine(other, self.MUL, False)
139
+
140
+ def __truediv__(self, other: Any) -> CombinedExpression:
141
+ return self._combine(other, self.DIV, False)
142
+
143
+ def __mod__(self, other: Any) -> CombinedExpression:
144
+ return self._combine(other, self.MOD, False)
145
+
146
+ def __pow__(self, other: Any) -> CombinedExpression:
147
+ return self._combine(other, self.POW, False)
148
+
149
+ def __and__(self, other: Any) -> Q:
150
+ if getattr(self, "conditional", False) and getattr(other, "conditional", False):
151
+ return Q(self) & Q(other)
152
+ raise NotImplementedError(
153
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
154
+ )
155
+
156
+ def bitand(self, other: Any) -> CombinedExpression:
157
+ return self._combine(other, self.BITAND, False)
158
+
159
+ def bitleftshift(self, other: Any) -> CombinedExpression:
160
+ return self._combine(other, self.BITLEFTSHIFT, False)
161
+
162
+ def bitrightshift(self, other: Any) -> CombinedExpression:
163
+ return self._combine(other, self.BITRIGHTSHIFT, False)
164
+
165
+ def __xor__(self, other: Any) -> Q:
166
+ if getattr(self, "conditional", False) and getattr(other, "conditional", False):
167
+ return Q(self) ^ Q(other)
168
+ raise NotImplementedError(
169
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
170
+ )
171
+
172
+ def bitxor(self, other: Any) -> CombinedExpression:
173
+ return self._combine(other, self.BITXOR, False)
174
+
175
+ def __or__(self, other: Any) -> Q:
176
+ if getattr(self, "conditional", False) and getattr(other, "conditional", False):
177
+ return Q(self) | Q(other)
178
+ raise NotImplementedError(
179
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
180
+ )
181
+
182
+ def bitor(self, other: Any) -> CombinedExpression:
183
+ return self._combine(other, self.BITOR, False)
184
+
185
+ def __radd__(self, other: Any) -> CombinedExpression:
186
+ return self._combine(other, self.ADD, True)
187
+
188
+ def __rsub__(self, other: Any) -> CombinedExpression:
189
+ return self._combine(other, self.SUB, True)
190
+
191
+ def __rmul__(self, other: Any) -> CombinedExpression:
192
+ return self._combine(other, self.MUL, True)
193
+
194
+ def __rtruediv__(self, other: Any) -> CombinedExpression:
195
+ return self._combine(other, self.DIV, True)
196
+
197
+ def __rmod__(self, other: Any) -> CombinedExpression:
198
+ return self._combine(other, self.MOD, True)
199
+
200
+ def __rpow__(self, other: Any) -> CombinedExpression:
201
+ return self._combine(other, self.POW, True)
202
+
203
+ def __rand__(self, other: Any) -> None:
204
+ raise NotImplementedError(
205
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
206
+ )
207
+
208
+ def __ror__(self, other: Any) -> None:
209
+ raise NotImplementedError(
210
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
211
+ )
212
+
213
+ def __rxor__(self, other: Any) -> None:
214
+ raise NotImplementedError(
215
+ "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
216
+ )
217
+
218
+ def __invert__(self) -> NegatedExpression:
219
+ return NegatedExpression(self)
220
+
221
+
222
+ class BaseExpression:
223
+ """Base class for all query expressions."""
224
+
225
+ empty_result_set_value = NotImplemented
226
+ # aggregate specific fields
227
+ is_summary = False
228
+ _output_field_resolved_to_none = False
229
+ # Can the expression be used in a WHERE clause?
230
+ filterable = True
231
+ # Can the expression can be used as a source expression in Window?
232
+ window_compatible = False
233
+
234
+ def __init__(self, output_field: Field | None = None):
235
+ if output_field is not None:
236
+ self.output_field = output_field
237
+
238
+ def __getstate__(self) -> dict[str, Any]:
239
+ state = self.__dict__.copy()
240
+ state.pop("convert_value", None)
241
+ return state
242
+
243
+ def get_db_converters(
244
+ self, connection: DatabaseConnection
245
+ ) -> list[Callable[..., Any]]:
246
+ converters = []
247
+ if self.convert_value is not self._convert_value_noop:
248
+ converters.append(self.convert_value)
249
+ converters.extend(self.output_field.get_db_converters(connection))
250
+ return converters
251
+
252
+ def get_source_expressions(self) -> list[Any]:
253
+ return []
254
+
255
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
256
+ assert not exprs
257
+
258
+ def _parse_expressions(self, *expressions: Any) -> list[Any]:
259
+ return [
260
+ arg
261
+ if isinstance(arg, ResolvableExpression)
262
+ else (F(arg) if isinstance(arg, str) else Value(arg))
263
+ for arg in expressions
264
+ ]
265
+
266
+ def as_sql(
267
+ self, compiler: SQLCompiler, connection: DatabaseConnection
268
+ ) -> tuple[str, Sequence[Any]]:
269
+ """
270
+ Return a (sql, params) tuple to be included in the current query.
271
+
272
+ Arguments:
273
+ * compiler: the query compiler responsible for generating the query.
274
+ Must have a compile method, returning a (sql, [params]) tuple.
275
+ Calling compiler(value) will return a quoted `value`.
276
+
277
+ * connection: the database connection used for the current query.
278
+
279
+ Return: (sql, params)
280
+ Where `sql` is a string containing ordered sql parameters to be
281
+ replaced with the elements of the list `params`.
282
+ """
283
+ raise NotImplementedError("Subclasses must implement as_sql()")
284
+
285
+ @cached_property
286
+ def contains_aggregate(self) -> bool:
287
+ return any(
288
+ expr and expr.contains_aggregate for expr in self.get_source_expressions()
289
+ )
290
+
291
+ @cached_property
292
+ def contains_over_clause(self) -> bool:
293
+ return any(
294
+ expr and expr.contains_over_clause for expr in self.get_source_expressions()
295
+ )
296
+
297
+ @cached_property
298
+ def contains_column_references(self) -> bool:
299
+ return any(
300
+ expr and expr.contains_column_references
301
+ for expr in self.get_source_expressions()
302
+ )
303
+
304
+ def resolve_expression(
305
+ self,
306
+ query: Any = None,
307
+ allow_joins: bool = True,
308
+ reuse: Any = None,
309
+ summarize: bool = False,
310
+ for_save: bool = False,
311
+ ) -> Self:
312
+ """
313
+ Provide the chance to do any preprocessing or validation before being
314
+ added to the query.
315
+
316
+ Arguments:
317
+ * query: the backend query implementation
318
+ * allow_joins: boolean allowing or denying use of joins
319
+ in this query
320
+ * reuse: a set of reusable joins for multijoins
321
+ * summarize: a terminal aggregate clause
322
+ * for_save: whether this expression about to be used in a save or update
323
+
324
+ Return: an Expression to be added to the query.
325
+ """
326
+ c = self.copy()
327
+ c.is_summary = summarize
328
+ c.set_source_expressions(
329
+ [
330
+ expr.resolve_expression(query, allow_joins, reuse, summarize)
331
+ if expr
332
+ else None
333
+ for expr in c.get_source_expressions()
334
+ ]
335
+ )
336
+ return c
337
+
338
+ @property
339
+ def conditional(self) -> bool:
340
+ output_field = getattr(self, "output_field", None)
341
+ return isinstance(output_field, fields.BooleanField)
342
+
343
+ @property
344
+ def field(self) -> Field:
345
+ return self.output_field
346
+
347
+ @cached_property
348
+ def output_field(self) -> Field:
349
+ """Return the output type of this expressions."""
350
+ output_field = self._resolve_output_field()
351
+ if output_field is None:
352
+ self._output_field_resolved_to_none = True
353
+ raise FieldError("Cannot resolve expression type, unknown output_field")
354
+ return output_field
355
+
356
+ @cached_property
357
+ def _output_field_or_none(self) -> Field | None:
358
+ """
359
+ Return the output field of this expression, or None if
360
+ _resolve_output_field() didn't return an output type.
361
+ """
362
+ try:
363
+ return self.output_field
364
+ except FieldError:
365
+ if not self._output_field_resolved_to_none:
366
+ raise
367
+ return None
368
+
369
+ def _resolve_output_field(self) -> Field | None:
370
+ """
371
+ Attempt to infer the output type of the expression.
372
+
373
+ As a guess, if the output fields of all source fields match then simply
374
+ infer the same type here.
375
+
376
+ If a source's output field resolves to None, exclude it from this check.
377
+ If all sources are None, then an error is raised higher up the stack in
378
+ the output_field property.
379
+ """
380
+ # This guess is mostly a bad idea, but there is quite a lot of code
381
+ # (especially 3rd party Func subclasses) that depend on it, we'd need a
382
+ # deprecation path to fix it.
383
+ sources_iter = (
384
+ source for source in self.get_source_fields() if source is not None
385
+ )
386
+ for output_field in sources_iter:
387
+ for source in sources_iter:
388
+ if not isinstance(output_field, source.__class__):
389
+ raise FieldError(
390
+ f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
391
+ "set output_field."
392
+ )
393
+ return output_field
394
+ return None
395
+
396
+ @staticmethod
397
+ def _convert_value_noop(
398
+ value: Any, expression: Any, connection: DatabaseConnection
399
+ ) -> Any:
400
+ return value
401
+
402
+ @cached_property
403
+ def convert_value(self) -> Callable[[Any, Any, Any], Any]:
404
+ """
405
+ Expressions provide their own converters because users have the option
406
+ of manually specifying the output_field which may be a different type
407
+ from the one the database returns.
408
+ """
409
+ field = self.output_field
410
+ internal_type = field.get_internal_type()
411
+ if internal_type == "FloatField":
412
+ return (
413
+ lambda value, expression, connection: None
414
+ if value is None
415
+ else float(value)
416
+ )
417
+ elif internal_type.endswith("IntegerField"):
418
+ return (
419
+ lambda value, expression, connection: None
420
+ if value is None
421
+ else int(value)
422
+ )
423
+ elif internal_type == "DecimalField":
424
+ return (
425
+ lambda value, expression, connection: None
426
+ if value is None
427
+ else Decimal(value)
428
+ )
429
+ return self._convert_value_noop
430
+
431
+ def get_lookup(self, lookup: str) -> type[Lookup] | None:
432
+ return self.output_field.get_lookup(lookup)
433
+
434
+ def get_transform(self, name: str) -> type[Transform] | None:
435
+ return self.output_field.get_transform(name) # type: ignore[return-type]
436
+
437
+ def relabeled_clone(self, change_map: dict[str, str]) -> Self:
438
+ clone = self.copy()
439
+ clone.set_source_expressions(
440
+ [
441
+ e.relabeled_clone(change_map) if e is not None else None
442
+ for e in self.get_source_expressions()
443
+ ]
444
+ )
445
+ return clone
446
+
447
+ def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
448
+ if replacement := replacements.get(self):
449
+ return replacement
450
+ clone = self.copy()
451
+ source_expressions = clone.get_source_expressions()
452
+ clone.set_source_expressions(
453
+ [
454
+ expr.replace_expressions(replacements) if expr else None
455
+ for expr in source_expressions
456
+ ]
457
+ )
458
+ return clone
459
+
460
+ def get_refs(self) -> set[str]:
461
+ refs = set()
462
+ for expr in self.get_source_expressions():
463
+ refs |= expr.get_refs()
464
+ return refs
465
+
466
+ def copy(self) -> Self:
467
+ return copy.copy(self)
468
+
469
+ def prefix_references(self, prefix: str) -> Self:
470
+ clone = self.copy()
471
+ clone.set_source_expressions(
472
+ [
473
+ F(f"{prefix}{expr.name}")
474
+ if isinstance(expr, F)
475
+ else expr.prefix_references(prefix)
476
+ for expr in self.get_source_expressions()
477
+ ]
478
+ )
479
+ return clone
480
+
481
+ def get_group_by_cols(self) -> list[BaseExpression]:
482
+ if not self.contains_aggregate:
483
+ return [self]
484
+ cols: list[BaseExpression] = []
485
+ for source in self.get_source_expressions():
486
+ cols.extend(source.get_group_by_cols())
487
+ return cols
488
+
489
+ def get_source_fields(self) -> list[Field | None]:
490
+ """Return the underlying field types used by this aggregate."""
491
+ return [e._output_field_or_none for e in self.get_source_expressions()]
492
+
493
+ def asc(self, **kwargs: Any) -> OrderBy:
494
+ return OrderBy(self, **kwargs)
495
+
496
+ def desc(self, **kwargs: Any) -> OrderBy:
497
+ return OrderBy(self, descending=True, **kwargs)
498
+
499
+ def reverse_ordering(self) -> Self:
500
+ return self
501
+
502
+ def flatten(self) -> Iterable[Any]:
503
+ """
504
+ Recursively yield this expression and all subexpressions, in
505
+ depth-first order.
506
+ """
507
+ yield self
508
+ for expr in self.get_source_expressions():
509
+ if expr:
510
+ if hasattr(expr, "flatten"):
511
+ yield from expr.flatten()
512
+ else:
513
+ yield expr
514
+
515
+ def select_format(
516
+ self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
517
+ ) -> tuple[str, Sequence[Any]]:
518
+ """Custom format for select clauses."""
519
+ if output_field := getattr(self, "output_field", None):
520
+ if select_format := getattr(output_field, "select_format", None):
521
+ return select_format(compiler, sql, params)
522
+ return sql, params
523
+
524
+
525
+ @deconstructible
526
+ class Expression(BaseExpression, Combinable):
527
+ """An expression that can be combined with other expressions."""
528
+
529
+ # Set by @deconstructible decorator in __new__
530
+ _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
531
+
532
+ @cached_property
533
+ def identity(self) -> tuple[Any, ...]:
534
+ constructor_signature = inspect.signature(self.__init__)
535
+ args, kwargs = self._constructor_args
536
+ signature = constructor_signature.bind_partial(*args, **kwargs)
537
+ signature.apply_defaults()
538
+ arguments = signature.arguments.items()
539
+ identity: list[Any] = [self.__class__]
540
+ for arg, value in arguments:
541
+ if isinstance(value, fields.Field):
542
+ if value.name and value.model:
543
+ value = (value.model.model_options.label, value.name)
544
+ else:
545
+ value = type(value)
546
+ else:
547
+ value = make_hashable(value)
548
+ identity.append((arg, value))
549
+ return tuple(identity)
550
+
551
+ def __eq__(self, other: object) -> bool:
552
+ if not isinstance(other, Expression):
553
+ return NotImplemented
554
+ return other.identity == self.identity
555
+
556
+ def __hash__(self) -> int:
557
+ return hash(self.identity)
558
+
559
+
560
+ # Type inference for CombinedExpression.output_field.
561
+ # Missing items will result in FieldError, by design.
562
+ #
563
+ # The current approach for NULL is based on lowest common denominator behavior
564
+ # i.e. if one of the supported databases is raising an error (rather than
565
+ # return NULL) for `val <op> NULL`, then Plain raises FieldError.
566
+
567
+ _connector_combinations = [
568
+ # Numeric operations - operands of same type.
569
+ {
570
+ connector: [
571
+ (fields.IntegerField, fields.IntegerField, fields.IntegerField),
572
+ (fields.FloatField, fields.FloatField, fields.FloatField),
573
+ (fields.DecimalField, fields.DecimalField, fields.DecimalField),
574
+ ]
575
+ for connector in (
576
+ Combinable.ADD,
577
+ Combinable.SUB,
578
+ Combinable.MUL,
579
+ Combinable.DIV,
580
+ Combinable.MOD,
581
+ Combinable.POW,
582
+ )
583
+ },
584
+ # Numeric operations - operands of different type.
585
+ {
586
+ connector: [
587
+ (fields.IntegerField, fields.DecimalField, fields.DecimalField),
588
+ (fields.DecimalField, fields.IntegerField, fields.DecimalField),
589
+ (fields.IntegerField, fields.FloatField, fields.FloatField),
590
+ (fields.FloatField, fields.IntegerField, fields.FloatField),
591
+ ]
592
+ for connector in (
593
+ Combinable.ADD,
594
+ Combinable.SUB,
595
+ Combinable.MUL,
596
+ Combinable.DIV,
597
+ Combinable.MOD,
598
+ )
599
+ },
600
+ # Bitwise operators.
601
+ {
602
+ connector: [
603
+ (fields.IntegerField, fields.IntegerField, fields.IntegerField),
604
+ ]
605
+ for connector in (
606
+ Combinable.BITAND,
607
+ Combinable.BITOR,
608
+ Combinable.BITLEFTSHIFT,
609
+ Combinable.BITRIGHTSHIFT,
610
+ Combinable.BITXOR,
611
+ )
612
+ },
613
+ # Numeric with NULL.
614
+ {
615
+ connector: [
616
+ (field_type, NoneType, field_type),
617
+ (NoneType, field_type, field_type),
618
+ ]
619
+ for connector in (
620
+ Combinable.ADD,
621
+ Combinable.SUB,
622
+ Combinable.MUL,
623
+ Combinable.DIV,
624
+ Combinable.MOD,
625
+ Combinable.POW,
626
+ )
627
+ for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
628
+ },
629
+ # Date/DateTimeField/DurationField/TimeField.
630
+ {
631
+ Combinable.ADD: [
632
+ # Date/DateTimeField.
633
+ (fields.DateField, fields.DurationField, fields.DateTimeField),
634
+ (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
635
+ (fields.DurationField, fields.DateField, fields.DateTimeField),
636
+ (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
637
+ # DurationField.
638
+ (fields.DurationField, fields.DurationField, fields.DurationField),
639
+ # TimeField.
640
+ (fields.TimeField, fields.DurationField, fields.TimeField),
641
+ (fields.DurationField, fields.TimeField, fields.TimeField),
642
+ ],
643
+ },
644
+ {
645
+ Combinable.SUB: [
646
+ # Date/DateTimeField.
647
+ (fields.DateField, fields.DurationField, fields.DateTimeField),
648
+ (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
649
+ (fields.DateField, fields.DateField, fields.DurationField),
650
+ (fields.DateField, fields.DateTimeField, fields.DurationField),
651
+ (fields.DateTimeField, fields.DateField, fields.DurationField),
652
+ (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
653
+ # DurationField.
654
+ (fields.DurationField, fields.DurationField, fields.DurationField),
655
+ # TimeField.
656
+ (fields.TimeField, fields.DurationField, fields.TimeField),
657
+ (fields.TimeField, fields.TimeField, fields.DurationField),
658
+ ],
659
+ },
660
+ ]
661
+
662
+ _connector_combinators = defaultdict(list)
663
+
664
+
665
+ def register_combinable_fields(
666
+ lhs: type[Field] | type[None],
667
+ connector: str,
668
+ rhs: type[Field] | type[None],
669
+ result: type[Field],
670
+ ) -> None:
671
+ """
672
+ Register combinable types:
673
+ lhs <connector> rhs -> result
674
+ e.g.
675
+ register_combinable_fields(
676
+ IntegerField, Combinable.ADD, FloatField, FloatField
677
+ )
678
+ """
679
+ _connector_combinators[connector].append((lhs, rhs, result))
680
+
681
+
682
+ for d in _connector_combinations:
683
+ for connector, field_types in d.items():
684
+ for lhs, rhs, result in field_types:
685
+ register_combinable_fields(lhs, connector, rhs, result)
686
+
687
+
688
+ @functools.lru_cache(maxsize=128)
689
+ def _resolve_combined_type(
690
+ connector: str, lhs_type: type[Field], rhs_type: type[Field]
691
+ ) -> type[Field] | None:
692
+ combinators = _connector_combinators.get(connector, ())
693
+ for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
694
+ if issubclass(lhs_type, combinator_lhs_type) and issubclass(
695
+ rhs_type, combinator_rhs_type
696
+ ):
697
+ return combined_type
698
+ return None
699
+
700
+
701
+ class CombinedExpression(Expression):
702
+ def __init__(
703
+ self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
704
+ ):
705
+ super().__init__(output_field=output_field)
706
+ self.connector = connector
707
+ self.lhs = lhs
708
+ self.rhs = rhs
709
+
710
+ def __repr__(self) -> str:
711
+ return f"<{self.__class__.__name__}: {self}>"
712
+
713
+ def __str__(self) -> str:
714
+ return f"{self.lhs} {self.connector} {self.rhs}"
715
+
716
+ def get_source_expressions(self) -> list[Any]:
717
+ return [self.lhs, self.rhs]
718
+
719
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
720
+ self.lhs, self.rhs = exprs
721
+
722
+ def _resolve_output_field(self) -> Field | None:
723
+ # We avoid using super() here for reasons given in
724
+ # Expression._resolve_output_field()
725
+ combined_type = _resolve_combined_type(
726
+ self.connector,
727
+ type(self.lhs._output_field_or_none),
728
+ type(self.rhs._output_field_or_none),
729
+ )
730
+ if combined_type is None:
731
+ raise FieldError(
732
+ f"Cannot infer type of {self.connector!r} expression involving these "
733
+ f"types: {self.lhs.output_field.__class__.__name__}, "
734
+ f"{self.rhs.output_field.__class__.__name__}. You must set "
735
+ f"output_field."
736
+ )
737
+ return combined_type()
738
+
739
+ def as_sql(
740
+ self, compiler: SQLCompiler, connection: DatabaseConnection
741
+ ) -> tuple[str, list[Any]]:
742
+ expressions = []
743
+ expression_params = []
744
+ sql, params = compiler.compile(self.lhs)
745
+ expressions.append(sql)
746
+ expression_params.extend(params)
747
+ sql, params = compiler.compile(self.rhs)
748
+ expressions.append(sql)
749
+ expression_params.extend(params)
750
+ # order of precedence
751
+ expression_wrapper = "(%s)"
752
+ sql = combine_expression(self.connector, expressions)
753
+ return expression_wrapper % sql, expression_params
754
+
755
+ def resolve_expression(
756
+ self,
757
+ query: Any = None,
758
+ allow_joins: bool = True,
759
+ reuse: Any = None,
760
+ summarize: bool = False,
761
+ for_save: bool = False,
762
+ ) -> CombinedExpression | TemporalSubtraction:
763
+ lhs = self.lhs.resolve_expression(
764
+ query, allow_joins, reuse, summarize, for_save
765
+ )
766
+ rhs = self.rhs.resolve_expression(
767
+ query, allow_joins, reuse, summarize, for_save
768
+ )
769
+ if not isinstance(self, TemporalSubtraction):
770
+ try:
771
+ lhs_type = lhs.output_field.get_internal_type()
772
+ except (AttributeError, FieldError):
773
+ lhs_type = None
774
+ try:
775
+ rhs_type = rhs.output_field.get_internal_type()
776
+ except (AttributeError, FieldError):
777
+ rhs_type = None
778
+ datetime_fields = {"DateField", "DateTimeField", "TimeField"}
779
+ if (
780
+ self.connector == self.SUB
781
+ and lhs_type in datetime_fields
782
+ and lhs_type == rhs_type
783
+ ):
784
+ return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
785
+ query,
786
+ allow_joins,
787
+ reuse,
788
+ summarize,
789
+ for_save,
790
+ )
791
+ c = self.copy()
792
+ c.is_summary = summarize
793
+ c.lhs = lhs
794
+ c.rhs = rhs
795
+ return c
796
+
797
+
798
+ class TemporalSubtraction(CombinedExpression):
799
+ output_field = fields.DurationField()
800
+
801
+ def __init__(self, lhs: Any, rhs: Any):
802
+ super().__init__(lhs, self.SUB, rhs)
803
+
804
+ def as_sql(
805
+ self, compiler: SQLCompiler, connection: DatabaseConnection
806
+ ) -> tuple[str, list[Any]]:
807
+ lhs = compiler.compile(self.lhs)
808
+ rhs = compiler.compile(self.rhs)
809
+ sql, params = subtract_temporals(
810
+ self.lhs.output_field.get_internal_type(), lhs, rhs
811
+ )
812
+ return sql, list(params)
813
+
814
+
815
+ @deconstructible(path="plain.postgres.F")
816
+ class F(Combinable):
817
+ """An object capable of resolving references to existing query objects."""
818
+
819
+ def __init__(self, name: str):
820
+ """
821
+ Arguments:
822
+ * name: the name of the field this expression references
823
+ """
824
+ self.name = name
825
+
826
+ def __repr__(self) -> str:
827
+ return f"{self.__class__.__name__}({self.name})"
828
+
829
+ def resolve_expression(
830
+ self,
831
+ query: Any = None,
832
+ allow_joins: bool = True,
833
+ reuse: Any = None,
834
+ summarize: bool = False,
835
+ for_save: bool = False,
836
+ ) -> Any:
837
+ return query.resolve_ref(self.name, allow_joins, reuse, summarize)
838
+
839
+ def replace_expressions(self, replacements: dict[Any, Any]) -> F:
840
+ return replacements.get(self, self)
841
+
842
+ def asc(self, **kwargs: Any) -> OrderBy:
843
+ return OrderBy(self, **kwargs)
844
+
845
+ def desc(self, **kwargs: Any) -> OrderBy:
846
+ return OrderBy(self, descending=True, **kwargs)
847
+
848
+ def __eq__(self, other: object) -> bool:
849
+ if not isinstance(other, F):
850
+ return NotImplemented
851
+ return self.__class__ == other.__class__ and self.name == other.name
852
+
853
+ def __hash__(self) -> int:
854
+ return hash(self.name)
855
+
856
+ def copy(self) -> Self:
857
+ return copy.copy(self)
858
+
859
+
860
+ class ResolvedOuterRef(F):
861
+ """
862
+ An object that contains a reference to an outer query.
863
+
864
+ In this case, the reference to the outer query has been resolved because
865
+ the inner query has been used as a subquery.
866
+ """
867
+
868
+ contains_aggregate = False
869
+ contains_over_clause = False
870
+
871
+ def as_sql(self, *args: Any, **kwargs: Any) -> None:
872
+ raise ValueError(
873
+ "This queryset contains a reference to an outer query and may "
874
+ "only be used in a subquery."
875
+ )
876
+
877
+ def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
878
+ col = super().resolve_expression(*args, **kwargs)
879
+ if col.contains_over_clause:
880
+ raise NotSupportedError(
881
+ f"Referencing outer query window expression is not supported: "
882
+ f"{self.name}."
883
+ )
884
+ # FIXME: Rename possibly_multivalued to multivalued and fix detection
885
+ # for non-multivalued JOINs (e.g. foreign key fields). This should take
886
+ # into account only many-to-many and one-to-many relationships.
887
+ col.possibly_multivalued = LOOKUP_SEP in self.name
888
+ return col
889
+
890
+ def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
891
+ return self
892
+
893
+ def get_group_by_cols(self) -> list[Any]:
894
+ return []
895
+
896
+
897
+ class OuterRef(F):
898
+ contains_aggregate = False
899
+
900
+ def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
901
+ if isinstance(self.name, self.__class__):
902
+ return self.name
903
+ return ResolvedOuterRef(self.name)
904
+
905
+ def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
906
+ return self
907
+
908
+
909
+ @deconstructible(path="plain.postgres.Func")
910
+ class Func(Expression):
911
+ """An SQL function call."""
912
+
913
+ function = None
914
+ template = "%(function)s(%(expressions)s)"
915
+ arg_joiner = ", "
916
+ arity = None # The number of arguments the function accepts.
917
+
918
+ def __init__(
919
+ self, *expressions: Any, output_field: Field | None = None, **extra: Any
920
+ ):
921
+ if self.arity is not None and len(expressions) != self.arity:
922
+ raise TypeError(
923
+ "'{}' takes exactly {} {} ({} given)".format(
924
+ self.__class__.__name__,
925
+ self.arity,
926
+ "argument" if self.arity == 1 else "arguments",
927
+ len(expressions),
928
+ )
929
+ )
930
+ super().__init__(output_field=output_field)
931
+ self.source_expressions: list[Any] = self._parse_expressions(*expressions)
932
+ self.extra = extra
933
+
934
+ def __repr__(self) -> str:
935
+ args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
936
+ extra = {**self.extra, **self._get_repr_options()}
937
+ if extra:
938
+ extra = ", ".join(
939
+ str(key) + "=" + str(val) for key, val in sorted(extra.items())
940
+ )
941
+ return f"{self.__class__.__name__}({args}, {extra})"
942
+ return f"{self.__class__.__name__}({args})"
943
+
944
+ def _get_repr_options(self) -> dict[str, Any]:
945
+ """Return a dict of extra __init__() options to include in the repr."""
946
+ return {}
947
+
948
+ def get_source_expressions(self) -> list[Any]:
949
+ return self.source_expressions
950
+
951
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
952
+ self.source_expressions = list(exprs)
953
+
954
+ def resolve_expression(
955
+ self,
956
+ query: Any = None,
957
+ allow_joins: bool = True,
958
+ reuse: Any = None,
959
+ summarize: bool = False,
960
+ for_save: bool = False,
961
+ ) -> Self:
962
+ c = self.copy()
963
+ c.is_summary = summarize
964
+ for pos, arg in enumerate(c.source_expressions):
965
+ c.source_expressions[pos] = arg.resolve_expression(
966
+ query, allow_joins, reuse, summarize, for_save
967
+ )
968
+ return c
969
+
970
+ def as_sql(
971
+ self,
972
+ compiler: SQLCompiler,
973
+ connection: DatabaseConnection,
974
+ function: str | None = None,
975
+ template: str | None = None,
976
+ arg_joiner: str | None = None,
977
+ **extra_context: Any,
978
+ ) -> tuple[str, list[Any]]:
979
+ sql_parts = []
980
+ params = []
981
+ for arg in self.source_expressions:
982
+ try:
983
+ arg_sql, arg_params = compiler.compile(arg)
984
+ except EmptyResultSet:
985
+ empty_result_set_value = getattr(
986
+ arg, "empty_result_set_value", NotImplemented
987
+ )
988
+ if empty_result_set_value is NotImplemented:
989
+ raise
990
+ arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
991
+ except FullResultSet:
992
+ arg_sql, arg_params = compiler.compile(Value(True))
993
+ sql_parts.append(arg_sql)
994
+ params.extend(arg_params)
995
+ data = {**self.extra, **extra_context}
996
+ # Use the first supplied value in this order: the parameter to this
997
+ # method, a value supplied in __init__()'s **extra (the value in
998
+ # `data`), or the value defined on the class.
999
+ if function is not None:
1000
+ data["function"] = function
1001
+ else:
1002
+ data.setdefault("function", self.function)
1003
+ template = template or data.get("template", self.template)
1004
+ arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1005
+ data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1006
+ return template % data, params
1007
+
1008
+ def copy(self) -> Self:
1009
+ clone = super().copy()
1010
+ clone.source_expressions = self.source_expressions[:]
1011
+ clone.extra = self.extra.copy()
1012
+ return clone
1013
+
1014
+
1015
+ @deconstructible(path="plain.postgres.Value")
1016
+ class Value(Expression):
1017
+ """Represent a wrapped value as a node within an expression."""
1018
+
1019
+ # Provide a default value for `for_save` in order to allow unresolved
1020
+ # instances to be compiled until a decision is taken in #25425.
1021
+ for_save = False
1022
+
1023
+ def __init__(self, value: Any, output_field: Field | None = None):
1024
+ """
1025
+ Arguments:
1026
+ * value: the value this expression represents. The value will be
1027
+ added into the sql parameter list and properly quoted.
1028
+
1029
+ * output_field: an instance of the model field type that this
1030
+ expression will return, such as IntegerField() or CharField().
1031
+ """
1032
+ super().__init__(output_field=output_field)
1033
+ self.value = value
1034
+
1035
+ def __repr__(self) -> str:
1036
+ return f"{self.__class__.__name__}({self.value!r})"
1037
+
1038
+ def as_sql(
1039
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1040
+ ) -> tuple[str, list[Any]]:
1041
+ val = self.value
1042
+ output_field = self._output_field_or_none
1043
+ if output_field is not None:
1044
+ if self.for_save:
1045
+ val = output_field.get_db_prep_save(val, connection=connection)
1046
+ else:
1047
+ val = output_field.get_db_prep_value(val, connection=connection)
1048
+ if hasattr(output_field, "get_placeholder"):
1049
+ return output_field.get_placeholder(val, compiler, connection), [val] # type: ignore[call-non-callable]
1050
+ if val is None:
1051
+ return "NULL", []
1052
+ return "%s", [val]
1053
+
1054
+ def resolve_expression(
1055
+ self,
1056
+ query: Any = None,
1057
+ allow_joins: bool = True,
1058
+ reuse: Any = None,
1059
+ summarize: bool = False,
1060
+ for_save: bool = False,
1061
+ ) -> Value:
1062
+ c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1063
+ c.for_save = for_save
1064
+ return c
1065
+
1066
+ def get_group_by_cols(self) -> list[Any]:
1067
+ return []
1068
+
1069
+ def _resolve_output_field(self) -> Field | None:
1070
+ if isinstance(self.value, str):
1071
+ return fields.CharField()
1072
+ if isinstance(self.value, bool):
1073
+ return fields.BooleanField()
1074
+ if isinstance(self.value, int):
1075
+ return fields.IntegerField()
1076
+ if isinstance(self.value, float):
1077
+ return fields.FloatField()
1078
+ if isinstance(self.value, datetime.datetime):
1079
+ return fields.DateTimeField()
1080
+ if isinstance(self.value, datetime.date):
1081
+ return fields.DateField()
1082
+ if isinstance(self.value, datetime.time):
1083
+ return fields.TimeField()
1084
+ if isinstance(self.value, datetime.timedelta):
1085
+ return fields.DurationField()
1086
+ if isinstance(self.value, Decimal):
1087
+ return fields.DecimalField()
1088
+ if isinstance(self.value, bytes):
1089
+ return fields.BinaryField()
1090
+ if isinstance(self.value, UUID):
1091
+ return fields.UUIDField()
1092
+
1093
+ @property
1094
+ def empty_result_set_value(self) -> Any:
1095
+ return self.value
1096
+
1097
+
1098
+ class RawSQL(Expression):
1099
+ def __init__(
1100
+ self, sql: str, params: Sequence[Any], output_field: Field | None = None
1101
+ ):
1102
+ if output_field is None:
1103
+ output_field = fields.Field()
1104
+ self.sql, self.params = sql, params
1105
+ super().__init__(output_field=output_field)
1106
+
1107
+ def __repr__(self) -> str:
1108
+ return f"{self.__class__.__name__}({self.sql}, {self.params})"
1109
+
1110
+ def as_sql(
1111
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1112
+ ) -> tuple[str, Sequence[Any]]:
1113
+ return f"({self.sql})", self.params
1114
+
1115
+ def get_group_by_cols(self) -> list[BaseExpression]:
1116
+ return [self]
1117
+
1118
+
1119
+ class Star(Expression):
1120
+ def __repr__(self) -> str:
1121
+ return "'*'"
1122
+
1123
+ def as_sql(
1124
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1125
+ ) -> tuple[str, list[Any]]:
1126
+ return "*", []
1127
+
1128
+
1129
+ class Col(Expression):
1130
+ contains_column_references = True
1131
+ possibly_multivalued = False
1132
+
1133
+ def __init__(
1134
+ self, alias: str | None, target: Any, output_field: Field | None = None
1135
+ ):
1136
+ if output_field is None:
1137
+ output_field = target
1138
+ super().__init__(output_field=output_field)
1139
+ self.alias, self.target = alias, target
1140
+
1141
+ def __repr__(self) -> str:
1142
+ alias, target = self.alias, self.target
1143
+ identifiers = (alias, str(target)) if alias else (str(target),)
1144
+ return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1145
+
1146
+ def as_sql(
1147
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1148
+ ) -> tuple[str, list[Any]]:
1149
+ alias, column = self.alias, self.target.column
1150
+ identifiers = (alias, column) if alias else (column,)
1151
+ sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1152
+ return sql, []
1153
+
1154
+ def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1155
+ if self.alias is None:
1156
+ return self
1157
+ return self.__class__(
1158
+ change_map.get(self.alias, self.alias), self.target, self.output_field
1159
+ )
1160
+
1161
+ def get_group_by_cols(self) -> list[BaseExpression]:
1162
+ return [self]
1163
+
1164
+ def get_db_converters(
1165
+ self, connection: DatabaseConnection
1166
+ ) -> list[Callable[..., Any]]:
1167
+ if self.target == self.output_field:
1168
+ return self.output_field.get_db_converters(connection)
1169
+ return self.output_field.get_db_converters(
1170
+ connection
1171
+ ) + self.target.get_db_converters(connection)
1172
+
1173
+
1174
+ class Ref(Expression):
1175
+ """
1176
+ Reference to column alias of the query. For example, Ref('sum_cost') in
1177
+ qs.annotate(sum_cost=Sum('cost')) query.
1178
+ """
1179
+
1180
+ def __init__(self, refs: str, source: Any):
1181
+ super().__init__()
1182
+ self.refs, self.source = refs, source
1183
+
1184
+ def __repr__(self) -> str:
1185
+ return f"{self.__class__.__name__}({self.refs}, {self.source})"
1186
+
1187
+ def get_source_expressions(self) -> list[Any]:
1188
+ return [self.source]
1189
+
1190
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1191
+ (self.source,) = exprs
1192
+
1193
+ def resolve_expression(
1194
+ self,
1195
+ query: Any = None,
1196
+ allow_joins: bool = True,
1197
+ reuse: Any = None,
1198
+ summarize: bool = False,
1199
+ for_save: bool = False,
1200
+ ) -> Ref:
1201
+ # The sub-expression `source` has already been resolved, as this is
1202
+ # just a reference to the name of `source`.
1203
+ return self
1204
+
1205
+ def get_refs(self) -> set[str]:
1206
+ return {self.refs}
1207
+
1208
+ def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1209
+ return self
1210
+
1211
+ def as_sql(
1212
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1213
+ ) -> tuple[str, list[Any]]:
1214
+ return quote_name(self.refs), []
1215
+
1216
+ def get_group_by_cols(self) -> list[BaseExpression]:
1217
+ return [self]
1218
+
1219
+
1220
+ class ExpressionList(Func):
1221
+ """
1222
+ An expression containing multiple expressions. Can be used to provide a
1223
+ list of expressions as an argument to another expression, like a partition
1224
+ clause.
1225
+ """
1226
+
1227
+ template = "%(expressions)s"
1228
+
1229
+ def __init__(self, *expressions: Any, **extra: Any):
1230
+ if not expressions:
1231
+ raise ValueError(
1232
+ f"{self.__class__.__name__} requires at least one expression."
1233
+ )
1234
+ super().__init__(*expressions, **extra)
1235
+
1236
+ def __str__(self) -> str:
1237
+ return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1238
+
1239
+
1240
+ class OrderByList(Func):
1241
+ template = "ORDER BY %(expressions)s"
1242
+
1243
+ def __init__(self, *expressions: Any, **extra: Any):
1244
+ expressions_tuple = tuple(
1245
+ (
1246
+ OrderBy(F(expr[1:]), descending=True)
1247
+ if isinstance(expr, str) and expr[0] == "-"
1248
+ else expr
1249
+ )
1250
+ for expr in expressions
1251
+ )
1252
+ super().__init__(*expressions_tuple, **extra)
1253
+
1254
+ def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1255
+ if not self.source_expressions:
1256
+ return "", []
1257
+ sql, params = super().as_sql(*args, **kwargs)
1258
+ return sql, list(params)
1259
+
1260
+ def get_group_by_cols(self) -> list[Any]:
1261
+ group_by_cols = []
1262
+ for order_by in self.get_source_expressions():
1263
+ group_by_cols.extend(order_by.get_group_by_cols())
1264
+ return group_by_cols
1265
+
1266
+
1267
+ @deconstructible(path="plain.postgres.ExpressionWrapper")
1268
+ class ExpressionWrapper(Expression):
1269
+ """
1270
+ An expression that can wrap another expression so that it can provide
1271
+ extra context to the inner expression, such as the output_field.
1272
+ """
1273
+
1274
+ def __init__(self, expression: Any, output_field: Field):
1275
+ super().__init__(output_field=output_field)
1276
+ self.expression = expression
1277
+
1278
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1279
+ self.expression = exprs[0]
1280
+
1281
+ def get_source_expressions(self) -> list[Any]:
1282
+ return [self.expression]
1283
+
1284
+ def get_group_by_cols(self) -> list[Any]:
1285
+ if isinstance(self.expression, Expression):
1286
+ expression = self.expression.copy()
1287
+ expression.output_field = self.output_field
1288
+ return expression.get_group_by_cols()
1289
+ # For non-expressions e.g. an SQL WHERE clause, the entire
1290
+ # `expression` must be included in the GROUP BY clause.
1291
+ return super().get_group_by_cols()
1292
+
1293
+ def as_sql(
1294
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1295
+ ) -> tuple[str, Sequence[Any]]:
1296
+ return compiler.compile(self.expression)
1297
+
1298
+ def __repr__(self) -> str:
1299
+ return f"{self.__class__.__name__}({self.expression})"
1300
+
1301
+
1302
+ class NegatedExpression(ExpressionWrapper):
1303
+ """The logical negation of a conditional expression."""
1304
+
1305
+ def __init__(self, expression: Any):
1306
+ super().__init__(expression, output_field=fields.BooleanField())
1307
+
1308
+ def __invert__(self) -> Any:
1309
+ return self.expression.copy()
1310
+
1311
+ def as_sql(
1312
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1313
+ ) -> tuple[str, Sequence[Any]]:
1314
+ try:
1315
+ sql, params = super().as_sql(compiler, connection)
1316
+ except EmptyResultSet:
1317
+ return compiler.compile(Value(True))
1318
+ return f"NOT {sql}", params
1319
+
1320
+ def resolve_expression(
1321
+ self,
1322
+ query: Any = None,
1323
+ allow_joins: bool = True,
1324
+ reuse: Any = None,
1325
+ summarize: bool = False,
1326
+ for_save: bool = False,
1327
+ ) -> NegatedExpression:
1328
+ resolved = super().resolve_expression(
1329
+ query, allow_joins, reuse, summarize, for_save
1330
+ )
1331
+ if not getattr(resolved.expression, "conditional", False):
1332
+ raise TypeError("Cannot negate non-conditional expressions.")
1333
+ return resolved
1334
+
1335
+ def select_format(
1336
+ self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1337
+ ) -> tuple[str, Sequence[Any]]:
1338
+ # Boolean expressions work directly in SELECT
1339
+ return sql, params
1340
+
1341
+
1342
+ @deconstructible(path="plain.postgres.When")
1343
+ class When(Expression):
1344
+ template = "WHEN %(condition)s THEN %(result)s"
1345
+ # This isn't a complete conditional expression, must be used in Case().
1346
+ conditional = False
1347
+ condition: SQLCompilable
1348
+
1349
+ def __init__(
1350
+ self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1351
+ ):
1352
+ lookups_dict: dict[str, Any] | None = lookups or None
1353
+ if lookups_dict:
1354
+ if condition is None:
1355
+ condition, lookups_dict = Q(**lookups_dict), None
1356
+ elif getattr(condition, "conditional", False):
1357
+ condition, lookups_dict = Q(condition, **lookups_dict), None
1358
+ if (
1359
+ condition is None
1360
+ or not getattr(condition, "conditional", False)
1361
+ or lookups_dict
1362
+ ):
1363
+ raise TypeError(
1364
+ "When() supports a Q object, a boolean expression, or lookups "
1365
+ "as a condition."
1366
+ )
1367
+ if isinstance(condition, Q) and not condition:
1368
+ raise ValueError("An empty Q() can't be used as a When() condition.")
1369
+ super().__init__(output_field=None)
1370
+ self.condition = condition # type: ignore[assignment]
1371
+ self.result = self._parse_expressions(then)[0]
1372
+
1373
+ def __str__(self) -> str:
1374
+ return f"WHEN {self.condition!r} THEN {self.result!r}"
1375
+
1376
+ def __repr__(self) -> str:
1377
+ return f"<{self.__class__.__name__}: {self}>"
1378
+
1379
+ def get_source_expressions(self) -> list[Any]:
1380
+ return [self.condition, self.result]
1381
+
1382
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1383
+ self.condition, self.result = exprs
1384
+
1385
+ def get_source_fields(self) -> list[Field | None]:
1386
+ # We're only interested in the fields of the result expressions.
1387
+ return [self.result._output_field_or_none]
1388
+
1389
+ def resolve_expression(
1390
+ self,
1391
+ query: Any = None,
1392
+ allow_joins: bool = True,
1393
+ reuse: Any = None,
1394
+ summarize: bool = False,
1395
+ for_save: bool = False,
1396
+ ) -> When:
1397
+ c = self.copy()
1398
+ c.is_summary = summarize
1399
+ if isinstance(c.condition, ResolvableExpression):
1400
+ c.condition = c.condition.resolve_expression(
1401
+ query, allow_joins, reuse, summarize, False
1402
+ )
1403
+ c.result = c.result.resolve_expression(
1404
+ query, allow_joins, reuse, summarize, for_save
1405
+ )
1406
+ return c
1407
+
1408
+ def as_sql(
1409
+ self,
1410
+ compiler: SQLCompiler,
1411
+ connection: DatabaseConnection,
1412
+ template: str | None = None,
1413
+ **extra_context: Any,
1414
+ ) -> tuple[str, tuple[Any, ...]]:
1415
+ template_params = extra_context
1416
+ sql_params = []
1417
+ # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1418
+ condition_sql, condition_params = compiler.compile(self.condition)
1419
+ template_params["condition"] = condition_sql
1420
+ result_sql, result_params = compiler.compile(self.result)
1421
+ template_params["result"] = result_sql
1422
+ template = template or self.template
1423
+ return template % template_params, (
1424
+ *sql_params,
1425
+ *condition_params,
1426
+ *result_params,
1427
+ )
1428
+
1429
+ def get_group_by_cols(self) -> list[Any]:
1430
+ # This is not a complete expression and cannot be used in GROUP BY.
1431
+ cols = []
1432
+ for source in self.get_source_expressions():
1433
+ cols.extend(source.get_group_by_cols())
1434
+ return cols
1435
+
1436
+
1437
+ @deconstructible(path="plain.postgres.Case")
1438
+ class Case(Expression):
1439
+ """
1440
+ An SQL searched CASE expression:
1441
+
1442
+ CASE
1443
+ WHEN n > 0
1444
+ THEN 'positive'
1445
+ WHEN n < 0
1446
+ THEN 'negative'
1447
+ ELSE 'zero'
1448
+ END
1449
+ """
1450
+
1451
+ template = "CASE %(cases)s ELSE %(default)s END"
1452
+ case_joiner = " "
1453
+
1454
+ def __init__(
1455
+ self,
1456
+ *cases: When,
1457
+ default: Any = None,
1458
+ output_field: Field | None = None,
1459
+ **extra: Any,
1460
+ ):
1461
+ if not all(isinstance(case, When) for case in cases):
1462
+ raise TypeError("Positional arguments must all be When objects.")
1463
+ super().__init__(output_field)
1464
+ self.cases = list(cases)
1465
+ self.default = self._parse_expressions(default)[0]
1466
+ self.extra = extra
1467
+
1468
+ def __str__(self) -> str:
1469
+ return "CASE {}, ELSE {!r}".format(
1470
+ ", ".join(str(c) for c in self.cases),
1471
+ self.default,
1472
+ )
1473
+
1474
+ def __repr__(self) -> str:
1475
+ return f"<{self.__class__.__name__}: {self}>"
1476
+
1477
+ def get_source_expressions(self) -> list[Any]:
1478
+ return self.cases + [self.default]
1479
+
1480
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1481
+ *self.cases, self.default = exprs
1482
+
1483
+ def resolve_expression(
1484
+ self,
1485
+ query: Any = None,
1486
+ allow_joins: bool = True,
1487
+ reuse: Any = None,
1488
+ summarize: bool = False,
1489
+ for_save: bool = False,
1490
+ ) -> Case:
1491
+ c = self.copy()
1492
+ c.is_summary = summarize
1493
+ for pos, case in enumerate(c.cases):
1494
+ c.cases[pos] = case.resolve_expression(
1495
+ query, allow_joins, reuse, summarize, for_save
1496
+ )
1497
+ c.default = c.default.resolve_expression(
1498
+ query, allow_joins, reuse, summarize, for_save
1499
+ )
1500
+ return c
1501
+
1502
+ def copy(self) -> Self:
1503
+ c = super().copy()
1504
+ c.cases = c.cases[:]
1505
+ return c
1506
+
1507
+ def as_sql(
1508
+ self,
1509
+ compiler: SQLCompiler,
1510
+ connection: DatabaseConnection,
1511
+ template: str | None = None,
1512
+ case_joiner: str | None = None,
1513
+ **extra_context: Any,
1514
+ ) -> tuple[str, list[Any]]:
1515
+ if not self.cases:
1516
+ sql, params = compiler.compile(self.default)
1517
+ return sql, list(params)
1518
+ template_params = {**self.extra, **extra_context}
1519
+ case_parts = []
1520
+ sql_params = []
1521
+ default_sql, default_params = compiler.compile(self.default)
1522
+ for case in self.cases:
1523
+ try:
1524
+ case_sql, case_params = compiler.compile(case)
1525
+ except EmptyResultSet:
1526
+ continue
1527
+ except FullResultSet:
1528
+ default_sql, default_params = compiler.compile(case.result)
1529
+ break
1530
+ case_parts.append(case_sql)
1531
+ sql_params.extend(case_params)
1532
+ if not case_parts:
1533
+ return default_sql, list(default_params)
1534
+ case_joiner = case_joiner or self.case_joiner
1535
+ template_params["cases"] = case_joiner.join(case_parts)
1536
+ template_params["default"] = default_sql
1537
+ sql_params.extend(default_params)
1538
+ template = template or template_params.get("template", self.template)
1539
+ sql = template % template_params
1540
+ if self._output_field_or_none is not None:
1541
+ sql = connection.unification_cast_sql(self.output_field) % sql
1542
+ return sql, sql_params
1543
+
1544
+ def get_group_by_cols(self) -> list[Any]:
1545
+ if not self.cases:
1546
+ return self.default.get_group_by_cols()
1547
+ return super().get_group_by_cols()
1548
+
1549
+
1550
+ class Subquery(BaseExpression, Combinable):
1551
+ """
1552
+ An explicit subquery. It may contain OuterRef() references to the outer
1553
+ query which will be resolved when it is applied to that query.
1554
+ """
1555
+
1556
+ template = "(%(subquery)s)"
1557
+ contains_aggregate = False
1558
+ empty_result_set_value = None
1559
+
1560
+ def __init__(
1561
+ self,
1562
+ query: QuerySet[Any] | Query,
1563
+ output_field: Field | None = None,
1564
+ **extra: Any,
1565
+ ):
1566
+ # Import here to avoid circular import
1567
+ from plain.postgres.sql.query import Query
1568
+
1569
+ # Allow the usage of both QuerySet and sql.Query objects.
1570
+ if isinstance(query, Query):
1571
+ # It's already a Query object, use it directly
1572
+ sql_query = query
1573
+ else:
1574
+ # It's a QuerySet, extract the sql.Query
1575
+ sql_query = query.sql_query
1576
+ self.query = sql_query.clone()
1577
+ self.query.subquery = True
1578
+ self.extra = extra
1579
+ super().__init__(output_field)
1580
+
1581
+ def get_source_expressions(self) -> list[Any]:
1582
+ return [self.query]
1583
+
1584
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1585
+ self.query = exprs[0]
1586
+
1587
+ def _resolve_output_field(self) -> Field | None:
1588
+ return self.query.output_field
1589
+
1590
+ def copy(self) -> Self:
1591
+ clone = super().copy()
1592
+ clone.query = clone.query.clone()
1593
+ return clone
1594
+
1595
+ @property
1596
+ def external_aliases(self) -> dict[str, bool]:
1597
+ return self.query.external_aliases
1598
+
1599
+ def get_external_cols(self) -> list[Any]:
1600
+ return self.query.get_external_cols()
1601
+
1602
+ def as_sql(
1603
+ self,
1604
+ compiler: SQLCompiler,
1605
+ connection: DatabaseConnection,
1606
+ template: str | None = None,
1607
+ **extra_context: Any,
1608
+ ) -> tuple[str, tuple[Any, ...]]:
1609
+ template_params = {**self.extra, **extra_context}
1610
+ subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1611
+ template_params["subquery"] = subquery_sql[1:-1]
1612
+
1613
+ template = template or template_params.get("template", self.template)
1614
+ sql = template % template_params
1615
+ return sql, sql_params
1616
+
1617
+ def get_group_by_cols(self) -> list[Any]:
1618
+ return self.query.get_group_by_cols(wrapper=self)
1619
+
1620
+
1621
+ class Exists(Subquery):
1622
+ template = "EXISTS(%(subquery)s)"
1623
+ output_field = fields.BooleanField()
1624
+ empty_result_set_value = False
1625
+
1626
+ def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1627
+ super().__init__(query, **kwargs)
1628
+ self.query = self.query.exists()
1629
+
1630
+ def select_format(
1631
+ self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1632
+ ) -> tuple[str, Sequence[Any]]:
1633
+ # Boolean expressions work directly in SELECT
1634
+ return sql, params
1635
+
1636
+
1637
+ @deconstructible(path="plain.postgres.OrderBy")
1638
+ class OrderBy(Expression):
1639
+ template = "%(expression)s %(ordering)s"
1640
+ conditional = False
1641
+
1642
+ def __init__(
1643
+ self,
1644
+ expression: Any,
1645
+ descending: bool = False,
1646
+ nulls_first: bool | None = None,
1647
+ nulls_last: bool | None = None,
1648
+ ):
1649
+ if nulls_first and nulls_last:
1650
+ raise ValueError("nulls_first and nulls_last are mutually exclusive")
1651
+ if nulls_first is False or nulls_last is False:
1652
+ raise ValueError("nulls_first and nulls_last values must be True or None.")
1653
+ self.nulls_first = nulls_first
1654
+ self.nulls_last = nulls_last
1655
+ self.descending = descending
1656
+ if not isinstance(expression, ResolvableExpression):
1657
+ raise ValueError("expression must be an expression type")
1658
+ self.expression = expression
1659
+
1660
+ def __repr__(self) -> str:
1661
+ return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1662
+
1663
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1664
+ self.expression = exprs[0]
1665
+
1666
+ def get_source_expressions(self) -> list[Any]:
1667
+ return [self.expression]
1668
+
1669
+ def as_sql(
1670
+ self,
1671
+ compiler: SQLCompiler,
1672
+ connection: DatabaseConnection,
1673
+ template: str | None = None,
1674
+ **extra_context: Any,
1675
+ ) -> tuple[str, tuple[Any, ...]]:
1676
+ template = template or self.template
1677
+ # Handle NULLS FIRST/LAST modifiers
1678
+ if self.nulls_last:
1679
+ template = f"{template} NULLS LAST"
1680
+ elif self.nulls_first:
1681
+ template = f"{template} NULLS FIRST"
1682
+ expression_sql, params = compiler.compile(self.expression)
1683
+ placeholders = {
1684
+ "expression": expression_sql,
1685
+ "ordering": "DESC" if self.descending else "ASC",
1686
+ **extra_context,
1687
+ }
1688
+ params *= template.count("%(expression)s")
1689
+ return (template % placeholders).rstrip(), params
1690
+
1691
+ def get_group_by_cols(self) -> list[Any]:
1692
+ cols = []
1693
+ for source in self.get_source_expressions():
1694
+ cols.extend(source.get_group_by_cols())
1695
+ return cols
1696
+
1697
+ def reverse_ordering(self) -> OrderBy:
1698
+ self.descending = not self.descending
1699
+ if self.nulls_first:
1700
+ self.nulls_last = True
1701
+ self.nulls_first = None
1702
+ elif self.nulls_last:
1703
+ self.nulls_first = True
1704
+ self.nulls_last = None
1705
+ return self
1706
+
1707
+ def asc(self) -> None: # type: ignore[override]
1708
+ self.descending = False
1709
+
1710
+ def desc(self) -> None: # type: ignore[override]
1711
+ self.descending = True
1712
+
1713
+
1714
+ class Window(Expression):
1715
+ template = "%(expression)s OVER (%(window)s)"
1716
+ # Although the main expression may either be an aggregate or an
1717
+ # expression with an aggregate function, the GROUP BY that will
1718
+ # be introduced in the query as a result is not desired.
1719
+ contains_aggregate = False
1720
+ contains_over_clause = True
1721
+ partition_by: ExpressionList | None
1722
+ order_by: OrderByList | None
1723
+
1724
+ def __init__(
1725
+ self,
1726
+ expression: Any,
1727
+ partition_by: Any = None,
1728
+ order_by: Any = None,
1729
+ frame: Any = None,
1730
+ output_field: Field | None = None,
1731
+ ):
1732
+ self.partition_by = partition_by
1733
+ self.order_by = order_by
1734
+ self.frame = frame
1735
+
1736
+ if not getattr(expression, "window_compatible", False):
1737
+ raise ValueError(
1738
+ f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1739
+ )
1740
+
1741
+ if self.partition_by is not None:
1742
+ partition_by_values = (
1743
+ self.partition_by
1744
+ if isinstance(self.partition_by, tuple | list)
1745
+ else (self.partition_by,)
1746
+ )
1747
+ self.partition_by = ExpressionList(*partition_by_values)
1748
+
1749
+ if self.order_by is not None:
1750
+ if isinstance(self.order_by, list | tuple):
1751
+ self.order_by = OrderByList(*self.order_by)
1752
+ elif isinstance(self.order_by, BaseExpression | str):
1753
+ self.order_by = OrderByList(self.order_by)
1754
+ else:
1755
+ raise ValueError(
1756
+ "Window.order_by must be either a string reference to a "
1757
+ "field, an expression, or a list or tuple of them."
1758
+ )
1759
+ super().__init__(output_field=output_field)
1760
+ self.source_expression = self._parse_expressions(expression)[0]
1761
+
1762
+ def _resolve_output_field(self) -> Field | None:
1763
+ return self.source_expression.output_field
1764
+
1765
+ def get_source_expressions(self) -> list[Any]:
1766
+ return [self.source_expression, self.partition_by, self.order_by, self.frame]
1767
+
1768
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1769
+ self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1770
+
1771
+ def as_sql(
1772
+ self,
1773
+ compiler: SQLCompiler,
1774
+ connection: DatabaseConnection,
1775
+ template: str | None = None,
1776
+ ) -> tuple[str, tuple[Any, ...]]:
1777
+ expr_sql, params = compiler.compile(self.source_expression)
1778
+ window_sql, window_params = [], ()
1779
+
1780
+ if self.partition_by is not None:
1781
+ sql_expr, sql_params = self.partition_by.as_sql(
1782
+ compiler=compiler,
1783
+ connection=connection,
1784
+ template="PARTITION BY %(expressions)s",
1785
+ )
1786
+ window_sql.append(sql_expr)
1787
+ window_params += tuple(sql_params)
1788
+
1789
+ if self.order_by is not None:
1790
+ order_sql, order_params = compiler.compile(self.order_by)
1791
+ window_sql.append(order_sql)
1792
+ window_params += tuple(order_params)
1793
+
1794
+ if self.frame:
1795
+ frame_sql, frame_params = compiler.compile(self.frame)
1796
+ window_sql.append(frame_sql)
1797
+ window_params += tuple(frame_params)
1798
+
1799
+ template = template or self.template
1800
+
1801
+ return (
1802
+ template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1803
+ (*params, *window_params),
1804
+ )
1805
+
1806
+ def __str__(self) -> str:
1807
+ return "{} OVER ({}{}{})".format(
1808
+ str(self.source_expression),
1809
+ "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1810
+ str(self.order_by or ""),
1811
+ str(self.frame or ""),
1812
+ )
1813
+
1814
+ def __repr__(self) -> str:
1815
+ return f"<{self.__class__.__name__}: {self}>"
1816
+
1817
+ def get_group_by_cols(self) -> list[Any]:
1818
+ group_by_cols = []
1819
+ if self.partition_by:
1820
+ group_by_cols.extend(self.partition_by.get_group_by_cols())
1821
+ if self.order_by is not None:
1822
+ group_by_cols.extend(self.order_by.get_group_by_cols())
1823
+ return group_by_cols
1824
+
1825
+
1826
+ class WindowFrame(Expression):
1827
+ """
1828
+ Model the frame clause in window expressions. There are two types of frame
1829
+ clauses which are subclasses, however, all processing and validation (by no
1830
+ means intended to be complete) is done here. Thus, providing an end for a
1831
+ frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1832
+ row in the frame).
1833
+ """
1834
+
1835
+ template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1836
+ frame_type: str
1837
+
1838
+ def __init__(self, start: int | None = None, end: int | None = None):
1839
+ self.start = Value(start)
1840
+ self.end = Value(end)
1841
+
1842
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1843
+ self.start, self.end = exprs
1844
+
1845
+ def get_source_expressions(self) -> list[Any]:
1846
+ return [self.start, self.end]
1847
+
1848
+ def as_sql(
1849
+ self, compiler: SQLCompiler, connection: DatabaseConnection
1850
+ ) -> tuple[str, list[Any]]:
1851
+ start, end = self.window_frame_start_end(
1852
+ connection, self.start.value, self.end.value
1853
+ )
1854
+ return (
1855
+ self.template
1856
+ % {
1857
+ "frame_type": self.frame_type,
1858
+ "start": start,
1859
+ "end": end,
1860
+ },
1861
+ [],
1862
+ )
1863
+
1864
+ def __repr__(self) -> str:
1865
+ return f"<{self.__class__.__name__}: {self}>"
1866
+
1867
+ def get_group_by_cols(self) -> list[Any]:
1868
+ return []
1869
+
1870
+ def __str__(self) -> str:
1871
+ if self.start.value is not None and self.start.value < 0:
1872
+ start = f"{abs(self.start.value)} {PRECEDING}"
1873
+ elif self.start.value is not None and self.start.value == 0:
1874
+ start = CURRENT_ROW
1875
+ else:
1876
+ start = UNBOUNDED_PRECEDING
1877
+
1878
+ if self.end.value is not None and self.end.value > 0:
1879
+ end = f"{self.end.value} {FOLLOWING}"
1880
+ elif self.end.value is not None and self.end.value == 0:
1881
+ end = CURRENT_ROW
1882
+ else:
1883
+ end = UNBOUNDED_FOLLOWING
1884
+ return self.template % {
1885
+ "frame_type": self.frame_type,
1886
+ "start": start,
1887
+ "end": end,
1888
+ }
1889
+
1890
+ def window_frame_start_end(
1891
+ self, connection: DatabaseConnection, start: int | None, end: int | None
1892
+ ) -> tuple[str, str]:
1893
+ """Return the window frame start and end for the given connection."""
1894
+ raise NotImplementedError("Subclasses must implement window_frame_start_end()")
1895
+
1896
+
1897
+ class RowRange(WindowFrame):
1898
+ frame_type = "ROWS"
1899
+
1900
+ def window_frame_start_end(
1901
+ self, connection: DatabaseConnection, start: int | None, end: int | None
1902
+ ) -> tuple[str, str]:
1903
+ return window_frame_rows_start_end(start, end)
1904
+
1905
+
1906
+ class ValueRange(WindowFrame):
1907
+ frame_type = "RANGE"
1908
+
1909
+ def window_frame_start_end(
1910
+ self, connection: DatabaseConnection, start: int | None, end: int | None
1911
+ ) -> tuple[str, str]:
1912
+ return window_frame_range_start_end(start, end)