lsst-felis 29.2025.4500__py3-none-any.whl → 30.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,917 @@
1
+ """API for managing database operations across different dialects."""
2
+
3
+ # This file is part of felis.
4
+ #
5
+ # Developed for the LSST Data Management System.
6
+ # This product includes software developed by the LSST Project
7
+ # (https://www.lsst.org).
8
+ # See the COPYRIGHT file at the top-level directory of this distribution
9
+ # for details of code ownership.
10
+ #
11
+ # This program is free software: you can redistribute it and/or modify
12
+ # it under the terms of the GNU General Public License as published by
13
+ # the Free Software Foundation, either version 3 of the License, or
14
+ # (at your option) any later version.
15
+ #
16
+ # This program is distributed in the hope that it will be useful,
17
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
18
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19
+ # GNU General Public License for more details.
20
+ #
21
+ # You should have received a copy of the GNU General Public License
22
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from abc import abstractmethod
28
+ from collections.abc import Callable, Iterator
29
+ from contextlib import AbstractContextManager, contextmanager
30
+ from typing import IO, Any, Literal, TypeAlias
31
+
32
+ from sqlalchemy import (
33
+ Engine,
34
+ MetaData,
35
+ create_engine,
36
+ inspect,
37
+ make_url,
38
+ quoted_name,
39
+ )
40
+ from sqlalchemy.engine import (
41
+ Connection,
42
+ Dialect,
43
+ Result,
44
+ )
45
+ from sqlalchemy.engine.mock import MockConnection, create_mock_engine
46
+ from sqlalchemy.engine.url import URL
47
+ from sqlalchemy.exc import SQLAlchemyError
48
+ from sqlalchemy.schema import (
49
+ CreateSchema,
50
+ DropSchema,
51
+ )
52
+ from sqlalchemy.sql import (
53
+ Executable,
54
+ text,
55
+ )
56
+ from sqlalchemy.sql.elements import TextClause
57
+
58
+ __all__ = [
59
+ "DatabaseContext",
60
+ "DatabaseContextError",
61
+ "MockContext",
62
+ "MySQLContext",
63
+ "PostgreSQLContext",
64
+ "SQLiteContext",
65
+ "create_database_context",
66
+ ]
67
+
68
+ logger = logging.getLogger("felis")
69
+
70
+ SQLStatement = str | Executable | TextClause
71
+
72
+
73
+ def _normalize_statement(statement: SQLStatement) -> Executable | TextClause:
74
+ if isinstance(statement, str):
75
+ return text(statement)
76
+ return statement
77
+
78
+
79
+ def _create_mock_connection(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
80
+ writer = _SQLWriter(output_file)
81
+ engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat")
82
+ writer.dialect = engine.dialect
83
+ return engine
84
+
85
+
86
+ def _dialect_name(url: URL) -> str:
87
+ dialect_name = url.drivername
88
+ # Normalize dialect name (e.g., "postgresql+psycopg2" -> "postgresql")
89
+ if "+" in dialect_name:
90
+ dialect_name = dialect_name.split("+")[0]
91
+ return dialect_name
92
+
93
+
94
+ def _clear_schema(metadata: MetaData) -> None:
95
+ if metadata.schema:
96
+ metadata.schema = None
97
+ for table in metadata.tables.values():
98
+ table.schema = None
99
+
100
+
101
+ def _get_existing_indexes(inspector: Any, table_name: str, schema: str | None) -> set[str]:
102
+ return {
103
+ ix["name"]
104
+ for ix in inspector.get_indexes(table_name, schema=schema)
105
+ if "name" in ix and ix["name"] is not None
106
+ }
107
+
108
+
109
+ def is_mock_url(url: URL) -> bool:
110
+ """Check if the engine URL points to a mock connection.
111
+
112
+ Parameters
113
+ ----------
114
+ url
115
+ The SQLAlchemy engine URL.
116
+
117
+ Returns
118
+ -------
119
+ bool
120
+ True if the URL is a mock URL, False otherwise.
121
+ """
122
+ return (url.drivername == "sqlite" and url.database is None) or (
123
+ url.drivername != "sqlite" and url.host is None
124
+ )
125
+
126
+
127
+ def is_sqlite_url(url: URL | str) -> bool:
128
+ """Check if the engine URL points to a SQLite database.
129
+
130
+ Parameters
131
+ ----------
132
+ url
133
+ The SQLAlchemy engine URL or string.
134
+
135
+ Returns
136
+ -------
137
+ bool
138
+ True if the URL is a SQLite URL, False otherwise.
139
+ """
140
+ if isinstance(url, str):
141
+ url = make_url(url)
142
+ return url.drivername.startswith("sqlite")
143
+
144
+
145
+ class DatabaseContextError(Exception):
146
+ """Exception raised for errors in the DatabaseContext operations."""
147
+
148
+
149
+ class DatabaseContext(AbstractContextManager):
150
+ """Interface for managing database operations across different
151
+ SQL dialects.
152
+ """
153
+
154
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
155
+ """Exit the context manager and clean up resources."""
156
+ try:
157
+ self.close()
158
+ except Exception:
159
+ logger.exception("Error during cleanup of database context")
160
+ return False
161
+
162
+ @abstractmethod
163
+ def close(self) -> None:
164
+ """Close and clean up database resources."""
165
+ ...
166
+
167
+ @property
168
+ @abstractmethod
169
+ def metadata(self) -> MetaData:
170
+ """The SQLAlchemy metadata representing the database for the context
171
+ (`~sqlalchemy.sql.schema.MetaData`).
172
+ """
173
+ ...
174
+
175
+ @property
176
+ @abstractmethod
177
+ def engine(self) -> Engine:
178
+ """The SQAlchemy engine for the context
179
+ (`~sqlalchemy.engine.Engine`).
180
+ """
181
+ ...
182
+
183
+ @property
184
+ @abstractmethod
185
+ def dialect(self) -> Dialect:
186
+ """The SQLAlchemy dialect for the context
187
+ (`~sqlalchemy.engine.Dialect`).
188
+ """
189
+ ...
190
+
191
+ @property
192
+ @abstractmethod
193
+ def dialect_name(self) -> str:
194
+ """Get the dialect name for this database context (``str``)."""
195
+ ...
196
+
197
+ @abstractmethod
198
+ def initialize(self) -> None:
199
+ """Create the target schema in the database if it does not exist
200
+ already.
201
+
202
+ Sub-classes should implement idempotent behavior so that calling this
203
+ method multiple times has no adverse effects. If the schema already
204
+ exists, the method should simply return without raising an error. (A
205
+ warning message may be logged in this case.)
206
+
207
+ Raises
208
+ ------
209
+ DatabaseContextError
210
+ If there is an error instantiating the schema.
211
+ """
212
+ ...
213
+
214
+ @abstractmethod
215
+ def drop(self) -> None:
216
+ """Drop the schema in the database if it exists.
217
+
218
+ Implementations should use ``IF EXISTS`` semantics to avoid raising
219
+ an error if the schema does not exist.
220
+
221
+ Raises
222
+ ------
223
+ DatabaseContextError
224
+ If there is an error dropping the schema.
225
+ """
226
+ ...
227
+
228
+ @abstractmethod
229
+ def create_all(self) -> None:
230
+ """Create all database objects in the schema using the metadata
231
+ object.
232
+
233
+ Raises
234
+ ------
235
+ DatabaseContextError
236
+ If there is an error creating the schema objects in the database.
237
+ """
238
+ ...
239
+
240
+ @abstractmethod
241
+ def create_indexes(self) -> None:
242
+ """Create all indexes in the schema using the metadata object.
243
+
244
+ Raises
245
+ ------
246
+ DatabaseContextError
247
+ If there is an error creating the indexes in the database.
248
+ """
249
+ ...
250
+
251
+ @abstractmethod
252
+ def drop_indexes(self) -> None:
253
+ """Drop all indexes in the schema using the metadata object.
254
+
255
+ Raises
256
+ ------
257
+ DatabaseContextError
258
+ If there is an error dropping the indexes in the database.
259
+ """
260
+ ...
261
+
262
+ @abstractmethod
263
+ def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
264
+ """Execute a SQL statement and return the result.
265
+
266
+ Parameters
267
+ ----------
268
+ statement
269
+ The SQL statement to execute.
270
+
271
+ Returns
272
+ -------
273
+ `~sqlalchemy.engine.Result`
274
+ The result of the statement execution.
275
+
276
+ Raises
277
+ ------
278
+ DatabaseContextError
279
+ If there is an error executing the SQL statement.
280
+ """
281
+ ...
282
+
283
+
284
+ class _BaseContext(DatabaseContext):
285
+ """Base database context providing common behavior.
286
+
287
+ Parameters
288
+ ----------
289
+ engine_url
290
+ The SQLAlchemy engine for connecting to the database.
291
+ metadata
292
+ The SQLAlchemy metadata representing the database objects.
293
+ require_schema
294
+ True if a valid schema name is required on the MetaData, False if not.
295
+ """
296
+
297
+ # Subclasses should set this to the dialect name.
298
+ DIALECT: str
299
+
300
+ def __init__(self, engine_url: URL, metadata: MetaData, require_schema: bool = False) -> None:
301
+ self._engine_url = engine_url
302
+ self._metadata = metadata
303
+ self._schema_name: str | None = metadata.schema
304
+ self._engine: Engine | None = None
305
+ self._echo: bool = False
306
+
307
+ # Check that the URL dialect matches this context's expected dialect
308
+ self._validate_dialect(engine_url)
309
+
310
+ # Ensure the schema name is set for dialects that require it
311
+ if require_schema and self._schema_name is None:
312
+ raise DatabaseContextError(f"Schema name must be set for context: {self.dialect_name}")
313
+
314
+ @property
315
+ def echo(self) -> bool:
316
+ """Whether to log all SQL statements executed by the engine
317
+ (``bool``).
318
+ """
319
+ return self._echo
320
+
321
+ @echo.setter
322
+ def echo(self, value: bool) -> None:
323
+ self._echo = value
324
+ if self.engine is not None:
325
+ self.engine.echo = value
326
+
327
+ @classmethod
328
+ def _validate_dialect(cls, engine_url: URL) -> None:
329
+ """Validate that the engine dialect matches this context's expected
330
+ dialect.
331
+
332
+ Parameters
333
+ ----------
334
+ engine_url
335
+ The SQLAlchemy database URL to validate.
336
+
337
+ Raises
338
+ ------
339
+ DatabaseContextError
340
+ If the engine dialect doesn't match the context's expected dialect.
341
+ """
342
+ # Normalize both the engine dialect and expected dialect for comparison
343
+ engine_dialect = _dialect_name(engine_url)
344
+ expected_dialect = cls.DIALECT.lower()
345
+
346
+ if engine_dialect != expected_dialect:
347
+ raise DatabaseContextError(
348
+ f"Engine dialect '{engine_dialect}' does not match the context's expected dialect: "
349
+ f"{expected_dialect}"
350
+ )
351
+
352
+ @property
353
+ def engine(self) -> Engine:
354
+ if self._engine is None:
355
+ self._engine = create_engine(self._engine_url)
356
+ return self._engine
357
+
358
+ @property
359
+ def metadata(self) -> MetaData:
360
+ return self._metadata
361
+
362
+ @property
363
+ def dialect(self) -> Dialect:
364
+ return self.engine.dialect
365
+
366
+ @property
367
+ def dialect_name(self) -> str:
368
+ """Get the dialect name for this database context.
369
+
370
+ Returns
371
+ -------
372
+ str
373
+ The normalized dialect name.
374
+ """
375
+ return self.DIALECT
376
+
377
+ @property
378
+ def schema_name(self) -> str | None:
379
+ """Effective schema name for this context (may be None).
380
+
381
+ Returns
382
+ -------
383
+ str | None
384
+ The schema name, or None if no schema is set.
385
+ """
386
+ return self._schema_name
387
+
388
+ @contextmanager
389
+ def connect(self) -> Iterator[Connection]:
390
+ """Context manager for database connection."""
391
+ with self.engine.connect() as connection:
392
+ yield connection
393
+
394
+ def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
395
+ statement = _normalize_statement(statement)
396
+ try:
397
+ with self.connect() as conn:
398
+ with conn.begin():
399
+ if parameters:
400
+ result = conn.execute(statement, parameters)
401
+ else:
402
+ result = conn.execute(statement)
403
+ return result
404
+ except SQLAlchemyError as e:
405
+ raise DatabaseContextError(f"Error executing statement: {e}") from e
406
+
407
+ def create_all(self) -> None:
408
+ with self.connect() as conn:
409
+ with conn.begin():
410
+ try:
411
+ self.metadata.create_all(bind=conn)
412
+ except SQLAlchemyError as e:
413
+ raise DatabaseContextError(f"Error creating database: {e}") from e
414
+
415
+ def _manage_indexes(self, action: str) -> None:
416
+ """Manage indexes by creating or dropping them.
417
+
418
+ Parameters
419
+ ----------
420
+ action
421
+ The action to perform, either "create" or "drop".
422
+
423
+ Raises
424
+ ------
425
+ DatabaseContextError
426
+ If there is an error managing the indexes in the database.
427
+ """
428
+ with self.connect() as conn:
429
+ with conn.begin():
430
+ try:
431
+ inspector = inspect(conn)
432
+ for table in self.metadata.tables.values():
433
+ # Fetch all existing indexes for this table once
434
+ existing_indexes = _get_existing_indexes(inspector, table.name, self.schema_name)
435
+
436
+ for index in table.indexes:
437
+ if index.name is None:
438
+ # Anonymous indexes can't be checked by name
439
+ logger.warning(f"Skipping anonymous index on table '{table.name}'")
440
+ continue
441
+
442
+ if action == "create":
443
+ if index.name in existing_indexes:
444
+ logger.warning(
445
+ f"Skipping creation of index '{index.name}' which already exists"
446
+ )
447
+ continue
448
+ index.create(bind=conn, checkfirst=False) # We already checked
449
+ logger.info(f"Created index '{index.name}'")
450
+ elif action == "drop":
451
+ if index.name not in existing_indexes:
452
+ logger.warning(f"Skipping index '{index.name}' which does not exist")
453
+ continue
454
+ index.drop(bind=conn, checkfirst=False) # We already checked
455
+ logger.info(f"Dropped index '{index.name}'")
456
+ else:
457
+ raise ValueError(f"Invalid action '{action}'. Must be 'create' or 'drop'.")
458
+ except SQLAlchemyError as e:
459
+ raise DatabaseContextError(f"Error {action}ing indexes: {e}") from e
460
+
461
+ def create_indexes(self) -> None:
462
+ """Create all indexes in the schema using the metadata object.
463
+
464
+ Raises
465
+ ------
466
+ DatabaseContextError
467
+ If there is an error creating the indexes in the database.
468
+ """
469
+ self._manage_indexes("create")
470
+
471
+ def drop_indexes(self) -> None:
472
+ """Drop all indexes in the schema using the metadata object.
473
+
474
+ Raises
475
+ ------
476
+ DatabaseContextError
477
+ If there is an error dropping the indexes in the database.
478
+ """
479
+ self._manage_indexes("drop")
480
+
481
+ def _required_schema_name(self) -> str:
482
+ """Return the schema name, ensuring that it is set.
483
+
484
+ This is mainly here for typing purposes, because the schema_name
485
+ property may be None, and mypy doesn't understand that we already
486
+ checked it during initialization.
487
+ """
488
+ if self.schema_name is None:
489
+ raise DatabaseContextError("Schema name is required but not set.")
490
+ return self.schema_name
491
+
492
+ def close(self) -> None:
493
+ """Close and dispose of the database engine."""
494
+ if self._engine is not None:
495
+ self._engine.dispose()
496
+ self._engine = None
497
+
498
+
499
+ _ContextClass: TypeAlias = type[_BaseContext]
500
+ _ContextDecorator: TypeAlias = Callable[[_ContextClass], _ContextClass]
501
+
502
+
503
+ class DatabaseContextFactory:
504
+ """Factory for creating DatabaseContext instances based on dialect type."""
505
+
506
+ _registry: dict[str, _ContextClass] = {}
507
+
508
+ @classmethod
509
+ def register(cls) -> _ContextDecorator:
510
+ """Register a context class for its dialect.
511
+
512
+ The dialect is determined by reading the DIALECT attribute from the
513
+ decorated class.
514
+
515
+ Returns
516
+ -------
517
+ Callable
518
+ The decorator function that registers the context class.
519
+
520
+ Examples
521
+ --------
522
+ >>> @DatabaseContextFactory.register()
523
+ ... class PostgreSQLContext(_BaseContext):
524
+ ... DIALECT = "postgresql"
525
+ ... pass
526
+
527
+ Notes
528
+ -----
529
+ The registry is populated at module import time and afterwards should
530
+ be treated as read-only.
531
+ """
532
+
533
+ def decorator(context_class: type[_BaseContext]) -> type[_BaseContext]:
534
+ # Get the dialect from the class's DIALECT attribute
535
+ if not hasattr(context_class, "DIALECT"):
536
+ raise ValueError(f"Context class {context_class.__name__} must define a DIALECT attribute")
537
+ cls._registry[context_class.DIALECT] = context_class
538
+ return context_class
539
+
540
+ return decorator
541
+
542
+ @classmethod
543
+ def register_class(cls, dialect: str, context_class: type[_BaseContext]) -> None:
544
+ """Register a context class for a specific dialect programmatically.
545
+
546
+ Parameters
547
+ ----------
548
+ dialect
549
+ The dialect name to register.
550
+ context_class
551
+ The context class to use for this dialect.
552
+ """
553
+ dialect_name = dialect.lower()
554
+ if "+" in dialect_name:
555
+ dialect_name = dialect_name.split("+")[0]
556
+ cls._registry[dialect_name] = context_class
557
+
558
+ @classmethod
559
+ def create_context(cls, dialect: str, engine_url: URL, metadata: MetaData) -> DatabaseContext:
560
+ """Create a context instance for the given dialect.
561
+
562
+ Parameters
563
+ ----------
564
+ dialect
565
+ The database dialect name.
566
+ engine_url
567
+ The SQLAlchemy database URL.
568
+ metadata
569
+ The SQLAlchemy metadata.
570
+
571
+ Returns
572
+ -------
573
+ DatabaseContext
574
+ The appropriate context instance.
575
+
576
+ Raises
577
+ ------
578
+ ValueError
579
+ If no context class is registered for the dialect.
580
+ """
581
+ dialect_name = dialect.lower()
582
+ if "+" in dialect_name:
583
+ dialect_name = dialect_name.split("+")[0]
584
+
585
+ if dialect_name not in cls._registry:
586
+ supported = cls.get_supported_dialects()
587
+ raise ValueError(
588
+ f"No context class registered for dialect: {dialect_name}. "
589
+ f"Supported dialects: {', '.join(supported)}"
590
+ )
591
+
592
+ context_class = cls._registry[dialect_name]
593
+ return context_class(engine_url, metadata)
594
+
595
+ @classmethod
596
+ def get_supported_dialects(cls) -> list[str]:
597
+ """Get a list of supported dialect names.
598
+
599
+ Returns
600
+ -------
601
+ list[str]
602
+ List of supported dialect names.
603
+ """
604
+ return list(cls._registry.keys())
605
+
606
+
607
+ class _SQLWriter:
608
+ """Write SQL statements to stdout or a file.
609
+
610
+ Parameters
611
+ ----------
612
+ file
613
+ The file to write the SQL statements to. If None, the statements
614
+ will be written to stdout.
615
+ """
616
+
617
+ def __init__(self, file: IO[str] | None = None) -> None:
618
+ """Initialize the SQL writer."""
619
+ self.file = file
620
+ self.dialect: Dialect | None = None
621
+
622
+ def write(self, sql: Any, *multiparams: Any, **params: Any) -> None:
623
+ """Write the SQL statement to a file or stdout.
624
+
625
+ Statements with parameters will be formatted with the values
626
+ inserted into the resultant SQL output.
627
+
628
+ Parameters
629
+ ----------
630
+ sql
631
+ The SQL statement to write.
632
+ *multiparams
633
+ The multiparams to use for the SQL statement.
634
+ **params
635
+ The params to use for the SQL statement.
636
+
637
+ Notes
638
+ -----
639
+ The functions arguments are typed very loosely because this method in
640
+ SQLAlchemy is untyped, amd we do not call it directly.
641
+ """
642
+ compiled = sql.compile(dialect=self.dialect)
643
+ sql_str = str(compiled) + ";"
644
+ params_list = [compiled.params]
645
+ for params in params_list:
646
+ if not params:
647
+ print(sql_str, file=self.file)
648
+ continue
649
+ new_params = {}
650
+ for key, value in params.items():
651
+ if isinstance(value, str):
652
+ new_params[key] = f"'{value}'"
653
+ elif value is None:
654
+ new_params[key] = "null"
655
+ else:
656
+ new_params[key] = value
657
+ print(sql_str % new_params, file=self.file)
658
+
659
+
660
+ @DatabaseContextFactory.register()
661
+ class PostgreSQLContext(_BaseContext):
662
+ """Database context for Postgres.
663
+
664
+ Parameters
665
+ ----------
666
+ engine_url
667
+ The SQLAlchemy database URL for connecting to the database.
668
+ metadata
669
+ The SQLAlchemy metadata representing the database objects.
670
+ """
671
+
672
+ DIALECT = "postgresql"
673
+
674
+ def __init__(self, engine_url: URL, metadata: MetaData):
675
+ super().__init__(engine_url, metadata, require_schema=True)
676
+
677
+ def initialize(self) -> None:
678
+ schema_name = self._required_schema_name()
679
+ try:
680
+ logger.debug(f"Checking if PG schema exists: {schema_name}")
681
+ result = self.execute(
682
+ """
683
+ SELECT schema_name
684
+ FROM information_schema.schemata
685
+ WHERE schema_name = :schema_name
686
+ """,
687
+ {"schema_name": schema_name},
688
+ )
689
+ if result.fetchone():
690
+ return
691
+ logger.debug(f"Creating PG schema: {schema_name}")
692
+ self.execute(CreateSchema(schema_name))
693
+ except SQLAlchemyError as e:
694
+ raise DatabaseContextError(f"Error initializing Postgres schema: {e}") from e
695
+
696
+ def drop(self) -> None:
697
+ schema_name = self._required_schema_name()
698
+ try:
699
+ logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
700
+ self.execute(DropSchema(schema_name, if_exists=True, cascade=True))
701
+ except SQLAlchemyError as e:
702
+ raise DatabaseContextError(f"Error dropping Postgres database: {e}") from e
703
+
704
+
705
+ @DatabaseContextFactory.register()
706
+ class MySQLContext(_BaseContext):
707
+ """Database context for MySQL.
708
+
709
+ Parameters
710
+ ----------
711
+ engine_url
712
+ The SQLAlchemy database URL for connecting to the database.
713
+ metadata
714
+ The SQLAlchemy metadata representing the database objects.
715
+ """
716
+
717
+ DIALECT = "mysql"
718
+
719
+ def __init__(self, engine_url: URL, metadata: MetaData):
720
+ super().__init__(engine_url, metadata, require_schema=True)
721
+
722
+ def initialize(self) -> None:
723
+ # The schema is instantiated as a database, as MySQL does not have a
724
+ # distinct schema concept, unlike Postgres.
725
+ schema_name = self._required_schema_name()
726
+ try:
727
+ logger.debug(f"Checking if MySQL database exists: {schema_name}")
728
+ result = self.execute("SHOW DATABASES LIKE :schema_name", {"schema_name": schema_name})
729
+ if result.fetchone():
730
+ return
731
+ logger.debug(f"Creating MySQL database: {schema_name}")
732
+ from sqlalchemy import DDL
733
+
734
+ create_stmt = DDL(f"CREATE DATABASE {quoted_name(schema_name, quote=True)}")
735
+ self.execute(create_stmt)
736
+ except SQLAlchemyError as e:
737
+ raise DatabaseContextError(f"Error initializing MySQL database: {e}") from e
738
+
739
+ def drop(self) -> None:
740
+ schema_name = self._required_schema_name()
741
+ try:
742
+ logger.debug(f"Dropping MySQL database if exists: {schema_name}")
743
+ from sqlalchemy import DDL
744
+
745
+ drop_stmt = DDL(f"DROP DATABASE IF EXISTS {quoted_name(schema_name, quote=True)}")
746
+ self.execute(drop_stmt)
747
+ except SQLAlchemyError as e:
748
+ raise DatabaseContextError(f"Error dropping MySQL database: {e}") from e
749
+
750
+
751
+ @DatabaseContextFactory.register()
752
+ class SQLiteContext(_BaseContext):
753
+ """Database context for SQLite.
754
+
755
+ Parameters
756
+ ----------
757
+ engine_url
758
+ The SQLAlchemy database URL for connecting to the database.
759
+ metadata
760
+ The SQLAlchemy metadata representing the database objects.
761
+ """
762
+
763
+ DIALECT = "sqlite"
764
+
765
+ def __init__(self, engine_url: URL, metadata: MetaData):
766
+ # Schema name needs to be cleared, if set.
767
+ _clear_schema(metadata)
768
+ # Schema name is not required.
769
+ super().__init__(engine_url, metadata)
770
+
771
+ def initialize(self) -> None:
772
+ # Nothing needs to be done for SQLite initialization.
773
+ return
774
+
775
+ def drop(self) -> None:
776
+ try:
777
+ logger.debug("Dropping tables in SQLite schema")
778
+ # Drop all the tables in the database file.
779
+ self.metadata.drop_all(bind=self.engine)
780
+ except SQLAlchemyError as e:
781
+ raise DatabaseContextError(f"Error dropping SQLite database: {e}") from e
782
+
783
+
784
+ class MockContext(DatabaseContext):
785
+ """Database context for a mock connection.
786
+
787
+ Parameters
788
+ ----------
789
+ metadata
790
+ The SQLAlchemy metadata defining the database objects.
791
+ connection
792
+ The SQLAlchemy mock connection.
793
+ """
794
+
795
+ def __init__(self, metadata: MetaData, connection: MockConnection):
796
+ self._metadata = metadata
797
+ self._connection = connection
798
+ self._dialect = connection.dialect
799
+
800
+ @property
801
+ def dialect(self) -> Dialect:
802
+ return self._dialect
803
+
804
+ @property
805
+ def dialect_name(self) -> str:
806
+ return self.dialect.name
807
+
808
+ @property
809
+ def metadata(self) -> MetaData:
810
+ return self._metadata
811
+
812
+ @property
813
+ def engine(self) -> Engine:
814
+ raise DatabaseContextError("MockContext does not provide an engine.")
815
+
816
+ def initialize(self) -> None:
817
+ # Mock connection doesn't do any initialization.
818
+ pass
819
+
820
+ def drop(self) -> None:
821
+ # Mock connection doesn't drop.
822
+ pass
823
+
824
+ def create_all(self) -> None:
825
+ self._metadata.create_all(self._connection)
826
+
827
+ def create_indexes(self) -> None:
828
+ # Mock connection can't create indexes.
829
+ pass
830
+
831
+ def drop_indexes(self) -> None:
832
+ # Mock connection can't drop indexes.
833
+ pass
834
+
835
+ def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result:
836
+ statement = _normalize_statement(statement)
837
+ if parameters:
838
+ return self._connection.connect().execute(statement, parameters)
839
+ else:
840
+ return self._connection.connect().execute(statement)
841
+
842
+ def close(self) -> None:
843
+ """Close the mock connection (no-op)."""
844
+ pass
845
+
846
+
847
+ def create_database_context(
848
+ engine_url: str | URL,
849
+ metadata: MetaData,
850
+ output_file: IO[str] | None = None,
851
+ dry_run: bool = False,
852
+ echo: bool | None = None,
853
+ ) -> DatabaseContext:
854
+ """Create a DatabaseContext object based on the engine URL.
855
+
856
+ Parameters
857
+ ----------
858
+ engine_url
859
+ The database URL for the database connection.
860
+ metadata
861
+ The SQLAlchemy MetaData representing the database objects.
862
+ output_file
863
+ Output file for writing generated SQL commands.
864
+ dry_run
865
+ If True, configure the context to perform a dry run, where operations
866
+ will not be executed.
867
+ If False, use a normal context where operations are executed.
868
+ echo
869
+ If True, the SQLAlchemy engine will log all statements to the console.
870
+
871
+ Returns
872
+ -------
873
+ DatabaseContext
874
+ A database context appropriate for the given engine URL. This will be
875
+ a `MockContext` if the URL appears like a mock URL or if ``dry_run`` is
876
+ True, otherwise it will be a context based on the dialect using the
877
+ factory pattern.
878
+
879
+ Raises
880
+ ------
881
+ DatabaseContextError
882
+ If the dialect is not supported or if there's an issue creating
883
+ the context.
884
+ """
885
+ if isinstance(engine_url, str):
886
+ engine_url = make_url(engine_url)
887
+
888
+ if is_mock_url(engine_url) or dry_run:
889
+ # Use a mock context for mock URLs or dry run mode.
890
+ dialect_name = _dialect_name(engine_url)
891
+ if dialect_name == "sqlite":
892
+ _clear_schema(metadata)
893
+ mock_connection = _create_mock_connection(engine_url, output_file)
894
+ return MockContext(metadata, mock_connection)
895
+ else:
896
+ # Create a real engine and context for the given dialect.
897
+ try:
898
+ dialect_name = _dialect_name(engine_url)
899
+
900
+ # Use the factory to create the appropriate context
901
+ try:
902
+ db_ctx = DatabaseContextFactory.create_context(dialect_name, engine_url, metadata)
903
+ if echo is not None:
904
+ # This is settable for real contexts only.
905
+ if hasattr(db_ctx, "echo"):
906
+ db_ctx.echo = echo
907
+ return db_ctx
908
+ except ValueError as e:
909
+ supported = DatabaseContextFactory.get_supported_dialects()
910
+ raise DatabaseContextError(
911
+ f"Unsupported dialect: {dialect_name}. Supported dialects are: {', '.join(supported)}"
912
+ ) from e
913
+
914
+ except Exception as e:
915
+ if isinstance(e, DatabaseContextError):
916
+ raise
917
+ raise DatabaseContextError(f"Failed to create database context: {e}") from e