thds.core 0.0.1__py3-none-any.whl → 1.31.20250116223856__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of thds.core might be problematic. Click here for more details.

Files changed (70) hide show
  1. thds/core/__init__.py +48 -0
  2. thds/core/ansi_esc.py +46 -0
  3. thds/core/cache.py +201 -0
  4. thds/core/calgitver.py +82 -0
  5. thds/core/concurrency.py +100 -0
  6. thds/core/config.py +250 -0
  7. thds/core/decos.py +55 -0
  8. thds/core/dict_utils.py +188 -0
  9. thds/core/env.py +40 -0
  10. thds/core/exit_after.py +121 -0
  11. thds/core/files.py +125 -0
  12. thds/core/fretry.py +115 -0
  13. thds/core/generators.py +56 -0
  14. thds/core/git.py +81 -0
  15. thds/core/hash_cache.py +86 -0
  16. thds/core/hashing.py +106 -0
  17. thds/core/home.py +15 -0
  18. thds/core/hostname.py +10 -0
  19. thds/core/imports.py +17 -0
  20. thds/core/inspect.py +58 -0
  21. thds/core/iterators.py +9 -0
  22. thds/core/lazy.py +83 -0
  23. thds/core/link.py +153 -0
  24. thds/core/log/__init__.py +29 -0
  25. thds/core/log/basic_config.py +171 -0
  26. thds/core/log/json_formatter.py +43 -0
  27. thds/core/log/kw_formatter.py +84 -0
  28. thds/core/log/kw_logger.py +93 -0
  29. thds/core/log/logfmt.py +302 -0
  30. thds/core/merge_args.py +168 -0
  31. thds/core/meta.json +8 -0
  32. thds/core/meta.py +518 -0
  33. thds/core/parallel.py +200 -0
  34. thds/core/pickle_visit.py +24 -0
  35. thds/core/prof.py +276 -0
  36. thds/core/progress.py +112 -0
  37. thds/core/protocols.py +17 -0
  38. thds/core/py.typed +0 -0
  39. thds/core/scaling.py +39 -0
  40. thds/core/scope.py +199 -0
  41. thds/core/source.py +238 -0
  42. thds/core/source_serde.py +104 -0
  43. thds/core/sqlite/__init__.py +21 -0
  44. thds/core/sqlite/connect.py +33 -0
  45. thds/core/sqlite/copy.py +35 -0
  46. thds/core/sqlite/ddl.py +4 -0
  47. thds/core/sqlite/functions.py +63 -0
  48. thds/core/sqlite/index.py +22 -0
  49. thds/core/sqlite/insert_utils.py +23 -0
  50. thds/core/sqlite/merge.py +84 -0
  51. thds/core/sqlite/meta.py +190 -0
  52. thds/core/sqlite/read.py +66 -0
  53. thds/core/sqlite/sqlmap.py +179 -0
  54. thds/core/sqlite/structured.py +138 -0
  55. thds/core/sqlite/types.py +64 -0
  56. thds/core/sqlite/upsert.py +139 -0
  57. thds/core/sqlite/write.py +99 -0
  58. thds/core/stack_context.py +41 -0
  59. thds/core/thunks.py +40 -0
  60. thds/core/timer.py +214 -0
  61. thds/core/tmp.py +85 -0
  62. thds/core/types.py +4 -0
  63. thds.core-1.31.20250116223856.dist-info/METADATA +68 -0
  64. thds.core-1.31.20250116223856.dist-info/RECORD +67 -0
  65. {thds.core-0.0.1.dist-info → thds.core-1.31.20250116223856.dist-info}/WHEEL +1 -1
  66. thds.core-1.31.20250116223856.dist-info/entry_points.txt +4 -0
  67. thds.core-1.31.20250116223856.dist-info/top_level.txt +1 -0
  68. thds.core-0.0.1.dist-info/METADATA +0 -8
  69. thds.core-0.0.1.dist-info/RECORD +0 -4
  70. thds.core-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,63 @@
