lsst-felis 27.2024.4100__py3-none-any.whl → 27.2024.4300__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lsst-felis might be problematic. Click here for more details.
- felis/cli.py +83 -14
- felis/datamodel.py +175 -3
- felis/db/utils.py +92 -12
- felis/schemas/tap_schema_std.yaml +273 -0
- felis/tap.py +1 -5
- felis/tap_schema.py +644 -0
- felis/tests/utils.py +122 -0
- felis/version.py +1 -1
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/METADATA +2 -1
- lsst_felis-27.2024.4300.dist-info/RECORD +26 -0
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/WHEEL +1 -1
- lsst_felis-27.2024.4100.dist-info/RECORD +0 -23
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/COPYRIGHT +0 -0
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/LICENSE +0 -0
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/entry_points.txt +0 -0
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/top_level.txt +0 -0
- {lsst_felis-27.2024.4100.dist-info → lsst_felis-27.2024.4300.dist-info}/zip-safe +0 -0
felis/cli.py
CHANGED
|
@@ -23,22 +23,21 @@
|
|
|
23
23
|
|
|
24
24
|
from __future__ import annotations
|
|
25
25
|
|
|
26
|
-
import io
|
|
27
26
|
import logging
|
|
28
27
|
from collections.abc import Iterable
|
|
29
28
|
from typing import IO
|
|
30
29
|
|
|
31
30
|
import click
|
|
32
|
-
import yaml
|
|
33
31
|
from pydantic import ValidationError
|
|
34
32
|
from sqlalchemy.engine import Engine, create_engine, make_url
|
|
35
|
-
from sqlalchemy.engine.mock import MockConnection
|
|
33
|
+
from sqlalchemy.engine.mock import MockConnection, create_mock_engine
|
|
36
34
|
|
|
37
35
|
from . import __version__
|
|
38
36
|
from .datamodel import Schema
|
|
39
|
-
from .db.utils import DatabaseContext
|
|
37
|
+
from .db.utils import DatabaseContext, is_mock_url
|
|
40
38
|
from .metadata import MetaDataBuilder
|
|
41
39
|
from .tap import Tap11Base, TapLoadingVisitor, init_tables
|
|
40
|
+
from .tap_schema import DataLoader, TableManager
|
|
42
41
|
|
|
43
42
|
__all__ = ["cli"]
|
|
44
43
|
|
|
@@ -107,7 +106,7 @@ def create(
|
|
|
107
106
|
dry_run: bool,
|
|
108
107
|
output_file: IO[str] | None,
|
|
109
108
|
ignore_constraints: bool,
|
|
110
|
-
file: IO,
|
|
109
|
+
file: IO[str],
|
|
111
110
|
) -> None:
|
|
112
111
|
"""Create database objects from the Felis file.
|
|
113
112
|
|
|
@@ -133,8 +132,7 @@ def create(
|
|
|
133
132
|
Felis file to read.
|
|
134
133
|
"""
|
|
135
134
|
try:
|
|
136
|
-
|
|
137
|
-
schema = Schema.model_validate(yaml_data, context={"id_generation": ctx.obj["id_generation"]})
|
|
135
|
+
schema = Schema.from_stream(file, context={"id_generation": ctx.obj["id_generation"]})
|
|
138
136
|
url = make_url(engine_url)
|
|
139
137
|
if schema_name:
|
|
140
138
|
logger.info(f"Overriding schema name with: {schema_name}")
|
|
@@ -261,7 +259,7 @@ def load_tap(
|
|
|
261
259
|
tap_keys_table: str,
|
|
262
260
|
tap_key_columns_table: str,
|
|
263
261
|
tap_schema_index: int,
|
|
264
|
-
file:
|
|
262
|
+
file: IO[str],
|
|
265
263
|
) -> None:
|
|
266
264
|
"""Load TAP metadata from a Felis file.
|
|
267
265
|
|
|
@@ -304,8 +302,7 @@ def load_tap(
|
|
|
304
302
|
The data will be loaded into the TAP_SCHEMA from the engine URL. The
|
|
305
303
|
tables must have already been initialized or an error will occur.
|
|
306
304
|
"""
|
|
307
|
-
|
|
308
|
-
schema = Schema.model_validate(yaml_data)
|
|
305
|
+
schema = Schema.from_stream(file)
|
|
309
306
|
|
|
310
307
|
tap_tables = init_tables(
|
|
311
308
|
tap_schema_name,
|
|
@@ -345,6 +342,79 @@ def load_tap(
|
|
|
345
342
|
tap_visitor.visit_schema(schema)
|
|
346
343
|
|
|
347
344
|
|
|
345
|
+
@cli.command("load-tap-schema", help="Load metadata from a Felis file into a TAP_SCHEMA database")
|
|
346
|
+
@click.option("--engine-url", envvar="FELIS_ENGINE_URL", help="SQLAlchemy Engine URL")
|
|
347
|
+
@click.option("--tap-schema-name", help="Name of the TAP_SCHEMA schema in the database")
|
|
348
|
+
@click.option(
|
|
349
|
+
"--tap-tables-postfix", help="Postfix which is applied to standard TAP_SCHEMA table names", default=""
|
|
350
|
+
)
|
|
351
|
+
@click.option("--tap-schema-index", type=int, help="TAP_SCHEMA index of the schema in this environment")
|
|
352
|
+
@click.option("--dry-run", is_flag=True, help="Execute dry run only. Does not insert any data.")
|
|
353
|
+
@click.option("--echo", is_flag=True, help="Print out the generated insert statements to stdout")
|
|
354
|
+
@click.option("--output-file", type=click.Path(), help="Write SQL commands to a file")
|
|
355
|
+
@click.argument("file", type=click.File())
|
|
356
|
+
@click.pass_context
|
|
357
|
+
def load_tap_schema(
|
|
358
|
+
ctx: click.Context,
|
|
359
|
+
engine_url: str,
|
|
360
|
+
tap_schema_name: str,
|
|
361
|
+
tap_tables_postfix: str,
|
|
362
|
+
tap_schema_index: int,
|
|
363
|
+
dry_run: bool,
|
|
364
|
+
echo: bool,
|
|
365
|
+
output_file: str | None,
|
|
366
|
+
file: IO[str],
|
|
367
|
+
) -> None:
|
|
368
|
+
"""Load TAP metadata from a Felis file.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
engine_url
|
|
373
|
+
SQLAlchemy Engine URL.
|
|
374
|
+
tap_tables_postfix
|
|
375
|
+
Postfix which is applied to standard TAP_SCHEMA table names.
|
|
376
|
+
tap_schema_index
|
|
377
|
+
TAP_SCHEMA index of the schema in this environment.
|
|
378
|
+
dry_run
|
|
379
|
+
Execute dry run only. Does not insert any data.
|
|
380
|
+
echo
|
|
381
|
+
Print out the generated insert statements to stdout.
|
|
382
|
+
output_file
|
|
383
|
+
Output file for writing generated SQL.
|
|
384
|
+
file
|
|
385
|
+
Felis file to read.
|
|
386
|
+
|
|
387
|
+
Notes
|
|
388
|
+
-----
|
|
389
|
+
The TAP_SCHEMA database must already exist or the command will fail. This
|
|
390
|
+
command will not initialize the TAP_SCHEMA tables.
|
|
391
|
+
"""
|
|
392
|
+
url = make_url(engine_url)
|
|
393
|
+
engine: Engine | MockConnection
|
|
394
|
+
if dry_run or is_mock_url(url):
|
|
395
|
+
engine = create_mock_engine(url, executor=None)
|
|
396
|
+
else:
|
|
397
|
+
engine = create_engine(engine_url)
|
|
398
|
+
mgr = TableManager(
|
|
399
|
+
engine=engine,
|
|
400
|
+
apply_schema_to_metadata=False if engine.dialect.name == "sqlite" else True,
|
|
401
|
+
schema_name=tap_schema_name,
|
|
402
|
+
table_name_postfix=tap_tables_postfix,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
schema = Schema.from_stream(file, context={"id_generation": ctx.obj["id_generation"]})
|
|
406
|
+
|
|
407
|
+
DataLoader(
|
|
408
|
+
schema,
|
|
409
|
+
mgr,
|
|
410
|
+
engine,
|
|
411
|
+
tap_schema_index=tap_schema_index,
|
|
412
|
+
dry_run=dry_run,
|
|
413
|
+
print_sql=echo,
|
|
414
|
+
output_path=output_file,
|
|
415
|
+
).load()
|
|
416
|
+
|
|
417
|
+
|
|
348
418
|
@cli.command("validate", help="Validate one or more Felis YAML files")
|
|
349
419
|
@click.option(
|
|
350
420
|
"--check-description", is_flag=True, help="Check that all objects have a description", default=False
|
|
@@ -372,7 +442,7 @@ def validate(
|
|
|
372
442
|
check_redundant_datatypes: bool,
|
|
373
443
|
check_tap_table_indexes: bool,
|
|
374
444
|
check_tap_principal: bool,
|
|
375
|
-
files: Iterable[
|
|
445
|
+
files: Iterable[IO[str]],
|
|
376
446
|
) -> None:
|
|
377
447
|
"""Validate one or more felis YAML files.
|
|
378
448
|
|
|
@@ -406,9 +476,8 @@ def validate(
|
|
|
406
476
|
file_name = getattr(file, "name", None)
|
|
407
477
|
logger.info(f"Validating {file_name}")
|
|
408
478
|
try:
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
data,
|
|
479
|
+
Schema.from_stream(
|
|
480
|
+
file,
|
|
412
481
|
context={
|
|
413
482
|
"check_description": check_description,
|
|
414
483
|
"check_redundant_datatypes": check_redundant_datatypes,
|
felis/datamodel.py
CHANGED
|
@@ -26,10 +26,12 @@ from __future__ import annotations
|
|
|
26
26
|
import logging
|
|
27
27
|
from collections.abc import Sequence
|
|
28
28
|
from enum import StrEnum, auto
|
|
29
|
-
from typing import Annotated, Any, Literal, TypeAlias, Union
|
|
29
|
+
from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union
|
|
30
30
|
|
|
31
|
+
import yaml
|
|
31
32
|
from astropy import units as units # type: ignore
|
|
32
33
|
from astropy.io.votable import ucd # type: ignore
|
|
34
|
+
from lsst.resources import ResourcePath, ResourcePathExpression
|
|
33
35
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
|
|
34
36
|
|
|
35
37
|
from .db.dialects import get_supported_dialects
|
|
@@ -253,7 +255,7 @@ class Column(BaseObject):
|
|
|
253
255
|
Raises
|
|
254
256
|
------
|
|
255
257
|
ValueError
|
|
256
|
-
Raised
|
|
258
|
+
Raised if both FITS and IVOA units are provided, or if the unit is
|
|
257
259
|
invalid.
|
|
258
260
|
"""
|
|
259
261
|
fits_unit = self.fits_tunit
|
|
@@ -383,6 +385,58 @@ class Column(BaseObject):
|
|
|
383
385
|
raise ValueError("Precision is only valid for timestamp columns")
|
|
384
386
|
return self
|
|
385
387
|
|
|
388
|
+
@model_validator(mode="before")
|
|
389
|
+
@classmethod
|
|
390
|
+
def check_votable_arraysize(cls, values: dict[str, Any]) -> dict[str, Any]:
|
|
391
|
+
"""Set the default value for the ``votable_arraysize`` field, which
|
|
392
|
+
corresponds to ``arraysize`` in the IVOA VOTable standard.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
values
|
|
397
|
+
Values of the column.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
`dict` [ `str`, `Any` ]
|
|
402
|
+
The values of the column.
|
|
403
|
+
|
|
404
|
+
Notes
|
|
405
|
+
-----
|
|
406
|
+
Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
|
|
407
|
+
be used.
|
|
408
|
+
"""
|
|
409
|
+
if values.get("name", None) is None or values.get("datatype", None) is None:
|
|
410
|
+
# Skip bad column data that will not validate
|
|
411
|
+
return values
|
|
412
|
+
arraysize = values.get("votable:arraysize", None)
|
|
413
|
+
if arraysize is None:
|
|
414
|
+
length = values.get("length", None)
|
|
415
|
+
datatype = values.get("datatype")
|
|
416
|
+
if length is not None and length > 1:
|
|
417
|
+
# Following the IVOA standard, arraysize of 1 is disallowed
|
|
418
|
+
if datatype == "char":
|
|
419
|
+
arraysize = str(length)
|
|
420
|
+
elif datatype in ("string", "unicode", "binary"):
|
|
421
|
+
arraysize = f"{length}*"
|
|
422
|
+
elif datatype in ("timestamp", "text"):
|
|
423
|
+
arraysize = "*"
|
|
424
|
+
if arraysize is not None:
|
|
425
|
+
values["votable:arraysize"] = arraysize
|
|
426
|
+
logger.debug(
|
|
427
|
+
f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'"
|
|
428
|
+
+ f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'"
|
|
429
|
+
)
|
|
430
|
+
else:
|
|
431
|
+
logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'")
|
|
432
|
+
if isinstance(values["votable:arraysize"], int):
|
|
433
|
+
logger.warning(
|
|
434
|
+
f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is "
|
|
435
|
+
+ "deprecated"
|
|
436
|
+
)
|
|
437
|
+
values["votable:arraysize"] = str(arraysize)
|
|
438
|
+
return values
|
|
439
|
+
|
|
386
440
|
|
|
387
441
|
class Constraint(BaseObject):
|
|
388
442
|
"""Table constraint model."""
|
|
@@ -700,7 +754,10 @@ class SchemaIdVisitor:
|
|
|
700
754
|
self.add(constraint)
|
|
701
755
|
|
|
702
756
|
|
|
703
|
-
|
|
757
|
+
T = TypeVar("T", bound=BaseObject)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
class Schema(BaseObject, Generic[T]):
|
|
704
761
|
"""Database schema model.
|
|
705
762
|
|
|
706
763
|
This represents a database schema, which contains one or more tables.
|
|
@@ -942,3 +999,118 @@ class Schema(BaseObject):
|
|
|
942
999
|
The ID of the object to check.
|
|
943
1000
|
"""
|
|
944
1001
|
return id in self.id_map
|
|
1002
|
+
|
|
1003
|
+
def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
|
|
1004
|
+
"""Find an object with the given type by its ID.
|
|
1005
|
+
|
|
1006
|
+
Parameters
|
|
1007
|
+
----------
|
|
1008
|
+
id
|
|
1009
|
+
The ID of the object to find.
|
|
1010
|
+
obj_type
|
|
1011
|
+
The type of the object to find.
|
|
1012
|
+
|
|
1013
|
+
Returns
|
|
1014
|
+
-------
|
|
1015
|
+
BaseObject
|
|
1016
|
+
The object with the given ID and type.
|
|
1017
|
+
|
|
1018
|
+
Raises
|
|
1019
|
+
------
|
|
1020
|
+
KeyError
|
|
1021
|
+
If the object with the given ID is not found in the schema.
|
|
1022
|
+
TypeError
|
|
1023
|
+
If the object that is found does not have the right type.
|
|
1024
|
+
|
|
1025
|
+
Notes
|
|
1026
|
+
-----
|
|
1027
|
+
The actual return type is the user-specified argument ``T``, which is
|
|
1028
|
+
expected to be a subclass of `BaseObject`.
|
|
1029
|
+
"""
|
|
1030
|
+
obj = self[id]
|
|
1031
|
+
if not isinstance(obj, obj_type):
|
|
1032
|
+
raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
|
|
1033
|
+
return obj
|
|
1034
|
+
|
|
1035
|
+
def get_table_by_column(self, column: Column) -> Table:
|
|
1036
|
+
"""Find the table that contains a column.
|
|
1037
|
+
|
|
1038
|
+
Parameters
|
|
1039
|
+
----------
|
|
1040
|
+
column
|
|
1041
|
+
The column to find.
|
|
1042
|
+
|
|
1043
|
+
Returns
|
|
1044
|
+
-------
|
|
1045
|
+
`Table`
|
|
1046
|
+
The table that contains the column.
|
|
1047
|
+
|
|
1048
|
+
Raises
|
|
1049
|
+
------
|
|
1050
|
+
ValueError
|
|
1051
|
+
If the column is not found in any table.
|
|
1052
|
+
"""
|
|
1053
|
+
for table in self.tables:
|
|
1054
|
+
if column in table.columns:
|
|
1055
|
+
return table
|
|
1056
|
+
raise ValueError(f"Column '{column.name}' not found in any table")
|
|
1057
|
+
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
|
|
1060
|
+
"""Load a `Schema` from a string representing a ``ResourcePath``.
|
|
1061
|
+
|
|
1062
|
+
Parameters
|
|
1063
|
+
----------
|
|
1064
|
+
resource_path
|
|
1065
|
+
The ``ResourcePath`` pointing to a YAML file.
|
|
1066
|
+
context
|
|
1067
|
+
Pydantic context to be used in validation.
|
|
1068
|
+
|
|
1069
|
+
Returns
|
|
1070
|
+
-------
|
|
1071
|
+
`str`
|
|
1072
|
+
The ID of the object.
|
|
1073
|
+
|
|
1074
|
+
Raises
|
|
1075
|
+
------
|
|
1076
|
+
yaml.YAMLError
|
|
1077
|
+
Raised if there is an error loading the YAML data.
|
|
1078
|
+
ValueError
|
|
1079
|
+
Raised if there is an error reading the resource.
|
|
1080
|
+
pydantic.ValidationError
|
|
1081
|
+
Raised if the schema fails validation.
|
|
1082
|
+
"""
|
|
1083
|
+
logger.debug(f"Loading schema from: '{resource_path}'")
|
|
1084
|
+
try:
|
|
1085
|
+
rp_stream = ResourcePath(resource_path).read()
|
|
1086
|
+
except Exception as e:
|
|
1087
|
+
raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
|
|
1088
|
+
yaml_data = yaml.safe_load(rp_stream)
|
|
1089
|
+
return Schema.model_validate(yaml_data, context=context)
|
|
1090
|
+
|
|
1091
|
+
@classmethod
|
|
1092
|
+
def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
|
|
1093
|
+
"""Load a `Schema` from a file stream which should contain YAML data.
|
|
1094
|
+
|
|
1095
|
+
Parameters
|
|
1096
|
+
----------
|
|
1097
|
+
source
|
|
1098
|
+
The file stream to read from.
|
|
1099
|
+
context
|
|
1100
|
+
Pydantic context to be used in validation.
|
|
1101
|
+
|
|
1102
|
+
Returns
|
|
1103
|
+
-------
|
|
1104
|
+
`Schema`
|
|
1105
|
+
The Felis schema loaded from the stream.
|
|
1106
|
+
|
|
1107
|
+
Raises
|
|
1108
|
+
------
|
|
1109
|
+
yaml.YAMLError
|
|
1110
|
+
Raised if there is an error loading the YAML file.
|
|
1111
|
+
pydantic.ValidationError
|
|
1112
|
+
Raised if the schema fails validation.
|
|
1113
|
+
"""
|
|
1114
|
+
logger.debug("Loading schema from: '%s'", source)
|
|
1115
|
+
yaml_data = yaml.safe_load(source)
|
|
1116
|
+
return Schema.model_validate(yaml_data, context=context)
|
felis/db/utils.py
CHANGED
|
@@ -106,6 +106,43 @@ def string_to_typeengine(
|
|
|
106
106
|
return type_obj
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
def is_mock_url(url: URL) -> bool:
|
|
110
|
+
"""Check if the engine URL is a mock URL.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
url
|
|
115
|
+
The SQLAlchemy engine URL.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
bool
|
|
120
|
+
True if the URL is a mock URL, False otherwise.
|
|
121
|
+
"""
|
|
122
|
+
return (url.drivername == "sqlite" and url.database is None) or (
|
|
123
|
+
url.drivername != "sqlite" and url.host is None
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def is_valid_engine(engine: Engine | MockConnection | None) -> bool:
|
|
128
|
+
"""Check if the engine is valid.
|
|
129
|
+
|
|
130
|
+
The engine cannot be none; it must not be a mock connection; and it must
|
|
131
|
+
not be a mock URL which is missing a host or, for sqlite, a database name.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
engine
|
|
136
|
+
The SQLAlchemy engine or mock connection.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
bool
|
|
141
|
+
True if the engine is valid, False otherwise.
|
|
142
|
+
"""
|
|
143
|
+
return engine is not None and not isinstance(engine, MockConnection) and not is_mock_url(engine.url)
|
|
144
|
+
|
|
145
|
+
|
|
109
146
|
class SQLWriter:
|
|
110
147
|
"""Write SQL statements to stdout or a file.
|
|
111
148
|
|
|
@@ -193,12 +230,19 @@ class ConnectionWrapper:
|
|
|
193
230
|
"""
|
|
194
231
|
if isinstance(statement, str):
|
|
195
232
|
statement = text(statement)
|
|
196
|
-
if isinstance(self.engine,
|
|
233
|
+
if isinstance(self.engine, Engine):
|
|
234
|
+
try:
|
|
235
|
+
with self.engine.begin() as connection:
|
|
236
|
+
result = connection.execute(statement)
|
|
237
|
+
return result
|
|
238
|
+
except SQLAlchemyError as e:
|
|
239
|
+
connection.rollback()
|
|
240
|
+
logger.error(f"Error executing statement: {e}")
|
|
241
|
+
raise
|
|
242
|
+
elif isinstance(self.engine, MockConnection):
|
|
197
243
|
return self.engine.connect().execute(statement)
|
|
198
244
|
else:
|
|
199
|
-
|
|
200
|
-
result = connection.execute(statement)
|
|
201
|
-
return result
|
|
245
|
+
raise ValueError("Unsupported engine type:" + str(type(self.engine)))
|
|
202
246
|
|
|
203
247
|
|
|
204
248
|
class DatabaseContext:
|
|
@@ -218,7 +262,7 @@ class DatabaseContext:
|
|
|
218
262
|
self.engine = engine
|
|
219
263
|
self.dialect_name = engine.dialect.name
|
|
220
264
|
self.metadata = metadata
|
|
221
|
-
self.
|
|
265
|
+
self.connection = ConnectionWrapper(engine)
|
|
222
266
|
|
|
223
267
|
def initialize(self) -> None:
|
|
224
268
|
"""Create the schema in the database if it does not exist.
|
|
@@ -240,14 +284,14 @@ class DatabaseContext:
|
|
|
240
284
|
try:
|
|
241
285
|
if self.dialect_name == "mysql":
|
|
242
286
|
logger.debug(f"Checking if MySQL database exists: {schema_name}")
|
|
243
|
-
result = self.
|
|
287
|
+
result = self.execute(text(f"SHOW DATABASES LIKE '{schema_name}'"))
|
|
244
288
|
if result.fetchone():
|
|
245
289
|
raise ValueError(f"MySQL database '{schema_name}' already exists.")
|
|
246
290
|
logger.debug(f"Creating MySQL database: {schema_name}")
|
|
247
|
-
self.
|
|
291
|
+
self.execute(text(f"CREATE DATABASE {schema_name}"))
|
|
248
292
|
elif self.dialect_name == "postgresql":
|
|
249
293
|
logger.debug(f"Checking if PG schema exists: {schema_name}")
|
|
250
|
-
result = self.
|
|
294
|
+
result = self.execute(
|
|
251
295
|
text(
|
|
252
296
|
f"""
|
|
253
297
|
SELECT schema_name
|
|
@@ -259,7 +303,7 @@ class DatabaseContext:
|
|
|
259
303
|
if result.fetchone():
|
|
260
304
|
raise ValueError(f"PostgreSQL schema '{schema_name}' already exists.")
|
|
261
305
|
logger.debug(f"Creating PG schema: {schema_name}")
|
|
262
|
-
self.
|
|
306
|
+
self.execute(CreateSchema(schema_name))
|
|
263
307
|
elif self.dialect_name == "sqlite":
|
|
264
308
|
# Just silently ignore this operation for SQLite. The database
|
|
265
309
|
# will still be created if it does not exist and the engine
|
|
@@ -285,13 +329,15 @@ class DatabaseContext:
|
|
|
285
329
|
schema. For other variants, this is an unsupported operation.
|
|
286
330
|
"""
|
|
287
331
|
schema_name = self.metadata.schema
|
|
332
|
+
if not self.engine.dialect.name == "sqlite" and self.metadata.schema is None:
|
|
333
|
+
raise ValueError("Schema name is required to drop the schema.")
|
|
288
334
|
try:
|
|
289
335
|
if self.dialect_name == "mysql":
|
|
290
336
|
logger.debug(f"Dropping MySQL database if exists: {schema_name}")
|
|
291
|
-
self.
|
|
337
|
+
self.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
|
|
292
338
|
elif self.dialect_name == "postgresql":
|
|
293
339
|
logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
|
|
294
|
-
self.
|
|
340
|
+
self.execute(DropSchema(schema_name, if_exists=True, cascade=True))
|
|
295
341
|
elif self.dialect_name == "sqlite":
|
|
296
342
|
if isinstance(self.engine, Engine):
|
|
297
343
|
logger.debug("Dropping tables in SQLite schema")
|
|
@@ -304,7 +350,21 @@ class DatabaseContext:
|
|
|
304
350
|
|
|
305
351
|
def create_all(self) -> None:
|
|
306
352
|
"""Create all tables in the schema using the metadata object."""
|
|
307
|
-
|
|
353
|
+
if isinstance(self.engine, Engine):
|
|
354
|
+
# Use a transaction for a real connection.
|
|
355
|
+
with self.engine.begin() as conn:
|
|
356
|
+
try:
|
|
357
|
+
self.metadata.create_all(bind=conn)
|
|
358
|
+
conn.commit()
|
|
359
|
+
except SQLAlchemyError as e:
|
|
360
|
+
conn.rollback()
|
|
361
|
+
logger.error(f"Error creating tables: {e}")
|
|
362
|
+
raise
|
|
363
|
+
elif isinstance(self.engine, MockConnection):
|
|
364
|
+
# Mock connection so no need for a transaction.
|
|
365
|
+
self.metadata.create_all(self.engine)
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError("Unsupported engine type: " + str(type(self.engine)))
|
|
308
368
|
|
|
309
369
|
@staticmethod
|
|
310
370
|
def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
|
|
@@ -327,3 +387,23 @@ class DatabaseContext:
|
|
|
327
387
|
engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat")
|
|
328
388
|
writer.dialect = engine.dialect
|
|
329
389
|
return engine
|
|
390
|
+
|
|
391
|
+
def execute(self, statement: Any) -> ResultProxy:
|
|
392
|
+
"""Execute a SQL statement on the engine and return the result.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
statement
|
|
397
|
+
The SQL statement to execute.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
``sqlalchemy.engine.ResultProxy``
|
|
402
|
+
The result of the statement execution.
|
|
403
|
+
|
|
404
|
+
Notes
|
|
405
|
+
-----
|
|
406
|
+
This is just a wrapper around the execution method of the connection
|
|
407
|
+
object, which may execute on a real or mock connection.
|
|
408
|
+
"""
|
|
409
|
+
return self.connection.execute(statement)
|