thds.tabularasa 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. thds/tabularasa/__init__.py +6 -0
  2. thds/tabularasa/__main__.py +1122 -0
  3. thds/tabularasa/compat.py +33 -0
  4. thds/tabularasa/data_dependencies/__init__.py +0 -0
  5. thds/tabularasa/data_dependencies/adls.py +97 -0
  6. thds/tabularasa/data_dependencies/build.py +573 -0
  7. thds/tabularasa/data_dependencies/sqlite.py +286 -0
  8. thds/tabularasa/data_dependencies/tabular.py +167 -0
  9. thds/tabularasa/data_dependencies/util.py +209 -0
  10. thds/tabularasa/diff/__init__.py +0 -0
  11. thds/tabularasa/diff/data.py +346 -0
  12. thds/tabularasa/diff/schema.py +254 -0
  13. thds/tabularasa/diff/summary.py +249 -0
  14. thds/tabularasa/git_util.py +37 -0
  15. thds/tabularasa/loaders/__init__.py +0 -0
  16. thds/tabularasa/loaders/lazy_adls.py +44 -0
  17. thds/tabularasa/loaders/parquet_util.py +385 -0
  18. thds/tabularasa/loaders/sqlite_util.py +346 -0
  19. thds/tabularasa/loaders/util.py +532 -0
  20. thds/tabularasa/py.typed +0 -0
  21. thds/tabularasa/schema/__init__.py +7 -0
  22. thds/tabularasa/schema/compilation/__init__.py +20 -0
  23. thds/tabularasa/schema/compilation/_format.py +50 -0
  24. thds/tabularasa/schema/compilation/attrs.py +257 -0
  25. thds/tabularasa/schema/compilation/attrs_sqlite.py +278 -0
  26. thds/tabularasa/schema/compilation/io.py +96 -0
  27. thds/tabularasa/schema/compilation/pandas.py +252 -0
  28. thds/tabularasa/schema/compilation/pyarrow.py +93 -0
  29. thds/tabularasa/schema/compilation/sphinx.py +550 -0
  30. thds/tabularasa/schema/compilation/sqlite.py +69 -0
  31. thds/tabularasa/schema/compilation/util.py +117 -0
  32. thds/tabularasa/schema/constraints.py +327 -0
  33. thds/tabularasa/schema/dtypes.py +153 -0
  34. thds/tabularasa/schema/extract_from_parquet.py +132 -0
  35. thds/tabularasa/schema/files.py +215 -0
  36. thds/tabularasa/schema/metaschema.py +1007 -0
  37. thds/tabularasa/schema/util.py +123 -0
  38. thds/tabularasa/schema/validation.py +878 -0
  39. thds/tabularasa/sqlite3_compat.py +41 -0
  40. thds/tabularasa/sqlite_from_parquet.py +34 -0
  41. thds/tabularasa/to_sqlite.py +56 -0
  42. thds_tabularasa-0.13.0.dist-info/METADATA +530 -0
  43. thds_tabularasa-0.13.0.dist-info/RECORD +46 -0
  44. thds_tabularasa-0.13.0.dist-info/WHEEL +5 -0
  45. thds_tabularasa-0.13.0.dist-info/entry_points.txt +2 -0
  46. thds_tabularasa-0.13.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,257 @@
