patchworks 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- patchworks/__init__.py +48 -0
- patchworks/_chunks.py +258 -0
- patchworks/_cluster.py +93 -0
- patchworks/_core.py +352 -0
- patchworks/_io.py +218 -0
- patchworks/_merge.py +405 -0
- patchworks/_relabel.py +83 -0
- patchworks/plugins/__init__.py +1 -0
- patchworks/plugins/cellpose.py +188 -0
- patchworks-0.2.0.dist-info/METADATA +294 -0
- patchworks-0.2.0.dist-info/RECORD +12 -0
- patchworks-0.2.0.dist-info/WHEEL +4 -0
patchworks/_core.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""Core tile_process function."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from contextlib import nullcontext as _nullcontext
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Union
|
|
9
|
+
|
|
10
|
+
import dask.array as da
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from ._chunks import auto_tile_shape
|
|
15
|
+
from ._cluster import _client_is_in_process, _distributed_client
|
|
16
|
+
from ._io import _auto_empty_threshold, load_ome_zarr
|
|
17
|
+
from ._merge import zarr_native_merge
|
|
18
|
+
from ._relabel import relabel_sequential_zarr
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _stage_to_zarr(
|
|
24
|
+
arr: da.Array, path: str, component: str, show_progress: bool
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Write *arr* to zarr *path/component*, never loading it into RAM."""
|
|
27
|
+
import dask
|
|
28
|
+
lazy_write = arr.to_zarr(str(path), component=component, overwrite=True, compute=False)
|
|
29
|
+
client = _distributed_client()
|
|
30
|
+
if client is not None:
|
|
31
|
+
future = client.compute(lazy_write)
|
|
32
|
+
if show_progress:
|
|
33
|
+
from dask.distributed import progress as _dist_progress
|
|
34
|
+
_dist_progress(future)
|
|
35
|
+
future.result()
|
|
36
|
+
else:
|
|
37
|
+
from dask.diagnostics import ProgressBar
|
|
38
|
+
ctx = ProgressBar() if show_progress else _nullcontext()
|
|
39
|
+
with ctx:
|
|
40
|
+
dask.compute(lazy_write)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def tile_process(
|
|
44
|
+
image: Union[da.Array, str, Path],
|
|
45
|
+
fn: Callable[[np.ndarray], np.ndarray],
|
|
46
|
+
*,
|
|
47
|
+
tile_shape: Union[tuple[int, ...], Callable[[tuple, Any], tuple], str, None] = None,
|
|
48
|
+
overlap: int = 0,
|
|
49
|
+
channel: int | None = 0,
|
|
50
|
+
level: int = 0,
|
|
51
|
+
use_gpu: bool = False,
|
|
52
|
+
progress: bool = False,
|
|
53
|
+
write_to: Union[str, Path, None] = None,
|
|
54
|
+
output_component: str = "labels",
|
|
55
|
+
sequential_labels: bool = False,
|
|
56
|
+
skip_empty: bool = False,
|
|
57
|
+
empty_threshold: float | None = None,
|
|
58
|
+
stage_dir: Union[str, Path, None] = None,
|
|
59
|
+
keep_stage: bool = False,
|
|
60
|
+
verbose: bool = False,
|
|
61
|
+
) -> da.Array:
|
|
62
|
+
"""Apply *fn* to every tile of *image* and merge labels globally.
|
|
63
|
+
|
|
64
|
+
The core workhorse of patchworks. ``fn`` can be any callable that takes a
|
|
65
|
+
NumPy array and returns an integer label array of the same shape — Cellpose,
|
|
66
|
+
StarDist, Otsu threshold, your own model, anything.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
image:
|
|
71
|
+
Dask array *or* path to an OME-ZARR store.
|
|
72
|
+
fn:
|
|
73
|
+
``(ndarray) -> ndarray`` returning integer labels of the same shape.
|
|
74
|
+
Must be picklable when using distributed schedulers.
|
|
75
|
+
tile_shape:
|
|
76
|
+
Controls tiling before calling *fn*. Accepted values:
|
|
77
|
+
|
|
78
|
+
- ``None`` : keep existing dask chunks.
|
|
79
|
+
- ``tuple`` : use this exact tile shape.
|
|
80
|
+
- ``"auto"`` : call ``auto_tile_shape`` based on shape and dtype.
|
|
81
|
+
- ``Callable[[shape, dtype], tuple]`` : called with the image's shape
|
|
82
|
+
and dtype; the return value is used. Use this with
|
|
83
|
+
``auto_tile_shape_cellpose``:
|
|
84
|
+
|
|
85
|
+
.. code-block:: python
|
|
86
|
+
|
|
87
|
+
from functools import partial
|
|
88
|
+
from patchworks import auto_tile_shape_cellpose, tile_process
|
|
89
|
+
tile_fn = partial(auto_tile_shape_cellpose, diameter=30, use_gpu=True)
|
|
90
|
+
result = tile_process("image.zarr", fn, tile_shape=tile_fn)
|
|
91
|
+
|
|
92
|
+
overlap:
|
|
93
|
+
Voxels of overlap (halo) added to each tile before *fn* is called, so
|
|
94
|
+
objects near tile boundaries have enough spatial context to be
|
|
95
|
+
segmented correctly (Cellpose, StarDist, …). The halo is trimmed off
|
|
96
|
+
before merging — the output has the original shape. ``0`` disables it.
|
|
97
|
+
|
|
98
|
+
Merging is always **touching-label** based: after the halo is trimmed,
|
|
99
|
+
labels that touch across a tile boundary are merged into one object.
|
|
100
|
+
channel:
|
|
101
|
+
Channel index when *image* is a path. Ignored for arrays.
|
|
102
|
+
level:
|
|
103
|
+
Pyramid level when *image* is a path (0 = full resolution).
|
|
104
|
+
use_gpu:
|
|
105
|
+
When ``tile_shape="auto"``, size tiles against GPU VRAM instead of RAM.
|
|
106
|
+
progress:
|
|
107
|
+
Show a progress bar during the tile-writing and relabel steps.
|
|
108
|
+
write_to:
|
|
109
|
+
Output zarr store path. When None, an auto-temp store is used and its
|
|
110
|
+
path is logged. Pass an explicit path to control the output location.
|
|
111
|
+
output_component:
|
|
112
|
+
Array name inside ``write_to``. Default ``"labels"``.
|
|
113
|
+
sequential_labels:
|
|
114
|
+
Renumber merged labels to a contiguous ``1..N`` range. Default False —
|
|
115
|
+
labels stay globally unique but gappy (block-encoded), which is fine for
|
|
116
|
+
counting/measurement. Uses a cheap linear post-pass (O(voxels)), not the
|
|
117
|
+
O(n_chunks²) dask built-in.
|
|
118
|
+
skip_empty:
|
|
119
|
+
Skip *fn* on background tiles. A tile whose max signal is <=
|
|
120
|
+
``empty_threshold`` returns all-zeros immediately. Biggest speed-up for
|
|
121
|
+
sparse/mostly-background volumes. Use ``estimate_empty_tiles()`` first
|
|
122
|
+
to pick a threshold.
|
|
123
|
+
empty_threshold:
|
|
124
|
+
Intensity at or below which a tile is empty (``skip_empty=True`` only).
|
|
125
|
+
None → auto-derive via Otsu on a bounded sample.
|
|
126
|
+
stage_dir:
|
|
127
|
+
Where to put the temporary stage store. ``fn`` is always run once per
|
|
128
|
+
tile to this store, then the merge reads it back from disk (running
|
|
129
|
+
``fn`` again is never needed). Default → next to ``write_to``, else next
|
|
130
|
+
to the input store, else a system temp directory.
|
|
131
|
+
keep_stage:
|
|
132
|
+
Keep the temp stage store after merging (default: delete it). Useful
|
|
133
|
+
for debugging or resuming an interrupted run.
|
|
134
|
+
verbose:
|
|
135
|
+
Log each tile's location and shape as it is processed.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
da.Array
|
|
140
|
+
Globally relabeled array (int32) backed by ``write_to`` (or an
|
|
141
|
+
auto-temp zarr when ``write_to`` is None). Never loads the full volume
|
|
142
|
+
into RAM. Call ``.compute()`` yourself only if the result fits in RAM.
|
|
143
|
+
|
|
144
|
+
Examples
|
|
145
|
+
--------
|
|
146
|
+
**Any threshold function:**
|
|
147
|
+
|
|
148
|
+
>>> from skimage.filters import threshold_otsu
|
|
149
|
+
>>> from skimage.measure import label
|
|
150
|
+
>>>
|
|
151
|
+
>>> def my_fn(tile):
|
|
152
|
+
... return label(tile > threshold_otsu(tile)).astype("int32")
|
|
153
|
+
>>>
|
|
154
|
+
>>> result = tile_process("image.zarr", my_fn, write_to="labels.zarr")
|
|
155
|
+
|
|
156
|
+
**Cellpose (via the plugin):**
|
|
157
|
+
|
|
158
|
+
>>> from patchworks.plugins.cellpose import cellpose_fn
|
|
159
|
+
>>>
|
|
160
|
+
>>> fn = cellpose_fn("cyto3", gpu=True, diameter=30)
|
|
161
|
+
>>> result = tile_process(
|
|
162
|
+
... "image.zarr", fn,
|
|
163
|
+
... tile_shape=(1, 2048, 2048),
|
|
164
|
+
... overlap=20,
|
|
165
|
+
... write_to="labels.zarr",
|
|
166
|
+
... progress=True,
|
|
167
|
+
... )
|
|
168
|
+
|
|
169
|
+
**StarDist:**
|
|
170
|
+
|
|
171
|
+
>>> from stardist.models import StarDist2D
|
|
172
|
+
>>> model = StarDist2D.from_pretrained("2D_versatile_fluo")
|
|
173
|
+
>>>
|
|
174
|
+
>>> def stardist_fn(tile):
|
|
175
|
+
... norm = tile.astype("float32") / tile.max()
|
|
176
|
+
... labels, _ = model.predict_instances(norm)
|
|
177
|
+
... return labels.astype("int32")
|
|
178
|
+
>>>
|
|
179
|
+
>>> result = tile_process("image.zarr", stardist_fn,
|
|
180
|
+
... tile_shape=(1, 1024, 1024), overlap=32)
|
|
181
|
+
|
|
182
|
+
**Write directly to zarr (no RAM accumulation):**
|
|
183
|
+
|
|
184
|
+
>>> tile_process("image.zarr", fn, write_to="labels.zarr", progress=True)
|
|
185
|
+
"""
|
|
186
|
+
# In-process dask workers break the label merge. A GIL-holding fn starves
|
|
187
|
+
# the worker heartbeat and the P2P barrier drops inputs →
|
|
188
|
+
# "FutureCancelledError: lost dependencies".
|
|
189
|
+
_active = _distributed_client()
|
|
190
|
+
if _active is not None and _client_is_in_process(_active):
|
|
191
|
+
raise RuntimeError(
|
|
192
|
+
"Active Dask client uses an in-process worker (processes=False). "
|
|
193
|
+
"This breaks the label merge when fn holds the GIL. Use a "
|
|
194
|
+
"process-based cluster instead:\n"
|
|
195
|
+
" from patchworks import make_local_cluster\n"
|
|
196
|
+
" client, cluster = make_local_cluster(use_gpu=True)\n"
|
|
197
|
+
"or drop the client to use the threaded scheduler "
|
|
198
|
+
"(client.close(); cluster.close())."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Load + tile
|
|
202
|
+
image_source_path = None if isinstance(image, da.Array) else str(image)
|
|
203
|
+
_load_chunks: tuple[int, ...] | None = None
|
|
204
|
+
|
|
205
|
+
if not isinstance(image, da.Array):
|
|
206
|
+
_peek = load_ome_zarr(image, channel=channel, level=level)
|
|
207
|
+
if callable(tile_shape):
|
|
208
|
+
_load_chunks = tuple(tile_shape(_peek.shape, _peek.dtype))
|
|
209
|
+
elif isinstance(tile_shape, str):
|
|
210
|
+
if tile_shape != "auto":
|
|
211
|
+
raise ValueError(f"Unknown tile_shape value: {tile_shape!r}. Use 'auto', a tuple, or a callable.")
|
|
212
|
+
_load_chunks = auto_tile_shape(_peek.shape, _peek.dtype, use_gpu=use_gpu, verbose=True)
|
|
213
|
+
elif tile_shape is not None:
|
|
214
|
+
_load_chunks = tuple(tile_shape)
|
|
215
|
+
tile_shape = None # already handled at load time
|
|
216
|
+
if _load_chunks is not None:
|
|
217
|
+
logger.info("Loading zarr with target tiles %s", _load_chunks)
|
|
218
|
+
image = load_ome_zarr(image, channel=channel, level=level, chunks=_load_chunks)
|
|
219
|
+
else:
|
|
220
|
+
image = _peek
|
|
221
|
+
|
|
222
|
+
if callable(tile_shape):
|
|
223
|
+
tile_shape = tile_shape(image.shape, image.dtype)
|
|
224
|
+
elif isinstance(tile_shape, str):
|
|
225
|
+
if tile_shape != "auto":
|
|
226
|
+
raise ValueError(f"Unknown tile_shape value: {tile_shape!r}. Use 'auto', a tuple, or a callable.")
|
|
227
|
+
tile_shape = auto_tile_shape(image.shape, image.dtype, use_gpu=use_gpu, verbose=True)
|
|
228
|
+
|
|
229
|
+
if tile_shape is not None:
|
|
230
|
+
image = image.rechunk(tile_shape)
|
|
231
|
+
logger.info("Rechunked to %s", tile_shape)
|
|
232
|
+
|
|
233
|
+
n_tiles = int(np.prod([len(c) for c in image.chunks]))
|
|
234
|
+
logger.info(
|
|
235
|
+
"Processing %d tiles (per-axis %s, tile shape %s)",
|
|
236
|
+
n_tiles,
|
|
237
|
+
tuple(len(c) for c in image.chunks),
|
|
238
|
+
tuple(c[0] for c in image.chunks),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
image_for_threshold = image
|
|
242
|
+
|
|
243
|
+
# Overlap — build a per-axis depth dict (clips to fit each axis).
|
|
244
|
+
# An integer depth raises if any axis is smaller than the depth, so we
|
|
245
|
+
# cap per axis. In practice z-axis of size 1 (2-D Cellpose) gets depth=0.
|
|
246
|
+
_depth: dict[int, int] = {
|
|
247
|
+
ax: min(overlap, max(0, sum(c) - 1))
|
|
248
|
+
for ax, c in enumerate(image.chunks)
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
if overlap > 0:
|
|
252
|
+
# boundary="none" is required: only this boundary mode composes with
|
|
253
|
+
# trim_overlap to recover the original shape. "reflect" keeps the halo.
|
|
254
|
+
image = da.overlap.overlap(image, depth=_depth, boundary="none")
|
|
255
|
+
|
|
256
|
+
# Wrap fn with optional empty-tile skipping
|
|
257
|
+
_skip_thr = empty_threshold
|
|
258
|
+
if skip_empty and _skip_thr is None:
|
|
259
|
+
_skip_thr = _auto_empty_threshold(image_for_threshold, channel, level)
|
|
260
|
+
|
|
261
|
+
def active_fn(block, block_info=None):
|
|
262
|
+
loc = block_info[0].get("chunk-location") if block_info else "?"
|
|
263
|
+
if skip_empty and block.size and block.max() <= _skip_thr:
|
|
264
|
+
if verbose:
|
|
265
|
+
logger.debug("skip empty tile %s (max<=%s)", loc, _skip_thr)
|
|
266
|
+
return np.zeros(block.shape, dtype=np.int32)
|
|
267
|
+
if verbose:
|
|
268
|
+
logger.debug("process tile %s shape=%s", loc, block.shape)
|
|
269
|
+
return fn(block)
|
|
270
|
+
|
|
271
|
+
labeled = image.map_blocks(
|
|
272
|
+
active_fn, dtype=np.int32, meta=np.empty((0,) * image.ndim, dtype=np.int32)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Trim the overlap halo so staged tiles have clean boundaries for the
|
|
276
|
+
# boundary-slab scan. Without this the scan reads halo-expanded chunks and
|
|
277
|
+
# the merged output is larger than the input.
|
|
278
|
+
if overlap > 0:
|
|
279
|
+
labeled = da.overlap.trim_overlap(labeled, depth=_depth, boundary="none")
|
|
280
|
+
|
|
281
|
+
# With no distributed client the threaded scheduler runs many tiles at
|
|
282
|
+
# once. For GPU that means several evals sharing one device → CUDA OOM.
|
|
283
|
+
# Pin to a single worker thread so evals run serially. A distributed
|
|
284
|
+
# client manages its own concurrency, so skip the override there.
|
|
285
|
+
import dask as _dask
|
|
286
|
+
|
|
287
|
+
if _active is None and use_gpu:
|
|
288
|
+
_sched_ctx: Any = _dask.config.set(scheduler="threads", num_workers=1)
|
|
289
|
+
else:
|
|
290
|
+
_sched_ctx = _nullcontext()
|
|
291
|
+
|
|
292
|
+
# Stage: run fn once per tile to a temp zarr, then the zarr-native merge
|
|
293
|
+
# reads concrete data from disk (fn is never re-run). Required because the
|
|
294
|
+
# merge scans the labels directly on disk.
|
|
295
|
+
import tempfile
|
|
296
|
+
|
|
297
|
+
if stage_dir is not None:
|
|
298
|
+
base = str(stage_dir)
|
|
299
|
+
elif write_to is not None:
|
|
300
|
+
base = os.path.dirname(os.path.abspath(str(write_to)))
|
|
301
|
+
elif image_source_path is not None:
|
|
302
|
+
base = os.path.dirname(os.path.abspath(image_source_path))
|
|
303
|
+
else:
|
|
304
|
+
base = tempfile.mkdtemp(prefix="bb_stage_")
|
|
305
|
+
stage_path = os.path.join(base, "_bb_stage.zarr")
|
|
306
|
+
logger.info("Staging tiles to %s …", stage_path)
|
|
307
|
+
with _sched_ctx:
|
|
308
|
+
_stage_to_zarr(labeled, stage_path, "staged", progress)
|
|
309
|
+
labeled = da.from_zarr(stage_path, component="staged")
|
|
310
|
+
|
|
311
|
+
if skip_empty and _skip_thr is not None:
|
|
312
|
+
def _tile_max(block: np.ndarray) -> np.ndarray:
|
|
313
|
+
return np.full((1,) * block.ndim, int(block.max()), dtype=np.int32)
|
|
314
|
+
_tile_maxes = labeled.map_blocks(
|
|
315
|
+
_tile_max, dtype=np.int32,
|
|
316
|
+
chunks=tuple(tuple(1 for _ in c) for c in labeled.chunks),
|
|
317
|
+
).compute()
|
|
318
|
+
_n_skip = int((_tile_maxes == 0).sum())
|
|
319
|
+
logger.info(
|
|
320
|
+
"skip_empty: %d/%d tiles ran fn, %d skipped (max<=%.4g)",
|
|
321
|
+
int(_tile_maxes.size) - _n_skip, int(_tile_maxes.size), _n_skip, _skip_thr,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _cleanup_stage():
|
|
325
|
+
if not keep_stage:
|
|
326
|
+
import shutil
|
|
327
|
+
shutil.rmtree(stage_path, ignore_errors=True)
|
|
328
|
+
logger.info("Removed stage store %s", stage_path)
|
|
329
|
+
|
|
330
|
+
_nw = min(4, os.cpu_count() or 1)
|
|
331
|
+
|
|
332
|
+
if write_to is not None:
|
|
333
|
+
_effective_out = str(write_to)
|
|
334
|
+
else:
|
|
335
|
+
_effective_out = os.path.join(
|
|
336
|
+
tempfile.mkdtemp(prefix="bb_merge_"), "merged.zarr"
|
|
337
|
+
)
|
|
338
|
+
logger.info("write_to not set — merged labels in auto-temp %s", _effective_out)
|
|
339
|
+
|
|
340
|
+
zarr_native_merge(
|
|
341
|
+
stage_path, "staged", _effective_out, output_component,
|
|
342
|
+
n_workers=_nw, show_progress=progress,
|
|
343
|
+
)
|
|
344
|
+
if sequential_labels:
|
|
345
|
+
logger.info("Relabelling to contiguous ids…")
|
|
346
|
+
relabel_sequential_zarr(_effective_out, output_component)
|
|
347
|
+
_cleanup_stage()
|
|
348
|
+
|
|
349
|
+
# Always return a lazy dask array backed by the output zarr.
|
|
350
|
+
# Never load the full volume into RAM here — the merge already materialised
|
|
351
|
+
# to disk (auto-temp when write_to=None). Caller can .compute() if needed.
|
|
352
|
+
return da.from_zarr(_effective_out, component=output_component)
|
patchworks/_io.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""OME-ZARR loading and empty-tile estimation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Union
|
|
7
|
+
|
|
8
|
+
import dask.array as da
|
|
9
|
+
import numpy as np
|
|
10
|
+
import zarr
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
_ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_ome_zarr(
|
|
18
|
+
store_path: Union[str, Path],
|
|
19
|
+
channel: int | None = 0,
|
|
20
|
+
level: int = 0,
|
|
21
|
+
chunks: tuple[int, ...] | None = None,
|
|
22
|
+
) -> da.Array:
|
|
23
|
+
"""Load one spatial array from an OME-ZARR store.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
store_path:
|
|
28
|
+
Path to the OME-ZARR store (.zarr directory).
|
|
29
|
+
channel:
|
|
30
|
+
Channel index to select (axis is dropped). Pass ``None`` to keep it.
|
|
31
|
+
level:
|
|
32
|
+
Resolution pyramid level (0 = full resolution).
|
|
33
|
+
chunks:
|
|
34
|
+
Target chunk shape for the returned dask array.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
da.Array
|
|
39
|
+
Shape ``(z, y, x)`` when *channel* is an int, or ``(c, z, y, x)``
|
|
40
|
+
when *channel* is None.
|
|
41
|
+
|
|
42
|
+
Examples
|
|
43
|
+
--------
|
|
44
|
+
>>> arr = load_ome_zarr("image.zarr", channel=0)
|
|
45
|
+
>>> arr.shape
|
|
46
|
+
(128, 2048, 2048)
|
|
47
|
+
"""
|
|
48
|
+
root = zarr.open_group(str(store_path), mode="r")
|
|
49
|
+
try:
|
|
50
|
+
path = root.attrs["multiscales"][0]["datasets"][level]["path"]
|
|
51
|
+
except (KeyError, IndexError, TypeError) as exc:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Cannot read OME-ZARR multiscales metadata at level {level} "
|
|
54
|
+
f"in {store_path!r}"
|
|
55
|
+
) from exc
|
|
56
|
+
|
|
57
|
+
zarr_chunks = chunks
|
|
58
|
+
if chunks is not None and channel is not None:
|
|
59
|
+
zarr_ndim = len(root[path].shape)
|
|
60
|
+
if zarr_ndim > len(chunks):
|
|
61
|
+
zarr_chunks = (1,) * (zarr_ndim - len(chunks)) + tuple(chunks)
|
|
62
|
+
|
|
63
|
+
arr = da.from_zarr(str(store_path), component=path, chunks=zarr_chunks)
|
|
64
|
+
if channel is not None:
|
|
65
|
+
arr = arr[channel]
|
|
66
|
+
return arr
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _otsu_threshold(sample: np.ndarray) -> float:
|
|
70
|
+
"""Otsu threshold of *sample*; falls back to 0 if degenerate.
|
|
71
|
+
|
|
72
|
+
Operates on the full distribution including zeros — zeros are background
|
|
73
|
+
pixels and must be included so Otsu can find the signal/background boundary.
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
from skimage.filters import threshold_otsu
|
|
77
|
+
return float(threshold_otsu(sample))
|
|
78
|
+
except Exception:
|
|
79
|
+
# Degenerate (all same value) → no threshold needed; return 0 so
|
|
80
|
+
# non-zero tiles are marked occupied.
|
|
81
|
+
return 0.0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _auto_empty_threshold(image: da.Array, channel: int | None, level: int) -> float:
|
|
85
|
+
"""Pick an empty-tile threshold from a cheap bounded sample (Otsu)."""
|
|
86
|
+
n = image.ndim
|
|
87
|
+
win = [min(64 if i >= n - 3 else s, s) for i, s in enumerate(image.shape)]
|
|
88
|
+
win = [min(w, 256) if i >= n - 2 else w for i, w in enumerate(win)]
|
|
89
|
+
samples = []
|
|
90
|
+
for frac in (0.33, 0.5, 0.66):
|
|
91
|
+
sl = tuple(
|
|
92
|
+
slice(int(s * frac) - w // 2 if s > w else 0,
|
|
93
|
+
(int(s * frac) - w // 2 if s > w else 0) + w)
|
|
94
|
+
for s, w in zip(image.shape, win)
|
|
95
|
+
)
|
|
96
|
+
samples.append(np.asarray(image[sl]).ravel())
|
|
97
|
+
sample = np.concatenate(samples)
|
|
98
|
+
thr = _otsu_threshold(sample)
|
|
99
|
+
logger.info("Auto empty_threshold=%.3g (Otsu on %d samples)", thr, len(samples))
|
|
100
|
+
return thr
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def estimate_empty_tiles(
|
|
104
|
+
image: Union[da.Array, str, Path],
|
|
105
|
+
tile_shape: tuple[int, ...],
|
|
106
|
+
threshold: float | None = None,
|
|
107
|
+
channel: int | None = 0,
|
|
108
|
+
level: int = 0,
|
|
109
|
+
sample_window: tuple[int, ...] = (24, 256, 256),
|
|
110
|
+
) -> dict[str, Any]:
|
|
111
|
+
"""Fast preview of which tiles are background before processing.
|
|
112
|
+
|
|
113
|
+
For each tile, reads a small centred window (``sample_window``) and tests
|
|
114
|
+
whether its max exceeds *threshold*. Bounded I/O — runs in seconds to
|
|
115
|
+
minutes on terabyte arrays.
|
|
116
|
+
|
|
117
|
+
APPROXIMATE: only the tile centre is inspected. The actual ``tile_process``
|
|
118
|
+
run always tests the whole tile inline. Use this only to pick a threshold
|
|
119
|
+
and gauge the empty fraction before committing to a full run.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
image:
|
|
124
|
+
Dask array or OME-ZARR path.
|
|
125
|
+
tile_shape:
|
|
126
|
+
Tile shape you plan to use, e.g. ``(120, 697, 697)``.
|
|
127
|
+
threshold:
|
|
128
|
+
Empty cutoff (signal <= threshold → empty). None → Otsu on samples.
|
|
129
|
+
channel, level:
|
|
130
|
+
Used only when *image* is a path.
|
|
131
|
+
sample_window:
|
|
132
|
+
Size of the centred window read per tile.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
dict with keys:
|
|
137
|
+
``threshold``, ``n_tiles``, ``n_occupied``, ``empty_fraction``,
|
|
138
|
+
``occupancy`` (bool ndarray, one entry per tile in the grid).
|
|
139
|
+
|
|
140
|
+
Examples
|
|
141
|
+
--------
|
|
142
|
+
>>> info = estimate_empty_tiles("image.zarr", (120, 697, 697))
|
|
143
|
+
>>> print(f"{info['empty_fraction']:.0%} of tiles are background")
|
|
144
|
+
>>> labels = tile_process("image.zarr", fn, tile_shape=(120, 697, 697),
|
|
145
|
+
... skip_empty=True, empty_threshold=info["threshold"])
|
|
146
|
+
"""
|
|
147
|
+
n_spatial = len(tile_shape)
|
|
148
|
+
|
|
149
|
+
z_src: Any = None
|
|
150
|
+
if isinstance(image, (str, Path)):
|
|
151
|
+
_root = zarr.open_group(str(image), mode="r")
|
|
152
|
+
try:
|
|
153
|
+
_zpath = _root.attrs["multiscales"][0]["datasets"][level]["path"]
|
|
154
|
+
except (KeyError, IndexError, TypeError) as exc:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Cannot read OME-ZARR multiscales metadata at level {level} "
|
|
157
|
+
f"in {image!r}"
|
|
158
|
+
) from exc
|
|
159
|
+
z_src = _root[_zpath]
|
|
160
|
+
sp_shape = tuple(z_src.shape[-n_spatial:])
|
|
161
|
+
else:
|
|
162
|
+
arr = image
|
|
163
|
+
sp_shape = tuple(arr.shape[-n_spatial:])
|
|
164
|
+
|
|
165
|
+
win = [min(w, t, s) for w, t, s in zip(sample_window, tile_shape, sp_shape)]
|
|
166
|
+
grid = [int(np.ceil(s / t)) for s, t in zip(sp_shape, tile_shape)]
|
|
167
|
+
|
|
168
|
+
_ch_prefix: tuple = ()
|
|
169
|
+
if z_src is not None:
|
|
170
|
+
n_leading = z_src.ndim - n_spatial
|
|
171
|
+
if channel is not None and n_leading > 0:
|
|
172
|
+
_ch_prefix = (0,) * (n_leading - 1) + (channel,)
|
|
173
|
+
|
|
174
|
+
# Streaming single pass: store only per-tile max (a scalar) and a bounded
|
|
175
|
+
# sample list for Otsu. The old approach stored every tile's full block in
|
|
176
|
+
# `blocks` dict — for 2000 tiles × 24×256×256 × 2 bytes = ~6 GB in RAM.
|
|
177
|
+
_MAX_OTSU_SAMPLES = 500
|
|
178
|
+
samples: list[np.ndarray] = []
|
|
179
|
+
tile_maxes: dict[tuple, float] = {}
|
|
180
|
+
|
|
181
|
+
for idx in np.ndindex(*grid):
|
|
182
|
+
sl: list[slice] = []
|
|
183
|
+
for i, t, w, s in zip(idx, tile_shape, win, sp_shape):
|
|
184
|
+
start = min(i * t + (t - w) // 2, s - w)
|
|
185
|
+
sl.append(slice(start, start + w))
|
|
186
|
+
|
|
187
|
+
if z_src is not None:
|
|
188
|
+
block = np.asarray(z_src[_ch_prefix + tuple(sl)])
|
|
189
|
+
else:
|
|
190
|
+
sub = arr[(...,) + tuple(sl)] if arr.ndim > n_spatial else arr[tuple(sl)]
|
|
191
|
+
block = np.asarray(sub)
|
|
192
|
+
|
|
193
|
+
tile_maxes[idx] = float(block.max()) if block.size else 0.0
|
|
194
|
+
if threshold is None and len(samples) < _MAX_OTSU_SAMPLES:
|
|
195
|
+
samples.append(block.ravel())
|
|
196
|
+
# block freed here — not stored
|
|
197
|
+
|
|
198
|
+
if threshold is None:
|
|
199
|
+
threshold = _otsu_threshold(np.concatenate(samples) if samples else np.zeros(1))
|
|
200
|
+
|
|
201
|
+
occupancy = np.zeros(grid, dtype=bool)
|
|
202
|
+
for idx, mx in tile_maxes.items():
|
|
203
|
+
occupancy[idx] = mx > threshold
|
|
204
|
+
|
|
205
|
+
n_tiles = int(occupancy.size)
|
|
206
|
+
n_occ = int(occupancy.sum())
|
|
207
|
+
empty_frac = 1.0 - n_occ / n_tiles if n_tiles else 0.0
|
|
208
|
+
logger.info(
|
|
209
|
+
"estimate_empty_tiles: threshold=%.4g occupied %d/%d tiles empty=%.0f%%",
|
|
210
|
+
threshold, n_occ, n_tiles, empty_frac * 100,
|
|
211
|
+
)
|
|
212
|
+
return {
|
|
213
|
+
"threshold": float(threshold),
|
|
214
|
+
"n_tiles": n_tiles,
|
|
215
|
+
"n_occupied": n_occ,
|
|
216
|
+
"empty_fraction": empty_frac,
|
|
217
|
+
"occupancy": occupancy,
|
|
218
|
+
}
|