plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,1331 @@
1
+ from __future__ import annotations
2
+
3
+ import _thread
4
+ import datetime
5
+ import logging
6
+ import os
7
+ import signal
8
+ import subprocess
9
+ import sys
10
+ import time
11
+ import warnings
12
+ import zoneinfo
13
+ from collections import deque
14
+ from collections.abc import Generator, Sequence
15
+ from contextlib import contextmanager
16
+ from functools import cached_property, lru_cache
17
+ from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
18
+
19
+ import psycopg as Database
20
+ from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
21
+ from psycopg import sql as psycopg_sql
22
+ from psycopg.abc import Buffer, PyFormat
23
+ from psycopg.postgres import types as pg_types
24
+ from psycopg.pq import Format
25
+ from psycopg.types.datetime import TimestamptzLoader
26
+ from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
27
+ from psycopg.types.string import TextLoader
28
+
29
+ from plain.exceptions import ImproperlyConfigured
30
+ from plain.postgres import utils
31
+ from plain.postgres.db import (
32
+ DatabaseError,
33
+ DatabaseErrorWrapper,
34
+ )
35
+ from plain.postgres.dialect import MAX_NAME_LENGTH, quote_name
36
+ from plain.postgres.indexes import Index
37
+ from plain.postgres.schema import DatabaseSchemaEditor
38
+ from plain.postgres.transaction import TransactionManagementError
39
+ from plain.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
40
+ from plain.postgres.utils import CursorWrapper, debug_transaction
41
+ from plain.runtime import settings
42
+
43
+ if TYPE_CHECKING:
44
+ from psycopg import Connection as PsycopgConnection
45
+
46
+ from plain.postgres.connections import DatabaseConfig
47
+ from plain.postgres.fields import Field
48
+
49
+ logger = logging.getLogger("plain.postgres.connection")
50
+
51
+ # The prefix to put on the default database name when creating
52
+ # the test database.
53
+ TEST_DATABASE_PREFIX = "test_"
54
+
55
+
56
+ def get_migratable_models() -> Generator[Any]:
57
+ """Return all models that should be included in migrations."""
58
+ from plain.packages import packages_registry
59
+ from plain.postgres import models_registry
60
+
61
+ return (
62
+ model
63
+ for package_config in packages_registry.get_package_configs()
64
+ for model in models_registry.get_models(
65
+ package_label=package_config.package_label
66
+ )
67
+ )
68
+
69
+
70
+ class TableInfo(NamedTuple):
71
+ """Structure returned by DatabaseConnection.get_table_list()."""
72
+
73
+ name: str
74
+ type: str
75
+ comment: str | None
76
+
77
+
78
+ # Type OIDs
79
+ TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
80
+ TSRANGE_OID = pg_types["tsrange"].oid
81
+ TSTZRANGE_OID = pg_types["tstzrange"].oid
82
+
83
+
84
+ class BaseTzLoader(TimestamptzLoader):
85
+ """
86
+ Load a PostgreSQL timestamptz using a specific timezone.
87
+ The timezone can be None too, in which case it will be chopped.
88
+ """
89
+
90
+ timezone: datetime.tzinfo | None = None
91
+
92
+ def load(self, data: Buffer) -> datetime.datetime:
93
+ res = super().load(data)
94
+ return res.replace(tzinfo=self.timezone)
95
+
96
+
97
+ def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
98
+ class SpecificTzLoader(BaseTzLoader):
99
+ timezone = tz
100
+
101
+ context.adapters.register_loader("timestamptz", SpecificTzLoader)
102
+
103
+
104
+ class PlainRangeDumper(RangeDumper):
105
+ """A Range dumper customized for Plain."""
106
+
107
+ def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
108
+ dumper = super().upgrade(obj, format)
109
+ if dumper is not self and dumper.oid == TSRANGE_OID:
110
+ dumper.oid = TSTZRANGE_OID
111
+ return dumper
112
+
113
+
114
+ @lru_cache
115
+ def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
116
+ ctx = adapt.AdaptersMap(adapters)
117
+ # No-op JSON loader to avoid psycopg3 round trips
118
+ ctx.register_loader("jsonb", TextLoader)
119
+ # Treat inet/cidr as text
120
+ ctx.register_loader("inet", TextLoader)
121
+ ctx.register_loader("cidr", TextLoader)
122
+ ctx.register_dumper(Range, PlainRangeDumper)
123
+ register_tzloader(timezone, ctx)
124
+ return ctx
125
+
126
+
127
+ def _psql_settings_to_cmd_args_env(
128
+ settings_dict: DatabaseConfig, parameters: list[str]
129
+ ) -> tuple[list[str], dict[str, str] | None]:
130
+ """Build psql command-line arguments from database settings."""
131
+ args = ["psql"]
132
+ options = settings_dict.get("OPTIONS", {})
133
+
134
+ if user := settings_dict.get("USER"):
135
+ args += ["-U", user]
136
+ if host := settings_dict.get("HOST"):
137
+ args += ["-h", host]
138
+ if port := settings_dict.get("PORT"):
139
+ args += ["-p", str(port)]
140
+ args.extend(parameters)
141
+ args += [settings_dict.get("DATABASE") or "postgres"]
142
+
143
+ env: dict[str, str] = {}
144
+ if password := settings_dict.get("PASSWORD"):
145
+ env["PGPASSWORD"] = str(password)
146
+
147
+ # Map OPTIONS keys to their corresponding environment variables.
148
+ option_env_vars = {
149
+ "passfile": "PGPASSFILE",
150
+ "sslmode": "PGSSLMODE",
151
+ "sslrootcert": "PGSSLROOTCERT",
152
+ "sslcert": "PGSSLCERT",
153
+ "sslkey": "PGSSLKEY",
154
+ }
155
+ for option_key, env_var in option_env_vars.items():
156
+ if value := options.get(option_key):
157
+ env[env_var] = str(value)
158
+
159
+ return args, (env or None)
160
+
161
+
162
+ class DatabaseConnection:
163
+ """
164
+ PostgreSQL database connection.
165
+
166
+ This is the only database backend supported by Plain.
167
+ """
168
+
169
+ queries_limit: int = 9000
170
+ executable_name: str = "psql"
171
+
172
+ index_default_access_method = "btree"
173
+ ignored_tables: list[str] = []
174
+
175
+ def __init__(self, settings_dict: DatabaseConfig):
176
+ # Connection related attributes.
177
+ # The underlying database connection (from the database library, not a wrapper).
178
+ self.connection: PsycopgConnection[Any] | None = None
179
+ # `settings_dict` should be a dictionary containing keys such as
180
+ # DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
181
+ # to disambiguate it from Plain settings modules.
182
+ self.settings_dict: DatabaseConfig = settings_dict
183
+ # Query logging in debug mode or when explicitly enabled.
184
+ self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
185
+ self.force_debug_cursor: bool = False
186
+
187
+ # Transaction related attributes.
188
+ # Tracks if the connection is in autocommit mode. Per PEP 249, by
189
+ # default, it isn't.
190
+ self.autocommit: bool = False
191
+ # Tracks if the connection is in a transaction managed by 'atomic'.
192
+ self.in_atomic_block: bool = False
193
+ # Increment to generate unique savepoint ids.
194
+ self.savepoint_state: int = 0
195
+ # List of savepoints created by 'atomic'.
196
+ self.savepoint_ids: list[str | None] = []
197
+ # Stack of active 'atomic' blocks.
198
+ self.atomic_blocks: list[Any] = []
199
+ # Tracks if the transaction should be rolled back to the next
200
+ # available savepoint because of an exception in an inner block.
201
+ self.needs_rollback: bool = False
202
+ self.rollback_exc: Exception | None = None
203
+
204
+ # Connection termination related attributes.
205
+ self.close_at: float | None = None
206
+ self.closed_in_transaction: bool = False
207
+ self.errors_occurred: bool = False
208
+ self.health_check_enabled: bool = False
209
+ self.health_check_done: bool = False
210
+
211
+ # A list of no-argument functions to run when the transaction commits.
212
+ # Each entry is an (sids, func, robust) tuple, where sids is a set of
213
+ # the active savepoint IDs when this function was registered and robust
214
+ # specifies whether it's allowed for the function to fail.
215
+ self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
216
+
217
+ # Should we run the on-commit hooks the next time set_autocommit(True)
218
+ # is called?
219
+ self.run_commit_hooks_on_set_autocommit_on: bool = False
220
+
221
+ # A stack of wrappers to be invoked around execute()/executemany()
222
+ # calls. Each entry is a function taking five arguments: execute, sql,
223
+ # params, many, and context. It's the function's responsibility to
224
+ # call execute(sql, params, many, context).
225
+ self.execute_wrappers: list[Any] = []
226
+
227
+ def __repr__(self) -> str:
228
+ return f"<{self.__class__.__qualname__} vendor='postgresql'>"
229
+
230
+ @cached_property
231
+ def timezone(self) -> datetime.tzinfo:
232
+ """
233
+ Return a tzinfo of the database connection time zone.
234
+
235
+ When a datetime is read from the database, it is returned in this time
236
+ zone. Since PostgreSQL supports time zones, it doesn't matter which
237
+ time zone Plain uses, as long as aware datetimes are used everywhere.
238
+ Other users connecting to the database can choose their own time zone.
239
+ """
240
+ if self.settings_dict["TIME_ZONE"] is None:
241
+ return datetime.UTC
242
+ return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
243
+
244
+ @cached_property
245
+ def timezone_name(self) -> str:
246
+ """
247
+ Name of the time zone of the database connection.
248
+ """
249
+ if self.settings_dict["TIME_ZONE"] is None:
250
+ return "UTC"
251
+ return self.settings_dict["TIME_ZONE"]
252
+
253
+ @property
254
+ def queries_logged(self) -> bool:
255
+ return self.force_debug_cursor or settings.DEBUG
256
+
257
+ @property
258
+ def queries(self) -> list[dict[str, Any]]:
259
+ if len(self.queries_log) == self.queries_log.maxlen:
260
+ warnings.warn(
261
+ f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
262
+ "will be returned."
263
+ )
264
+ return list(self.queries_log)
265
+
266
+ # ##### Connection and cursor methods #####
267
+
268
+ def get_connection_params(self) -> dict[str, Any]:
269
+ """Return a dict of parameters suitable for get_new_connection."""
270
+ settings_dict = self.settings_dict
271
+ options = settings_dict.get("OPTIONS", {})
272
+ db_name = settings_dict.get("DATABASE")
273
+ if db_name == "":
274
+ raise ImproperlyConfigured(
275
+ "PostgreSQL database is not configured. "
276
+ "Set DATABASE_URL or the POSTGRES_DATABASE setting."
277
+ )
278
+ if len(db_name or "") > MAX_NAME_LENGTH:
279
+ raise ImproperlyConfigured(
280
+ "The database name '%s' (%d characters) is longer than " # noqa: UP031
281
+ "PostgreSQL's limit of %d characters. Supply a shorter "
282
+ "POSTGRES_DATABASE setting."
283
+ % (
284
+ db_name,
285
+ len(db_name or ""),
286
+ MAX_NAME_LENGTH,
287
+ )
288
+ )
289
+ if db_name is None:
290
+ # None is used to connect to the default 'postgres' db.
291
+ db_name = "postgres"
292
+ conn_params: dict[str, Any] = {
293
+ "dbname": db_name,
294
+ **options,
295
+ }
296
+
297
+ conn_params.pop("assume_role", None)
298
+ conn_params.pop("isolation_level", None)
299
+ conn_params.pop("server_side_binding", None)
300
+ if settings_dict["USER"]:
301
+ conn_params["user"] = settings_dict["USER"]
302
+ if settings_dict["PASSWORD"]:
303
+ conn_params["password"] = settings_dict["PASSWORD"]
304
+ if settings_dict["HOST"]:
305
+ conn_params["host"] = settings_dict["HOST"]
306
+ if settings_dict["PORT"]:
307
+ conn_params["port"] = settings_dict["PORT"]
308
+ conn_params["context"] = get_adapters_template(self.timezone)
309
+ # Disable prepared statements by default to keep connection poolers
310
+ # working. Can be reenabled via OPTIONS in the settings dict.
311
+ conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
312
+ return conn_params
313
+
314
+ def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
315
+ """Open a connection to the database."""
316
+ # self.isolation_level must be set:
317
+ # - after connecting to the database in order to obtain the database's
318
+ # default when no value is explicitly specified in options.
319
+ # - before calling _set_autocommit() because if autocommit is on, that
320
+ # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
321
+ options = self.settings_dict.get("OPTIONS", {})
322
+ set_isolation_level = False
323
+ try:
324
+ isolation_level_value = options["isolation_level"]
325
+ except KeyError:
326
+ self.isolation_level = IsolationLevel.READ_COMMITTED
327
+ else:
328
+ # Set the isolation level to the value from OPTIONS.
329
+ try:
330
+ self.isolation_level = IsolationLevel(isolation_level_value)
331
+ set_isolation_level = True
332
+ except ValueError:
333
+ raise ImproperlyConfigured(
334
+ f"Invalid transaction isolation level {isolation_level_value} "
335
+ f"specified. Use one of the psycopg.IsolationLevel values."
336
+ )
337
+ connection = Database.connect(**conn_params)
338
+ if set_isolation_level:
339
+ connection.isolation_level = self.isolation_level
340
+ # Use server-side binding cursor if requested, otherwise standard cursor
341
+ connection.cursor_factory = (
342
+ ServerBindingCursor
343
+ if options.get("server_side_binding") is True
344
+ else Cursor
345
+ )
346
+ return connection
347
+
348
+ def ensure_timezone(self) -> bool:
349
+ """
350
+ Ensure the connection's timezone is set to `self.timezone_name` and
351
+ return whether it changed or not.
352
+ """
353
+ if self.connection is None:
354
+ return False
355
+ conn_timezone_name = self.connection.info.parameter_status("TimeZone")
356
+ timezone_name = self.timezone_name
357
+ if timezone_name and conn_timezone_name != timezone_name:
358
+ self.connection.execute(
359
+ "SELECT set_config('TimeZone', %s, false)", [timezone_name]
360
+ )
361
+ return True
362
+ return False
363
+
364
+ def ensure_role(self) -> bool:
365
+ if self.connection is None:
366
+ return False
367
+ if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
368
+ sql_str = self.compose_sql("SET ROLE %s", [new_role])
369
+ self.connection.execute(sql_str) # type: ignore[arg-type]
370
+ return True
371
+ return False
372
+
373
+ def init_connection_state(self) -> None:
374
+ """Initialize the database connection settings."""
375
+ self.ensure_timezone()
376
+ # Set the role on the connection. This is useful if the credential used
377
+ # to login is not the same as the role that owns database resources. As
378
+ # can be the case when using temporary or ephemeral credentials.
379
+ self.ensure_role()
380
+
381
+ def create_cursor(self) -> Any:
382
+ """Create a cursor. Assume that a connection is established."""
383
+ assert self.connection is not None
384
+ cursor = self.connection.cursor()
385
+
386
+ # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
387
+ tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
388
+ if self.timezone != tzloader.timezone: # type: ignore[union-attr]
389
+ register_tzloader(self.timezone, cursor)
390
+ return cursor
391
+
392
+ def _set_autocommit(self, autocommit: bool) -> None:
393
+ """Backend-specific implementation to enable or disable autocommit."""
394
+ assert self.connection is not None
395
+ with self.wrap_database_errors:
396
+ self.connection.autocommit = autocommit
397
+
398
+ def check_constraints(self, table_names: list[str] | None = None) -> None:
399
+ """
400
+ Check constraints by setting them to immediate. Return them to deferred
401
+ afterward.
402
+ """
403
+ with self.cursor() as cursor:
404
+ cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
405
+ cursor.execute("SET CONSTRAINTS ALL DEFERRED")
406
+
407
+ def is_usable(self) -> bool:
408
+ """
409
+ Test if the database connection is usable.
410
+
411
+ This method may assume that self.connection is not None.
412
+
413
+ Actual implementations should take care not to raise exceptions
414
+ as that may prevent Plain from recycling unusable connections.
415
+ """
416
+ assert self.connection is not None
417
+ try:
418
+ # Use psycopg directly, bypassing Plain's utilities.
419
+ self.connection.execute("SELECT 1")
420
+ except Database.Error:
421
+ return False
422
+ else:
423
+ return True
424
+
425
+ @contextmanager
426
+ def _nodb_cursor(self) -> Generator[utils.CursorWrapper]:
427
+ """
428
+ Return a cursor from an alternative connection to be used when there is
429
+ no need to access the main database, specifically for test db
430
+ creation/deletion. This also prevents the production database from
431
+ being exposed to potential child threads while (or after) the test
432
+ database is destroyed. Refs #10868, #17786, #16969.
433
+ """
434
+ cursor = None
435
+ try:
436
+ conn = self.__class__({**self.settings_dict, "DATABASE": None})
437
+ try:
438
+ with conn.cursor() as cursor:
439
+ yield cursor
440
+ finally:
441
+ conn.close()
442
+ except (Database.DatabaseError, DatabaseError):
443
+ if cursor is not None:
444
+ raise
445
+ warnings.warn(
446
+ "Normally Plain will use a connection to the 'postgres' database "
447
+ "to avoid running initialization queries against the production "
448
+ "database when it's not needed (for example, when running tests). "
449
+ "Plain was unable to create a connection to the 'postgres' database "
450
+ "and will use the first PostgreSQL database instead.",
451
+ RuntimeWarning,
452
+ )
453
+ conn = self.__class__(self.settings_dict)
454
+ try:
455
+ with conn.cursor() as cursor:
456
+ yield cursor
457
+ finally:
458
+ conn.close()
459
+
460
+ @cached_property
461
+ def pg_version(self) -> int:
462
+ with self.temporary_connection():
463
+ assert self.connection is not None
464
+ return self.connection.info.server_version
465
+
466
+ def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
467
+ return CursorDebugWrapper(cursor, self)
468
+
469
+ # ##### Connection lifecycle #####
470
+
471
+ def connect(self) -> None:
472
+ """Connect to the database. Assume that the connection is closed."""
473
+ # In case the previous connection was closed while in an atomic block
474
+ self.in_atomic_block = False
475
+ self.savepoint_ids = []
476
+ self.atomic_blocks = []
477
+ self.needs_rollback = False
478
+ # Reset parameters defining when to close/health-check the connection.
479
+ self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
480
+ max_age = self.settings_dict["CONN_MAX_AGE"]
481
+ self.close_at = None if max_age is None else time.monotonic() + max_age
482
+ self.closed_in_transaction = False
483
+ self.errors_occurred = False
484
+ # New connections are healthy.
485
+ self.health_check_done = True
486
+ # Establish the connection
487
+ conn_params = self.get_connection_params()
488
+ self.connection = self.get_new_connection(conn_params)
489
+ self.set_autocommit(True)
490
+ self.init_connection_state()
491
+
492
+ self.run_on_commit = []
493
+
494
+ def ensure_connection(self) -> None:
495
+ """Guarantee that a connection to the database is established."""
496
+ if self.connection is None:
497
+ with self.wrap_database_errors:
498
+ self.connect()
499
+
500
+ # ##### PEP-249 connection method wrappers #####
501
+
502
+ def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
503
+ """
504
+ Validate the connection is usable and perform database cursor wrapping.
505
+ """
506
+ if self.queries_logged:
507
+ wrapped_cursor = self.make_debug_cursor(cursor)
508
+ else:
509
+ wrapped_cursor = self.make_cursor(cursor)
510
+ return wrapped_cursor
511
+
512
+ def _cursor(self) -> utils.CursorWrapper:
513
+ self.close_if_health_check_failed()
514
+ self.ensure_connection()
515
+ with self.wrap_database_errors:
516
+ return self._prepare_cursor(self.create_cursor())
517
+
518
+ def _commit(self) -> None:
519
+ if self.connection is not None:
520
+ with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
521
+ return self.connection.commit()
522
+
523
+ def _rollback(self) -> None:
524
+ if self.connection is not None:
525
+ with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
526
+ return self.connection.rollback()
527
+
528
+ def _close(self) -> None:
529
+ if self.connection is not None:
530
+ with self.wrap_database_errors:
531
+ return self.connection.close()
532
+
533
+ # ##### Generic wrappers for PEP-249 connection methods #####
534
+
535
+ def cursor(self) -> utils.CursorWrapper:
536
+ """Create a cursor, opening a connection if necessary."""
537
+ return self._cursor()
538
+
539
+ def commit(self) -> None:
540
+ """Commit a transaction and reset the dirty flag."""
541
+ self.validate_no_atomic_block()
542
+ self._commit()
543
+ # A successful commit means that the database connection works.
544
+ self.errors_occurred = False
545
+ self.run_commit_hooks_on_set_autocommit_on = True
546
+
547
+ def rollback(self) -> None:
548
+ """Roll back a transaction and reset the dirty flag."""
549
+ self.validate_no_atomic_block()
550
+ self._rollback()
551
+ # A successful rollback means that the database connection works.
552
+ self.errors_occurred = False
553
+ self.needs_rollback = False
554
+ self.run_on_commit = []
555
+
556
+ def close(self) -> None:
557
+ """Close the connection to the database."""
558
+ self.run_on_commit = []
559
+
560
+ # Don't call validate_no_atomic_block() to avoid making it difficult
561
+ # to get rid of a connection in an invalid state. The next connect()
562
+ # will reset the transaction state anyway.
563
+ if self.closed_in_transaction or self.connection is None:
564
+ return
565
+ try:
566
+ self._close()
567
+ finally:
568
+ if self.in_atomic_block:
569
+ self.closed_in_transaction = True
570
+ self.needs_rollback = True
571
+ else:
572
+ self.connection = None
573
+
574
+ # ##### Savepoint management #####
575
+
576
+ def _savepoint(self, sid: str) -> None:
577
+ with self.cursor() as cursor:
578
+ cursor.execute(f"SAVEPOINT {quote_name(sid)}")
579
+
580
+ def _savepoint_rollback(self, sid: str) -> None:
581
+ with self.cursor() as cursor:
582
+ cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
583
+
584
+ def _savepoint_commit(self, sid: str) -> None:
585
+ with self.cursor() as cursor:
586
+ cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
587
+
588
+ # ##### Generic savepoint management methods #####
589
+
590
+ def savepoint(self) -> str | None:
591
+ """
592
+ Create a savepoint inside the current transaction. Return an
593
+ identifier for the savepoint that will be used for the subsequent
594
+ rollback or commit. Return None if in autocommit mode (no transaction).
595
+ """
596
+ if self.get_autocommit():
597
+ return None
598
+
599
+ thread_ident = _thread.get_ident()
600
+ tid = str(thread_ident).replace("-", "")
601
+
602
+ self.savepoint_state += 1
603
+ sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
604
+
605
+ self._savepoint(sid)
606
+
607
+ return sid
608
+
609
+ def savepoint_rollback(self, sid: str) -> None:
610
+ """
611
+ Roll back to a savepoint. Do nothing if in autocommit mode.
612
+ """
613
+ if self.get_autocommit():
614
+ return
615
+
616
+ self._savepoint_rollback(sid)
617
+
618
+ # Remove any callbacks registered while this savepoint was active.
619
+ self.run_on_commit = [
620
+ (sids, func, robust)
621
+ for (sids, func, robust) in self.run_on_commit
622
+ if sid not in sids
623
+ ]
624
+
625
+ def savepoint_commit(self, sid: str) -> None:
626
+ """
627
+ Release a savepoint. Do nothing if in autocommit mode.
628
+ """
629
+ if self.get_autocommit():
630
+ return
631
+
632
+ self._savepoint_commit(sid)
633
+
634
+ def clean_savepoints(self) -> None:
635
+ """
636
+ Reset the counter used to generate unique savepoint ids in this thread.
637
+ """
638
+ self.savepoint_state = 0
639
+
640
+ # ##### Generic transaction management methods #####
641
+
642
+ def get_autocommit(self) -> bool:
643
+ """Get the autocommit state."""
644
+ self.ensure_connection()
645
+ return self.autocommit
646
+
647
+ def set_autocommit(self, autocommit: bool) -> None:
648
+ """
649
+ Enable or disable autocommit.
650
+
651
+ Used internally by atomic() to manage transactions. Don't call this
652
+ directly — use atomic() instead.
653
+ """
654
+ self.validate_no_atomic_block()
655
+ self.close_if_health_check_failed()
656
+ self.ensure_connection()
657
+
658
+ if autocommit:
659
+ self._set_autocommit(autocommit)
660
+ else:
661
+ with debug_transaction(self, "BEGIN"):
662
+ self._set_autocommit(autocommit)
663
+ self.autocommit = autocommit
664
+
665
+ if autocommit and self.run_commit_hooks_on_set_autocommit_on:
666
+ self.run_and_clear_commit_hooks()
667
+ self.run_commit_hooks_on_set_autocommit_on = False
668
+
669
+ def get_rollback(self) -> bool:
670
+ """Get the "needs rollback" flag -- for *advanced use* only."""
671
+ if not self.in_atomic_block:
672
+ raise TransactionManagementError(
673
+ "The rollback flag doesn't work outside of an 'atomic' block."
674
+ )
675
+ return self.needs_rollback
676
+
677
+ def set_rollback(self, rollback: bool) -> None:
678
+ """
679
+ Set or unset the "needs rollback" flag -- for *advanced use* only.
680
+ """
681
+ if not self.in_atomic_block:
682
+ raise TransactionManagementError(
683
+ "The rollback flag doesn't work outside of an 'atomic' block."
684
+ )
685
+ self.needs_rollback = rollback
686
+
687
+ def validate_no_atomic_block(self) -> None:
688
+ """Raise an error if an atomic block is active."""
689
+ if self.in_atomic_block:
690
+ raise TransactionManagementError(
691
+ "This is forbidden when an 'atomic' block is active."
692
+ )
693
+
694
+ def validate_no_broken_transaction(self) -> None:
695
+ if self.needs_rollback:
696
+ raise TransactionManagementError(
697
+ "An error occurred in the current transaction. You can't "
698
+ "execute queries until the end of the 'atomic' block."
699
+ ) from self.rollback_exc
700
+
701
+ # ##### Connection termination handling #####
702
+
703
+ def close_if_health_check_failed(self) -> None:
704
+ """Close existing connection if it fails a health check."""
705
+ if (
706
+ self.connection is None
707
+ or not self.health_check_enabled
708
+ or self.health_check_done
709
+ ):
710
+ return
711
+
712
+ if not self.is_usable():
713
+ self.close()
714
+ self.health_check_done = True
715
+
716
+ def close_if_unusable_or_obsolete(self) -> None:
717
+ """
718
+ Close the current connection if unrecoverable errors have occurred
719
+ or if it outlived its maximum age.
720
+ """
721
+ if self.connection is not None:
722
+ self.health_check_done = False
723
+ # If autocommit was not restored (e.g. a transaction was not
724
+ # properly closed), don't take chances, drop the connection.
725
+ if not self.get_autocommit():
726
+ self.close()
727
+ return
728
+
729
+ # If an exception other than DataError or IntegrityError occurred
730
+ # since the last commit / rollback, check if the connection works.
731
+ if self.errors_occurred:
732
+ if self.is_usable():
733
+ self.errors_occurred = False
734
+ self.health_check_done = True
735
+ else:
736
+ self.close()
737
+ return
738
+
739
+ if self.close_at is not None and time.monotonic() >= self.close_at:
740
+ self.close()
741
+ return
742
+
743
+ # ##### Miscellaneous #####
744
+
745
+ @cached_property
746
+ def wrap_database_errors(self) -> DatabaseErrorWrapper:
747
+ """
748
+ Context manager and decorator that re-throws backend-specific database
749
+ exceptions using Plain's common wrappers.
750
+ """
751
+ return DatabaseErrorWrapper(self)
752
+
753
+ def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
754
+ """Create a cursor without debug logging."""
755
+ return utils.CursorWrapper(cursor, self)
756
+
757
+ @contextmanager
758
+ def temporary_connection(self) -> Generator[utils.CursorWrapper]:
759
+ """
760
+ Context manager that ensures that a connection is established, and
761
+ if it opened one, closes it to avoid leaving a dangling connection.
762
+ This is useful for operations outside of the request-response cycle.
763
+
764
+ Provide a cursor: with self.temporary_connection() as cursor: ...
765
+ """
766
+ must_close = self.connection is None
767
+ try:
768
+ with self.cursor() as cursor:
769
+ yield cursor
770
+ finally:
771
+ if must_close:
772
+ self.close()
773
+
774
+ def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
775
+ """Return a new instance of the schema editor."""
776
+ return DatabaseSchemaEditor(self, *args, **kwargs)
777
+
778
+ def runshell(self, parameters: list[str]) -> None:
779
+ """Run an interactive psql shell."""
780
+ args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
781
+ env = {**os.environ, **env} if env else None
782
+ sigint_handler = signal.getsignal(signal.SIGINT)
783
+ try:
784
+ # Allow SIGINT to pass to psql to abort queries.
785
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
786
+ subprocess.run(args, env=env, check=True)
787
+ finally:
788
+ # Restore the original SIGINT handler.
789
+ signal.signal(signal.SIGINT, sigint_handler)
790
+
791
+ def on_commit(self, func: Any, robust: bool = False) -> None:
792
+ if not callable(func):
793
+ raise TypeError("on_commit()'s callback must be a callable.")
794
+ if self.in_atomic_block:
795
+ # Transaction in progress; save for execution on commit.
796
+ self.run_on_commit.append((set(self.savepoint_ids), func, robust))
797
+ else:
798
+ # No transaction in progress; execute immediately.
799
+ if robust:
800
+ try:
801
+ func()
802
+ except Exception as e:
803
+ logger.error(
804
+ f"Error calling {func.__qualname__} in on_commit() (%s).",
805
+ e,
806
+ exc_info=True,
807
+ )
808
+ else:
809
+ func()
810
+
811
+ def run_and_clear_commit_hooks(self) -> None:
812
+ self.validate_no_atomic_block()
813
+ current_run_on_commit = self.run_on_commit
814
+ self.run_on_commit = []
815
+ while current_run_on_commit:
816
+ _, func, robust = current_run_on_commit.pop(0)
817
+ if robust:
818
+ try:
819
+ func()
820
+ except Exception as e:
821
+ logger.error(
822
+ f"Error calling {func.__qualname__} in on_commit() during "
823
+ f"transaction (%s).",
824
+ e,
825
+ exc_info=True,
826
+ )
827
+ else:
828
+ func()
829
+
830
+ @contextmanager
831
+ def execute_wrapper(self, wrapper: Any) -> Generator[None]:
832
+ """
833
+ Return a context manager under which the wrapper is applied to suitable
834
+ database query executions.
835
+ """
836
+ self.execute_wrappers.append(wrapper)
837
+ try:
838
+ yield
839
+ finally:
840
+ self.execute_wrappers.pop()
841
+
842
+ # ##### SQL generation methods that require connection state #####
843
+
844
+ def compose_sql(self, query: str, params: Any) -> str:
845
+ """
846
+ Compose a SQL query with parameters using psycopg's mogrify.
847
+
848
+ This requires an active connection because it uses the connection's
849
+ cursor to properly format parameters.
850
+ """
851
+ assert self.connection is not None
852
+ return ClientCursor(self.connection).mogrify(
853
+ psycopg_sql.SQL(cast(LiteralString, query)), params
854
+ )
855
+
856
+ def last_executed_query(
857
+ self,
858
+ cursor: utils.CursorWrapper,
859
+ sql: str,
860
+ params: Any,
861
+ ) -> str | None:
862
+ """
863
+ Return a string of the query last executed by the given cursor, with
864
+ placeholders replaced with actual values.
865
+ """
866
+ try:
867
+ return self.compose_sql(sql, params)
868
+ except errors.DataError:
869
+ return None
870
+
871
+ def unification_cast_sql(self, output_field: Field) -> str:
872
+ """
873
+ Given a field instance, return the SQL that casts the result of a union
874
+ to that type. The resulting string should contain a '%s' placeholder
875
+ for the expression being cast.
876
+ """
877
+ internal_type = output_field.get_internal_type()
878
+ if internal_type in (
879
+ "GenericIPAddressField",
880
+ "TimeField",
881
+ "UUIDField",
882
+ ):
883
+ # PostgreSQL will resolve a union as type 'text' if input types are
884
+ # 'unknown'.
885
+ # https://www.postgresql.org/docs/current/typeconv-union-case.html
886
+ # These fields cannot be implicitly cast back in the default
887
+ # PostgreSQL configuration so we need to explicitly cast them.
888
+ # We must also remove components of the type within brackets:
889
+ # varchar(255) -> varchar.
890
+ db_type = output_field.db_type()
891
+ if db_type:
892
+ return "CAST(%s AS {})".format(db_type.split("(")[0])
893
+ return "%s"
894
+
895
+ # ##### Introspection methods #####
896
+
897
+ def table_names(
898
+ self, cursor: CursorWrapper | None = None, include_views: bool = False
899
+ ) -> list[str]:
900
+ """
901
+ Return a list of names of all tables that exist in the database.
902
+ Sort the returned table list by Python's default sorting. Do NOT use
903
+ the database's ORDER BY here to avoid subtle differences in sorting
904
+ order between databases.
905
+ """
906
+
907
+ def get_names(cursor: CursorWrapper) -> list[str]:
908
+ return sorted(
909
+ ti.name
910
+ for ti in self.get_table_list(cursor)
911
+ if include_views or ti.type == "t"
912
+ )
913
+
914
+ if cursor is None:
915
+ with self.cursor() as cursor:
916
+ return get_names(cursor)
917
+ return get_names(cursor)
918
+
919
+ def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
920
+ """
921
+ Return an unsorted list of TableInfo named tuples of all tables and
922
+ views that exist in the database.
923
+ """
924
+ cursor.execute(
925
+ """
926
+ SELECT
927
+ c.relname,
928
+ CASE
929
+ WHEN c.relispartition THEN 'p'
930
+ WHEN c.relkind IN ('m', 'v') THEN 'v'
931
+ ELSE 't'
932
+ END,
933
+ obj_description(c.oid, 'pg_class')
934
+ FROM pg_catalog.pg_class c
935
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
936
+ WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
937
+ AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
938
+ AND pg_catalog.pg_table_is_visible(c.oid)
939
+ """
940
+ )
941
+ return [
942
+ TableInfo(*row)
943
+ for row in cursor.fetchall()
944
+ if row[0] not in self.ignored_tables
945
+ ]
946
+
947
+ def plain_table_names(
948
+ self, only_existing: bool = False, include_views: bool = True
949
+ ) -> list[str]:
950
+ """
951
+ Return a list of all table names that have associated Plain models and
952
+ are in INSTALLED_PACKAGES.
953
+
954
+ If only_existing is True, include only the tables in the database.
955
+ """
956
+ tables = set()
957
+ for model in get_migratable_models():
958
+ tables.add(model.model_options.db_table)
959
+ tables.update(
960
+ f.m2m_db_table() for f in model._model_meta.local_many_to_many
961
+ )
962
+ tables = list(tables)
963
+ if only_existing:
964
+ existing_tables = set(self.table_names(include_views=include_views))
965
+ tables = [t for t in tables if t in existing_tables]
966
+ return tables
967
+
968
+ def get_sequences(
969
+ self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
970
+ ) -> list[dict[str, Any]]:
971
+ """
972
+ Return a list of introspected sequences for table_name. Each sequence
973
+ is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
974
+ """
975
+ cursor.execute(
976
+ """
977
+ SELECT
978
+ s.relname AS sequence_name,
979
+ a.attname AS colname
980
+ FROM
981
+ pg_class s
982
+ JOIN pg_depend d ON d.objid = s.oid
983
+ AND d.classid = 'pg_class'::regclass
984
+ AND d.refclassid = 'pg_class'::regclass
985
+ JOIN pg_attribute a ON d.refobjid = a.attrelid
986
+ AND d.refobjsubid = a.attnum
987
+ JOIN pg_class tbl ON tbl.oid = d.refobjid
988
+ AND tbl.relname = %s
989
+ AND pg_catalog.pg_table_is_visible(tbl.oid)
990
+ WHERE
991
+ s.relkind = 'S';
992
+ """,
993
+ [table_name],
994
+ )
995
+ return [
996
+ {"name": row[0], "table": table_name, "column": row[1]}
997
+ for row in cursor.fetchall()
998
+ ]
999
+
1000
+ def get_constraints(
1001
+ self, cursor: CursorWrapper, table_name: str
1002
+ ) -> dict[str, dict[str, Any]]:
1003
+ """
1004
+ Retrieve any constraints or keys (unique, pk, fk, check, index) across
1005
+ one or more columns. Also retrieve the definition of expression-based
1006
+ indexes.
1007
+ """
1008
+ constraints: dict[str, dict[str, Any]] = {}
1009
+ # Loop over the key table, collecting things as constraints. The column
1010
+ # array must return column names in the same order in which they were
1011
+ # created.
1012
+ cursor.execute(
1013
+ """
1014
+ SELECT
1015
+ c.conname,
1016
+ array(
1017
+ SELECT attname
1018
+ FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1019
+ JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1020
+ WHERE ca.attrelid = c.conrelid
1021
+ ORDER BY cols.arridx
1022
+ ),
1023
+ c.contype,
1024
+ (SELECT fkc.relname || '.' || fka.attname
1025
+ FROM pg_attribute AS fka
1026
+ JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1027
+ WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1028
+ cl.reloptions
1029
+ FROM pg_constraint AS c
1030
+ JOIN pg_class AS cl ON c.conrelid = cl.oid
1031
+ WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1032
+ """,
1033
+ [table_name],
1034
+ )
1035
+ for constraint, columns, kind, used_cols, options in cursor.fetchall():
1036
+ constraints[constraint] = {
1037
+ "columns": columns,
1038
+ "primary_key": kind == "p",
1039
+ "unique": kind in ["p", "u"],
1040
+ "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1041
+ "check": kind == "c",
1042
+ "index": False,
1043
+ "definition": None,
1044
+ "options": options,
1045
+ }
1046
+ # Now get indexes
1047
+ cursor.execute(
1048
+ """
1049
+ SELECT
1050
+ indexname,
1051
+ array_agg(attname ORDER BY arridx),
1052
+ indisunique,
1053
+ indisprimary,
1054
+ array_agg(ordering ORDER BY arridx),
1055
+ amname,
1056
+ exprdef,
1057
+ s2.attoptions
1058
+ FROM (
1059
+ SELECT
1060
+ c2.relname as indexname, idx.*, attr.attname, am.amname,
1061
+ CASE
1062
+ WHEN idx.indexprs IS NOT NULL THEN
1063
+ pg_get_indexdef(idx.indexrelid)
1064
+ END AS exprdef,
1065
+ CASE am.amname
1066
+ WHEN %s THEN
1067
+ CASE (option & 1)
1068
+ WHEN 1 THEN 'DESC' ELSE 'ASC'
1069
+ END
1070
+ END as ordering,
1071
+ c2.reloptions as attoptions
1072
+ FROM (
1073
+ SELECT *
1074
+ FROM
1075
+ pg_index i,
1076
+ unnest(i.indkey, i.indoption)
1077
+ WITH ORDINALITY koi(key, option, arridx)
1078
+ ) idx
1079
+ LEFT JOIN pg_class c ON idx.indrelid = c.oid
1080
+ LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1081
+ LEFT JOIN pg_am am ON c2.relam = am.oid
1082
+ LEFT JOIN
1083
+ pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1084
+ WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1085
+ ) s2
1086
+ GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1087
+ """,
1088
+ [self.index_default_access_method, table_name],
1089
+ )
1090
+ for (
1091
+ index,
1092
+ columns,
1093
+ unique,
1094
+ primary,
1095
+ orders,
1096
+ type_,
1097
+ definition,
1098
+ options,
1099
+ ) in cursor.fetchall():
1100
+ if index not in constraints:
1101
+ basic_index = (
1102
+ type_ == self.index_default_access_method and options is None
1103
+ )
1104
+ constraints[index] = {
1105
+ "columns": columns if columns != [None] else [],
1106
+ "orders": orders if orders != [None] else [],
1107
+ "primary_key": primary,
1108
+ "unique": unique,
1109
+ "foreign_key": None,
1110
+ "check": False,
1111
+ "index": True,
1112
+ "type": Index.suffix if basic_index else type_,
1113
+ "definition": definition,
1114
+ "options": options,
1115
+ }
1116
+ return constraints
1117
+
1118
+ # ##### Test database creation methods (merged from DatabaseCreation) #####
1119
+
1120
+ def _log(self, msg: str) -> None:
1121
+ sys.stderr.write(msg + os.linesep)
1122
+
1123
+ def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1124
+ """
1125
+ Create a test database, prompting the user for confirmation if the
1126
+ database already exists. Return the name of the test database created.
1127
+
1128
+ If prefix is provided, it will be prepended to the database name
1129
+ to isolate it from other test databases.
1130
+ """
1131
+ from plain.postgres.cli.migrations import apply
1132
+
1133
+ test_database_name = self._get_test_db_name(prefix)
1134
+
1135
+ if verbosity >= 1:
1136
+ self._log(f"Creating test database '{test_database_name}'...")
1137
+
1138
+ self._create_test_db(
1139
+ test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1140
+ )
1141
+
1142
+ self.close()
1143
+ settings.POSTGRES_DATABASE = test_database_name
1144
+ self.settings_dict["DATABASE"] = test_database_name
1145
+
1146
+ apply.callback(
1147
+ package_label=None,
1148
+ migration_name=None,
1149
+ fake=False,
1150
+ plan=False,
1151
+ check_unapplied=False,
1152
+ backup=False,
1153
+ no_input=True,
1154
+ atomic_batch=False, # No need for atomic batch when creating test database
1155
+ quiet=verbosity < 2, # Show migration output when verbosity is 2+
1156
+ )
1157
+
1158
+ # Ensure a connection for the side effect of initializing the test database.
1159
+ self.ensure_connection()
1160
+
1161
+ return test_database_name
1162
+
1163
+ def _get_test_db_name(self, prefix: str = "") -> str:
1164
+ """
1165
+ Internal implementation - return the name of the test DB that will be
1166
+ created. Only useful when called from create_test_db() and
1167
+ _create_test_db() and when no external munging is done with the 'DATABASE'
1168
+ settings.
1169
+
1170
+ If prefix is provided, it will be prepended to the database name.
1171
+ """
1172
+ # Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
1173
+ base_name = (
1174
+ self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
1175
+ )
1176
+ if prefix:
1177
+ return f"{prefix}_{base_name}"
1178
+ if self.settings_dict["TEST"]["DATABASE"]:
1179
+ return self.settings_dict["TEST"]["DATABASE"]
1180
+ name = self.settings_dict["DATABASE"]
1181
+ if name is None:
1182
+ raise ValueError("POSTGRES_DATABASE must be set")
1183
+ return TEST_DATABASE_PREFIX + name
1184
+
1185
+ def _get_database_create_suffix(
1186
+ self, encoding: str | None = None, template: str | None = None
1187
+ ) -> str:
1188
+ """Return PostgreSQL-specific CREATE DATABASE suffix."""
1189
+ suffix = ""
1190
+ if encoding:
1191
+ suffix += f" ENCODING '{encoding}'"
1192
+ if template:
1193
+ suffix += f" TEMPLATE {quote_name(template)}"
1194
+ return suffix and "WITH" + suffix
1195
+
1196
+ def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1197
+ try:
1198
+ cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1199
+ except Exception as e:
1200
+ cause = e.__cause__
1201
+ if cause and not isinstance(cause, errors.DuplicateDatabase):
1202
+ # All errors except "database already exists" cancel tests.
1203
+ self._log(f"Got an error creating the test database: {e}")
1204
+ sys.exit(2)
1205
+ else:
1206
+ raise
1207
+
1208
+ def _create_test_db(
1209
+ self, *, test_database_name: str, verbosity: int, autoclobber: bool
1210
+ ) -> str:
1211
+ """
1212
+ Internal implementation - create the test db tables.
1213
+ """
1214
+ test_db_params = {
1215
+ "dbname": quote_name(test_database_name),
1216
+ "suffix": self.sql_table_creation_suffix(),
1217
+ }
1218
+ # Create the test database and connect to it.
1219
+ with self._nodb_cursor() as cursor:
1220
+ try:
1221
+ self._execute_create_test_db(cursor, test_db_params)
1222
+ except Exception as e:
1223
+ self._log(f"Got an error creating the test database: {e}")
1224
+ if not autoclobber:
1225
+ confirm = input(
1226
+ "Type 'yes' if you would like to try deleting the test "
1227
+ f"database '{test_database_name}', or 'no' to cancel: "
1228
+ )
1229
+ if autoclobber or confirm == "yes":
1230
+ try:
1231
+ if verbosity >= 1:
1232
+ self._log(
1233
+ f"Destroying old test database '{test_database_name}'..."
1234
+ )
1235
+ cursor.execute(
1236
+ "DROP DATABASE {dbname}".format(**test_db_params)
1237
+ )
1238
+ self._execute_create_test_db(cursor, test_db_params)
1239
+ except Exception as e:
1240
+ self._log(f"Got an error recreating the test database: {e}")
1241
+ sys.exit(2)
1242
+ else:
1243
+ self._log("Tests cancelled.")
1244
+ sys.exit(1)
1245
+
1246
+ return test_database_name
1247
+
1248
+ def destroy_test_db(
1249
+ self, old_database_name: str | None = None, verbosity: int = 1
1250
+ ) -> None:
1251
+ """
1252
+ Destroy a test database, prompting the user for confirmation if the
1253
+ database already exists.
1254
+ """
1255
+ self.close()
1256
+
1257
+ test_database_name = self.settings_dict["DATABASE"]
1258
+ if test_database_name is None:
1259
+ raise ValueError("Test POSTGRES_DATABASE must be set")
1260
+
1261
+ if verbosity >= 1:
1262
+ self._log(f"Destroying test database '{test_database_name}'...")
1263
+ self._destroy_test_db(test_database_name, verbosity)
1264
+
1265
+ # Restore the original database name
1266
+ if old_database_name is not None:
1267
+ settings.POSTGRES_DATABASE = old_database_name
1268
+ self.settings_dict["DATABASE"] = old_database_name
1269
+
1270
+ def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1271
+ """
1272
+ Internal implementation - remove the test db tables.
1273
+ """
1274
+ # Remove the test database to clean up after
1275
+ # ourselves. Connect to the previous database (not the test database)
1276
+ # to do so, because it's not allowed to delete a database while being
1277
+ # connected to it.
1278
+ with self._nodb_cursor() as cursor:
1279
+ cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1280
+
1281
+ def sql_table_creation_suffix(self) -> str:
1282
+ """
1283
+ SQL to append to the end of the test table creation statements.
1284
+ """
1285
+ test_settings = self.settings_dict["TEST"]
1286
+ return self._get_database_create_suffix(
1287
+ encoding=test_settings.get("CHARSET"),
1288
+ template=test_settings.get("TEMPLATE"),
1289
+ )
1290
+
1291
+
1292
+ class CursorMixin:
1293
+ """
1294
+ A subclass of psycopg cursor implementing callproc.
1295
+ """
1296
+
1297
+ def callproc(
1298
+ self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1299
+ ) -> list[Any] | None:
1300
+ if not isinstance(name, psycopg_sql.Identifier):
1301
+ name = psycopg_sql.Identifier(name)
1302
+
1303
+ qparts: list[psycopg_sql.Composable] = [
1304
+ psycopg_sql.SQL("SELECT * FROM "),
1305
+ name,
1306
+ psycopg_sql.SQL("("),
1307
+ ]
1308
+ if args:
1309
+ for item in args:
1310
+ qparts.append(psycopg_sql.Literal(item))
1311
+ qparts.append(psycopg_sql.SQL(","))
1312
+ del qparts[-1]
1313
+
1314
+ qparts.append(psycopg_sql.SQL(")"))
1315
+ stmt = psycopg_sql.Composed(qparts)
1316
+ self.execute(stmt) # type: ignore[attr-defined]
1317
+ return args
1318
+
1319
+
1320
+ class ServerBindingCursor(CursorMixin, Database.Cursor):
1321
+ pass
1322
+
1323
+
1324
+ class Cursor(CursorMixin, Database.ClientCursor):
1325
+ pass
1326
+
1327
+
1328
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
1329
+ def copy(self, statement: Any) -> Any:
1330
+ with self.debug_sql(statement):
1331
+ return self.cursor.copy(statement)