plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from contextvars import ContextVar
4
+ from typing import TYPE_CHECKING, Any, TypedDict
5
+
6
+ from plain.exceptions import ImproperlyConfigured
7
+ from plain.runtime import settings as plain_settings
8
+
9
+ if TYPE_CHECKING:
10
+ from plain.postgres.connection import DatabaseConnection
11
+
12
+
13
+ class DatabaseConfig(TypedDict, total=False):
14
+ CONN_MAX_AGE: int | None
15
+ CONN_HEALTH_CHECKS: bool
16
+ HOST: str
17
+ DATABASE: str | None
18
+ OPTIONS: dict[str, Any]
19
+ PASSWORD: str
20
+ PORT: int | None
21
+ TEST: dict[str, Any]
22
+ TIME_ZONE: str | None
23
+ USER: str
24
+
25
+
26
+ # Module-level ContextVar for per-task/per-thread connection storage.
27
+ # Each asyncio.Task gets its own copy (since Python 3.7.1).
28
+ # Thread pool threads maintain their own native context across work items,
29
+ # so connections persist across requests (honoring CONN_MAX_AGE).
30
+ _db_conn: ContextVar[DatabaseConnection | None] = ContextVar("_db_conn", default=None)
31
+
32
+
33
+ def _configure_settings() -> DatabaseConfig:
34
+ if plain_settings.POSTGRES_DATABASE == "":
35
+ raise ImproperlyConfigured(
36
+ "The PostgreSQL database has been disabled (DATABASE_URL=none). "
37
+ "No database operations are available in this context."
38
+ )
39
+ if not plain_settings.POSTGRES_DATABASE: # None or unresolved setting
40
+ raise ImproperlyConfigured(
41
+ "PostgreSQL database is not configured. "
42
+ "Set DATABASE_URL or the individual POSTGRES_* settings."
43
+ )
44
+
45
+ return {
46
+ "DATABASE": plain_settings.POSTGRES_DATABASE,
47
+ "USER": plain_settings.POSTGRES_USER,
48
+ "PASSWORD": plain_settings.POSTGRES_PASSWORD,
49
+ "HOST": plain_settings.POSTGRES_HOST,
50
+ "PORT": plain_settings.POSTGRES_PORT,
51
+ "CONN_MAX_AGE": plain_settings.POSTGRES_CONN_MAX_AGE,
52
+ "CONN_HEALTH_CHECKS": plain_settings.POSTGRES_CONN_HEALTH_CHECKS,
53
+ "OPTIONS": plain_settings.POSTGRES_OPTIONS,
54
+ "TIME_ZONE": plain_settings.POSTGRES_TIME_ZONE,
55
+ "TEST": {"DATABASE": None},
56
+ }
57
+
58
+
59
+ def _create_connection() -> DatabaseConnection:
60
+ from plain.postgres.connection import DatabaseConnection
61
+
62
+ database_config = _configure_settings()
63
+ return DatabaseConnection(database_config)
64
+
65
+
66
+ def get_connection() -> DatabaseConnection:
67
+ """Get or create the database connection for the current context."""
68
+ conn = _db_conn.get()
69
+ if conn is None:
70
+ conn = _create_connection()
71
+ _db_conn.set(conn)
72
+ return conn
73
+
74
+
75
+ def has_connection() -> bool:
76
+ """Check if a database connection exists in the current context."""
77
+ return _db_conn.get() is not None
@@ -0,0 +1,13 @@
1
+ """
2
+ Constants used across the ORM in general.
3
+ """
4
+
5
+ from enum import Enum
6
+
7
+ # Separator used to split filter strings apart.
8
+ LOOKUP_SEP = "__"
9
+
10
+
11
+ class OnConflict(Enum):
12
+ IGNORE = "ignore"
13
+ UPDATE = "update"
@@ -0,0 +1,495 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from types import NoneType
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from plain.exceptions import ValidationError
8
+ from plain.postgres.exceptions import FieldError
9
+ from plain.postgres.expressions import (
10
+ Exists,
11
+ ExpressionList,
12
+ F,
13
+ OrderBy,
14
+ ReplaceableExpression,
15
+ )
16
+ from plain.postgres.indexes import IndexExpression
17
+ from plain.postgres.lookups import Exact
18
+ from plain.postgres.query_utils import Q
19
+ from plain.postgres.sql.query import Query
20
+
21
+ if TYPE_CHECKING:
22
+ from plain.postgres.base import Model
23
+ from plain.postgres.schema import DatabaseSchemaEditor, Statement
24
+
25
+ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
26
+
27
+
28
+ class BaseConstraint:
29
+ default_violation_error_message = 'Constraint "%(name)s" is violated.'
30
+ violation_error_code: str | None = None
31
+ violation_error_message: str | None = None
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ name: str,
37
+ violation_error_code: str | None = None,
38
+ violation_error_message: str | None = None,
39
+ ) -> None:
40
+ self.name = name
41
+ if violation_error_code is not None:
42
+ self.violation_error_code = violation_error_code
43
+ if violation_error_message is not None:
44
+ self.violation_error_message = violation_error_message
45
+ else:
46
+ self.violation_error_message = self.default_violation_error_message
47
+
48
+ @property
49
+ def contains_expressions(self) -> bool:
50
+ return False
51
+
52
+ def constraint_sql(
53
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
54
+ ) -> str | None:
55
+ raise NotImplementedError(
56
+ "subclasses of BaseConstraint must provide a constraint_sql() method"
57
+ )
58
+
59
+ def create_sql(
60
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
61
+ ) -> str | Statement | None:
62
+ raise NotImplementedError(
63
+ "subclasses of BaseConstraint must provide a create_sql() method"
64
+ )
65
+
66
+ def remove_sql(
67
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
68
+ ) -> str | Statement | None:
69
+ raise NotImplementedError(
70
+ "subclasses of BaseConstraint must provide a remove_sql() method"
71
+ )
72
+
73
+ def validate(
74
+ self, model: type[Model], instance: Model, exclude: set[str] | None = None
75
+ ) -> None:
76
+ raise NotImplementedError(
77
+ "subclasses of BaseConstraint must provide a validate() method"
78
+ )
79
+
80
+ def get_violation_error_message(self) -> str:
81
+ assert self.violation_error_message is not None
82
+ return self.violation_error_message % {"name": self.name}
83
+
84
+ def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
85
+ path = f"{self.__class__.__module__}.{self.__class__.__name__}"
86
+ path = path.replace("plain.postgres.constraints", "plain.postgres")
87
+ kwargs: dict[str, Any] = {"name": self.name}
88
+ if (
89
+ self.violation_error_message is not None
90
+ and self.violation_error_message != self.default_violation_error_message
91
+ ):
92
+ kwargs["violation_error_message"] = self.violation_error_message
93
+ if self.violation_error_code is not None:
94
+ kwargs["violation_error_code"] = self.violation_error_code
95
+ return (path, (), kwargs)
96
+
97
+ def clone(self) -> BaseConstraint:
98
+ _, args, kwargs = self.deconstruct()
99
+ return self.__class__(*args, **kwargs)
100
+
101
+
102
+ class CheckConstraint(BaseConstraint):
103
+ def __init__(
104
+ self,
105
+ *,
106
+ check: Q,
107
+ name: str,
108
+ violation_error_code: str | None = None,
109
+ violation_error_message: str | None = None,
110
+ ) -> None:
111
+ self.check = check
112
+ if not getattr(check, "conditional", False):
113
+ raise TypeError(
114
+ "CheckConstraint.check must be a Q instance or boolean expression."
115
+ )
116
+ super().__init__(
117
+ name=name,
118
+ violation_error_code=violation_error_code,
119
+ violation_error_message=violation_error_message,
120
+ )
121
+
122
+ def _get_check_sql(
123
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
124
+ ) -> str:
125
+ query = Query(model=model, alias_cols=False)
126
+ where = query.build_where(self.check)
127
+ compiler = query.get_compiler()
128
+ sql, params = where.as_sql(compiler, schema_editor.connection)
129
+ return sql % tuple(schema_editor.quote_value(p) for p in params)
130
+
131
+ def constraint_sql(
132
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
133
+ ) -> str:
134
+ check = self._get_check_sql(model, schema_editor)
135
+ return schema_editor._check_sql(self.name, check)
136
+
137
+ def create_sql(
138
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
139
+ ) -> Statement | None:
140
+ check = self._get_check_sql(model, schema_editor)
141
+ return schema_editor._create_check_sql(model, self.name, check)
142
+
143
+ def remove_sql(
144
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
145
+ ) -> Statement | None:
146
+ return schema_editor._delete_constraint_sql(
147
+ schema_editor.sql_delete_check, model, self.name
148
+ )
149
+
150
+ def validate(
151
+ self, model: type[Model], instance: Model, exclude: set[str] | None = None
152
+ ) -> None:
153
+ against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
154
+ try:
155
+ if not Q(self.check).check(against):
156
+ raise ValidationError(
157
+ self.get_violation_error_message(), code=self.violation_error_code
158
+ )
159
+ except FieldError:
160
+ pass
161
+
162
+ def __repr__(self) -> str:
163
+ return "<{}: check={} name={}{}{}>".format(
164
+ self.__class__.__qualname__,
165
+ self.check,
166
+ repr(self.name),
167
+ (
168
+ ""
169
+ if self.violation_error_code is None
170
+ else f" violation_error_code={self.violation_error_code!r}"
171
+ ),
172
+ (
173
+ ""
174
+ if self.violation_error_message is None
175
+ or self.violation_error_message == self.default_violation_error_message
176
+ else f" violation_error_message={self.violation_error_message!r}"
177
+ ),
178
+ )
179
+
180
+ def __eq__(self, other: object) -> bool:
181
+ if isinstance(other, CheckConstraint):
182
+ return (
183
+ self.name == other.name
184
+ and self.check == other.check
185
+ and self.violation_error_code == other.violation_error_code
186
+ and self.violation_error_message == other.violation_error_message
187
+ )
188
+ return super().__eq__(other)
189
+
190
+ def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
191
+ path, args, kwargs = super().deconstruct()
192
+ kwargs["check"] = self.check
193
+ return path, args, kwargs
194
+
195
+
196
+ class Deferrable(Enum):
197
+ DEFERRED = "deferred"
198
+ IMMEDIATE = "immediate"
199
+
200
+ # A similar format was proposed for Python 3.10.
201
+ def __repr__(self) -> str:
202
+ return f"{self.__class__.__qualname__}.{self._name_}"
203
+
204
+
205
+ class UniqueConstraint(BaseConstraint):
206
+ expressions: tuple[ReplaceableExpression, ...]
207
+
208
+ def __init__(
209
+ self,
210
+ *expressions: str | ReplaceableExpression,
211
+ fields: tuple[str, ...] | list[str] = (),
212
+ name: str | None = None,
213
+ condition: Q | None = None,
214
+ deferrable: Deferrable | None = None,
215
+ include: tuple[str, ...] | list[str] | None = None,
216
+ opclasses: tuple[str, ...] | list[str] = (),
217
+ violation_error_code: str | None = None,
218
+ violation_error_message: str | None = None,
219
+ ) -> None:
220
+ if not name:
221
+ raise ValueError("A unique constraint must be named.")
222
+ if not expressions and not fields:
223
+ raise ValueError(
224
+ "At least one field or expression is required to define a "
225
+ "unique constraint."
226
+ )
227
+ if expressions and fields:
228
+ raise ValueError(
229
+ "UniqueConstraint.fields and expressions are mutually exclusive."
230
+ )
231
+ if not isinstance(condition, NoneType | Q):
232
+ raise ValueError("UniqueConstraint.condition must be a Q instance.")
233
+ if condition and deferrable:
234
+ raise ValueError("UniqueConstraint with conditions cannot be deferred.")
235
+ if include and deferrable:
236
+ raise ValueError("UniqueConstraint with include fields cannot be deferred.")
237
+ if opclasses and deferrable:
238
+ raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
239
+ if expressions and deferrable:
240
+ raise ValueError("UniqueConstraint with expressions cannot be deferred.")
241
+ if expressions and opclasses:
242
+ raise ValueError(
243
+ "UniqueConstraint.opclasses cannot be used with expressions. "
244
+ "Use a custom OpClass() instead."
245
+ )
246
+ if not isinstance(deferrable, NoneType | Deferrable):
247
+ raise ValueError(
248
+ "UniqueConstraint.deferrable must be a Deferrable instance."
249
+ )
250
+ if not isinstance(include, NoneType | list | tuple):
251
+ raise ValueError("UniqueConstraint.include must be a list or tuple.")
252
+ if not isinstance(opclasses, list | tuple):
253
+ raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
254
+ if opclasses and len(fields) != len(opclasses):
255
+ raise ValueError(
256
+ "UniqueConstraint.fields and UniqueConstraint.opclasses must "
257
+ "have the same number of elements."
258
+ )
259
+ self.fields = tuple(fields)
260
+ self.condition = condition
261
+ self.deferrable = deferrable
262
+ self.include = tuple(include) if include else ()
263
+ self.opclasses = opclasses
264
+ self.expressions = tuple(
265
+ F(expression) if isinstance(expression, str) else expression
266
+ for expression in expressions
267
+ )
268
+ super().__init__(
269
+ name=name,
270
+ violation_error_code=violation_error_code,
271
+ violation_error_message=violation_error_message,
272
+ )
273
+
274
+ @property
275
+ def contains_expressions(self) -> bool:
276
+ return bool(self.expressions)
277
+
278
+ def _get_condition_sql(
279
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
280
+ ) -> str | None:
281
+ if self.condition is None:
282
+ return None
283
+ query = Query(model=model, alias_cols=False)
284
+ where = query.build_where(self.condition)
285
+ compiler = query.get_compiler()
286
+ sql, params = where.as_sql(compiler, schema_editor.connection)
287
+ return sql % tuple(schema_editor.quote_value(p) for p in params)
288
+
289
+ def _get_index_expressions(
290
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
291
+ ) -> Any:
292
+ if not self.expressions:
293
+ return None
294
+ index_expressions = []
295
+ for expression in self.expressions:
296
+ index_expression = IndexExpression(expression)
297
+ index_expressions.append(index_expression)
298
+ return ExpressionList(*index_expressions).resolve_expression(
299
+ Query(model, alias_cols=False),
300
+ )
301
+
302
+ def constraint_sql(
303
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
304
+ ) -> str | None:
305
+ fields = [
306
+ model._model_meta.get_forward_field(field_name)
307
+ for field_name in self.fields
308
+ ]
309
+ include = [
310
+ model._model_meta.get_forward_field(field_name).column
311
+ for field_name in self.include
312
+ ]
313
+ condition = self._get_condition_sql(model, schema_editor)
314
+ expressions = self._get_index_expressions(model, schema_editor)
315
+ return schema_editor._unique_sql(
316
+ model,
317
+ fields,
318
+ self.name,
319
+ condition=condition,
320
+ deferrable=self.deferrable,
321
+ include=include,
322
+ opclasses=tuple(self.opclasses) if self.opclasses else None,
323
+ expressions=expressions,
324
+ )
325
+
326
+ def create_sql(
327
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
328
+ ) -> Statement | None:
329
+ fields = [
330
+ model._model_meta.get_forward_field(field_name)
331
+ for field_name in self.fields
332
+ ]
333
+ include = [
334
+ model._model_meta.get_forward_field(field_name).column
335
+ for field_name in self.include
336
+ ]
337
+ condition = self._get_condition_sql(model, schema_editor)
338
+ expressions = self._get_index_expressions(model, schema_editor)
339
+ return schema_editor._create_unique_sql(
340
+ model,
341
+ fields,
342
+ self.name,
343
+ condition=condition,
344
+ deferrable=self.deferrable,
345
+ include=include,
346
+ opclasses=tuple(self.opclasses) if self.opclasses else None,
347
+ expressions=expressions,
348
+ )
349
+
350
+ def remove_sql(
351
+ self, model: type[Model], schema_editor: DatabaseSchemaEditor
352
+ ) -> Statement | None:
353
+ condition = self._get_condition_sql(model, schema_editor)
354
+ include = [
355
+ model._model_meta.get_forward_field(field_name).column
356
+ for field_name in self.include
357
+ ]
358
+ expressions = self._get_index_expressions(model, schema_editor)
359
+ return schema_editor._delete_unique_sql(
360
+ model,
361
+ self.name,
362
+ condition=condition,
363
+ deferrable=self.deferrable,
364
+ include=include,
365
+ opclasses=tuple(self.opclasses) if self.opclasses else None,
366
+ expressions=expressions,
367
+ )
368
+
369
+ def __repr__(self) -> str:
370
+ return "<{}:{}{}{}{}{}{}{}{}{}>".format(
371
+ self.__class__.__qualname__,
372
+ "" if not self.fields else f" fields={repr(self.fields)}",
373
+ "" if not self.expressions else f" expressions={repr(self.expressions)}",
374
+ f" name={repr(self.name)}",
375
+ "" if self.condition is None else f" condition={self.condition}",
376
+ "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
377
+ "" if not self.include else f" include={repr(self.include)}",
378
+ "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
379
+ (
380
+ ""
381
+ if self.violation_error_code is None
382
+ else f" violation_error_code={self.violation_error_code!r}"
383
+ ),
384
+ (
385
+ ""
386
+ if self.violation_error_message is None
387
+ or self.violation_error_message == self.default_violation_error_message
388
+ else f" violation_error_message={self.violation_error_message!r}"
389
+ ),
390
+ )
391
+
392
+ def __eq__(self, other: object) -> bool:
393
+ if isinstance(other, UniqueConstraint):
394
+ return (
395
+ self.name == other.name
396
+ and self.fields == other.fields
397
+ and self.condition == other.condition
398
+ and self.deferrable == other.deferrable
399
+ and self.include == other.include
400
+ and self.opclasses == other.opclasses
401
+ and self.expressions == other.expressions
402
+ and self.violation_error_code == other.violation_error_code
403
+ and self.violation_error_message == other.violation_error_message
404
+ )
405
+ return super().__eq__(other)
406
+
407
+ def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
408
+ path, args, kwargs = super().deconstruct()
409
+ if self.fields:
410
+ kwargs["fields"] = self.fields
411
+ if self.condition:
412
+ kwargs["condition"] = self.condition
413
+ if self.deferrable:
414
+ kwargs["deferrable"] = self.deferrable
415
+ if self.include:
416
+ kwargs["include"] = self.include
417
+ if self.opclasses:
418
+ kwargs["opclasses"] = self.opclasses
419
+ return path, self.expressions, kwargs
420
+
421
+ def validate(
422
+ self, model: type[Model], instance: Model, exclude: set[str] | None = None
423
+ ) -> None:
424
+ queryset = model.query
425
+ if self.fields:
426
+ lookup_kwargs = {}
427
+ for field_name in self.fields:
428
+ if exclude and field_name in exclude:
429
+ return
430
+ field = model._model_meta.get_forward_field(field_name)
431
+ lookup_value = getattr(instance, field.attname)
432
+ if lookup_value is None:
433
+ # A composite constraint containing NULL value cannot cause
434
+ # a violation since NULL != NULL in SQL.
435
+ return
436
+ lookup_kwargs[field.name] = lookup_value
437
+ queryset = queryset.filter(**lookup_kwargs)
438
+ else:
439
+ # Ignore constraints with excluded fields.
440
+ if exclude:
441
+ for expression in self.expressions:
442
+ if hasattr(expression, "flatten"):
443
+ for expr in expression.flatten(): # type: ignore[operator]
444
+ if isinstance(expr, F) and expr.name in exclude:
445
+ return
446
+ elif isinstance(expression, F) and expression.name in exclude:
447
+ return
448
+ replacements: dict[Any, Any] = {
449
+ F(field): value
450
+ for field, value in instance._get_field_value_map(
451
+ meta=model._model_meta, exclude=exclude
452
+ ).items()
453
+ }
454
+ expressions = []
455
+ for expr in self.expressions:
456
+ # Ignore ordering.
457
+ if isinstance(expr, OrderBy):
458
+ expr = expr.expression
459
+ expressions.append(Exact(expr, expr.replace_expressions(replacements)))
460
+ queryset = queryset.filter(*expressions)
461
+ model_class_id = instance.id
462
+ if not instance._state.adding and model_class_id is not None:
463
+ queryset = queryset.exclude(id=model_class_id)
464
+ if not self.condition:
465
+ if queryset.exists():
466
+ if self.expressions:
467
+ raise ValidationError(
468
+ self.get_violation_error_message(),
469
+ code=self.violation_error_code,
470
+ )
471
+ # When fields are defined, use the unique_error_message() for
472
+ # backward compatibility.
473
+ for constraint_model, constraints in instance.get_constraints():
474
+ for constraint in constraints:
475
+ if constraint is self:
476
+ raise ValidationError(
477
+ instance.unique_error_message(
478
+ constraint_model,
479
+ self.fields,
480
+ ),
481
+ )
482
+ else:
483
+ against = instance._get_field_value_map(
484
+ meta=model._model_meta, exclude=exclude
485
+ )
486
+ try:
487
+ if (self.condition & Exists(queryset.filter(self.condition))).check(
488
+ against
489
+ ):
490
+ raise ValidationError(
491
+ self.get_violation_error_message(),
492
+ code=self.violation_error_code,
493
+ )
494
+ except FieldError:
495
+ pass
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ # Copyright (c) Kenneth Reitz & individual contributors
4
+ # All rights reserved.
5
+ # Redistribution and use in source and binary forms, with or without modification,
6
+ # are permitted provided that the following conditions are met:
7
+ # 1. Redistributions of source code must retain the above copyright notice,
8
+ # this list of conditions and the following disclaimer.
9
+ # 2. Redistributions in binary form must reproduce the above copyright
10
+ # notice, this list of conditions and the following disclaimer in the
11
+ # documentation and/or other materials provided with the distribution.
12
+ # 3. Neither the name of Plain nor the names of its contributors may be used
13
+ # to endorse or promote products derived from this software without
14
+ # specific prior written permission.
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19
+ # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ import urllib.parse as urlparse
26
+
27
+ from .connections import DatabaseConfig
28
+
29
+ SCHEMES = {"postgres", "postgresql", "pgsql"}
30
+
31
+ # Register database schemes in URLs.
32
+ for scheme in SCHEMES:
33
+ urlparse.uses_netloc.append(scheme)
34
+
35
+
36
+ def parse_database_url(url: str) -> DatabaseConfig:
37
+ """Parses a database URL."""
38
+ spliturl = urlparse.urlsplit(url)
39
+
40
+ if spliturl.scheme not in SCHEMES:
41
+ raise ValueError(
42
+ f"No support for '{spliturl.scheme}'. We support: {', '.join(sorted(SCHEMES))}"
43
+ )
44
+
45
+ path = spliturl.path[1:]
46
+ query = urlparse.parse_qs(spliturl.query)
47
+
48
+ # Handle percent-encoded hostnames (e.g. socket paths).
49
+ hostname = spliturl.hostname or ""
50
+ if "%" in hostname:
51
+ # Use netloc to avoid lowercased paths, strip credentials if present.
52
+ hostname = spliturl.netloc
53
+ if "@" in hostname:
54
+ hostname = hostname.rsplit("@", 1)[1]
55
+ hostname = urlparse.unquote(hostname)
56
+
57
+ parsed_config: DatabaseConfig = {
58
+ "DATABASE": urlparse.unquote(path or ""),
59
+ "USER": urlparse.unquote(spliturl.username or ""),
60
+ "PASSWORD": urlparse.unquote(spliturl.password or ""),
61
+ "HOST": hostname,
62
+ "PORT": spliturl.port,
63
+ }
64
+
65
+ # Pass the query string into OPTIONS.
66
+ options = {key: values[-1] for key, values in query.items()}
67
+ if options:
68
+ parsed_config["OPTIONS"] = options
69
+
70
+ return parsed_config
71
+
72
+
73
+ def build_database_url(config: DatabaseConfig) -> str:
74
+ """Build a database URL from a configuration dictionary."""
75
+ options = config.get("OPTIONS", {})
76
+ query = urlparse.urlencode(list(options.items()))
77
+
78
+ user = urlparse.quote(str(config.get("USER", "")))
79
+ password = urlparse.quote(str(config.get("PASSWORD", "")))
80
+ host = config.get("HOST", "")
81
+ port = config.get("PORT")
82
+ name = urlparse.quote(str(config.get("DATABASE", "")))
83
+
84
+ netloc = ""
85
+ if user or password:
86
+ netloc += user
87
+ if password:
88
+ netloc += f":{password}"
89
+ netloc += "@"
90
+ netloc += host
91
+ if port:
92
+ netloc += f":{port}"
93
+
94
+ return urlparse.urlunsplit(("postgresql", netloc, f"/{name}", query, ""))