1
+ import typing as ty
2
+ from operator import itemgetter
3
+ from textwrap import wrap
4
+
5
+ import typing_extensions
6
+
7
+ import thds.tabularasa.loaders.util
8
+ from thds.tabularasa.schema import metaschema
9
+
10
+ from ._format import autoformat
11
+ from .util import (
12
+ AUTOGEN_DISCLAIMER,
13
+ VarName,
14
+ _indent,
15
+ _wrap_lines_with_prefix,
16
+ constructor_template,
17
+ render_blob_store_def,
18
+ render_constructor,
19
+ sorted_class_names_for_import,
20
+ )
21
+
22
+ DEFAULT_LINE_WIDTH = 88
23
+
24
+ REMOTE_BLOB_STORE_VAR_NAME = "REMOTE_BLOB_STORE"
25
+
26
+ DOCSTRING_PARAM_TEMPLATE = """:param {name}: {doc}"""
27
+
28
+ CUSTOM_TYPE_DEF_TEMPLATE = """{comment}{name} = {type}"""
29
+
30
+ ATTRS_CLASS_DEF_TEMPLATE = """@attr.s(auto_attribs=True, frozen=True)
31
+ class {class_name}:
32
+ \"\"\"{doc}
33
+
34
+ {params}
35
+ \"\"\"
36
+
37
+ {fields}
38
+ """
39
+
40
+ ATTRS_FIELD_DEF_TEMPLATE_BASIC = "{name}: {type}"
41
+
42
+ ATTRS_LOADER_TEMPLATE = constructor_template(
43
+ thds.tabularasa.loaders.util.AttrsParquetLoader,
44
+ exclude=["filename"],
45
+ type_params=["{record_type}"],
46
+ )
47
+
48
+
49
+ def render_type_def(
50
+ type_: metaschema.CustomType,
51
+ build_options: metaschema.BuildOptions,
52
+ ) -> str:
53
+ type_literal = type_.python_type_def_literal(build_options)
54
+
55
+ if build_options.type_constraint_comments:
56
+ comment = type_.comment
57
+ if comment:
58
+ lines = wrap(comment, DEFAULT_LINE_WIDTH - 2)
59
+ comment = "\n".join("# " + line for line in lines) + "\n"
60
+ else:
61
+ comment = ""
62
+ else:
63
+ comment = ""
64
+
65
+ return CUSTOM_TYPE_DEF_TEMPLATE.format(comment=comment, name=type_.class_name, type=type_literal)
66
+
67
+
68
+ def render_attr_field_def(
69
+ column: metaschema.Column, build_options: metaschema.BuildOptions, builtin: bool = False
70
+ ) -> str:
71
+ type_literal = column.python_type_literal(build_options=build_options, builtin=builtin)
72
+ return ATTRS_FIELD_DEF_TEMPLATE_BASIC.format(name=column.snake_case_name, type=type_literal)
73
+
74
+
75
+ def render_attrs_table_schema(table: metaschema.Table, build_options: metaschema.BuildOptions) -> str:
76
+ field_defs = []
77
+ params = []
78
+
79
+ for column in table.columns:
80
+ field_def = render_attr_field_def(column, builtin=False, build_options=build_options)
81
+ field_defs.append(field_def)
82
+ doc = _wrap_lines_with_prefix(
83
+ column.doc,
84
+ DEFAULT_LINE_WIDTH - 4,
85
+ first_line_prefix_len=len(f":param {column.snake_case_name}: "),
86
+ trailing_line_indent=2,
87
+ )
88
+ params.append(
89
+ DOCSTRING_PARAM_TEMPLATE.format(
90
+ name=column.snake_case_name,
91
+ doc=doc,
92
+ )
93
+ )
94
+
95
+ table_doc = _wrap_lines_with_prefix(
96
+ table.doc,
97
+ DEFAULT_LINE_WIDTH - 4,
98
+ first_line_prefix_len=3, # triple quotes
99
+ trailing_line_indent=0,
100
+ )
101
+
102
+ return ATTRS_CLASS_DEF_TEMPLATE.format(
103
+ class_name=table.class_name,
104
+ doc=_indent(table_doc),
105
+ params=_indent("\n".join(params)),
106
+ fields=_indent("\n".join(field_defs)),
107
+ )
108
+
109
+
110
+ PYARROW_SCHEMAS_QUALIFIED_IMPORT = "pyarrow_schemas"
111
+
112
+
113
+ class ImportsAndCode(ty.NamedTuple):
114
+ """Couples code and its required imports."""
115
+
116
+ third_party_imports: ty.List[str]
117
+ tabularasa_imports: ty.List[str]
118
+ code: ty.List[str]
119
+
120
+
121
+ def render_attrs_loaders(
122
+ schema: metaschema.Schema,
123
+ package: str,
124
+ ) -> ImportsAndCode:
125
+ data_dir = schema.build_options.package_data_dir
126
+ render_pyarrow_schemas = schema.build_options.pyarrow
127
+ import_lines = list()
128
+ if render_pyarrow_schemas:
129
+ import_lines.append("\n")
130
+ import_lines.append(f"from . import pyarrow as {PYARROW_SCHEMAS_QUALIFIED_IMPORT}")
131
+
132
+ return ImportsAndCode(
133
+ list(),
134
+ import_lines,
135
+ [
136
+ render_constructor(
137
+ ATTRS_LOADER_TEMPLATE,
138
+ kwargs=dict(
139
+ record_type=VarName(table.class_name),
140
+ table_name=table.snake_case_name,
141
+ type_=VarName(table.class_name),
142
+ package=package,
143
+ data_dir=data_dir,
144
+ md5=table.md5,
145
+ blob_store=(
146
+ None
147
+ if schema.remote_blob_store is None or table.md5 is None
148
+ else VarName(REMOTE_BLOB_STORE_VAR_NAME)
149
+ ),
150
+ pyarrow_schema=(
151
+ VarName(
152
+ f"{PYARROW_SCHEMAS_QUALIFIED_IMPORT}.{table.snake_case_name}_pyarrow_schema"
153
+ )
154
+ if render_pyarrow_schemas
155
+ else None
156
+ ),
157
+ ),
158
+ var_name=f"load_{table.snake_case_name}",
159
+ )
160
+ for table in schema.package_tables
161
+ ],
162
+ )
163
+
164
+
165
+ def render_attrs_type_defs(
166
+ schema: metaschema.Schema,
167
+ ) -> ImportsAndCode:
168
+ # custom types
169
+ defined_custom_types = schema.defined_types
170
+ type_defs = [
171
+ render_type_def(
172
+ type_,
173
+ build_options=schema.build_options,
174
+ )
175
+ for type_ in sorted(defined_custom_types, key=lambda type_: type_.name)
176
+ ]
177
+
178
+ import_lines = list()
179
+ # external type imports
180
+ sep = ",\n "
181
+ for module_name, class_names in sorted(schema.external_type_imports.items(), key=itemgetter(0)):
182
+ import_lines.append(
183
+ f"from {module_name} import (\n {sep.join(sorted_class_names_for_import(class_names))},\n)\n"
184
+ )
185
+
186
+ return ImportsAndCode([], import_lines, type_defs)
187
+
188
+
189
+ def _render_attrs_schema(
190
+ schema: metaschema.Schema,
191
+ type_defs: ImportsAndCode,
192
+ loader_defs: ty.Optional[ImportsAndCode],
193
+ ) -> str:
194
+ loader_defs = loader_defs or ImportsAndCode([], [], [])
195
+ assert loader_defs, "Loaders are optional but the line above is not"
196
+
197
+ # attrs record types
198
+ table_defs = [
199
+ render_attrs_table_schema(table, schema.build_options) for table in schema.package_tables
200
+ ]
201
+
202
+ # imports
203
+ stdlib_imports = sorted(schema.attrs_required_imports)
204
+ import_extensions = typing_extensions.__name__ in stdlib_imports
205
+ if import_extensions:
206
+ stdlib_imports.remove(typing_extensions.__name__)
207
+ import_lines = [f"import {module}\n" for module in stdlib_imports]
208
+
209
+ if import_lines:
210
+ import_lines.append("\n")
211
+ import_lines.append("import attr\n")
212
+ if import_extensions:
213
+ import_lines.append(f"import {typing_extensions.__name__}\n")
214
+
215
+ import_lines.append("\n")
216
+
217
+ if loader_defs.code:
218
+ import_lines.append(f"import {thds.tabularasa.loaders.util.__name__}\n")
219
+ if schema.remote_blob_store is not None:
220
+ import_lines.append(f"import {thds.tabularasa.schema.files.__name__}\n")
221
+
222
+ import_lines.extend(type_defs.tabularasa_imports)
223
+ import_lines.extend(loader_defs.tabularasa_imports)
224
+
225
+ # globals
226
+ global_var_defs = []
227
+ if schema.remote_blob_store is not None:
228
+ global_var_defs.append(
229
+ render_blob_store_def(schema.remote_blob_store, REMOTE_BLOB_STORE_VAR_NAME)
230
+ )
231
+
232
+ imports = "".join(import_lines)
233
+ globals_ = "\n".join(global_var_defs)
234
+ types = "\n".join(type_defs.code)
235
+ classes = "\n\n".join(table_defs)
236
+ loaders = "\n\n".join(loader_defs.code)
237
+
238
+ # module
239
+ return autoformat(
240
+ f"{imports}\n# {AUTOGEN_DISCLAIMER}\n\n{globals_}\n\n{types}\n\n\n{classes}\n\n{loaders}\n"
241
+ )
242
+
243
+
244
+ def render_attrs_module(
245
+ schema: metaschema.Schema,
246
+ package: str,
247
+ loader_defs: ty.Optional[ImportsAndCode] = None,
248
+ ) -> str:
249
+ if loader_defs is None:
250
+ loader_defs = (
251
+ render_attrs_loaders(schema, package) if schema.build_options.package_data_dir else None
252
+ )
253
+ return _render_attrs_schema(
254
+ schema,
255
+ render_attrs_type_defs(schema),
256
+ loader_defs,
257
+ )
@@ -0,0 +1,278 @@
1
+ from itertools import chain
2
+ from logging import getLogger
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import thds.tabularasa.loaders.util
6
+ import thds.tabularasa.schema
7
+ from thds.tabularasa.schema import metaschema
8
+ from thds.tabularasa.schema.compilation.attrs import render_attr_field_def
9
+
10
+ from ._format import autoformat
11
+ from .sqlite import index_name
12
+ from .util import AUTOGEN_DISCLAIMER, sorted_class_names_for_import
13
+
14
+ _LOGGER = getLogger(__name__)
15
+
16
+ PACKAGE_VARNAME = "PACKAGE"
17
+ DB_PATH_VARNAME = "DB_PATH"
18
+
19
+ ATTRS_MODULE_NAME = ".attrs"
20
+
21
+ LINE_WIDTH = 88
22
+
23
+ COLUMN_LINESEP = ",\n "
24
+
25
+ ATTRS_CLASS_LOADER_TEMPLATE = """class {class_name}Loader:
26
+
27
+ def __init__(self, db: util.%s):
28
+ self._db = db
29
+ self._record = util.%s({class_name})
30
+
31
+ {accessors}
32
+ """ % (
33
+ thds.tabularasa.loaders.util.AttrsSQLiteDatabase.__name__,
34
+ thds.tabularasa.loaders.sqlite_util.sqlite_constructor_for_record_type.__name__, # type: ignore
35
+ )
36
+
37
+ ATTRS_INDEX_ACCESSOR_TEMPLATE = """
38
+ def {method_name}(self, {typed_args}) -> typing.{return_type}[{class_name}]:
39
+ return self._db.sqlite_{index_kind}_query(
40
+ self._record,
41
+ \"\"\"
42
+ SELECT
43
+ {columns}
44
+ FROM {table_name}
45
+ INDEXED BY {index_name}
46
+ WHERE {condition};
47
+ \"\"\",
48
+ ({args},),
49
+ )
50
+ """
51
+
52
+ ATTRS_BULK_INDEX_ACCESSOR_TEMPLATE = """
53
+ def {method_name}_bulk(self, {arg_name}: typing.Collection[{typed_args}]) -> typing.{return_type}[{class_name}]:
54
+ if {arg_name}:
55
+ return self._db.sqlite_{index_kind}_query(
56
+ self._record,
57
+ f\"\"\"
58
+ SELECT
59
+ {columns}
60
+ FROM {table_name}
61
+ INDEXED BY {index_name}
62
+ WHERE {condition};
63
+ \"\"\",
64
+ {arg_name},
65
+ single_col={single_col},
66
+ )
67
+ else:
68
+ return iter(())
69
+ """
70
+
71
+ ATTRS_MAIN_LOADER_TEMPLATE = """class SQLiteLoader:
72
+ def __init__(
73
+ self,
74
+ package: typing.Optional[str] = %s,
75
+ db_path: str = %s,
76
+ cache_size: int = util.DEFAULT_ATTR_SQLITE_CACHE_SIZE,
77
+ mmap_size: int = util.DEFAULT_MMAP_BYTES,
78
+ ):
79
+ self._db = util.%s(package=package, db_path=db_path, cache_size=cache_size, mmap_size=mmap_size)
80
+ {table_loaders}
81
+ """ % (
82
+ PACKAGE_VARNAME,
83
+ DB_PATH_VARNAME,
84
+ thds.tabularasa.loaders.util.AttrsSQLiteDatabase.__name__,
85
+ )
86
+
87
+
88
+ def render_attrs_loader_schema(table: metaschema.Table, build_options: metaschema.BuildOptions) -> str:
89
+ accessor_defs = []
90
+ column_lookup = {col.name: col for col in table.columns}
91
+ unq_constraints = {frozenset(c.unique) for c in table.unique_constraints}
92
+
93
+ if table.primary_key:
94
+ accessor_defs.append(
95
+ render_accessor_method(
96
+ table, table.primary_key, column_lookup, pk=True, build_options=build_options
97
+ )
98
+ )
99
+ accessor_defs.append(
100
+ render_accessor_method(
101
+ table, table.primary_key, column_lookup, pk=True, bulk=True, build_options=build_options
102
+ )
103
+ )
104
+
105
+ for idx in table.indexes:
106
+ unique = frozenset(idx) in unq_constraints
107
+ accessor_defs.append(
108
+ render_accessor_method(
109
+ table, idx, column_lookup, pk=False, unique=unique, build_options=build_options
110
+ )
111
+ )
112
+ accessor_defs.append(
113
+ render_accessor_method(
114
+ table,
115
+ idx,
116
+ column_lookup,
117
+ pk=False,
118
+ unique=unique,
119
+ bulk=True,
120
+ build_options=build_options,
121
+ )
122
+ )
123
+
124
+ accessors = "".join(accessor_defs).strip()
125
+ return ATTRS_CLASS_LOADER_TEMPLATE.format(
126
+ class_name=table.class_name,
127
+ accessors=accessors,
128
+ )
129
+
130
+
131
+ def render_accessor_method(
132
+ table: metaschema.Table,
133
+ index_columns: Tuple[str, ...],
134
+ column_lookup: Dict[str, metaschema.Column],
135
+ build_options: metaschema.BuildOptions,
136
+ pk: bool = False,
137
+ unique: bool = False,
138
+ bulk: bool = False,
139
+ ) -> str:
140
+ index_column_names = tuple(map(metaschema.snake_case, index_columns))
141
+ method_name = "pk" if pk else f"idx_{'_'.join(index_column_names)}"
142
+ index_kind = "bulk" if bulk else ("pk" if pk or unique else "index")
143
+ return_type = "Iterator" if bulk else ("Optional" if pk or unique else "List")
144
+ # we use the `IS` operator to allow for comparison in case of nullable index columns
145
+ arg_name = "__".join(index_column_names)
146
+ nullsafe_condition = " AND ".join([f"{column} IS (?)" for column in index_column_names])
147
+ if bulk:
148
+ columns = [column_lookup[col] for col in index_columns]
149
+ has_null_cols = any(col.nullable for col in columns)
150
+ if len(index_column_names) == 1:
151
+ # for single-column indexes, we need to unpack the single value from the tuple
152
+ typed_args = columns[0].python_type_literal(build_options=build_options, builtin=True)
153
+ if has_null_cols:
154
+ # sqlite IN operator unfortunately doesn't support a NULL == NULL variant, the way that IS does for =
155
+ condition = f"{{' OR '.join(['{index_column_names[0]} IS (?)'] * len({arg_name}))}}"
156
+ else:
157
+ condition = f"{index_column_names[0]} IN ({{','.join('?' * len({arg_name}))}})"
158
+ single_col = "True"
159
+ else:
160
+ type_strs = ", ".join(
161
+ col.python_type_literal(build_options=build_options, builtin=True) for col in columns
162
+ )
163
+ typed_args = f"typing.Tuple[{type_strs}]"
164
+ if has_null_cols:
165
+ # sqlite IN operator unfortunately doesn't support a NULL == NULL variant, the way that IS does for =
166
+ param_tuple = f"({nullsafe_condition})"
167
+ condition = f"{{' OR '.join(['{param_tuple}'] * len({arg_name}))}}"
168
+ else:
169
+ param_tuple = f"({', '.join('?' * len(index_column_names))})"
170
+ condition = f"({', '.join(index_column_names)}) IN ({{','.join(['{param_tuple}'] * len({arg_name}))}})"
171
+ single_col = "False"
172
+ else:
173
+ condition = nullsafe_condition
174
+ typed_args = ", ".join(
175
+ [
176
+ render_attr_field_def(column_lookup[col], builtin=True, build_options=build_options)
177
+ for col in index_columns
178
+ ]
179
+ )
180
+ single_col = ""
181
+
182
+ return (ATTRS_BULK_INDEX_ACCESSOR_TEMPLATE if bulk else ATTRS_INDEX_ACCESSOR_TEMPLATE).format(
183
+ class_name=table.class_name,
184
+ method_name=method_name,
185
+ # use builtin types (as opposed to e.g. Literals and Newtypes) to make the API simpler to use
186
+ typed_args=typed_args,
187
+ arg_name=arg_name,
188
+ single_col=single_col,
189
+ return_type=return_type,
190
+ index_kind=index_kind,
191
+ table_name=table.snake_case_name,
192
+ columns=COLUMN_LINESEP.join(c.snake_case_name for c in table.columns),
193
+ index_name=index_name(table.snake_case_name, *index_column_names),
194
+ condition=condition,
195
+ args=", ".join(index_column_names),
196
+ )
197
+
198
+
199
+ def render_attrs_main_loader(
200
+ tables: List[metaschema.Table],
201
+ ) -> str:
202
+ loader_instance_defs = [
203
+ f"self.{table.snake_case_name} = {table.class_name}Loader(self._db)" for table in tables
204
+ ]
205
+ table_loaders = "\n ".join(loader_instance_defs)
206
+ return ATTRS_MAIN_LOADER_TEMPLATE.format(table_loaders=table_loaders)
207
+
208
+
209
+ def _import_lines(
210
+ tables: List[metaschema.Table],
211
+ attrs_module_name: Optional[str],
212
+ build_options: metaschema.BuildOptions,
213
+ ):
214
+ # need typing always for List and Optional
215
+ stdlib_imports = sorted(
216
+ {
217
+ "typing",
218
+ *chain.from_iterable(t.attrs_sqlite_required_imports(build_options) for t in tables),
219
+ }
220
+ )
221
+ import_lines = [f"import {module}\n" for module in stdlib_imports]
222
+ if import_lines:
223
+ import_lines.append("\n")
224
+
225
+ import_lines.append(f"import {thds.tabularasa.loaders.sqlite_util.__name__} as util\n")
226
+ import_lines.append("\n")
227
+ if attrs_module_name is not None:
228
+ import_lines.append(f"from {attrs_module_name} import (\n")
229
+ attrs_module_classnames = {table.class_name for table in tables}
230
+ import_lines.extend(
231
+ f" {name},\n" for name in sorted_class_names_for_import(attrs_module_classnames)
232
+ )
233
+ import_lines.append(")\n")
234
+
235
+ return "".join(import_lines)
236
+
237
+
238
+ def _has_index(table: metaschema.Table) -> bool:
239
+ return (table.primary_key is not None) or len(table.indexes) > 0
240
+
241
+
242
+ def render_attrs_sqlite_schema(
243
+ schema: metaschema.Schema,
244
+ package: str = "",
245
+ db_path: str = "",
246
+ attrs_module_name: Optional[str] = ATTRS_MODULE_NAME,
247
+ ) -> str:
248
+ has_database_loader = bool(package and db_path)
249
+
250
+ # do not generate a SQLite loader if there is no primary key or index defined on the table def
251
+ tables = [table for table in schema.package_tables if table.has_indexes]
252
+ tables_filtered = [table.name for table in schema.package_tables if not table.has_indexes]
253
+ if not tables:
254
+ _LOGGER.info(
255
+ f"Skipping SQLite loader generation for all tables: {tables_filtered}; none has any index "
256
+ f"specified"
257
+ )
258
+ return ""
259
+
260
+ if tables_filtered:
261
+ _LOGGER.info(
262
+ f"Skipping SQLite loader generation for the following "
263
+ f"tables because no indices or primary keys are defined: {', '.join(tables_filtered)}"
264
+ )
265
+
266
+ imports = _import_lines(tables, attrs_module_name, schema.build_options)
267
+ loader_defs = [render_attrs_loader_schema(table, schema.build_options) for table in tables]
268
+ loaders = "\n\n".join(loader_defs)
269
+
270
+ if has_database_loader:
271
+ constants = f'{PACKAGE_VARNAME} = "{package}"\n{DB_PATH_VARNAME} = "{db_path}"'
272
+ loader_def = render_attrs_main_loader(tables)
273
+ else:
274
+ constants = ""
275
+ loader_def = ""
276
+ return autoformat(
277
+ f"{imports}\n# {AUTOGEN_DISCLAIMER}\n\n{constants}\n\n\n{loaders}\n\n{loader_def}\n"
278
+ )
@@ -0,0 +1,96 @@
1
+ import ast
2
+ from functools import singledispatch
3
+ from itertools import starmap, zip_longest
4
+ from logging import getLogger
5
+ from operator import itemgetter
6
+ from pathlib import Path
7
+ from typing import Any, Iterator, List, Tuple, Union
8
+
9
+ _LOGGER = getLogger(__name__)
10
+
11
+ AST_CODE_CONTEXT_VARS = {"lineno", "col_offset", "ctx", "end_lineno", "end_col_offset"}
12
+
13
+
14
+ def ast_eq(ast1: ast.AST, ast2: ast.AST) -> bool:
15
+ """Return True if two python source strings are AST-equivalent, else False"""
16
+ return _ast_eq(ast1, ast2)
17
+
18
+
19
+ def ast_vars(node: ast.AST) -> Iterator[Tuple[str, Any]]:
20
+ """Iterator of (name, value) tuples for all attributes of an AST node *except* for those that are
21
+ not abstract (e.g. line numbers and column offsets)"""
22
+ return (
23
+ (name, value)
24
+ for name, value in sorted(vars(node).items(), key=itemgetter(0))
25
+ if name not in AST_CODE_CONTEXT_VARS
26
+ )
27
+
28
+
29
+ @singledispatch
30
+ def _ast_eq(ast1: Any, ast2: Any) -> bool:
31
+ # base case, literal values (non-AST nodes)
32
+ return (type(ast1) is type(ast2)) and (ast1 == ast2)
33
+
34
+
35
+ @_ast_eq.register(ast.AST)
36
+ def _ast_eq_ast(ast1: ast.AST, ast2: ast.AST) -> bool:
37
+ if type(ast1) is not type(ast2):
38
+ return False
39
+ attrs1 = ast_vars(ast1)
40
+ attrs2 = ast_vars(ast2)
41
+ return all(
42
+ (name1 == name2) and _ast_eq(attr1, attr2)
43
+ for (name1, attr1), (name2, attr2) in zip(
44
+ attrs1,
45
+ attrs2,
46
+ )
47
+ )
48
+
49
+
50
+ @_ast_eq.register(list)
51
+ @_ast_eq.register(tuple)
52
+ def _ast_eq_list(ast1: List[Any], ast2: List[Any]):
53
+ missing = object()
54
+ return all(starmap(_ast_eq, zip_longest(ast1, ast2, fillvalue=missing)))
55
+
56
+
57
+ def write_if_ast_changed(source: str, path: Union[str, Path]): # pragma: no cover
58
+ """Write the source code `source` to the file at `path`, but only if the file doesn't exist, or the
59
+ AST of the code therein differs from that of `source`"""
60
+ path = Path(path)
61
+ this_ast = ast.parse(source)
62
+
63
+ if path.exists():
64
+ with open(path, "r+") as f:
65
+ that_source = f.read()
66
+ try:
67
+ that_ast = ast.parse(that_source)
68
+ except SyntaxError:
69
+ _LOGGER.warning(
70
+ f"syntax error in code at {path}; merge conflicts? Code will be overwritten"
71
+ )
72
+ rewrite = True
73
+ reason = "Invalid AST"
74
+ else:
75
+ rewrite = not ast_eq(this_ast, that_ast)
76
+ reason = "AST changed"
77
+
78
+ if rewrite:
79
+ _LOGGER.info(f"writing new generated code to {path}; {reason}")
80
+ f.seek(0)
81
+ f.truncate()
82
+ f.write(source)
83
+ else:
84
+ _LOGGER.info(f"leaving generated code at {path}; AST unchanged")
85
+ else:
86
+ _LOGGER.info(f"writing new generated code to {path}; no prior file")
87
+ with open(path, "w") as f:
88
+ f.write(source)
89
+
90
+
91
+ def write_sql(source: str, path: Union[str, Path]): # pragma: no cover
92
+ """Write the SQL source code `source` to the file at `path`"""
93
+ path = Path(path)
94
+ _LOGGER.info(f"writing new generated code to {path}; no prior file")
95
+ with open(path, "w") as f:
96
+ f.write(source)