thds.tabularasa 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. thds/tabularasa/__init__.py +6 -0
  2. thds/tabularasa/__main__.py +1122 -0
  3. thds/tabularasa/compat.py +33 -0
  4. thds/tabularasa/data_dependencies/__init__.py +0 -0
  5. thds/tabularasa/data_dependencies/adls.py +97 -0
  6. thds/tabularasa/data_dependencies/build.py +573 -0
  7. thds/tabularasa/data_dependencies/sqlite.py +286 -0
  8. thds/tabularasa/data_dependencies/tabular.py +167 -0
  9. thds/tabularasa/data_dependencies/util.py +209 -0
  10. thds/tabularasa/diff/__init__.py +0 -0
  11. thds/tabularasa/diff/data.py +346 -0
  12. thds/tabularasa/diff/schema.py +254 -0
  13. thds/tabularasa/diff/summary.py +249 -0
  14. thds/tabularasa/git_util.py +37 -0
  15. thds/tabularasa/loaders/__init__.py +0 -0
  16. thds/tabularasa/loaders/lazy_adls.py +44 -0
  17. thds/tabularasa/loaders/parquet_util.py +385 -0
  18. thds/tabularasa/loaders/sqlite_util.py +346 -0
  19. thds/tabularasa/loaders/util.py +532 -0
  20. thds/tabularasa/py.typed +0 -0
  21. thds/tabularasa/schema/__init__.py +7 -0
  22. thds/tabularasa/schema/compilation/__init__.py +20 -0
  23. thds/tabularasa/schema/compilation/_format.py +50 -0
  24. thds/tabularasa/schema/compilation/attrs.py +257 -0
  25. thds/tabularasa/schema/compilation/attrs_sqlite.py +278 -0
  26. thds/tabularasa/schema/compilation/io.py +96 -0
  27. thds/tabularasa/schema/compilation/pandas.py +252 -0
  28. thds/tabularasa/schema/compilation/pyarrow.py +93 -0
  29. thds/tabularasa/schema/compilation/sphinx.py +550 -0
  30. thds/tabularasa/schema/compilation/sqlite.py +69 -0
  31. thds/tabularasa/schema/compilation/util.py +117 -0
  32. thds/tabularasa/schema/constraints.py +327 -0
  33. thds/tabularasa/schema/dtypes.py +153 -0
  34. thds/tabularasa/schema/extract_from_parquet.py +132 -0
  35. thds/tabularasa/schema/files.py +215 -0
  36. thds/tabularasa/schema/metaschema.py +1007 -0
  37. thds/tabularasa/schema/util.py +123 -0
  38. thds/tabularasa/schema/validation.py +878 -0
  39. thds/tabularasa/sqlite3_compat.py +41 -0
  40. thds/tabularasa/sqlite_from_parquet.py +34 -0
  41. thds/tabularasa/to_sqlite.py +56 -0
  42. thds_tabularasa-0.13.0.dist-info/METADATA +530 -0
  43. thds_tabularasa-0.13.0.dist-info/RECORD +46 -0
  44. thds_tabularasa-0.13.0.dist-info/WHEEL +5 -0
  45. thds_tabularasa-0.13.0.dist-info/entry_points.txt +2 -0
  46. thds_tabularasa-0.13.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,346 @@
