lsst-felis 27.2024.2300__tar.gz → 27.2024.2500__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lsst-felis might be problematic. Click here for more details.

Files changed (34) hide show
  1. {lsst_felis-27.2024.2300/python/lsst_felis.egg-info → lsst_felis-27.2024.2500}/PKG-INFO +1 -1
  2. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/cli.py +27 -30
  3. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/datamodel.py +52 -62
  4. lsst_felis-27.2024.2500/python/felis/db/dialects.py +63 -0
  5. lsst_felis-27.2024.2500/python/felis/db/utils.py +248 -0
  6. lsst_felis-27.2024.2300/python/felis/db/_variants.py → lsst_felis-27.2024.2500/python/felis/db/variants.py +29 -22
  7. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/metadata.py +2 -185
  8. lsst_felis-27.2024.2500/python/felis/version.py +2 -0
  9. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500/python/lsst_felis.egg-info}/PKG-INFO +1 -1
  10. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/SOURCES.txt +4 -4
  11. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_cli.py +12 -20
  12. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_datamodel.py +115 -54
  13. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_metadata.py +2 -1
  14. lsst_felis-27.2024.2300/python/felis/validation.py +0 -103
  15. lsst_felis-27.2024.2300/python/felis/version.py +0 -2
  16. lsst_felis-27.2024.2300/tests/test_validation.py +0 -233
  17. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/COPYRIGHT +0 -0
  18. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/LICENSE +0 -0
  19. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/README.rst +0 -0
  20. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/pyproject.toml +0 -0
  21. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/__init__.py +0 -0
  22. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/db/__init__.py +0 -0
  23. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/db/sqltypes.py +0 -0
  24. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/py.typed +0 -0
  25. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/tap.py +0 -0
  26. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/felis/types.py +0 -0
  27. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/dependency_links.txt +0 -0
  28. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/entry_points.txt +0 -0
  29. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/requires.txt +0 -0
  30. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/top_level.txt +0 -0
  31. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/python/lsst_felis.egg-info/zip-safe +0 -0
  32. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/setup.cfg +0 -0
  33. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_datatypes.py +0 -0
  34. {lsst_felis-27.2024.2300 → lsst_felis-27.2024.2500}/tests/test_tap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lsst-felis
3
- Version: 27.2024.2300
3
+ Version: 27.2024.2500
4
4
  Summary: A vocabulary for describing catalogs and acting on those descriptions
5
5
  Author-email: Rubin Observatory Data Management <dm-admin@lists.lsst.org>
6
6
  License: GNU General Public License v3 or later (GPLv3+)
@@ -29,14 +29,14 @@ from typing import IO
29
29
  import click
30
30
  import yaml
31
31
  from pydantic import ValidationError
32
- from sqlalchemy.engine import Engine, create_engine, create_mock_engine, make_url
32
+ from sqlalchemy.engine import Engine, create_engine, make_url
33
33
  from sqlalchemy.engine.mock import MockConnection
34
34
 
35
35
  from . import __version__
36
36
  from .datamodel import Schema
37
- from .metadata import DatabaseContext, InsertDump, MetaDataBuilder
37
+ from .db.utils import DatabaseContext
38
+ from .metadata import MetaDataBuilder
38
39
  from .tap import Tap11Base, TapLoadingVisitor, init_tables
39
- from .validation import get_schema
40
40
 
41
41
  logger = logging.getLogger("felis")
42
42
 
