sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1182 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third-pary
|
|
4
|
+
import zarr
|
|
5
|
+
from zarr.storage import LocalStore, ZipStore
|
|
6
|
+
from zarr.codecs import BloscCodec, BloscShuffle, BloscCname
|
|
7
|
+
import numpy as np
|
|
8
|
+
# built-in
|
|
9
|
+
import re
|
|
10
|
+
import logging
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from math import ceil
|
|
13
|
+
from datetime import datetime, date
|
|
14
|
+
import multiprocessing as mp
|
|
15
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor, ProcessPoolExecutor
|
|
16
|
+
from typing import Callable, Iterable, Iterator, Any
|
|
17
|
+
import os, psutil, tempfile
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
# *----------------------------------------------------*
|
|
23
|
+
# GLOBALS
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
|
|
26
|
+
_logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# *----------------------------------------------------*
|
|
29
|
+
# CLASSES
|
|
30
|
+
# *----------------------------------------------------*
|
|
31
|
+
|
|
32
|
+
class ArrayStorage:
|
|
33
|
+
"""A single-root-group Zarr v3 container with multiple arrays and metadata.
|
|
34
|
+
|
|
35
|
+
This wraps a root Zarr **group** (backed by a LocalStore `<name>.zarr`
|
|
36
|
+
or a read-only ZipStore `<name>.zip`). Each logical "block" is a Zarr
|
|
37
|
+
array with shape ``(N, *item_shape)`` where axis 0 is append-only.
|
|
38
|
+
Per-block metadata (chunk length, item shape, dtype) is kept in group attrs.
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, pth: Path | str, mode: str) -> None:
|
|
41
|
+
"""Initialize the storage and ensure a root group exists.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
pth: Base path. If it ends with ``.zip`` a read-only ZipStore is used;
|
|
45
|
+
otherwise a LocalStore at ``<pth>.zarr`` is used.
|
|
46
|
+
mode: Zarr open mode. For ZipStore this must be ``"r"``.
|
|
47
|
+
For LocalStore, an existing store is opened with this mode; if
|
|
48
|
+
missing, a new root group is created.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If `pth` type is invalid or ZipStore mode is not ``"r"``.
|
|
52
|
+
FileNotFoundError: If a ZipStore was requested but the file is missing.
|
|
53
|
+
TypeError: If the root object is an array instead of a group.
|
|
54
|
+
"""
|
|
55
|
+
_logger.info("ArrayStorage init: pth=%s mode=%s", pth, mode)
|
|
56
|
+
|
|
57
|
+
if not isinstance(pth, (str, Path)):
|
|
58
|
+
_logger.error("Invalid 'pth' type: %s", type(pth))
|
|
59
|
+
raise ValueError(f"Expected 'str' or 'Path' for 'pth'; got: {type(pth)}")
|
|
60
|
+
|
|
61
|
+
p = Path(pth)
|
|
62
|
+
self.mode = mode
|
|
63
|
+
|
|
64
|
+
# store backend
|
|
65
|
+
if p.suffix == ".zip":
|
|
66
|
+
# ZipStore is read-only for safety (no overwrite semantics)
|
|
67
|
+
self.store_path = p.resolve()
|
|
68
|
+
_logger.info("Using ZipStore backend at %s", self.store_path)
|
|
69
|
+
if mode != "r":
|
|
70
|
+
_logger.error("Attempted to open ZipStore with non-read mode: %s", mode)
|
|
71
|
+
raise ValueError("ZipStore must be opened read-only (mode='r').")
|
|
72
|
+
if not self.store_path.exists():
|
|
73
|
+
_logger.error("ZipStore path does not exist: %s", self.store_path)
|
|
74
|
+
raise FileNotFoundError(f"No ZipStore at: {self.store_path}")
|
|
75
|
+
self.store = ZipStore(self.store_path, mode="r")
|
|
76
|
+
else:
|
|
77
|
+
# local directory store at <pth>.zarr
|
|
78
|
+
self.store_path = p.with_suffix(".zarr").resolve()
|
|
79
|
+
_logger.info("Using LocalStore backend at %s", self.store_path)
|
|
80
|
+
self.store = LocalStore(self.store_path)
|
|
81
|
+
|
|
82
|
+
# open existing or create new root group
|
|
83
|
+
try:
|
|
84
|
+
# try to open the store
|
|
85
|
+
_logger.info("Opening store at %s with mode=%s", self.store_path, mode)
|
|
86
|
+
self.root = zarr.open(self.store, mode=mode)
|
|
87
|
+
# the root must be a group. if it's not -- schema error then
|
|
88
|
+
if not isinstance(self.root, zarr.Group):
|
|
89
|
+
_logger.error("Root is not a group at %s", self.store_path)
|
|
90
|
+
raise TypeError(f"Root at {self.store_path} must be a group.")
|
|
91
|
+
except Exception:
|
|
92
|
+
# if we can't open:
|
|
93
|
+
# for ZipStore or read-only modes, we must not create, so re-raise
|
|
94
|
+
if isinstance(self.store, ZipStore) or mode == "r":
|
|
95
|
+
_logger.exception("Failed to open store in read-only context; re-raising")
|
|
96
|
+
raise
|
|
97
|
+
# otherwise, create a new group
|
|
98
|
+
_logger.info("Creating new root group at %s", self.store_path)
|
|
99
|
+
self.root = zarr.group(store=self.store, mode="a")
|
|
100
|
+
|
|
101
|
+
# metadata attrs (JSON-safe)
|
|
102
|
+
self._attrs = self.root.attrs
|
|
103
|
+
self._attrs.setdefault("array_chunk_size_in_block", {})
|
|
104
|
+
self._attrs.setdefault("array_shape_in_block", {})
|
|
105
|
+
self._attrs.setdefault("array_dtype_in_block", {})
|
|
106
|
+
_logger.debug("Metadata attrs initialized: keys=%s", list(self._attrs.keys()))
|
|
107
|
+
|
|
108
|
+
def close(self) -> None:
|
|
109
|
+
"""Close the underlying store if it supports closing."""
|
|
110
|
+
try:
|
|
111
|
+
if hasattr(self, "store") and hasattr(self.store, "close"):
|
|
112
|
+
self.store.close()
|
|
113
|
+
except Exception as e:
|
|
114
|
+
_logger.warning("Ignoring error while closing store: %s", e)
|
|
115
|
+
|
|
116
|
+
def __enter__(self):
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
def __exit__(self, exc_type, exc, tb):
|
|
120
|
+
self.close()
|
|
121
|
+
|
|
122
|
+
def __del__(self):
|
|
123
|
+
try:
|
|
124
|
+
self.close()
|
|
125
|
+
except Exception:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
# --------- PRIVATE ----------
|
|
129
|
+
|
|
130
|
+
def _array_chunk_size_in_block(self, named: str, *, given: int | None) -> int:
|
|
131
|
+
"""Resolve per-block chunk length along axis 0; set default if unset."""
|
|
132
|
+
apc = self._attrs["array_chunk_size_in_block"]
|
|
133
|
+
cached = apc.get(named)
|
|
134
|
+
if cached is None:
|
|
135
|
+
if given is None:
|
|
136
|
+
apc[named] = 10
|
|
137
|
+
_logger.warning(
|
|
138
|
+
"array_chunk_size_in_block not provided for '%s'; defaulting to 10", named
|
|
139
|
+
)
|
|
140
|
+
warnings.warn(
|
|
141
|
+
f"You never set 'array_chunk_size_in_block' for block '{named}'. "
|
|
142
|
+
f"Defaulting to 10 — may be suboptimal for your RAM and array size.",
|
|
143
|
+
RuntimeWarning,
|
|
144
|
+
stacklevel=2,
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
if given <= 0:
|
|
148
|
+
_logger.error("Non-positive arrays_per_chunk for block '%s': %s", named, given)
|
|
149
|
+
raise ValueError("'array_chunk_size_in_block' must be positive")
|
|
150
|
+
apc[named] = int(given)
|
|
151
|
+
self._attrs["array_chunk_size_in_block"] = apc
|
|
152
|
+
_logger.debug("Set arrays_per_chunk for '%s' to %s", named, apc[named])
|
|
153
|
+
return apc[named]
|
|
154
|
+
|
|
155
|
+
if given is None:
|
|
156
|
+
return int(cached)
|
|
157
|
+
|
|
158
|
+
if int(cached) != int(given):
|
|
159
|
+
_logger.error(
|
|
160
|
+
"array_chunk_size_in_block mismatch for '%s': cached=%s, given=%s",
|
|
161
|
+
named, cached, given
|
|
162
|
+
)
|
|
163
|
+
raise RuntimeError(
|
|
164
|
+
"The specified 'array_chunk_size_in_block' does not match the value used "
|
|
165
|
+
f"when the block was initialized: {named}.array_chunk_size_in_block is {cached}, "
|
|
166
|
+
f"but {given} was provided."
|
|
167
|
+
)
|
|
168
|
+
return int(cached)
|
|
169
|
+
|
|
170
|
+
def _array_shape_in_block(self, named: str, *, given: tuple[int, ...]) -> tuple[int, ...]:
|
|
171
|
+
"""Resolve per-item shape for a block; enforce consistency if already set."""
|
|
172
|
+
shp = self._attrs["array_shape_in_block"]
|
|
173
|
+
cached = shp.get(named)
|
|
174
|
+
if cached is None:
|
|
175
|
+
shp[named] = list(map(int, given))
|
|
176
|
+
self._attrs["array_shape_in_block"] = shp
|
|
177
|
+
_logger.debug("Set shape for '%s' to %s", named, shp[named])
|
|
178
|
+
return tuple(given)
|
|
179
|
+
|
|
180
|
+
cached_t = tuple(int(x) for x in cached)
|
|
181
|
+
if cached_t != tuple(given):
|
|
182
|
+
_logger.error(
|
|
183
|
+
"Shape mismatch for '%s': cached=%s, given=%s", named, cached_t, given
|
|
184
|
+
)
|
|
185
|
+
raise RuntimeError(
|
|
186
|
+
"The specified 'array_shape_in_block' does not match the value used "
|
|
187
|
+
f"when the block was initialized: {named}.array_shape_in_block is {cached_t}, "
|
|
188
|
+
f"but {given} was provided."
|
|
189
|
+
)
|
|
190
|
+
return cached_t
|
|
191
|
+
|
|
192
|
+
def _array_dtype_in_block(self, named: str, *, given: np.dtype) -> np.dtype:
|
|
193
|
+
"""Resolve dtype for a block; store/recover via dtype.str."""
|
|
194
|
+
dty = self._attrs["array_dtype_in_block"]
|
|
195
|
+
given = np.dtype(given)
|
|
196
|
+
cached = dty.get(named)
|
|
197
|
+
if cached is None:
|
|
198
|
+
dty[named] = given.str
|
|
199
|
+
self._attrs["array_dtype_in_block"] = dty
|
|
200
|
+
_logger.debug("Set dtype for '%s' to %s", named, dty[named])
|
|
201
|
+
return given
|
|
202
|
+
|
|
203
|
+
cached_dt = np.dtype(cached)
|
|
204
|
+
if cached_dt != given:
|
|
205
|
+
_logger.error(
|
|
206
|
+
"Dtype mismatch for '%s': cached=%s, given=%s", named, cached_dt, given
|
|
207
|
+
)
|
|
208
|
+
raise RuntimeError(
|
|
209
|
+
"The specified 'array_dtype_in_block' does not match the value used "
|
|
210
|
+
f"when the block was initialized: {named}.array_dtype_in_block is {cached_dt}, "
|
|
211
|
+
f"but {given} was provided."
|
|
212
|
+
)
|
|
213
|
+
return cached_dt
|
|
214
|
+
|
|
215
|
+
def _setdefault(
|
|
216
|
+
self,
|
|
217
|
+
named: str,
|
|
218
|
+
shape: tuple[int, ...],
|
|
219
|
+
dtype: np.dtype,
|
|
220
|
+
arrays_per_chunk: int | None = None,
|
|
221
|
+
) -> zarr.Array:
|
|
222
|
+
"""Create or open the block array with the resolved metadata."""
|
|
223
|
+
shape = self._array_shape_in_block(named, given=shape)
|
|
224
|
+
dtype = self._array_dtype_in_block(named, given=dtype)
|
|
225
|
+
apc = self._array_chunk_size_in_block(named, given=arrays_per_chunk)
|
|
226
|
+
|
|
227
|
+
# if it already exists, validate and return it
|
|
228
|
+
if named in self.root:
|
|
229
|
+
block = self.root[named]
|
|
230
|
+
if not isinstance(block, zarr.Array):
|
|
231
|
+
raise TypeError(f"Member '{named}' is not a Zarr array")
|
|
232
|
+
if block.shape[1:] != shape:
|
|
233
|
+
raise TypeError(f"Incompatible existing shape {block.shape} vs (0,{shape})")
|
|
234
|
+
if np.dtype(block.dtype) != np.dtype(dtype):
|
|
235
|
+
raise TypeError(f"Incompatible dtype {block.dtype} vs {dtype}")
|
|
236
|
+
return block
|
|
237
|
+
|
|
238
|
+
# otherwise, create the appendable array (length 0 along axis 0)
|
|
239
|
+
_logger.debug("Creating array '%s' with shape=(0,%s), chunks=(%s,%s), dtype=%s",
|
|
240
|
+
named, shape, apc, shape, dtype)
|
|
241
|
+
return self.root.create_array(
|
|
242
|
+
name=named,
|
|
243
|
+
shape=(0,) + shape,
|
|
244
|
+
chunks=(int(apc),) + shape,
|
|
245
|
+
dtype=dtype,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# --------- PUBLIC ----------
|
|
249
|
+
|
|
250
|
+
def write(
|
|
251
|
+
self,
|
|
252
|
+
these_arrays: Sequence[np.ndarray] | np.ndarray,
|
|
253
|
+
to_block_named: str,
|
|
254
|
+
*,
|
|
255
|
+
arrays_per_chunk: int | None = None,
|
|
256
|
+
) -> None:
|
|
257
|
+
"""Append arrays to a block.
|
|
258
|
+
|
|
259
|
+
Appends a batch of arrays (all the same shape and dtype) to the Zarr array
|
|
260
|
+
named `to_block_named`. The array grows along axis 0; chunk length is
|
|
261
|
+
resolved per-block and stored in group attrs.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
these_arrays: A sequence of NumPy arrays **or** a stacked ndarray with
|
|
265
|
+
shape `(k, *item_shape)`. If a generic iterable is provided, it will be
|
|
266
|
+
consumed into a list. All items must share the same shape and dtype.
|
|
267
|
+
to_block_named: Name of the target block (array) inside the root group.
|
|
268
|
+
arrays_per_chunk: Optional chunk length along axis 0. If unset and the
|
|
269
|
+
block is new, defaults to 10 with a warning.
|
|
270
|
+
|
|
271
|
+
Raises:
|
|
272
|
+
RuntimeError: If the storage is opened read-only.
|
|
273
|
+
ValueError: If any array's shape or dtype differs from the first element.
|
|
274
|
+
"""
|
|
275
|
+
if self.mode == "r":
|
|
276
|
+
_logger.error("Write attempted in read-only mode")
|
|
277
|
+
raise RuntimeError("Cannot write to a read-only ArrayStorage")
|
|
278
|
+
|
|
279
|
+
# Normalize to something indexable (list/tuple/ndarray)
|
|
280
|
+
if not isinstance(these_arrays, (list, tuple, np.ndarray)):
|
|
281
|
+
these_arrays = list(these_arrays)
|
|
282
|
+
|
|
283
|
+
if len(these_arrays) == 0:
|
|
284
|
+
_logger.info("write() called with empty input for block '%s'; no-op", to_block_named)
|
|
285
|
+
return
|
|
286
|
+
|
|
287
|
+
arr0 = np.asarray(these_arrays[0])
|
|
288
|
+
_logger.info("Appending %d arrays to block '%s' (item_shape=%s, dtype=%s)",
|
|
289
|
+
len(these_arrays), to_block_named, arr0.shape, arr0.dtype)
|
|
290
|
+
block = self._setdefault(
|
|
291
|
+
to_block_named, tuple(arr0.shape), arr0.dtype, arrays_per_chunk
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# quick validation
|
|
295
|
+
for i, a in enumerate(these_arrays[1:], start=1):
|
|
296
|
+
a = np.asarray(a)
|
|
297
|
+
if a.shape != arr0.shape:
|
|
298
|
+
_logger.error("Shape mismatch at index %d: %s != %s", i, a.shape, arr0.shape)
|
|
299
|
+
raise ValueError(f"these_arrays[{i}] shape {a.shape} != {arr0.shape}")
|
|
300
|
+
if np.dtype(a.dtype) != np.dtype(arr0.dtype):
|
|
301
|
+
_logger.error("Dtype mismatch at index %d: %s != %s", i, a.dtype, arr0.dtype)
|
|
302
|
+
raise ValueError(f"these_arrays[{i}] dtype {a.dtype} != {arr0.dtype}")
|
|
303
|
+
|
|
304
|
+
data = np.asarray(these_arrays, dtype=block.dtype)
|
|
305
|
+
k = data.shape[0]
|
|
306
|
+
start = block.shape[0]
|
|
307
|
+
block.resize((start + k,) + arr0.shape)
|
|
308
|
+
block[start:start + k, ...] = data
|
|
309
|
+
_logger.info("Appended %d rows to '%s'; new length=%d", k, to_block_named, block.shape[0])
|
|
310
|
+
|
|
311
|
+
def read(
|
|
312
|
+
self,
|
|
313
|
+
from_block_named: str,
|
|
314
|
+
ids: int | slice | tuple[int] = None):
|
|
315
|
+
"""Read rows from a block and return a NumPy array.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
from_block_named: Name of the block (array) to read from.
|
|
319
|
+
ids: Row indices to select along axis 0. May be one of:
|
|
320
|
+
- ``None``: read the entire array;
|
|
321
|
+
- ``int``: a single row;
|
|
322
|
+
- ``slice``: a range of rows;
|
|
323
|
+
- ``tuple[int]``: explicit row indices (order preserved).
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
A NumPy array containing the selected data (a copy).
|
|
327
|
+
|
|
328
|
+
Raises:
|
|
329
|
+
KeyError: If the named block does not exist.
|
|
330
|
+
TypeError: If the named member is not a Zarr array.
|
|
331
|
+
"""
|
|
332
|
+
if from_block_named not in self.root:
|
|
333
|
+
_logger.error("read(): block '%s' does not exist", from_block_named)
|
|
334
|
+
raise KeyError(f"Block '{from_block_named}' does not exist")
|
|
335
|
+
|
|
336
|
+
block = self.root[from_block_named]
|
|
337
|
+
if not isinstance(block, zarr.Array):
|
|
338
|
+
_logger.error("read(): member '%s' is not a Zarr array", from_block_named)
|
|
339
|
+
raise TypeError(f"Member '{from_block_named}' is not a Zarr array")
|
|
340
|
+
|
|
341
|
+
# log selection summary (type only to avoid huge logs)
|
|
342
|
+
sel_type = type(ids).__name__ if ids is not None else "all"
|
|
343
|
+
_logger.debug("Reading from '%s' with selection=%s", from_block_named, sel_type)
|
|
344
|
+
|
|
345
|
+
if ids is None:
|
|
346
|
+
out = block[:]
|
|
347
|
+
elif isinstance(ids, (int, slice)):
|
|
348
|
+
out = block[ids, ...]
|
|
349
|
+
else:
|
|
350
|
+
idx = np.asarray(ids, dtype=np.intp)
|
|
351
|
+
out = block.get_orthogonal_selection((idx,) + (slice(None),) * (block.ndim - 1))
|
|
352
|
+
|
|
353
|
+
return np.asarray(out, copy=True)
|
|
354
|
+
|
|
355
|
+
def block_iter(
|
|
356
|
+
self,
|
|
357
|
+
from_block_named: str,
|
|
358
|
+
*,
|
|
359
|
+
step: int = 1) -> Iterator:
|
|
360
|
+
"""Iterate over a block in chunks along axis 0.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
from_block_named: Name of the block (array) to iterate over.
|
|
364
|
+
step: Number of rows per yielded chunk along axis 0.
|
|
365
|
+
|
|
366
|
+
Yields:
|
|
367
|
+
NumPy arrays of shape ``(m, *item_shape)`` where ``m <= step`` for the
|
|
368
|
+
last chunk.
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
KeyError: If the named block does not exist.
|
|
372
|
+
TypeError: If the named member is not a Zarr array.
|
|
373
|
+
"""
|
|
374
|
+
if from_block_named not in self.root:
|
|
375
|
+
_logger.error("block_iter(): block '%s' does not exist", from_block_named)
|
|
376
|
+
raise KeyError(f"Block '{from_block_named}' does not exist")
|
|
377
|
+
|
|
378
|
+
block = self.root[from_block_named]
|
|
379
|
+
if not isinstance(block, zarr.Array):
|
|
380
|
+
_logger.error("block_iter(): member '%s' is not a Zarr array", from_block_named)
|
|
381
|
+
raise TypeError(f"Member '{from_block_named}' is not a Zarr array")
|
|
382
|
+
|
|
383
|
+
_logger.info("Iterating block '%s' with step=%d", from_block_named, step)
|
|
384
|
+
|
|
385
|
+
if block.ndim == 0:
|
|
386
|
+
# scalar array
|
|
387
|
+
yield np.asarray(block[...], copy=True)
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
for i in range(0, block.shape[0], step):
|
|
391
|
+
j = min(i + step, block.shape[0])
|
|
392
|
+
out = block[i:j, ...]
|
|
393
|
+
yield np.asarray(out, copy=True)
|
|
394
|
+
|
|
395
|
+
def delete_block(self, named: str) -> None:
|
|
396
|
+
"""Delete a block and remove its metadata entries.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
named: Block (array) name to delete.
|
|
400
|
+
|
|
401
|
+
Raises:
|
|
402
|
+
RuntimeError: If the storage is opened read-only.
|
|
403
|
+
KeyError: If the block does not exist.
|
|
404
|
+
"""
|
|
405
|
+
if self.mode == "r":
|
|
406
|
+
_logger.error("delete_block() attempted in read-only mode")
|
|
407
|
+
raise RuntimeError("Cannot delete blocks from a read-only ArrayStorage")
|
|
408
|
+
|
|
409
|
+
if named not in self.root:
|
|
410
|
+
_logger.error("delete_block(): block '%s' does not exist", named)
|
|
411
|
+
raise KeyError(f"Block '{named}' does not exist")
|
|
412
|
+
|
|
413
|
+
_logger.info("Deleting block '%s'", named)
|
|
414
|
+
del self.root[named]
|
|
415
|
+
|
|
416
|
+
for key in ("array_chunk_size_in_block", "array_shape_in_block", "array_dtype_in_block"):
|
|
417
|
+
d = dict(self._attrs.get(key, {}))
|
|
418
|
+
d.pop(named, None)
|
|
419
|
+
self._attrs[key] = d
|
|
420
|
+
_logger.debug("Removed metadata entries for '%s'", named)
|
|
421
|
+
|
|
422
|
+
def add_attr(self, key: str, val: Any) -> None:
|
|
423
|
+
"""
|
|
424
|
+
Attach JSON-serializable metadata to the root group's attributes.
|
|
425
|
+
|
|
426
|
+
Coerces common non-JSON types into JSON-safe forms before writing to
|
|
427
|
+
``self.root.attrs``:
|
|
428
|
+
* NumPy scalars → native Python scalars via ``.item()``
|
|
429
|
+
* NumPy arrays → Python lists via ``.tolist()``
|
|
430
|
+
* ``set``/``tuple`` → lists
|
|
431
|
+
* ``datetime.datetime``/``datetime.date`` → ISO 8601 strings via ``.isoformat()``
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
key (str): Attribute name to set on the root group.
|
|
435
|
+
val (Any): Value to store. If not JSON-serializable as provided, it will be
|
|
436
|
+
coerced using the rules above. Large blobs should not be stored as attrs.
|
|
437
|
+
|
|
438
|
+
Raises:
|
|
439
|
+
RuntimeError: If the storage was opened in read-only mode (``mode == "r"``).
|
|
440
|
+
TypeError: If the coerced value is still not JSON-serializable by Zarr.
|
|
441
|
+
|
|
442
|
+
Examples:
|
|
443
|
+
>>> store = ArrayStorage("/tmp/demo", mode="w")
|
|
444
|
+
>>> store.add_attr("experiment", "run_3")
|
|
445
|
+
>>> store.add_attr("created_at", datetime.utcnow())
|
|
446
|
+
>>> store.add_attr("means", np.arange(3, dtype=np.float32))
|
|
447
|
+
>>> store.get_attr["experiment"]
|
|
448
|
+
'run_3'
|
|
449
|
+
|
|
450
|
+
Note:
|
|
451
|
+
If you distribute consolidated metadata, re-consolidate after changing attrs
|
|
452
|
+
so external readers can see the updates.
|
|
453
|
+
"""
|
|
454
|
+
if self.mode == "r":
|
|
455
|
+
_logger.error("Write attempted in read-only mode")
|
|
456
|
+
raise RuntimeError("Cannot write to a read-only ArrayStorage")
|
|
457
|
+
|
|
458
|
+
# coerce to JSON-safe types Zarr accepts for attrs
|
|
459
|
+
def _to_json_safe(x):
|
|
460
|
+
if isinstance(x, np.generic):
|
|
461
|
+
return x.item()
|
|
462
|
+
if isinstance(x, np.ndarray):
|
|
463
|
+
return x.tolist()
|
|
464
|
+
if isinstance(x, (set, tuple)):
|
|
465
|
+
return list(x)
|
|
466
|
+
if isinstance(x, (datetime, date)):
|
|
467
|
+
return x.isoformat()
|
|
468
|
+
return x
|
|
469
|
+
|
|
470
|
+
js_val = _to_json_safe(val)
|
|
471
|
+
try:
|
|
472
|
+
self.root.attrs[key] = js_val
|
|
473
|
+
_logger.debug("Set root attr %r=%r", key, js_val)
|
|
474
|
+
except TypeError as e:
|
|
475
|
+
_logger.error("Value for attr %r is not JSON-serializable: %s", key, e)
|
|
476
|
+
raise
|
|
477
|
+
|
|
478
|
+
def get_attr(self, key: str):
|
|
479
|
+
"""Return a root attribute by key.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
key: Attribute name.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
The stored value as-is (JSON-safe form, e.g., lists/ISO strings).
|
|
486
|
+
|
|
487
|
+
Raises:
|
|
488
|
+
KeyError: If the attribute does not exist.
|
|
489
|
+
"""
|
|
490
|
+
try:
|
|
491
|
+
val = self.root.attrs[key]
|
|
492
|
+
except KeyError:
|
|
493
|
+
_logger.error("get_attr: attribute %r not found", key)
|
|
494
|
+
raise
|
|
495
|
+
_logger.debug("get_attr: %r=%r", key, val)
|
|
496
|
+
return val
|
|
497
|
+
|
|
498
|
+
def compress(
|
|
499
|
+
self,
|
|
500
|
+
into: str | Path | None = None,
|
|
501
|
+
*,
|
|
502
|
+
compression_level: int,
|
|
503
|
+
) -> str:
|
|
504
|
+
"""Write a read-only ZipStore clone of the current store.
|
|
505
|
+
|
|
506
|
+
Copies the single root group (its attrs and all child arrays with their
|
|
507
|
+
attrs) into a new ``.zip`` file.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
into: Optional destination. If a path ending with ``.zip``, that exact
|
|
511
|
+
file is created/overwritten. If a directory, the zip is created there
|
|
512
|
+
with ``<store>.zip``. If ``None``, uses ``<store>.zip`` next to the
|
|
513
|
+
local store.
|
|
514
|
+
compression_level: Blosc compression level to use for data chunks
|
|
515
|
+
(integer, 0-9). ``0`` disables compression (still writes with a Blosc
|
|
516
|
+
container); higher = more compression, slower writes.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
Path to the created ZipStore as a string.
|
|
520
|
+
|
|
521
|
+
Notes:
|
|
522
|
+
* If the backend is already a ZipStore, this is a no-op (path returned).
|
|
523
|
+
* For Zarr v3, compressors are part of the *codecs pipeline*. Here we set
|
|
524
|
+
a single compressor (Blosc with Zstd) and rely on defaults for the
|
|
525
|
+
serializer; that's valid and interoperable.
|
|
526
|
+
"""
|
|
527
|
+
if isinstance(self.store, ZipStore):
|
|
528
|
+
_logger.info("compress(): already a ZipStore; returning current path")
|
|
529
|
+
return str(self.store_path)
|
|
530
|
+
|
|
531
|
+
# --- destination path resolution ---
|
|
532
|
+
if into is None:
|
|
533
|
+
zip_path = self.store_path.with_suffix(".zip")
|
|
534
|
+
else:
|
|
535
|
+
into = Path(into)
|
|
536
|
+
if into.suffix.lower() == ".zip":
|
|
537
|
+
zip_path = into.resolve()
|
|
538
|
+
else:
|
|
539
|
+
zip_path = (into / self.store_path.with_suffix(".zip").name).resolve()
|
|
540
|
+
zip_path.parent.mkdir(parents=True, exist_ok=True)
|
|
541
|
+
|
|
542
|
+
# --- compression level checks & logs ---
|
|
543
|
+
try:
|
|
544
|
+
clevel = int(compression_level)
|
|
545
|
+
except Exception as e:
|
|
546
|
+
_logger.error("Invalid compression_level=%r (%s)", compression_level, e)
|
|
547
|
+
raise
|
|
548
|
+
|
|
549
|
+
if not (0 <= clevel <= 9):
|
|
550
|
+
_logger.error("compression_level out of range: %r (expected 0..9)", clevel)
|
|
551
|
+
raise ValueError("compression_level must be in the range [0, 9]")
|
|
552
|
+
|
|
553
|
+
if clevel == 0:
|
|
554
|
+
_logger.warning("Compression disabled: compression_level=0")
|
|
555
|
+
|
|
556
|
+
_logger.info("Compressing store to ZipStore at %s with Blosc(zstd, clevel=%d, shuffle=shuffle)",
|
|
557
|
+
zip_path, clevel)
|
|
558
|
+
|
|
559
|
+
def _attrs_dict(attrs):
|
|
560
|
+
try:
|
|
561
|
+
return attrs.asdict()
|
|
562
|
+
except Exception:
|
|
563
|
+
return dict(attrs)
|
|
564
|
+
|
|
565
|
+
with ZipStore(zip_path, mode="w") as z:
|
|
566
|
+
dst_root = zarr.group(store=z)
|
|
567
|
+
|
|
568
|
+
dst_root.attrs.update(_attrs_dict(self.root.attrs))
|
|
569
|
+
|
|
570
|
+
copied = 0
|
|
571
|
+
for key, src in self.root.arrays():
|
|
572
|
+
|
|
573
|
+
dst = dst_root.create_array(
|
|
574
|
+
name=key,
|
|
575
|
+
shape=src.shape,
|
|
576
|
+
chunks=src.chunks,
|
|
577
|
+
dtype=src.dtype,
|
|
578
|
+
compressors=BloscCodec(
|
|
579
|
+
cname=BloscCname.zstd,
|
|
580
|
+
clevel=clevel,
|
|
581
|
+
shuffle=BloscShuffle.shuffle,
|
|
582
|
+
)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
dst.attrs.update(_attrs_dict(src.attrs))
|
|
586
|
+
dst[...] = src[...]
|
|
587
|
+
copied += 1
|
|
588
|
+
_logger.debug("Compressed array '%s' shape=%s dtype=%s", key, src.shape, src.dtype)
|
|
589
|
+
|
|
590
|
+
_logger.info("Compression complete: %d arrays -> %s", copied, zip_path)
|
|
591
|
+
return str(zip_path)
|
|
592
|
+
|
|
593
|
+
@classmethod
|
|
594
|
+
@contextmanager
|
|
595
|
+
def compress_and_cleanup(cls, output_pth: str | Path, compression_level: int) -> Iterator[ArrayStorage]:
|
|
596
|
+
"""
|
|
597
|
+
Create a temporary ArrayStorage, yield it for writes, then compress it into `output_pth`.
|
|
598
|
+
The temporary local store is deleted after compression.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
output_pth: Destination .zip file or directory (delegated to `compress(into=...)`).
|
|
602
|
+
compression_level: Blosc compression level to use for data chunks
|
|
603
|
+
(integer, 0-9). ``0`` disables compression (still writes with a Blosc
|
|
604
|
+
container); higher = more compression, slower writes.
|
|
605
|
+
"""
|
|
606
|
+
output_pth = Path(output_pth)
|
|
607
|
+
_logger.info("compress_and_cleanup: creating temp store (suffix .zarr)")
|
|
608
|
+
with tempfile.TemporaryDirectory(suffix=".zarr") as tmp_dir:
|
|
609
|
+
arr_storage = cls(tmp_dir, mode="w")
|
|
610
|
+
try:
|
|
611
|
+
yield arr_storage
|
|
612
|
+
finally:
|
|
613
|
+
_logger.info("compress_and_cleanup: compressing to %s (compression level of %d)", output_pth, compression_level)
|
|
614
|
+
arr_storage.compress(output_pth, compression_level=compression_level)
|
|
615
|
+
arr_storage.close()
|
|
616
|
+
_logger.info("compress_and_cleanup: temp store cleaned up")
|
|
617
|
+
|
|
618
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
619
|
+
# PARALLEL PROCESSING AND EFFICIENT MEMORY USAGE RELATED FUNCTIONS
|
|
620
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
621
|
+
|
|
622
|
+
def is_main_process() -> bool:
|
|
623
|
+
p = mp.current_process()
|
|
624
|
+
return mp.parent_process() is None and p.name == "MainProcess"
|
|
625
|
+
|
|
626
|
+
def _apply(f: Callable, x: Any, extra_args: tuple, extra_kwargs: dict) -> Any:
|
|
627
|
+
return f(x, *extra_args, **extra_kwargs)
|
|
628
|
+
|
|
629
|
+
def elementwise_processor(
|
|
630
|
+
in_parallel: bool = False,
|
|
631
|
+
Executor: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] | None = None,
|
|
632
|
+
max_workers: int | None = None,
|
|
633
|
+
capture_output: bool = True,
|
|
634
|
+
) -> Callable[[Iterable[Any], Callable[..., Any], Any], list[Any] | None]:
|
|
635
|
+
"""Factory that returns a function to process an iterable elementwise.
|
|
636
|
+
|
|
637
|
+
The returned callable executes a provided `function` over each element of an
|
|
638
|
+
`iterable`, either sequentially or in parallel using the specified
|
|
639
|
+
`Executor`. Results are optionally captured and returned as a list.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
in_parallel: If True, process with a concurrent executor; otherwise run sequentially.
|
|
643
|
+
Executor: Executor class to use when `in_parallel` is True
|
|
644
|
+
(e.g., `ThreadPoolExecutor` or `ProcessPoolExecutor`). Ignored if `in_parallel` is False.
|
|
645
|
+
max_workers: Maximum parallel workers. Defaults to `os.cpu_count()` when None.
|
|
646
|
+
capture_output: If True, collect and return results; if False, execute for side effects and return None.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
A callable with signature:
|
|
650
|
+
`(iterable, function, *extra_args, **extra_kwargs) -> list | None`
|
|
651
|
+
When `capture_output` is True, the list preserves the input order.
|
|
652
|
+
|
|
653
|
+
Raises:
|
|
654
|
+
ValueError: If `in_parallel` is True and `Executor` is None.
|
|
655
|
+
Exception: Any exception raised by `function` for a given element is propagated.
|
|
656
|
+
|
|
657
|
+
Notes:
|
|
658
|
+
- In parallel mode, task results are re-ordered to match input order.
|
|
659
|
+
- In non-capturing modes, tasks are still awaited so exceptions surface.
|
|
660
|
+
|
|
661
|
+
Example:
|
|
662
|
+
>>> runner = elementwise_processor(in_parallel=True, Executor=ThreadPoolExecutor, max_workers=4)
|
|
663
|
+
>>> out = runner(range(5), lambda x: x * 2)
|
|
664
|
+
>>> out
|
|
665
|
+
[0, 2, 4, 6, 8]
|
|
666
|
+
"""
|
|
667
|
+
def inner(iterable: Iterable[Any], function: Callable, *extra_args: Any, **extra_kwargs: Any) -> list[Any] | None:
|
|
668
|
+
"""Execute `function` over `iterable` per the configuration of the factory.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
iterable: Collection of input elements to process.
|
|
672
|
+
function: Callable applied to each element of `iterable`.
|
|
673
|
+
*extra_args: Extra positional arguments forwarded to `function`.
|
|
674
|
+
**extra_kwargs: Extra keyword arguments forwarded to `function`.
|
|
675
|
+
|
|
676
|
+
Returns:
|
|
677
|
+
List of results when `capture_output` is True; otherwise None.
|
|
678
|
+
|
|
679
|
+
Raises:
|
|
680
|
+
ValueError: If `Executor` is missing while `in_parallel` is True.
|
|
681
|
+
Exception: Any exception raised by `function` is propagated.
|
|
682
|
+
"""
|
|
683
|
+
_logger.debug(
|
|
684
|
+
"elementwise_processor: in_parallel=%s, Executor=%s, max_workers=%s, capture_output=%s, func=%s",
|
|
685
|
+
in_parallel, getattr(Executor, "__name__", None), max_workers, capture_output, getattr(function, "__name__", repr(function))
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
if not in_parallel:
|
|
689
|
+
_logger.info("elementwise_processor: running sequentially")
|
|
690
|
+
if capture_output:
|
|
691
|
+
result = [function(x, *extra_args, **extra_kwargs) for x in iterable]
|
|
692
|
+
_logger.info("elementwise_processor: sequential completed with %d results", len(result))
|
|
693
|
+
return result
|
|
694
|
+
else:
|
|
695
|
+
for x in iterable:
|
|
696
|
+
function(x, *extra_args, **extra_kwargs)
|
|
697
|
+
_logger.info("elementwise_processor: sequential completed (no capture)")
|
|
698
|
+
return None
|
|
699
|
+
|
|
700
|
+
if Executor is None:
|
|
701
|
+
_logger.error("elementwise_processor: Executor is required when in_parallel=True")
|
|
702
|
+
raise ValueError("An 'Executor' argument must be provided if 'in_parallel' is True.")
|
|
703
|
+
|
|
704
|
+
local_max_workers = max_workers or (os.cpu_count() or 1)
|
|
705
|
+
_logger.info("elementwise_processor: starting parallel with %d workers via %s", local_max_workers, Executor.__name__)
|
|
706
|
+
with Executor(max_workers=local_max_workers) as executor:
|
|
707
|
+
futures = {executor.submit(_apply, function, x, extra_args, extra_kwargs): i
|
|
708
|
+
for i, x in enumerate(iterable)}
|
|
709
|
+
_logger.info("elementwise_processor: submitted %d tasks", len(futures))
|
|
710
|
+
if capture_output:
|
|
711
|
+
results: list[Any] = [None] * len(futures)
|
|
712
|
+
for fut in as_completed(futures):
|
|
713
|
+
idx = futures[fut]
|
|
714
|
+
try:
|
|
715
|
+
results[idx] = fut.result()
|
|
716
|
+
except Exception:
|
|
717
|
+
_logger.exception("elementwise_processor: task %d raised", idx)
|
|
718
|
+
raise
|
|
719
|
+
_logger.info("elementwise_processor: parallel completed with %d results", len(results))
|
|
720
|
+
return results
|
|
721
|
+
else:
|
|
722
|
+
for fut in as_completed(futures):
|
|
723
|
+
try:
|
|
724
|
+
fut.result()
|
|
725
|
+
except Exception:
|
|
726
|
+
_logger.exception("elementwise_processor: task %d raised", futures[fut])
|
|
727
|
+
raise
|
|
728
|
+
_logger.info("elementwise_processor: parallel completed (no capture)")
|
|
729
|
+
return None
|
|
730
|
+
|
|
731
|
+
return inner
|
|
732
|
+
|
|
733
|
+
def files_from(dir_path: str, pattern: re.Pattern = None) -> list[str]:
|
|
734
|
+
"""List files in a directory matching a regex pattern.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
dir_path: Path to the directory to scan.
|
|
738
|
+
pattern: Compiled regex pattern to match file names. If None, matches all files.
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
A sorted list of absolute (string) file paths present in `dir_path` that match `pattern`.
|
|
742
|
+
|
|
743
|
+
Notes:
|
|
744
|
+
- Only regular files are returned; directories are ignored.
|
|
745
|
+
- Raises `FileNotFoundError`/`PermissionError` if `dir_path` is invalid/inaccessible.
|
|
746
|
+
"""
|
|
747
|
+
pat = pattern or re.compile(r".*")
|
|
748
|
+
dp = Path(dir_path)
|
|
749
|
+
_logger.debug("files_from: scanning %s with pattern=%r", dp, pat.pattern if hasattr(pat, "pattern") else pat)
|
|
750
|
+
files = list()
|
|
751
|
+
for file_name in sorted(os.listdir(dir_path)):
|
|
752
|
+
pth = dp / file_name
|
|
753
|
+
if pat.match(file_name) and pth.is_file():
|
|
754
|
+
files.append(str(pth))
|
|
755
|
+
_logger.debug("files_from: matched %d files in %s", len(files), dp)
|
|
756
|
+
return files
|
|
757
|
+
|
|
758
|
+
def file_chunks_generator(file_path: str, chunk_size: int, skip_header: bool = True) -> Iterable[list[str]]:
|
|
759
|
+
"""Yield lists of text lines from a file using a size-hint per chunk.
|
|
760
|
+
|
|
761
|
+
Uses `io.TextIOBase.readlines(sizehint)` to read approximately `chunk_size`
|
|
762
|
+
bytes per iteration, always ending on a line boundary.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
file_path: UTF-8 encoded text file to read.
|
|
766
|
+
chunk_size: Approximate number of bytes to read per chunk (size hint).
|
|
767
|
+
skip_header: If True, skip the first line before yielding content.
|
|
768
|
+
|
|
769
|
+
Yields:
|
|
770
|
+
Lists of strings, each list containing complete lines.
|
|
771
|
+
|
|
772
|
+
Notes:
|
|
773
|
+
- The `sizehint` is approximate; chunks may be larger or smaller.
|
|
774
|
+
- If `skip_header` is True and the file is empty, the generator returns immediately.
|
|
775
|
+
"""
|
|
776
|
+
_logger.info("file_chunks_generator: file=%s chunk_size=%d skip_header=%s", file_path, chunk_size, skip_header)
|
|
777
|
+
with open(file_path, "r", encoding="utf-8") as file:
|
|
778
|
+
if skip_header:
|
|
779
|
+
try:
|
|
780
|
+
next(file)
|
|
781
|
+
_logger.debug("file_chunks_generator: skipped header line")
|
|
782
|
+
except StopIteration:
|
|
783
|
+
_logger.info("file_chunks_generator: file empty after header skip")
|
|
784
|
+
return
|
|
785
|
+
while True:
|
|
786
|
+
chunk = file.readlines(chunk_size)
|
|
787
|
+
if not chunk:
|
|
788
|
+
break
|
|
789
|
+
yield chunk
|
|
790
|
+
_logger.debug("file_chunks_generator: completed for %s", file_path)
|
|
791
|
+
|
|
792
|
+
def chunked_file(file_path: str, allowed_memory_percentage_hint: float, num_workers: int) -> Iterable[list[str]]:
|
|
793
|
+
"""Split a file into line chunks sized by per-worker memory allowance.
|
|
794
|
+
|
|
795
|
+
Heuristically plans chunk sizes from available system memory and the
|
|
796
|
+
declared number of workers, then yields line lists produced by
|
|
797
|
+
`file_chunks_generator`.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
file_path: Path to a UTF-8 text file.
|
|
801
|
+
allowed_memory_percentage_hint: Fraction in (0, 1] of *available* RAM to budget in total,
|
|
802
|
+
divided across workers.
|
|
803
|
+
num_workers: Number of workers the chunks are intended for.
|
|
804
|
+
|
|
805
|
+
Yields:
|
|
806
|
+
Lists of strings representing line chunks of the file.
|
|
807
|
+
|
|
808
|
+
Raises:
|
|
809
|
+
ValueError: If `allowed_memory_percentage_hint` not in (0, 1] or `num_workers` < 1.
|
|
810
|
+
|
|
811
|
+
Notes:
|
|
812
|
+
- This is a heuristic: Python string overhead and decoding expand beyond raw bytes.
|
|
813
|
+
- If the whole file fits within `memory_per_worker`, a single chunk is yielded.
|
|
814
|
+
"""
|
|
815
|
+
if not (0 < allowed_memory_percentage_hint <= 1.0):
|
|
816
|
+
_logger.error("chunked_file: invalid allowed_memory_percentage_hint=%s", allowed_memory_percentage_hint)
|
|
817
|
+
raise ValueError(f"Invalid allowed_memory_percentage_hint parameter: expected a value between 0 and 1, instead got: {allowed_memory_percentage_hint}")
|
|
818
|
+
|
|
819
|
+
if num_workers < 1:
|
|
820
|
+
_logger.error("chunked_file: num_workers must be >= 1 (got %s)", num_workers)
|
|
821
|
+
raise ValueError("num_workers must be at least 1")
|
|
822
|
+
|
|
823
|
+
memory_per_worker = max(1, int((allowed_memory_percentage_hint * psutil.virtual_memory().available) / num_workers))
|
|
824
|
+
file_size = os.path.getsize(file_path)
|
|
825
|
+
_logger.info("chunked_file: file_size=%d bytes, memory_per_worker=%d bytes", file_size, memory_per_worker)
|
|
826
|
+
|
|
827
|
+
if file_size <= memory_per_worker:
|
|
828
|
+
_logger.info("chunked_file: file fits in memory per worker; yielding all lines")
|
|
829
|
+
yield read_lines(file_path)
|
|
830
|
+
return
|
|
831
|
+
|
|
832
|
+
num_chunks = max(1, ceil(file_size / memory_per_worker))
|
|
833
|
+
chunk_size = max(1, file_size // num_chunks)
|
|
834
|
+
_logger.info("chunked_file: planning %d chunks (~%d bytes each)", num_chunks, chunk_size)
|
|
835
|
+
|
|
836
|
+
yield from file_chunks_generator(file_path, chunk_size)
|
|
837
|
+
|
|
838
|
+
def dir_chunks_generator(file_paths: list[str], files_per_chunk: int, residual_files: int):
|
|
839
|
+
"""Yield lists of file paths partitioned by a base chunk size and residuals.
|
|
840
|
+
|
|
841
|
+
Distributes `residual_files` by giving the first `residual_files` chunks one
|
|
842
|
+
extra file each.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
file_paths: Full list of file paths to chunk.
|
|
846
|
+
files_per_chunk: Base number of files to include in each chunk (>= 0).
|
|
847
|
+
residual_files: Number of initial chunks that should receive one additional file.
|
|
848
|
+
|
|
849
|
+
Yields:
|
|
850
|
+
Slices (lists) of `file_paths` representing each chunk.
|
|
851
|
+
|
|
852
|
+
Notes:
|
|
853
|
+
- If `files_per_chunk <= 0`, all files are yielded as a single chunk.
|
|
854
|
+
- The final tail (if any) is yielded after full chunks.
|
|
855
|
+
"""
|
|
856
|
+
total_files = len(file_paths)
|
|
857
|
+
_logger.debug("dir_chunks_generator: total_files=%d, files_per_chunk=%d, residual_files=%d",
|
|
858
|
+
total_files, files_per_chunk, residual_files)
|
|
859
|
+
|
|
860
|
+
if files_per_chunk <= 0:
|
|
861
|
+
_logger.info("dir_chunks_generator: files_per_chunk<=0 → yielding all %d files at once", total_files)
|
|
862
|
+
yield file_paths
|
|
863
|
+
return
|
|
864
|
+
|
|
865
|
+
num_chunks = (total_files - residual_files) // files_per_chunk
|
|
866
|
+
_logger.debug("dir_chunks_generator: full_chunks=%d", num_chunks)
|
|
867
|
+
|
|
868
|
+
start = 0
|
|
869
|
+
for i in range(num_chunks):
|
|
870
|
+
chunk_size = files_per_chunk + 1 if i < residual_files else files_per_chunk
|
|
871
|
+
yield file_paths[start:start + chunk_size]
|
|
872
|
+
start += chunk_size
|
|
873
|
+
|
|
874
|
+
if start < total_files:
|
|
875
|
+
_logger.debug("dir_chunks_generator: yielding tail chunk of %d files", total_files - start)
|
|
876
|
+
yield file_paths[start:]
|
|
877
|
+
|
|
878
|
+
def chunked_dir(dir_path: str, allowed_memory_percentage_hint: float, num_workers: int):
|
|
879
|
+
"""Plan directory file chunks to fit a per-worker memory hint.
|
|
880
|
+
|
|
881
|
+
Assumes files in `dir_path` are of similar size (uses the first file as a
|
|
882
|
+
representative) to estimate how many files can be processed per worker.
|
|
883
|
+
Yields lists of file paths sized accordingly.
|
|
884
|
+
|
|
885
|
+
Args:
|
|
886
|
+
dir_path: Directory containing the files to chunk.
|
|
887
|
+
allowed_memory_percentage_hint: Fraction in (0, 1] of available RAM to allocate across workers.
|
|
888
|
+
num_workers: Number of workers that will process the chunks.
|
|
889
|
+
|
|
890
|
+
Yields:
|
|
891
|
+
Lists of file paths sized for concurrent processing.
|
|
892
|
+
|
|
893
|
+
Raises:
|
|
894
|
+
ValueError: If inputs are invalid or the directory is empty.
|
|
895
|
+
MemoryError: If a single file is too large for the per-worker memory allowance.
|
|
896
|
+
|
|
897
|
+
Notes:
|
|
898
|
+
- If the first file is empty, a 1-byte surrogate is used to avoid division by zero.
|
|
899
|
+
- Actual memory usage depends on file content and processing overhead.
|
|
900
|
+
"""
|
|
901
|
+
if not (0 < allowed_memory_percentage_hint <= 1.0):
|
|
902
|
+
_logger.error("chunked_dir: invalid allowed_memory_percentage_hint=%s", allowed_memory_percentage_hint)
|
|
903
|
+
raise ValueError(f"Invalid allowed_memory_percentage_hint parameter: expected a value between 0 and 1, instead got: {allowed_memory_percentage_hint}")
|
|
904
|
+
if num_workers < 1:
|
|
905
|
+
_logger.error("chunked_dir: num_workers must be >= 1 (got %s)", num_workers)
|
|
906
|
+
raise ValueError("num_workers must be at least 1")
|
|
907
|
+
|
|
908
|
+
memory_per_worker = max(1, int((psutil.virtual_memory().available * allowed_memory_percentage_hint) / num_workers))
|
|
909
|
+
_logger.info("chunked_dir: memory_per_worker=%d bytes (hint=%s, workers=%d)", memory_per_worker, allowed_memory_percentage_hint, num_workers)
|
|
910
|
+
|
|
911
|
+
file_paths = files_from(dir_path)
|
|
912
|
+
if not file_paths:
|
|
913
|
+
_logger.error("chunked_dir: no files found in directory %s", dir_path)
|
|
914
|
+
raise ValueError(f"No files found in directory: {dir_path}")
|
|
915
|
+
|
|
916
|
+
file_size = os.path.getsize(file_paths[0])
|
|
917
|
+
if file_size == 0:
|
|
918
|
+
_logger.warning("chunked_dir: first file is zero bytes; falling back to 1 byte for sizing")
|
|
919
|
+
file_size = 1
|
|
920
|
+
|
|
921
|
+
files_per_worker = int(memory_per_worker // file_size)
|
|
922
|
+
_logger.info("chunked_dir: representative_file_size=%d bytes -> files_per_worker=%d", file_size, files_per_worker)
|
|
923
|
+
|
|
924
|
+
if files_per_worker < 1:
|
|
925
|
+
_logger.error("chunked_dir: files too large for current memory hint per worker")
|
|
926
|
+
raise MemoryError(
|
|
927
|
+
f"The files contained in {dir_path} are too large. Cannot distribute the files across the workers. "
|
|
928
|
+
"Solution: increase 'allowed_memory_percentage_hint', if possible, or decrease 'num_workers'"
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
num_files = len(file_paths)
|
|
932
|
+
files_per_chunk = files_per_worker
|
|
933
|
+
residual_files = num_files % files_per_chunk
|
|
934
|
+
_logger.info("chunked_dir: num_files=%d -> files_per_chunk=%d, residual_files=%d", num_files, files_per_chunk, residual_files)
|
|
935
|
+
|
|
936
|
+
yield from dir_chunks_generator(file_paths, files_per_chunk, residual_files)
|
|
937
|
+
|
|
938
|
+
def read_lines(file_path: str, skip_header: bool = True) -> list[str]:
|
|
939
|
+
"""Read all lines from a UTF-8 text file, optionally skipping the header.
|
|
940
|
+
|
|
941
|
+
Args:
|
|
942
|
+
file_path: Path to the input file.
|
|
943
|
+
skip_header: If True, omit the first line from the returned list.
|
|
944
|
+
|
|
945
|
+
Returns:
|
|
946
|
+
A list of lines (strings). If `skip_header` is True and the file is
|
|
947
|
+
non-empty, the first line is excluded.
|
|
948
|
+
|
|
949
|
+
Notes:
|
|
950
|
+
- Uses `readlines()`; for gigantic files prefer streaming approaches.
|
|
951
|
+
"""
|
|
952
|
+
_logger.debug("read_lines: reading %s (skip_header=%s)", file_path, skip_header)
|
|
953
|
+
with open(file_path, "r", encoding="utf-8") as file:
|
|
954
|
+
lines = file.readlines()
|
|
955
|
+
_logger.debug("read_lines: read %d lines from %s", len(lines), file_path)
|
|
956
|
+
return lines[1:] if (skip_header and lines) else lines
|
|
957
|
+
|
|
958
|
+
def temporary_file(prefix: str, suffix: str) -> Path:
|
|
959
|
+
"""Create a named temporary file and return its path.
|
|
960
|
+
|
|
961
|
+
This helper creates a `NamedTemporaryFile`, closes it immediately, and
|
|
962
|
+
returns its filesystem path so other processes can open/write it later.
|
|
963
|
+
The caller is responsible for deleting the file when finished.
|
|
964
|
+
|
|
965
|
+
Args:
|
|
966
|
+
prefix: Filename prefix used when creating the temporary file.
|
|
967
|
+
suffix: Filename suffix (e.g., extension) used when creating the file.
|
|
968
|
+
|
|
969
|
+
Returns:
|
|
970
|
+
Path: Filesystem path to the created temporary file.
|
|
971
|
+
|
|
972
|
+
Notes:
|
|
973
|
+
The file is created on the default temporary directory for the system.
|
|
974
|
+
The file handle is closed before returning, so only the path is kept.
|
|
975
|
+
"""
|
|
976
|
+
ntf = tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=False)
|
|
977
|
+
ntf.close()
|
|
978
|
+
return Path(ntf.name)
|
|
979
|
+
|
|
980
|
+
def batches_of(iterable: Iterable,
|
|
981
|
+
batch_size: int = -1,
|
|
982
|
+
*,
|
|
983
|
+
out_as: type = list,
|
|
984
|
+
ranges: bool = False,
|
|
985
|
+
inclusive_end: bool = False):
|
|
986
|
+
"""Yield elements of `iterable` in fixed-size batches or index ranges.
|
|
987
|
+
|
|
988
|
+
Works with any iterable (lists, ranges, generators, file objects, etc.).
|
|
989
|
+
For sliceable sequences, a fast path uses len()+slicing; for general
|
|
990
|
+
iterables, items are accumulated into chunks.
|
|
991
|
+
|
|
992
|
+
When `ranges=True`, yields 0-based index ranges based on consumption
|
|
993
|
+
order: `(start, end_exclusive)` (or `(start, end_inclusive)` if
|
|
994
|
+
`inclusive_end=True`) without materializing the data.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
iterable: Any iterable to batch (sequence or generator).
|
|
998
|
+
batch_size: Number of items per batch. If <= 0, the entire iterable is
|
|
999
|
+
yielded as a single batch. Defaults to -1.
|
|
1000
|
+
out_as: Constructor to wrap each yielded batch (e.g., `list`, `tuple`)
|
|
1001
|
+
or to wrap the index pair when `ranges=True`. Defaults to `list`.
|
|
1002
|
+
ranges: If True, yield index ranges instead of data batches. Defaults to False.
|
|
1003
|
+
inclusive_end: If `ranges=True`, return an inclusive end index instead of
|
|
1004
|
+
exclusive. Ignored when `ranges=False`. Defaults to False.
|
|
1005
|
+
|
|
1006
|
+
Yields:
|
|
1007
|
+
For `ranges=False`: a batch containing up to `batch_size` elements,
|
|
1008
|
+
wrapped with `out_as`.
|
|
1009
|
+
For `ranges=True`: an index pair `(start, end_exclusive)` (or inclusive)
|
|
1010
|
+
wrapped with `out_as`.
|
|
1011
|
+
|
|
1012
|
+
Examples:
|
|
1013
|
+
>>> list(batches_of([1,2,3,4,5], batch_size=2))
|
|
1014
|
+
[[1, 2], [3, 4], [5]]
|
|
1015
|
+
|
|
1016
|
+
>>> list(batches_of(range(10), batch_size=4, ranges=True))
|
|
1017
|
+
[(0, 4), (4, 8), (8, 10)]
|
|
1018
|
+
|
|
1019
|
+
>>> gen = (i*i for i in range(7))
|
|
1020
|
+
>>> list(batches_of(gen, batch_size=3, out_as=tuple))
|
|
1021
|
+
[(0, 1, 4), (9, 16, 25), (36,)]
|
|
1022
|
+
"""
|
|
1023
|
+
# try fast path for sliceable sequences (len + slicing)
|
|
1024
|
+
try:
|
|
1025
|
+
n = len(iterable) # may raise TypeError for generators
|
|
1026
|
+
_ = iterable[0:0] # cheap probe for slicing support
|
|
1027
|
+
is_sliceable = True
|
|
1028
|
+
except Exception:
|
|
1029
|
+
n = None
|
|
1030
|
+
is_sliceable = False
|
|
1031
|
+
|
|
1032
|
+
if is_sliceable:
|
|
1033
|
+
if batch_size <= 0:
|
|
1034
|
+
batch_size = n
|
|
1035
|
+
for start in range(0, n, batch_size):
|
|
1036
|
+
end_excl = min(start + batch_size, n)
|
|
1037
|
+
if ranges:
|
|
1038
|
+
yield out_as((start, end_excl - 1)) if inclusive_end else out_as((start, end_excl))
|
|
1039
|
+
else:
|
|
1040
|
+
yield out_as(iterable[start:end_excl])
|
|
1041
|
+
return
|
|
1042
|
+
|
|
1043
|
+
# generic-iterable path (generators, iterators, file objects, etc.)
|
|
1044
|
+
it = iter(iterable)
|
|
1045
|
+
|
|
1046
|
+
if batch_size <= 0:
|
|
1047
|
+
# consume everything into a single batch
|
|
1048
|
+
chunk = list(it)
|
|
1049
|
+
if ranges:
|
|
1050
|
+
end_excl = len(chunk)
|
|
1051
|
+
yield out_as((0, end_excl - 1)) if inclusive_end else out_as((0, end_excl))
|
|
1052
|
+
else:
|
|
1053
|
+
yield out_as(chunk)
|
|
1054
|
+
return
|
|
1055
|
+
|
|
1056
|
+
start_idx = 0
|
|
1057
|
+
while True:
|
|
1058
|
+
chunk = []
|
|
1059
|
+
try:
|
|
1060
|
+
for _ in range(batch_size):
|
|
1061
|
+
chunk.append(next(it))
|
|
1062
|
+
except StopIteration:
|
|
1063
|
+
pass
|
|
1064
|
+
|
|
1065
|
+
if not chunk:
|
|
1066
|
+
break
|
|
1067
|
+
|
|
1068
|
+
if ranges:
|
|
1069
|
+
end_excl = start_idx + len(chunk)
|
|
1070
|
+
yield out_as((start_idx, end_excl - 1)) if inclusive_end else out_as((start_idx, end_excl))
|
|
1071
|
+
else:
|
|
1072
|
+
yield out_as(chunk)
|
|
1073
|
+
|
|
1074
|
+
start_idx += len(chunk)
|
|
1075
|
+
|
|
1076
|
+
def create_updated_subprocess_env(**var_vals: Any) -> dict[str, str]:
|
|
1077
|
+
"""Return a copy of the current environment with specified overrides.
|
|
1078
|
+
|
|
1079
|
+
Convenience helper for preparing an `env` dict to pass to `subprocess.run`.
|
|
1080
|
+
Values are converted to strings; booleans map to ``"TRUE"``/``"FALSE"``.
|
|
1081
|
+
If a value is `None`, the variable is removed from the child environment.
|
|
1082
|
+
Path-like values are converted via `os.fspath`.
|
|
1083
|
+
|
|
1084
|
+
Args:
|
|
1085
|
+
**var_vals: Mapping of environment variable names to desired values.
|
|
1086
|
+
- `None`: remove the variable from the environment.
|
|
1087
|
+
- `bool`: stored as `"TRUE"` or `"FALSE"`.
|
|
1088
|
+
- `int`, `str`, path-like: converted to `str` (path-like via `os.fspath`).
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
dict[str, str]: A new environment dictionary suitable for `subprocess.run`.
|
|
1092
|
+
|
|
1093
|
+
Examples:
|
|
1094
|
+
>>> env = create_updated_subprocess_env(OMP_NUM_THREADS=1, MKL_DYNAMIC=False)
|
|
1095
|
+
>>> env["OMP_NUM_THREADS"]
|
|
1096
|
+
'1'
|
|
1097
|
+
>>> env["MKL_DYNAMIC"]
|
|
1098
|
+
'FALSE'
|
|
1099
|
+
"""
|
|
1100
|
+
env: dict[str, str] = os.environ.copy()
|
|
1101
|
+
for var, val in var_vals.items():
|
|
1102
|
+
if val is None:
|
|
1103
|
+
env.pop(var, None)
|
|
1104
|
+
elif isinstance(val, bool):
|
|
1105
|
+
env[var] = "TRUE" if val else "FALSE"
|
|
1106
|
+
else:
|
|
1107
|
+
env[var] = os.fspath(val) if hasattr(val, "__fspath__") else str(val)
|
|
1108
|
+
return env
|
|
1109
|
+
|
|
1110
|
+
def current_time() -> str:
|
|
1111
|
+
"""Returns the current time in the Y-%m-%d_%H%M%S format"""
|
|
1112
|
+
return datetime.now().strftime("%Y-%m-%d_%H%M%S")
|
|
1113
|
+
|
|
1114
|
+
def compose_steps(
|
|
1115
|
+
*steps: tuple[
|
|
1116
|
+
Callable[..., Any], dict[str, Any] | None
|
|
1117
|
+
]
|
|
1118
|
+
) -> Callable[[Any], Any]:
|
|
1119
|
+
"""Compose a pipeline from an ordered sequence of (function, kwargs) pairs.
|
|
1120
|
+
|
|
1121
|
+
This helper returns a unary function that feeds an input value through each
|
|
1122
|
+
step you provide, in the exact order the steps appear in the argument list.
|
|
1123
|
+
Each step is a 2-tuple ``(func, kwargs)``; the composed function will call
|
|
1124
|
+
``func(current, **(kwargs or {}))``, where *current* is the running value,
|
|
1125
|
+
and use the return value as the next *current*.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
*steps: Variable-length sequence of pairs ``(callable, kwargs_dict_or_None)``.
|
|
1129
|
+
- Each callable must accept at least one positional argument
|
|
1130
|
+
(the current value) plus any keyword arguments supplied.
|
|
1131
|
+
- ``kwargs`` may be ``None`` to indicate no keyword arguments.
|
|
1132
|
+
- The order of steps determines execution order.
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
Callable[[Any], Any]: A function ``g(x)`` that applies all steps to ``x``
|
|
1136
|
+
and returns the final result.
|
|
1137
|
+
|
|
1138
|
+
Raises:
|
|
1139
|
+
TypeError: If any element of ``steps`` is not a 2-tuple of
|
|
1140
|
+
``(callable, dict_or_None)``.
|
|
1141
|
+
Any exception raised by an individual step is propagated unchanged.
|
|
1142
|
+
|
|
1143
|
+
Notes:
|
|
1144
|
+
- If a step mutates its input and returns ``None``, the next step will
|
|
1145
|
+
receive ``None``. Ensure each step returns the value you want to pass on.
|
|
1146
|
+
- ``kwargs`` is shallow-copied (via ``dict(kwargs)``) before each call so a
|
|
1147
|
+
callee cannot mutate the original mapping.
|
|
1148
|
+
|
|
1149
|
+
Examples:
|
|
1150
|
+
>>> def scale(a, *, c): return a * c
|
|
1151
|
+
>>> def shift(a, *, b): return a + b
|
|
1152
|
+
>>> pipeline = compose_steps((scale, {'c': 2}), (shift, {'b': 3}))
|
|
1153
|
+
>>> pipeline(10)
|
|
1154
|
+
23
|
|
1155
|
+
"""
|
|
1156
|
+
# validation
|
|
1157
|
+
for i, pair in enumerate(steps):
|
|
1158
|
+
if not (isinstance(pair, tuple) and len(pair) == 2 and callable(pair[0])):
|
|
1159
|
+
raise TypeError(
|
|
1160
|
+
f"steps[{i}] must be a (callable, kwargs_dict_or_None) pair; got: {pair!r}"
|
|
1161
|
+
)
|
|
1162
|
+
|
|
1163
|
+
def inner(x: Any) -> Any:
|
|
1164
|
+
for func, kwargs in steps:
|
|
1165
|
+
x = func(x, **({} if kwargs is None else dict(kwargs)))
|
|
1166
|
+
return x
|
|
1167
|
+
|
|
1168
|
+
return inner
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
__all__ = [
|
|
1172
|
+
"ArrayStorage",
|
|
1173
|
+
"elementwise_processor",
|
|
1174
|
+
"files_from",
|
|
1175
|
+
"chunked_file", # <-- legacy
|
|
1176
|
+
"chunked_dir", # <-- legacy
|
|
1177
|
+
"batches_of",
|
|
1178
|
+
"compose_steps"
|
|
1179
|
+
]
|
|
1180
|
+
|
|
1181
|
+
if __name__ == "__main__":
|
|
1182
|
+
pass
|