1
+ import contextlib
2
+ import datetime
3
+ import itertools
4
+ import json
5
+ import logging
6
+ import os
7
+ import sys
8
+ import typing as ty
9
+ from functools import lru_cache, wraps
10
+ from pathlib import Path
11
+ from typing import Callable, Optional, Type
12
+
13
+ import attr
14
+ import cattrs.preconf.json
15
+ import pkg_resources
16
+ from filelock import FileLock
17
+ from typing_inspect import get_args, get_origin, is_literal_type, is_optional_type, is_union_type
18
+
19
+ from thds.core.types import StrOrPath
20
+ from thds.tabularasa.schema.dtypes import DType
21
+ from thds.tabularasa.sqlite3_compat import sqlite3
22
+
23
+ DEFAULT_ATTR_SQLITE_CACHE_SIZE = 100_000
24
+ DEFAULT_MMAP_BYTES = int(os.environ.get("TABULA_RASA_DEFAULT_MMAP_BYTES", 8_589_934_592)) # 8 GB
25
+ DISABLE_WAL_MODE = bool(os.environ.get("REF_D_DISABLE_SQLITE_WAL_MODE", False))
26
+
27
+ PARAMETERIZABLE_BUILTINS = sys.version_info >= (3, 9)
28
+
29
+ if not PARAMETERIZABLE_BUILTINS:
30
+ _builtin_to_typing = {
31
+ list: ty.List,
32
+ set: ty.Set,
33
+ frozenset: ty.FrozenSet,
34
+ tuple: ty.Tuple,
35
+ dict: ty.Dict,
36
+ }
37
+
38
+ def get_generic_origin(t) -> ty.Optional[ty.Type]:
39
+ org = get_origin(t)
40
+ return None if org is None else _builtin_to_typing.get(org, org) # type: ignore
41
+
42
+ else:
43
+ get_generic_origin = get_origin
44
+
45
+
46
+ LITERAL_SQLITE_TYPES = {int, float, bool, str, type(None), datetime.date, datetime.datetime}
47
+
48
+
49
+ CONVERTER = cattrs.preconf.json.make_converter()
50
+
51
+
52
+ def structure_date(s: str, dt: ty.Type[datetime.date] = datetime.date) -> datetime.date:
53
+ return dt.fromisoformat(s)
54
+
55
+
56
+ CONVERTER.register_structure_hook(datetime.date, structure_date)
57
+ CONVERTER.register_unstructure_hook(datetime.date, datetime.date.isoformat)
58
+
59
+ T = ty.TypeVar("T")
60
+ Record = ty.TypeVar("Record", bound=attr.AttrsInstance)
61
+
62
+
63
+ @lru_cache(None)
64
+ def sqlite_postprocessor_for_type(t: ty.Type[T]) -> Optional[Callable[[ty.AnyStr], Optional[T]]]:
65
+ """Construct a parser converting an optional JSON string from a sqlite query to a python value of
66
+ type T"""
67
+ t = resolve_newtypes(t)
68
+
69
+ if is_primitive_type(t):
70
+ return None
71
+
72
+ if is_optional_type(t):
73
+
74
+ def parse(s: ty.AnyStr) -> Optional[T]:
75
+ if s is None:
76
+ return None
77
+ raw = json.loads(s)
78
+ return CONVERTER.structure(raw, t)
79
+
80
+ else:
81
+
82
+ def parse(s: ty.AnyStr) -> Optional[T]:
83
+ raw = json.loads(s)
84
+ return CONVERTER.structure(raw, t)
85
+
86
+ return parse
87
+
88
+
89
+ @lru_cache(None)
90
+ def sqlite_preprocessor_for_type(t: ty.Type[T]) -> Optional[Callable[[T], Optional[str]]]:
91
+ """Prepare a value of type T for inserting into a sqlite TEXT/JSON column by serializing it as
92
+ JSON"""
93
+ t = resolve_newtypes(t)
94
+
95
+ if is_primitive_type(t):
96
+ return None
97
+
98
+ if is_optional_type(t):
99
+
100
+ def unparse(value: T) -> Optional[str]:
101
+ if value is None:
102
+ return None
103
+ raw = CONVERTER.unstructure(value, t)
104
+ return json.dumps(raw, check_circular=False, indent=None)
105
+
106
+ else:
107
+
108
+ def unparse(value: T) -> Optional[str]:
109
+ raw = CONVERTER.unstructure(value, t)
110
+ return json.dumps(raw, check_circular=False, indent=None)
111
+
112
+ return unparse
113
+
114
+
115
+ def resolve_newtypes(t: Type) -> Type:
116
+ supertype = getattr(t, "__supertype__", None)
117
+ if supertype is None:
118
+ origin = get_generic_origin(t)
119
+ if origin is None:
120
+ return t
121
+ args = get_args(t)
122
+ if not args:
123
+ return t
124
+ args_resolved = tuple(resolve_newtypes(a) for a in args)
125
+ return origin[args_resolved]
126
+ return resolve_newtypes(supertype)
127
+
128
+
129
+ def is_primitive_type(t: Type) -> bool:
130
+ if t in LITERAL_SQLITE_TYPES:
131
+ return True
132
+ elif is_optional_type(t) and is_primitive_type(nonnull_type_of(t)):
133
+ return True
134
+ elif is_literal_type(t):
135
+ return True
136
+ else:
137
+ return False
138
+
139
+
140
+ def nonnull_type_of(t: Type) -> Type:
141
+ if not is_union_type(t):
142
+ return t
143
+ types = get_args(t)
144
+ if type(None) not in types:
145
+ return t
146
+ nonnull_types = tuple(t_ for t_ in types if t_ is not type(None)) # noqa
147
+ return ty.Union[nonnull_types] # type: ignore
148
+
149
+
150
+ def to_local_path(path: StrOrPath, package: Optional[str] = None) -> Path:
151
+ if package is None:
152
+ return Path(path)
153
+ else:
154
+ return Path(pkg_resources.resource_filename(package, str(path)))
155
+
156
+
157
+ def set_bulk_write_mode(con: sqlite3.Connection) -> sqlite3.Connection:
158
+ logger = logging.getLogger(__name__)
159
+ logger.debug("Setting pragmas for bulk write optimization")
160
+ # https://www.sqlite.org/pragma.html#pragma_synchronous
161
+ _log_exec_sql(logger, con, "PRAGMA synchronous = 0") # OFF
162
+ # https://www.sqlite.org/pragma.html#pragma_journal_mode
163
+ if not DISABLE_WAL_MODE:
164
+ _log_exec_sql(logger, con, "PRAGMA journal_mode = WAL")
165
+ # https://www.sqlite.org/pragma.html#pragma_locking_mode
166
+ _log_exec_sql(logger, con, "PRAGMA locking_mode = EXCLUSIVE")
167
+
168
+ return con
169
+
170
+
171
+ def unset_bulk_write_mode(con: sqlite3.Connection) -> sqlite3.Connection:
172
+ logger = logging.getLogger(__name__)
173
+ logger.debug("Setting pragmas for bulk write optimization")
174
+ # https://www.sqlite.org/pragma.html#pragma_journal_mode
175
+ # resetting this to the default. This is a property of the database, rather than the connection.
176
+ # the other settings are connection-specific.
177
+ # according to the docs, the WAL journal mode should be disabled before the locking mode is restored,
178
+ # else any attempt to do so is a no-op.
179
+ _log_exec_sql(logger, con, "PRAGMA journal_mode = DELETE")
180
+ # https://www.sqlite.org/pragma.html#pragma_synchronous
181
+ _log_exec_sql(logger, con, "PRAGMA synchronous = 2") # FULL (default)
182
+ # https://www.sqlite.org/pragma.html#pragma_locking_mode
183
+ _log_exec_sql(logger, con, "PRAGMA locking_mode = NORMAL")
184
+
185
+ return con
186
+
187
+
188
+ @contextlib.contextmanager
189
+ def bulk_write_connection(
190
+ db_path: StrOrPath, db_package: Optional[str] = None, close: bool = True
191
+ ) -> ty.Generator[sqlite3.Connection, None, None]:
192
+ """Context manager to set/unset bulk write mode on a sqlite connection. Sets pragmas for efficient bulk writes,
193
+ such as loosening synchronous and locking modes. If `close` is True, the connection will be closed on exit.
194
+ Since setting the pragmas may mutate the database file, and since by design this context manager exists to enable
195
+ bulk writes which intentionally mutate the database, if a `db_path` (and optionally `db_package`, if specified as
196
+ package data) is given, we also acquire a file lock on the database file on entry and release it on exit.
197
+ """
198
+ db_path_ = to_local_path(db_path, db_package).absolute()
199
+ lock_path = db_path_.with_suffix(".lock")
200
+ lock = FileLock(lock_path)
201
+ logger = logging.getLogger(__name__)
202
+ logger.info("PID %d: Acquiring lock on %s", os.getpid(), lock_path)
203
+ with lock:
204
+ con = sqlite_connection(db_path, db_package, read_only=False)
205
+ set_bulk_write_mode(con)
206
+ try:
207
+ yield con
208
+ finally:
209
+ unset_bulk_write_mode(con)
210
+
211
+ if close:
212
+ con.close()
213
+
214
+ if lock_path.exists():
215
+ os.remove(lock_path)
216
+
217
+
218
+ def sqlite_connection(
219
+ db_path: StrOrPath,
220
+ package: Optional[str] = None,
221
+ *,
222
+ mmap_size: Optional[int] = None,
223
+ read_only: bool = False,
224
+ ):
225
+ db_full_path = to_local_path(db_path, package)
226
+
227
+ logger = logging.getLogger(__name__)
228
+ logger.info(f"Connecting to sqlite database: {db_full_path}")
229
+ # sqlite3.PARSE_DECLTYPES will cover parsing dates/datetimes from the db
230
+ con = sqlite3.connect(
231
+ db_full_path.absolute(), detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=not read_only
232
+ )
233
+ if mmap_size is not None:
234
+ logger.info(f"Setting sqlite mmap size to {mmap_size}")
235
+ _log_exec_sql(logger, con, f"PRAGMA mmap_size={mmap_size};")
236
+
237
+ return con
238
+
239
+
240
+ def _log_exec_sql(
241
+ logger: logging.Logger, con: sqlite3.Connection, statement: str, level: int = logging.DEBUG
242
+ ):
243
+ logger.log(level, "sqlite: %s", statement)
244
+ con.execute(statement)
245
+
246
+
247
+ @lru_cache(None)
248
+ def sqlite_constructor_for_record_type(cls):
249
+ """Wrap an `attrs` record class to allow it to accept raw JSON strings from sqlite queries in place
250
+ of collection types. If not fields of the class have collection types, simply return the record class
251
+ unwrapped"""
252
+ postprocessors = [sqlite_postprocessor_for_type(type_) for type_ in cls.__annotations__.values()]
253
+ if not any(postprocessors):
254
+ return cls
255
+
256
+ @wraps(cls)
257
+ def cons(*args):
258
+ return cls(*(v if f is None else f(v) for f, v in zip(postprocessors, args)))
259
+
260
+ return cons
261
+
262
+
263
+ class AttrsSQLiteDatabase:
264
+ """Base interface for loading package resources as record iterators"""
265
+
266
+ def __init__(
267
+ self,
268
+ package: ty.Optional[str],
269
+ db_path: StrOrPath,
270
+ cache_size: ty.Optional[int] = DEFAULT_ATTR_SQLITE_CACHE_SIZE,
271
+ mmap_size: int = DEFAULT_MMAP_BYTES,
272
+ ):
273
+ if cache_size is not None:
274
+ self.sqlite_index_query = lru_cache(cache_size)(self.sqlite_index_query) # type: ignore
275
+
276
+ self._sqlite_con = sqlite_connection(
277
+ db_path,
278
+ package,
279
+ mmap_size=mmap_size,
280
+ read_only=True,
281
+ )
282
+
283
+ def sqlite_index_query(
284
+ self, clazz: ty.Callable[..., Record], query: str, args: ty.Tuple
285
+ ) -> ty.List[Record]:
286
+ result = self._sqlite_con.execute(query, args).fetchall()
287
+ return [clazz(*r) for r in result]
288
+
289
+ def sqlite_pk_query(
290
+ self, clazz: ty.Callable[..., Record], query: str, args: ty.Tuple
291
+ ) -> ty.Optional[Record]:
292
+ # Note: when we create PK indexes on our sqlite tables, we enforce a UNIQUE constraint, so if the
293
+ # build succeeds then we're guaranteed 0 or 1 results here
294
+ result = self.sqlite_index_query(clazz, query, args)
295
+ return result[0] if result else None
296
+
297
+ @ty.overload
298
+ def sqlite_bulk_query(
299
+ self,
300
+ clazz: ty.Callable[..., Record],
301
+ query: str,
302
+ args: ty.Collection[ty.Tuple],
303
+ single_col: ty.Literal[False],
304
+ ) -> ty.Iterator[Record]: ...
305
+
306
+ @ty.overload
307
+ def sqlite_bulk_query(
308
+ self,
309
+ clazz: ty.Callable[..., Record],
310
+ query: str,
311
+ args: ty.Collection,
312
+ single_col: ty.Literal[True],
313
+ ) -> ty.Iterator[Record]: ...
314
+
315
+ def sqlite_bulk_query(
316
+ self, clazz: ty.Callable[..., Record], query: str, args: ty.Collection, single_col: bool
317
+ ) -> ty.Iterator[Record]:
318
+ """Note: this method is intentionally left un-cached; it makes a tradeoff: minimize the number of disk acesses
319
+ and calls into sqlite at the cost of potentially re-loading the same records multiple times in case multiple
320
+ calls pass overlapping keys. Since it isn't cached, it can also be lazyly evaluated as an iterator. Callers are
321
+ encouraged to take advantage of this laziness where it may be useful."""
322
+ if single_col:
323
+ args_ = args if isinstance(args, (list, tuple)) else list(args)
324
+ else:
325
+ args_ = list(itertools.chain.from_iterable(args))
326
+ cursor = self._sqlite_con.execute(query, args_)
327
+ for row in cursor:
328
+ yield clazz(*row)
329
+
330
+
331
+ # SQL pre/post processing
332
+
333
+
334
+ def load_date(datestr: Optional[bytes]) -> Optional[datetime.date]:
335
+ return None if datestr is None else datetime.datetime.fromisoformat(datestr.decode()).date()
336
+
337
+
338
+ def load_datetime(datestr: Optional[bytes]) -> Optional[datetime.datetime]:
339
+ return None if datestr is None else datetime.datetime.fromisoformat(datestr.decode())
340
+
341
+
342
+ sqlite3.register_converter(DType.BOOL.sqlite, lambda b: bool(int(b)))
343
+ sqlite3.register_converter(DType.DATE.sqlite, load_date)
344
+ sqlite3.register_converter(DType.DATETIME.sqlite, load_datetime)
345
+ sqlite3.register_adapter(datetime.date, datetime.date.isoformat)
346
+ sqlite3.register_adapter(datetime.datetime, datetime.datetime.isoformat)