lsst-felis 27.2024.2300__tar.gz → 27.2024.2500__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lsst-felis might be problematic. Click here for more details.
- {lsst_felis-27.2024.2300/python/lsst_felis.egg-info → lsst_felis-27.2024.2500}/PKG-INFO +1 -1
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/cli.py +27 -30
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/datamodel.py +52 -62
- lsst_felis-27.2024.2500/python/felis/db/dialects.py +63 -0
- lsst_felis-27.2024.2500/python/felis/db/utils.py +248 -0
- lsst_felis-27.2024.2300/python/felis/db/_variants.py → lsst_felis-27.2024.2500/python/felis/db/variants.py +29 -22
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/metadata.py +2 -185
- lsst_felis-27.2024.2500/python/felis/version.py +2 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500/python/lsst_felis.egg-info}/PKG-INFO +1 -1
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/SOURCES.txt +4 -4
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_cli.py +12 -20
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_datamodel.py +115 -54
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_metadata.py +2 -1
- lsst_felis-27.2024.2300/python/felis/validation.py +0 -103
- lsst_felis-27.2024.2300/python/felis/version.py +0 -2
- lsst_felis-27.2024.2300/tests/test_validation.py +0 -233
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/COPYRIGHT +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/LICENSE +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/README.rst +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/pyproject.toml +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/__init__.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/db/__init__.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/db/sqltypes.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/py.typed +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/tap.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/types.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/dependency_links.txt +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/entry_points.txt +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/requires.txt +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/top_level.txt +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/zip-safe +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/setup.cfg +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_datatypes.py +0 -0
- {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_tap.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: lsst-felis
|
|
3
|
-
Version: 27.2024.
|
|
3
|
+
Version: 27.2024.2500
|
|
4
4
|
Summary: A vocabulary for describing catalogs and acting on those descriptions
|
|
5
5
|
Author-email: Rubin Observatory Data Management <dm-admin@lists.lsst.org>
|
|
6
6
|
License: GNU General Public License v3 or later (GPLv3+)
|
|
@@ -29,14 +29,14 @@ from typing import IO
|
|
|
29
29
|
import click
|
|
30
30
|
import yaml
|
|
31
31
|
from pydantic import ValidationError
|
|
32
|
-
from sqlalchemy.engine import Engine, create_engine,
|
|
32
|
+
from sqlalchemy.engine import Engine, create_engine, make_url
|
|
33
33
|
from sqlalchemy.engine.mock import MockConnection
|
|
34
34
|
|
|
35
35
|
from . import __version__
|
|
36
36
|
from .datamodel import Schema
|
|
37
|
-
from .
|
|
37
|
+
from .db.utils import DatabaseContext
|
|
38
|
+
from .metadata import MetaDataBuilder
|
|
38
39
|
from .tap import Tap11Base, TapLoadingVisitor, init_tables
|
|
39
|
-
from .validation import get_schema
|
|
40
40
|
|
|
41
41
|
logger = logging.getLogger("felis")
|
|
42
42
|
|
|
@@ -92,29 +92,27 @@ def create(
|
|
|
92
92
|
"""Create database objects from the Felis file."""
|
|
93
93
|
yaml_data = yaml.safe_load(file)
|
|
94
94
|
schema = Schema.model_validate(yaml_data)
|
|
95
|
-
|
|
95
|
+
url = make_url(engine_url)
|
|
96
96
|
if schema_name:
|
|
97
97
|
logger.info(f"Overriding schema name with: {schema_name}")
|
|
98
98
|
schema.name = schema_name
|
|
99
|
-
elif
|
|
99
|
+
elif url.drivername == "sqlite":
|
|
100
100
|
logger.info("Overriding schema name for sqlite with: main")
|
|
101
101
|
schema.name = "main"
|
|
102
|
-
if not
|
|
102
|
+
if not url.host and not url.drivername == "sqlite":
|
|
103
103
|
dry_run = True
|
|
104
104
|
logger.info("Forcing dry run for non-sqlite engine URL with no host")
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
builder.build()
|
|
108
|
-
metadata = builder.metadata
|
|
106
|
+
metadata = MetaDataBuilder(schema).build()
|
|
109
107
|
logger.debug(f"Created metadata with schema name: {metadata.schema}")
|
|
110
108
|
|
|
111
109
|
engine: Engine | MockConnection
|
|
112
110
|
if not dry_run and not output_file:
|
|
113
|
-
engine = create_engine(
|
|
111
|
+
engine = create_engine(url, echo=echo)
|
|
114
112
|
else:
|
|
115
113
|
if dry_run:
|
|
116
114
|
logger.info("Dry run will be executed")
|
|
117
|
-
engine = DatabaseContext.create_mock_engine(
|
|
115
|
+
engine = DatabaseContext.create_mock_engine(url, output_file)
|
|
118
116
|
if output_file:
|
|
119
117
|
logger.info("Writing SQL output to: " + output_file.name)
|
|
120
118
|
|
|
@@ -229,10 +227,7 @@ def load_tap(
|
|
|
229
227
|
)
|
|
230
228
|
tap_visitor.visit_schema(schema)
|
|
231
229
|
else:
|
|
232
|
-
|
|
233
|
-
conn = create_mock_engine(make_url(engine_url), executor=_insert_dump.dump, paramstyle="pyformat")
|
|
234
|
-
# After the engine is created, update the executor with the dialect
|
|
235
|
-
_insert_dump.dialect = conn.dialect
|
|
230
|
+
conn = DatabaseContext.create_mock_engine(engine_url)
|
|
236
231
|
|
|
237
232
|
tap_visitor = TapLoadingVisitor.from_mock_connection(
|
|
238
233
|
conn,
|
|
@@ -245,42 +240,44 @@ def load_tap(
|
|
|
245
240
|
|
|
246
241
|
|
|
247
242
|
@cli.command("validate")
|
|
243
|
+
@click.option("--check-description", is_flag=True, help="Require description for all objects", default=False)
|
|
248
244
|
@click.option(
|
|
249
|
-
"-
|
|
250
|
-
"--schema-name",
|
|
251
|
-
help="Schema name for validation",
|
|
252
|
-
type=click.Choice(["RSP", "default"]),
|
|
253
|
-
default="default",
|
|
245
|
+
"--check-redundant-datatypes", is_flag=True, help="Check for redundant datatypes", default=False
|
|
254
246
|
)
|
|
255
247
|
@click.option(
|
|
256
|
-
"
|
|
248
|
+
"--check-tap-table-indexes",
|
|
249
|
+
is_flag=True,
|
|
250
|
+
help="Check that every table has a unique TAP table index",
|
|
251
|
+
default=False,
|
|
257
252
|
)
|
|
258
253
|
@click.option(
|
|
259
|
-
"
|
|
254
|
+
"--check-tap-principal",
|
|
255
|
+
is_flag=True,
|
|
256
|
+
help="Check that at least one column per table is flagged as TAP principal",
|
|
257
|
+
default=False,
|
|
260
258
|
)
|
|
261
259
|
@click.argument("files", nargs=-1, type=click.File())
|
|
262
260
|
def validate(
|
|
263
|
-
|
|
264
|
-
require_description: bool,
|
|
261
|
+
check_description: bool,
|
|
265
262
|
check_redundant_datatypes: bool,
|
|
263
|
+
check_tap_table_indexes: bool,
|
|
264
|
+
check_tap_principal: bool,
|
|
266
265
|
files: Iterable[io.TextIOBase],
|
|
267
266
|
) -> None:
|
|
268
267
|
"""Validate one or more felis YAML files."""
|
|
269
|
-
schema_class = get_schema(schema_name)
|
|
270
|
-
if schema_name != "default":
|
|
271
|
-
logger.info(f"Using schema '{schema_class.__name__}'")
|
|
272
|
-
|
|
273
268
|
rc = 0
|
|
274
269
|
for file in files:
|
|
275
270
|
file_name = getattr(file, "name", None)
|
|
276
271
|
logger.info(f"Validating {file_name}")
|
|
277
272
|
try:
|
|
278
273
|
data = yaml.load(file, Loader=yaml.SafeLoader)
|
|
279
|
-
|
|
274
|
+
Schema.model_validate(
|
|
280
275
|
data,
|
|
281
276
|
context={
|
|
277
|
+
"check_description": check_description,
|
|
282
278
|
"check_redundant_datatypes": check_redundant_datatypes,
|
|
283
|
-
"
|
|
279
|
+
"check_tap_table_indexes": check_tap_table_indexes,
|
|
280
|
+
"check_tap_principal": check_tap_principal,
|
|
284
281
|
},
|
|
285
282
|
)
|
|
286
283
|
except ValidationError as e:
|
|
@@ -22,7 +22,6 @@
|
|
|
22
22
|
from __future__ import annotations
|
|
23
23
|
|
|
24
24
|
import logging
|
|
25
|
-
import re
|
|
26
25
|
from collections.abc import Mapping, Sequence
|
|
27
26
|
from enum import StrEnum, auto
|
|
28
27
|
from typing import Annotated, Any, Literal, TypeAlias
|
|
@@ -30,13 +29,10 @@ from typing import Annotated, Any, Literal, TypeAlias
|
|
|
30
29
|
from astropy import units as units # type: ignore
|
|
31
30
|
from astropy.io.votable import ucd # type: ignore
|
|
32
31
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
|
|
33
|
-
from sqlalchemy import dialects
|
|
34
|
-
from sqlalchemy import types as sqa_types
|
|
35
|
-
from sqlalchemy.engine import create_mock_engine
|
|
36
|
-
from sqlalchemy.engine.interfaces import Dialect
|
|
37
|
-
from sqlalchemy.types import TypeEngine
|
|
38
32
|
|
|
33
|
+
from .db.dialects import get_supported_dialects
|
|
39
34
|
from .db.sqltypes import get_type_func
|
|
35
|
+
from .db.utils import string_to_typeengine
|
|
40
36
|
from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
|
|
41
37
|
|
|
42
38
|
logger = logging.getLogger(__name__)
|
|
@@ -100,7 +96,7 @@ class BaseObject(BaseModel):
|
|
|
100
96
|
def check_description(self, info: ValidationInfo) -> BaseObject:
|
|
101
97
|
"""Check that the description is present if required."""
|
|
102
98
|
context = info.context
|
|
103
|
-
if not context or not context.get("
|
|
99
|
+
if not context or not context.get("check_description", False):
|
|
104
100
|
return self
|
|
105
101
|
if self.description is None or self.description == "":
|
|
106
102
|
raise ValueError("Description is required and must be non-empty")
|
|
@@ -127,51 +123,6 @@ class DataType(StrEnum):
|
|
|
127
123
|
timestamp = auto()
|
|
128
124
|
|
|
129
125
|
|
|
130
|
-
_DIALECTS = {
|
|
131
|
-
"mysql": create_mock_engine("mysql://", executor=None).dialect,
|
|
132
|
-
"postgresql": create_mock_engine("postgresql://", executor=None).dialect,
|
|
133
|
-
}
|
|
134
|
-
"""Dictionary of dialect names to SQLAlchemy dialects."""
|
|
135
|
-
|
|
136
|
-
_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")}
|
|
137
|
-
"""Dictionary of dialect names to SQLAlchemy dialect modules."""
|
|
138
|
-
|
|
139
|
-
_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
|
|
140
|
-
"""Regular expression to match data types in the form "type(length)"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def string_to_typeengine(
|
|
144
|
-
type_string: str, dialect: Dialect | None = None, length: int | None = None
|
|
145
|
-
) -> TypeEngine:
|
|
146
|
-
match = _DATATYPE_REGEXP.search(type_string)
|
|
147
|
-
if not match:
|
|
148
|
-
raise ValueError(f"Invalid type string: {type_string}")
|
|
149
|
-
|
|
150
|
-
type_name, _, params = match.groups()
|
|
151
|
-
if dialect is None:
|
|
152
|
-
type_class = getattr(sqa_types, type_name.upper(), None)
|
|
153
|
-
else:
|
|
154
|
-
try:
|
|
155
|
-
dialect_module = _DIALECT_MODULES[dialect.name]
|
|
156
|
-
except KeyError:
|
|
157
|
-
raise ValueError(f"Unsupported dialect: {dialect}")
|
|
158
|
-
type_class = getattr(dialect_module, type_name.upper(), None)
|
|
159
|
-
|
|
160
|
-
if not type_class:
|
|
161
|
-
raise ValueError(f"Unsupported type: {type_class}")
|
|
162
|
-
|
|
163
|
-
if params:
|
|
164
|
-
params = [int(param) if param.isdigit() else param for param in params.split(",")]
|
|
165
|
-
type_obj = type_class(*params)
|
|
166
|
-
else:
|
|
167
|
-
type_obj = type_class()
|
|
168
|
-
|
|
169
|
-
if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
|
|
170
|
-
type_obj.length = length
|
|
171
|
-
|
|
172
|
-
return type_obj
|
|
173
|
-
|
|
174
|
-
|
|
175
126
|
class Column(BaseObject):
|
|
176
127
|
"""A column in a table."""
|
|
177
128
|
|
|
@@ -257,12 +208,11 @@ class Column(BaseObject):
|
|
|
257
208
|
raise ValueError(f"Invalid IVOA UCD: {e}")
|
|
258
209
|
return ivoa_ucd
|
|
259
210
|
|
|
260
|
-
@model_validator(mode="
|
|
261
|
-
|
|
262
|
-
def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:
|
|
211
|
+
@model_validator(mode="after")
|
|
212
|
+
def check_units(self) -> Column:
|
|
263
213
|
"""Check that units are valid."""
|
|
264
|
-
fits_unit =
|
|
265
|
-
ivoa_unit =
|
|
214
|
+
fits_unit = self.fits_tunit
|
|
215
|
+
ivoa_unit = self.ivoa_unit
|
|
266
216
|
|
|
267
217
|
if fits_unit and ivoa_unit:
|
|
268
218
|
raise ValueError("Column cannot have both FITS and IVOA units")
|
|
@@ -274,7 +224,7 @@ class Column(BaseObject):
|
|
|
274
224
|
except ValueError as e:
|
|
275
225
|
raise ValueError(f"Invalid unit: {e}")
|
|
276
226
|
|
|
277
|
-
return
|
|
227
|
+
return self
|
|
278
228
|
|
|
279
229
|
@model_validator(mode="before")
|
|
280
230
|
@classmethod
|
|
@@ -299,12 +249,15 @@ class Column(BaseObject):
|
|
|
299
249
|
return values
|
|
300
250
|
|
|
301
251
|
@model_validator(mode="after")
|
|
302
|
-
def
|
|
252
|
+
def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
|
|
303
253
|
"""Check for redundant datatypes on columns."""
|
|
304
254
|
context = info.context
|
|
305
255
|
if not context or not context.get("check_redundant_datatypes", False):
|
|
306
256
|
return self
|
|
307
|
-
if all(
|
|
257
|
+
if all(
|
|
258
|
+
getattr(self, f"{dialect}:datatype", None) is not None
|
|
259
|
+
for dialect in get_supported_dialects().keys()
|
|
260
|
+
):
|
|
308
261
|
return self
|
|
309
262
|
|
|
310
263
|
datatype = self.datatype
|
|
@@ -317,7 +270,7 @@ class Column(BaseObject):
|
|
|
317
270
|
else:
|
|
318
271
|
datatype_obj = datatype_func()
|
|
319
272
|
|
|
320
|
-
for dialect_name, dialect in
|
|
273
|
+
for dialect_name, dialect in get_supported_dialects().items():
|
|
321
274
|
db_annotation = f"{dialect_name}_datatype"
|
|
322
275
|
if datatype_string := self.model_dump().get(db_annotation):
|
|
323
276
|
db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
|
|
@@ -465,6 +418,29 @@ class Table(BaseObject):
|
|
|
465
418
|
raise ValueError("Column names must be unique")
|
|
466
419
|
return columns
|
|
467
420
|
|
|
421
|
+
@model_validator(mode="after")
|
|
422
|
+
def check_tap_table_index(self, info: ValidationInfo) -> Table:
|
|
423
|
+
"""Check that the table has a TAP table index."""
|
|
424
|
+
context = info.context
|
|
425
|
+
if not context or not context.get("check_tap_table_indexes", False):
|
|
426
|
+
return self
|
|
427
|
+
if self.tap_table_index is None:
|
|
428
|
+
raise ValueError("Table is missing a TAP table index")
|
|
429
|
+
return self
|
|
430
|
+
|
|
431
|
+
@model_validator(mode="after")
|
|
432
|
+
def check_tap_principal(self, info: ValidationInfo) -> Table:
|
|
433
|
+
"""Check that at least one column is flagged as 'principal' for TAP
|
|
434
|
+
purposes.
|
|
435
|
+
"""
|
|
436
|
+
context = info.context
|
|
437
|
+
if not context or not context.get("check_tap_principal", False):
|
|
438
|
+
return self
|
|
439
|
+
for col in self.columns:
|
|
440
|
+
if col.tap_principal == 1:
|
|
441
|
+
return self
|
|
442
|
+
raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
|
|
443
|
+
|
|
468
444
|
|
|
469
445
|
class SchemaVersion(BaseModel):
|
|
470
446
|
"""The version of the schema."""
|
|
@@ -554,6 +530,21 @@ class Schema(BaseObject):
|
|
|
554
530
|
raise ValueError("Table names must be unique")
|
|
555
531
|
return tables
|
|
556
532
|
|
|
533
|
+
@model_validator(mode="after")
|
|
534
|
+
def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
|
|
535
|
+
"""Check that the TAP table indexes are unique."""
|
|
536
|
+
context = info.context
|
|
537
|
+
if not context or not context.get("check_tap_table_indexes", False):
|
|
538
|
+
return self
|
|
539
|
+
table_indicies = set()
|
|
540
|
+
for table in self.tables:
|
|
541
|
+
table_index = table.tap_table_index
|
|
542
|
+
if table_index is not None:
|
|
543
|
+
if table_index in table_indicies:
|
|
544
|
+
raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
|
|
545
|
+
table_indicies.add(table_index)
|
|
546
|
+
return self
|
|
547
|
+
|
|
557
548
|
def _create_id_map(self: Schema) -> Schema:
|
|
558
549
|
"""Create a map of IDs to objects.
|
|
559
550
|
|
|
@@ -566,7 +557,6 @@ class Schema(BaseObject):
|
|
|
566
557
|
return self
|
|
567
558
|
visitor: SchemaIdVisitor = SchemaIdVisitor()
|
|
568
559
|
visitor.visit_schema(self)
|
|
569
|
-
logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects")
|
|
570
560
|
if len(visitor.duplicates):
|
|
571
561
|
raise ValueError(
|
|
572
562
|
"Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# This file is part of felis.
|
|
2
|
+
#
|
|
3
|
+
# Developed for the LSST Data Management System.
|
|
4
|
+
# This product includes software developed by the LSST Project
|
|
5
|
+
# (https://www.lsst.org).
|
|
6
|
+
# See the COPYRIGHT file at the top-level directory of this distribution
|
|
7
|
+
# for details of code ownership.
|
|
8
|
+
#
|
|
9
|
+
# This program is free software: you can redistribute it and/or modify
|
|
10
|
+
# it under the terms of the GNU General Public License as published by
|
|
11
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
12
|
+
# (at your option) any later version.
|
|
13
|
+
#
|
|
14
|
+
# This program is distributed in the hope that it will be useful,
|
|
15
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
16
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
17
|
+
# GNU General Public License for more details.
|
|
18
|
+
#
|
|
19
|
+
# You should have received a copy of the GNU General Public License
|
|
20
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
from types import ModuleType
|
|
24
|
+
|
|
25
|
+
from sqlalchemy import dialects
|
|
26
|
+
from sqlalchemy.engine import Dialect
|
|
27
|
+
from sqlalchemy.engine.mock import create_mock_engine
|
|
28
|
+
|
|
29
|
+
from .sqltypes import MYSQL, ORACLE, POSTGRES, SQLITE
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
_DIALECT_NAMES = [MYSQL, POSTGRES, SQLITE, ORACLE]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _dialect(dialect_name: str) -> Dialect:
|
|
37
|
+
"""Create the SQLAlchemy dialect for the given name."""
|
|
38
|
+
return create_mock_engine(f"{dialect_name}://", executor=None).dialect
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
_DIALECTS = {name: _dialect(name) for name in _DIALECT_NAMES}
|
|
42
|
+
"""Dictionary of dialect names to SQLAlchemy dialects."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_supported_dialects() -> dict[str, Dialect]:
|
|
46
|
+
"""Get a dictionary of the supported SQLAlchemy dialects."""
|
|
47
|
+
return _DIALECTS
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _dialect_module(dialect_name: str) -> ModuleType:
|
|
51
|
+
"""Get the SQLAlchemy dialect module for the given name."""
|
|
52
|
+
return getattr(dialects, dialect_name)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_DIALECT_MODULES = {name: _dialect_module(name) for name in _DIALECT_NAMES}
|
|
56
|
+
"""Dictionary of dialect names to SQLAlchemy modules for type instantiation."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_dialect_module(dialect_name: str) -> ModuleType:
|
|
60
|
+
"""Get the SQLAlchemy dialect module for the given name."""
|
|
61
|
+
if dialect_name not in _DIALECT_MODULES:
|
|
62
|
+
raise ValueError(f"Unsupported dialect: {dialect_name}")
|
|
63
|
+
return _DIALECT_MODULES[dialect_name]
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# This file is part of felis.
|
|
2
|
+
#
|
|
3
|
+
# Developed for the LSST Data Management System.
|
|
4
|
+
# This product includes software developed by the LSST Project
|
|
5
|
+
# (https://www.lsst.org).
|
|
6
|
+
# See the COPYRIGHT file at the top-level directory of this distribution
|
|
7
|
+
# for details of code ownership.
|
|
8
|
+
#
|
|
9
|
+
# This program is free software: you can redistribute it and/or modify
|
|
10
|
+
# it under the terms of the GNU General Public License as published by
|
|
11
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
12
|
+
# (at your option) any later version.
|
|
13
|
+
#
|
|
14
|
+
# This program is distributed in the hope that it will be useful,
|
|
15
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
16
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
17
|
+
# GNU General Public License for more details.
|
|
18
|
+
#
|
|
19
|
+
# You should have received a copy of the GNU General Public License
|
|
20
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
import re
|
|
26
|
+
from typing import IO, Any
|
|
27
|
+
|
|
28
|
+
from sqlalchemy import MetaData, types
|
|
29
|
+
from sqlalchemy.engine import Dialect, Engine, ResultProxy
|
|
30
|
+
from sqlalchemy.engine.mock import MockConnection, create_mock_engine
|
|
31
|
+
from sqlalchemy.engine.url import URL
|
|
32
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
33
|
+
from sqlalchemy.schema import CreateSchema, DropSchema
|
|
34
|
+
from sqlalchemy.sql import text
|
|
35
|
+
from sqlalchemy.types import TypeEngine
|
|
36
|
+
|
|
37
|
+
from .dialects import get_dialect_module
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger("felis")
|
|
40
|
+
|
|
41
|
+
_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
|
|
42
|
+
"""Regular expression to match data types in the form "type(length)"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def string_to_typeengine(
|
|
46
|
+
type_string: str, dialect: Dialect | None = None, length: int | None = None
|
|
47
|
+
) -> TypeEngine:
|
|
48
|
+
"""Convert a string representation of a data type to a SQLAlchemy
|
|
49
|
+
TypeEngine.
|
|
50
|
+
"""
|
|
51
|
+
match = _DATATYPE_REGEXP.search(type_string)
|
|
52
|
+
if not match:
|
|
53
|
+
raise ValueError(f"Invalid type string: {type_string}")
|
|
54
|
+
|
|
55
|
+
type_name, _, params = match.groups()
|
|
56
|
+
if dialect is None:
|
|
57
|
+
type_class = getattr(types, type_name.upper(), None)
|
|
58
|
+
else:
|
|
59
|
+
try:
|
|
60
|
+
dialect_module = get_dialect_module(dialect.name)
|
|
61
|
+
except KeyError:
|
|
62
|
+
raise ValueError(f"Unsupported dialect: {dialect}")
|
|
63
|
+
type_class = getattr(dialect_module, type_name.upper(), None)
|
|
64
|
+
|
|
65
|
+
if not type_class:
|
|
66
|
+
raise ValueError(f"Unsupported type: {type_class}")
|
|
67
|
+
|
|
68
|
+
if params:
|
|
69
|
+
params = [int(param) if param.isdigit() else param for param in params.split(",")]
|
|
70
|
+
type_obj = type_class(*params)
|
|
71
|
+
else:
|
|
72
|
+
type_obj = type_class()
|
|
73
|
+
|
|
74
|
+
if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
|
|
75
|
+
type_obj.length = length
|
|
76
|
+
|
|
77
|
+
return type_obj
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SQLWriter:
|
|
81
|
+
"""Writes SQL statements to stdout or a file."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, file: IO[str] | None = None) -> None:
|
|
84
|
+
"""Initialize the SQL writer.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
file : `io.TextIOBase` or `None`, optional
|
|
89
|
+
The file to write the SQL statements to. If None, the statements
|
|
90
|
+
will be written to stdout.
|
|
91
|
+
"""
|
|
92
|
+
self.file = file
|
|
93
|
+
self.dialect: Dialect | None = None
|
|
94
|
+
|
|
95
|
+
def write(self, sql: Any, *multiparams: Any, **params: Any) -> None:
|
|
96
|
+
"""Write the SQL statement to a file or stdout.
|
|
97
|
+
|
|
98
|
+
Statements with parameters will be formatted with the values
|
|
99
|
+
inserted into the resultant SQL output.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
sql : `typing.Any`
|
|
104
|
+
The SQL statement to write.
|
|
105
|
+
multiparams : `typing.Any`
|
|
106
|
+
The multiparams to use for the SQL statement.
|
|
107
|
+
params : `typing.Any`
|
|
108
|
+
The params to use for the SQL statement.
|
|
109
|
+
"""
|
|
110
|
+
compiled = sql.compile(dialect=self.dialect)
|
|
111
|
+
sql_str = str(compiled) + ";"
|
|
112
|
+
params_list = [compiled.params]
|
|
113
|
+
for params in params_list:
|
|
114
|
+
if not params:
|
|
115
|
+
print(sql_str, file=self.file)
|
|
116
|
+
continue
|
|
117
|
+
new_params = {}
|
|
118
|
+
for key, value in params.items():
|
|
119
|
+
if isinstance(value, str):
|
|
120
|
+
new_params[key] = f"'{value}'"
|
|
121
|
+
elif value is None:
|
|
122
|
+
new_params[key] = "null"
|
|
123
|
+
else:
|
|
124
|
+
new_params[key] = value
|
|
125
|
+
print(sql_str % new_params, file=self.file)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ConnectionWrapper:
|
|
129
|
+
"""A wrapper for a SQLAlchemy engine or mock connection which provides a
|
|
130
|
+
consistent interface for executing SQL statements.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, engine: Engine | MockConnection):
|
|
134
|
+
"""Initialize the connection wrapper.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
|
|
139
|
+
The SQLAlchemy engine or mock connection to wrap.
|
|
140
|
+
"""
|
|
141
|
+
self.engine = engine
|
|
142
|
+
|
|
143
|
+
def execute(self, statement: Any) -> ResultProxy:
|
|
144
|
+
"""Execute a SQL statement on the engine and return the result."""
|
|
145
|
+
if isinstance(statement, str):
|
|
146
|
+
statement = text(statement)
|
|
147
|
+
if isinstance(self.engine, MockConnection):
|
|
148
|
+
return self.engine.connect().execute(statement)
|
|
149
|
+
else:
|
|
150
|
+
with self.engine.begin() as connection:
|
|
151
|
+
result = connection.execute(statement)
|
|
152
|
+
return result
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class DatabaseContext:
|
|
156
|
+
"""A class for managing the schema and its database connection."""
|
|
157
|
+
|
|
158
|
+
def __init__(self, metadata: MetaData, engine: Engine | MockConnection):
|
|
159
|
+
"""Initialize the database context.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
metadata : `sqlalchemy.MetaData`
|
|
164
|
+
The SQLAlchemy metadata object.
|
|
165
|
+
|
|
166
|
+
engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
|
|
167
|
+
The SQLAlchemy engine or mock connection object.
|
|
168
|
+
"""
|
|
169
|
+
self.engine = engine
|
|
170
|
+
self.dialect_name = engine.dialect.name
|
|
171
|
+
self.metadata = metadata
|
|
172
|
+
self.conn = ConnectionWrapper(engine)
|
|
173
|
+
|
|
174
|
+
def create_if_not_exists(self) -> None:
|
|
175
|
+
"""Create the schema in the database if it does not exist.
|
|
176
|
+
|
|
177
|
+
In MySQL, this will create a new database. In PostgreSQL, it will
|
|
178
|
+
create a new schema. For other variants, this is an unsupported
|
|
179
|
+
operation.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
engine: `sqlalchemy.Engine`
|
|
184
|
+
The SQLAlchemy engine object.
|
|
185
|
+
schema_name: `str`
|
|
186
|
+
The name of the schema (or database) to create.
|
|
187
|
+
"""
|
|
188
|
+
schema_name = self.metadata.schema
|
|
189
|
+
try:
|
|
190
|
+
if self.dialect_name == "mysql":
|
|
191
|
+
logger.debug(f"Creating MySQL database: {schema_name}")
|
|
192
|
+
self.conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {schema_name}"))
|
|
193
|
+
elif self.dialect_name == "postgresql":
|
|
194
|
+
logger.debug(f"Creating PG schema: {schema_name}")
|
|
195
|
+
self.conn.execute(CreateSchema(schema_name, if_not_exists=True))
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError("Unsupported database type:" + self.dialect_name)
|
|
198
|
+
except SQLAlchemyError as e:
|
|
199
|
+
logger.error(f"Error creating schema: {e}")
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
def drop_if_exists(self) -> None:
|
|
203
|
+
"""Drop the schema in the database if it exists.
|
|
204
|
+
|
|
205
|
+
In MySQL, this will drop a database. In PostgreSQL, it will drop a
|
|
206
|
+
schema. For other variants, this is unsupported for now.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
engine: `sqlalchemy.Engine`
|
|
211
|
+
The SQLAlchemy engine object.
|
|
212
|
+
schema_name: `str`
|
|
213
|
+
The name of the schema (or database) to drop.
|
|
214
|
+
"""
|
|
215
|
+
schema_name = self.metadata.schema
|
|
216
|
+
try:
|
|
217
|
+
if self.dialect_name == "mysql":
|
|
218
|
+
logger.debug(f"Dropping MySQL database if exists: {schema_name}")
|
|
219
|
+
self.conn.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
|
|
220
|
+
elif self.dialect_name == "postgresql":
|
|
221
|
+
logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
|
|
222
|
+
self.conn.execute(DropSchema(schema_name, if_exists=True, cascade=True))
|
|
223
|
+
else:
|
|
224
|
+
raise ValueError(f"Unsupported database type: {self.dialect_name}")
|
|
225
|
+
except SQLAlchemyError as e:
|
|
226
|
+
logger.error(f"Error dropping schema: {e}")
|
|
227
|
+
raise
|
|
228
|
+
|
|
229
|
+
def create_all(self) -> None:
|
|
230
|
+
"""Create all tables in the schema using the metadata object."""
|
|
231
|
+
self.metadata.create_all(self.engine)
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
|
|
235
|
+
"""Create a mock engine for testing or dumping DDL statements.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
engine_url : `sqlalchemy.engine.url.URL`
|
|
240
|
+
The SQLAlchemy engine URL.
|
|
241
|
+
output_file : `typing.IO` [ `str` ] or `None`, optional
|
|
242
|
+
The file to write the SQL statements to. If None, the statements
|
|
243
|
+
will be written to stdout.
|
|
244
|
+
"""
|
|
245
|
+
writer = SQLWriter(output_file)
|
|
246
|
+
engine = create_mock_engine(engine_url, executor=writer.write)
|
|
247
|
+
writer.dialect = engine.dialect
|
|
248
|
+
return engine
|