1
+ import hashlib
2
+ import inspect
3
+ import sqlite3
4
+ import typing as ty
5
+
6
+
7
+ def _slowish_hash_str_function(__string: str) -> int:
8
+ # Raymond Hettinger says that little-endian is slightly faster, though that was 2021.
9
+ # I have also tested this myself and found it to be true.
10
+ # https://bugs.python.org/msg401661
11
+ return int.from_bytes(hashlib.md5(__string.encode()).digest()[:7], byteorder="little")
12
+
13
+
14
+ _THE_HASH_FUNCTION = _slowish_hash_str_function
15
+
16
+
17
+ # If you need it to be faster, you can 'replace' this with your own implementation.
18
+ def set_hash_function(f: ty.Callable[[str], int]) -> None:
19
+ global _THE_HASH_FUNCTION
20
+ _THE_HASH_FUNCTION = f
21
+
22
+
23
+ def _pyhash_values(*args) -> int:
24
+ _args = (x if isinstance(x, str) else str(x) for x in args)
25
+ concatenated = "".join(_args)
26
+ hash_value = _THE_HASH_FUNCTION(concatenated)
27
+ return hash_value
28
+
29
+
30
+ def _has_param_kind(signature, kind) -> bool:
31
+ return any(p.kind == kind for p in signature.parameters.values())
32
+
33
+
34
+ def _num_parameters(f: ty.Callable) -> int:
35
+ signature = inspect.signature(f)
36
+ if _has_param_kind(signature, inspect.Parameter.VAR_KEYWORD):
37
+ raise NotImplementedError("**kwargs in sqlite functions is not supported")
38
+ elif _has_param_kind(signature, inspect.Parameter.VAR_POSITIONAL):
39
+ return -1
40
+ else:
41
+ return len(signature.parameters.keys())
42
+
43
+
44
+ _FUNCTIONS = [_pyhash_values]
45
+
46
+
47
+ def register_functions_on_connection(
48
+ conn: sqlite3.Connection,
49
+ *,
50
+ functions: ty.Collection[ty.Callable] = _FUNCTIONS,
51
+ ) -> sqlite3.Connection:
52
+ """By default registers our default functions.
53
+
54
+ Returns the connection itself, for chaining.
55
+ """
56
+ for f in _FUNCTIONS:
57
+ narg = _num_parameters(f)
58
+ # SPOOKY: we're registering a function here with SQLite, and the SQLite name will
59
+ # be the same as its name in Python. Be very careful that you do not register two
60
+ # different functions with the same name - you can read their docs on what will
61
+ # happen, but it would be far better to just not ever do this.
62
+ conn.create_function(name=f.__name__, narg=narg, func=f, deterministic=True)
63
+ return conn
@@ -0,0 +1,22 @@
1
+ import typing as ty
2
+
3
+ from .connect import Connectable, autoconn_scope, autoconnect
4
+
5
+
6
+ @autoconn_scope.bound
7
+ def create(connectable: Connectable, table_name: str, columns: ty.Collection[str], unique: bool = False):
8
+ """Is idempotent, but does not verify that your index DDL matches what you're asking for."""
9
+ conn = autoconnect(connectable)
10
+ """Create an index on a table in a SQLite database using only sqlite3 and SQL."""
11
+ colnames = "_".join(colname for colname in columns).replace("-", "_")
12
+
13
+ sql_create_index = (
14
+ f"CREATE {'UNIQUE' if unique else ''} INDEX IF NOT EXISTS "
15
+ f"[{table_name}_{colnames}_idx] ON [{table_name}] ({', '.join(columns)})"
16
+ )
17
+ try:
18
+ conn.execute(sql_create_index)
19
+ # do not commit - let the caller decide when to commit, or allow autoconnect to do its job
20
+ except Exception:
21
+ print(sql_create_index)
22
+ raise
@@ -0,0 +1,23 @@
1
+ # this works only if you have something to map the tuples to an attrs class.
2
+ # it also does not currently offer any parallelism.
3
+ import typing as ty
4
+
5
+ from thds.core import progress
6
+ from thds.core.sqlite import DbAndTable, TableSource, table_source
7
+
8
+ T = ty.TypeVar("T")
9
+
10
+
11
+ def tuples_to_table_source(
12
+ data: ty.Iterable[tuple],
13
+ upsert_many: ty.Callable[[ty.Iterable[T]], DbAndTable],
14
+ func: ty.Callable[[tuple], ty.Optional[T]],
15
+ ) -> TableSource:
16
+ """
17
+ Converts an iterable of tuples to a TableSource.
18
+ :param data: An iterable of tuples.
19
+ :param upsert_many: A function that accepts an iterable of T and returns a DbAndTable.
20
+ :param func: A function that processes each tuple and returns an optional T.
21
+ :return: The resulting table source after upserting data.
22
+ """
23
+ return table_source(*upsert_many(progress.report(filter(None, map(func, data)))))
@@ -0,0 +1,84 @@
1
+ import os
2
+ import typing as ty
3
+ from contextlib import closing
4
+ from pathlib import Path
5
+ from sqlite3 import connect
6
+ from timeit import default_timer
7
+
8
+ from thds.core import log, types
9
+
10
+ from .meta import get_indexes, get_tables, pydd
11
+
12
+ logger = log.getLogger(__name__)
13
+
14
+
15
+ def merge_databases(
16
+ filenames: ty.Iterable[types.StrOrPath],
17
+ table_names: ty.Collection[str] = tuple(),
18
+ *,
19
+ replace: bool = False,
20
+ copy_indexes: bool = True,
21
+ ) -> Path:
22
+ """Merges the listed tables, if found in the other filenames,
23
+ into the first database in the list. If no table names are listed,
24
+ we will instead merge _all_ tables found in _every_ database into the first.
25
+
26
+ If the table already exists in the destination, it will not be altered.
27
+ If the table does not, it will be created using the SQL CREATE TABLE statement
28
+ found in the first database where it is encountered.
29
+
30
+ This mutates the first database in the list! If you don't want it
31
+ mutated, make a copy of the file first, or start with an empty database file!
32
+
33
+ Allows SQL injection via the table names - don't use this on untrusted inputs.
34
+
35
+ By default, also copies indexes associated with the tables where an index with the
36
+ same name does not already exist in the destination table. You can disable this
37
+ wholesale and then create/copy the specific indexes that you want after the fact.
38
+ """
39
+ _or_replace_ = "OR REPLACE" if replace else ""
40
+ filenames = iter(filenames)
41
+ first_filename = next(filenames)
42
+ logger.info(f"Connecting to {first_filename}")
43
+ conn = connect(first_filename)
44
+ destination_tables = get_tables(conn)
45
+
46
+ merge_start = default_timer()
47
+ with closing(conn.cursor()) as cursor:
48
+ cursor.execute("pragma synchronous = off;")
49
+ cursor.execute("pragma journal_mode = off;")
50
+ for filename in filenames:
51
+ pydd(Path(filename))
52
+ start = default_timer()
53
+ to_merge = "to_merge" # just a local/temporary constant alias
54
+ logger.info(f"Merging {filename} into {first_filename}")
55
+ cursor.execute(f"ATTACH '{os.fspath(filename)}' AS " + to_merge)
56
+
57
+ attached_tables = get_tables(conn, schema_name=to_merge)
58
+ cursor.execute("BEGIN")
59
+ for table_name in table_names or attached_tables.keys():
60
+ if table_name not in attached_tables:
61
+ continue # table doesn't exist in the database to merge, so we skip it!
62
+ if table_name not in destination_tables:
63
+ cursor.execute(attached_tables[table_name]) # create the table in the destination
64
+ destination_tables = get_tables(conn) # refresh tables dict
65
+ cursor.execute(
66
+ f"INSERT {_or_replace_} INTO {table_name} SELECT * FROM {to_merge}.[{table_name}]"
67
+ )
68
+ if copy_indexes:
69
+ dest_indexes = get_indexes(table_name, conn)
70
+ for idx_name, index_sql in get_indexes(
71
+ table_name, conn, schema_name=to_merge
72
+ ).items():
73
+ if idx_name not in dest_indexes:
74
+ cursor.execute(index_sql)
75
+ logger.info(
76
+ f"Committing merge of {filename} into {first_filename} after {default_timer() - start:.2f}s"
77
+ )
78
+ conn.commit() # without a commit, DETACH DATABASE will error with Database is locked.
79
+ # https://stackoverflow.com/questions/56243770/sqlite3-detach-database-produces-database-is-locked-error
80
+ cursor.execute(f"DETACH DATABASE {to_merge}")
81
+
82
+ logger.info(f"Merge complete after {default_timer() - merge_start:.2f}s")
83
+ conn.close()
84
+ return Path(first_filename)
@@ -0,0 +1,190 @@
1
+ """Read-only utilities for inspecting a SQLite database.
2
+
3
+ Note that none of these are safe from SQL injection - you should probably
4
+ not be allowing users to specify tables in an ad-hoc fashion.
5
+ """
6
+
7
+ import contextlib
8
+ import os
9
+ import sqlite3
10
+ import typing as ty
11
+ from pathlib import Path
12
+
13
+ from thds.core import log
14
+ from thds.core.source import from_file
15
+
16
+ from .connect import autoconn_scope, autoconnect
17
+ from .types import Connectable, TableSource
18
+
19
+ logger = log.getLogger(__name__)
20
+
21
+
22
+ def fullname(table_name: str, schema_name: str = "") -> str:
23
+ if schema_name:
24
+ return f"{schema_name}.[{table_name}]"
25
+ return f"[{table_name}]"
26
+
27
+
28
+ @autoconn_scope.bound
29
+ def list_tables(connectable: Connectable, schema_name: str = "") -> ty.List[str]:
30
+ conn = autoconnect(connectable)
31
+ return [
32
+ row[0]
33
+ for row in conn.execute(
34
+ f"SELECT name FROM {fullname('sqlite_master', schema_name)} WHERE type='table'"
35
+ )
36
+ ]
37
+
38
+
39
+ @autoconn_scope.bound
40
+ def get_tables(connectable: Connectable, *, schema_name: str = "main") -> ty.Dict[str, str]:
41
+ """Keys of the returned dict are the names of tables in the database.
42
+
43
+ Values of the returned dict are the raw SQL that can be used to recreate the table.
44
+ """
45
+ conn = autoconnect(connectable)
46
+ return {
47
+ row[0]: row[1]
48
+ for row in conn.execute(
49
+ f"""
50
+ SELECT name, sql
51
+ FROM {fullname('sqlite_master', schema_name)}
52
+ WHERE type = 'table'
53
+ AND sql is not null
54
+ """
55
+ )
56
+ }
57
+
58
+
59
+ def pydd(path: os.PathLike):
60
+ """Sometimes running this on a big sqlite file before starting a
61
+ query will make a big difference to overall query performance.
62
+ """
63
+ with open(path, "rb") as f:
64
+ while f.read(1024 * 1024):
65
+ pass
66
+
67
+
68
+ def table_name_from_path(db_path: Path) -> str:
69
+ tables = list_tables(db_path)
70
+ assert len(tables) == 1, f"Expected exactly one table, got {tables}"
71
+ return tables[0]
72
+
73
+
74
+ def table_source(db_path: Path, table_name: str = "") -> TableSource:
75
+ if not table_name:
76
+ table_name = table_name_from_path(db_path)
77
+ return TableSource(from_file(db_path), table_name)
78
+
79
+
80
+ @autoconn_scope.bound
81
+ def primary_key_cols(table_name: str, connectable: Connectable) -> ty.Tuple[str, ...]:
82
+ conn = autoconnect(connectable)
83
+ return tuple(
84
+ row[0]
85
+ for row in conn.execute(
86
+ f"""
87
+ SELECT name
88
+ FROM pragma_table_info('{table_name}')
89
+ WHERE pk <> 0
90
+ """
91
+ )
92
+ )
93
+
94
+
95
+ @autoconn_scope.bound
96
+ def column_names(table_name: str, connectable: Connectable) -> ty.Tuple[str, ...]:
97
+ conn = autoconnect(connectable)
98
+ return tuple([row[1] for row in conn.execute(f"PRAGMA table_info({table_name})")])
99
+
100
+
101
+ def preload_sources(*table_srcs: ty.Optional[TableSource]) -> None:
102
+ for table_src in table_srcs:
103
+ if not table_src:
104
+ continue
105
+ assert isinstance(table_src, TableSource)
106
+ logger.debug("Preloading %s from %s", table_src.table_name, table_src.db_src.uri)
107
+ pydd(table_src.db_src)
108
+ logger.debug("Preloading complete")
109
+
110
+
111
+ @autoconn_scope.bound
112
+ def get_indexes(
113
+ table_name: str, connectable: Connectable, *, schema_name: str = "main"
114
+ ) -> ty.Dict[str, str]:
115
+ """Keys of the returned dict are the names of indexes belonging to the given table.
116
+
117
+ Values of the returned dict are the raw SQL that can be used to recreate the index(es).
118
+ """
119
+ conn = autoconnect(connectable)
120
+ return {
121
+ row[0]: row[1]
122
+ for row in conn.execute(
123
+ f"""
124
+ SELECT name, sql
125
+ FROM {fullname('sqlite_master', schema_name)}
126
+ WHERE type = 'index' AND tbl_name = '{table_name}'
127
+ AND sql is not null
128
+ """
129
+ )
130
+ if row[1]
131
+ # sqlite has internal indexes, usually prefixed with "sqlite_autoindex_",
132
+ # that do not have sql statements associated with them.
133
+ # we exclude these because they are uninteresting to the average caller of this function.
134
+ }
135
+
136
+
137
+ @contextlib.contextmanager
138
+ @autoconn_scope.bound
139
+ def debug_errors(connectable: Connectable) -> ty.Iterator:
140
+ try:
141
+ yield
142
+ except Exception:
143
+ try:
144
+ conn = autoconnect(connectable)
145
+ cur = conn.cursor()
146
+ cur.execute("PRAGMA database_list")
147
+ rows = cur.fetchall()
148
+ databases = {row[1]: dict(num=row[0], file=row[2]) for row in rows}
149
+ logger.error(f"Database info: {databases}")
150
+ for db_name in databases:
151
+ logger.error(
152
+ f" In {db_name}, tables are: {get_tables(connectable, schema_name=db_name)}"
153
+ )
154
+ except Exception:
155
+ logger.error(f"SQLite database: {connectable} is not introspectable")
156
+
157
+ raise
158
+
159
+
160
+ @autoconn_scope.bound
161
+ def get_table_schema(
162
+ conn: ty.Union[sqlite3.Connection, Connectable], table_name: str
163
+ ) -> ty.Dict[str, str]:
164
+ """
165
+ Retrieve the schema of a given table.
166
+
167
+ Args:
168
+ conn: The database connection object or a Connectable.
169
+ table_name: The name of the table.
170
+
171
+ Returns: A dictionary with column names as keys and their types as values.
172
+ """
173
+ # Ensure we have a connection object
174
+ connection = autoconnect(conn)
175
+
176
+ # Fetch the table schema
177
+ cursor = connection.cursor()
178
+ cursor.execute(f"PRAGMA table_info('{table_name}')")
179
+ schema = {row[1]: row[2].lower() for row in cursor.fetchall()}
180
+ return schema
181
+
182
+
183
+ @autoconn_scope.bound
184
+ def attach(connectable: Connectable, db_path: os.PathLike, schema_name: str) -> None:
185
+ """ATTACH a database to the current connection, using your provided schema name.
186
+
187
+ It must be an actual file.
188
+ """
189
+ conn = autoconnect(connectable)
190
+ conn.execute(f"ATTACH DATABASE '{os.fspath(db_path)}' AS {schema_name}")
@@ -0,0 +1,66 @@
1
+ import typing as ty
2
+ from sqlite3 import Connection, Row
3
+
4
+
5
+ def matching_where(to_match_colnames: ty.Iterable[str]) -> str:
6
+ """Creates a where clause for these column names, with named @{col} placeholders for each column."""
7
+ qs = " AND ".join(f"{k} = @{k}" for k in to_match_colnames)
8
+ return f"WHERE {qs}" if qs else ""
9
+
10
+
11
+ def matching_select(
12
+ table_name: str,
13
+ conn: Connection,
14
+ to_match: ty.Mapping[str, ty.Any],
15
+ columns: ty.Sequence[str] = tuple(),
16
+ ) -> ty.Iterator[ty.Mapping[str, ty.Any]]:
17
+ """Get a single row from a table by key.
18
+
19
+ This is susceptible to SQL injection because the keys are
20
+ formatted directly. Do _not_ give external users the ability to
21
+ call this function directly and specify any of its keys.
22
+ """
23
+ cols = ", ".join(columns) if columns else "*"
24
+
25
+ qs = " AND ".join(f"{k} = ?" for k in to_match.keys())
26
+ where = f"WHERE {qs}" if qs else ""
27
+ # because we control the whole query, we're matching on the 'dumb' ? placeholder.
28
+
29
+ old_row_factory = conn.row_factory
30
+ conn.row_factory = Row # this is an optimized approach to getting 'mappings' (with key names)
31
+ for row in conn.execute(f"SELECT {cols} FROM {table_name} {where}", tuple(to_match.values())):
32
+ yield row
33
+ conn.row_factory = old_row_factory
34
+
35
+
36
+ matching = matching_select # alias
37
+
38
+
39
+ def partition(
40
+ n: int, i: int, columns: ty.Optional[ty.Union[str, ty.Collection[str]]] = None
41
+ ) -> ty.Dict[str, int]:
42
+ """Can (surprisingly) be used directly with matching().
43
+
44
+ i should be zero-indexed, whereas N is a count (natural number).
45
+
46
+ columns is an optional parameter to specify column(s) to partition on
47
+ if no columns are specified partitioning will be based on rowid
48
+ """
49
+ assert 0 <= i < n
50
+ if not columns:
51
+ return {f"rowid % {n}": i}
52
+ hash_cols = columns if isinstance(columns, str) else ", ".join(columns)
53
+ # when SQLite uses this in a WHERE clause, as "hash(foo, bar) % 5 = 3",
54
+ # it hashes the _values_ in the row for those columns.
55
+ # Note that if we do have to fall back to this, the partitioning will be
56
+ # quite a bit slower because of the necessity of calling back into Python.
57
+ return {f"_pyhash_values({hash_cols}) % {n}": i}
58
+
59
+
60
+ def maybe(
61
+ table_name: str, conn: Connection, to_match: ty.Mapping[str, ty.Any]
62
+ ) -> ty.Optional[ty.Mapping[str, ty.Any]]:
63
+ """Get a single row, if it exists, from a table by key"""
64
+ results = list(matching(table_name, conn, to_match))
65
+ assert len(results) == 0 or len(results) == 1
66
+ return results[0] if results else None
@@ -0,0 +1,179 @@
1
+ import inspect
2
+ import shutil
3
+ import typing as ty
4
+ from collections import defaultdict
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from uuid import uuid4
9
+
10
+ from thds.core import log, parallel, scope, thunks, tmp, types
11
+
12
+ from .merge import merge_databases
13
+
14
+ logger = log.getLogger(__name__)
15
+ _tmpdir_scope = scope.Scope()
16
+
17
+
18
+ class Partition(ty.NamedTuple):
19
+ partition: int # 0 indexed
20
+ of: int # count (1 indexed)
21
+
22
+
23
+ ALL = Partition(0, 1)
24
+
25
+
26
+ def _name(callable: ty.Callable) -> str:
27
+ if hasattr(callable, "func"):
28
+ return callable.func.__name__
29
+ return callable.__name__
30
+
31
+
32
+ def _write_partition(
33
+ basename: str, writer: ty.Callable[[Partition, Path], ty.Any], base_dir: Path, partition: Partition
34
+ ) -> Path:
35
+ part_dir = base_dir / f"-{basename}-{partition.partition:02d}of{partition.of:02d}"
36
+ part_dir.mkdir(exist_ok=True, parents=True)
37
+ with log.logger_context(p=f"{partition.partition + 1:2d}/{partition.of:2d}"):
38
+ logger.info(f"Writing partition '{partition}' outputs into {part_dir}")
39
+ try:
40
+ writer(partition, part_dir)
41
+ except Exception:
42
+ logger.exception(f"Failed to write partition '{partition}' outputs into {part_dir}")
43
+ raise
44
+ logger.info(f"Finished writing partition {partition} outputs into {part_dir}")
45
+ return part_dir
46
+
47
+
48
+ Merger = ty.Callable[[ty.Iterable[types.StrOrPath]], Path]
49
+
50
+
51
+ def merge_sqlite_dirs(
52
+ merger: ty.Optional[Merger],
53
+ part_dirs: ty.Iterable[Path],
54
+ output_dir: Path,
55
+ max_cores: int = 2,
56
+ ) -> ty.Dict[str, Path]:
57
+ """Return a dictionary where the keys are the filenames, and the values are the Paths
58
+ of merged SQLite databases.
59
+
60
+ Any file found in any of the input part directories is assumed to be a SQLite
61
+ database, and will be merged, using `merger`, into all the other SQLite databases that
62
+ _bear the same name_ in any of the other directories.
63
+
64
+ Each final, merged SQLite database will then be _moved_ into the output_dir provided.
65
+
66
+ max_cores is the maximum number of _databases_ to merge in parallel;
67
+ since SQLite is doing almost all of the work, we don't imagine that we'd be able to get
68
+ much speedup by merging multiple databases using the same core. This has not been benchmarked.
69
+ """
70
+ _ensure_output_dir(output_dir)
71
+ sqlite_dbs_by_filename: ty.Dict[str, ty.List[Path]] = defaultdict(list)
72
+ for partition_dir in part_dirs:
73
+ if not partition_dir.exists():
74
+ continue # a partition writer is allowed to write out nothing to a partition
75
+ if not partition_dir.is_dir():
76
+ # this may happen if people don't read the parallel_to_sqlite docstring and
77
+ # assume the Path is meant to be a file.
78
+ raise ValueError(
79
+ f"Partition directory {partition_dir} is not a directory!"
80
+ " Your code may have written directly to the provided Path as a file,"
81
+ " rather than writing SQLite database files into the directory as required."
82
+ )
83
+ for sqlite_db_path in partition_dir.iterdir():
84
+ if sqlite_db_path.is_file():
85
+ sqlite_dbs_by_filename[sqlite_db_path.name].append(sqlite_db_path)
86
+
87
+ thunking_merger = thunks.thunking(merger or _default_merge_databases)
88
+ for merged_db in parallel.yield_results(
89
+ [
90
+ thunking_merger(sqlite_db_paths)
91
+ for filename, sqlite_db_paths in sqlite_dbs_by_filename.items()
92
+ ],
93
+ # SQLite merge is CPU-intensive, so we use a Process Pool.
94
+ executor_cm=ProcessPoolExecutor(max_workers=max(min(max_cores, len(sqlite_dbs_by_filename)), 1)),
95
+ ):
96
+ logger.info(f"Moving merged database {merged_db} into {output_dir}")
97
+ shutil.move(str(merged_db), output_dir)
98
+
99
+ return {filename: output_dir / filename for filename in sqlite_dbs_by_filename}
100
+
101
+
102
+ def _ensure_output_dir(output_directory: Path):
103
+ if output_directory.exists():
104
+ if not output_directory.is_dir():
105
+ raise ValueError("Output path must be a directory if it exists!")
106
+ else:
107
+ output_directory.mkdir(parents=True, exist_ok=True)
108
+ assert output_directory.is_dir()
109
+
110
+
111
+ _default_merge_databases: Merger = partial(
112
+ merge_databases,
113
+ **{
114
+ k: v.default
115
+ for k, v in inspect.signature(merge_databases).parameters.items()
116
+ if v.default != inspect._empty
117
+ },
118
+ )
119
+ # TODO - feels like this suggests creating a general utility that creates partials where all the defaults are applied
120
+ # the typing seems a bit tricky though
121
+
122
+
123
+ @_tmpdir_scope.bound
124
+ def partitions_to_sqlite(
125
+ partition_writer: ty.Callable[[Partition, Path], ty.Any],
126
+ output_directory: Path,
127
+ partitions: ty.Sequence[Partition],
128
+ *,
129
+ custom_merger: ty.Optional[Merger] = None,
130
+ max_workers: int = 0,
131
+ ) -> ty.Dict[str, Path]:
132
+ """By default, will use one Process worker per partition provided."""
133
+ temp_dir = _tmpdir_scope.enter(tmp.tempdir_same_fs(output_directory))
134
+
135
+ part_directories = list(
136
+ parallel.yield_results(
137
+ [
138
+ thunks.thunking(_write_partition)(
139
+ _name(partition_writer) + uuid4().hex[:20],
140
+ partition_writer,
141
+ temp_dir,
142
+ partition,
143
+ )
144
+ for partition in partitions
145
+ ],
146
+ executor_cm=ProcessPoolExecutor(max_workers=max_workers or len(partitions)),
147
+ # executor_cm=contextlib.nullcontext(loky.get_reusable_executor(max_workers=N)),
148
+ )
149
+ )
150
+ return merge_sqlite_dirs(
151
+ custom_merger if custom_merger is not None else _default_merge_databases,
152
+ part_directories,
153
+ output_directory,
154
+ max_cores=max_workers,
155
+ )
156
+
157
+
158
+ def parallel_to_sqlite(
159
+ partition_writer: ty.Callable[[Partition, Path], ty.Any],
160
+ output_directory: Path,
161
+ N: int = 8,
162
+ custom_merger: ty.Optional[Merger] = None,
163
+ ) -> ty.Dict[str, Path]:
164
+ """The partition_writer will be provided a partition number and a directory (as a Path).
165
+
166
+ It must write one or more sqlite databases to the directory provided, using (of
167
+ course) the given partition to filter/query its input.
168
+
169
+ Any files found in that output directory will be assumed to be a SQLite database, and
170
+ any _matching_ filenames across the set of partitions that are written will be merged
171
+ into each other. Therefore, there will be a single database in the output directory
172
+ for every unique filename found in any of the partition directories after writing.
173
+ """
174
+ return partitions_to_sqlite(
175
+ partition_writer,
176
+ output_directory,
177
+ [Partition(i, N) for i in range(N)],
178
+ custom_merger=custom_merger,
179
+ )