plain.postgres 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/postgres/CHANGELOG.md +1028 -0
- plain/postgres/README.md +925 -0
- plain/postgres/__init__.py +120 -0
- plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
- plain/postgres/aggregates.py +236 -0
- plain/postgres/backups/__init__.py +0 -0
- plain/postgres/backups/cli.py +148 -0
- plain/postgres/backups/clients.py +94 -0
- plain/postgres/backups/core.py +172 -0
- plain/postgres/base.py +1415 -0
- plain/postgres/cli/__init__.py +3 -0
- plain/postgres/cli/db.py +142 -0
- plain/postgres/cli/migrations.py +1085 -0
- plain/postgres/config.py +18 -0
- plain/postgres/connection.py +1331 -0
- plain/postgres/connections.py +77 -0
- plain/postgres/constants.py +13 -0
- plain/postgres/constraints.py +495 -0
- plain/postgres/database_url.py +94 -0
- plain/postgres/db.py +59 -0
- plain/postgres/default_settings.py +38 -0
- plain/postgres/deletion.py +475 -0
- plain/postgres/dialect.py +640 -0
- plain/postgres/entrypoints.py +4 -0
- plain/postgres/enums.py +103 -0
- plain/postgres/exceptions.py +217 -0
- plain/postgres/expressions.py +1912 -0
- plain/postgres/fields/__init__.py +2118 -0
- plain/postgres/fields/encrypted.py +354 -0
- plain/postgres/fields/json.py +413 -0
- plain/postgres/fields/mixins.py +30 -0
- plain/postgres/fields/related.py +1192 -0
- plain/postgres/fields/related_descriptors.py +290 -0
- plain/postgres/fields/related_lookups.py +223 -0
- plain/postgres/fields/related_managers.py +661 -0
- plain/postgres/fields/reverse_descriptors.py +229 -0
- plain/postgres/fields/reverse_related.py +328 -0
- plain/postgres/fields/timezones.py +143 -0
- plain/postgres/forms.py +773 -0
- plain/postgres/functions/__init__.py +189 -0
- plain/postgres/functions/comparison.py +127 -0
- plain/postgres/functions/datetime.py +454 -0
- plain/postgres/functions/math.py +140 -0
- plain/postgres/functions/mixins.py +59 -0
- plain/postgres/functions/text.py +282 -0
- plain/postgres/functions/window.py +125 -0
- plain/postgres/indexes.py +286 -0
- plain/postgres/lookups.py +758 -0
- plain/postgres/meta.py +584 -0
- plain/postgres/migrations/__init__.py +53 -0
- plain/postgres/migrations/autodetector.py +1379 -0
- plain/postgres/migrations/exceptions.py +54 -0
- plain/postgres/migrations/executor.py +188 -0
- plain/postgres/migrations/graph.py +364 -0
- plain/postgres/migrations/loader.py +377 -0
- plain/postgres/migrations/migration.py +180 -0
- plain/postgres/migrations/operations/__init__.py +34 -0
- plain/postgres/migrations/operations/base.py +139 -0
- plain/postgres/migrations/operations/fields.py +373 -0
- plain/postgres/migrations/operations/models.py +798 -0
- plain/postgres/migrations/operations/special.py +184 -0
- plain/postgres/migrations/optimizer.py +74 -0
- plain/postgres/migrations/questioner.py +340 -0
- plain/postgres/migrations/recorder.py +119 -0
- plain/postgres/migrations/serializer.py +378 -0
- plain/postgres/migrations/state.py +882 -0
- plain/postgres/migrations/utils.py +147 -0
- plain/postgres/migrations/writer.py +302 -0
- plain/postgres/options.py +207 -0
- plain/postgres/otel.py +231 -0
- plain/postgres/preflight.py +336 -0
- plain/postgres/query.py +2242 -0
- plain/postgres/query_utils.py +456 -0
- plain/postgres/registry.py +217 -0
- plain/postgres/schema.py +1885 -0
- plain/postgres/sql/__init__.py +40 -0
- plain/postgres/sql/compiler.py +1869 -0
- plain/postgres/sql/constants.py +22 -0
- plain/postgres/sql/datastructures.py +222 -0
- plain/postgres/sql/query.py +2947 -0
- plain/postgres/sql/where.py +374 -0
- plain/postgres/test/__init__.py +0 -0
- plain/postgres/test/pytest.py +117 -0
- plain/postgres/test/utils.py +18 -0
- plain/postgres/transaction.py +222 -0
- plain/postgres/types.py +92 -0
- plain/postgres/types.pyi +751 -0
- plain/postgres/utils.py +345 -0
- plain_postgres-0.84.0.dist-info/METADATA +937 -0
- plain_postgres-0.84.0.dist-info/RECORD +93 -0
- plain_postgres-0.84.0.dist-info/WHEEL +4 -0
- plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
- plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from plain.postgres.expressions import Func, ResolvableExpression, Value
|
|
6
|
+
from plain.postgres.fields import CharField, IntegerField, TextField
|
|
7
|
+
from plain.postgres.functions import Cast, Coalesce
|
|
8
|
+
from plain.postgres.lookups import Transform
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from plain.postgres.connection import DatabaseConnection
|
|
12
|
+
from plain.postgres.sql.compiler import SQLCompiler
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SHAMixin(Transform):
|
|
16
|
+
"""Base class for SHA hashing using PostgreSQL's pgcrypto extension."""
|
|
17
|
+
|
|
18
|
+
def as_sql(
|
|
19
|
+
self,
|
|
20
|
+
compiler: SQLCompiler,
|
|
21
|
+
connection: DatabaseConnection,
|
|
22
|
+
function: str | None = None,
|
|
23
|
+
template: str | None = None,
|
|
24
|
+
arg_joiner: str | None = None,
|
|
25
|
+
**extra_context: Any,
|
|
26
|
+
) -> tuple[str, list[Any]]:
|
|
27
|
+
assert self.function is not None
|
|
28
|
+
return super().as_sql(
|
|
29
|
+
compiler,
|
|
30
|
+
connection,
|
|
31
|
+
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
|
32
|
+
function=self.function.lower(),
|
|
33
|
+
**extra_context,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Chr(Transform):
|
|
38
|
+
function = "CHR"
|
|
39
|
+
lookup_name = "chr"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ConcatPair(Func):
|
|
43
|
+
"""Concatenate two arguments together."""
|
|
44
|
+
|
|
45
|
+
function = "CONCAT"
|
|
46
|
+
|
|
47
|
+
def as_sql(
|
|
48
|
+
self,
|
|
49
|
+
compiler: SQLCompiler,
|
|
50
|
+
connection: DatabaseConnection,
|
|
51
|
+
function: str | None = None,
|
|
52
|
+
template: str | None = None,
|
|
53
|
+
arg_joiner: str | None = None,
|
|
54
|
+
**extra_context: Any,
|
|
55
|
+
) -> tuple[str, list[Any]]:
|
|
56
|
+
# PostgreSQL requires explicit cast to text for CONCAT.
|
|
57
|
+
copy = self.copy()
|
|
58
|
+
copy.set_source_expressions(
|
|
59
|
+
[
|
|
60
|
+
Cast(expression, TextField())
|
|
61
|
+
for expression in copy.get_source_expressions()
|
|
62
|
+
]
|
|
63
|
+
)
|
|
64
|
+
return super(ConcatPair, copy).as_sql(
|
|
65
|
+
compiler,
|
|
66
|
+
connection,
|
|
67
|
+
**extra_context,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def coalesce(self) -> ConcatPair:
|
|
71
|
+
# null on either side results in null for expression, wrap with coalesce
|
|
72
|
+
c = self.copy()
|
|
73
|
+
c.set_source_expressions(
|
|
74
|
+
[
|
|
75
|
+
Coalesce(expression, Value(""))
|
|
76
|
+
for expression in c.get_source_expressions()
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
return c
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Concat(Func):
|
|
83
|
+
"""
|
|
84
|
+
Concatenate text fields together. Wraps each argument in coalesce
|
|
85
|
+
functions to ensure a non-null result.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
function = None
|
|
89
|
+
template = "%(expressions)s"
|
|
90
|
+
|
|
91
|
+
def __init__(self, *expressions: Any, **extra: Any) -> None:
|
|
92
|
+
if len(expressions) < 2:
|
|
93
|
+
raise ValueError("Concat must take at least two expressions")
|
|
94
|
+
paired = self._paired(expressions)
|
|
95
|
+
super().__init__(paired, **extra)
|
|
96
|
+
|
|
97
|
+
def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
|
|
98
|
+
# wrap pairs of expressions in successive concat functions
|
|
99
|
+
# exp = [a, b, c, d]
|
|
100
|
+
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
|
101
|
+
if len(expressions) == 2:
|
|
102
|
+
return ConcatPair(*expressions)
|
|
103
|
+
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Left(Func):
|
|
107
|
+
function = "LEFT"
|
|
108
|
+
arity = 2
|
|
109
|
+
output_field = CharField()
|
|
110
|
+
|
|
111
|
+
def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
|
|
112
|
+
"""
|
|
113
|
+
expression: the name of a field, or an expression returning a string
|
|
114
|
+
length: the number of characters to return from the start of the string
|
|
115
|
+
"""
|
|
116
|
+
if not isinstance(length, ResolvableExpression):
|
|
117
|
+
if length < 1:
|
|
118
|
+
raise ValueError("'length' must be greater than 0.")
|
|
119
|
+
super().__init__(expression, length, **extra)
|
|
120
|
+
|
|
121
|
+
def get_substr(self) -> Substr:
|
|
122
|
+
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Length(Transform):
|
|
126
|
+
"""Return the number of characters in the expression."""
|
|
127
|
+
|
|
128
|
+
function = "LENGTH"
|
|
129
|
+
lookup_name = "length"
|
|
130
|
+
output_field = IntegerField()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Lower(Transform):
|
|
134
|
+
function = "LOWER"
|
|
135
|
+
lookup_name = "lower"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class LPad(Func):
|
|
139
|
+
function = "LPAD"
|
|
140
|
+
output_field = CharField()
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
|
|
144
|
+
) -> None:
|
|
145
|
+
if (
|
|
146
|
+
not isinstance(length, ResolvableExpression)
|
|
147
|
+
and length is not None
|
|
148
|
+
and length < 0
|
|
149
|
+
):
|
|
150
|
+
raise ValueError("'length' must be greater or equal to 0.")
|
|
151
|
+
super().__init__(expression, length, fill_text, **extra)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class LTrim(Transform):
|
|
155
|
+
function = "LTRIM"
|
|
156
|
+
lookup_name = "ltrim"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class MD5(Transform):
|
|
160
|
+
function = "MD5"
|
|
161
|
+
lookup_name = "md5"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Ord(Transform):
|
|
165
|
+
function = "ASCII"
|
|
166
|
+
lookup_name = "ord"
|
|
167
|
+
output_field = IntegerField()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class Repeat(Func):
|
|
171
|
+
function = "REPEAT"
|
|
172
|
+
output_field = CharField()
|
|
173
|
+
|
|
174
|
+
def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
|
|
175
|
+
if (
|
|
176
|
+
not isinstance(number, ResolvableExpression)
|
|
177
|
+
and number is not None
|
|
178
|
+
and number < 0
|
|
179
|
+
):
|
|
180
|
+
raise ValueError("'number' must be greater or equal to 0.")
|
|
181
|
+
super().__init__(expression, number, **extra)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class Replace(Func):
|
|
185
|
+
function = "REPLACE"
|
|
186
|
+
|
|
187
|
+
def __init__(
|
|
188
|
+
self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
|
|
189
|
+
) -> None:
|
|
190
|
+
super().__init__(expression, text, replacement, **extra)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class Reverse(Transform):
|
|
194
|
+
function = "REVERSE"
|
|
195
|
+
lookup_name = "reverse"
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class Right(Left):
|
|
199
|
+
function = "RIGHT"
|
|
200
|
+
|
|
201
|
+
def get_substr(self) -> Substr:
|
|
202
|
+
return Substr(
|
|
203
|
+
self.source_expressions[0], self.source_expressions[1] * Value(-1)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class RPad(LPad):
|
|
208
|
+
function = "RPAD"
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class RTrim(Transform):
|
|
212
|
+
function = "RTRIM"
|
|
213
|
+
lookup_name = "rtrim"
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class SHA1(SHAMixin, Transform):
|
|
217
|
+
function = "SHA1"
|
|
218
|
+
lookup_name = "sha1"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class SHA224(SHAMixin, Transform):
|
|
222
|
+
function = "SHA224"
|
|
223
|
+
lookup_name = "sha224"
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class SHA256(SHAMixin, Transform):
|
|
227
|
+
function = "SHA256"
|
|
228
|
+
lookup_name = "sha256"
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class SHA384(SHAMixin, Transform):
|
|
232
|
+
function = "SHA384"
|
|
233
|
+
lookup_name = "sha384"
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class SHA512(SHAMixin, Transform):
|
|
237
|
+
function = "SHA512"
|
|
238
|
+
lookup_name = "sha512"
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class StrIndex(Func):
|
|
242
|
+
"""
|
|
243
|
+
Return a positive integer corresponding to the 1-indexed position of the
|
|
244
|
+
first occurrence of a substring inside another string, or 0 if the
|
|
245
|
+
substring is not found.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
# PostgreSQL uses STRPOS instead of INSTR.
|
|
249
|
+
function = "STRPOS"
|
|
250
|
+
arity = 2
|
|
251
|
+
output_field = IntegerField()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class Substr(Func):
|
|
255
|
+
function = "SUBSTRING"
|
|
256
|
+
output_field = CharField()
|
|
257
|
+
|
|
258
|
+
def __init__(
|
|
259
|
+
self, expression: Any, pos: Any, length: Any = None, **extra: Any
|
|
260
|
+
) -> None:
|
|
261
|
+
"""
|
|
262
|
+
expression: the name of a field, or an expression returning a string
|
|
263
|
+
pos: an integer > 0, or an expression returning an integer
|
|
264
|
+
length: an optional number of characters to return
|
|
265
|
+
"""
|
|
266
|
+
if not isinstance(pos, ResolvableExpression):
|
|
267
|
+
if pos < 1:
|
|
268
|
+
raise ValueError("'pos' must be greater than 0")
|
|
269
|
+
expressions = [expression, pos]
|
|
270
|
+
if length is not None:
|
|
271
|
+
expressions.append(length)
|
|
272
|
+
super().__init__(*expressions, **extra)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class Trim(Transform):
|
|
276
|
+
function = "TRIM"
|
|
277
|
+
lookup_name = "trim"
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class Upper(Transform):
|
|
281
|
+
function = "UPPER"
|
|
282
|
+
lookup_name = "upper"
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from plain.postgres.expressions import Func
|
|
6
|
+
from plain.postgres.fields import Field, FloatField, IntegerField
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"CumeDist",
|
|
10
|
+
"DenseRank",
|
|
11
|
+
"FirstValue",
|
|
12
|
+
"Lag",
|
|
13
|
+
"LastValue",
|
|
14
|
+
"Lead",
|
|
15
|
+
"NthValue",
|
|
16
|
+
"Ntile",
|
|
17
|
+
"PercentRank",
|
|
18
|
+
"Rank",
|
|
19
|
+
"RowNumber",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CumeDist(Func):
|
|
24
|
+
function = "CUME_DIST"
|
|
25
|
+
output_field = FloatField()
|
|
26
|
+
window_compatible = True
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DenseRank(Func):
|
|
30
|
+
function = "DENSE_RANK"
|
|
31
|
+
output_field = IntegerField()
|
|
32
|
+
window_compatible = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FirstValue(Func):
|
|
36
|
+
arity = 1
|
|
37
|
+
function = "FIRST_VALUE"
|
|
38
|
+
window_compatible = True
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LagLeadFunction(Func):
|
|
42
|
+
window_compatible = True
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self, expression: Any, offset: int = 1, default: Any = None, **extra: Any
|
|
46
|
+
) -> None:
|
|
47
|
+
if expression is None:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"{self.__class__.__name__} requires a non-null source expression."
|
|
50
|
+
)
|
|
51
|
+
if offset is None or offset <= 0:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"{self.__class__.__name__} requires a positive integer for the offset."
|
|
54
|
+
)
|
|
55
|
+
args = (expression, offset)
|
|
56
|
+
if default is not None:
|
|
57
|
+
args += (default,)
|
|
58
|
+
super().__init__(*args, **extra)
|
|
59
|
+
|
|
60
|
+
def _resolve_output_field(self) -> Field:
|
|
61
|
+
sources = self.get_source_expressions()
|
|
62
|
+
return sources[0].output_field
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Lag(LagLeadFunction):
|
|
66
|
+
function = "LAG"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LastValue(Func):
|
|
70
|
+
arity = 1
|
|
71
|
+
function = "LAST_VALUE"
|
|
72
|
+
window_compatible = True
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Lead(LagLeadFunction):
|
|
76
|
+
function = "LEAD"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class NthValue(Func):
|
|
80
|
+
function = "NTH_VALUE"
|
|
81
|
+
window_compatible = True
|
|
82
|
+
|
|
83
|
+
def __init__(self, expression: Any, nth: int = 1, **extra: Any) -> None:
|
|
84
|
+
if expression is None:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"{self.__class__.__name__} requires a non-null source expression."
|
|
87
|
+
)
|
|
88
|
+
if nth is None or nth <= 0:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"{self.__class__.__name__} requires a positive integer as for nth."
|
|
91
|
+
)
|
|
92
|
+
super().__init__(expression, nth, **extra)
|
|
93
|
+
|
|
94
|
+
def _resolve_output_field(self) -> Field:
|
|
95
|
+
sources = self.get_source_expressions()
|
|
96
|
+
return sources[0].output_field
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Ntile(Func):
|
|
100
|
+
function = "NTILE"
|
|
101
|
+
output_field = IntegerField()
|
|
102
|
+
window_compatible = True
|
|
103
|
+
|
|
104
|
+
def __init__(self, num_buckets: int = 1, **extra: Any) -> None:
|
|
105
|
+
if num_buckets <= 0:
|
|
106
|
+
raise ValueError("num_buckets must be greater than 0.")
|
|
107
|
+
super().__init__(num_buckets, **extra)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class PercentRank(Func):
|
|
111
|
+
function = "PERCENT_RANK"
|
|
112
|
+
output_field = FloatField()
|
|
113
|
+
window_compatible = True
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Rank(Func):
|
|
117
|
+
function = "RANK"
|
|
118
|
+
output_field = IntegerField()
|
|
119
|
+
window_compatible = True
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class RowNumber(Func):
|
|
123
|
+
function = "ROW_NUMBER"
|
|
124
|
+
output_field = IntegerField()
|
|
125
|
+
window_compatible = True
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from types import NoneType
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
5
|
+
|
|
6
|
+
from plain.postgres.expressions import Col, ExpressionList, F, Func, OrderBy
|
|
7
|
+
from plain.postgres.query_utils import Q
|
|
8
|
+
from plain.postgres.sql.query import Query
|
|
9
|
+
from plain.postgres.utils import names_digest, split_identifier
|
|
10
|
+
from plain.utils.functional import partition
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from plain.postgres.base import Model
|
|
14
|
+
from plain.postgres.expressions import Expression
|
|
15
|
+
from plain.postgres.schema import DatabaseSchemaEditor, Statement
|
|
16
|
+
|
|
17
|
+
__all__ = ["Index"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Index:
|
|
21
|
+
suffix = "idx"
|
|
22
|
+
# The max length of the name of the index
|
|
23
|
+
max_name_length = 30
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
*expressions: Any,
|
|
28
|
+
fields: tuple[str, ...] | list[str] = (),
|
|
29
|
+
name: str | None = None,
|
|
30
|
+
opclasses: tuple[str, ...] | list[str] = (),
|
|
31
|
+
condition: Q | None = None,
|
|
32
|
+
include: tuple[str, ...] | list[str] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
if opclasses and not name:
|
|
35
|
+
raise ValueError("An index must be named to use opclasses.")
|
|
36
|
+
if not isinstance(condition, NoneType | Q):
|
|
37
|
+
raise ValueError("Index.condition must be a Q instance.")
|
|
38
|
+
if condition and not name:
|
|
39
|
+
raise ValueError("An index must be named to use condition.")
|
|
40
|
+
if not isinstance(fields, list | tuple):
|
|
41
|
+
raise ValueError("Index.fields must be a list or tuple.")
|
|
42
|
+
if not isinstance(opclasses, list | tuple):
|
|
43
|
+
raise ValueError("Index.opclasses must be a list or tuple.")
|
|
44
|
+
if not expressions and not fields:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"At least one field or expression is required to define an index."
|
|
47
|
+
)
|
|
48
|
+
if expressions and fields:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"Index.fields and expressions are mutually exclusive.",
|
|
51
|
+
)
|
|
52
|
+
if expressions and not name:
|
|
53
|
+
raise ValueError("An index must be named to use expressions.")
|
|
54
|
+
if expressions and opclasses:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"Index.opclasses cannot be used with expressions. Use "
|
|
57
|
+
"a custom OpClass() instead."
|
|
58
|
+
)
|
|
59
|
+
if opclasses and len(fields) != len(opclasses):
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Index.fields and Index.opclasses must have the same number of "
|
|
62
|
+
"elements."
|
|
63
|
+
)
|
|
64
|
+
if fields and not all(isinstance(field, str) for field in fields):
|
|
65
|
+
raise ValueError("Index.fields must contain only strings with field names.")
|
|
66
|
+
if include and not name:
|
|
67
|
+
raise ValueError("A covering index must be named.")
|
|
68
|
+
if not isinstance(include, NoneType | list | tuple):
|
|
69
|
+
raise ValueError("Index.include must be a list or tuple.")
|
|
70
|
+
self.fields = list(fields)
|
|
71
|
+
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
|
|
72
|
+
self.fields_orders = [
|
|
73
|
+
(field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
|
|
74
|
+
for field_name in self.fields
|
|
75
|
+
]
|
|
76
|
+
self.name = name or ""
|
|
77
|
+
self.opclasses: tuple[str, ...] = tuple(opclasses)
|
|
78
|
+
self.condition = condition
|
|
79
|
+
self.include = tuple(include) if include else ()
|
|
80
|
+
self.expressions: tuple[Expression, ...] = tuple( # type: ignore[assignment]
|
|
81
|
+
F(expression) if isinstance(expression, str) else expression
|
|
82
|
+
for expression in expressions
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def contains_expressions(self) -> bool:
|
|
87
|
+
return bool(self.expressions)
|
|
88
|
+
|
|
89
|
+
def _get_condition_sql(
|
|
90
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
91
|
+
) -> str | None:
|
|
92
|
+
if self.condition is None:
|
|
93
|
+
return None
|
|
94
|
+
query = Query(model=model, alias_cols=False)
|
|
95
|
+
where = query.build_where(self.condition)
|
|
96
|
+
compiler = query.get_compiler()
|
|
97
|
+
sql, params = where.as_sql(compiler, schema_editor.connection)
|
|
98
|
+
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
|
99
|
+
|
|
100
|
+
def create_sql(
|
|
101
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
|
|
102
|
+
) -> Statement:
|
|
103
|
+
include = [
|
|
104
|
+
model._model_meta.get_forward_field(field_name).column
|
|
105
|
+
for field_name in self.include
|
|
106
|
+
]
|
|
107
|
+
condition = self._get_condition_sql(model, schema_editor)
|
|
108
|
+
if self.expressions:
|
|
109
|
+
index_expressions = []
|
|
110
|
+
for expression in self.expressions:
|
|
111
|
+
index_expression = IndexExpression(expression)
|
|
112
|
+
index_expressions.append(index_expression)
|
|
113
|
+
expressions = ExpressionList(*index_expressions).resolve_expression(
|
|
114
|
+
Query(model, alias_cols=False),
|
|
115
|
+
)
|
|
116
|
+
fields = None
|
|
117
|
+
col_suffixes = ()
|
|
118
|
+
else:
|
|
119
|
+
fields = [
|
|
120
|
+
model._model_meta.get_forward_field(field_name)
|
|
121
|
+
for field_name, _ in self.fields_orders
|
|
122
|
+
]
|
|
123
|
+
# Support index column ordering (ASC/DESC)
|
|
124
|
+
col_suffixes = tuple(order[1] for order in self.fields_orders)
|
|
125
|
+
expressions = None
|
|
126
|
+
return schema_editor._create_index_sql(
|
|
127
|
+
model,
|
|
128
|
+
fields=fields,
|
|
129
|
+
name=self.name,
|
|
130
|
+
col_suffixes=col_suffixes,
|
|
131
|
+
opclasses=self.opclasses,
|
|
132
|
+
condition=condition,
|
|
133
|
+
include=include,
|
|
134
|
+
expressions=expressions,
|
|
135
|
+
**kwargs,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def remove_sql(
|
|
139
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
|
|
140
|
+
) -> Statement:
|
|
141
|
+
return schema_editor._delete_index_sql(model, self.name, **kwargs)
|
|
142
|
+
|
|
143
|
+
def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
|
|
144
|
+
path = f"{self.__class__.__module__}.{self.__class__.__name__}"
|
|
145
|
+
path = path.replace("plain.postgres.indexes", "plain.postgres")
|
|
146
|
+
kwargs = {"name": self.name}
|
|
147
|
+
if self.fields:
|
|
148
|
+
kwargs["fields"] = self.fields
|
|
149
|
+
if self.opclasses:
|
|
150
|
+
kwargs["opclasses"] = self.opclasses
|
|
151
|
+
if self.condition:
|
|
152
|
+
kwargs["condition"] = self.condition
|
|
153
|
+
if self.include:
|
|
154
|
+
kwargs["include"] = self.include
|
|
155
|
+
return (path, self.expressions, kwargs)
|
|
156
|
+
|
|
157
|
+
def clone(self) -> Index:
|
|
158
|
+
"""Create a copy of this Index."""
|
|
159
|
+
_, args, kwargs = self.deconstruct()
|
|
160
|
+
return self.__class__(*args, **kwargs)
|
|
161
|
+
|
|
162
|
+
def set_name_with_model(self, model: type[Model]) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Generate a unique name for the index.
|
|
165
|
+
|
|
166
|
+
The name is divided into 3 parts - table name (12 chars), field name
|
|
167
|
+
(8 chars) and unique hash + suffix (10 chars). Each part is made to
|
|
168
|
+
fit its size by truncating the excess length.
|
|
169
|
+
"""
|
|
170
|
+
_, table_name = split_identifier(model.model_options.db_table)
|
|
171
|
+
column_names = [
|
|
172
|
+
model._model_meta.get_forward_field(field_name).column
|
|
173
|
+
for field_name, order in self.fields_orders
|
|
174
|
+
]
|
|
175
|
+
column_names_with_order = [
|
|
176
|
+
(("-%s" if order else "%s") % column_name)
|
|
177
|
+
for column_name, (field_name, order) in zip(
|
|
178
|
+
column_names, self.fields_orders
|
|
179
|
+
)
|
|
180
|
+
]
|
|
181
|
+
# The length of the parts of the name is based on the default max
|
|
182
|
+
# length of 30 characters.
|
|
183
|
+
hash_data = [table_name] + column_names_with_order + [self.suffix]
|
|
184
|
+
self.name = "{}_{}_{}".format(
|
|
185
|
+
table_name[:11],
|
|
186
|
+
column_names[0][:7],
|
|
187
|
+
f"{names_digest(*hash_data, length=6)}_{self.suffix}",
|
|
188
|
+
)
|
|
189
|
+
if len(self.name) > self.max_name_length:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Index name too long. Is self.suffix longer than 3 characters?"
|
|
192
|
+
)
|
|
193
|
+
if self.name[0] == "_" or self.name[0].isdigit():
|
|
194
|
+
self.name = f"D{self.name[1:]}"
|
|
195
|
+
|
|
196
|
+
def __repr__(self) -> str:
|
|
197
|
+
return "<{}:{}{}{}{}{}{}>".format(
|
|
198
|
+
self.__class__.__qualname__,
|
|
199
|
+
"" if not self.fields else f" fields={repr(self.fields)}",
|
|
200
|
+
"" if not self.expressions else f" expressions={repr(self.expressions)}",
|
|
201
|
+
"" if not self.name else f" name={repr(self.name)}",
|
|
202
|
+
"" if self.condition is None else f" condition={self.condition}",
|
|
203
|
+
"" if not self.include else f" include={repr(self.include)}",
|
|
204
|
+
"" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def __eq__(self, other: object) -> bool:
|
|
208
|
+
if isinstance(other, Index):
|
|
209
|
+
return self.deconstruct() == other.deconstruct()
|
|
210
|
+
return NotImplemented
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class IndexExpression(Func):
|
|
214
|
+
"""Order and wrap expressions for CREATE INDEX statements."""
|
|
215
|
+
|
|
216
|
+
template = "%(expressions)s"
|
|
217
|
+
wrapper_classes = (OrderBy,)
|
|
218
|
+
|
|
219
|
+
def resolve_expression(
|
|
220
|
+
self,
|
|
221
|
+
query: Any = None,
|
|
222
|
+
allow_joins: bool = True,
|
|
223
|
+
reuse: Any = None,
|
|
224
|
+
summarize: bool = False,
|
|
225
|
+
for_save: bool = False,
|
|
226
|
+
) -> Self:
|
|
227
|
+
expressions = list(self.flatten())
|
|
228
|
+
# Split expressions and wrappers.
|
|
229
|
+
index_expressions, wrappers = partition(
|
|
230
|
+
lambda e: isinstance(e, self.wrapper_classes),
|
|
231
|
+
expressions,
|
|
232
|
+
)
|
|
233
|
+
wrapper_types = [type(wrapper) for wrapper in wrappers]
|
|
234
|
+
if len(wrapper_types) != len(set(wrapper_types)):
|
|
235
|
+
raise ValueError(
|
|
236
|
+
"Multiple references to {} can't be used in an indexed "
|
|
237
|
+
"expression.".format(
|
|
238
|
+
", ".join(
|
|
239
|
+
[
|
|
240
|
+
wrapper_cls.__qualname__
|
|
241
|
+
for wrapper_cls in self.wrapper_classes
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
if expressions[1 : len(wrappers) + 1] != wrappers:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
"{} must be topmost expressions in an indexed expression.".format(
|
|
249
|
+
", ".join(
|
|
250
|
+
[
|
|
251
|
+
wrapper_cls.__qualname__
|
|
252
|
+
for wrapper_cls in self.wrapper_classes
|
|
253
|
+
]
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
# Wrap expressions in parentheses if they are not column references.
|
|
258
|
+
root_expression = index_expressions[1]
|
|
259
|
+
resolve_root_expression = root_expression.resolve_expression(
|
|
260
|
+
query,
|
|
261
|
+
allow_joins,
|
|
262
|
+
reuse,
|
|
263
|
+
summarize,
|
|
264
|
+
for_save,
|
|
265
|
+
)
|
|
266
|
+
if not isinstance(resolve_root_expression, Col):
|
|
267
|
+
root_expression = Func(root_expression, template="(%(expressions)s)")
|
|
268
|
+
|
|
269
|
+
if wrappers:
|
|
270
|
+
# Order wrappers and set their expressions.
|
|
271
|
+
wrappers = sorted(
|
|
272
|
+
wrappers,
|
|
273
|
+
key=lambda w: self.wrapper_classes.index(type(w)),
|
|
274
|
+
)
|
|
275
|
+
wrappers = [wrapper.copy() for wrapper in wrappers]
|
|
276
|
+
for i, wrapper in enumerate(wrappers[:-1]):
|
|
277
|
+
wrapper.set_source_expressions([wrappers[i + 1]])
|
|
278
|
+
# Set the root expression on the deepest wrapper.
|
|
279
|
+
wrappers[-1].set_source_expressions([root_expression])
|
|
280
|
+
self.set_source_expressions([wrappers[0]])
|
|
281
|
+
else:
|
|
282
|
+
# Use the root expression, if there are no wrappers.
|
|
283
|
+
self.set_source_expressions([root_expression])
|
|
284
|
+
return super().resolve_expression(
|
|
285
|
+
query, allow_joins, reuse, summarize, for_save
|
|
286
|
+
)
|