lsst-felis 26.2024.900__py3-none-any.whl → 29.2025.4500__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- felis/__init__.py +10 -24
- felis/cli.py +437 -341
- felis/config/tap_schema/columns.csv +33 -0
- felis/config/tap_schema/key_columns.csv +8 -0
- felis/config/tap_schema/keys.csv +8 -0
- felis/config/tap_schema/schemas.csv +2 -0
- felis/config/tap_schema/tables.csv +6 -0
- felis/config/tap_schema/tap_schema_std.yaml +273 -0
- felis/datamodel.py +1386 -193
- felis/db/dialects.py +116 -0
- felis/db/schema.py +62 -0
- felis/db/sqltypes.py +275 -48
- felis/db/utils.py +409 -0
- felis/db/variants.py +159 -0
- felis/diff.py +234 -0
- felis/metadata.py +385 -0
- felis/tap_schema.py +767 -0
- felis/tests/__init__.py +0 -0
- felis/tests/postgresql.py +134 -0
- felis/tests/run_cli.py +79 -0
- felis/types.py +57 -9
- lsst_felis-29.2025.4500.dist-info/METADATA +38 -0
- lsst_felis-29.2025.4500.dist-info/RECORD +31 -0
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info}/WHEEL +1 -1
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info/licenses}/COPYRIGHT +1 -1
- felis/check.py +0 -381
- felis/simple.py +0 -424
- felis/sql.py +0 -275
- felis/tap.py +0 -433
- felis/utils.py +0 -100
- felis/validation.py +0 -103
- felis/version.py +0 -2
- felis/visitor.py +0 -180
- lsst_felis-26.2024.900.dist-info/METADATA +0 -28
- lsst_felis-26.2024.900.dist-info/RECORD +0 -23
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info}/entry_points.txt +0 -0
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info/licenses}/LICENSE +0 -0
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info}/top_level.txt +0 -0
- {lsst_felis-26.2024.900.dist-info → lsst_felis-29.2025.4500.dist-info}/zip-safe +0 -0
felis/db/utils.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""Database utility functions and classes."""
|
|
2
|
+
|
|
3
|
+
# This file is part of felis.
|
|
4
|
+
#
|
|
5
|
+
# Developed for the LSST Data Management System.
|
|
6
|
+
# This product includes software developed by the LSST Project
|
|
7
|
+
# (https://www.lsst.org).
|
|
8
|
+
# See the COPYRIGHT file at the top-level directory of this distribution
|
|
9
|
+
# for details of code ownership.
|
|
10
|
+
#
|
|
11
|
+
# This program is free software: you can redistribute it and/or modify
|
|
12
|
+
# it under the terms of the GNU General Public License as published by
|
|
13
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
14
|
+
# (at your option) any later version.
|
|
15
|
+
#
|
|
16
|
+
# This program is distributed in the hope that it will be useful,
|
|
17
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
18
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
19
|
+
# GNU General Public License for more details.
|
|
20
|
+
#
|
|
21
|
+
# You should have received a copy of the GNU General Public License
|
|
22
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
import re
|
|
28
|
+
from typing import IO, Any
|
|
29
|
+
|
|
30
|
+
from sqlalchemy import MetaData, types
|
|
31
|
+
from sqlalchemy.engine import Dialect, Engine, ResultProxy
|
|
32
|
+
from sqlalchemy.engine.mock import MockConnection, create_mock_engine
|
|
33
|
+
from sqlalchemy.engine.url import URL
|
|
34
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
35
|
+
from sqlalchemy.schema import CreateSchema, DropSchema
|
|
36
|
+
from sqlalchemy.sql import text
|
|
37
|
+
from sqlalchemy.types import TypeEngine
|
|
38
|
+
|
|
39
|
+
from .dialects import get_dialect_module
|
|
40
|
+
|
|
41
|
+
__all__ = ["ConnectionWrapper", "DatabaseContext", "SQLWriter", "string_to_typeengine"]
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger("felis")
|
|
44
|
+
|
|
45
|
+
_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
|
|
46
|
+
"""Regular expression to match data types with parameters in parentheses."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def string_to_typeengine(
|
|
50
|
+
type_string: str, dialect: Dialect | None = None, length: int | None = None
|
|
51
|
+
) -> TypeEngine:
|
|
52
|
+
"""Convert a string representation of a datatype to a SQLAlchemy type.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
type_string
|
|
57
|
+
The string representation of the data type.
|
|
58
|
+
dialect
|
|
59
|
+
The SQLAlchemy dialect to use. If None, the default dialect will be
|
|
60
|
+
used.
|
|
61
|
+
length
|
|
62
|
+
The length of the data type. If the data type does not have a length
|
|
63
|
+
attribute, this parameter will be ignored.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
`sqlalchemy.types.TypeEngine`
|
|
68
|
+
The SQLAlchemy type engine object.
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
------
|
|
72
|
+
ValueError
|
|
73
|
+
Raised if the type string is invalid or the type is not supported.
|
|
74
|
+
|
|
75
|
+
Notes
|
|
76
|
+
-----
|
|
77
|
+
This function is used when converting type override strings defined in
|
|
78
|
+
fields such as ``mysql:datatype`` in the schema data.
|
|
79
|
+
"""
|
|
80
|
+
match = _DATATYPE_REGEXP.search(type_string)
|
|
81
|
+
if not match:
|
|
82
|
+
raise ValueError(f"Invalid type string: {type_string}")
|
|
83
|
+
|
|
84
|
+
type_name, _, params = match.groups()
|
|
85
|
+
if dialect is None:
|
|
86
|
+
type_class = getattr(types, type_name.upper(), None)
|
|
87
|
+
else:
|
|
88
|
+
try:
|
|
89
|
+
dialect_module = get_dialect_module(dialect.name)
|
|
90
|
+
except KeyError:
|
|
91
|
+
raise ValueError(f"Unsupported dialect: {dialect}")
|
|
92
|
+
type_class = getattr(dialect_module, type_name.upper(), None)
|
|
93
|
+
|
|
94
|
+
if not type_class:
|
|
95
|
+
raise ValueError(f"Unsupported type: {type_class}")
|
|
96
|
+
|
|
97
|
+
if params:
|
|
98
|
+
params = [int(param) if param.isdigit() else param for param in params.split(",")]
|
|
99
|
+
type_obj = type_class(*params)
|
|
100
|
+
else:
|
|
101
|
+
type_obj = type_class()
|
|
102
|
+
|
|
103
|
+
if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
|
|
104
|
+
type_obj.length = length
|
|
105
|
+
|
|
106
|
+
return type_obj
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def is_mock_url(url: URL) -> bool:
|
|
110
|
+
"""Check if the engine URL is a mock URL.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
url
|
|
115
|
+
The SQLAlchemy engine URL.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
bool
|
|
120
|
+
True if the URL is a mock URL, False otherwise.
|
|
121
|
+
"""
|
|
122
|
+
return (url.drivername == "sqlite" and url.database is None) or (
|
|
123
|
+
url.drivername != "sqlite" and url.host is None
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def is_valid_engine(engine: Engine | MockConnection | None) -> bool:
|
|
128
|
+
"""Check if the engine is valid.
|
|
129
|
+
|
|
130
|
+
The engine cannot be none; it must not be a mock connection; and it must
|
|
131
|
+
not be a mock URL which is missing a host or, for sqlite, a database name.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
engine
|
|
136
|
+
The SQLAlchemy engine or mock connection.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
bool
|
|
141
|
+
True if the engine is valid, False otherwise.
|
|
142
|
+
"""
|
|
143
|
+
return engine is not None and not isinstance(engine, MockConnection) and not is_mock_url(engine.url)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SQLWriter:
|
|
147
|
+
"""Write SQL statements to stdout or a file.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
file
|
|
152
|
+
The file to write the SQL statements to. If None, the statements
|
|
153
|
+
will be written to stdout.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, file: IO[str] | None = None) -> None:
|
|
157
|
+
"""Initialize the SQL writer."""
|
|
158
|
+
self.file = file
|
|
159
|
+
self.dialect: Dialect | None = None
|
|
160
|
+
|
|
161
|
+
def write(self, sql: Any, *multiparams: Any, **params: Any) -> None:
|
|
162
|
+
"""Write the SQL statement to a file or stdout.
|
|
163
|
+
|
|
164
|
+
Statements with parameters will be formatted with the values
|
|
165
|
+
inserted into the resultant SQL output.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
sql
|
|
170
|
+
The SQL statement to write.
|
|
171
|
+
*multiparams
|
|
172
|
+
The multiparams to use for the SQL statement.
|
|
173
|
+
**params
|
|
174
|
+
The params to use for the SQL statement.
|
|
175
|
+
|
|
176
|
+
Notes
|
|
177
|
+
-----
|
|
178
|
+
The functions arguments are typed very loosely because this method in
|
|
179
|
+
SQLAlchemy is untyped, amd we do not call it directly.
|
|
180
|
+
"""
|
|
181
|
+
compiled = sql.compile(dialect=self.dialect)
|
|
182
|
+
sql_str = str(compiled) + ";"
|
|
183
|
+
params_list = [compiled.params]
|
|
184
|
+
for params in params_list:
|
|
185
|
+
if not params:
|
|
186
|
+
print(sql_str, file=self.file)
|
|
187
|
+
continue
|
|
188
|
+
new_params = {}
|
|
189
|
+
for key, value in params.items():
|
|
190
|
+
if isinstance(value, str):
|
|
191
|
+
new_params[key] = f"'{value}'"
|
|
192
|
+
elif value is None:
|
|
193
|
+
new_params[key] = "null"
|
|
194
|
+
else:
|
|
195
|
+
new_params[key] = value
|
|
196
|
+
print(sql_str % new_params, file=self.file)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ConnectionWrapper:
|
|
200
|
+
"""Wrap a SQLAlchemy engine or mock connection to provide a consistent
|
|
201
|
+
interface for executing SQL statements.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
engine
|
|
206
|
+
The SQLAlchemy engine or mock connection to wrap.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, engine: Engine | MockConnection):
|
|
210
|
+
"""Initialize the connection wrapper."""
|
|
211
|
+
self.engine = engine
|
|
212
|
+
|
|
213
|
+
def execute(self, statement: Any) -> ResultProxy:
|
|
214
|
+
"""Execute a SQL statement on the engine and return the result.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
statement
|
|
219
|
+
The SQL statement to execute.
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
``sqlalchemy.engine.ResultProxy``
|
|
224
|
+
The result of the statement execution.
|
|
225
|
+
|
|
226
|
+
Notes
|
|
227
|
+
-----
|
|
228
|
+
The statement will be executed in a transaction block if not using
|
|
229
|
+
a mock connection.
|
|
230
|
+
"""
|
|
231
|
+
if isinstance(statement, str):
|
|
232
|
+
statement = text(statement)
|
|
233
|
+
if isinstance(self.engine, Engine):
|
|
234
|
+
try:
|
|
235
|
+
with self.engine.begin() as connection:
|
|
236
|
+
result = connection.execute(statement)
|
|
237
|
+
return result
|
|
238
|
+
except SQLAlchemyError as e:
|
|
239
|
+
connection.rollback()
|
|
240
|
+
logger.error(f"Error executing statement: {e}")
|
|
241
|
+
raise
|
|
242
|
+
elif isinstance(self.engine, MockConnection):
|
|
243
|
+
return self.engine.connect().execute(statement)
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError("Unsupported engine type:" + str(type(self.engine)))
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class DatabaseContext:
|
|
249
|
+
"""Manage the database connection and SQLAlchemy metadata.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
metadata
|
|
254
|
+
The SQLAlchemy metadata object.
|
|
255
|
+
|
|
256
|
+
engine
|
|
257
|
+
The SQLAlchemy engine or mock connection object.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(self, metadata: MetaData, engine: Engine | MockConnection):
|
|
261
|
+
"""Initialize the database context."""
|
|
262
|
+
self.engine = engine
|
|
263
|
+
self.dialect_name = engine.dialect.name
|
|
264
|
+
self.metadata = metadata
|
|
265
|
+
self.connection = ConnectionWrapper(engine)
|
|
266
|
+
|
|
267
|
+
def initialize(self) -> None:
|
|
268
|
+
"""Create the schema in the database if it does not exist.
|
|
269
|
+
|
|
270
|
+
Raises
|
|
271
|
+
------
|
|
272
|
+
ValueError
|
|
273
|
+
Raised if the database is not supported or it already exists.
|
|
274
|
+
sqlalchemy.exc.SQLAlchemyError
|
|
275
|
+
Raised if there is an error creating the schema.
|
|
276
|
+
|
|
277
|
+
Notes
|
|
278
|
+
-----
|
|
279
|
+
In MySQL, this will create a new database and, in PostgreSQL, it will
|
|
280
|
+
create a new schema. For other variants, this is an unsupported
|
|
281
|
+
operation.
|
|
282
|
+
"""
|
|
283
|
+
if self.engine.dialect.name == "sqlite":
|
|
284
|
+
# Initialization is unneeded for sqlite.
|
|
285
|
+
return
|
|
286
|
+
schema_name = self.metadata.schema
|
|
287
|
+
if schema_name is None:
|
|
288
|
+
raise ValueError("Schema name is required to initialize the schema.")
|
|
289
|
+
try:
|
|
290
|
+
if self.dialect_name == "mysql":
|
|
291
|
+
logger.debug(f"Checking if MySQL database exists: {schema_name}")
|
|
292
|
+
result = self.execute(text(f"SHOW DATABASES LIKE '{schema_name}'"))
|
|
293
|
+
if result.fetchone():
|
|
294
|
+
raise ValueError(f"MySQL database '{schema_name}' already exists.")
|
|
295
|
+
logger.debug(f"Creating MySQL database: {schema_name}")
|
|
296
|
+
self.execute(text(f"CREATE DATABASE {schema_name}"))
|
|
297
|
+
elif self.dialect_name == "postgresql":
|
|
298
|
+
logger.debug(f"Checking if PG schema exists: {schema_name}")
|
|
299
|
+
result = self.execute(
|
|
300
|
+
text(
|
|
301
|
+
f"""
|
|
302
|
+
SELECT schema_name
|
|
303
|
+
FROM information_schema.schemata
|
|
304
|
+
WHERE schema_name = '{schema_name}'
|
|
305
|
+
"""
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
if result.fetchone():
|
|
309
|
+
raise ValueError(f"PostgreSQL schema '{schema_name}' already exists.")
|
|
310
|
+
logger.debug(f"Creating PG schema: {schema_name}")
|
|
311
|
+
self.execute(CreateSchema(schema_name))
|
|
312
|
+
else:
|
|
313
|
+
raise ValueError(f"Initialization not supported for: {self.dialect_name}")
|
|
314
|
+
except SQLAlchemyError as e:
|
|
315
|
+
logger.error(f"Error creating schema: {e}")
|
|
316
|
+
raise
|
|
317
|
+
|
|
318
|
+
def drop(self) -> None:
|
|
319
|
+
"""Drop the schema in the database if it exists.
|
|
320
|
+
|
|
321
|
+
Raises
|
|
322
|
+
------
|
|
323
|
+
ValueError
|
|
324
|
+
Raised if the database is not supported.
|
|
325
|
+
|
|
326
|
+
Notes
|
|
327
|
+
-----
|
|
328
|
+
In MySQL, this will drop a database. In PostgreSQL, it will drop a
|
|
329
|
+
schema. A SQlite database will have all its tables dropped. For other
|
|
330
|
+
database variants, this is currently an unsupported operation.
|
|
331
|
+
"""
|
|
332
|
+
try:
|
|
333
|
+
if self.dialect_name == "sqlite":
|
|
334
|
+
if isinstance(self.engine, Engine):
|
|
335
|
+
logger.debug("Dropping tables in SQLite schema")
|
|
336
|
+
self.metadata.drop_all(bind=self.engine)
|
|
337
|
+
else:
|
|
338
|
+
schema_name = self.metadata.schema
|
|
339
|
+
if schema_name is None:
|
|
340
|
+
raise ValueError("Schema name is required to drop the schema.")
|
|
341
|
+
if self.dialect_name == "mysql":
|
|
342
|
+
logger.debug(f"Dropping MySQL database if exists: {schema_name}")
|
|
343
|
+
self.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
|
|
344
|
+
elif self.dialect_name == "postgresql":
|
|
345
|
+
logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}")
|
|
346
|
+
self.execute(DropSchema(schema_name, if_exists=True, cascade=True))
|
|
347
|
+
except SQLAlchemyError as e:
|
|
348
|
+
logger.error(f"Error dropping schema: {e}")
|
|
349
|
+
raise
|
|
350
|
+
|
|
351
|
+
def create_all(self) -> None:
|
|
352
|
+
"""Create all tables in the schema using the metadata object."""
|
|
353
|
+
if isinstance(self.engine, Engine):
|
|
354
|
+
# Use a transaction for a real connection.
|
|
355
|
+
with self.engine.begin() as conn:
|
|
356
|
+
try:
|
|
357
|
+
self.metadata.create_all(bind=conn)
|
|
358
|
+
conn.commit()
|
|
359
|
+
except SQLAlchemyError as e:
|
|
360
|
+
conn.rollback()
|
|
361
|
+
logger.error(f"Error creating tables: {e}")
|
|
362
|
+
raise
|
|
363
|
+
elif isinstance(self.engine, MockConnection):
|
|
364
|
+
# Mock connection so no need for a transaction.
|
|
365
|
+
self.metadata.create_all(self.engine)
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError("Unsupported engine type: " + str(type(self.engine)))
|
|
368
|
+
|
|
369
|
+
@staticmethod
|
|
370
|
+
def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection:
|
|
371
|
+
"""Create a mock engine for testing or dumping DDL statements.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
engine_url
|
|
376
|
+
The SQLAlchemy engine URL.
|
|
377
|
+
output_file
|
|
378
|
+
The file to write the SQL statements to. If None, the statements
|
|
379
|
+
will be written to stdout.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
``sqlalchemy.engine.mock.MockConnection``
|
|
384
|
+
The mock connection object.
|
|
385
|
+
"""
|
|
386
|
+
writer = SQLWriter(output_file)
|
|
387
|
+
engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat")
|
|
388
|
+
writer.dialect = engine.dialect
|
|
389
|
+
return engine
|
|
390
|
+
|
|
391
|
+
def execute(self, statement: Any) -> ResultProxy:
|
|
392
|
+
"""Execute a SQL statement on the engine and return the result.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
statement
|
|
397
|
+
The SQL statement to execute.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
``sqlalchemy.engine.ResultProxy``
|
|
402
|
+
The result of the statement execution.
|
|
403
|
+
|
|
404
|
+
Notes
|
|
405
|
+
-----
|
|
406
|
+
This is just a wrapper around the execution method of the connection
|
|
407
|
+
object, which may execute on a real or mock connection.
|
|
408
|
+
"""
|
|
409
|
+
return self.connection.execute(statement)
|
felis/db/variants.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""Handle variant overrides for a Felis column."""
|
|
2
|
+
|
|
3
|
+
# This file is part of felis.
|
|
4
|
+
#
|
|
5
|
+
# Developed for the LSST Data Management System.
|
|
6
|
+
# This product includes software developed by the LSST Project
|
|
7
|
+
# (https://www.lsst.org).
|
|
8
|
+
# See the COPYRIGHT file at the top-level directory of this distribution
|
|
9
|
+
# for details of code ownership.
|
|
10
|
+
#
|
|
11
|
+
# This program is free software: you can redistribute it and/or modify
|
|
12
|
+
# it under the terms of the GNU General Public License as published by
|
|
13
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
14
|
+
# (at your option) any later version.
|
|
15
|
+
#
|
|
16
|
+
# This program is distributed in the hope that it will be useful,
|
|
17
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
18
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
19
|
+
# GNU General Public License for more details.
|
|
20
|
+
#
|
|
21
|
+
# You should have received a copy of the GNU General Public License
|
|
22
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import re
|
|
27
|
+
from collections.abc import Mapping
|
|
28
|
+
from types import MappingProxyType
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
from sqlalchemy import types
|
|
32
|
+
from sqlalchemy.types import TypeEngine
|
|
33
|
+
|
|
34
|
+
from ..datamodel import Column
|
|
35
|
+
from .dialects import get_dialect_module, get_supported_dialects
|
|
36
|
+
|
|
37
|
+
__all__ = ["make_variant_dict"]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _create_column_variant_overrides() -> dict[str, str]:
|
|
41
|
+
"""Map column variant overrides to their dialect name.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
column_variant_overrides : `dict` [ `str`, `str` ]
|
|
46
|
+
A mapping of column variant overrides to their dialect name.
|
|
47
|
+
|
|
48
|
+
Notes
|
|
49
|
+
-----
|
|
50
|
+
This function is intended for internal use only.
|
|
51
|
+
"""
|
|
52
|
+
column_variant_overrides = {}
|
|
53
|
+
for dialect_name in get_supported_dialects().keys():
|
|
54
|
+
column_variant_overrides[f"{dialect_name}_datatype"] = dialect_name
|
|
55
|
+
return column_variant_overrides
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
_COLUMN_VARIANT_OVERRIDES = MappingProxyType(_create_column_variant_overrides())
|
|
59
|
+
"""Map of column variant overrides to their dialect name."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_column_variant_overrides() -> Mapping[str, str]:
|
|
63
|
+
"""Get a dictionary of column variant overrides.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
column_variant_overrides : `dict` [ `str`, `str` ]
|
|
68
|
+
A mapping of column variant overrides to their dialect name.
|
|
69
|
+
"""
|
|
70
|
+
return _COLUMN_VARIANT_OVERRIDES
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_column_variant_override(field_name: str) -> str:
|
|
74
|
+
"""Get the dialect name from an override field name on the column like
|
|
75
|
+
``mysql_datatype``.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
dialect_name : `str`
|
|
80
|
+
The name of the dialect.
|
|
81
|
+
|
|
82
|
+
Raises
|
|
83
|
+
------
|
|
84
|
+
ValueError
|
|
85
|
+
Raised if the field name is not found in the column variant overrides.
|
|
86
|
+
"""
|
|
87
|
+
if field_name not in _COLUMN_VARIANT_OVERRIDES:
|
|
88
|
+
raise ValueError(f"Field name {field_name} not found in column variant overrides")
|
|
89
|
+
return _COLUMN_VARIANT_OVERRIDES[field_name]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
_length_regex = re.compile(r"\((\d+)\)")
|
|
93
|
+
"""A regular expression that is looking for numbers within parentheses."""
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
|
|
97
|
+
"""Get the variant type for the given dialect.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
dialect_name
|
|
102
|
+
The name of the dialect to create.
|
|
103
|
+
variant_override_str
|
|
104
|
+
The string representation of the variant override.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
variant_type : `~sqlalchemy.types.TypeEngine`
|
|
109
|
+
The variant type for the given dialect.
|
|
110
|
+
|
|
111
|
+
Raises
|
|
112
|
+
------
|
|
113
|
+
ValueError
|
|
114
|
+
Raised if the type is not found in the dialect.
|
|
115
|
+
|
|
116
|
+
Notes
|
|
117
|
+
-----
|
|
118
|
+
This function converts a string representation of a variant override
|
|
119
|
+
into a `sqlalchemy.types.TypeEngine` object.
|
|
120
|
+
"""
|
|
121
|
+
dialect = get_dialect_module(dialect_name)
|
|
122
|
+
variant_type_name = variant_override_str.split("(")[0]
|
|
123
|
+
|
|
124
|
+
# Process Variant Type
|
|
125
|
+
if variant_type_name not in dir(dialect):
|
|
126
|
+
raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
|
|
127
|
+
variant_type = getattr(dialect, variant_type_name)
|
|
128
|
+
length_params = []
|
|
129
|
+
if match := _length_regex.search(variant_override_str):
|
|
130
|
+
length_params.extend([int(i) for i in match.group(1).split(",")])
|
|
131
|
+
return variant_type(*length_params)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def make_variant_dict(column_obj: Column) -> dict[str, TypeEngine[Any]]:
|
|
135
|
+
"""Handle variant overrides for a `felis.datamodel.Column`.
|
|
136
|
+
|
|
137
|
+
This function will return a dictionary of `str` to
|
|
138
|
+
`sqlalchemy.types.TypeEngine` containing variant datatype information
|
|
139
|
+
(e.g., for mysql, postgresql, etc).
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
column_obj
|
|
144
|
+
The column object from which to build the variant dictionary.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
`dict` [ `str`, `~sqlalchemy.types.TypeEngine` ]
|
|
149
|
+
The dictionary of `str` to `sqlalchemy.types.TypeEngine` containing
|
|
150
|
+
variant datatype information (e.g., for mysql, postgresql, etc).
|
|
151
|
+
"""
|
|
152
|
+
variant_dict = {}
|
|
153
|
+
variant_overrides = _get_column_variant_overrides()
|
|
154
|
+
for field_name, value in iter(column_obj):
|
|
155
|
+
if field_name in variant_overrides and value is not None:
|
|
156
|
+
dialect = _get_column_variant_override(field_name)
|
|
157
|
+
variant: TypeEngine = _process_variant_override(dialect, value)
|
|
158
|
+
variant_dict[dialect] = variant
|
|
159
|
+
return variant_dict
|