thds.tabularasa 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- thds/tabularasa/__init__.py +6 -0
- thds/tabularasa/__main__.py +1122 -0
- thds/tabularasa/compat.py +33 -0
- thds/tabularasa/data_dependencies/__init__.py +0 -0
- thds/tabularasa/data_dependencies/adls.py +97 -0
- thds/tabularasa/data_dependencies/build.py +573 -0
- thds/tabularasa/data_dependencies/sqlite.py +286 -0
- thds/tabularasa/data_dependencies/tabular.py +167 -0
- thds/tabularasa/data_dependencies/util.py +209 -0
- thds/tabularasa/diff/__init__.py +0 -0
- thds/tabularasa/diff/data.py +346 -0
- thds/tabularasa/diff/schema.py +254 -0
- thds/tabularasa/diff/summary.py +249 -0
- thds/tabularasa/git_util.py +37 -0
- thds/tabularasa/loaders/__init__.py +0 -0
- thds/tabularasa/loaders/lazy_adls.py +44 -0
- thds/tabularasa/loaders/parquet_util.py +385 -0
- thds/tabularasa/loaders/sqlite_util.py +346 -0
- thds/tabularasa/loaders/util.py +532 -0
- thds/tabularasa/py.typed +0 -0
- thds/tabularasa/schema/__init__.py +7 -0
- thds/tabularasa/schema/compilation/__init__.py +20 -0
- thds/tabularasa/schema/compilation/_format.py +50 -0
- thds/tabularasa/schema/compilation/attrs.py +257 -0
- thds/tabularasa/schema/compilation/attrs_sqlite.py +278 -0
- thds/tabularasa/schema/compilation/io.py +96 -0
- thds/tabularasa/schema/compilation/pandas.py +252 -0
- thds/tabularasa/schema/compilation/pyarrow.py +93 -0
- thds/tabularasa/schema/compilation/sphinx.py +550 -0
- thds/tabularasa/schema/compilation/sqlite.py +69 -0
- thds/tabularasa/schema/compilation/util.py +117 -0
- thds/tabularasa/schema/constraints.py +327 -0
- thds/tabularasa/schema/dtypes.py +153 -0
- thds/tabularasa/schema/extract_from_parquet.py +132 -0
- thds/tabularasa/schema/files.py +215 -0
- thds/tabularasa/schema/metaschema.py +1007 -0
- thds/tabularasa/schema/util.py +123 -0
- thds/tabularasa/schema/validation.py +878 -0
- thds/tabularasa/sqlite3_compat.py +41 -0
- thds/tabularasa/sqlite_from_parquet.py +34 -0
- thds/tabularasa/to_sqlite.py +56 -0
- thds_tabularasa-0.13.0.dist-info/METADATA +530 -0
- thds_tabularasa-0.13.0.dist-info/RECORD +46 -0
- thds_tabularasa-0.13.0.dist-info/WHEEL +5 -0
- thds_tabularasa-0.13.0.dist-info/entry_points.txt +2 -0
- thds_tabularasa-0.13.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import typing as ty
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from textwrap import wrap
|
|
4
|
+
|
|
5
|
+
import typing_extensions
|
|
6
|
+
|
|
7
|
+
import thds.tabularasa.loaders.util
|
|
8
|
+
from thds.tabularasa.schema import metaschema
|
|
9
|
+
|
|
10
|
+
from ._format import autoformat
|
|
11
|
+
from .util import (
|
|
12
|
+
AUTOGEN_DISCLAIMER,
|
|
13
|
+
VarName,
|
|
14
|
+
_indent,
|
|
15
|
+
_wrap_lines_with_prefix,
|
|
16
|
+
constructor_template,
|
|
17
|
+
render_blob_store_def,
|
|
18
|
+
render_constructor,
|
|
19
|
+
sorted_class_names_for_import,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
DEFAULT_LINE_WIDTH = 88
|
|
23
|
+
|
|
24
|
+
REMOTE_BLOB_STORE_VAR_NAME = "REMOTE_BLOB_STORE"
|
|
25
|
+
|
|
26
|
+
DOCSTRING_PARAM_TEMPLATE = """:param {name}: {doc}"""
|
|
27
|
+
|
|
28
|
+
CUSTOM_TYPE_DEF_TEMPLATE = """{comment}{name} = {type}"""
|
|
29
|
+
|
|
30
|
+
ATTRS_CLASS_DEF_TEMPLATE = """@attr.s(auto_attribs=True, frozen=True)
|
|
31
|
+
class {class_name}:
|
|
32
|
+
\"\"\"{doc}
|
|
33
|
+
|
|
34
|
+
{params}
|
|
35
|
+
\"\"\"
|
|
36
|
+
|
|
37
|
+
{fields}
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
ATTRS_FIELD_DEF_TEMPLATE_BASIC = "{name}: {type}"
|
|
41
|
+
|
|
42
|
+
ATTRS_LOADER_TEMPLATE = constructor_template(
|
|
43
|
+
thds.tabularasa.loaders.util.AttrsParquetLoader,
|
|
44
|
+
exclude=["filename"],
|
|
45
|
+
type_params=["{record_type}"],
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def render_type_def(
|
|
50
|
+
type_: metaschema.CustomType,
|
|
51
|
+
build_options: metaschema.BuildOptions,
|
|
52
|
+
) -> str:
|
|
53
|
+
type_literal = type_.python_type_def_literal(build_options)
|
|
54
|
+
|
|
55
|
+
if build_options.type_constraint_comments:
|
|
56
|
+
comment = type_.comment
|
|
57
|
+
if comment:
|
|
58
|
+
lines = wrap(comment, DEFAULT_LINE_WIDTH - 2)
|
|
59
|
+
comment = "\n".join("# " + line for line in lines) + "\n"
|
|
60
|
+
else:
|
|
61
|
+
comment = ""
|
|
62
|
+
else:
|
|
63
|
+
comment = ""
|
|
64
|
+
|
|
65
|
+
return CUSTOM_TYPE_DEF_TEMPLATE.format(comment=comment, name=type_.class_name, type=type_literal)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def render_attr_field_def(
|
|
69
|
+
column: metaschema.Column, build_options: metaschema.BuildOptions, builtin: bool = False
|
|
70
|
+
) -> str:
|
|
71
|
+
type_literal = column.python_type_literal(build_options=build_options, builtin=builtin)
|
|
72
|
+
return ATTRS_FIELD_DEF_TEMPLATE_BASIC.format(name=column.snake_case_name, type=type_literal)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def render_attrs_table_schema(table: metaschema.Table, build_options: metaschema.BuildOptions) -> str:
|
|
76
|
+
field_defs = []
|
|
77
|
+
params = []
|
|
78
|
+
|
|
79
|
+
for column in table.columns:
|
|
80
|
+
field_def = render_attr_field_def(column, builtin=False, build_options=build_options)
|
|
81
|
+
field_defs.append(field_def)
|
|
82
|
+
doc = _wrap_lines_with_prefix(
|
|
83
|
+
column.doc,
|
|
84
|
+
DEFAULT_LINE_WIDTH - 4,
|
|
85
|
+
first_line_prefix_len=len(f":param {column.snake_case_name}: "),
|
|
86
|
+
trailing_line_indent=2,
|
|
87
|
+
)
|
|
88
|
+
params.append(
|
|
89
|
+
DOCSTRING_PARAM_TEMPLATE.format(
|
|
90
|
+
name=column.snake_case_name,
|
|
91
|
+
doc=doc,
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
table_doc = _wrap_lines_with_prefix(
|
|
96
|
+
table.doc,
|
|
97
|
+
DEFAULT_LINE_WIDTH - 4,
|
|
98
|
+
first_line_prefix_len=3, # triple quotes
|
|
99
|
+
trailing_line_indent=0,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return ATTRS_CLASS_DEF_TEMPLATE.format(
|
|
103
|
+
class_name=table.class_name,
|
|
104
|
+
doc=_indent(table_doc),
|
|
105
|
+
params=_indent("\n".join(params)),
|
|
106
|
+
fields=_indent("\n".join(field_defs)),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
PYARROW_SCHEMAS_QUALIFIED_IMPORT = "pyarrow_schemas"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ImportsAndCode(ty.NamedTuple):
|
|
114
|
+
"""Couples code and its required imports."""
|
|
115
|
+
|
|
116
|
+
third_party_imports: ty.List[str]
|
|
117
|
+
tabularasa_imports: ty.List[str]
|
|
118
|
+
code: ty.List[str]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def render_attrs_loaders(
|
|
122
|
+
schema: metaschema.Schema,
|
|
123
|
+
package: str,
|
|
124
|
+
) -> ImportsAndCode:
|
|
125
|
+
data_dir = schema.build_options.package_data_dir
|
|
126
|
+
render_pyarrow_schemas = schema.build_options.pyarrow
|
|
127
|
+
import_lines = list()
|
|
128
|
+
if render_pyarrow_schemas:
|
|
129
|
+
import_lines.append("\n")
|
|
130
|
+
import_lines.append(f"from . import pyarrow as {PYARROW_SCHEMAS_QUALIFIED_IMPORT}")
|
|
131
|
+
|
|
132
|
+
return ImportsAndCode(
|
|
133
|
+
list(),
|
|
134
|
+
import_lines,
|
|
135
|
+
[
|
|
136
|
+
render_constructor(
|
|
137
|
+
ATTRS_LOADER_TEMPLATE,
|
|
138
|
+
kwargs=dict(
|
|
139
|
+
record_type=VarName(table.class_name),
|
|
140
|
+
table_name=table.snake_case_name,
|
|
141
|
+
type_=VarName(table.class_name),
|
|
142
|
+
package=package,
|
|
143
|
+
data_dir=data_dir,
|
|
144
|
+
md5=table.md5,
|
|
145
|
+
blob_store=(
|
|
146
|
+
None
|
|
147
|
+
if schema.remote_blob_store is None or table.md5 is None
|
|
148
|
+
else VarName(REMOTE_BLOB_STORE_VAR_NAME)
|
|
149
|
+
),
|
|
150
|
+
pyarrow_schema=(
|
|
151
|
+
VarName(
|
|
152
|
+
f"{PYARROW_SCHEMAS_QUALIFIED_IMPORT}.{table.snake_case_name}_pyarrow_schema"
|
|
153
|
+
)
|
|
154
|
+
if render_pyarrow_schemas
|
|
155
|
+
else None
|
|
156
|
+
),
|
|
157
|
+
),
|
|
158
|
+
var_name=f"load_{table.snake_case_name}",
|
|
159
|
+
)
|
|
160
|
+
for table in schema.package_tables
|
|
161
|
+
],
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def render_attrs_type_defs(
|
|
166
|
+
schema: metaschema.Schema,
|
|
167
|
+
) -> ImportsAndCode:
|
|
168
|
+
# custom types
|
|
169
|
+
defined_custom_types = schema.defined_types
|
|
170
|
+
type_defs = [
|
|
171
|
+
render_type_def(
|
|
172
|
+
type_,
|
|
173
|
+
build_options=schema.build_options,
|
|
174
|
+
)
|
|
175
|
+
for type_ in sorted(defined_custom_types, key=lambda type_: type_.name)
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
import_lines = list()
|
|
179
|
+
# external type imports
|
|
180
|
+
sep = ",\n "
|
|
181
|
+
for module_name, class_names in sorted(schema.external_type_imports.items(), key=itemgetter(0)):
|
|
182
|
+
import_lines.append(
|
|
183
|
+
f"from {module_name} import (\n {sep.join(sorted_class_names_for_import(class_names))},\n)\n"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return ImportsAndCode([], import_lines, type_defs)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _render_attrs_schema(
|
|
190
|
+
schema: metaschema.Schema,
|
|
191
|
+
type_defs: ImportsAndCode,
|
|
192
|
+
loader_defs: ty.Optional[ImportsAndCode],
|
|
193
|
+
) -> str:
|
|
194
|
+
loader_defs = loader_defs or ImportsAndCode([], [], [])
|
|
195
|
+
assert loader_defs, "Loaders are optional but the line above is not"
|
|
196
|
+
|
|
197
|
+
# attrs record types
|
|
198
|
+
table_defs = [
|
|
199
|
+
render_attrs_table_schema(table, schema.build_options) for table in schema.package_tables
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
# imports
|
|
203
|
+
stdlib_imports = sorted(schema.attrs_required_imports)
|
|
204
|
+
import_extensions = typing_extensions.__name__ in stdlib_imports
|
|
205
|
+
if import_extensions:
|
|
206
|
+
stdlib_imports.remove(typing_extensions.__name__)
|
|
207
|
+
import_lines = [f"import {module}\n" for module in stdlib_imports]
|
|
208
|
+
|
|
209
|
+
if import_lines:
|
|
210
|
+
import_lines.append("\n")
|
|
211
|
+
import_lines.append("import attr\n")
|
|
212
|
+
if import_extensions:
|
|
213
|
+
import_lines.append(f"import {typing_extensions.__name__}\n")
|
|
214
|
+
|
|
215
|
+
import_lines.append("\n")
|
|
216
|
+
|
|
217
|
+
if loader_defs.code:
|
|
218
|
+
import_lines.append(f"import {thds.tabularasa.loaders.util.__name__}\n")
|
|
219
|
+
if schema.remote_blob_store is not None:
|
|
220
|
+
import_lines.append(f"import {thds.tabularasa.schema.files.__name__}\n")
|
|
221
|
+
|
|
222
|
+
import_lines.extend(type_defs.tabularasa_imports)
|
|
223
|
+
import_lines.extend(loader_defs.tabularasa_imports)
|
|
224
|
+
|
|
225
|
+
# globals
|
|
226
|
+
global_var_defs = []
|
|
227
|
+
if schema.remote_blob_store is not None:
|
|
228
|
+
global_var_defs.append(
|
|
229
|
+
render_blob_store_def(schema.remote_blob_store, REMOTE_BLOB_STORE_VAR_NAME)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
imports = "".join(import_lines)
|
|
233
|
+
globals_ = "\n".join(global_var_defs)
|
|
234
|
+
types = "\n".join(type_defs.code)
|
|
235
|
+
classes = "\n\n".join(table_defs)
|
|
236
|
+
loaders = "\n\n".join(loader_defs.code)
|
|
237
|
+
|
|
238
|
+
# module
|
|
239
|
+
return autoformat(
|
|
240
|
+
f"{imports}\n# {AUTOGEN_DISCLAIMER}\n\n{globals_}\n\n{types}\n\n\n{classes}\n\n{loaders}\n"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def render_attrs_module(
|
|
245
|
+
schema: metaschema.Schema,
|
|
246
|
+
package: str,
|
|
247
|
+
loader_defs: ty.Optional[ImportsAndCode] = None,
|
|
248
|
+
) -> str:
|
|
249
|
+
if loader_defs is None:
|
|
250
|
+
loader_defs = (
|
|
251
|
+
render_attrs_loaders(schema, package) if schema.build_options.package_data_dir else None
|
|
252
|
+
)
|
|
253
|
+
return _render_attrs_schema(
|
|
254
|
+
schema,
|
|
255
|
+
render_attrs_type_defs(schema),
|
|
256
|
+
loader_defs,
|
|
257
|
+
)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from itertools import chain
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import thds.tabularasa.loaders.util
|
|
6
|
+
import thds.tabularasa.schema
|
|
7
|
+
from thds.tabularasa.schema import metaschema
|
|
8
|
+
from thds.tabularasa.schema.compilation.attrs import render_attr_field_def
|
|
9
|
+
|
|
10
|
+
from ._format import autoformat
|
|
11
|
+
from .sqlite import index_name
|
|
12
|
+
from .util import AUTOGEN_DISCLAIMER, sorted_class_names_for_import
|
|
13
|
+
|
|
14
|
+
_LOGGER = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
PACKAGE_VARNAME = "PACKAGE"
|
|
17
|
+
DB_PATH_VARNAME = "DB_PATH"
|
|
18
|
+
|
|
19
|
+
ATTRS_MODULE_NAME = ".attrs"
|
|
20
|
+
|
|
21
|
+
LINE_WIDTH = 88
|
|
22
|
+
|
|
23
|
+
COLUMN_LINESEP = ",\n "
|
|
24
|
+
|
|
25
|
+
ATTRS_CLASS_LOADER_TEMPLATE = """class {class_name}Loader:
|
|
26
|
+
|
|
27
|
+
def __init__(self, db: util.%s):
|
|
28
|
+
self._db = db
|
|
29
|
+
self._record = util.%s({class_name})
|
|
30
|
+
|
|
31
|
+
{accessors}
|
|
32
|
+
""" % (
|
|
33
|
+
thds.tabularasa.loaders.util.AttrsSQLiteDatabase.__name__,
|
|
34
|
+
thds.tabularasa.loaders.sqlite_util.sqlite_constructor_for_record_type.__name__, # type: ignore
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
ATTRS_INDEX_ACCESSOR_TEMPLATE = """
|
|
38
|
+
def {method_name}(self, {typed_args}) -> typing.{return_type}[{class_name}]:
|
|
39
|
+
return self._db.sqlite_{index_kind}_query(
|
|
40
|
+
self._record,
|
|
41
|
+
\"\"\"
|
|
42
|
+
SELECT
|
|
43
|
+
{columns}
|
|
44
|
+
FROM {table_name}
|
|
45
|
+
INDEXED BY {index_name}
|
|
46
|
+
WHERE {condition};
|
|
47
|
+
\"\"\",
|
|
48
|
+
({args},),
|
|
49
|
+
)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
ATTRS_BULK_INDEX_ACCESSOR_TEMPLATE = """
|
|
53
|
+
def {method_name}_bulk(self, {arg_name}: typing.Collection[{typed_args}]) -> typing.{return_type}[{class_name}]:
|
|
54
|
+
if {arg_name}:
|
|
55
|
+
return self._db.sqlite_{index_kind}_query(
|
|
56
|
+
self._record,
|
|
57
|
+
f\"\"\"
|
|
58
|
+
SELECT
|
|
59
|
+
{columns}
|
|
60
|
+
FROM {table_name}
|
|
61
|
+
INDEXED BY {index_name}
|
|
62
|
+
WHERE {condition};
|
|
63
|
+
\"\"\",
|
|
64
|
+
{arg_name},
|
|
65
|
+
single_col={single_col},
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
return iter(())
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
ATTRS_MAIN_LOADER_TEMPLATE = """class SQLiteLoader:
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
package: typing.Optional[str] = %s,
|
|
75
|
+
db_path: str = %s,
|
|
76
|
+
cache_size: int = util.DEFAULT_ATTR_SQLITE_CACHE_SIZE,
|
|
77
|
+
mmap_size: int = util.DEFAULT_MMAP_BYTES,
|
|
78
|
+
):
|
|
79
|
+
self._db = util.%s(package=package, db_path=db_path, cache_size=cache_size, mmap_size=mmap_size)
|
|
80
|
+
{table_loaders}
|
|
81
|
+
""" % (
|
|
82
|
+
PACKAGE_VARNAME,
|
|
83
|
+
DB_PATH_VARNAME,
|
|
84
|
+
thds.tabularasa.loaders.util.AttrsSQLiteDatabase.__name__,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def render_attrs_loader_schema(table: metaschema.Table, build_options: metaschema.BuildOptions) -> str:
|
|
89
|
+
accessor_defs = []
|
|
90
|
+
column_lookup = {col.name: col for col in table.columns}
|
|
91
|
+
unq_constraints = {frozenset(c.unique) for c in table.unique_constraints}
|
|
92
|
+
|
|
93
|
+
if table.primary_key:
|
|
94
|
+
accessor_defs.append(
|
|
95
|
+
render_accessor_method(
|
|
96
|
+
table, table.primary_key, column_lookup, pk=True, build_options=build_options
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
accessor_defs.append(
|
|
100
|
+
render_accessor_method(
|
|
101
|
+
table, table.primary_key, column_lookup, pk=True, bulk=True, build_options=build_options
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
for idx in table.indexes:
|
|
106
|
+
unique = frozenset(idx) in unq_constraints
|
|
107
|
+
accessor_defs.append(
|
|
108
|
+
render_accessor_method(
|
|
109
|
+
table, idx, column_lookup, pk=False, unique=unique, build_options=build_options
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
accessor_defs.append(
|
|
113
|
+
render_accessor_method(
|
|
114
|
+
table,
|
|
115
|
+
idx,
|
|
116
|
+
column_lookup,
|
|
117
|
+
pk=False,
|
|
118
|
+
unique=unique,
|
|
119
|
+
bulk=True,
|
|
120
|
+
build_options=build_options,
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
accessors = "".join(accessor_defs).strip()
|
|
125
|
+
return ATTRS_CLASS_LOADER_TEMPLATE.format(
|
|
126
|
+
class_name=table.class_name,
|
|
127
|
+
accessors=accessors,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def render_accessor_method(
|
|
132
|
+
table: metaschema.Table,
|
|
133
|
+
index_columns: Tuple[str, ...],
|
|
134
|
+
column_lookup: Dict[str, metaschema.Column],
|
|
135
|
+
build_options: metaschema.BuildOptions,
|
|
136
|
+
pk: bool = False,
|
|
137
|
+
unique: bool = False,
|
|
138
|
+
bulk: bool = False,
|
|
139
|
+
) -> str:
|
|
140
|
+
index_column_names = tuple(map(metaschema.snake_case, index_columns))
|
|
141
|
+
method_name = "pk" if pk else f"idx_{'_'.join(index_column_names)}"
|
|
142
|
+
index_kind = "bulk" if bulk else ("pk" if pk or unique else "index")
|
|
143
|
+
return_type = "Iterator" if bulk else ("Optional" if pk or unique else "List")
|
|
144
|
+
# we use the `IS` operator to allow for comparison in case of nullable index columns
|
|
145
|
+
arg_name = "__".join(index_column_names)
|
|
146
|
+
nullsafe_condition = " AND ".join([f"{column} IS (?)" for column in index_column_names])
|
|
147
|
+
if bulk:
|
|
148
|
+
columns = [column_lookup[col] for col in index_columns]
|
|
149
|
+
has_null_cols = any(col.nullable for col in columns)
|
|
150
|
+
if len(index_column_names) == 1:
|
|
151
|
+
# for single-column indexes, we need to unpack the single value from the tuple
|
|
152
|
+
typed_args = columns[0].python_type_literal(build_options=build_options, builtin=True)
|
|
153
|
+
if has_null_cols:
|
|
154
|
+
# sqlite IN operator unfortunately doesn't support a NULL == NULL variant, the way that IS does for =
|
|
155
|
+
condition = f"{{' OR '.join(['{index_column_names[0]} IS (?)'] * len({arg_name}))}}"
|
|
156
|
+
else:
|
|
157
|
+
condition = f"{index_column_names[0]} IN ({{','.join('?' * len({arg_name}))}})"
|
|
158
|
+
single_col = "True"
|
|
159
|
+
else:
|
|
160
|
+
type_strs = ", ".join(
|
|
161
|
+
col.python_type_literal(build_options=build_options, builtin=True) for col in columns
|
|
162
|
+
)
|
|
163
|
+
typed_args = f"typing.Tuple[{type_strs}]"
|
|
164
|
+
if has_null_cols:
|
|
165
|
+
# sqlite IN operator unfortunately doesn't support a NULL == NULL variant, the way that IS does for =
|
|
166
|
+
param_tuple = f"({nullsafe_condition})"
|
|
167
|
+
condition = f"{{' OR '.join(['{param_tuple}'] * len({arg_name}))}}"
|
|
168
|
+
else:
|
|
169
|
+
param_tuple = f"({', '.join('?' * len(index_column_names))})"
|
|
170
|
+
condition = f"({', '.join(index_column_names)}) IN ({{','.join(['{param_tuple}'] * len({arg_name}))}})"
|
|
171
|
+
single_col = "False"
|
|
172
|
+
else:
|
|
173
|
+
condition = nullsafe_condition
|
|
174
|
+
typed_args = ", ".join(
|
|
175
|
+
[
|
|
176
|
+
render_attr_field_def(column_lookup[col], builtin=True, build_options=build_options)
|
|
177
|
+
for col in index_columns
|
|
178
|
+
]
|
|
179
|
+
)
|
|
180
|
+
single_col = ""
|
|
181
|
+
|
|
182
|
+
return (ATTRS_BULK_INDEX_ACCESSOR_TEMPLATE if bulk else ATTRS_INDEX_ACCESSOR_TEMPLATE).format(
|
|
183
|
+
class_name=table.class_name,
|
|
184
|
+
method_name=method_name,
|
|
185
|
+
# use builtin types (as opposed to e.g. Literals and Newtypes) to make the API simpler to use
|
|
186
|
+
typed_args=typed_args,
|
|
187
|
+
arg_name=arg_name,
|
|
188
|
+
single_col=single_col,
|
|
189
|
+
return_type=return_type,
|
|
190
|
+
index_kind=index_kind,
|
|
191
|
+
table_name=table.snake_case_name,
|
|
192
|
+
columns=COLUMN_LINESEP.join(c.snake_case_name for c in table.columns),
|
|
193
|
+
index_name=index_name(table.snake_case_name, *index_column_names),
|
|
194
|
+
condition=condition,
|
|
195
|
+
args=", ".join(index_column_names),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def render_attrs_main_loader(
|
|
200
|
+
tables: List[metaschema.Table],
|
|
201
|
+
) -> str:
|
|
202
|
+
loader_instance_defs = [
|
|
203
|
+
f"self.{table.snake_case_name} = {table.class_name}Loader(self._db)" for table in tables
|
|
204
|
+
]
|
|
205
|
+
table_loaders = "\n ".join(loader_instance_defs)
|
|
206
|
+
return ATTRS_MAIN_LOADER_TEMPLATE.format(table_loaders=table_loaders)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _import_lines(
|
|
210
|
+
tables: List[metaschema.Table],
|
|
211
|
+
attrs_module_name: Optional[str],
|
|
212
|
+
build_options: metaschema.BuildOptions,
|
|
213
|
+
):
|
|
214
|
+
# need typing always for List and Optional
|
|
215
|
+
stdlib_imports = sorted(
|
|
216
|
+
{
|
|
217
|
+
"typing",
|
|
218
|
+
*chain.from_iterable(t.attrs_sqlite_required_imports(build_options) for t in tables),
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
import_lines = [f"import {module}\n" for module in stdlib_imports]
|
|
222
|
+
if import_lines:
|
|
223
|
+
import_lines.append("\n")
|
|
224
|
+
|
|
225
|
+
import_lines.append(f"import {thds.tabularasa.loaders.sqlite_util.__name__} as util\n")
|
|
226
|
+
import_lines.append("\n")
|
|
227
|
+
if attrs_module_name is not None:
|
|
228
|
+
import_lines.append(f"from {attrs_module_name} import (\n")
|
|
229
|
+
attrs_module_classnames = {table.class_name for table in tables}
|
|
230
|
+
import_lines.extend(
|
|
231
|
+
f" {name},\n" for name in sorted_class_names_for_import(attrs_module_classnames)
|
|
232
|
+
)
|
|
233
|
+
import_lines.append(")\n")
|
|
234
|
+
|
|
235
|
+
return "".join(import_lines)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _has_index(table: metaschema.Table) -> bool:
|
|
239
|
+
return (table.primary_key is not None) or len(table.indexes) > 0
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def render_attrs_sqlite_schema(
|
|
243
|
+
schema: metaschema.Schema,
|
|
244
|
+
package: str = "",
|
|
245
|
+
db_path: str = "",
|
|
246
|
+
attrs_module_name: Optional[str] = ATTRS_MODULE_NAME,
|
|
247
|
+
) -> str:
|
|
248
|
+
has_database_loader = bool(package and db_path)
|
|
249
|
+
|
|
250
|
+
# do not generate a SQLite loader if there is no primary key or index defined on the table def
|
|
251
|
+
tables = [table for table in schema.package_tables if table.has_indexes]
|
|
252
|
+
tables_filtered = [table.name for table in schema.package_tables if not table.has_indexes]
|
|
253
|
+
if not tables:
|
|
254
|
+
_LOGGER.info(
|
|
255
|
+
f"Skipping SQLite loader generation for all tables: {tables_filtered}; none has any index "
|
|
256
|
+
f"specified"
|
|
257
|
+
)
|
|
258
|
+
return ""
|
|
259
|
+
|
|
260
|
+
if tables_filtered:
|
|
261
|
+
_LOGGER.info(
|
|
262
|
+
f"Skipping SQLite loader generation for the following "
|
|
263
|
+
f"tables because no indices or primary keys are defined: {', '.join(tables_filtered)}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
imports = _import_lines(tables, attrs_module_name, schema.build_options)
|
|
267
|
+
loader_defs = [render_attrs_loader_schema(table, schema.build_options) for table in tables]
|
|
268
|
+
loaders = "\n\n".join(loader_defs)
|
|
269
|
+
|
|
270
|
+
if has_database_loader:
|
|
271
|
+
constants = f'{PACKAGE_VARNAME} = "{package}"\n{DB_PATH_VARNAME} = "{db_path}"'
|
|
272
|
+
loader_def = render_attrs_main_loader(tables)
|
|
273
|
+
else:
|
|
274
|
+
constants = ""
|
|
275
|
+
loader_def = ""
|
|
276
|
+
return autoformat(
|
|
277
|
+
f"{imports}\n# {AUTOGEN_DISCLAIMER}\n\n{constants}\n\n\n{loaders}\n\n{loader_def}\n"
|
|
278
|
+
)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from functools import singledispatch
|
|
3
|
+
from itertools import starmap, zip_longest
|
|
4
|
+
from logging import getLogger
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterator, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
_LOGGER = getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
AST_CODE_CONTEXT_VARS = {"lineno", "col_offset", "ctx", "end_lineno", "end_col_offset"}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ast_eq(ast1: ast.AST, ast2: ast.AST) -> bool:
|
|
15
|
+
"""Return True if two python source strings are AST-equivalent, else False"""
|
|
16
|
+
return _ast_eq(ast1, ast2)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def ast_vars(node: ast.AST) -> Iterator[Tuple[str, Any]]:
|
|
20
|
+
"""Iterator of (name, value) tuples for all attributes of an AST node *except* for those that are
|
|
21
|
+
not abstract (e.g. line numbers and column offsets)"""
|
|
22
|
+
return (
|
|
23
|
+
(name, value)
|
|
24
|
+
for name, value in sorted(vars(node).items(), key=itemgetter(0))
|
|
25
|
+
if name not in AST_CODE_CONTEXT_VARS
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@singledispatch
|
|
30
|
+
def _ast_eq(ast1: Any, ast2: Any) -> bool:
|
|
31
|
+
# base case, literal values (non-AST nodes)
|
|
32
|
+
return (type(ast1) is type(ast2)) and (ast1 == ast2)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@_ast_eq.register(ast.AST)
|
|
36
|
+
def _ast_eq_ast(ast1: ast.AST, ast2: ast.AST) -> bool:
|
|
37
|
+
if type(ast1) is not type(ast2):
|
|
38
|
+
return False
|
|
39
|
+
attrs1 = ast_vars(ast1)
|
|
40
|
+
attrs2 = ast_vars(ast2)
|
|
41
|
+
return all(
|
|
42
|
+
(name1 == name2) and _ast_eq(attr1, attr2)
|
|
43
|
+
for (name1, attr1), (name2, attr2) in zip(
|
|
44
|
+
attrs1,
|
|
45
|
+
attrs2,
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@_ast_eq.register(list)
|
|
51
|
+
@_ast_eq.register(tuple)
|
|
52
|
+
def _ast_eq_list(ast1: List[Any], ast2: List[Any]):
|
|
53
|
+
missing = object()
|
|
54
|
+
return all(starmap(_ast_eq, zip_longest(ast1, ast2, fillvalue=missing)))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def write_if_ast_changed(source: str, path: Union[str, Path]): # pragma: no cover
|
|
58
|
+
"""Write the source code `source` to the file at `path`, but only if the file doesn't exist, or the
|
|
59
|
+
AST of the code therein differs from that of `source`"""
|
|
60
|
+
path = Path(path)
|
|
61
|
+
this_ast = ast.parse(source)
|
|
62
|
+
|
|
63
|
+
if path.exists():
|
|
64
|
+
with open(path, "r+") as f:
|
|
65
|
+
that_source = f.read()
|
|
66
|
+
try:
|
|
67
|
+
that_ast = ast.parse(that_source)
|
|
68
|
+
except SyntaxError:
|
|
69
|
+
_LOGGER.warning(
|
|
70
|
+
f"syntax error in code at {path}; merge conflicts? Code will be overwritten"
|
|
71
|
+
)
|
|
72
|
+
rewrite = True
|
|
73
|
+
reason = "Invalid AST"
|
|
74
|
+
else:
|
|
75
|
+
rewrite = not ast_eq(this_ast, that_ast)
|
|
76
|
+
reason = "AST changed"
|
|
77
|
+
|
|
78
|
+
if rewrite:
|
|
79
|
+
_LOGGER.info(f"writing new generated code to {path}; {reason}")
|
|
80
|
+
f.seek(0)
|
|
81
|
+
f.truncate()
|
|
82
|
+
f.write(source)
|
|
83
|
+
else:
|
|
84
|
+
_LOGGER.info(f"leaving generated code at {path}; AST unchanged")
|
|
85
|
+
else:
|
|
86
|
+
_LOGGER.info(f"writing new generated code to {path}; no prior file")
|
|
87
|
+
with open(path, "w") as f:
|
|
88
|
+
f.write(source)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def write_sql(source: str, path: Union[str, Path]): # pragma: no cover
|
|
92
|
+
"""Write the SQL source code `source` to the file at `path`"""
|
|
93
|
+
path = Path(path)
|
|
94
|
+
_LOGGER.info(f"writing new generated code to {path}; no prior file")
|
|
95
|
+
with open(path, "w") as f:
|
|
96
|
+
f.write(source)
|