plain.postgres 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plain/postgres/CHANGELOG.md +1028 -0
- plain/postgres/README.md +925 -0
- plain/postgres/__init__.py +120 -0
- plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
- plain/postgres/aggregates.py +236 -0
- plain/postgres/backups/__init__.py +0 -0
- plain/postgres/backups/cli.py +148 -0
- plain/postgres/backups/clients.py +94 -0
- plain/postgres/backups/core.py +172 -0
- plain/postgres/base.py +1415 -0
- plain/postgres/cli/__init__.py +3 -0
- plain/postgres/cli/db.py +142 -0
- plain/postgres/cli/migrations.py +1085 -0
- plain/postgres/config.py +18 -0
- plain/postgres/connection.py +1331 -0
- plain/postgres/connections.py +77 -0
- plain/postgres/constants.py +13 -0
- plain/postgres/constraints.py +495 -0
- plain/postgres/database_url.py +94 -0
- plain/postgres/db.py +59 -0
- plain/postgres/default_settings.py +38 -0
- plain/postgres/deletion.py +475 -0
- plain/postgres/dialect.py +640 -0
- plain/postgres/entrypoints.py +4 -0
- plain/postgres/enums.py +103 -0
- plain/postgres/exceptions.py +217 -0
- plain/postgres/expressions.py +1912 -0
- plain/postgres/fields/__init__.py +2118 -0
- plain/postgres/fields/encrypted.py +354 -0
- plain/postgres/fields/json.py +413 -0
- plain/postgres/fields/mixins.py +30 -0
- plain/postgres/fields/related.py +1192 -0
- plain/postgres/fields/related_descriptors.py +290 -0
- plain/postgres/fields/related_lookups.py +223 -0
- plain/postgres/fields/related_managers.py +661 -0
- plain/postgres/fields/reverse_descriptors.py +229 -0
- plain/postgres/fields/reverse_related.py +328 -0
- plain/postgres/fields/timezones.py +143 -0
- plain/postgres/forms.py +773 -0
- plain/postgres/functions/__init__.py +189 -0
- plain/postgres/functions/comparison.py +127 -0
- plain/postgres/functions/datetime.py +454 -0
- plain/postgres/functions/math.py +140 -0
- plain/postgres/functions/mixins.py +59 -0
- plain/postgres/functions/text.py +282 -0
- plain/postgres/functions/window.py +125 -0
- plain/postgres/indexes.py +286 -0
- plain/postgres/lookups.py +758 -0
- plain/postgres/meta.py +584 -0
- plain/postgres/migrations/__init__.py +53 -0
- plain/postgres/migrations/autodetector.py +1379 -0
- plain/postgres/migrations/exceptions.py +54 -0
- plain/postgres/migrations/executor.py +188 -0
- plain/postgres/migrations/graph.py +364 -0
- plain/postgres/migrations/loader.py +377 -0
- plain/postgres/migrations/migration.py +180 -0
- plain/postgres/migrations/operations/__init__.py +34 -0
- plain/postgres/migrations/operations/base.py +139 -0
- plain/postgres/migrations/operations/fields.py +373 -0
- plain/postgres/migrations/operations/models.py +798 -0
- plain/postgres/migrations/operations/special.py +184 -0
- plain/postgres/migrations/optimizer.py +74 -0
- plain/postgres/migrations/questioner.py +340 -0
- plain/postgres/migrations/recorder.py +119 -0
- plain/postgres/migrations/serializer.py +378 -0
- plain/postgres/migrations/state.py +882 -0
- plain/postgres/migrations/utils.py +147 -0
- plain/postgres/migrations/writer.py +302 -0
- plain/postgres/options.py +207 -0
- plain/postgres/otel.py +231 -0
- plain/postgres/preflight.py +336 -0
- plain/postgres/query.py +2242 -0
- plain/postgres/query_utils.py +456 -0
- plain/postgres/registry.py +217 -0
- plain/postgres/schema.py +1885 -0
- plain/postgres/sql/__init__.py +40 -0
- plain/postgres/sql/compiler.py +1869 -0
- plain/postgres/sql/constants.py +22 -0
- plain/postgres/sql/datastructures.py +222 -0
- plain/postgres/sql/query.py +2947 -0
- plain/postgres/sql/where.py +374 -0
- plain/postgres/test/__init__.py +0 -0
- plain/postgres/test/pytest.py +117 -0
- plain/postgres/test/utils.py +18 -0
- plain/postgres/transaction.py +222 -0
- plain/postgres/types.py +92 -0
- plain/postgres/types.pyi +751 -0
- plain/postgres/utils.py +345 -0
- plain_postgres-0.84.0.dist-info/METADATA +937 -0
- plain_postgres-0.84.0.dist-info/RECORD +93 -0
- plain_postgres-0.84.0.dist-info/WHEEL +4 -0
- plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
- plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
|
@@ -0,0 +1,1331 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import _thread
|
|
4
|
+
import datetime
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import signal
|
|
8
|
+
import subprocess
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
import warnings
|
|
12
|
+
import zoneinfo
|
|
13
|
+
from collections import deque
|
|
14
|
+
from collections.abc import Generator, Sequence
|
|
15
|
+
from contextlib import contextmanager
|
|
16
|
+
from functools import cached_property, lru_cache
|
|
17
|
+
from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
|
|
18
|
+
|
|
19
|
+
import psycopg as Database
|
|
20
|
+
from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
|
|
21
|
+
from psycopg import sql as psycopg_sql
|
|
22
|
+
from psycopg.abc import Buffer, PyFormat
|
|
23
|
+
from psycopg.postgres import types as pg_types
|
|
24
|
+
from psycopg.pq import Format
|
|
25
|
+
from psycopg.types.datetime import TimestamptzLoader
|
|
26
|
+
from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
|
|
27
|
+
from psycopg.types.string import TextLoader
|
|
28
|
+
|
|
29
|
+
from plain.exceptions import ImproperlyConfigured
|
|
30
|
+
from plain.postgres import utils
|
|
31
|
+
from plain.postgres.db import (
|
|
32
|
+
DatabaseError,
|
|
33
|
+
DatabaseErrorWrapper,
|
|
34
|
+
)
|
|
35
|
+
from plain.postgres.dialect import MAX_NAME_LENGTH, quote_name
|
|
36
|
+
from plain.postgres.indexes import Index
|
|
37
|
+
from plain.postgres.schema import DatabaseSchemaEditor
|
|
38
|
+
from plain.postgres.transaction import TransactionManagementError
|
|
39
|
+
from plain.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
|
40
|
+
from plain.postgres.utils import CursorWrapper, debug_transaction
|
|
41
|
+
from plain.runtime import settings
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from psycopg import Connection as PsycopgConnection
|
|
45
|
+
|
|
46
|
+
from plain.postgres.connections import DatabaseConfig
|
|
47
|
+
from plain.postgres.fields import Field
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger("plain.postgres.connection")
|
|
50
|
+
|
|
51
|
+
# The prefix to put on the default database name when creating
|
|
52
|
+
# the test database.
|
|
53
|
+
TEST_DATABASE_PREFIX = "test_"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_migratable_models() -> Generator[Any]:
|
|
57
|
+
"""Return all models that should be included in migrations."""
|
|
58
|
+
from plain.packages import packages_registry
|
|
59
|
+
from plain.postgres import models_registry
|
|
60
|
+
|
|
61
|
+
return (
|
|
62
|
+
model
|
|
63
|
+
for package_config in packages_registry.get_package_configs()
|
|
64
|
+
for model in models_registry.get_models(
|
|
65
|
+
package_label=package_config.package_label
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TableInfo(NamedTuple):
|
|
71
|
+
"""Structure returned by DatabaseConnection.get_table_list()."""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
type: str
|
|
75
|
+
comment: str | None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Type OIDs
|
|
79
|
+
TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
|
|
80
|
+
TSRANGE_OID = pg_types["tsrange"].oid
|
|
81
|
+
TSTZRANGE_OID = pg_types["tstzrange"].oid
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class BaseTzLoader(TimestamptzLoader):
|
|
85
|
+
"""
|
|
86
|
+
Load a PostgreSQL timestamptz using a specific timezone.
|
|
87
|
+
The timezone can be None too, in which case it will be chopped.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
timezone: datetime.tzinfo | None = None
|
|
91
|
+
|
|
92
|
+
def load(self, data: Buffer) -> datetime.datetime:
|
|
93
|
+
res = super().load(data)
|
|
94
|
+
return res.replace(tzinfo=self.timezone)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
|
|
98
|
+
class SpecificTzLoader(BaseTzLoader):
|
|
99
|
+
timezone = tz
|
|
100
|
+
|
|
101
|
+
context.adapters.register_loader("timestamptz", SpecificTzLoader)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class PlainRangeDumper(RangeDumper):
|
|
105
|
+
"""A Range dumper customized for Plain."""
|
|
106
|
+
|
|
107
|
+
def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
|
|
108
|
+
dumper = super().upgrade(obj, format)
|
|
109
|
+
if dumper is not self and dumper.oid == TSRANGE_OID:
|
|
110
|
+
dumper.oid = TSTZRANGE_OID
|
|
111
|
+
return dumper
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@lru_cache
|
|
115
|
+
def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
|
|
116
|
+
ctx = adapt.AdaptersMap(adapters)
|
|
117
|
+
# No-op JSON loader to avoid psycopg3 round trips
|
|
118
|
+
ctx.register_loader("jsonb", TextLoader)
|
|
119
|
+
# Treat inet/cidr as text
|
|
120
|
+
ctx.register_loader("inet", TextLoader)
|
|
121
|
+
ctx.register_loader("cidr", TextLoader)
|
|
122
|
+
ctx.register_dumper(Range, PlainRangeDumper)
|
|
123
|
+
register_tzloader(timezone, ctx)
|
|
124
|
+
return ctx
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _psql_settings_to_cmd_args_env(
|
|
128
|
+
settings_dict: DatabaseConfig, parameters: list[str]
|
|
129
|
+
) -> tuple[list[str], dict[str, str] | None]:
|
|
130
|
+
"""Build psql command-line arguments from database settings."""
|
|
131
|
+
args = ["psql"]
|
|
132
|
+
options = settings_dict.get("OPTIONS", {})
|
|
133
|
+
|
|
134
|
+
if user := settings_dict.get("USER"):
|
|
135
|
+
args += ["-U", user]
|
|
136
|
+
if host := settings_dict.get("HOST"):
|
|
137
|
+
args += ["-h", host]
|
|
138
|
+
if port := settings_dict.get("PORT"):
|
|
139
|
+
args += ["-p", str(port)]
|
|
140
|
+
args.extend(parameters)
|
|
141
|
+
args += [settings_dict.get("DATABASE") or "postgres"]
|
|
142
|
+
|
|
143
|
+
env: dict[str, str] = {}
|
|
144
|
+
if password := settings_dict.get("PASSWORD"):
|
|
145
|
+
env["PGPASSWORD"] = str(password)
|
|
146
|
+
|
|
147
|
+
# Map OPTIONS keys to their corresponding environment variables.
|
|
148
|
+
option_env_vars = {
|
|
149
|
+
"passfile": "PGPASSFILE",
|
|
150
|
+
"sslmode": "PGSSLMODE",
|
|
151
|
+
"sslrootcert": "PGSSLROOTCERT",
|
|
152
|
+
"sslcert": "PGSSLCERT",
|
|
153
|
+
"sslkey": "PGSSLKEY",
|
|
154
|
+
}
|
|
155
|
+
for option_key, env_var in option_env_vars.items():
|
|
156
|
+
if value := options.get(option_key):
|
|
157
|
+
env[env_var] = str(value)
|
|
158
|
+
|
|
159
|
+
return args, (env or None)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class DatabaseConnection:
|
|
163
|
+
"""
|
|
164
|
+
PostgreSQL database connection.
|
|
165
|
+
|
|
166
|
+
This is the only database backend supported by Plain.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
queries_limit: int = 9000
|
|
170
|
+
executable_name: str = "psql"
|
|
171
|
+
|
|
172
|
+
index_default_access_method = "btree"
|
|
173
|
+
ignored_tables: list[str] = []
|
|
174
|
+
|
|
175
|
+
def __init__(self, settings_dict: DatabaseConfig):
|
|
176
|
+
# Connection related attributes.
|
|
177
|
+
# The underlying database connection (from the database library, not a wrapper).
|
|
178
|
+
self.connection: PsycopgConnection[Any] | None = None
|
|
179
|
+
# `settings_dict` should be a dictionary containing keys such as
|
|
180
|
+
# DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
|
|
181
|
+
# to disambiguate it from Plain settings modules.
|
|
182
|
+
self.settings_dict: DatabaseConfig = settings_dict
|
|
183
|
+
# Query logging in debug mode or when explicitly enabled.
|
|
184
|
+
self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
|
|
185
|
+
self.force_debug_cursor: bool = False
|
|
186
|
+
|
|
187
|
+
# Transaction related attributes.
|
|
188
|
+
# Tracks if the connection is in autocommit mode. Per PEP 249, by
|
|
189
|
+
# default, it isn't.
|
|
190
|
+
self.autocommit: bool = False
|
|
191
|
+
# Tracks if the connection is in a transaction managed by 'atomic'.
|
|
192
|
+
self.in_atomic_block: bool = False
|
|
193
|
+
# Increment to generate unique savepoint ids.
|
|
194
|
+
self.savepoint_state: int = 0
|
|
195
|
+
# List of savepoints created by 'atomic'.
|
|
196
|
+
self.savepoint_ids: list[str | None] = []
|
|
197
|
+
# Stack of active 'atomic' blocks.
|
|
198
|
+
self.atomic_blocks: list[Any] = []
|
|
199
|
+
# Tracks if the transaction should be rolled back to the next
|
|
200
|
+
# available savepoint because of an exception in an inner block.
|
|
201
|
+
self.needs_rollback: bool = False
|
|
202
|
+
self.rollback_exc: Exception | None = None
|
|
203
|
+
|
|
204
|
+
# Connection termination related attributes.
|
|
205
|
+
self.close_at: float | None = None
|
|
206
|
+
self.closed_in_transaction: bool = False
|
|
207
|
+
self.errors_occurred: bool = False
|
|
208
|
+
self.health_check_enabled: bool = False
|
|
209
|
+
self.health_check_done: bool = False
|
|
210
|
+
|
|
211
|
+
# A list of no-argument functions to run when the transaction commits.
|
|
212
|
+
# Each entry is an (sids, func, robust) tuple, where sids is a set of
|
|
213
|
+
# the active savepoint IDs when this function was registered and robust
|
|
214
|
+
# specifies whether it's allowed for the function to fail.
|
|
215
|
+
self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
|
|
216
|
+
|
|
217
|
+
# Should we run the on-commit hooks the next time set_autocommit(True)
|
|
218
|
+
# is called?
|
|
219
|
+
self.run_commit_hooks_on_set_autocommit_on: bool = False
|
|
220
|
+
|
|
221
|
+
# A stack of wrappers to be invoked around execute()/executemany()
|
|
222
|
+
# calls. Each entry is a function taking five arguments: execute, sql,
|
|
223
|
+
# params, many, and context. It's the function's responsibility to
|
|
224
|
+
# call execute(sql, params, many, context).
|
|
225
|
+
self.execute_wrappers: list[Any] = []
|
|
226
|
+
|
|
227
|
+
def __repr__(self) -> str:
|
|
228
|
+
return f"<{self.__class__.__qualname__} vendor='postgresql'>"
|
|
229
|
+
|
|
230
|
+
@cached_property
|
|
231
|
+
def timezone(self) -> datetime.tzinfo:
|
|
232
|
+
"""
|
|
233
|
+
Return a tzinfo of the database connection time zone.
|
|
234
|
+
|
|
235
|
+
When a datetime is read from the database, it is returned in this time
|
|
236
|
+
zone. Since PostgreSQL supports time zones, it doesn't matter which
|
|
237
|
+
time zone Plain uses, as long as aware datetimes are used everywhere.
|
|
238
|
+
Other users connecting to the database can choose their own time zone.
|
|
239
|
+
"""
|
|
240
|
+
if self.settings_dict["TIME_ZONE"] is None:
|
|
241
|
+
return datetime.UTC
|
|
242
|
+
return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
|
|
243
|
+
|
|
244
|
+
@cached_property
|
|
245
|
+
def timezone_name(self) -> str:
|
|
246
|
+
"""
|
|
247
|
+
Name of the time zone of the database connection.
|
|
248
|
+
"""
|
|
249
|
+
if self.settings_dict["TIME_ZONE"] is None:
|
|
250
|
+
return "UTC"
|
|
251
|
+
return self.settings_dict["TIME_ZONE"]
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def queries_logged(self) -> bool:
|
|
255
|
+
return self.force_debug_cursor or settings.DEBUG
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def queries(self) -> list[dict[str, Any]]:
|
|
259
|
+
if len(self.queries_log) == self.queries_log.maxlen:
|
|
260
|
+
warnings.warn(
|
|
261
|
+
f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
|
|
262
|
+
"will be returned."
|
|
263
|
+
)
|
|
264
|
+
return list(self.queries_log)
|
|
265
|
+
|
|
266
|
+
# ##### Connection and cursor methods #####
|
|
267
|
+
|
|
268
|
+
def get_connection_params(self) -> dict[str, Any]:
|
|
269
|
+
"""Return a dict of parameters suitable for get_new_connection."""
|
|
270
|
+
settings_dict = self.settings_dict
|
|
271
|
+
options = settings_dict.get("OPTIONS", {})
|
|
272
|
+
db_name = settings_dict.get("DATABASE")
|
|
273
|
+
if db_name == "":
|
|
274
|
+
raise ImproperlyConfigured(
|
|
275
|
+
"PostgreSQL database is not configured. "
|
|
276
|
+
"Set DATABASE_URL or the POSTGRES_DATABASE setting."
|
|
277
|
+
)
|
|
278
|
+
if len(db_name or "") > MAX_NAME_LENGTH:
|
|
279
|
+
raise ImproperlyConfigured(
|
|
280
|
+
"The database name '%s' (%d characters) is longer than " # noqa: UP031
|
|
281
|
+
"PostgreSQL's limit of %d characters. Supply a shorter "
|
|
282
|
+
"POSTGRES_DATABASE setting."
|
|
283
|
+
% (
|
|
284
|
+
db_name,
|
|
285
|
+
len(db_name or ""),
|
|
286
|
+
MAX_NAME_LENGTH,
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
if db_name is None:
|
|
290
|
+
# None is used to connect to the default 'postgres' db.
|
|
291
|
+
db_name = "postgres"
|
|
292
|
+
conn_params: dict[str, Any] = {
|
|
293
|
+
"dbname": db_name,
|
|
294
|
+
**options,
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
conn_params.pop("assume_role", None)
|
|
298
|
+
conn_params.pop("isolation_level", None)
|
|
299
|
+
conn_params.pop("server_side_binding", None)
|
|
300
|
+
if settings_dict["USER"]:
|
|
301
|
+
conn_params["user"] = settings_dict["USER"]
|
|
302
|
+
if settings_dict["PASSWORD"]:
|
|
303
|
+
conn_params["password"] = settings_dict["PASSWORD"]
|
|
304
|
+
if settings_dict["HOST"]:
|
|
305
|
+
conn_params["host"] = settings_dict["HOST"]
|
|
306
|
+
if settings_dict["PORT"]:
|
|
307
|
+
conn_params["port"] = settings_dict["PORT"]
|
|
308
|
+
conn_params["context"] = get_adapters_template(self.timezone)
|
|
309
|
+
# Disable prepared statements by default to keep connection poolers
|
|
310
|
+
# working. Can be reenabled via OPTIONS in the settings dict.
|
|
311
|
+
conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
|
|
312
|
+
return conn_params
|
|
313
|
+
|
|
314
|
+
def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
|
|
315
|
+
"""Open a connection to the database."""
|
|
316
|
+
# self.isolation_level must be set:
|
|
317
|
+
# - after connecting to the database in order to obtain the database's
|
|
318
|
+
# default when no value is explicitly specified in options.
|
|
319
|
+
# - before calling _set_autocommit() because if autocommit is on, that
|
|
320
|
+
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
|
|
321
|
+
options = self.settings_dict.get("OPTIONS", {})
|
|
322
|
+
set_isolation_level = False
|
|
323
|
+
try:
|
|
324
|
+
isolation_level_value = options["isolation_level"]
|
|
325
|
+
except KeyError:
|
|
326
|
+
self.isolation_level = IsolationLevel.READ_COMMITTED
|
|
327
|
+
else:
|
|
328
|
+
# Set the isolation level to the value from OPTIONS.
|
|
329
|
+
try:
|
|
330
|
+
self.isolation_level = IsolationLevel(isolation_level_value)
|
|
331
|
+
set_isolation_level = True
|
|
332
|
+
except ValueError:
|
|
333
|
+
raise ImproperlyConfigured(
|
|
334
|
+
f"Invalid transaction isolation level {isolation_level_value} "
|
|
335
|
+
f"specified. Use one of the psycopg.IsolationLevel values."
|
|
336
|
+
)
|
|
337
|
+
connection = Database.connect(**conn_params)
|
|
338
|
+
if set_isolation_level:
|
|
339
|
+
connection.isolation_level = self.isolation_level
|
|
340
|
+
# Use server-side binding cursor if requested, otherwise standard cursor
|
|
341
|
+
connection.cursor_factory = (
|
|
342
|
+
ServerBindingCursor
|
|
343
|
+
if options.get("server_side_binding") is True
|
|
344
|
+
else Cursor
|
|
345
|
+
)
|
|
346
|
+
return connection
|
|
347
|
+
|
|
348
|
+
def ensure_timezone(self) -> bool:
|
|
349
|
+
"""
|
|
350
|
+
Ensure the connection's timezone is set to `self.timezone_name` and
|
|
351
|
+
return whether it changed or not.
|
|
352
|
+
"""
|
|
353
|
+
if self.connection is None:
|
|
354
|
+
return False
|
|
355
|
+
conn_timezone_name = self.connection.info.parameter_status("TimeZone")
|
|
356
|
+
timezone_name = self.timezone_name
|
|
357
|
+
if timezone_name and conn_timezone_name != timezone_name:
|
|
358
|
+
self.connection.execute(
|
|
359
|
+
"SELECT set_config('TimeZone', %s, false)", [timezone_name]
|
|
360
|
+
)
|
|
361
|
+
return True
|
|
362
|
+
return False
|
|
363
|
+
|
|
364
|
+
def ensure_role(self) -> bool:
|
|
365
|
+
if self.connection is None:
|
|
366
|
+
return False
|
|
367
|
+
if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
|
|
368
|
+
sql_str = self.compose_sql("SET ROLE %s", [new_role])
|
|
369
|
+
self.connection.execute(sql_str) # type: ignore[arg-type]
|
|
370
|
+
return True
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
def init_connection_state(self) -> None:
|
|
374
|
+
"""Initialize the database connection settings."""
|
|
375
|
+
self.ensure_timezone()
|
|
376
|
+
# Set the role on the connection. This is useful if the credential used
|
|
377
|
+
# to login is not the same as the role that owns database resources. As
|
|
378
|
+
# can be the case when using temporary or ephemeral credentials.
|
|
379
|
+
self.ensure_role()
|
|
380
|
+
|
|
381
|
+
def create_cursor(self) -> Any:
|
|
382
|
+
"""Create a cursor. Assume that a connection is established."""
|
|
383
|
+
assert self.connection is not None
|
|
384
|
+
cursor = self.connection.cursor()
|
|
385
|
+
|
|
386
|
+
# Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
|
|
387
|
+
tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
|
388
|
+
if self.timezone != tzloader.timezone: # type: ignore[union-attr]
|
|
389
|
+
register_tzloader(self.timezone, cursor)
|
|
390
|
+
return cursor
|
|
391
|
+
|
|
392
|
+
def _set_autocommit(self, autocommit: bool) -> None:
|
|
393
|
+
"""Backend-specific implementation to enable or disable autocommit."""
|
|
394
|
+
assert self.connection is not None
|
|
395
|
+
with self.wrap_database_errors:
|
|
396
|
+
self.connection.autocommit = autocommit
|
|
397
|
+
|
|
398
|
+
def check_constraints(self, table_names: list[str] | None = None) -> None:
|
|
399
|
+
"""
|
|
400
|
+
Check constraints by setting them to immediate. Return them to deferred
|
|
401
|
+
afterward.
|
|
402
|
+
"""
|
|
403
|
+
with self.cursor() as cursor:
|
|
404
|
+
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
|
|
405
|
+
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
|
|
406
|
+
|
|
407
|
+
def is_usable(self) -> bool:
|
|
408
|
+
"""
|
|
409
|
+
Test if the database connection is usable.
|
|
410
|
+
|
|
411
|
+
This method may assume that self.connection is not None.
|
|
412
|
+
|
|
413
|
+
Actual implementations should take care not to raise exceptions
|
|
414
|
+
as that may prevent Plain from recycling unusable connections.
|
|
415
|
+
"""
|
|
416
|
+
assert self.connection is not None
|
|
417
|
+
try:
|
|
418
|
+
# Use psycopg directly, bypassing Plain's utilities.
|
|
419
|
+
self.connection.execute("SELECT 1")
|
|
420
|
+
except Database.Error:
|
|
421
|
+
return False
|
|
422
|
+
else:
|
|
423
|
+
return True
|
|
424
|
+
|
|
425
|
+
@contextmanager
|
|
426
|
+
def _nodb_cursor(self) -> Generator[utils.CursorWrapper]:
|
|
427
|
+
"""
|
|
428
|
+
Return a cursor from an alternative connection to be used when there is
|
|
429
|
+
no need to access the main database, specifically for test db
|
|
430
|
+
creation/deletion. This also prevents the production database from
|
|
431
|
+
being exposed to potential child threads while (or after) the test
|
|
432
|
+
database is destroyed. Refs #10868, #17786, #16969.
|
|
433
|
+
"""
|
|
434
|
+
cursor = None
|
|
435
|
+
try:
|
|
436
|
+
conn = self.__class__({**self.settings_dict, "DATABASE": None})
|
|
437
|
+
try:
|
|
438
|
+
with conn.cursor() as cursor:
|
|
439
|
+
yield cursor
|
|
440
|
+
finally:
|
|
441
|
+
conn.close()
|
|
442
|
+
except (Database.DatabaseError, DatabaseError):
|
|
443
|
+
if cursor is not None:
|
|
444
|
+
raise
|
|
445
|
+
warnings.warn(
|
|
446
|
+
"Normally Plain will use a connection to the 'postgres' database "
|
|
447
|
+
"to avoid running initialization queries against the production "
|
|
448
|
+
"database when it's not needed (for example, when running tests). "
|
|
449
|
+
"Plain was unable to create a connection to the 'postgres' database "
|
|
450
|
+
"and will use the first PostgreSQL database instead.",
|
|
451
|
+
RuntimeWarning,
|
|
452
|
+
)
|
|
453
|
+
conn = self.__class__(self.settings_dict)
|
|
454
|
+
try:
|
|
455
|
+
with conn.cursor() as cursor:
|
|
456
|
+
yield cursor
|
|
457
|
+
finally:
|
|
458
|
+
conn.close()
|
|
459
|
+
|
|
460
|
+
@cached_property
|
|
461
|
+
def pg_version(self) -> int:
|
|
462
|
+
with self.temporary_connection():
|
|
463
|
+
assert self.connection is not None
|
|
464
|
+
return self.connection.info.server_version
|
|
465
|
+
|
|
466
|
+
def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
|
|
467
|
+
return CursorDebugWrapper(cursor, self)
|
|
468
|
+
|
|
469
|
+
# ##### Connection lifecycle #####
|
|
470
|
+
|
|
471
|
+
def connect(self) -> None:
|
|
472
|
+
"""Connect to the database. Assume that the connection is closed."""
|
|
473
|
+
# In case the previous connection was closed while in an atomic block
|
|
474
|
+
self.in_atomic_block = False
|
|
475
|
+
self.savepoint_ids = []
|
|
476
|
+
self.atomic_blocks = []
|
|
477
|
+
self.needs_rollback = False
|
|
478
|
+
# Reset parameters defining when to close/health-check the connection.
|
|
479
|
+
self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
|
|
480
|
+
max_age = self.settings_dict["CONN_MAX_AGE"]
|
|
481
|
+
self.close_at = None if max_age is None else time.monotonic() + max_age
|
|
482
|
+
self.closed_in_transaction = False
|
|
483
|
+
self.errors_occurred = False
|
|
484
|
+
# New connections are healthy.
|
|
485
|
+
self.health_check_done = True
|
|
486
|
+
# Establish the connection
|
|
487
|
+
conn_params = self.get_connection_params()
|
|
488
|
+
self.connection = self.get_new_connection(conn_params)
|
|
489
|
+
self.set_autocommit(True)
|
|
490
|
+
self.init_connection_state()
|
|
491
|
+
|
|
492
|
+
self.run_on_commit = []
|
|
493
|
+
|
|
494
|
+
def ensure_connection(self) -> None:
|
|
495
|
+
"""Guarantee that a connection to the database is established."""
|
|
496
|
+
if self.connection is None:
|
|
497
|
+
with self.wrap_database_errors:
|
|
498
|
+
self.connect()
|
|
499
|
+
|
|
500
|
+
# ##### PEP-249 connection method wrappers #####
|
|
501
|
+
|
|
502
|
+
def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
|
|
503
|
+
"""
|
|
504
|
+
Validate the connection is usable and perform database cursor wrapping.
|
|
505
|
+
"""
|
|
506
|
+
if self.queries_logged:
|
|
507
|
+
wrapped_cursor = self.make_debug_cursor(cursor)
|
|
508
|
+
else:
|
|
509
|
+
wrapped_cursor = self.make_cursor(cursor)
|
|
510
|
+
return wrapped_cursor
|
|
511
|
+
|
|
512
|
+
def _cursor(self) -> utils.CursorWrapper:
|
|
513
|
+
self.close_if_health_check_failed()
|
|
514
|
+
self.ensure_connection()
|
|
515
|
+
with self.wrap_database_errors:
|
|
516
|
+
return self._prepare_cursor(self.create_cursor())
|
|
517
|
+
|
|
518
|
+
def _commit(self) -> None:
|
|
519
|
+
if self.connection is not None:
|
|
520
|
+
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
|
521
|
+
return self.connection.commit()
|
|
522
|
+
|
|
523
|
+
def _rollback(self) -> None:
|
|
524
|
+
if self.connection is not None:
|
|
525
|
+
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
|
526
|
+
return self.connection.rollback()
|
|
527
|
+
|
|
528
|
+
def _close(self) -> None:
|
|
529
|
+
if self.connection is not None:
|
|
530
|
+
with self.wrap_database_errors:
|
|
531
|
+
return self.connection.close()
|
|
532
|
+
|
|
533
|
+
# ##### Generic wrappers for PEP-249 connection methods #####
|
|
534
|
+
|
|
535
|
+
def cursor(self) -> utils.CursorWrapper:
|
|
536
|
+
"""Create a cursor, opening a connection if necessary."""
|
|
537
|
+
return self._cursor()
|
|
538
|
+
|
|
539
|
+
def commit(self) -> None:
|
|
540
|
+
"""Commit a transaction and reset the dirty flag."""
|
|
541
|
+
self.validate_no_atomic_block()
|
|
542
|
+
self._commit()
|
|
543
|
+
# A successful commit means that the database connection works.
|
|
544
|
+
self.errors_occurred = False
|
|
545
|
+
self.run_commit_hooks_on_set_autocommit_on = True
|
|
546
|
+
|
|
547
|
+
def rollback(self) -> None:
|
|
548
|
+
"""Roll back a transaction and reset the dirty flag."""
|
|
549
|
+
self.validate_no_atomic_block()
|
|
550
|
+
self._rollback()
|
|
551
|
+
# A successful rollback means that the database connection works.
|
|
552
|
+
self.errors_occurred = False
|
|
553
|
+
self.needs_rollback = False
|
|
554
|
+
self.run_on_commit = []
|
|
555
|
+
|
|
556
|
+
def close(self) -> None:
|
|
557
|
+
"""Close the connection to the database."""
|
|
558
|
+
self.run_on_commit = []
|
|
559
|
+
|
|
560
|
+
# Don't call validate_no_atomic_block() to avoid making it difficult
|
|
561
|
+
# to get rid of a connection in an invalid state. The next connect()
|
|
562
|
+
# will reset the transaction state anyway.
|
|
563
|
+
if self.closed_in_transaction or self.connection is None:
|
|
564
|
+
return
|
|
565
|
+
try:
|
|
566
|
+
self._close()
|
|
567
|
+
finally:
|
|
568
|
+
if self.in_atomic_block:
|
|
569
|
+
self.closed_in_transaction = True
|
|
570
|
+
self.needs_rollback = True
|
|
571
|
+
else:
|
|
572
|
+
self.connection = None
|
|
573
|
+
|
|
574
|
+
# ##### Savepoint management #####
|
|
575
|
+
|
|
576
|
+
def _savepoint(self, sid: str) -> None:
|
|
577
|
+
with self.cursor() as cursor:
|
|
578
|
+
cursor.execute(f"SAVEPOINT {quote_name(sid)}")
|
|
579
|
+
|
|
580
|
+
def _savepoint_rollback(self, sid: str) -> None:
|
|
581
|
+
with self.cursor() as cursor:
|
|
582
|
+
cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
|
|
583
|
+
|
|
584
|
+
def _savepoint_commit(self, sid: str) -> None:
|
|
585
|
+
with self.cursor() as cursor:
|
|
586
|
+
cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
|
|
587
|
+
|
|
588
|
+
# ##### Generic savepoint management methods #####
|
|
589
|
+
|
|
590
|
+
def savepoint(self) -> str | None:
|
|
591
|
+
"""
|
|
592
|
+
Create a savepoint inside the current transaction. Return an
|
|
593
|
+
identifier for the savepoint that will be used for the subsequent
|
|
594
|
+
rollback or commit. Return None if in autocommit mode (no transaction).
|
|
595
|
+
"""
|
|
596
|
+
if self.get_autocommit():
|
|
597
|
+
return None
|
|
598
|
+
|
|
599
|
+
thread_ident = _thread.get_ident()
|
|
600
|
+
tid = str(thread_ident).replace("-", "")
|
|
601
|
+
|
|
602
|
+
self.savepoint_state += 1
|
|
603
|
+
sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
|
|
604
|
+
|
|
605
|
+
self._savepoint(sid)
|
|
606
|
+
|
|
607
|
+
return sid
|
|
608
|
+
|
|
609
|
+
def savepoint_rollback(self, sid: str) -> None:
|
|
610
|
+
"""
|
|
611
|
+
Roll back to a savepoint. Do nothing if in autocommit mode.
|
|
612
|
+
"""
|
|
613
|
+
if self.get_autocommit():
|
|
614
|
+
return
|
|
615
|
+
|
|
616
|
+
self._savepoint_rollback(sid)
|
|
617
|
+
|
|
618
|
+
# Remove any callbacks registered while this savepoint was active.
|
|
619
|
+
self.run_on_commit = [
|
|
620
|
+
(sids, func, robust)
|
|
621
|
+
for (sids, func, robust) in self.run_on_commit
|
|
622
|
+
if sid not in sids
|
|
623
|
+
]
|
|
624
|
+
|
|
625
|
+
def savepoint_commit(self, sid: str) -> None:
|
|
626
|
+
"""
|
|
627
|
+
Release a savepoint. Do nothing if in autocommit mode.
|
|
628
|
+
"""
|
|
629
|
+
if self.get_autocommit():
|
|
630
|
+
return
|
|
631
|
+
|
|
632
|
+
self._savepoint_commit(sid)
|
|
633
|
+
|
|
634
|
+
def clean_savepoints(self) -> None:
|
|
635
|
+
"""
|
|
636
|
+
Reset the counter used to generate unique savepoint ids in this thread.
|
|
637
|
+
"""
|
|
638
|
+
self.savepoint_state = 0
|
|
639
|
+
|
|
640
|
+
# ##### Generic transaction management methods #####
|
|
641
|
+
|
|
642
|
+
def get_autocommit(self) -> bool:
|
|
643
|
+
"""Get the autocommit state."""
|
|
644
|
+
self.ensure_connection()
|
|
645
|
+
return self.autocommit
|
|
646
|
+
|
|
647
|
+
def set_autocommit(self, autocommit: bool) -> None:
|
|
648
|
+
"""
|
|
649
|
+
Enable or disable autocommit.
|
|
650
|
+
|
|
651
|
+
Used internally by atomic() to manage transactions. Don't call this
|
|
652
|
+
directly — use atomic() instead.
|
|
653
|
+
"""
|
|
654
|
+
self.validate_no_atomic_block()
|
|
655
|
+
self.close_if_health_check_failed()
|
|
656
|
+
self.ensure_connection()
|
|
657
|
+
|
|
658
|
+
if autocommit:
|
|
659
|
+
self._set_autocommit(autocommit)
|
|
660
|
+
else:
|
|
661
|
+
with debug_transaction(self, "BEGIN"):
|
|
662
|
+
self._set_autocommit(autocommit)
|
|
663
|
+
self.autocommit = autocommit
|
|
664
|
+
|
|
665
|
+
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
|
666
|
+
self.run_and_clear_commit_hooks()
|
|
667
|
+
self.run_commit_hooks_on_set_autocommit_on = False
|
|
668
|
+
|
|
669
|
+
def get_rollback(self) -> bool:
|
|
670
|
+
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
|
671
|
+
if not self.in_atomic_block:
|
|
672
|
+
raise TransactionManagementError(
|
|
673
|
+
"The rollback flag doesn't work outside of an 'atomic' block."
|
|
674
|
+
)
|
|
675
|
+
return self.needs_rollback
|
|
676
|
+
|
|
677
|
+
def set_rollback(self, rollback: bool) -> None:
|
|
678
|
+
"""
|
|
679
|
+
Set or unset the "needs rollback" flag -- for *advanced use* only.
|
|
680
|
+
"""
|
|
681
|
+
if not self.in_atomic_block:
|
|
682
|
+
raise TransactionManagementError(
|
|
683
|
+
"The rollback flag doesn't work outside of an 'atomic' block."
|
|
684
|
+
)
|
|
685
|
+
self.needs_rollback = rollback
|
|
686
|
+
|
|
687
|
+
def validate_no_atomic_block(self) -> None:
|
|
688
|
+
"""Raise an error if an atomic block is active."""
|
|
689
|
+
if self.in_atomic_block:
|
|
690
|
+
raise TransactionManagementError(
|
|
691
|
+
"This is forbidden when an 'atomic' block is active."
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
def validate_no_broken_transaction(self) -> None:
|
|
695
|
+
if self.needs_rollback:
|
|
696
|
+
raise TransactionManagementError(
|
|
697
|
+
"An error occurred in the current transaction. You can't "
|
|
698
|
+
"execute queries until the end of the 'atomic' block."
|
|
699
|
+
) from self.rollback_exc
|
|
700
|
+
|
|
701
|
+
# ##### Connection termination handling #####
|
|
702
|
+
|
|
703
|
+
def close_if_health_check_failed(self) -> None:
|
|
704
|
+
"""Close existing connection if it fails a health check."""
|
|
705
|
+
if (
|
|
706
|
+
self.connection is None
|
|
707
|
+
or not self.health_check_enabled
|
|
708
|
+
or self.health_check_done
|
|
709
|
+
):
|
|
710
|
+
return
|
|
711
|
+
|
|
712
|
+
if not self.is_usable():
|
|
713
|
+
self.close()
|
|
714
|
+
self.health_check_done = True
|
|
715
|
+
|
|
716
|
+
def close_if_unusable_or_obsolete(self) -> None:
|
|
717
|
+
"""
|
|
718
|
+
Close the current connection if unrecoverable errors have occurred
|
|
719
|
+
or if it outlived its maximum age.
|
|
720
|
+
"""
|
|
721
|
+
if self.connection is not None:
|
|
722
|
+
self.health_check_done = False
|
|
723
|
+
# If autocommit was not restored (e.g. a transaction was not
|
|
724
|
+
# properly closed), don't take chances, drop the connection.
|
|
725
|
+
if not self.get_autocommit():
|
|
726
|
+
self.close()
|
|
727
|
+
return
|
|
728
|
+
|
|
729
|
+
# If an exception other than DataError or IntegrityError occurred
|
|
730
|
+
# since the last commit / rollback, check if the connection works.
|
|
731
|
+
if self.errors_occurred:
|
|
732
|
+
if self.is_usable():
|
|
733
|
+
self.errors_occurred = False
|
|
734
|
+
self.health_check_done = True
|
|
735
|
+
else:
|
|
736
|
+
self.close()
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
if self.close_at is not None and time.monotonic() >= self.close_at:
|
|
740
|
+
self.close()
|
|
741
|
+
return
|
|
742
|
+
|
|
743
|
+
# ##### Miscellaneous #####
|
|
744
|
+
|
|
745
|
+
@cached_property
|
|
746
|
+
def wrap_database_errors(self) -> DatabaseErrorWrapper:
|
|
747
|
+
"""
|
|
748
|
+
Context manager and decorator that re-throws backend-specific database
|
|
749
|
+
exceptions using Plain's common wrappers.
|
|
750
|
+
"""
|
|
751
|
+
return DatabaseErrorWrapper(self)
|
|
752
|
+
|
|
753
|
+
def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
|
|
754
|
+
"""Create a cursor without debug logging."""
|
|
755
|
+
return utils.CursorWrapper(cursor, self)
|
|
756
|
+
|
|
757
|
+
@contextmanager
|
|
758
|
+
def temporary_connection(self) -> Generator[utils.CursorWrapper]:
|
|
759
|
+
"""
|
|
760
|
+
Context manager that ensures that a connection is established, and
|
|
761
|
+
if it opened one, closes it to avoid leaving a dangling connection.
|
|
762
|
+
This is useful for operations outside of the request-response cycle.
|
|
763
|
+
|
|
764
|
+
Provide a cursor: with self.temporary_connection() as cursor: ...
|
|
765
|
+
"""
|
|
766
|
+
must_close = self.connection is None
|
|
767
|
+
try:
|
|
768
|
+
with self.cursor() as cursor:
|
|
769
|
+
yield cursor
|
|
770
|
+
finally:
|
|
771
|
+
if must_close:
|
|
772
|
+
self.close()
|
|
773
|
+
|
|
774
|
+
def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
|
|
775
|
+
"""Return a new instance of the schema editor."""
|
|
776
|
+
return DatabaseSchemaEditor(self, *args, **kwargs)
|
|
777
|
+
|
|
778
|
+
def runshell(self, parameters: list[str]) -> None:
|
|
779
|
+
"""Run an interactive psql shell."""
|
|
780
|
+
args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
|
|
781
|
+
env = {**os.environ, **env} if env else None
|
|
782
|
+
sigint_handler = signal.getsignal(signal.SIGINT)
|
|
783
|
+
try:
|
|
784
|
+
# Allow SIGINT to pass to psql to abort queries.
|
|
785
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
786
|
+
subprocess.run(args, env=env, check=True)
|
|
787
|
+
finally:
|
|
788
|
+
# Restore the original SIGINT handler.
|
|
789
|
+
signal.signal(signal.SIGINT, sigint_handler)
|
|
790
|
+
|
|
791
|
+
def on_commit(self, func: Any, robust: bool = False) -> None:
|
|
792
|
+
if not callable(func):
|
|
793
|
+
raise TypeError("on_commit()'s callback must be a callable.")
|
|
794
|
+
if self.in_atomic_block:
|
|
795
|
+
# Transaction in progress; save for execution on commit.
|
|
796
|
+
self.run_on_commit.append((set(self.savepoint_ids), func, robust))
|
|
797
|
+
else:
|
|
798
|
+
# No transaction in progress; execute immediately.
|
|
799
|
+
if robust:
|
|
800
|
+
try:
|
|
801
|
+
func()
|
|
802
|
+
except Exception as e:
|
|
803
|
+
logger.error(
|
|
804
|
+
f"Error calling {func.__qualname__} in on_commit() (%s).",
|
|
805
|
+
e,
|
|
806
|
+
exc_info=True,
|
|
807
|
+
)
|
|
808
|
+
else:
|
|
809
|
+
func()
|
|
810
|
+
|
|
811
|
+
def run_and_clear_commit_hooks(self) -> None:
|
|
812
|
+
self.validate_no_atomic_block()
|
|
813
|
+
current_run_on_commit = self.run_on_commit
|
|
814
|
+
self.run_on_commit = []
|
|
815
|
+
while current_run_on_commit:
|
|
816
|
+
_, func, robust = current_run_on_commit.pop(0)
|
|
817
|
+
if robust:
|
|
818
|
+
try:
|
|
819
|
+
func()
|
|
820
|
+
except Exception as e:
|
|
821
|
+
logger.error(
|
|
822
|
+
f"Error calling {func.__qualname__} in on_commit() during "
|
|
823
|
+
f"transaction (%s).",
|
|
824
|
+
e,
|
|
825
|
+
exc_info=True,
|
|
826
|
+
)
|
|
827
|
+
else:
|
|
828
|
+
func()
|
|
829
|
+
|
|
830
|
+
@contextmanager
|
|
831
|
+
def execute_wrapper(self, wrapper: Any) -> Generator[None]:
|
|
832
|
+
"""
|
|
833
|
+
Return a context manager under which the wrapper is applied to suitable
|
|
834
|
+
database query executions.
|
|
835
|
+
"""
|
|
836
|
+
self.execute_wrappers.append(wrapper)
|
|
837
|
+
try:
|
|
838
|
+
yield
|
|
839
|
+
finally:
|
|
840
|
+
self.execute_wrappers.pop()
|
|
841
|
+
|
|
842
|
+
# ##### SQL generation methods that require connection state #####
|
|
843
|
+
|
|
844
|
+
def compose_sql(self, query: str, params: Any) -> str:
|
|
845
|
+
"""
|
|
846
|
+
Compose a SQL query with parameters using psycopg's mogrify.
|
|
847
|
+
|
|
848
|
+
This requires an active connection because it uses the connection's
|
|
849
|
+
cursor to properly format parameters.
|
|
850
|
+
"""
|
|
851
|
+
assert self.connection is not None
|
|
852
|
+
return ClientCursor(self.connection).mogrify(
|
|
853
|
+
psycopg_sql.SQL(cast(LiteralString, query)), params
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
def last_executed_query(
|
|
857
|
+
self,
|
|
858
|
+
cursor: utils.CursorWrapper,
|
|
859
|
+
sql: str,
|
|
860
|
+
params: Any,
|
|
861
|
+
) -> str | None:
|
|
862
|
+
"""
|
|
863
|
+
Return a string of the query last executed by the given cursor, with
|
|
864
|
+
placeholders replaced with actual values.
|
|
865
|
+
"""
|
|
866
|
+
try:
|
|
867
|
+
return self.compose_sql(sql, params)
|
|
868
|
+
except errors.DataError:
|
|
869
|
+
return None
|
|
870
|
+
|
|
871
|
+
def unification_cast_sql(self, output_field: Field) -> str:
|
|
872
|
+
"""
|
|
873
|
+
Given a field instance, return the SQL that casts the result of a union
|
|
874
|
+
to that type. The resulting string should contain a '%s' placeholder
|
|
875
|
+
for the expression being cast.
|
|
876
|
+
"""
|
|
877
|
+
internal_type = output_field.get_internal_type()
|
|
878
|
+
if internal_type in (
|
|
879
|
+
"GenericIPAddressField",
|
|
880
|
+
"TimeField",
|
|
881
|
+
"UUIDField",
|
|
882
|
+
):
|
|
883
|
+
# PostgreSQL will resolve a union as type 'text' if input types are
|
|
884
|
+
# 'unknown'.
|
|
885
|
+
# https://www.postgresql.org/docs/current/typeconv-union-case.html
|
|
886
|
+
# These fields cannot be implicitly cast back in the default
|
|
887
|
+
# PostgreSQL configuration so we need to explicitly cast them.
|
|
888
|
+
# We must also remove components of the type within brackets:
|
|
889
|
+
# varchar(255) -> varchar.
|
|
890
|
+
db_type = output_field.db_type()
|
|
891
|
+
if db_type:
|
|
892
|
+
return "CAST(%s AS {})".format(db_type.split("(")[0])
|
|
893
|
+
return "%s"
|
|
894
|
+
|
|
895
|
+
# ##### Introspection methods #####
|
|
896
|
+
|
|
897
|
+
def table_names(
|
|
898
|
+
self, cursor: CursorWrapper | None = None, include_views: bool = False
|
|
899
|
+
) -> list[str]:
|
|
900
|
+
"""
|
|
901
|
+
Return a list of names of all tables that exist in the database.
|
|
902
|
+
Sort the returned table list by Python's default sorting. Do NOT use
|
|
903
|
+
the database's ORDER BY here to avoid subtle differences in sorting
|
|
904
|
+
order between databases.
|
|
905
|
+
"""
|
|
906
|
+
|
|
907
|
+
def get_names(cursor: CursorWrapper) -> list[str]:
|
|
908
|
+
return sorted(
|
|
909
|
+
ti.name
|
|
910
|
+
for ti in self.get_table_list(cursor)
|
|
911
|
+
if include_views or ti.type == "t"
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
if cursor is None:
|
|
915
|
+
with self.cursor() as cursor:
|
|
916
|
+
return get_names(cursor)
|
|
917
|
+
return get_names(cursor)
|
|
918
|
+
|
|
919
|
+
def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
|
|
920
|
+
"""
|
|
921
|
+
Return an unsorted list of TableInfo named tuples of all tables and
|
|
922
|
+
views that exist in the database.
|
|
923
|
+
"""
|
|
924
|
+
cursor.execute(
|
|
925
|
+
"""
|
|
926
|
+
SELECT
|
|
927
|
+
c.relname,
|
|
928
|
+
CASE
|
|
929
|
+
WHEN c.relispartition THEN 'p'
|
|
930
|
+
WHEN c.relkind IN ('m', 'v') THEN 'v'
|
|
931
|
+
ELSE 't'
|
|
932
|
+
END,
|
|
933
|
+
obj_description(c.oid, 'pg_class')
|
|
934
|
+
FROM pg_catalog.pg_class c
|
|
935
|
+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
|
936
|
+
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
|
937
|
+
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
|
938
|
+
AND pg_catalog.pg_table_is_visible(c.oid)
|
|
939
|
+
"""
|
|
940
|
+
)
|
|
941
|
+
return [
|
|
942
|
+
TableInfo(*row)
|
|
943
|
+
for row in cursor.fetchall()
|
|
944
|
+
if row[0] not in self.ignored_tables
|
|
945
|
+
]
|
|
946
|
+
|
|
947
|
+
def plain_table_names(
|
|
948
|
+
self, only_existing: bool = False, include_views: bool = True
|
|
949
|
+
) -> list[str]:
|
|
950
|
+
"""
|
|
951
|
+
Return a list of all table names that have associated Plain models and
|
|
952
|
+
are in INSTALLED_PACKAGES.
|
|
953
|
+
|
|
954
|
+
If only_existing is True, include only the tables in the database.
|
|
955
|
+
"""
|
|
956
|
+
tables = set()
|
|
957
|
+
for model in get_migratable_models():
|
|
958
|
+
tables.add(model.model_options.db_table)
|
|
959
|
+
tables.update(
|
|
960
|
+
f.m2m_db_table() for f in model._model_meta.local_many_to_many
|
|
961
|
+
)
|
|
962
|
+
tables = list(tables)
|
|
963
|
+
if only_existing:
|
|
964
|
+
existing_tables = set(self.table_names(include_views=include_views))
|
|
965
|
+
tables = [t for t in tables if t in existing_tables]
|
|
966
|
+
return tables
|
|
967
|
+
|
|
968
|
+
def get_sequences(
|
|
969
|
+
self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
|
|
970
|
+
) -> list[dict[str, Any]]:
|
|
971
|
+
"""
|
|
972
|
+
Return a list of introspected sequences for table_name. Each sequence
|
|
973
|
+
is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
|
|
974
|
+
"""
|
|
975
|
+
cursor.execute(
|
|
976
|
+
"""
|
|
977
|
+
SELECT
|
|
978
|
+
s.relname AS sequence_name,
|
|
979
|
+
a.attname AS colname
|
|
980
|
+
FROM
|
|
981
|
+
pg_class s
|
|
982
|
+
JOIN pg_depend d ON d.objid = s.oid
|
|
983
|
+
AND d.classid = 'pg_class'::regclass
|
|
984
|
+
AND d.refclassid = 'pg_class'::regclass
|
|
985
|
+
JOIN pg_attribute a ON d.refobjid = a.attrelid
|
|
986
|
+
AND d.refobjsubid = a.attnum
|
|
987
|
+
JOIN pg_class tbl ON tbl.oid = d.refobjid
|
|
988
|
+
AND tbl.relname = %s
|
|
989
|
+
AND pg_catalog.pg_table_is_visible(tbl.oid)
|
|
990
|
+
WHERE
|
|
991
|
+
s.relkind = 'S';
|
|
992
|
+
""",
|
|
993
|
+
[table_name],
|
|
994
|
+
)
|
|
995
|
+
return [
|
|
996
|
+
{"name": row[0], "table": table_name, "column": row[1]}
|
|
997
|
+
for row in cursor.fetchall()
|
|
998
|
+
]
|
|
999
|
+
|
|
1000
|
+
def get_constraints(
|
|
1001
|
+
self, cursor: CursorWrapper, table_name: str
|
|
1002
|
+
) -> dict[str, dict[str, Any]]:
|
|
1003
|
+
"""
|
|
1004
|
+
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
|
1005
|
+
one or more columns. Also retrieve the definition of expression-based
|
|
1006
|
+
indexes.
|
|
1007
|
+
"""
|
|
1008
|
+
constraints: dict[str, dict[str, Any]] = {}
|
|
1009
|
+
# Loop over the key table, collecting things as constraints. The column
|
|
1010
|
+
# array must return column names in the same order in which they were
|
|
1011
|
+
# created.
|
|
1012
|
+
cursor.execute(
|
|
1013
|
+
"""
|
|
1014
|
+
SELECT
|
|
1015
|
+
c.conname,
|
|
1016
|
+
array(
|
|
1017
|
+
SELECT attname
|
|
1018
|
+
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
|
|
1019
|
+
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
|
|
1020
|
+
WHERE ca.attrelid = c.conrelid
|
|
1021
|
+
ORDER BY cols.arridx
|
|
1022
|
+
),
|
|
1023
|
+
c.contype,
|
|
1024
|
+
(SELECT fkc.relname || '.' || fka.attname
|
|
1025
|
+
FROM pg_attribute AS fka
|
|
1026
|
+
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
|
|
1027
|
+
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
|
|
1028
|
+
cl.reloptions
|
|
1029
|
+
FROM pg_constraint AS c
|
|
1030
|
+
JOIN pg_class AS cl ON c.conrelid = cl.oid
|
|
1031
|
+
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
|
|
1032
|
+
""",
|
|
1033
|
+
[table_name],
|
|
1034
|
+
)
|
|
1035
|
+
for constraint, columns, kind, used_cols, options in cursor.fetchall():
|
|
1036
|
+
constraints[constraint] = {
|
|
1037
|
+
"columns": columns,
|
|
1038
|
+
"primary_key": kind == "p",
|
|
1039
|
+
"unique": kind in ["p", "u"],
|
|
1040
|
+
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
|
|
1041
|
+
"check": kind == "c",
|
|
1042
|
+
"index": False,
|
|
1043
|
+
"definition": None,
|
|
1044
|
+
"options": options,
|
|
1045
|
+
}
|
|
1046
|
+
# Now get indexes
|
|
1047
|
+
cursor.execute(
|
|
1048
|
+
"""
|
|
1049
|
+
SELECT
|
|
1050
|
+
indexname,
|
|
1051
|
+
array_agg(attname ORDER BY arridx),
|
|
1052
|
+
indisunique,
|
|
1053
|
+
indisprimary,
|
|
1054
|
+
array_agg(ordering ORDER BY arridx),
|
|
1055
|
+
amname,
|
|
1056
|
+
exprdef,
|
|
1057
|
+
s2.attoptions
|
|
1058
|
+
FROM (
|
|
1059
|
+
SELECT
|
|
1060
|
+
c2.relname as indexname, idx.*, attr.attname, am.amname,
|
|
1061
|
+
CASE
|
|
1062
|
+
WHEN idx.indexprs IS NOT NULL THEN
|
|
1063
|
+
pg_get_indexdef(idx.indexrelid)
|
|
1064
|
+
END AS exprdef,
|
|
1065
|
+
CASE am.amname
|
|
1066
|
+
WHEN %s THEN
|
|
1067
|
+
CASE (option & 1)
|
|
1068
|
+
WHEN 1 THEN 'DESC' ELSE 'ASC'
|
|
1069
|
+
END
|
|
1070
|
+
END as ordering,
|
|
1071
|
+
c2.reloptions as attoptions
|
|
1072
|
+
FROM (
|
|
1073
|
+
SELECT *
|
|
1074
|
+
FROM
|
|
1075
|
+
pg_index i,
|
|
1076
|
+
unnest(i.indkey, i.indoption)
|
|
1077
|
+
WITH ORDINALITY koi(key, option, arridx)
|
|
1078
|
+
) idx
|
|
1079
|
+
LEFT JOIN pg_class c ON idx.indrelid = c.oid
|
|
1080
|
+
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
|
|
1081
|
+
LEFT JOIN pg_am am ON c2.relam = am.oid
|
|
1082
|
+
LEFT JOIN
|
|
1083
|
+
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
|
|
1084
|
+
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
|
|
1085
|
+
) s2
|
|
1086
|
+
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
|
|
1087
|
+
""",
|
|
1088
|
+
[self.index_default_access_method, table_name],
|
|
1089
|
+
)
|
|
1090
|
+
for (
|
|
1091
|
+
index,
|
|
1092
|
+
columns,
|
|
1093
|
+
unique,
|
|
1094
|
+
primary,
|
|
1095
|
+
orders,
|
|
1096
|
+
type_,
|
|
1097
|
+
definition,
|
|
1098
|
+
options,
|
|
1099
|
+
) in cursor.fetchall():
|
|
1100
|
+
if index not in constraints:
|
|
1101
|
+
basic_index = (
|
|
1102
|
+
type_ == self.index_default_access_method and options is None
|
|
1103
|
+
)
|
|
1104
|
+
constraints[index] = {
|
|
1105
|
+
"columns": columns if columns != [None] else [],
|
|
1106
|
+
"orders": orders if orders != [None] else [],
|
|
1107
|
+
"primary_key": primary,
|
|
1108
|
+
"unique": unique,
|
|
1109
|
+
"foreign_key": None,
|
|
1110
|
+
"check": False,
|
|
1111
|
+
"index": True,
|
|
1112
|
+
"type": Index.suffix if basic_index else type_,
|
|
1113
|
+
"definition": definition,
|
|
1114
|
+
"options": options,
|
|
1115
|
+
}
|
|
1116
|
+
return constraints
|
|
1117
|
+
|
|
1118
|
+
# ##### Test database creation methods (merged from DatabaseCreation) #####
|
|
1119
|
+
|
|
1120
|
+
def _log(self, msg: str) -> None:
|
|
1121
|
+
sys.stderr.write(msg + os.linesep)
|
|
1122
|
+
|
|
1123
|
+
def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
|
|
1124
|
+
"""
|
|
1125
|
+
Create a test database, prompting the user for confirmation if the
|
|
1126
|
+
database already exists. Return the name of the test database created.
|
|
1127
|
+
|
|
1128
|
+
If prefix is provided, it will be prepended to the database name
|
|
1129
|
+
to isolate it from other test databases.
|
|
1130
|
+
"""
|
|
1131
|
+
from plain.postgres.cli.migrations import apply
|
|
1132
|
+
|
|
1133
|
+
test_database_name = self._get_test_db_name(prefix)
|
|
1134
|
+
|
|
1135
|
+
if verbosity >= 1:
|
|
1136
|
+
self._log(f"Creating test database '{test_database_name}'...")
|
|
1137
|
+
|
|
1138
|
+
self._create_test_db(
|
|
1139
|
+
test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
self.close()
|
|
1143
|
+
settings.POSTGRES_DATABASE = test_database_name
|
|
1144
|
+
self.settings_dict["DATABASE"] = test_database_name
|
|
1145
|
+
|
|
1146
|
+
apply.callback(
|
|
1147
|
+
package_label=None,
|
|
1148
|
+
migration_name=None,
|
|
1149
|
+
fake=False,
|
|
1150
|
+
plan=False,
|
|
1151
|
+
check_unapplied=False,
|
|
1152
|
+
backup=False,
|
|
1153
|
+
no_input=True,
|
|
1154
|
+
atomic_batch=False, # No need for atomic batch when creating test database
|
|
1155
|
+
quiet=verbosity < 2, # Show migration output when verbosity is 2+
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
# Ensure a connection for the side effect of initializing the test database.
|
|
1159
|
+
self.ensure_connection()
|
|
1160
|
+
|
|
1161
|
+
return test_database_name
|
|
1162
|
+
|
|
1163
|
+
def _get_test_db_name(self, prefix: str = "") -> str:
|
|
1164
|
+
"""
|
|
1165
|
+
Internal implementation - return the name of the test DB that will be
|
|
1166
|
+
created. Only useful when called from create_test_db() and
|
|
1167
|
+
_create_test_db() and when no external munging is done with the 'DATABASE'
|
|
1168
|
+
settings.
|
|
1169
|
+
|
|
1170
|
+
If prefix is provided, it will be prepended to the database name.
|
|
1171
|
+
"""
|
|
1172
|
+
# Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
|
|
1173
|
+
base_name = (
|
|
1174
|
+
self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
|
|
1175
|
+
)
|
|
1176
|
+
if prefix:
|
|
1177
|
+
return f"{prefix}_{base_name}"
|
|
1178
|
+
if self.settings_dict["TEST"]["DATABASE"]:
|
|
1179
|
+
return self.settings_dict["TEST"]["DATABASE"]
|
|
1180
|
+
name = self.settings_dict["DATABASE"]
|
|
1181
|
+
if name is None:
|
|
1182
|
+
raise ValueError("POSTGRES_DATABASE must be set")
|
|
1183
|
+
return TEST_DATABASE_PREFIX + name
|
|
1184
|
+
|
|
1185
|
+
def _get_database_create_suffix(
|
|
1186
|
+
self, encoding: str | None = None, template: str | None = None
|
|
1187
|
+
) -> str:
|
|
1188
|
+
"""Return PostgreSQL-specific CREATE DATABASE suffix."""
|
|
1189
|
+
suffix = ""
|
|
1190
|
+
if encoding:
|
|
1191
|
+
suffix += f" ENCODING '{encoding}'"
|
|
1192
|
+
if template:
|
|
1193
|
+
suffix += f" TEMPLATE {quote_name(template)}"
|
|
1194
|
+
return suffix and "WITH" + suffix
|
|
1195
|
+
|
|
1196
|
+
def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
|
|
1197
|
+
try:
|
|
1198
|
+
cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
|
|
1199
|
+
except Exception as e:
|
|
1200
|
+
cause = e.__cause__
|
|
1201
|
+
if cause and not isinstance(cause, errors.DuplicateDatabase):
|
|
1202
|
+
# All errors except "database already exists" cancel tests.
|
|
1203
|
+
self._log(f"Got an error creating the test database: {e}")
|
|
1204
|
+
sys.exit(2)
|
|
1205
|
+
else:
|
|
1206
|
+
raise
|
|
1207
|
+
|
|
1208
|
+
def _create_test_db(
|
|
1209
|
+
self, *, test_database_name: str, verbosity: int, autoclobber: bool
|
|
1210
|
+
) -> str:
|
|
1211
|
+
"""
|
|
1212
|
+
Internal implementation - create the test db tables.
|
|
1213
|
+
"""
|
|
1214
|
+
test_db_params = {
|
|
1215
|
+
"dbname": quote_name(test_database_name),
|
|
1216
|
+
"suffix": self.sql_table_creation_suffix(),
|
|
1217
|
+
}
|
|
1218
|
+
# Create the test database and connect to it.
|
|
1219
|
+
with self._nodb_cursor() as cursor:
|
|
1220
|
+
try:
|
|
1221
|
+
self._execute_create_test_db(cursor, test_db_params)
|
|
1222
|
+
except Exception as e:
|
|
1223
|
+
self._log(f"Got an error creating the test database: {e}")
|
|
1224
|
+
if not autoclobber:
|
|
1225
|
+
confirm = input(
|
|
1226
|
+
"Type 'yes' if you would like to try deleting the test "
|
|
1227
|
+
f"database '{test_database_name}', or 'no' to cancel: "
|
|
1228
|
+
)
|
|
1229
|
+
if autoclobber or confirm == "yes":
|
|
1230
|
+
try:
|
|
1231
|
+
if verbosity >= 1:
|
|
1232
|
+
self._log(
|
|
1233
|
+
f"Destroying old test database '{test_database_name}'..."
|
|
1234
|
+
)
|
|
1235
|
+
cursor.execute(
|
|
1236
|
+
"DROP DATABASE {dbname}".format(**test_db_params)
|
|
1237
|
+
)
|
|
1238
|
+
self._execute_create_test_db(cursor, test_db_params)
|
|
1239
|
+
except Exception as e:
|
|
1240
|
+
self._log(f"Got an error recreating the test database: {e}")
|
|
1241
|
+
sys.exit(2)
|
|
1242
|
+
else:
|
|
1243
|
+
self._log("Tests cancelled.")
|
|
1244
|
+
sys.exit(1)
|
|
1245
|
+
|
|
1246
|
+
return test_database_name
|
|
1247
|
+
|
|
1248
|
+
def destroy_test_db(
|
|
1249
|
+
self, old_database_name: str | None = None, verbosity: int = 1
|
|
1250
|
+
) -> None:
|
|
1251
|
+
"""
|
|
1252
|
+
Destroy a test database, prompting the user for confirmation if the
|
|
1253
|
+
database already exists.
|
|
1254
|
+
"""
|
|
1255
|
+
self.close()
|
|
1256
|
+
|
|
1257
|
+
test_database_name = self.settings_dict["DATABASE"]
|
|
1258
|
+
if test_database_name is None:
|
|
1259
|
+
raise ValueError("Test POSTGRES_DATABASE must be set")
|
|
1260
|
+
|
|
1261
|
+
if verbosity >= 1:
|
|
1262
|
+
self._log(f"Destroying test database '{test_database_name}'...")
|
|
1263
|
+
self._destroy_test_db(test_database_name, verbosity)
|
|
1264
|
+
|
|
1265
|
+
# Restore the original database name
|
|
1266
|
+
if old_database_name is not None:
|
|
1267
|
+
settings.POSTGRES_DATABASE = old_database_name
|
|
1268
|
+
self.settings_dict["DATABASE"] = old_database_name
|
|
1269
|
+
|
|
1270
|
+
def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
|
|
1271
|
+
"""
|
|
1272
|
+
Internal implementation - remove the test db tables.
|
|
1273
|
+
"""
|
|
1274
|
+
# Remove the test database to clean up after
|
|
1275
|
+
# ourselves. Connect to the previous database (not the test database)
|
|
1276
|
+
# to do so, because it's not allowed to delete a database while being
|
|
1277
|
+
# connected to it.
|
|
1278
|
+
with self._nodb_cursor() as cursor:
|
|
1279
|
+
cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
|
|
1280
|
+
|
|
1281
|
+
def sql_table_creation_suffix(self) -> str:
|
|
1282
|
+
"""
|
|
1283
|
+
SQL to append to the end of the test table creation statements.
|
|
1284
|
+
"""
|
|
1285
|
+
test_settings = self.settings_dict["TEST"]
|
|
1286
|
+
return self._get_database_create_suffix(
|
|
1287
|
+
encoding=test_settings.get("CHARSET"),
|
|
1288
|
+
template=test_settings.get("TEMPLATE"),
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
class CursorMixin:
|
|
1293
|
+
"""
|
|
1294
|
+
A subclass of psycopg cursor implementing callproc.
|
|
1295
|
+
"""
|
|
1296
|
+
|
|
1297
|
+
def callproc(
|
|
1298
|
+
self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
|
|
1299
|
+
) -> list[Any] | None:
|
|
1300
|
+
if not isinstance(name, psycopg_sql.Identifier):
|
|
1301
|
+
name = psycopg_sql.Identifier(name)
|
|
1302
|
+
|
|
1303
|
+
qparts: list[psycopg_sql.Composable] = [
|
|
1304
|
+
psycopg_sql.SQL("SELECT * FROM "),
|
|
1305
|
+
name,
|
|
1306
|
+
psycopg_sql.SQL("("),
|
|
1307
|
+
]
|
|
1308
|
+
if args:
|
|
1309
|
+
for item in args:
|
|
1310
|
+
qparts.append(psycopg_sql.Literal(item))
|
|
1311
|
+
qparts.append(psycopg_sql.SQL(","))
|
|
1312
|
+
del qparts[-1]
|
|
1313
|
+
|
|
1314
|
+
qparts.append(psycopg_sql.SQL(")"))
|
|
1315
|
+
stmt = psycopg_sql.Composed(qparts)
|
|
1316
|
+
self.execute(stmt) # type: ignore[attr-defined]
|
|
1317
|
+
return args
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
class ServerBindingCursor(CursorMixin, Database.Cursor):
|
|
1321
|
+
pass
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
class Cursor(CursorMixin, Database.ClientCursor):
|
|
1325
|
+
pass
|
|
1326
|
+
|
|
1327
|
+
|
|
1328
|
+
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
|
1329
|
+
def copy(self, statement: Any) -> Any:
|
|
1330
|
+
with self.debug_sql(statement):
|
|
1331
|
+
return self.cursor.copy(statement)
|