plain.postgres 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/postgres/CHANGELOG.md +1028 -0
- plain/postgres/README.md +925 -0
- plain/postgres/__init__.py +120 -0
- plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
- plain/postgres/aggregates.py +236 -0
- plain/postgres/backups/__init__.py +0 -0
- plain/postgres/backups/cli.py +148 -0
- plain/postgres/backups/clients.py +94 -0
- plain/postgres/backups/core.py +172 -0
- plain/postgres/base.py +1415 -0
- plain/postgres/cli/__init__.py +3 -0
- plain/postgres/cli/db.py +142 -0
- plain/postgres/cli/migrations.py +1085 -0
- plain/postgres/config.py +18 -0
- plain/postgres/connection.py +1331 -0
- plain/postgres/connections.py +77 -0
- plain/postgres/constants.py +13 -0
- plain/postgres/constraints.py +495 -0
- plain/postgres/database_url.py +94 -0
- plain/postgres/db.py +59 -0
- plain/postgres/default_settings.py +38 -0
- plain/postgres/deletion.py +475 -0
- plain/postgres/dialect.py +640 -0
- plain/postgres/entrypoints.py +4 -0
- plain/postgres/enums.py +103 -0
- plain/postgres/exceptions.py +217 -0
- plain/postgres/expressions.py +1912 -0
- plain/postgres/fields/__init__.py +2118 -0
- plain/postgres/fields/encrypted.py +354 -0
- plain/postgres/fields/json.py +413 -0
- plain/postgres/fields/mixins.py +30 -0
- plain/postgres/fields/related.py +1192 -0
- plain/postgres/fields/related_descriptors.py +290 -0
- plain/postgres/fields/related_lookups.py +223 -0
- plain/postgres/fields/related_managers.py +661 -0
- plain/postgres/fields/reverse_descriptors.py +229 -0
- plain/postgres/fields/reverse_related.py +328 -0
- plain/postgres/fields/timezones.py +143 -0
- plain/postgres/forms.py +773 -0
- plain/postgres/functions/__init__.py +189 -0
- plain/postgres/functions/comparison.py +127 -0
- plain/postgres/functions/datetime.py +454 -0
- plain/postgres/functions/math.py +140 -0
- plain/postgres/functions/mixins.py +59 -0
- plain/postgres/functions/text.py +282 -0
- plain/postgres/functions/window.py +125 -0
- plain/postgres/indexes.py +286 -0
- plain/postgres/lookups.py +758 -0
- plain/postgres/meta.py +584 -0
- plain/postgres/migrations/__init__.py +53 -0
- plain/postgres/migrations/autodetector.py +1379 -0
- plain/postgres/migrations/exceptions.py +54 -0
- plain/postgres/migrations/executor.py +188 -0
- plain/postgres/migrations/graph.py +364 -0
- plain/postgres/migrations/loader.py +377 -0
- plain/postgres/migrations/migration.py +180 -0
- plain/postgres/migrations/operations/__init__.py +34 -0
- plain/postgres/migrations/operations/base.py +139 -0
- plain/postgres/migrations/operations/fields.py +373 -0
- plain/postgres/migrations/operations/models.py +798 -0
- plain/postgres/migrations/operations/special.py +184 -0
- plain/postgres/migrations/optimizer.py +74 -0
- plain/postgres/migrations/questioner.py +340 -0
- plain/postgres/migrations/recorder.py +119 -0
- plain/postgres/migrations/serializer.py +378 -0
- plain/postgres/migrations/state.py +882 -0
- plain/postgres/migrations/utils.py +147 -0
- plain/postgres/migrations/writer.py +302 -0
- plain/postgres/options.py +207 -0
- plain/postgres/otel.py +231 -0
- plain/postgres/preflight.py +336 -0
- plain/postgres/query.py +2242 -0
- plain/postgres/query_utils.py +456 -0
- plain/postgres/registry.py +217 -0
- plain/postgres/schema.py +1885 -0
- plain/postgres/sql/__init__.py +40 -0
- plain/postgres/sql/compiler.py +1869 -0
- plain/postgres/sql/constants.py +22 -0
- plain/postgres/sql/datastructures.py +222 -0
- plain/postgres/sql/query.py +2947 -0
- plain/postgres/sql/where.py +374 -0
- plain/postgres/test/__init__.py +0 -0
- plain/postgres/test/pytest.py +117 -0
- plain/postgres/test/utils.py +18 -0
- plain/postgres/transaction.py +222 -0
- plain/postgres/types.py +92 -0
- plain/postgres/types.pyi +751 -0
- plain/postgres/utils.py +345 -0
- plain_postgres-0.84.0.dist-info/METADATA +937 -0
- plain_postgres-0.84.0.dist-info/RECORD +93 -0
- plain_postgres-0.84.0.dist-info/WHEEL +4 -0
- plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
- plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextvars import ContextVar
|
|
4
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
5
|
+
|
|
6
|
+
from plain.exceptions import ImproperlyConfigured
|
|
7
|
+
from plain.runtime import settings as plain_settings
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from plain.postgres.connection import DatabaseConnection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatabaseConfig(TypedDict, total=False):
|
|
14
|
+
CONN_MAX_AGE: int | None
|
|
15
|
+
CONN_HEALTH_CHECKS: bool
|
|
16
|
+
HOST: str
|
|
17
|
+
DATABASE: str | None
|
|
18
|
+
OPTIONS: dict[str, Any]
|
|
19
|
+
PASSWORD: str
|
|
20
|
+
PORT: int | None
|
|
21
|
+
TEST: dict[str, Any]
|
|
22
|
+
TIME_ZONE: str | None
|
|
23
|
+
USER: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Module-level ContextVar for per-task/per-thread connection storage.
|
|
27
|
+
# Each asyncio.Task gets its own copy (since Python 3.7.1).
|
|
28
|
+
# Thread pool threads maintain their own native context across work items,
|
|
29
|
+
# so connections persist across requests (honoring CONN_MAX_AGE).
|
|
30
|
+
_db_conn: ContextVar[DatabaseConnection | None] = ContextVar("_db_conn", default=None)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _configure_settings() -> DatabaseConfig:
|
|
34
|
+
if plain_settings.POSTGRES_DATABASE == "":
|
|
35
|
+
raise ImproperlyConfigured(
|
|
36
|
+
"The PostgreSQL database has been disabled (DATABASE_URL=none). "
|
|
37
|
+
"No database operations are available in this context."
|
|
38
|
+
)
|
|
39
|
+
if not plain_settings.POSTGRES_DATABASE: # None or unresolved setting
|
|
40
|
+
raise ImproperlyConfigured(
|
|
41
|
+
"PostgreSQL database is not configured. "
|
|
42
|
+
"Set DATABASE_URL or the individual POSTGRES_* settings."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return {
|
|
46
|
+
"DATABASE": plain_settings.POSTGRES_DATABASE,
|
|
47
|
+
"USER": plain_settings.POSTGRES_USER,
|
|
48
|
+
"PASSWORD": plain_settings.POSTGRES_PASSWORD,
|
|
49
|
+
"HOST": plain_settings.POSTGRES_HOST,
|
|
50
|
+
"PORT": plain_settings.POSTGRES_PORT,
|
|
51
|
+
"CONN_MAX_AGE": plain_settings.POSTGRES_CONN_MAX_AGE,
|
|
52
|
+
"CONN_HEALTH_CHECKS": plain_settings.POSTGRES_CONN_HEALTH_CHECKS,
|
|
53
|
+
"OPTIONS": plain_settings.POSTGRES_OPTIONS,
|
|
54
|
+
"TIME_ZONE": plain_settings.POSTGRES_TIME_ZONE,
|
|
55
|
+
"TEST": {"DATABASE": None},
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _create_connection() -> DatabaseConnection:
|
|
60
|
+
from plain.postgres.connection import DatabaseConnection
|
|
61
|
+
|
|
62
|
+
database_config = _configure_settings()
|
|
63
|
+
return DatabaseConnection(database_config)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_connection() -> DatabaseConnection:
|
|
67
|
+
"""Get or create the database connection for the current context."""
|
|
68
|
+
conn = _db_conn.get()
|
|
69
|
+
if conn is None:
|
|
70
|
+
conn = _create_connection()
|
|
71
|
+
_db_conn.set(conn)
|
|
72
|
+
return conn
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def has_connection() -> bool:
|
|
76
|
+
"""Check if a database connection exists in the current context."""
|
|
77
|
+
return _db_conn.get() is not None
|
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from types import NoneType
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from plain.exceptions import ValidationError
|
|
8
|
+
from plain.postgres.exceptions import FieldError
|
|
9
|
+
from plain.postgres.expressions import (
|
|
10
|
+
Exists,
|
|
11
|
+
ExpressionList,
|
|
12
|
+
F,
|
|
13
|
+
OrderBy,
|
|
14
|
+
ReplaceableExpression,
|
|
15
|
+
)
|
|
16
|
+
from plain.postgres.indexes import IndexExpression
|
|
17
|
+
from plain.postgres.lookups import Exact
|
|
18
|
+
from plain.postgres.query_utils import Q
|
|
19
|
+
from plain.postgres.sql.query import Query
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from plain.postgres.base import Model
|
|
23
|
+
from plain.postgres.schema import DatabaseSchemaEditor, Statement
|
|
24
|
+
|
|
25
|
+
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseConstraint:
|
|
29
|
+
default_violation_error_message = 'Constraint "%(name)s" is violated.'
|
|
30
|
+
violation_error_code: str | None = None
|
|
31
|
+
violation_error_message: str | None = None
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
*,
|
|
36
|
+
name: str,
|
|
37
|
+
violation_error_code: str | None = None,
|
|
38
|
+
violation_error_message: str | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.name = name
|
|
41
|
+
if violation_error_code is not None:
|
|
42
|
+
self.violation_error_code = violation_error_code
|
|
43
|
+
if violation_error_message is not None:
|
|
44
|
+
self.violation_error_message = violation_error_message
|
|
45
|
+
else:
|
|
46
|
+
self.violation_error_message = self.default_violation_error_message
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def contains_expressions(self) -> bool:
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
def constraint_sql(
|
|
53
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
54
|
+
) -> str | None:
|
|
55
|
+
raise NotImplementedError(
|
|
56
|
+
"subclasses of BaseConstraint must provide a constraint_sql() method"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def create_sql(
|
|
60
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
61
|
+
) -> str | Statement | None:
|
|
62
|
+
raise NotImplementedError(
|
|
63
|
+
"subclasses of BaseConstraint must provide a create_sql() method"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def remove_sql(
|
|
67
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
68
|
+
) -> str | Statement | None:
|
|
69
|
+
raise NotImplementedError(
|
|
70
|
+
"subclasses of BaseConstraint must provide a remove_sql() method"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def validate(
|
|
74
|
+
self, model: type[Model], instance: Model, exclude: set[str] | None = None
|
|
75
|
+
) -> None:
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
"subclasses of BaseConstraint must provide a validate() method"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def get_violation_error_message(self) -> str:
|
|
81
|
+
assert self.violation_error_message is not None
|
|
82
|
+
return self.violation_error_message % {"name": self.name}
|
|
83
|
+
|
|
84
|
+
def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
|
|
85
|
+
path = f"{self.__class__.__module__}.{self.__class__.__name__}"
|
|
86
|
+
path = path.replace("plain.postgres.constraints", "plain.postgres")
|
|
87
|
+
kwargs: dict[str, Any] = {"name": self.name}
|
|
88
|
+
if (
|
|
89
|
+
self.violation_error_message is not None
|
|
90
|
+
and self.violation_error_message != self.default_violation_error_message
|
|
91
|
+
):
|
|
92
|
+
kwargs["violation_error_message"] = self.violation_error_message
|
|
93
|
+
if self.violation_error_code is not None:
|
|
94
|
+
kwargs["violation_error_code"] = self.violation_error_code
|
|
95
|
+
return (path, (), kwargs)
|
|
96
|
+
|
|
97
|
+
def clone(self) -> BaseConstraint:
|
|
98
|
+
_, args, kwargs = self.deconstruct()
|
|
99
|
+
return self.__class__(*args, **kwargs)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class CheckConstraint(BaseConstraint):
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
check: Q,
|
|
107
|
+
name: str,
|
|
108
|
+
violation_error_code: str | None = None,
|
|
109
|
+
violation_error_message: str | None = None,
|
|
110
|
+
) -> None:
|
|
111
|
+
self.check = check
|
|
112
|
+
if not getattr(check, "conditional", False):
|
|
113
|
+
raise TypeError(
|
|
114
|
+
"CheckConstraint.check must be a Q instance or boolean expression."
|
|
115
|
+
)
|
|
116
|
+
super().__init__(
|
|
117
|
+
name=name,
|
|
118
|
+
violation_error_code=violation_error_code,
|
|
119
|
+
violation_error_message=violation_error_message,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _get_check_sql(
|
|
123
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
124
|
+
) -> str:
|
|
125
|
+
query = Query(model=model, alias_cols=False)
|
|
126
|
+
where = query.build_where(self.check)
|
|
127
|
+
compiler = query.get_compiler()
|
|
128
|
+
sql, params = where.as_sql(compiler, schema_editor.connection)
|
|
129
|
+
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
|
130
|
+
|
|
131
|
+
def constraint_sql(
|
|
132
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
133
|
+
) -> str:
|
|
134
|
+
check = self._get_check_sql(model, schema_editor)
|
|
135
|
+
return schema_editor._check_sql(self.name, check)
|
|
136
|
+
|
|
137
|
+
def create_sql(
|
|
138
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
139
|
+
) -> Statement | None:
|
|
140
|
+
check = self._get_check_sql(model, schema_editor)
|
|
141
|
+
return schema_editor._create_check_sql(model, self.name, check)
|
|
142
|
+
|
|
143
|
+
def remove_sql(
|
|
144
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
145
|
+
) -> Statement | None:
|
|
146
|
+
return schema_editor._delete_constraint_sql(
|
|
147
|
+
schema_editor.sql_delete_check, model, self.name
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def validate(
|
|
151
|
+
self, model: type[Model], instance: Model, exclude: set[str] | None = None
|
|
152
|
+
) -> None:
|
|
153
|
+
against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
|
|
154
|
+
try:
|
|
155
|
+
if not Q(self.check).check(against):
|
|
156
|
+
raise ValidationError(
|
|
157
|
+
self.get_violation_error_message(), code=self.violation_error_code
|
|
158
|
+
)
|
|
159
|
+
except FieldError:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
def __repr__(self) -> str:
|
|
163
|
+
return "<{}: check={} name={}{}{}>".format(
|
|
164
|
+
self.__class__.__qualname__,
|
|
165
|
+
self.check,
|
|
166
|
+
repr(self.name),
|
|
167
|
+
(
|
|
168
|
+
""
|
|
169
|
+
if self.violation_error_code is None
|
|
170
|
+
else f" violation_error_code={self.violation_error_code!r}"
|
|
171
|
+
),
|
|
172
|
+
(
|
|
173
|
+
""
|
|
174
|
+
if self.violation_error_message is None
|
|
175
|
+
or self.violation_error_message == self.default_violation_error_message
|
|
176
|
+
else f" violation_error_message={self.violation_error_message!r}"
|
|
177
|
+
),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def __eq__(self, other: object) -> bool:
|
|
181
|
+
if isinstance(other, CheckConstraint):
|
|
182
|
+
return (
|
|
183
|
+
self.name == other.name
|
|
184
|
+
and self.check == other.check
|
|
185
|
+
and self.violation_error_code == other.violation_error_code
|
|
186
|
+
and self.violation_error_message == other.violation_error_message
|
|
187
|
+
)
|
|
188
|
+
return super().__eq__(other)
|
|
189
|
+
|
|
190
|
+
def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
|
|
191
|
+
path, args, kwargs = super().deconstruct()
|
|
192
|
+
kwargs["check"] = self.check
|
|
193
|
+
return path, args, kwargs
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class Deferrable(Enum):
|
|
197
|
+
DEFERRED = "deferred"
|
|
198
|
+
IMMEDIATE = "immediate"
|
|
199
|
+
|
|
200
|
+
# A similar format was proposed for Python 3.10.
|
|
201
|
+
def __repr__(self) -> str:
|
|
202
|
+
return f"{self.__class__.__qualname__}.{self._name_}"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class UniqueConstraint(BaseConstraint):
|
|
206
|
+
expressions: tuple[ReplaceableExpression, ...]
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
*expressions: str | ReplaceableExpression,
|
|
211
|
+
fields: tuple[str, ...] | list[str] = (),
|
|
212
|
+
name: str | None = None,
|
|
213
|
+
condition: Q | None = None,
|
|
214
|
+
deferrable: Deferrable | None = None,
|
|
215
|
+
include: tuple[str, ...] | list[str] | None = None,
|
|
216
|
+
opclasses: tuple[str, ...] | list[str] = (),
|
|
217
|
+
violation_error_code: str | None = None,
|
|
218
|
+
violation_error_message: str | None = None,
|
|
219
|
+
) -> None:
|
|
220
|
+
if not name:
|
|
221
|
+
raise ValueError("A unique constraint must be named.")
|
|
222
|
+
if not expressions and not fields:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"At least one field or expression is required to define a "
|
|
225
|
+
"unique constraint."
|
|
226
|
+
)
|
|
227
|
+
if expressions and fields:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"UniqueConstraint.fields and expressions are mutually exclusive."
|
|
230
|
+
)
|
|
231
|
+
if not isinstance(condition, NoneType | Q):
|
|
232
|
+
raise ValueError("UniqueConstraint.condition must be a Q instance.")
|
|
233
|
+
if condition and deferrable:
|
|
234
|
+
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
|
|
235
|
+
if include and deferrable:
|
|
236
|
+
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
|
|
237
|
+
if opclasses and deferrable:
|
|
238
|
+
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
|
|
239
|
+
if expressions and deferrable:
|
|
240
|
+
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
|
|
241
|
+
if expressions and opclasses:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"UniqueConstraint.opclasses cannot be used with expressions. "
|
|
244
|
+
"Use a custom OpClass() instead."
|
|
245
|
+
)
|
|
246
|
+
if not isinstance(deferrable, NoneType | Deferrable):
|
|
247
|
+
raise ValueError(
|
|
248
|
+
"UniqueConstraint.deferrable must be a Deferrable instance."
|
|
249
|
+
)
|
|
250
|
+
if not isinstance(include, NoneType | list | tuple):
|
|
251
|
+
raise ValueError("UniqueConstraint.include must be a list or tuple.")
|
|
252
|
+
if not isinstance(opclasses, list | tuple):
|
|
253
|
+
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
|
|
254
|
+
if opclasses and len(fields) != len(opclasses):
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
|
|
257
|
+
"have the same number of elements."
|
|
258
|
+
)
|
|
259
|
+
self.fields = tuple(fields)
|
|
260
|
+
self.condition = condition
|
|
261
|
+
self.deferrable = deferrable
|
|
262
|
+
self.include = tuple(include) if include else ()
|
|
263
|
+
self.opclasses = opclasses
|
|
264
|
+
self.expressions = tuple(
|
|
265
|
+
F(expression) if isinstance(expression, str) else expression
|
|
266
|
+
for expression in expressions
|
|
267
|
+
)
|
|
268
|
+
super().__init__(
|
|
269
|
+
name=name,
|
|
270
|
+
violation_error_code=violation_error_code,
|
|
271
|
+
violation_error_message=violation_error_message,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def contains_expressions(self) -> bool:
|
|
276
|
+
return bool(self.expressions)
|
|
277
|
+
|
|
278
|
+
def _get_condition_sql(
|
|
279
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
280
|
+
) -> str | None:
|
|
281
|
+
if self.condition is None:
|
|
282
|
+
return None
|
|
283
|
+
query = Query(model=model, alias_cols=False)
|
|
284
|
+
where = query.build_where(self.condition)
|
|
285
|
+
compiler = query.get_compiler()
|
|
286
|
+
sql, params = where.as_sql(compiler, schema_editor.connection)
|
|
287
|
+
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
|
288
|
+
|
|
289
|
+
def _get_index_expressions(
|
|
290
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
291
|
+
) -> Any:
|
|
292
|
+
if not self.expressions:
|
|
293
|
+
return None
|
|
294
|
+
index_expressions = []
|
|
295
|
+
for expression in self.expressions:
|
|
296
|
+
index_expression = IndexExpression(expression)
|
|
297
|
+
index_expressions.append(index_expression)
|
|
298
|
+
return ExpressionList(*index_expressions).resolve_expression(
|
|
299
|
+
Query(model, alias_cols=False),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def constraint_sql(
|
|
303
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
304
|
+
) -> str | None:
|
|
305
|
+
fields = [
|
|
306
|
+
model._model_meta.get_forward_field(field_name)
|
|
307
|
+
for field_name in self.fields
|
|
308
|
+
]
|
|
309
|
+
include = [
|
|
310
|
+
model._model_meta.get_forward_field(field_name).column
|
|
311
|
+
for field_name in self.include
|
|
312
|
+
]
|
|
313
|
+
condition = self._get_condition_sql(model, schema_editor)
|
|
314
|
+
expressions = self._get_index_expressions(model, schema_editor)
|
|
315
|
+
return schema_editor._unique_sql(
|
|
316
|
+
model,
|
|
317
|
+
fields,
|
|
318
|
+
self.name,
|
|
319
|
+
condition=condition,
|
|
320
|
+
deferrable=self.deferrable,
|
|
321
|
+
include=include,
|
|
322
|
+
opclasses=tuple(self.opclasses) if self.opclasses else None,
|
|
323
|
+
expressions=expressions,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def create_sql(
|
|
327
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
328
|
+
) -> Statement | None:
|
|
329
|
+
fields = [
|
|
330
|
+
model._model_meta.get_forward_field(field_name)
|
|
331
|
+
for field_name in self.fields
|
|
332
|
+
]
|
|
333
|
+
include = [
|
|
334
|
+
model._model_meta.get_forward_field(field_name).column
|
|
335
|
+
for field_name in self.include
|
|
336
|
+
]
|
|
337
|
+
condition = self._get_condition_sql(model, schema_editor)
|
|
338
|
+
expressions = self._get_index_expressions(model, schema_editor)
|
|
339
|
+
return schema_editor._create_unique_sql(
|
|
340
|
+
model,
|
|
341
|
+
fields,
|
|
342
|
+
self.name,
|
|
343
|
+
condition=condition,
|
|
344
|
+
deferrable=self.deferrable,
|
|
345
|
+
include=include,
|
|
346
|
+
opclasses=tuple(self.opclasses) if self.opclasses else None,
|
|
347
|
+
expressions=expressions,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def remove_sql(
|
|
351
|
+
self, model: type[Model], schema_editor: DatabaseSchemaEditor
|
|
352
|
+
) -> Statement | None:
|
|
353
|
+
condition = self._get_condition_sql(model, schema_editor)
|
|
354
|
+
include = [
|
|
355
|
+
model._model_meta.get_forward_field(field_name).column
|
|
356
|
+
for field_name in self.include
|
|
357
|
+
]
|
|
358
|
+
expressions = self._get_index_expressions(model, schema_editor)
|
|
359
|
+
return schema_editor._delete_unique_sql(
|
|
360
|
+
model,
|
|
361
|
+
self.name,
|
|
362
|
+
condition=condition,
|
|
363
|
+
deferrable=self.deferrable,
|
|
364
|
+
include=include,
|
|
365
|
+
opclasses=tuple(self.opclasses) if self.opclasses else None,
|
|
366
|
+
expressions=expressions,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def __repr__(self) -> str:
|
|
370
|
+
return "<{}:{}{}{}{}{}{}{}{}{}>".format(
|
|
371
|
+
self.__class__.__qualname__,
|
|
372
|
+
"" if not self.fields else f" fields={repr(self.fields)}",
|
|
373
|
+
"" if not self.expressions else f" expressions={repr(self.expressions)}",
|
|
374
|
+
f" name={repr(self.name)}",
|
|
375
|
+
"" if self.condition is None else f" condition={self.condition}",
|
|
376
|
+
"" if self.deferrable is None else f" deferrable={self.deferrable!r}",
|
|
377
|
+
"" if not self.include else f" include={repr(self.include)}",
|
|
378
|
+
"" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
|
|
379
|
+
(
|
|
380
|
+
""
|
|
381
|
+
if self.violation_error_code is None
|
|
382
|
+
else f" violation_error_code={self.violation_error_code!r}"
|
|
383
|
+
),
|
|
384
|
+
(
|
|
385
|
+
""
|
|
386
|
+
if self.violation_error_message is None
|
|
387
|
+
or self.violation_error_message == self.default_violation_error_message
|
|
388
|
+
else f" violation_error_message={self.violation_error_message!r}"
|
|
389
|
+
),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def __eq__(self, other: object) -> bool:
|
|
393
|
+
if isinstance(other, UniqueConstraint):
|
|
394
|
+
return (
|
|
395
|
+
self.name == other.name
|
|
396
|
+
and self.fields == other.fields
|
|
397
|
+
and self.condition == other.condition
|
|
398
|
+
and self.deferrable == other.deferrable
|
|
399
|
+
and self.include == other.include
|
|
400
|
+
and self.opclasses == other.opclasses
|
|
401
|
+
and self.expressions == other.expressions
|
|
402
|
+
and self.violation_error_code == other.violation_error_code
|
|
403
|
+
and self.violation_error_message == other.violation_error_message
|
|
404
|
+
)
|
|
405
|
+
return super().__eq__(other)
|
|
406
|
+
|
|
407
|
+
def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
|
|
408
|
+
path, args, kwargs = super().deconstruct()
|
|
409
|
+
if self.fields:
|
|
410
|
+
kwargs["fields"] = self.fields
|
|
411
|
+
if self.condition:
|
|
412
|
+
kwargs["condition"] = self.condition
|
|
413
|
+
if self.deferrable:
|
|
414
|
+
kwargs["deferrable"] = self.deferrable
|
|
415
|
+
if self.include:
|
|
416
|
+
kwargs["include"] = self.include
|
|
417
|
+
if self.opclasses:
|
|
418
|
+
kwargs["opclasses"] = self.opclasses
|
|
419
|
+
return path, self.expressions, kwargs
|
|
420
|
+
|
|
421
|
+
def validate(
|
|
422
|
+
self, model: type[Model], instance: Model, exclude: set[str] | None = None
|
|
423
|
+
) -> None:
|
|
424
|
+
queryset = model.query
|
|
425
|
+
if self.fields:
|
|
426
|
+
lookup_kwargs = {}
|
|
427
|
+
for field_name in self.fields:
|
|
428
|
+
if exclude and field_name in exclude:
|
|
429
|
+
return
|
|
430
|
+
field = model._model_meta.get_forward_field(field_name)
|
|
431
|
+
lookup_value = getattr(instance, field.attname)
|
|
432
|
+
if lookup_value is None:
|
|
433
|
+
# A composite constraint containing NULL value cannot cause
|
|
434
|
+
# a violation since NULL != NULL in SQL.
|
|
435
|
+
return
|
|
436
|
+
lookup_kwargs[field.name] = lookup_value
|
|
437
|
+
queryset = queryset.filter(**lookup_kwargs)
|
|
438
|
+
else:
|
|
439
|
+
# Ignore constraints with excluded fields.
|
|
440
|
+
if exclude:
|
|
441
|
+
for expression in self.expressions:
|
|
442
|
+
if hasattr(expression, "flatten"):
|
|
443
|
+
for expr in expression.flatten(): # type: ignore[operator]
|
|
444
|
+
if isinstance(expr, F) and expr.name in exclude:
|
|
445
|
+
return
|
|
446
|
+
elif isinstance(expression, F) and expression.name in exclude:
|
|
447
|
+
return
|
|
448
|
+
replacements: dict[Any, Any] = {
|
|
449
|
+
F(field): value
|
|
450
|
+
for field, value in instance._get_field_value_map(
|
|
451
|
+
meta=model._model_meta, exclude=exclude
|
|
452
|
+
).items()
|
|
453
|
+
}
|
|
454
|
+
expressions = []
|
|
455
|
+
for expr in self.expressions:
|
|
456
|
+
# Ignore ordering.
|
|
457
|
+
if isinstance(expr, OrderBy):
|
|
458
|
+
expr = expr.expression
|
|
459
|
+
expressions.append(Exact(expr, expr.replace_expressions(replacements)))
|
|
460
|
+
queryset = queryset.filter(*expressions)
|
|
461
|
+
model_class_id = instance.id
|
|
462
|
+
if not instance._state.adding and model_class_id is not None:
|
|
463
|
+
queryset = queryset.exclude(id=model_class_id)
|
|
464
|
+
if not self.condition:
|
|
465
|
+
if queryset.exists():
|
|
466
|
+
if self.expressions:
|
|
467
|
+
raise ValidationError(
|
|
468
|
+
self.get_violation_error_message(),
|
|
469
|
+
code=self.violation_error_code,
|
|
470
|
+
)
|
|
471
|
+
# When fields are defined, use the unique_error_message() for
|
|
472
|
+
# backward compatibility.
|
|
473
|
+
for constraint_model, constraints in instance.get_constraints():
|
|
474
|
+
for constraint in constraints:
|
|
475
|
+
if constraint is self:
|
|
476
|
+
raise ValidationError(
|
|
477
|
+
instance.unique_error_message(
|
|
478
|
+
constraint_model,
|
|
479
|
+
self.fields,
|
|
480
|
+
),
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
against = instance._get_field_value_map(
|
|
484
|
+
meta=model._model_meta, exclude=exclude
|
|
485
|
+
)
|
|
486
|
+
try:
|
|
487
|
+
if (self.condition & Exists(queryset.filter(self.condition))).check(
|
|
488
|
+
against
|
|
489
|
+
):
|
|
490
|
+
raise ValidationError(
|
|
491
|
+
self.get_violation_error_message(),
|
|
492
|
+
code=self.violation_error_code,
|
|
493
|
+
)
|
|
494
|
+
except FieldError:
|
|
495
|
+
pass
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Kenneth Reitz & individual contributors
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
# Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
# are permitted provided that the following conditions are met:
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice,
|
|
8
|
+
# this list of conditions and the following disclaimer.
|
|
9
|
+
# 2. Redistributions in binary form must reproduce the above copyright
|
|
10
|
+
# notice, this list of conditions and the following disclaimer in the
|
|
11
|
+
# documentation and/or other materials provided with the distribution.
|
|
12
|
+
# 3. Neither the name of Plain nor the names of its contributors may be used
|
|
13
|
+
# to endorse or promote products derived from this software without
|
|
14
|
+
# specific prior written permission.
|
|
15
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
16
|
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
17
|
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
18
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
|
19
|
+
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
20
|
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
21
|
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
22
|
+
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
23
|
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
24
|
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
25
|
+
import urllib.parse as urlparse
|
|
26
|
+
|
|
27
|
+
from .connections import DatabaseConfig
|
|
28
|
+
|
|
29
|
+
SCHEMES = {"postgres", "postgresql", "pgsql"}
|
|
30
|
+
|
|
31
|
+
# Register database schemes in URLs.
|
|
32
|
+
for scheme in SCHEMES:
|
|
33
|
+
urlparse.uses_netloc.append(scheme)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def parse_database_url(url: str) -> DatabaseConfig:
|
|
37
|
+
"""Parses a database URL."""
|
|
38
|
+
spliturl = urlparse.urlsplit(url)
|
|
39
|
+
|
|
40
|
+
if spliturl.scheme not in SCHEMES:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"No support for '{spliturl.scheme}'. We support: {', '.join(sorted(SCHEMES))}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
path = spliturl.path[1:]
|
|
46
|
+
query = urlparse.parse_qs(spliturl.query)
|
|
47
|
+
|
|
48
|
+
# Handle percent-encoded hostnames (e.g. socket paths).
|
|
49
|
+
hostname = spliturl.hostname or ""
|
|
50
|
+
if "%" in hostname:
|
|
51
|
+
# Use netloc to avoid lowercased paths, strip credentials if present.
|
|
52
|
+
hostname = spliturl.netloc
|
|
53
|
+
if "@" in hostname:
|
|
54
|
+
hostname = hostname.rsplit("@", 1)[1]
|
|
55
|
+
hostname = urlparse.unquote(hostname)
|
|
56
|
+
|
|
57
|
+
parsed_config: DatabaseConfig = {
|
|
58
|
+
"DATABASE": urlparse.unquote(path or ""),
|
|
59
|
+
"USER": urlparse.unquote(spliturl.username or ""),
|
|
60
|
+
"PASSWORD": urlparse.unquote(spliturl.password or ""),
|
|
61
|
+
"HOST": hostname,
|
|
62
|
+
"PORT": spliturl.port,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# Pass the query string into OPTIONS.
|
|
66
|
+
options = {key: values[-1] for key, values in query.items()}
|
|
67
|
+
if options:
|
|
68
|
+
parsed_config["OPTIONS"] = options
|
|
69
|
+
|
|
70
|
+
return parsed_config
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def build_database_url(config: DatabaseConfig) -> str:
|
|
74
|
+
"""Build a database URL from a configuration dictionary."""
|
|
75
|
+
options = config.get("OPTIONS", {})
|
|
76
|
+
query = urlparse.urlencode(list(options.items()))
|
|
77
|
+
|
|
78
|
+
user = urlparse.quote(str(config.get("USER", "")))
|
|
79
|
+
password = urlparse.quote(str(config.get("PASSWORD", "")))
|
|
80
|
+
host = config.get("HOST", "")
|
|
81
|
+
port = config.get("PORT")
|
|
82
|
+
name = urlparse.quote(str(config.get("DATABASE", "")))
|
|
83
|
+
|
|
84
|
+
netloc = ""
|
|
85
|
+
if user or password:
|
|
86
|
+
netloc += user
|
|
87
|
+
if password:
|
|
88
|
+
netloc += f":{password}"
|
|
89
|
+
netloc += "@"
|
|
90
|
+
netloc += host
|
|
91
|
+
if port:
|
|
92
|
+
netloc += f":{port}"
|
|
93
|
+
|
|
94
|
+
return urlparse.urlunsplit(("postgresql", netloc, f"/{name}", query, ""))
|