lsst-felis 29.2025.4500__py3-none-any.whl → 30.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- felis/__init__.py +1 -4
- felis/cli.py +172 -87
- felis/config/tap_schema/tap_schema_extensions.yaml +73 -0
- felis/datamodel.py +2 -3
- felis/db/{dialects.py → _dialects.py} +69 -4
- felis/db/{variants.py → _variants.py} +1 -1
- felis/db/database_context.py +917 -0
- felis/metadata.py +79 -11
- felis/tap_schema.py +159 -177
- felis/tests/postgresql.py +1 -1
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/METADATA +1 -1
- lsst_felis-30.0.0rc3.dist-info/RECORD +31 -0
- felis/db/schema.py +0 -62
- felis/db/utils.py +0 -409
- lsst_felis-29.2025.4500.dist-info/RECORD +0 -31
- /felis/db/{sqltypes.py → _sqltypes.py} +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/WHEEL +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/entry_points.txt +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/licenses/COPYRIGHT +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/licenses/LICENSE +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/top_level.txt +0 -0
- {lsst_felis-29.2025.4500.dist-info → lsst_felis-30.0.0rc3.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,917 @@
|
|
|
1
|
+
"""API for managing database operations across different dialects."""
|
|
2
|
+
|
|
3
|
+
# This file is part of felis.
|
|
4
|
+
#
|
|
5
|
+
# Developed for the LSST Data Management System.
|
|
6
|
+
# This product includes software developed by the LSST Project
|
|
7
|
+
# (https://www.lsst.org).
|
|
8
|
+
# See the COPYRIGHT file at the top-level directory of this distribution
|
|
9
|
+
# for details of code ownership.
|
|
10
|
+
#
|
|
11
|
+
# This program is free software: you can redistribute it and/or modify
|
|
12
|
+
# it under the terms of the GNU General Public License as published by
|
|
13
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
14
|
+
# (at your option) any later version.
|
|
15
|
+
#
|
|
16
|
+
# This program is distributed in the hope that it will be useful,
|
|
17
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
18
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
19
|
+
# GNU General Public License for more details.
|
|
20
|
+
#
|
|
21
|
+
# You should have received a copy of the GNU General Public License
|
|
22
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
from abc import abstractmethod
|
|
28
|
+
from collections.abc import Callable, Iterator
|
|
29
|
+
from contextlib import AbstractContextManager, contextmanager
|
|
30
|
+
from typing import IO, Any, Literal, TypeAlias
|
|
31
|
+
|
|
32
|
+
from sqlalchemy import (
|
|
33
|
+
Engine,
|
|
34
|
+
MetaData,
|
|
35
|
+
create_engine,
|
|
36
|
+
inspect,
|
|
37
|
+
make_url,
|
|
38
|
+
quoted_name,
|
|
39
|
+
)
|
|
40
|
+
from sqlalchemy.engine import (
|
|
41
|
+
Connection,
|
|
42
|
+
Dialect,
|
|
43
|
+
Result,
|
|
44
|
+
)
|
|
45
|
+
from sqlalchemy.engine.mock import MockConnection, create_mock_engine
|
|
46
|
+
from sqlalchemy.engine.url import URL
|
|
47
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
48
|
+
from sqlalchemy.schema import (
|
|
49
|
+
CreateSchema,
|
|
50
|
+
DropSchema,
|
|
51
|
+
)
|
|
52
|
+
from sqlalchemy.sql import (
|
|
53
|
+
Executable,
|
|
54
|
+
text,
|
|
55
|
+
)
|
|
56
|
+
from sqlalchemy.sql.elements import TextClause
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
"DatabaseContext",
|
|
60
|
+
"DatabaseContextError",
|
|
61
|
+
"MockContext",
|
|
62
|
+
"MySQLContext",
|
|
63
|
+
"PostgreSQLContext",
|
|
64
|
+
"SQLiteContext",
|
|
65
|
+
"create_database_context",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
logger = logging.getLogger("felis")
|
|
69
|
+
|
|
70
|
+
SQLStatement = str | Executable | TextClause
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _normalize_statement(statement: SQLStatement) -> Executable | TextClause:
|
|
74
|
+
if isinstance(statement, str):
|
|
75
|
+
return text(statement)
|
|
76
|
+
return statement
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_mock_connection(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
|
|
80
|
+
writer = _SQLWriter(output_file)
|
|
81
|
+
engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat")
|
|
82
|
+
writer.dialect = engine.dialect
|
|
83
|
+
return engine
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _dialect_name(url: URL) -> str:
|
|
87
|
+
dialect_name = url.drivername
|
|
88
|
+
# Normalize dialect name (e.g., "postgresql+psycopg2" -> "postgresql")
|
|
89
|
+
if "+" in dialect_name:
|
|
90
|
+
dialect_name = dialect_name.split("+")[0]
|
|
91
|
+
return dialect_name
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _clear_schema(metadata: MetaData) -> None:
|
|
95
|
+
if metadata.schema:
|
|
96
|
+
metadata.schema = None
|
|
97
|
+
for table in metadata.tables.values():
|
|
98
|
+
table.schema = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _get_existing_indexes(inspector: Any, table_name: str, schema: str | None) -> set[str]:
|
|
102
|
+
return {
|
|
103
|
+
ix["name"]
|
|
104
|
+
for ix in inspector.get_indexes(table_name, schema=schema)
|
|
105
|
+
if "name" in ix and ix["name"] is not None
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def is_mock_url(url: URL) -> bool:
|
|
110
|
+
"""Check if the engine URL points to a mock connection.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
url
|
|
115
|
+
The SQLAlchemy engine URL.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
bool
|
|
120
|
+
True if the URL is a mock URL, False otherwise.
|
|
121
|
+
"""
|
|
122
|
+
return (url.drivername == "sqlite" and url.database is None) or (
|
|
123
|
+
url.drivername != "sqlite" and url.host is None
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def is_sqlite_url(url: URL | str) -> bool:
|
|
128
|
+
"""Check if the engine URL points to a SQLite database.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
url
|
|
133
|
+
The SQLAlchemy engine URL or string.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
bool
|
|
138
|
+
True if the URL is a SQLite URL, False otherwise.
|
|
139
|
+
"""
|
|
140
|
+
if isinstance(url, str):
|
|
141
|
+
url = make_url(url)
|
|
142
|
+
return url.drivername.startswith("sqlite")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class DatabaseContextError(Exception):
|
|
146
|
+
"""Exception raised for errors in the DatabaseContext operations."""
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class DatabaseContext(AbstractContextManager):
|
|
150
|
+
"""Interface for managing database operations across different
|
|
151
|
+
SQL dialects.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
|
155
|
+
"""Exit the context manager and clean up resources."""
|
|
156
|
+
try:
|
|
157
|
+
self.close()
|
|
158
|
+
except Exception:
|
|
159
|
+
logger.exception("Error during cleanup of database context")
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
@abstractmethod
|
|
163
|
+
def close(self) -> None:
|
|
164
|
+
"""Close and clean up database resources."""
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
@abstractmethod
|
|
169
|
+
def metadata(self) -> MetaData:
|
|
170
|
+
"""The SQLAlchemy metadata representing the database for the context
|
|
171
|
+
(`~sqlalchemy.sql.schema.MetaData`).
|
|
172
|
+
"""
|
|
173
|
+
...
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
@abstractmethod
|
|
177
|
+
def engine(self) -> Engine:
|
|
178
|
+
"""The SQAlchemy engine for the context
|
|
179
|
+
(`~sqlalchemy.engine.Engine`).
|
|
180
|
+
"""
|
|
181
|
+
...
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
@abstractmethod
|
|
185
|
+
def dialect(self) -> Dialect:
|
|
186
|
+
"""The SQLAlchemy dialect for the context
|
|
187
|
+
(`~sqlalchemy.engine.Dialect`).
|
|
188
|
+
"""
|
|
189
|
+
...
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
@abstractmethod
|
|
193
|
+
def dialect_name(self) -> str:
|
|
194
|
+
"""Get the dialect name for this database context (``str``)."""
|
|
195
|
+
...
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def initialize(self) -> None:
|
|
199
|
+
"""Create the target schema in the database if it does not exist
|
|
200
|
+
already.
|
|
201
|
+
|
|
202
|
+
Sub-classes should implement idempotent behavior so that calling this
|
|
203
|
+
method multiple times has no adverse effects. If the schema already
|
|
204
|
+
exists, the method should simply return without raising an error. (A
|
|
205
|
+
warning message may be logged in this case.)
|
|
206
|
+
|
|
207
|
+
Raises
|
|
208
|
+
------
|
|
209
|
+
DatabaseContextError
|
|
210
|
+
If there is an error instantiating the schema.
|
|
211
|
+
"""
|
|
212
|
+
...
|
|
213
|
+
|
|
214
|
+
@abstractmethod
|
|
215
|
+
def drop(self) -> None:
|
|
216
|
+
"""Drop the schema in the database if it exists.
|
|
217
|
+
|
|
218
|
+
Implementations should use ``IF EXISTS`` semantics to avoid raising
|
|
219
|
+
an error if the schema does not exist.
|
|
220
|
+
|
|
221
|
+
Raises
|
|
222
|
+
------
|
|
223
|
+
DatabaseContextError
|
|
224
|
+
If there is an error dropping the schema.
|
|
225
|
+
"""
|
|
226
|
+
...
|
|
227
|
+
|
|
228
|
+
@abstractmethod
|
|
229
|
+
def create_all(self) -> None:
|
|
230
|
+
"""Create all database objects in the schema using the metadata
|
|
231
|
+
object.
|
|
232
|
+
|
|
233
|
+
Raises
|
|
234
|
+
------
|
|
235
|
+
DatabaseContextError
|
|
236
|
+
If there is an error creating the schema objects in the database.
|
|
237
|
+
"""
|
|
238
|
+
...
|
|
239
|
+
|
|
240
|
+
@abstractmethod
|
|
241
|
+
def create_indexes(self) -> None:
|
|
242
|
+
"""Create all indexes in the schema using the metadata object.
|
|
243
|
+
|
|
244
|
+
Raises
|
|
245
|
+
------
|
|
246
|
+
DatabaseContextError
|
|
247
|
+
If there is an error creating the indexes in the database.
|
|
248
|
+
"""
|
|
249
|
+
...
|
|
250
|
+
|
|
251
|
+
@abstractmethod
|
|
252
|
+
def drop_indexes(self) -> None:
|
|
253
|
+
"""Drop all indexes in the schema using the metadata object.
|
|
254
|
+
|
|
255
|
+
Raises
|
|
256
|
+
------
|
|
257
|
+
DatabaseContextError
|
|
258
|
+
If there is an error dropping the indexes in the database.
|
|
259
|
+
"""
|
|
260
|
+
...
|
|
261
|
+
|
|
262
|
+
@abstractmethod
|
|
263
|
+
def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
|
|
264
|
+
"""Execute a SQL statement and return the result.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
statement
|
|
269
|
+
The SQL statement to execute.
|
|
270
|
+
|
|
271
|
+
Returns
|
|
272
|
+
-------
|
|
273
|
+
`~sqlalchemy.engine.Result`
|
|
274
|
+
The result of the statement execution.
|
|
275
|
+
|
|
276
|
+
Raises
|
|
277
|
+
------
|
|
278
|
+
DatabaseContextError
|
|
279
|
+
If there is an error executing the SQL statement.
|
|
280
|
+
"""
|
|
281
|
+
...
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class _BaseContext(DatabaseContext):
|
|
285
|
+
"""Base database context providing common behavior.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
engine_url
|
|
290
|
+
The SQLAlchemy engine for connecting to the database.
|
|
291
|
+
metadata
|
|
292
|
+
The SQLAlchemy metadata representing the database objects.
|
|
293
|
+
require_schema
|
|
294
|
+
True if a valid schema name is required on the MetaData, False if not.
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
# Subclasses should set this to the dialect name.
|
|
298
|
+
DIALECT: str
|
|
299
|
+
|
|
300
|
+
def __init__(self, engine_url: URL, metadata: MetaData, require_schema: bool = False) -> None:
|
|
301
|
+
self._engine_url = engine_url
|
|
302
|
+
self._metadata = metadata
|
|
303
|
+
self._schema_name: str | None = metadata.schema
|
|
304
|
+
self._engine: Engine | None = None
|
|
305
|
+
self._echo: bool = False
|
|
306
|
+
|
|
307
|
+
# Check that the URL dialect matches this context's expected dialect
|
|
308
|
+
self._validate_dialect(engine_url)
|
|
309
|
+
|
|
310
|
+
# Ensure the schema name is set for dialects that require it
|
|
311
|
+
if require_schema and self._schema_name is None:
|
|
312
|
+
raise DatabaseContextError(f"Schema name must be set for context: {self.dialect_name}")
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def echo(self) -> bool:
|
|
316
|
+
"""Whether to log all SQL statements executed by the engine
|
|
317
|
+
(``bool``).
|
|
318
|
+
"""
|
|
319
|
+
return self._echo
|
|
320
|
+
|
|
321
|
+
@echo.setter
|
|
322
|
+
def echo(self, value: bool) -> None:
|
|
323
|
+
self._echo = value
|
|
324
|
+
if self.engine is not None:
|
|
325
|
+
self.engine.echo = value
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def _validate_dialect(cls, engine_url: URL) -> None:
|
|
329
|
+
"""Validate that the engine dialect matches this context's expected
|
|
330
|
+
dialect.
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
engine_url
|
|
335
|
+
The SQLAlchemy database URL to validate.
|
|
336
|
+
|
|
337
|
+
Raises
|
|
338
|
+
------
|
|
339
|
+
DatabaseContextError
|
|
340
|
+
If the engine dialect doesn't match the context's expected dialect.
|
|
341
|
+
"""
|
|
342
|
+
# Normalize both the engine dialect and expected dialect for comparison
|
|
343
|
+
engine_dialect = _dialect_name(engine_url)
|
|
344
|
+
expected_dialect = cls.DIALECT.lower()
|
|
345
|
+
|
|
346
|
+
if engine_dialect != expected_dialect:
|
|
347
|
+
raise DatabaseContextError(
|
|
348
|
+
f"Engine dialect '{engine_dialect}' does not match the context's expected dialect: "
|
|
349
|
+
f"{expected_dialect}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def engine(self) -> Engine:
|
|
354
|
+
if self._engine is None:
|
|
355
|
+
self._engine = create_engine(self._engine_url)
|
|
356
|
+
return self._engine
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def metadata(self) -> MetaData:
|
|
360
|
+
return self._metadata
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def dialect(self) -> Dialect:
|
|
364
|
+
return self.engine.dialect
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def dialect_name(self) -> str:
|
|
368
|
+
"""Get the dialect name for this database context.
|
|
369
|
+
|
|
370
|
+
Returns
|
|
371
|
+
-------
|
|
372
|
+
str
|
|
373
|
+
The normalized dialect name.
|
|
374
|
+
"""
|
|
375
|
+
return self.DIALECT
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def schema_name(self) -> str | None:
|
|
379
|
+
"""Effective schema name for this context (may be None).
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
str | None
|
|
384
|
+
The schema name, or None if no schema is set.
|
|
385
|
+
"""
|
|
386
|
+
return self._schema_name
|
|
387
|
+
|
|
388
|
+
@contextmanager
|
|
389
|
+
def connect(self) -> Iterator[Connection]:
|
|
390
|
+
"""Context manager for database connection."""
|
|
391
|
+
with self.engine.connect() as connection:
|
|
392
|
+
yield connection
|
|
393
|
+
|
|
394
|
+
def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
|
|
395
|
+
statement = _normalize_statement(statement)
|
|
396
|
+
try:
|
|
397
|
+
with self.connect() as conn:
|
|
398
|
+
with conn.begin():
|
|
399
|
+
if parameters:
|
|
400
|
+
result = conn.execute(statement, parameters)
|
|
401
|
+
else:
|
|
402
|
+
result = conn.execute(statement)
|
|
403
|
+
return result
|
|
404
|
+
except SQLAlchemyError as e:
|
|
405
|
+
raise DatabaseContextError(f"Error executing statement: {e}") from e
|
|
406
|
+
|
|
407
|
+
def create_all(self) -> None:
|
|
408
|
+
with self.connect() as conn:
|
|
409
|
+
with conn.begin():
|
|
410
|
+
try:
|
|
411
|
+
self.metadata.create_all(bind=conn)
|
|
412
|
+
except SQLAlchemyError as e:
|
|
413
|
+
raise DatabaseContextError(f"Error creating database: {e}") from e
|
|
414
|
+
|
|
415
|
+
def _manage_indexes(self, action: str) -> None:
|
|
416
|
+
"""Manage indexes by creating or dropping them.
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
action
|
|
421
|
+
The action to perform, either "create" or "drop".
|
|
422
|
+
|
|
423
|
+
Raises
|
|
424
|
+
------
|
|
425
|
+
DatabaseContextError
|
|
426
|
+
If there is an error managing the indexes in the database.
|
|
427
|
+
"""
|
|
428
|
+
with self.connect() as conn:
|
|
429
|
+
with conn.begin():
|
|
430
|
+
try:
|
|
431
|
+
inspector = inspect(conn)
|
|
432
|
+
for table in self.metadata.tables.values():
|
|
433
|
+
# Fetch all existing indexes for this table once
|
|
434
|
+
existing_indexes = _get_existing_indexes(inspector, table.name, self.schema_name)
|
|
435
|
+
|
|
436
|
+
for index in table.indexes:
|
|
437
|
+
if index.name is None:
|
|
438
|
+
# Anonymous indexes can't be checked by name
|
|
439
|
+
logger.warning(f"Skipping anonymous index on table '{table.name}'")
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
if action == "create":
|
|
443
|
+
if index.name in existing_indexes:
|
|
444
|
+
logger.warning(
|
|
445
|
+
f"Skipping creation of index '{index.name}' which already exists"
|
|
446
|
+
)
|
|
447
|
+
continue
|
|
448
|
+
index.create(bind=conn, checkfirst=False) # We already checked
|
|
449
|
+
logger.info(f"Created index '{index.name}'")
|
|
450
|
+
elif action == "drop":
|
|
451
|
+
if index.name not in existing_indexes:
|
|
452
|
+
logger.warning(f"Skipping index '{index.name}' which does not exist")
|
|
453
|
+
continue
|
|
454
|
+
index.drop(bind=conn, checkfirst=False) # We already checked
|
|
455
|
+
logger.info(f"Dropped index '{index.name}'")
|
|
456
|
+
else:
|
|
457
|
+
raise ValueError(f"Invalid action '{action}'. Must be 'create' or 'drop'.")
|
|
458
|
+
except SQLAlchemyError as e:
|
|
459
|
+
raise DatabaseContextError(f"Error {action}ing indexes: {e}") from e
|
|
460
|
+
|
|
461
|
+
def create_indexes(self) -> None:
|
|
462
|
+
"""Create all indexes in the schema using the metadata object.
|
|
463
|
+
|
|
464
|
+
Raises
|
|
465
|
+
------
|
|
466
|
+
DatabaseContextError
|
|
467
|
+
If there is an error creating the indexes in the database.
|
|
468
|
+
"""
|
|
469
|
+
self._manage_indexes("create")
|
|
470
|
+
|
|
471
|
+
def drop_indexes(self) -> None:
|
|
472
|
+
"""Drop all indexes in the schema using the metadata object.
|
|
473
|
+
|
|
474
|
+
Raises
|
|
475
|
+
------
|
|
476
|
+
DatabaseContextError
|
|
477
|
+
If there is an error dropping the indexes in the database.
|
|
478
|
+
"""
|
|
479
|
+
self._manage_indexes("drop")
|
|
480
|
+
|
|
481
|
+
def _required_schema_name(self) -> str:
|
|
482
|
+
"""Return the schema name, ensuring that it is set.
|
|
483
|
+
|
|
484
|
+
This is mainly here for typing purposes, because the schema_name
|
|
485
|
+
property may be None, and mypy doesn't understand that we already
|
|
486
|
+
checked it during initialization.
|
|
487
|
+
"""
|
|
488
|
+
if self.schema_name is None:
|
|
489
|
+
raise DatabaseContextError("Schema name is required but not set.")
|
|
490
|
+
return self.schema_name
|
|
491
|
+
|
|
492
|
+
def close(self) -> None:
|
|
493
|
+
"""Close and dispose of the database engine."""
|
|
494
|
+
if self._engine is not None:
|
|
495
|
+
self._engine.dispose()
|
|
496
|
+
self._engine = None
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
_ContextClass: TypeAlias = type[_BaseContext]
|
|
500
|
+
_ContextDecorator: TypeAlias = Callable[[_ContextClass], _ContextClass]
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class DatabaseContextFactory:
|
|
504
|
+
"""Factory for creating DatabaseContext instances based on dialect type."""
|
|
505
|
+
|
|
506
|
+
_registry: dict[str, _ContextClass] = {}
|
|
507
|
+
|
|
508
|
+
@classmethod
|
|
509
|
+
def register(cls) -> _ContextDecorator:
|
|
510
|
+
"""Register a context class for its dialect.
|
|
511
|
+
|
|
512
|
+
The dialect is determined by reading the DIALECT attribute from the
|
|
513
|
+
decorated class.
|
|
514
|
+
|
|
515
|
+
Returns
|
|
516
|
+
-------
|
|
517
|
+
Callable
|
|
518
|
+
The decorator function that registers the context class.
|
|
519
|
+
|
|
520
|
+
Examples
|
|
521
|
+
--------
|
|
522
|
+
>>> @DatabaseContextFactory.register()
|
|
523
|
+
... class PostgreSQLContext(_BaseContext):
|
|
524
|
+
... DIALECT = "postgresql"
|
|
525
|
+
... pass
|
|
526
|
+
|
|
527
|
+
Notes
|
|
528
|
+
-----
|
|
529
|
+
The registry is populated at module import time and afterwards should
|
|
530
|
+
be treated as read-only.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
def decorator(context_class: type[_BaseContext]) -> type[_BaseContext]:
|
|
534
|
+
# Get the dialect from the class's DIALECT attribute
|
|
535
|
+
if not hasattr(context_class, "DIALECT"):
|
|
536
|
+
raise ValueError(f"Context class {context_class.__name__} must define a DIALECT attribute")
|
|
537
|
+
cls._registry[context_class.DIALECT] = context_class
|
|
538
|
+
return context_class
|
|
539
|
+
|
|
540
|
+
return decorator
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def register_class(cls, dialect: str, context_class: type[_BaseContext]) -> None:
|
|
544
|
+
"""Register a context class for a specific dialect programmatically.
|
|
545
|
+
|
|
546
|
+
Parameters
|
|
547
|
+
----------
|
|
548
|
+
dialect
|
|
549
|
+
The dialect name to register.
|
|
550
|
+
context_class
|
|
551
|
+
The context class to use for this dialect.
|
|
552
|
+
"""
|
|
553
|
+
dialect_name = dialect.lower()
|
|
554
|
+
if "+" in dialect_name:
|
|
555
|
+
dialect_name = dialect_name.split("+")[0]
|
|
556
|
+
cls._registry[dialect_name] = context_class
|
|
557
|
+
|
|
558
|
+
@classmethod
|
|
559
|
+
def create_context(cls, dialect: str, engine_url: URL, metadata: MetaData) -> DatabaseContext:
|
|
560
|
+
"""Create a context instance for the given dialect.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
dialect
|
|
565
|
+
The database dialect name.
|
|
566
|
+
engine_url
|
|
567
|
+
The SQLAlchemy database URL.
|
|
568
|
+
metadata
|
|
569
|
+
The SQLAlchemy metadata.
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
DatabaseContext
|
|
574
|
+
The appropriate context instance.
|
|
575
|
+
|
|
576
|
+
Raises
|
|
577
|
+
------
|
|
578
|
+
ValueError
|
|
579
|
+
If no context class is registered for the dialect.
|
|
580
|
+
"""
|
|
581
|
+
dialect_name = dialect.lower()
|
|
582
|
+
if "+" in dialect_name:
|
|
583
|
+
dialect_name = dialect_name.split("+")[0]
|
|
584
|
+
|
|
585
|
+
if dialect_name not in cls._registry:
|
|
586
|
+
supported = cls.get_supported_dialects()
|
|
587
|
+
raise ValueError(
|
|
588
|
+
f"No context class registered for dialect: {dialect_name}. "
|
|
589
|
+
f"Supported dialects: {', '.join(supported)}"
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
context_class = cls._registry[dialect_name]
|
|
593
|
+
return context_class(engine_url, metadata)
|
|
594
|
+
|
|
595
|
+
@classmethod
|
|
596
|
+
def get_supported_dialects(cls) -> list[str]:
|
|
597
|
+
"""Get a list of supported dialect names.
|
|
598
|
+
|
|
599
|
+
Returns
|
|
600
|
+
-------
|
|
601
|
+
list[str]
|
|
602
|
+
List of supported dialect names.
|
|
603
|
+
"""
|
|
604
|
+
return list(cls._registry.keys())
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class _SQLWriter:
|
|
608
|
+
"""Write SQL statements to stdout or a file.
|
|
609
|
+
|
|
610
|
+
Parameters
|
|
611
|
+
----------
|
|
612
|
+
file
|
|
613
|
+
The file to write the SQL statements to. If None, the statements
|
|
614
|
+
will be written to stdout.
|
|
615
|
+
"""
|
|
616
|
+
|
|
617
|
+
def __init__(self, file: IO[str] | None = None) -> None:
|
|
618
|
+
"""Initialize the SQL writer."""
|
|
619
|
+
self.file = file
|
|
620
|
+
self.dialect: Dialect | None = None
|
|
621
|
+
|
|
622
|
+
def write(self, sql: Any, *multiparams: Any, **params: Any) -> None:
|
|
623
|
+
"""Write the SQL statement to a file or stdout.
|
|
624
|
+
|
|
625
|
+
Statements with parameters will be formatted with the values
|
|
626
|
+
inserted into the resultant SQL output.
|
|
627
|
+
|
|
628
|
+
Parameters
|
|
629
|
+
----------
|
|
630
|
+
sql
|
|
631
|
+
The SQL statement to write.
|
|
632
|
+
*multiparams
|
|
633
|
+
The multiparams to use for the SQL statement.
|
|
634
|
+
**params
|
|
635
|
+
The params to use for the SQL statement.
|
|
636
|
+
|
|
637
|
+
Notes
|
|
638
|
+
-----
|
|
639
|
+
The functions arguments are typed very loosely because this method in
|
|
640
|
+
SQLAlchemy is untyped, amd we do not call it directly.
|
|
641
|
+
"""
|
|
642
|
+
compiled = sql.compile(dialect=self.dialect)
|
|
643
|
+
sql_str = str(compiled) + ";"
|
|
644
|
+
params_list = [compiled.params]
|
|
645
|
+
for params in params_list:
|
|
646
|
+
if not params:
|
|
647
|
+
print(sql_str, file=self.file)
|
|
648
|
+
continue
|
|
649
|
+
new_params = {}
|
|
650
|
+
for key, value in params.items():
|
|
651
|
+
if isinstance(value, str):
|
|
652
|
+
new_params[key] = f"'{value}'"
|
|
653
|
+
elif value is None:
|
|
654
|
+
new_params[key] = "null"
|
|
655
|
+
else:
|
|
656
|
+
new_params[key] = value
|
|
657
|
+
print(sql_str % new_params, file=self.file)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
@DatabaseContextFactory.register()
|
|
661
|
+
class PostgreSQLContext(_BaseContext):
|
|
662
|
+
"""Database context for Postgres.
|
|
663
|
+
|
|
664
|
+
Parameters
|
|
665
|
+
----------
|
|
666
|
+
engine_url
|
|
667
|
+
The SQLAlchemy database URL for connecting to the database.
|
|
668
|
+
metadata
|
|
669
|
+
The SQLAlchemy metadata representing the database objects.
|
|
670
|
+
"""
|
|
671
|
+
|
|
672
|
+
DIALECT = "postgresql"
|
|
673
|
+
|
|
674
|
+
def __init__(self, engine_url: URL, metadata: MetaData):
|
|
675
|
+
super().__init__(engine_url, metadata, require_schema=True)
|
|
676
|
+
|
|
677
|
+
def initialize(self) -> None:
|
|
678
|
+
schema_name = self._required_schema_name()
|
|
679
|
+
try:
|
|
680
|
+
logger.debug(f"Checking if PG schema exists: {schema_name}")
|
|
681
|
+
result = self.execute(
|
|
682
|
+
"""
|
|
683
|
+
SELECT schema_name
|
|
684
|
+
FROM information_schema.schemata
|
|
685
|
+
WHERE schema_name = :schema_name
|
|
686
|
+
""",
|
|
687
|
+
{"schema_name": schema_name},
|
|
688
|
+
)
|
|
689
|
+
if result.fetchone():
|
|
690
|
+
return
|
|
691
|
+
logger.debug(f"Creating PG schema: {schema_name}")
|
|
692
|
+
self.execute(CreateSchema(schema_name))
|
|
693
|
+
except SQLAlchemyError as e:
|
|
694
|
+
raise DatabaseContextError(f"Error initializing Postgres schema: {e}") from e
|
|
695
|
+
|
|
696
|
+
def drop(self) -> None:
|
|
697
|
+
schema_name = self._required_schema_name()
|
|
698
|
+
try:
|
|
699
|
+
logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
|
|
700
|
+
self.execute(DropSchema(schema_name, if_exists=True, cascade=True))
|
|
701
|
+
except SQLAlchemyError as e:
|
|
702
|
+
raise DatabaseContextError(f"Error dropping Postgres database: {e}") from e
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
@DatabaseContextFactory.register()
|
|
706
|
+
class MySQLContext(_BaseContext):
|
|
707
|
+
"""Database context for MySQL.
|
|
708
|
+
|
|
709
|
+
Parameters
|
|
710
|
+
----------
|
|
711
|
+
engine_url
|
|
712
|
+
The SQLAlchemy database URL for connecting to the database.
|
|
713
|
+
metadata
|
|
714
|
+
The SQLAlchemy metadata representing the database objects.
|
|
715
|
+
"""
|
|
716
|
+
|
|
717
|
+
DIALECT = "mysql"
|
|
718
|
+
|
|
719
|
+
def __init__(self, engine_url: URL, metadata: MetaData):
|
|
720
|
+
super().__init__(engine_url, metadata, require_schema=True)
|
|
721
|
+
|
|
722
|
+
def initialize(self) -> None:
|
|
723
|
+
# The schema is instantiated as a database, as MySQL does not have a
|
|
724
|
+
# distinct schema concept, unlike Postgres.
|
|
725
|
+
schema_name = self._required_schema_name()
|
|
726
|
+
try:
|
|
727
|
+
logger.debug(f"Checking if MySQL database exists: {schema_name}")
|
|
728
|
+
result = self.execute("SHOW DATABASES LIKE :schema_name", {"schema_name": schema_name})
|
|
729
|
+
if result.fetchone():
|
|
730
|
+
return
|
|
731
|
+
logger.debug(f"Creating MySQL database: {schema_name}")
|
|
732
|
+
from sqlalchemy import DDL
|
|
733
|
+
|
|
734
|
+
create_stmt = DDL(f"CREATE DATABASE {quoted_name(schema_name, quote=True)}")
|
|
735
|
+
self.execute(create_stmt)
|
|
736
|
+
except SQLAlchemyError as e:
|
|
737
|
+
raise DatabaseContextError(f"Error initializing MySQL database: {e}") from e
|
|
738
|
+
|
|
739
|
+
def drop(self) -> None:
|
|
740
|
+
schema_name = self._required_schema_name()
|
|
741
|
+
try:
|
|
742
|
+
logger.debug(f"Dropping MySQL database if exists: {schema_name}")
|
|
743
|
+
from sqlalchemy import DDL
|
|
744
|
+
|
|
745
|
+
drop_stmt = DDL(f"DROP DATABASE IF EXISTS {quoted_name(schema_name, quote=True)}")
|
|
746
|
+
self.execute(drop_stmt)
|
|
747
|
+
except SQLAlchemyError as e:
|
|
748
|
+
raise DatabaseContextError(f"Error dropping MySQL database: {e}") from e
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
@DatabaseContextFactory.register()
|
|
752
|
+
class SQLiteContext(_BaseContext):
|
|
753
|
+
"""Database context for SQLite.
|
|
754
|
+
|
|
755
|
+
Parameters
|
|
756
|
+
----------
|
|
757
|
+
engine_url
|
|
758
|
+
The SQLAlchemy database URL for connecting to the database.
|
|
759
|
+
metadata
|
|
760
|
+
The SQLAlchemy metadata representing the database objects.
|
|
761
|
+
"""
|
|
762
|
+
|
|
763
|
+
DIALECT = "sqlite"
|
|
764
|
+
|
|
765
|
+
def __init__(self, engine_url: URL, metadata: MetaData):
|
|
766
|
+
# Schema name needs to be cleared, if set.
|
|
767
|
+
_clear_schema(metadata)
|
|
768
|
+
# Schema name is not required.
|
|
769
|
+
super().__init__(engine_url, metadata)
|
|
770
|
+
|
|
771
|
+
def initialize(self) -> None:
|
|
772
|
+
# Nothing needs to be done for SQLite initialization.
|
|
773
|
+
return
|
|
774
|
+
|
|
775
|
+
def drop(self) -> None:
|
|
776
|
+
try:
|
|
777
|
+
logger.debug("Dropping tables in SQLite schema")
|
|
778
|
+
# Drop all the tables in the database file.
|
|
779
|
+
self.metadata.drop_all(bind=self.engine)
|
|
780
|
+
except SQLAlchemyError as e:
|
|
781
|
+
raise DatabaseContextError(f"Error dropping SQLite database: {e}") from e
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
class MockContext(DatabaseContext):
|
|
785
|
+
"""Database context for a mock connection.
|
|
786
|
+
|
|
787
|
+
Parameters
|
|
788
|
+
----------
|
|
789
|
+
metadata
|
|
790
|
+
The SQLAlchemy metadata defining the database objects.
|
|
791
|
+
connection
|
|
792
|
+
The SQLAlchemy mock connection.
|
|
793
|
+
"""
|
|
794
|
+
|
|
795
|
+
def __init__(self, metadata: MetaData, connection: MockConnection):
|
|
796
|
+
self._metadata = metadata
|
|
797
|
+
self._connection = connection
|
|
798
|
+
self._dialect = connection.dialect
|
|
799
|
+
|
|
800
|
+
@property
|
|
801
|
+
def dialect(self) -> Dialect:
|
|
802
|
+
return self._dialect
|
|
803
|
+
|
|
804
|
+
@property
|
|
805
|
+
def dialect_name(self) -> str:
|
|
806
|
+
return self.dialect.name
|
|
807
|
+
|
|
808
|
+
@property
|
|
809
|
+
def metadata(self) -> MetaData:
|
|
810
|
+
return self._metadata
|
|
811
|
+
|
|
812
|
+
@property
|
|
813
|
+
def engine(self) -> Engine:
|
|
814
|
+
raise DatabaseContextError("MockContext does not provide an engine.")
|
|
815
|
+
|
|
816
|
+
def initialize(self) -> None:
|
|
817
|
+
# Mock connection doesn't do any initialization.
|
|
818
|
+
pass
|
|
819
|
+
|
|
820
|
+
def drop(self) -> None:
|
|
821
|
+
# Mock connection doesn't drop.
|
|
822
|
+
pass
|
|
823
|
+
|
|
824
|
+
def create_all(self) -> None:
|
|
825
|
+
self._metadata.create_all(self._connection)
|
|
826
|
+
|
|
827
|
+
def create_indexes(self) -> None:
|
|
828
|
+
# Mock connection can't create indexes.
|
|
829
|
+
pass
|
|
830
|
+
|
|
831
|
+
def drop_indexes(self) -> None:
|
|
832
|
+
# Mock connection can't drop indexes.
|
|
833
|
+
pass
|
|
834
|
+
|
|
835
|
+
def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
|
|
836
|
+
statement = _normalize_statement(statement)
|
|
837
|
+
if parameters:
|
|
838
|
+
return self._connection.connect().execute(statement, parameters)
|
|
839
|
+
else:
|
|
840
|
+
return self._connection.connect().execute(statement)
|
|
841
|
+
|
|
842
|
+
def close(self) -> None:
|
|
843
|
+
"""Close the mock connection (no-op)."""
|
|
844
|
+
pass
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def create_database_context(
|
|
848
|
+
engine_url: str | URL,
|
|
849
|
+
metadata: MetaData,
|
|
850
|
+
output_file: IO[str] | None = None,
|
|
851
|
+
dry_run: bool = False,
|
|
852
|
+
echo: bool | None = None,
|
|
853
|
+
) -> DatabaseContext:
|
|
854
|
+
"""Create a DatabaseContext object based on the engine URL.
|
|
855
|
+
|
|
856
|
+
Parameters
|
|
857
|
+
----------
|
|
858
|
+
engine_url
|
|
859
|
+
The database URL for the database connection.
|
|
860
|
+
metadata
|
|
861
|
+
The SQLAlchemy MetaData representing the database objects.
|
|
862
|
+
output_file
|
|
863
|
+
Output file for writing generated SQL commands.
|
|
864
|
+
dry_run
|
|
865
|
+
If True, configure the context to perform a dry run, where operations
|
|
866
|
+
will not be executed.
|
|
867
|
+
If False, use a normal context where operations are executed.
|
|
868
|
+
echo
|
|
869
|
+
If True, the SQLAlchemy engine will log all statements to the console.
|
|
870
|
+
|
|
871
|
+
Returns
|
|
872
|
+
-------
|
|
873
|
+
DatabaseContext
|
|
874
|
+
A database context appropriate for the given engine URL. This will be
|
|
875
|
+
a `MockContext` if the URL appears like a mock URL or if ``dry_run`` is
|
|
876
|
+
True, otherwise it will be a context based on the dialect using the
|
|
877
|
+
factory pattern.
|
|
878
|
+
|
|
879
|
+
Raises
|
|
880
|
+
------
|
|
881
|
+
DatabaseContextError
|
|
882
|
+
If the dialect is not supported or if there's an issue creating
|
|
883
|
+
the context.
|
|
884
|
+
"""
|
|
885
|
+
if isinstance(engine_url, str):
|
|
886
|
+
engine_url = make_url(engine_url)
|
|
887
|
+
|
|
888
|
+
if is_mock_url(engine_url) or dry_run:
|
|
889
|
+
# Use a mock context for mock URLs or dry run mode.
|
|
890
|
+
dialect_name = _dialect_name(engine_url)
|
|
891
|
+
if dialect_name == "sqlite":
|
|
892
|
+
_clear_schema(metadata)
|
|
893
|
+
mock_connection = _create_mock_connection(engine_url, output_file)
|
|
894
|
+
return MockContext(metadata, mock_connection)
|
|
895
|
+
else:
|
|
896
|
+
# Create a real engine and context for the given dialect.
|
|
897
|
+
try:
|
|
898
|
+
dialect_name = _dialect_name(engine_url)
|
|
899
|
+
|
|
900
|
+
# Use the factory to create the appropriate context
|
|
901
|
+
try:
|
|
902
|
+
db_ctx = DatabaseContextFactory.create_context(dialect_name, engine_url, metadata)
|
|
903
|
+
if echo is not None:
|
|
904
|
+
# This is settable for real contexts only.
|
|
905
|
+
if hasattr(db_ctx, "echo"):
|
|
906
|
+
db_ctx.echo = echo
|
|
907
|
+
return db_ctx
|
|
908
|
+
except ValueError as e:
|
|
909
|
+
supported = DatabaseContextFactory.get_supported_dialects()
|
|
910
|
+
raise DatabaseContextError(
|
|
911
|
+
f"Unsupported dialect: {dialect_name}. Supported dialects are: {', '.join(supported)}"
|
|
912
|
+
) from e
|
|
913
|
+
|
|
914
|
+
except Exception as e:
|
|
915
|
+
if isinstance(e, DatabaseContextError):
|
|
916
|
+
raise
|
|
917
|
+
raise DatabaseContextError(f"Failed to create database context: {e}") from e
|