plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,120 @@
1
+ from .registry import models_registry, register_model # noqa Create the registry first
2
+ from . import (
3
+ preflight, # noqa Imported for side effects (registers preflight checks)
4
+ )
5
+
6
+ # Imports that would create circular imports if sorted
7
+ from .base import Model
8
+ from .constraints import CheckConstraint, UniqueConstraint
9
+ from .db import IntegrityError, get_connection
10
+ from .deletion import CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL
11
+ from .enums import IntegerChoices, TextChoices
12
+ from .fields import (
13
+ BigIntegerField,
14
+ BinaryField,
15
+ BooleanField,
16
+ CharField,
17
+ DateField,
18
+ DateTimeField,
19
+ DecimalField,
20
+ DurationField,
21
+ EmailField,
22
+ FloatField,
23
+ GenericIPAddressField,
24
+ IntegerField,
25
+ PositiveBigIntegerField,
26
+ PositiveIntegerField,
27
+ PositiveSmallIntegerField,
28
+ PrimaryKeyField,
29
+ SmallIntegerField,
30
+ TextField,
31
+ TimeField,
32
+ URLField,
33
+ UUIDField,
34
+ )
35
+ from .fields.json import JSONField
36
+ from .fields.timezones import TimeZoneField
37
+ from .fields.related import (
38
+ ForeignKeyField,
39
+ ManyToManyField,
40
+ )
41
+ from .fields.reverse_descriptors import (
42
+ ReverseForeignKey,
43
+ ReverseManyToMany,
44
+ )
45
+ from .indexes import Index
46
+ from .options import Options
47
+ from .query import QuerySet
48
+ from .query_utils import Q
49
+ from . import types
50
+
51
+ # This module exports the user-facing API for defining model classes,
52
+ # with a secondary focus on the most common query utilities like Q.
53
+ # Advanced query-time features (aggregates, expressions, etc.) should be
54
+ # imported from their specific modules (e.g., plain.postgres.aggregates).
55
+ __all__ = [
56
+ # From constraints
57
+ "CheckConstraint",
58
+ "UniqueConstraint",
59
+ # From enums
60
+ "IntegerChoices",
61
+ "TextChoices",
62
+ # From fields
63
+ "BigIntegerField",
64
+ "BinaryField",
65
+ "BooleanField",
66
+ "CharField",
67
+ "DateField",
68
+ "DateTimeField",
69
+ "DecimalField",
70
+ "DurationField",
71
+ "EmailField",
72
+ "FloatField",
73
+ "GenericIPAddressField",
74
+ "IntegerField",
75
+ "PositiveBigIntegerField",
76
+ "PositiveIntegerField",
77
+ "PositiveSmallIntegerField",
78
+ "PrimaryKeyField",
79
+ "SmallIntegerField",
80
+ "TextField",
81
+ "TimeField",
82
+ "URLField",
83
+ "UUIDField",
84
+ # From fields.json
85
+ "JSONField",
86
+ # From fields.timezones
87
+ "TimeZoneField",
88
+ # From indexes
89
+ "Index",
90
+ # From deletion
91
+ "CASCADE",
92
+ "DO_NOTHING",
93
+ "PROTECT",
94
+ "RESTRICT",
95
+ "SET",
96
+ "SET_DEFAULT",
97
+ "SET_NULL",
98
+ # From options
99
+ "Options",
100
+ # From query
101
+ "QuerySet",
102
+ # From query_utils
103
+ "Q",
104
+ # From base
105
+ "Model",
106
+ # From fields.related
107
+ "ForeignKeyField",
108
+ "ManyToManyField",
109
+ # From fields.reverse_descriptors
110
+ "ReverseForeignKey",
111
+ "ReverseManyToMany",
112
+ # From db
113
+ "get_connection",
114
+ "IntegrityError",
115
+ # From registry
116
+ "register_model",
117
+ "models_registry",
118
+ # Typed field imports
119
+ "types",
120
+ ]
@@ -0,0 +1,78 @@
1
+ ---
2
+ paths:
3
+ - "**/*.py"
4
+ ---
5
+
6
+ # Database & Models
7
+
8
+ ## Field Imports
9
+
10
+ Import fields via `from plain.postgres import types` and annotate with Python types:
11
+
12
+ ```python
13
+ from plain.postgres import types
14
+
15
+ name: str = types.CharField(max_length=100)
16
+ car: Car = types.ForeignKeyField("Car", on_delete=postgres.CASCADE)
17
+ ```
18
+
19
+ Do NOT import field classes directly from `plain.postgres` or `plain.postgres.fields`.
20
+
21
+ ## Schema Changes
22
+
23
+ When creating new models or modifying existing model fields/relationships, always enter plan mode first. Database schema is hard to change after the fact, so get the design right before writing code.
24
+
25
+ In your plan, present:
26
+
27
+ - Proposed schema as a table (model, field, type, constraints)
28
+ - Relationship cardinality (1:1, 1:N, M:N)
29
+ - Key decisions: nullable vs default, indexing, cascade behavior
30
+ - Whether the data could live on an existing model instead of a new one
31
+
32
+ Get approval before writing any model code or generating migrations.
33
+
34
+ ## Migrations
35
+
36
+ - `uv run plain makemigrations` — create migrations (`--dry-run` to preview, `--check` for CI)
37
+ - `uv run plain migrate --backup` — apply migrations
38
+ - `uv run plain migrations list` — view status (not `migrate --list`)
39
+ - Before committing, consolidate multiple uncommitted migrations into one:
40
+ delete the intermediate files, run `migrations prune --yes` to clean stale DB records,
41
+ run `makemigrations` fresh, then `migrate --fake` to mark it applied
42
+ - Use `migrations squash` only for already-committed/deployed migrations — never for dev cleanup
43
+ - Only write migrations by hand for custom data migrations
44
+
45
+ Run `uv run plain docs postgres --section migrations` for full workflow details.
46
+
47
+ ## Querying
48
+
49
+ Use `Model.query` to build querysets (e.g., `User.query.filter(is_active=True)`).
50
+
51
+ - Use `select_related()` for FK access in loops, `prefetch_related()` for reverse/M2N
52
+ - Use `.annotate(Count(...))` instead of calling `.count()` per row
53
+ - Fetch all data in the view — templates should never trigger queries
54
+ - Use `.exists()` not `.count() > 0`, `.count()` not `len(qs)`
55
+ - Use `bulk_create`/`bulk_update` for batch ops, `.update()`/`.delete()` for mass ops
56
+ - Use `.values_list()` when you only need specific columns
57
+ - Wrap multi-step writes in `transaction.atomic()`
58
+ - Always paginate list queries — unbounded querysets get slower as data grows
59
+
60
+ Run `uv run plain docs postgres --section querying` for full patterns with code examples.
61
+
62
+ ## Schema Design
63
+
64
+ - Index fields used in `.filter()` and `.order_by()`
65
+ - Use `UniqueConstraint` in constraints, not `unique=True` on fields
66
+ - Choose `on_delete` deliberately: CASCADE for children, PROTECT for referenced data
67
+ - No `allow_null` on string fields — use `default=""`
68
+
69
+ Run `uv run plain docs postgres --section constraints` for full patterns with code examples.
70
+
71
+ ## Differences from Django
72
+
73
+ - Use `Model.query` not `Model.objects`
74
+ - Import fields from `plain.postgres.types` not `plain.postgres.fields` — and don't import field classes directly from `plain.postgres`
75
+ - Use `model_options = postgres.Options(...)` not `class Meta`
76
+ - Fields don't accept `unique=True` — use `UniqueConstraint` in constraints
77
+ - Never format raw SQL strings — always use parameterized queries
78
+ - Migrations are forward-only — no reverse migrations. `RunPython` takes a single callable (no `reverse_code` or `noop`). The callable signature is `fn(models, schema_editor)`, not `fn(apps, schema_editor)`
@@ -0,0 +1,236 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from plain.postgres.exceptions import FieldError, FullResultSet
6
+ from plain.postgres.expressions import (
7
+ Func,
8
+ ResolvableExpression,
9
+ Star,
10
+ Value,
11
+ )
12
+ from plain.postgres.fields import IntegerField
13
+ from plain.postgres.functions.comparison import Coalesce
14
+ from plain.postgres.functions.mixins import NumericOutputFieldMixin
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Sequence
18
+
19
+ from plain.postgres.connection import DatabaseConnection
20
+ from plain.postgres.expressions import Expression
21
+ from plain.postgres.query_utils import Q
22
+ from plain.postgres.sql.compiler import SQLCompiler
23
+
24
+
25
+ __all__ = [
26
+ "Aggregate",
27
+ "Avg",
28
+ "Count",
29
+ "Max",
30
+ "Min",
31
+ "StdDev",
32
+ "Sum",
33
+ "Variance",
34
+ ]
35
+
36
+
37
+ class Aggregate(Func):
38
+ template = "%(function)s(%(distinct)s%(expressions)s)"
39
+ contains_aggregate = True
40
+ name = None
41
+ filter_template = "%s FILTER (WHERE %%(filter)s)"
42
+ window_compatible = True
43
+ allow_distinct = False
44
+ empty_result_set_value = None
45
+
46
+ def __init__(
47
+ self,
48
+ *expressions: Any,
49
+ distinct: bool = False,
50
+ filter: Q | Expression | None = None,
51
+ default: Any = None,
52
+ **extra: Any,
53
+ ) -> None:
54
+ if distinct and not self.allow_distinct:
55
+ raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
56
+ if default is not None and self.empty_result_set_value is not None:
57
+ raise TypeError(f"{self.__class__.__name__} does not allow default.")
58
+ self.distinct = distinct
59
+ self.filter = filter
60
+ self.default = default
61
+ super().__init__(*expressions, **extra)
62
+
63
+ def get_source_fields(self) -> list[Any]:
64
+ # Don't return the filter expression since it's not a source field.
65
+ return [e._output_field_or_none for e in super().get_source_expressions()]
66
+
67
+ def get_source_expressions(self) -> list[Expression]:
68
+ source_expressions = super().get_source_expressions()
69
+ if self.filter:
70
+ return source_expressions + [self.filter]
71
+ return source_expressions
72
+
73
+ def set_source_expressions(self, exprs: Sequence[Any]) -> None:
74
+ exprs_list = list(exprs)
75
+ self.filter = self.filter and exprs_list.pop()
76
+ super().set_source_expressions(exprs_list)
77
+
78
+ def resolve_expression( # type: ignore[override]
79
+ self,
80
+ query: Any = None,
81
+ allow_joins: bool = True,
82
+ reuse: Any = None,
83
+ summarize: bool = False,
84
+ for_save: bool = False,
85
+ ) -> Expression:
86
+ # Aggregates are not allowed in UPDATE queries, so ignore for_save
87
+ c = super().resolve_expression(query, allow_joins, reuse, summarize)
88
+ if c.filter is not None:
89
+ c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
90
+ if not summarize:
91
+ # Call Aggregate.get_source_expressions() to avoid
92
+ # returning self.filter and including that in this loop.
93
+ expressions = super(Aggregate, c).get_source_expressions()
94
+ for index, expr in enumerate(expressions):
95
+ if expr.contains_aggregate:
96
+ before_resolved = self.get_source_expressions()[index]
97
+ name = (
98
+ before_resolved.name
99
+ if hasattr(before_resolved, "name")
100
+ else repr(before_resolved)
101
+ )
102
+ raise FieldError(
103
+ f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
104
+ )
105
+ if (default := c.default) is None:
106
+ return c
107
+ if isinstance(default, ResolvableExpression):
108
+ default = default.resolve_expression(query, allow_joins, reuse, summarize)
109
+ if default._output_field_or_none is None:
110
+ default.output_field = c._output_field_or_none
111
+ else:
112
+ default = Value(default, c._output_field_or_none)
113
+ c.default = None # Reset the default argument before wrapping.
114
+ coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
115
+ coalesce.is_summary = c.is_summary
116
+ return coalesce
117
+
118
+ @property
119
+ def default_alias(self) -> str:
120
+ expressions = self.get_source_expressions()
121
+ if len(expressions) == 1 and hasattr(expressions[0], "name"):
122
+ if self.name is None:
123
+ raise TypeError("Aggregate subclasses must define a name")
124
+ return f"{expressions[0].name}__{self.name.lower()}"
125
+ raise TypeError("Complex expressions require an alias")
126
+
127
+ def get_group_by_cols(self) -> list[Any]:
128
+ return []
129
+
130
+ def as_sql(
131
+ self,
132
+ compiler: SQLCompiler,
133
+ connection: DatabaseConnection,
134
+ function: str | None = None,
135
+ template: str | None = None,
136
+ arg_joiner: str | None = None,
137
+ **extra_context: Any,
138
+ ) -> tuple[str, list[Any]]:
139
+ extra_context["distinct"] = "DISTINCT " if self.distinct else ""
140
+ if self.filter is not None:
141
+ # Use FILTER clause for aggregates when filter is specified
142
+ try:
143
+ filter_sql, filter_params = self.filter.as_sql(compiler, connection) # type: ignore[union-attr]
144
+ except FullResultSet:
145
+ pass
146
+ else:
147
+ filter_template = self.filter_template % extra_context.get(
148
+ "template", template or self.template
149
+ )
150
+ sql, params = super().as_sql(
151
+ compiler,
152
+ connection,
153
+ function=function,
154
+ template=filter_template,
155
+ arg_joiner=arg_joiner,
156
+ filter=filter_sql,
157
+ **extra_context,
158
+ )
159
+ return sql, [*params, *filter_params]
160
+ return super().as_sql(
161
+ compiler,
162
+ connection,
163
+ function=function,
164
+ template=template,
165
+ arg_joiner=arg_joiner,
166
+ **extra_context,
167
+ )
168
+
169
+ def _get_repr_options(self) -> dict[str, Any]:
170
+ options = super()._get_repr_options()
171
+ if self.distinct:
172
+ options["distinct"] = self.distinct
173
+ if self.filter:
174
+ options["filter"] = self.filter
175
+ return options
176
+
177
+
178
+ class Avg(NumericOutputFieldMixin, Aggregate):
179
+ function = "AVG"
180
+ name = "Avg"
181
+ allow_distinct = True
182
+
183
+
184
+ class Count(Aggregate):
185
+ function = "COUNT"
186
+ name = "Count"
187
+ output_field = IntegerField()
188
+ allow_distinct = True
189
+ empty_result_set_value = 0
190
+
191
+ def __init__(
192
+ self, expression: Any, filter: Q | Expression | None = None, **extra: Any
193
+ ) -> None:
194
+ if expression == "*":
195
+ expression = Star()
196
+ if isinstance(expression, Star) and filter is not None:
197
+ raise ValueError("Star cannot be used with filter. Please specify a field.")
198
+ super().__init__(expression, filter=filter, **extra)
199
+
200
+
201
+ class Max(Aggregate):
202
+ function = "MAX"
203
+ name = "Max"
204
+
205
+
206
+ class Min(Aggregate):
207
+ function = "MIN"
208
+ name = "Min"
209
+
210
+
211
+ class StdDev(NumericOutputFieldMixin, Aggregate):
212
+ name = "StdDev"
213
+
214
+ def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
215
+ self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
216
+ super().__init__(expression, **extra)
217
+
218
+ def _get_repr_options(self) -> dict[str, Any]:
219
+ return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
220
+
221
+
222
+ class Sum(Aggregate):
223
+ function = "SUM"
224
+ name = "Sum"
225
+ allow_distinct = True
226
+
227
+
228
+ class Variance(NumericOutputFieldMixin, Aggregate):
229
+ name = "Variance"
230
+
231
+ def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
232
+ self.function = "VAR_SAMP" if sample else "VAR_POP"
233
+ super().__init__(expression, **extra)
234
+
235
+ def _get_repr_options(self) -> dict[str, Any]:
236
+ return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
File without changes
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import click
8
+
9
+ from .core import DatabaseBackups, get_git_branch
10
+
11
+
12
+ @click.group("backups")
13
+ def cli() -> None:
14
+ """Local database backups"""
15
+ pass
16
+
17
+
18
+ @cli.command("list")
19
+ @click.option(
20
+ "--branch",
21
+ "branch",
22
+ is_flag=False,
23
+ flag_value="__current__",
24
+ default=None,
25
+ help="Filter by branch (defaults to current branch if flag used without value)",
26
+ )
27
+ def list_backups(branch: str | None) -> None:
28
+ """List database backups"""
29
+ backups_handler = DatabaseBackups()
30
+ backups = backups_handler.find_backups()
31
+
32
+ # Resolve branch filter
33
+ if branch == "__current__":
34
+ branch = get_git_branch()
35
+
36
+ # Filter by branch if specified
37
+ if branch:
38
+ backups = [b for b in backups if b.metadata.get("git_branch") == branch]
39
+
40
+ if not backups:
41
+ if branch:
42
+ click.secho(f"No backups found for branch '{branch}'", fg="yellow")
43
+ else:
44
+ click.secho("No backups found", fg="yellow")
45
+ return
46
+
47
+ # Calculate column widths
48
+ name_width = max(len(b.name) for b in backups)
49
+ source_width = max(len(b.metadata.get("source") or "-") for b in backups)
50
+
51
+ # Print header
52
+ click.secho(
53
+ f"{'NAME':<{name_width}} {'SOURCE':<{source_width}} {'SIZE':<10} BRANCH",
54
+ dim=True,
55
+ )
56
+
57
+ # Print rows
58
+ for backup in backups:
59
+ backup_file = backup.path / "default.backup"
60
+ if backup_file.exists():
61
+ size = os.path.getsize(backup_file)
62
+ size_str = f"{size / 1024 / 1024:.2f} MB"
63
+ else:
64
+ size_str = "-"
65
+ metadata = backup.metadata
66
+ source = metadata.get("source") or "-"
67
+ git_branch = metadata.get("git_branch") or "-"
68
+
69
+ click.echo(
70
+ f"{backup.name:<{name_width}} {source:<{source_width}} {size_str:<10} {git_branch}"
71
+ )
72
+
73
+
74
+ @cli.command("create")
75
+ @click.option("--pg-dump", default="pg_dump", envvar="PG_DUMP")
76
+ @click.argument("backup_name", default="")
77
+ def create_backup(backup_name: str, pg_dump: str) -> None:
78
+ """Create a database backup"""
79
+ backups_handler = DatabaseBackups()
80
+
81
+ if not backup_name:
82
+ backup_name = time.strftime("%Y%m%d_%H%M%S")
83
+
84
+ try:
85
+ backup_dir = backups_handler.create(
86
+ backup_name,
87
+ source="manual",
88
+ pg_dump=pg_dump,
89
+ )
90
+ except Exception as e:
91
+ click.secho(str(e), fg="red")
92
+ exit(1)
93
+
94
+ click.secho(f"Backup created in {backup_dir.relative_to(Path.cwd())}", fg="green")
95
+
96
+
97
+ @cli.command("restore")
98
+ @click.option("--latest", is_flag=True)
99
+ @click.option("--pg-restore", default="pg_restore", envvar="PG_RESTORE")
100
+ @click.argument("backup_name", default="")
101
+ def restore_backup(backup_name: str, latest: bool, pg_restore: str) -> None:
102
+ """Restore a database backup"""
103
+ backups_handler = DatabaseBackups()
104
+
105
+ if backup_name and latest:
106
+ raise click.UsageError("Only one of --latest or backup_name is allowed")
107
+
108
+ if not backup_name and not latest:
109
+ raise click.UsageError("Backup name or --latest is required")
110
+
111
+ if not backup_name and latest:
112
+ backup_name = backups_handler.find_backups()[0].name
113
+
114
+ click.secho(f"Restoring backup {backup_name}...", bold=True)
115
+
116
+ try:
117
+ backups_handler.restore(
118
+ backup_name,
119
+ pg_restore=pg_restore,
120
+ )
121
+ except Exception as e:
122
+ click.secho(str(e), fg="red")
123
+ exit(1)
124
+ click.echo(f"Backup {backup_name} restored successfully.")
125
+
126
+
127
+ @cli.command("delete")
128
+ @click.argument("backup_name")
129
+ def delete_backup(backup_name: str) -> None:
130
+ """Delete a database backup"""
131
+ backups_handler = DatabaseBackups()
132
+ try:
133
+ backups_handler.delete(backup_name)
134
+ except Exception as e:
135
+ click.secho(str(e), fg="red")
136
+ return
137
+ click.secho(f"Backup {backup_name} deleted", fg="green")
138
+
139
+
140
+ @cli.command("clear")
141
+ @click.confirmation_option(prompt="Are you sure you want to delete all backups?")
142
+ def clear_backups() -> None:
143
+ """Clear all database backups"""
144
+ backups_handler = DatabaseBackups()
145
+ backups = backups_handler.find_backups()
146
+ for backup in backups:
147
+ backup.delete()
148
+ click.secho("All backups deleted", fg="green")
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import subprocess
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from plain.exceptions import ImproperlyConfigured
9
+
10
+ if TYPE_CHECKING:
11
+ from plain.postgres.connection import DatabaseConnection
12
+
13
+
14
+ class PostgresBackupClient:
15
+ def __init__(self, connection: DatabaseConnection) -> None:
16
+ self.connection = connection
17
+
18
+ def get_env(self) -> dict[str, str]:
19
+ settings_dict = self.connection.settings_dict
20
+ options = settings_dict.get("OPTIONS", {})
21
+ env: dict[str, str] = {}
22
+
23
+ if password := settings_dict.get("PASSWORD"):
24
+ env["PGPASSWORD"] = str(password)
25
+
26
+ # Map OPTIONS keys to their corresponding environment variables.
27
+ option_env_vars = {
28
+ "passfile": "PGPASSFILE",
29
+ "sslmode": "PGSSLMODE",
30
+ "sslrootcert": "PGSSLROOTCERT",
31
+ "sslcert": "PGSSLCERT",
32
+ "sslkey": "PGSSLKEY",
33
+ }
34
+ for option_key, env_var in option_env_vars.items():
35
+ if value := options.get(option_key):
36
+ env[env_var] = str(value)
37
+
38
+ return env
39
+
40
+ def _get_conn_args(self) -> list[str]:
41
+ """Build common connection CLI args from settings."""
42
+ settings_dict = self.connection.settings_dict
43
+ args: list[str] = []
44
+ if user := settings_dict.get("USER"):
45
+ args += ["-U", user]
46
+ if host := settings_dict.get("HOST"):
47
+ args += ["-h", host]
48
+ if port := settings_dict.get("PORT"):
49
+ args += ["-p", str(port)]
50
+ return args
51
+
52
+ def _run(self, cmd: str | list[str], *, shell: bool = False) -> None:
53
+ subprocess.run(
54
+ cmd, env={**os.environ, **self.get_env()}, check=True, shell=shell
55
+ )
56
+
57
+ def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
58
+ settings_dict = self.connection.settings_dict
59
+ dbname = settings_dict.get("DATABASE")
60
+ if not dbname:
61
+ raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
62
+
63
+ args = pg_dump.split() + self._get_conn_args()
64
+ args += ["-Fc", dbname]
65
+
66
+ # Pipe through gzip for compression
67
+ args += ["|", "gzip", ">", str(backup_path)]
68
+ self._run(" ".join(args), shell=True)
69
+
70
+ def restore_backup(
71
+ self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
72
+ ) -> None:
73
+ settings_dict = self.connection.settings_dict
74
+ dbname = settings_dict.get("DATABASE")
75
+ if not dbname:
76
+ raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
77
+
78
+ conn_args = self._get_conn_args()
79
+
80
+ # Drop and recreate the database via template1
81
+ drop_create_cmds = [
82
+ f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
83
+ f'DROP DATABASE IF EXISTS "{dbname}"',
84
+ f'CREATE DATABASE "{dbname}"',
85
+ ]
86
+ for sql in drop_create_cmds:
87
+ self._run(psql.split() + conn_args + ["-d", "template1", "-c", sql])
88
+
89
+ # Restore into the fresh database
90
+ args = pg_restore.split() + conn_args + ["-d", dbname]
91
+
92
+ # Pipe through gunzip for decompression
93
+ args = ["gunzip", "<", str(backup_path), "|"] + args
94
+ self._run(" ".join(args), shell=True)