@@ -92,29 +92,27 @@ def create(
92
92
  """Create database objects from the Felis file."""
93
93
  yaml_data = yaml.safe_load(file)
94
94
  schema = Schema.model_validate(yaml_data)
95
- url_obj = make_url(engine_url)
95
+ url = make_url(engine_url)
96
96
  if schema_name:
97
97
  logger.info(f"Overriding schema name with: {schema_name}")
98
98
  schema.name = schema_name
99
- elif url_obj.drivername == "sqlite":
99
+ elif url.drivername == "sqlite":
100
100
  logger.info("Overriding schema name for sqlite with: main")
101
101
  schema.name = "main"
102
- if not url_obj.host and not url_obj.drivername == "sqlite":
102
+ if not url.host and not url.drivername == "sqlite":
103
103
  dry_run = True
104
104
  logger.info("Forcing dry run for non-sqlite engine URL with no host")
105
105
 
106
- builder = MetaDataBuilder(schema)
107
- builder.build()
108
- metadata = builder.metadata
106
+ metadata = MetaDataBuilder(schema).build()
109
107
  logger.debug(f"Created metadata with schema name: {metadata.schema}")
110
108
 
111
109
  engine: Engine | MockConnection
112
110
  if not dry_run and not output_file:
113
- engine = create_engine(engine_url, echo=echo)
111
+ engine = create_engine(url, echo=echo)
114
112
  else:
115
113
  if dry_run:
116
114
  logger.info("Dry run will be executed")
117
- engine = DatabaseContext.create_mock_engine(url_obj, output_file)
115
+ engine = DatabaseContext.create_mock_engine(url, output_file)
118
116
  if output_file:
119
117
  logger.info("Writing SQL output to: " + output_file.name)
120
118
 
@@ -229,10 +227,7 @@ def load_tap(
229
227
  )
230
228
  tap_visitor.visit_schema(schema)
231
229
  else:
232
- _insert_dump = InsertDump()
233
- conn = create_mock_engine(make_url(engine_url), executor=_insert_dump.dump, paramstyle="pyformat")
234
- # After the engine is created, update the executor with the dialect
235
- _insert_dump.dialect = conn.dialect
230
+ conn = DatabaseContext.create_mock_engine(engine_url)
236
231
 
237
232
  tap_visitor = TapLoadingVisitor.from_mock_connection(
238
233
  conn,
@@ -245,42 +240,44 @@ def load_tap(
245
240
 
246
241
 
247
242
  @cli.command("validate")
243
+ @click.option("--check-description", is_flag=True, help="Require description for all objects", default=False)
248
244
  @click.option(
249
- "-s",
250
- "--schema-name",
251
- help="Schema name for validation",
252
- type=click.Choice(["RSP", "default"]),
253
- default="default",
245
+ "--check-redundant-datatypes", is_flag=True, help="Check for redundant datatypes", default=False
254
246
  )
255
247
  @click.option(
256
- "-d", "--require-description", is_flag=True, help="Require description for all objects", default=False
248
+ "--check-tap-table-indexes",
249
+ is_flag=True,
250
+ help="Check that every table has a unique TAP table index",
251
+ default=False,
257
252
  )
258
253
  @click.option(
259
- "-t", "--check-redundant-datatypes", is_flag=True, help="Check for redundant datatypes", default=False
254
+ "--check-tap-principal",
255
+ is_flag=True,
256
+ help="Check that at least one column per table is flagged as TAP principal",
257
+ default=False,
260
258
  )
261
259
  @click.argument("files", nargs=-1, type=click.File())
262
260
  def validate(
263
- schema_name: str,
264
- require_description: bool,
261
+ check_description: bool,
265
262
  check_redundant_datatypes: bool,
263
+ check_tap_table_indexes: bool,
264
+ check_tap_principal: bool,
266
265
  files: Iterable[io.TextIOBase],
267
266
  ) -> None:
268
267
  """Validate one or more felis YAML files."""
269
- schema_class = get_schema(schema_name)
270
- if schema_name != "default":
271
- logger.info(f"Using schema '{schema_class.__name__}'")
272
-
273
268
  rc = 0
274
269
  for file in files:
275
270
  file_name = getattr(file, "name", None)
276
271
  logger.info(f"Validating {file_name}")
277
272
  try:
278
273
  data = yaml.load(file, Loader=yaml.SafeLoader)
279
- schema_class.model_validate(
274
+ Schema.model_validate(
280
275
  data,
281
276
  context={
277
+ "check_description": check_description,
282
278
  "check_redundant_datatypes": check_redundant_datatypes,
283
- "require_description": require_description,
279
+ "check_tap_table_indexes": check_tap_table_indexes,
280
+ "check_tap_principal": check_tap_principal,
284
281
  },
285
282
  )
286
283
  except ValidationError as e:
@@ -22,7 +22,6 @@
22
22
  from __future__ import annotations
23
23
 
24
24
  import logging
25
- import re
26
25
  from collections.abc import Mapping, Sequence
27
26
  from enum import StrEnum, auto
28
27
  from typing import Annotated, Any, Literal, TypeAlias
@@ -30,13 +29,10 @@ from typing import Annotated, Any, Literal, TypeAlias
30
29
  from astropy import units as units # type: ignore
31
30
  from astropy.io.votable import ucd # type: ignore
32
31
  from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
33
- from sqlalchemy import dialects
34
- from sqlalchemy import types as sqa_types
35
- from sqlalchemy.engine import create_mock_engine
36
- from sqlalchemy.engine.interfaces import Dialect
37
- from sqlalchemy.types import TypeEngine
38
32
 
33
+ from .db.dialects import get_supported_dialects
39
34
  from .db.sqltypes import get_type_func
35
+ from .db.utils import string_to_typeengine
40
36
  from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
41
37
 
42
38
  logger = logging.getLogger(__name__)
@@ -100,7 +96,7 @@ class BaseObject(BaseModel):
100
96
  def check_description(self, info: ValidationInfo) -> BaseObject:
101
97
  """Check that the description is present if required."""
102
98
  context = info.context
103
- if not context or not context.get("require_description", False):
99
+ if not context or not context.get("check_description", False):
104
100
  return self
105
101
  if self.description is None or self.description == "":
106
102
  raise ValueError("Description is required and must be non-empty")
@@ -127,51 +123,6 @@ class DataType(StrEnum):
127
123
  timestamp = auto()
128
124
 
129
125
 
130
- _DIALECTS = {
131
- "mysql": create_mock_engine("mysql://", executor=None).dialect,
132
- "postgresql": create_mock_engine("postgresql://", executor=None).dialect,
133
- }
134
- """Dictionary of dialect names to SQLAlchemy dialects."""
135
-
136
- _DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")}
137
- """Dictionary of dialect names to SQLAlchemy dialect modules."""
138
-
139
- _DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
140
- """Regular expression to match data types in the form "type(length)"""
141
-
142
-
143
- def string_to_typeengine(
144
- type_string: str, dialect: Dialect | None = None, length: int | None = None
145
- ) -> TypeEngine:
146
- match = _DATATYPE_REGEXP.search(type_string)
147
- if not match:
148
- raise ValueError(f"Invalid type string: {type_string}")
149
-
150
- type_name, _, params = match.groups()
151
- if dialect is None:
152
- type_class = getattr(sqa_types, type_name.upper(), None)
153
- else:
154
- try:
155
- dialect_module = _DIALECT_MODULES[dialect.name]
156
- except KeyError:
157
- raise ValueError(f"Unsupported dialect: {dialect}")
158
- type_class = getattr(dialect_module, type_name.upper(), None)
159
-
160
- if not type_class:
161
- raise ValueError(f"Unsupported type: {type_class}")
162
-
163
- if params:
164
- params = [int(param) if param.isdigit() else param for param in params.split(",")]
165
- type_obj = type_class(*params)
166
- else:
167
- type_obj = type_class()
168
-
169
- if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
170
- type_obj.length = length
171
-
172
- return type_obj
173
-
174
-
175
126
  class Column(BaseObject):
176
127
  """A column in a table."""
177
128
 
@@ -257,12 +208,11 @@ class Column(BaseObject):
257
208
  raise ValueError(f"Invalid IVOA UCD: {e}")
258
209
  return ivoa_ucd
259
210
 
260
- @model_validator(mode="before")
261
- @classmethod
262
- def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:
211
+ @model_validator(mode="after")
212
+ def check_units(self) -> Column:
263
213
  """Check that units are valid."""
264
- fits_unit = values.get("fits:tunit")
265
- ivoa_unit = values.get("ivoa:unit")
214
+ fits_unit = self.fits_tunit
215
+ ivoa_unit = self.ivoa_unit
266
216
 
267
217
  if fits_unit and ivoa_unit:
268
218
  raise ValueError("Column cannot have both FITS and IVOA units")
@@ -274,7 +224,7 @@ class Column(BaseObject):
274
224
  except ValueError as e:
275
225
  raise ValueError(f"Invalid unit: {e}")
276
226
 
277
- return values
227
+ return self
278
228
 
279
229
  @model_validator(mode="before")
280
230
  @classmethod
@@ -299,12 +249,15 @@ class Column(BaseObject):
299
249
  return values
300
250
 
301
251
  @model_validator(mode="after")
302
- def check_datatypes(self, info: ValidationInfo) -> Column:
252
+ def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
303
253
  """Check for redundant datatypes on columns."""
304
254
  context = info.context
305
255
  if not context or not context.get("check_redundant_datatypes", False):
306
256
  return self
307
- if all(getattr(self, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()):
257
+ if all(
258
+ getattr(self, f"{dialect}:datatype", None) is not None
259
+ for dialect in get_supported_dialects().keys()
260
+ ):
308
261
  return self
309
262
 
310
263
  datatype = self.datatype
@@ -317,7 +270,7 @@ class Column(BaseObject):
317
270
  else:
318
271
  datatype_obj = datatype_func()
319
272
 
320
- for dialect_name, dialect in _DIALECTS.items():
273
+ for dialect_name, dialect in get_supported_dialects().items():
321
274
  db_annotation = f"{dialect_name}_datatype"
322
275
  if datatype_string := self.model_dump().get(db_annotation):
323
276
  db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
@@ -465,6 +418,29 @@ class Table(BaseObject):
465
418
  raise ValueError("Column names must be unique")
466
419
  return columns
467
420
 
421
+ @model_validator(mode="after")
422
+ def check_tap_table_index(self, info: ValidationInfo) -> Table:
423
+ """Check that the table has a TAP table index."""
424
+ context = info.context
425
+ if not context or not context.get("check_tap_table_indexes", False):
426
+ return self
427
+ if self.tap_table_index is None:
428
+ raise ValueError("Table is missing a TAP table index")
429
+ return self
430
+
431
+ @model_validator(mode="after")
432
+ def check_tap_principal(self, info: ValidationInfo) -> Table:
433
+ """Check that at least one column is flagged as 'principal' for TAP
434
+ purposes.
435
+ """
436
+ context = info.context
437
+ if not context or not context.get("check_tap_principal", False):
438
+ return self
439
+ for col in self.columns:
440
+ if col.tap_principal == 1:
441
+ return self
442
+ raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
443
+
468
444
 
469
445
  class SchemaVersion(BaseModel):
470
446
  """The version of the schema."""
@@ -554,6 +530,21 @@ class Schema(BaseObject):
554
530
  raise ValueError("Table names must be unique")
555
531
  return tables
556
532
 
533
+ @model_validator(mode="after")
534
+ def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
535
+ """Check that the TAP table indexes are unique."""
536
+ context = info.context
537
+ if not context or not context.get("check_tap_table_indexes", False):
538
+ return self
539
+ table_indicies = set()
540
+ for table in self.tables:
541
+ table_index = table.tap_table_index
542
+ if table_index is not None:
543
+ if table_index in table_indicies:
544
+ raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
545
+ table_indicies.add(table_index)
546
+ return self
547
+
557
548
  def _create_id_map(self: Schema) -> Schema:
558
549
  """Create a map of IDs to objects.
559
550
 
@@ -566,7 +557,6 @@ class Schema(BaseObject):
566
557
  return self
567
558
  visitor: SchemaIdVisitor = SchemaIdVisitor()
568
559
  visitor.visit_schema(self)
569
- logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects")
570
560
  if len(visitor.duplicates):
571
561
  raise ValueError(
572
562
  "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
@@ -0,0 +1,63 @@
1
+ # This file is part of felis.
2
+ #
3
+ # Developed for the LSST Data Management System.
4
+ # This product includes software developed by the LSST Project
5
+ # (https://www.lsst.org).
6
+ # See the COPYRIGHT file at the top-level directory of this distribution
7
+ # for details of code ownership.
8
+ #
9
+ # This program is free software: you can redistribute it and/or modify
10
+ # it under the terms of the GNU General Public License as published by
11
+ # the Free Software Foundation, either version 3 of the License, or
12
+ # (at your option) any later version.
13
+ #
14
+ # This program is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
21
+
22
+ import logging
23
+ from types import ModuleType
24
+
25
+ from sqlalchemy import dialects
26
+ from sqlalchemy.engine import Dialect
27
+ from sqlalchemy.engine.mock import create_mock_engine
28
+
29
+ from .sqltypes import MYSQL, ORACLE, POSTGRES, SQLITE
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ _DIALECT_NAMES = [MYSQL, POSTGRES, SQLITE, ORACLE]
34
+
35
+
36
+ def _dialect(dialect_name: str) -> Dialect:
37
+ """Create the SQLAlchemy dialect for the given name."""
38
+ return create_mock_engine(f"{dialect_name}://", executor=None).dialect
39
+
40
+
41
+ _DIALECTS = {name: _dialect(name) for name in _DIALECT_NAMES}
42
+ """Dictionary of dialect names to SQLAlchemy dialects."""
43
+
44
+
45
+ def get_supported_dialects() -> dict[str, Dialect]:
46
+ """Get a dictionary of the supported SQLAlchemy dialects."""
47
+ return _DIALECTS
48
+
49
+
50
+ def _dialect_module(dialect_name: str) -> ModuleType:
51
+ """Get the SQLAlchemy dialect module for the given name."""
52
+ return getattr(dialects, dialect_name)
53
+
54
+
55
+ _DIALECT_MODULES = {name: _dialect_module(name) for name in _DIALECT_NAMES}
56
+ """Dictionary of dialect names to SQLAlchemy modules for type instantiation."""
57
+
58
+
59
+ def get_dialect_module(dialect_name: str) -> ModuleType:
60
+ """Get the SQLAlchemy dialect module for the given name."""
61
+ if dialect_name not in _DIALECT_MODULES:
62
+ raise ValueError(f"Unsupported dialect: {dialect_name}")
63
+ return _DIALECT_MODULES[dialect_name]
@@ -0,0 +1,248 @@
1
+ # This file is part of felis.
2
+ #
3
+ # Developed for the LSST Data Management System.
4
+ # This product includes software developed by the LSST Project
5
+ # (https://www.lsst.org).
6
+ # See the COPYRIGHT file at the top-level directory of this distribution
7
+ # for details of code ownership.
8
+ #
9
+ # This program is free software: you can redistribute it and/or modify
10
+ # it under the terms of the GNU General Public License as published by
11
+ # the Free Software Foundation, either version 3 of the License, or
12
+ # (at your option) any later version.
13
+ #
14
+ # This program is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ import re
26
+ from typing import IO, Any
27
+
28
+ from sqlalchemy import MetaData, types
29
+ from sqlalchemy.engine import Dialect, Engine, ResultProxy
30
+ from sqlalchemy.engine.mock import MockConnection, create_mock_engine
31
+ from sqlalchemy.engine.url import URL
32
+ from sqlalchemy.exc import SQLAlchemyError
33
+ from sqlalchemy.schema import CreateSchema, DropSchema
34
+ from sqlalchemy.sql import text
35
+ from sqlalchemy.types import TypeEngine
36
+
37
+ from .dialects import get_dialect_module
38
+
39
+ logger = logging.getLogger("felis")
40
+
41
+ _DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
42
+ """Regular expression to match data types in the form "type(length)"""
43
+
44
+
45
+ def string_to_typeengine(
46
+ type_string: str, dialect: Dialect | None = None, length: int | None = None
47
+ ) -> TypeEngine:
48
+ """Convert a string representation of a data type to a SQLAlchemy
49
+ TypeEngine.
50
+ """
51
+ match = _DATATYPE_REGEXP.search(type_string)
52
+ if not match:
53
+ raise ValueError(f"Invalid type string: {type_string}")
54
+
55
+ type_name, _, params = match.groups()
56
+ if dialect is None:
57
+ type_class = getattr(types, type_name.upper(), None)
58
+ else:
59
+ try:
60
+ dialect_module = get_dialect_module(dialect.name)
61
+ except KeyError:
62
+ raise ValueError(f"Unsupported dialect: {dialect}")
63
+ type_class = getattr(dialect_module, type_name.upper(), None)
64
+
65
+ if not type_class:
66
+ raise ValueError(f"Unsupported type: {type_class}")
67
+
68
+ if params:
69
+ params = [int(param) if param.isdigit() else param for param in params.split(",")]
70
+ type_obj = type_class(*params)
71
+ else:
72
+ type_obj = type_class()
73
+
74
+ if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
75
+ type_obj.length = length
76
+
77
+ return type_obj
78
+
79
+
80
+ class SQLWriter:
81
+ """Writes SQL statements to stdout or a file."""
82
+
83
+ def __init__(self, file: IO[str] | None = None) -> None:
84
+ """Initialize the SQL writer.
85
+
86
+ Parameters
87
+ ----------
88
+ file : `io.TextIOBase` or `None`, optional
89
+ The file to write the SQL statements to. If None, the statements
90
+ will be written to stdout.
91
+ """
92
+ self.file = file
93
+ self.dialect: Dialect | None = None
94
+
95
+ def write(self, sql: Any, *multiparams: Any, **params: Any) -> None:
96
+ """Write the SQL statement to a file or stdout.
97
+
98
+ Statements with parameters will be formatted with the values
99
+ inserted into the resultant SQL output.
100
+
101
+ Parameters
102
+ ----------
103
+ sql : `typing.Any`
104
+ The SQL statement to write.
105
+ multiparams : `typing.Any`
106
+ The multiparams to use for the SQL statement.
107
+ params : `typing.Any`
108
+ The params to use for the SQL statement.
109
+ """
110
+ compiled = sql.compile(dialect=self.dialect)
111
+ sql_str = str(compiled) + ";"
112
+ params_list = [compiled.params]
113
+ for params in params_list:
114
+ if not params:
115
+ print(sql_str, file=self.file)
116
+ continue
117
+ new_params = {}
118
+ for key, value in params.items():
119
+ if isinstance(value, str):
120
+ new_params[key] = f"'{value}'"
121
+ elif value is None:
122
+ new_params[key] = "null"
123
+ else:
124
+ new_params[key] = value
125
+ print(sql_str % new_params, file=self.file)
126
+
127
+
128
+ class ConnectionWrapper:
129
+ """A wrapper for a SQLAlchemy engine or mock connection which provides a
130
+ consistent interface for executing SQL statements.
131
+ """
132
+
133
+ def __init__(self, engine: Engine | MockConnection):
134
+ """Initialize the connection wrapper.
135
+
136
+ Parameters
137
+ ----------
138
+ engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
139
+ The SQLAlchemy engine or mock connection to wrap.
140
+ """
141
+ self.engine = engine
142
+
143
+ def execute(self, statement: Any) -> ResultProxy:
144
+ """Execute a SQL statement on the engine and return the result."""
145
+ if isinstance(statement, str):
146
+ statement = text(statement)
147
+ if isinstance(self.engine, MockConnection):
148
+ return self.engine.connect().execute(statement)
149
+ else:
150
+ with self.engine.begin() as connection:
151
+ result = connection.execute(statement)
152
+ return result
153
+
154
+
155
+ class DatabaseContext:
156
+ """A class for managing the schema and its database connection."""
157
+
158
+ def __init__(self, metadata: MetaData, engine: Engine | MockConnection):
159
+ """Initialize the database context.
160
+
161
+ Parameters
162
+ ----------
163
+ metadata : `sqlalchemy.MetaData`
164
+ The SQLAlchemy metadata object.
165
+
166
+ engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
167
+ The SQLAlchemy engine or mock connection object.
168
+ """
169
+ self.engine = engine
170
+ self.dialect_name = engine.dialect.name
171
+ self.metadata = metadata
172
+ self.conn = ConnectionWrapper(engine)
173
+
174
+ def create_if_not_exists(self) -> None:
175
+ """Create the schema in the database if it does not exist.
176
+
177
+ In MySQL, this will create a new database. In PostgreSQL, it will
178
+ create a new schema. For other variants, this is an unsupported
179
+ operation.
180
+
181
+ Parameters
182
+ ----------
183
+ engine: `sqlalchemy.Engine`
184
+ The SQLAlchemy engine object.
185
+ schema_name: `str`
186
+ The name of the schema (or database) to create.
187
+ """
188
+ schema_name = self.metadata.schema
189
+ try:
190
+ if self.dialect_name == "mysql":
191
+ logger.debug(f"Creating MySQL database: {schema_name}")
192
+ self.conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {schema_name}"))
193
+ elif self.dialect_name == "postgresql":
194
+ logger.debug(f"Creating PG schema: {schema_name}")
195
+ self.conn.execute(CreateSchema(schema_name, if_not_exists=True))
196
+ else:
197
+ raise ValueError("Unsupported database type:" + self.dialect_name)
198
+ except SQLAlchemyError as e:
199
+ logger.error(f"Error creating schema: {e}")
200
+ raise
201
+
202
+ def drop_if_exists(self) -> None:
203
+ """Drop the schema in the database if it exists.
204
+
205
+ In MySQL, this will drop a database. In PostgreSQL, it will drop a
206
+ schema. For other variants, this is unsupported for now.
207
+
208
+ Parameters
209
+ ----------
210
+ engine: `sqlalchemy.Engine`
211
+ The SQLAlchemy engine object.
212
+ schema_name: `str`
213
+ The name of the schema (or database) to drop.
214
+ """
215
+ schema_name = self.metadata.schema
216
+ try:
217
+ if self.dialect_name == "mysql":
218
+ logger.debug(f"Dropping MySQL database if exists: {schema_name}")
219
+ self.conn.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
220
+ elif self.dialect_name == "postgresql":
221
+ logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
222
+ self.conn.execute(DropSchema(schema_name, if_exists=True, cascade=True))
223
+ else:
224
+ raise ValueError(f"Unsupported database type: {self.dialect_name}")
225
+ except SQLAlchemyError as e:
226
+ logger.error(f"Error dropping schema: {e}")
227
+ raise
228
+
229
+ def create_all(self) -> None:
230
+ """Create all tables in the schema using the metadata object."""
231
+ self.metadata.create_all(self.engine)
232
+
233
+ @staticmethod
234
+ def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
235
+ """Create a mock engine for testing or dumping DDL statements.
236
+
237
+ Parameters
238
+ ----------
239
+ engine_url : `sqlalchemy.engine.url.URL`
240
+ The SQLAlchemy engine URL.
241
+ output_file : `typing.IO` [ `str` ] or `None`, optional
242
+ The file to write the SQL statements to. If None, the statements
243
+ will be written to stdout.
244
+ """
245
+ writer = SQLWriter(output_file)
246
+ engine = create_mock_engine(engine_url, executor=writer.write)
247
+ writer.dialect = engine.dialect
248
+ return engine