patchworks 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- patchworks/__init__.py +48 -0
- patchworks/_chunks.py +258 -0
- patchworks/_cluster.py +93 -0
- patchworks/_core.py +352 -0
- patchworks/_io.py +218 -0
- patchworks/_merge.py +405 -0
- patchworks/_relabel.py +83 -0
- patchworks/plugins/__init__.py +1 -0
- patchworks/plugins/cellpose.py +188 -0
- patchworks-0.2.0.dist-info/METADATA +294 -0
- patchworks-0.2.0.dist-info/RECORD +12 -0
- patchworks-0.2.0.dist-info/WHEEL +4 -0
patchworks/_merge.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Zarr-native label merge: boundary scan → scipy CC → parallel relabel.
|
|
2
|
+
|
|
3
|
+
Three steps, all zarr-native with no dask task graph:
|
|
4
|
+
1. Scan thin boundary slabs → touching label pairs (O(n_faces × face_area))
|
|
5
|
+
2. scipy sparse connected_components on pairs → relabeling LUT
|
|
6
|
+
3. Apply LUT to each chunk in parallel via multiprocessing.Pool
|
|
7
|
+
|
|
8
|
+
Trade-off: touching-label merge only (overlap_depth=0 semantics for merge).
|
|
9
|
+
IoU-overlap merge is not supported here. Keep overlap > 0 during segmentation
|
|
10
|
+
for boundary-cell context; trim the halo before staging so chunk boundaries
|
|
11
|
+
in the staged zarr are clean for this merge.
|
|
12
|
+
|
|
13
|
+
Public API
|
|
14
|
+
----------
|
|
15
|
+
``merge_tile_labels(labeled, write_to, ...)`` — standalone merge for labeled
|
|
16
|
+
dask arrays or pre-staged zarr stores. Use this directly if you already have
|
|
17
|
+
per-tile labels and just need the boundary-stitching step.
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import tempfile
|
|
24
|
+
from contextlib import nullcontext as _nullcontext
|
|
25
|
+
from itertools import product as _iproduct
|
|
26
|
+
from multiprocessing import Pool as _Pool
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any, Union
|
|
29
|
+
|
|
30
|
+
import dask.array as da
|
|
31
|
+
import numpy as np
|
|
32
|
+
import zarr
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from tqdm.auto import tqdm as _tqdm
|
|
36
|
+
except ImportError:
|
|
37
|
+
_tqdm = None
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
_ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3
|
|
42
|
+
_LUT_WARN_THRESHOLD = 100_000_000 # warn when max_label > 100 M (LUT > 800 MB)
|
|
43
|
+
|
|
44
|
+
# Per-worker globals set by _init_worker.
|
|
45
|
+
# LUT is memory-mapped from disk so it is shared read-only across all workers
|
|
46
|
+
# (OS page cache, no per-process copy). Passing the LUT directly via pickle
|
|
47
|
+
# would deserialize N separate copies — e.g. 4 workers × 800 MB = 3.2 GB wasted.
|
|
48
|
+
_merge_lut: "np.ndarray | None" = None
|
|
49
|
+
_merge_lut_path: "str | None" = None
|
|
50
|
+
_merge_staged_path: "str | None" = None
|
|
51
|
+
_merge_staged_comp: "str | None" = None
|
|
52
|
+
_merge_out_path: "str | None" = None
|
|
53
|
+
_merge_out_comp: "str | None" = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp):
|
|
57
|
+
global _merge_lut, _merge_lut_path, _merge_staged_path, _merge_staged_comp
|
|
58
|
+
global _merge_out_path, _merge_out_comp
|
|
59
|
+
_merge_lut = np.load(lut_path, mmap_mode="r") # shared read-only via OS page cache
|
|
60
|
+
_merge_lut_path = lut_path
|
|
61
|
+
_merge_staged_path = staged_path
|
|
62
|
+
_merge_staged_comp = staged_comp
|
|
63
|
+
_merge_out_path = out_path
|
|
64
|
+
_merge_out_comp = out_comp
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _relabel_chunk_worker(chunk_slice: tuple) -> None:
|
|
68
|
+
src = zarr.open_group(_merge_staged_path, mode="r")[_merge_staged_comp]
|
|
69
|
+
dst = zarr.open_group(_merge_out_path, mode="r+")[_merge_out_comp]
|
|
70
|
+
block = np.asarray(src[chunk_slice], dtype=np.int64)
|
|
71
|
+
max_b = int(block.max())
|
|
72
|
+
if max_b == 0:
|
|
73
|
+
dst[chunk_slice] = block.astype(np.int32)
|
|
74
|
+
return
|
|
75
|
+
lut = _merge_lut
|
|
76
|
+
if max_b < len(lut):
|
|
77
|
+
out = lut[block]
|
|
78
|
+
else:
|
|
79
|
+
ext = np.arange(len(lut), max_b + 1, dtype=np.int64)
|
|
80
|
+
out = np.concatenate([lut, ext])[block]
|
|
81
|
+
dst[chunk_slice] = out.astype(np.int32)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _boundary_face_specs(
|
|
85
|
+
shape: tuple[int, ...], chunk_shape: tuple[int, ...]
|
|
86
|
+
) -> list[tuple[int, int]]:
|
|
87
|
+
specs = []
|
|
88
|
+
for ax, (s, cs) in enumerate(zip(shape, chunk_shape)):
|
|
89
|
+
pos = cs
|
|
90
|
+
while pos < s:
|
|
91
|
+
specs.append((ax, pos))
|
|
92
|
+
pos += cs
|
|
93
|
+
return specs
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _scan_touching_pairs(
|
|
97
|
+
zarr_path: str, component: str, chunk_shape: tuple[int, ...]
|
|
98
|
+
) -> np.ndarray:
|
|
99
|
+
"""Scan chunk-boundary slabs; return (N, 2) int64 array of touching pairs.
|
|
100
|
+
|
|
101
|
+
Reads the boundary face one zarr-chunk column at a time so memory per read
|
|
102
|
+
is bounded to one chunk (~200 MB). Reading the full face at once
|
|
103
|
+
(slice(None) on face axes) would allocate face_area × 8 bytes in one shot —
|
|
104
|
+
e.g. 37888 × 27392 × 8 = 8 GiB for a single z-face (OOM on real datasets).
|
|
105
|
+
"""
|
|
106
|
+
root = zarr.open_group(zarr_path, mode="r")
|
|
107
|
+
arr = root[component]
|
|
108
|
+
shape = arr.shape
|
|
109
|
+
specs = _boundary_face_specs(shape, chunk_shape)
|
|
110
|
+
all_pairs: list[np.ndarray] = []
|
|
111
|
+
for ax, pos in specs:
|
|
112
|
+
# tile the face dimensions using chunk_shape columns
|
|
113
|
+
face_axes = [a for a in range(arr.ndim) if a != ax]
|
|
114
|
+
face_ranges = [range(0, shape[a], chunk_shape[a]) for a in face_axes]
|
|
115
|
+
for offsets in _iproduct(*face_ranges):
|
|
116
|
+
sl: list = [slice(None)] * arr.ndim
|
|
117
|
+
sl[ax] = slice(pos - 1, pos + 1)
|
|
118
|
+
for a, off in zip(face_axes, offsets):
|
|
119
|
+
sl[a] = slice(off, min(off + chunk_shape[a], shape[a]))
|
|
120
|
+
slab = np.moveaxis(np.asarray(arr[tuple(sl)]), ax, 0)
|
|
121
|
+
a_vals = slab[0].ravel().astype(np.int64)
|
|
122
|
+
b_vals = slab[1].ravel().astype(np.int64)
|
|
123
|
+
mask = (a_vals > 0) & (b_vals > 0) & (a_vals != b_vals)
|
|
124
|
+
if mask.any():
|
|
125
|
+
pairs = np.sort(
|
|
126
|
+
np.stack([a_vals[mask], b_vals[mask]], axis=1), axis=1
|
|
127
|
+
)
|
|
128
|
+
all_pairs.append(np.unique(pairs, axis=0))
|
|
129
|
+
if not all_pairs:
|
|
130
|
+
return np.empty((0, 2), dtype=np.int64)
|
|
131
|
+
return np.unique(np.vstack(all_pairs), axis=0)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray:
|
|
135
|
+
"""Touching-pairs → scipy connected components → relabeling LUT."""
|
|
136
|
+
if max_label > _LUT_WARN_THRESHOLD:
|
|
137
|
+
logger.warning(
|
|
138
|
+
"_build_relabel_lut: max_label=%d → LUT ~%.0f MB. "
|
|
139
|
+
"Memory use is bounded but large LUTs slow the merge.",
|
|
140
|
+
max_label, max_label * 8 / 1024**2,
|
|
141
|
+
)
|
|
142
|
+
lut = np.arange(max_label + 1, dtype=np.int64)
|
|
143
|
+
if len(pairs) == 0 or max_label == 0:
|
|
144
|
+
return lut
|
|
145
|
+
from scipy.sparse import csr_matrix
|
|
146
|
+
from scipy.sparse.csgraph import connected_components
|
|
147
|
+
|
|
148
|
+
n = max_label + 1
|
|
149
|
+
valid = (pairs[:, 0] < n) & (pairs[:, 1] < n)
|
|
150
|
+
pairs = pairs[valid]
|
|
151
|
+
if len(pairs) == 0:
|
|
152
|
+
return lut
|
|
153
|
+
rows = np.concatenate([pairs[:, 0], pairs[:, 1]])
|
|
154
|
+
cols = np.concatenate([pairs[:, 1], pairs[:, 0]])
|
|
155
|
+
graph = csr_matrix(
|
|
156
|
+
(np.ones(len(rows), dtype=np.float32), (rows, cols)), shape=(n, n)
|
|
157
|
+
)
|
|
158
|
+
n_cc, cc_of = connected_components(graph, directed=False)
|
|
159
|
+
cc_min = np.full(n_cc, n, dtype=np.int64)
|
|
160
|
+
np.minimum.at(cc_min, cc_of, np.arange(n, dtype=np.int64))
|
|
161
|
+
return cc_min[cc_of]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _create_zarr_label_array(
|
|
165
|
+
group: zarr.Group, name: str, shape: tuple, chunks: tuple
|
|
166
|
+
) -> zarr.Array:
|
|
167
|
+
if name in group:
|
|
168
|
+
del group[name]
|
|
169
|
+
if _ZARR_V3:
|
|
170
|
+
return group.create_array(name, shape=shape, chunks=chunks, dtype=np.int32)
|
|
171
|
+
return group.zeros(name, shape=shape, chunks=chunks, dtype=np.int32, overwrite=True)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def zarr_native_merge(
|
|
175
|
+
staged_path: str,
|
|
176
|
+
staged_component: str,
|
|
177
|
+
out_path: str,
|
|
178
|
+
out_component: str,
|
|
179
|
+
n_workers: int = 4,
|
|
180
|
+
show_progress: bool = False,
|
|
181
|
+
) -> None:
|
|
182
|
+
"""Zarr-native label merge: boundary scan → scipy CC → parallel relabel.
|
|
183
|
+
|
|
184
|
+
Scales to 2000+ chunks where the dask_image approach stalls (O(n_chunks²)
|
|
185
|
+
graph). Reads *staged_path/staged_component*, merges touching cross-boundary
|
|
186
|
+
labels, writes result to *out_path/out_component*. No dask task graph.
|
|
187
|
+
"""
|
|
188
|
+
root = zarr.open_group(staged_path, mode="r")
|
|
189
|
+
arr = root[staged_component]
|
|
190
|
+
shape, chunk_shape = arr.shape, arr.chunks
|
|
191
|
+
|
|
192
|
+
max_label = int(da.from_zarr(staged_path, component=staged_component).max().compute())
|
|
193
|
+
logger.info(
|
|
194
|
+
"zarr_native_merge: shape=%s chunks=%s max_label=%d", shape, chunk_shape, max_label
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
n_faces = len(_boundary_face_specs(shape, chunk_shape))
|
|
198
|
+
logger.info("zarr_native_merge: scanning %d boundary faces…", n_faces)
|
|
199
|
+
pairs = _scan_touching_pairs(staged_path, staged_component, chunk_shape)
|
|
200
|
+
logger.info("zarr_native_merge: %d touching pairs → building LUT", len(pairs))
|
|
201
|
+
|
|
202
|
+
lut = _build_relabel_lut(pairs, max_label)
|
|
203
|
+
n_remapped = int((lut != np.arange(len(lut), dtype=np.int64)).sum())
|
|
204
|
+
logger.info("zarr_native_merge: %d labels remapped across boundaries", n_remapped)
|
|
205
|
+
|
|
206
|
+
out_root = zarr.open_group(out_path, mode="a")
|
|
207
|
+
_create_zarr_label_array(out_root, out_component, shape, chunk_shape)
|
|
208
|
+
|
|
209
|
+
n_per_dim = [(s + c - 1) // c for s, c in zip(shape, chunk_shape)]
|
|
210
|
+
chunk_slices = [
|
|
211
|
+
tuple(
|
|
212
|
+
slice(i * c, min((i + 1) * c, s))
|
|
213
|
+
for i, c, s in zip(idx, chunk_shape, shape)
|
|
214
|
+
)
|
|
215
|
+
for idx in _iproduct(*[range(n) for n in n_per_dim])
|
|
216
|
+
]
|
|
217
|
+
n_chunks = len(chunk_slices)
|
|
218
|
+
n_w = max(1, min(n_workers, n_chunks))
|
|
219
|
+
logger.info("zarr_native_merge: relabeling %d chunks with %d worker(s)…", n_chunks, n_w)
|
|
220
|
+
|
|
221
|
+
# Save LUT to a temp .npy file so workers memory-map it (shared OS page cache).
|
|
222
|
+
# Pickling the LUT array directly via multiprocessing initargs would
|
|
223
|
+
# deserialize a full copy per worker — e.g. 4 workers × 800 MB = 3.2 GB.
|
|
224
|
+
_lut_dir = tempfile.mkdtemp(prefix="bb_lut_")
|
|
225
|
+
lut_path = os.path.join(_lut_dir, "lut.npy")
|
|
226
|
+
np.save(lut_path, lut)
|
|
227
|
+
del lut # parent no longer needs it; workers load via mmap
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
if n_w <= 1:
|
|
231
|
+
_init_worker(lut_path, staged_path, staged_component, out_path, out_component)
|
|
232
|
+
it: Any = chunk_slices
|
|
233
|
+
if show_progress and _tqdm is not None:
|
|
234
|
+
it = _tqdm(it, total=n_chunks, desc="relabel chunks")
|
|
235
|
+
for sl in it:
|
|
236
|
+
_relabel_chunk_worker(sl)
|
|
237
|
+
else:
|
|
238
|
+
with _Pool(
|
|
239
|
+
processes=n_w,
|
|
240
|
+
initializer=_init_worker,
|
|
241
|
+
initargs=(lut_path, staged_path, staged_component, out_path, out_component),
|
|
242
|
+
) as pool:
|
|
243
|
+
it = pool.imap_unordered(_relabel_chunk_worker, chunk_slices)
|
|
244
|
+
if show_progress and _tqdm is not None:
|
|
245
|
+
it = _tqdm(it, total=n_chunks, desc="relabel chunks")
|
|
246
|
+
for _ in it:
|
|
247
|
+
pass
|
|
248
|
+
finally:
|
|
249
|
+
import shutil
|
|
250
|
+
shutil.rmtree(_lut_dir, ignore_errors=True)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ---------------------------------------------------------------------------
|
|
254
|
+
# Public standalone merge API
|
|
255
|
+
# ---------------------------------------------------------------------------
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def merge_tile_labels(
|
|
259
|
+
labeled: Union["da.Array", str, Path],
|
|
260
|
+
write_to: Union[str, Path, None] = None,
|
|
261
|
+
*,
|
|
262
|
+
input_component: str = "labels",
|
|
263
|
+
output_component: str = "labels",
|
|
264
|
+
overlap: int = 0,
|
|
265
|
+
sequential_labels: bool = False,
|
|
266
|
+
n_workers: int | None = None,
|
|
267
|
+
stage_dir: Union[str, Path, None] = None,
|
|
268
|
+
keep_stage: bool = False,
|
|
269
|
+
progress: bool = False,
|
|
270
|
+
) -> "da.Array":
|
|
271
|
+
"""Merge per-tile labels into a globally consistent label array.
|
|
272
|
+
|
|
273
|
+
Standalone merge step — use this when you already have per-tile labels
|
|
274
|
+
(from your own segmentation pipeline) and just need the boundary stitching.
|
|
275
|
+
|
|
276
|
+
Accepts either:
|
|
277
|
+
|
|
278
|
+
- A **dask array** of per-tile integer labels (e.g. output of
|
|
279
|
+
``dask.array.map_blocks`` on your own segmentation function).
|
|
280
|
+
- A **zarr store path** whose ``input_component`` array already contains
|
|
281
|
+
per-tile labels written by your own pipeline.
|
|
282
|
+
|
|
283
|
+
Labels that **touch** across tile boundaries are merged into a single ID.
|
|
284
|
+
The merge is zarr-native (boundary scan → scipy connected components →
|
|
285
|
+
parallel relabel) — no dask task graph, scales to thousands of tiles.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
labeled:
|
|
290
|
+
Per-tile label array. Either a dask array or a path to a zarr store
|
|
291
|
+
that contains per-tile labels in ``input_component``.
|
|
292
|
+
write_to:
|
|
293
|
+
Output zarr store path. When None, an auto-temp store is used.
|
|
294
|
+
input_component:
|
|
295
|
+
Array name inside a zarr *input* store (ignored for dask arrays).
|
|
296
|
+
output_component:
|
|
297
|
+
Array name inside ``write_to``. Default ``"labels"``.
|
|
298
|
+
overlap:
|
|
299
|
+
If ``labeled`` is a dask array that was computed with ``da.overlap``,
|
|
300
|
+
pass the same depth here to trim the halos before merging.
|
|
301
|
+
Set 0 (default) if the array has no overlap halos.
|
|
302
|
+
sequential_labels:
|
|
303
|
+
Renumber the merged labels to a contiguous ``1..N`` range via a cheap
|
|
304
|
+
linear post-pass (O(voxels)). Default False.
|
|
305
|
+
n_workers:
|
|
306
|
+
Parallel workers for the relabel step. Default ``min(4, cpu_count)``.
|
|
307
|
+
stage_dir:
|
|
308
|
+
Directory for the temp stage zarr when *labeled* is a dask array.
|
|
309
|
+
Default: a system temp directory.
|
|
310
|
+
keep_stage:
|
|
311
|
+
Keep the temp stage zarr after merging. Default False.
|
|
312
|
+
progress:
|
|
313
|
+
Show a progress bar during the relabel step.
|
|
314
|
+
|
|
315
|
+
Returns
|
|
316
|
+
-------
|
|
317
|
+
da.Array
|
|
318
|
+
Merged label array (int32) backed by ``write_to``.
|
|
319
|
+
|
|
320
|
+
Examples
|
|
321
|
+
--------
|
|
322
|
+
**From a dask array of per-tile labels:**
|
|
323
|
+
|
|
324
|
+
>>> import dask.array as da
|
|
325
|
+
>>> from patchworks import merge_tile_labels
|
|
326
|
+
>>>
|
|
327
|
+
>>> # your own tiling + segmentation
|
|
328
|
+
>>> image = da.from_zarr("image.zarr").rechunk((1, 1024, 1024))
|
|
329
|
+
>>> labeled = image.map_blocks(my_segment_fn, dtype="int32",
|
|
330
|
+
... meta=np.empty((0,) * image.ndim, dtype="int32"))
|
|
331
|
+
>>>
|
|
332
|
+
>>> merged = merge_tile_labels(labeled, write_to="labels.zarr", progress=True)
|
|
333
|
+
|
|
334
|
+
**From a pre-staged zarr store (your pipeline already wrote labels):**
|
|
335
|
+
|
|
336
|
+
>>> merged = merge_tile_labels(
|
|
337
|
+
... "my_staged_labels.zarr",
|
|
338
|
+
... input_component="raw_labels",
|
|
339
|
+
... write_to="merged_labels.zarr",
|
|
340
|
+
... sequential_labels=True,
|
|
341
|
+
... )
|
|
342
|
+
|
|
343
|
+
**Trim overlap halos before merging:**
|
|
344
|
+
|
|
345
|
+
>>> # if labeled was computed with da.overlap.overlap(depth=20)
|
|
346
|
+
>>> merged = merge_tile_labels(labeled, write_to="labels.zarr", overlap=20)
|
|
347
|
+
"""
|
|
348
|
+
import dask.array as da
|
|
349
|
+
from ._relabel import relabel_sequential_zarr
|
|
350
|
+
|
|
351
|
+
nw = n_workers if n_workers is not None else min(4, os.cpu_count() or 1)
|
|
352
|
+
|
|
353
|
+
# -- Stage dask array to zarr if needed --
|
|
354
|
+
stage_path: str | None = None
|
|
355
|
+
staged_component = "staged"
|
|
356
|
+
|
|
357
|
+
if isinstance(labeled, (str, Path)):
|
|
358
|
+
stage_path = str(labeled)
|
|
359
|
+
staged_component = input_component
|
|
360
|
+
else:
|
|
361
|
+
# labeled is a dask array
|
|
362
|
+
if overlap > 0:
|
|
363
|
+
labeled = da.overlap.trim_overlap(labeled, depth=overlap, boundary="none")
|
|
364
|
+
|
|
365
|
+
_base = str(stage_dir) if stage_dir is not None else tempfile.mkdtemp(prefix="bb_stage_")
|
|
366
|
+
stage_path = os.path.join(_base, "_bb_stage.zarr")
|
|
367
|
+
|
|
368
|
+
import dask
|
|
369
|
+
from dask.diagnostics import ProgressBar
|
|
370
|
+
|
|
371
|
+
ctx = ProgressBar() if progress else _nullcontext()
|
|
372
|
+
logger.info("Staging per-tile labels to %s …", stage_path)
|
|
373
|
+
with ctx:
|
|
374
|
+
dask.compute(
|
|
375
|
+
labeled.to_zarr(stage_path, component=staged_component, overwrite=True, compute=False)
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# -- Resolve output path --
|
|
379
|
+
if write_to is not None:
|
|
380
|
+
effective_out = str(write_to)
|
|
381
|
+
else:
|
|
382
|
+
effective_out = os.path.join(
|
|
383
|
+
tempfile.mkdtemp(prefix="bb_merge_"), "merged.zarr"
|
|
384
|
+
)
|
|
385
|
+
logger.info("write_to not set — merged labels in auto-temp %s", effective_out)
|
|
386
|
+
|
|
387
|
+
# -- Merge --
|
|
388
|
+
zarr_native_merge(
|
|
389
|
+
stage_path, staged_component,
|
|
390
|
+
effective_out, output_component,
|
|
391
|
+
n_workers=nw,
|
|
392
|
+
show_progress=progress,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if sequential_labels:
|
|
396
|
+
logger.info("Relabelling to contiguous ids…")
|
|
397
|
+
relabel_sequential_zarr(effective_out, output_component)
|
|
398
|
+
|
|
399
|
+
# -- Cleanup temp stage (only when we created it) --
|
|
400
|
+
if not isinstance(labeled, (str, Path)) and not keep_stage:
|
|
401
|
+
import shutil
|
|
402
|
+
shutil.rmtree(stage_path, ignore_errors=True)
|
|
403
|
+
logger.info("Removed stage store %s", stage_path)
|
|
404
|
+
|
|
405
|
+
return da.from_zarr(effective_out, component=output_component)
|
patchworks/_relabel.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Linear sequential relabelling (O(voxels), not O(n_chunks²))."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
from itertools import product as _iproduct
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import zarr
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_LUT_WARN_THRESHOLD = 100_000_000 # warn when max_label > 100 M (LUT > 800 MB)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def relabel_sequential_array(labels: np.ndarray) -> np.ndarray:
|
|
17
|
+
"""Remap *labels* to a contiguous ``0, 1, … N`` range.
|
|
18
|
+
|
|
19
|
+
Background (0) stays 0. Runs in one ``np.unique`` + a lookup-table gather,
|
|
20
|
+
i.e. O(voxels) — unlike dask's ``relabel_sequential`` which is O(n_chunks²).
|
|
21
|
+
|
|
22
|
+
Examples
|
|
23
|
+
--------
|
|
24
|
+
>>> relabel_sequential_array(np.array([0, 500000, 500000, 7]))
|
|
25
|
+
array([0, 2, 2, 1])
|
|
26
|
+
"""
|
|
27
|
+
uniq = np.unique(labels)
|
|
28
|
+
max_label = int(uniq[-1])
|
|
29
|
+
if max_label > _LUT_WARN_THRESHOLD:
|
|
30
|
+
logger.warning(
|
|
31
|
+
"relabel_sequential_array: max_label=%d → LUT size ~%.0f MB. "
|
|
32
|
+
"Consider using write_to= so labels never need to be in RAM.",
|
|
33
|
+
max_label, max_label * 8 / 1024**2,
|
|
34
|
+
)
|
|
35
|
+
lut = np.zeros(max_label + 1, dtype=np.int64)
|
|
36
|
+
lut[uniq] = np.arange(uniq.size)
|
|
37
|
+
out = lut[labels]
|
|
38
|
+
n = uniq.size - 1 if uniq[0] == 0 else uniq.size
|
|
39
|
+
dtype = np.uint16 if n < np.iinfo(np.uint16).max else np.uint32
|
|
40
|
+
return out.astype(dtype)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def relabel_sequential_zarr(store_path: str, component: str = "labels") -> int:
|
|
44
|
+
"""Relabel a written label zarr to contiguous ids, in place. Returns N.
|
|
45
|
+
|
|
46
|
+
Two-pass streaming algorithm — safe for arrays far larger than RAM.
|
|
47
|
+
Pass 1 collects unique ids (bounded memory: a set). Pass 2 applies the
|
|
48
|
+
lookup-table remap chunk by chunk.
|
|
49
|
+
"""
|
|
50
|
+
root = zarr.open_group(store_path, mode="r+")
|
|
51
|
+
z = root[component]
|
|
52
|
+
z_shape, z_chunks = z.shape, z.chunks
|
|
53
|
+
|
|
54
|
+
# Iterate over actual zarr chunks in ALL dimensions. The z-slab approach
|
|
55
|
+
# (step = z_chunks[0], slice z[i0:i0+step]) reads the full y/x extent per
|
|
56
|
+
# step — for chunks like (120, 731, 731) that means (120, 37888, 27392)
|
|
57
|
+
# = 464 GiB in one allocation (MemoryError).
|
|
58
|
+
n_per_dim = [(s + c - 1) // c for s, c in zip(z_shape, z_chunks)]
|
|
59
|
+
chunk_slices = [
|
|
60
|
+
tuple(slice(i * c, min((i + 1) * c, s)) for i, c, s in zip(idx, z_chunks, z_shape))
|
|
61
|
+
for idx in _iproduct(*[range(n) for n in n_per_dim])
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
uniq: set[int] = set()
|
|
65
|
+
for sl in chunk_slices:
|
|
66
|
+
uniq.update(np.unique(np.asarray(z[sl])).tolist())
|
|
67
|
+
sorted_ids = np.array(sorted(uniq), dtype=np.int64)
|
|
68
|
+
max_label = int(sorted_ids[-1])
|
|
69
|
+
if max_label > _LUT_WARN_THRESHOLD:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"relabel_sequential_zarr: max_label=%d → LUT size ~%.0f MB.",
|
|
72
|
+
max_label, max_label * 8 / 1024**2,
|
|
73
|
+
)
|
|
74
|
+
lut = np.zeros(max_label + 1, dtype=np.int64)
|
|
75
|
+
lut[sorted_ids] = np.arange(sorted_ids.size)
|
|
76
|
+
n = sorted_ids.size - 1 if sorted_ids[0] == 0 else sorted_ids.size
|
|
77
|
+
# Use same dtype logic as relabel_sequential_array so output never overflows.
|
|
78
|
+
out_dtype = np.uint16 if n < np.iinfo(np.uint16).max else np.uint32
|
|
79
|
+
for sl in chunk_slices:
|
|
80
|
+
block = np.asarray(z[sl])
|
|
81
|
+
z[sl] = lut[block].astype(out_dtype)
|
|
82
|
+
logger.info("relabel_sequential_zarr: %d objects renumbered to 1..%d", n, n)
|
|
83
|
+
return int(n)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# patchworks plugins
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Cellpose plugin for patchworks.
|
|
2
|
+
|
|
3
|
+
Requires cellpose >= 3.0 (compatible with v3 and v4).
|
|
4
|
+
|
|
5
|
+
Usage
|
|
6
|
+
-----
|
|
7
|
+
>>> from patchworks.plugins.cellpose import cellpose_fn
|
|
8
|
+
>>> from patchworks import tile_process
|
|
9
|
+
>>>
|
|
10
|
+
>>> fn = cellpose_fn("cyto3", gpu=True, diameter=30)
|
|
11
|
+
>>> result = tile_process("image.zarr", fn, tile_shape=(1, 2048, 2048),
|
|
12
|
+
... overlap=20, write_to="labels.zarr", progress=True)
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import importlib.metadata
|
|
17
|
+
import logging
|
|
18
|
+
from functools import partial
|
|
19
|
+
from typing import Any, Callable
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from cellpose import models as _cellpose_models
|
|
27
|
+
_CELLPOSE_VERSION: tuple[int, ...] = tuple(
|
|
28
|
+
int(x) for x in importlib.metadata.version("cellpose").split(".")[:2]
|
|
29
|
+
)
|
|
30
|
+
_CELLPOSE_V4 = _CELLPOSE_VERSION[0] >= 4
|
|
31
|
+
except ImportError as _e:
|
|
32
|
+
_cellpose_models = None # type: ignore[assignment]
|
|
33
|
+
_CELLPOSE_VERSION = (0, 0)
|
|
34
|
+
_CELLPOSE_V4 = False
|
|
35
|
+
|
|
36
|
+
# Per-process model cache keyed by (model_type, gpu)
|
|
37
|
+
_model_cache: dict[tuple, Any] = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _require_cellpose():
|
|
41
|
+
if _cellpose_models is None:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"cellpose is not installed. Install it with:\n"
|
|
44
|
+
" pip install cellpose\n"
|
|
45
|
+
"or:\n"
|
|
46
|
+
" pip install patchworks[cellpose]"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cellpose_fn(
|
|
51
|
+
model: str = "cyto3",
|
|
52
|
+
*,
|
|
53
|
+
gpu: bool = False,
|
|
54
|
+
diameter: float | None = None,
|
|
55
|
+
do_3D: bool = False,
|
|
56
|
+
channels: list[int] | None = None,
|
|
57
|
+
channel_axis: int | None = None,
|
|
58
|
+
**cellpose_kwargs: Any,
|
|
59
|
+
) -> Callable[[np.ndarray], np.ndarray]:
|
|
60
|
+
"""Return a ready-to-use Cellpose function for ``tile_process``.
|
|
61
|
+
|
|
62
|
+
One-liner convenience wrapper: combines model configuration and function
|
|
63
|
+
creation into a single call.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
model:
|
|
68
|
+
Cellpose model type: ``"cyto3"``, ``"cyto2"``, ``"nuclei"``, etc.
|
|
69
|
+
gpu:
|
|
70
|
+
Use GPU for inference.
|
|
71
|
+
diameter:
|
|
72
|
+
Expected cell diameter in pixels. ``None`` → Cellpose auto-estimates.
|
|
73
|
+
do_3D:
|
|
74
|
+
Run in 3-D mode. Each tile must contain the full z-stack — use
|
|
75
|
+
``auto_tile_shape_cellpose(do_3D=True)`` for appropriate tile shapes.
|
|
76
|
+
channels:
|
|
77
|
+
*Cellpose 3 only.* ``[cytoplasm_channel, nucleus_channel]`` (1-based,
|
|
78
|
+
0 = greyscale). ``[0, 0]`` → greyscale. ``[1, 2]`` → cyto=ch1, nuc=ch2.
|
|
79
|
+
channel_axis:
|
|
80
|
+
*Cellpose 4 only.* Index of the channel axis in the input array.
|
|
81
|
+
``None`` → greyscale input.
|
|
82
|
+
**cellpose_kwargs:
|
|
83
|
+
Extra kwargs forwarded to ``model.eval()``
|
|
84
|
+
(e.g. ``flow_threshold``, ``cellprob_threshold``, ``anisotropy``).
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
Callable[[ndarray], ndarray]
|
|
89
|
+
Picklable function ready for ``tile_process``.
|
|
90
|
+
|
|
91
|
+
Examples
|
|
92
|
+
--------
|
|
93
|
+
Greyscale 2-D:
|
|
94
|
+
|
|
95
|
+
>>> fn = cellpose_fn("cyto3", gpu=True, diameter=30)
|
|
96
|
+
>>> result = tile_process("image.zarr", fn, tile_shape=(1, 2048, 2048), overlap=20)
|
|
97
|
+
|
|
98
|
+
Nuclear segmentation:
|
|
99
|
+
|
|
100
|
+
>>> fn = cellpose_fn("nuclei", diameter=15)
|
|
101
|
+
>>> result = tile_process("image.zarr", fn, channel=1)
|
|
102
|
+
|
|
103
|
+
3-D with anisotropy:
|
|
104
|
+
|
|
105
|
+
>>> fn = cellpose_fn("cyto3", gpu=True, do_3D=True, anisotropy=3.0, diameter=20)
|
|
106
|
+
>>> from functools import partial
|
|
107
|
+
>>> from patchworks import auto_tile_shape_cellpose, tile_process
|
|
108
|
+
>>> tile_fn = partial(auto_tile_shape_cellpose, do_3D=True, use_gpu=True, diameter=20)
|
|
109
|
+
>>> result = tile_process("image.zarr", fn, tile_shape=tile_fn, overlap=10)
|
|
110
|
+
"""
|
|
111
|
+
_require_cellpose()
|
|
112
|
+
cfg = _make_config(model, gpu, channels, channel_axis, diameter, do_3D, **cellpose_kwargs)
|
|
113
|
+
return partial(_run, cellpose_dict=cfg)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _make_config(
|
|
117
|
+
model: str = "cyto3",
|
|
118
|
+
gpu: bool = False,
|
|
119
|
+
channels: list[int] | None = None,
|
|
120
|
+
channel_axis: int | None = None,
|
|
121
|
+
diameter: float | None = None,
|
|
122
|
+
do_3D: bool = False,
|
|
123
|
+
**cellpose_kwargs: Any,
|
|
124
|
+
) -> dict[str, Any]:
|
|
125
|
+
return {
|
|
126
|
+
"model": model,
|
|
127
|
+
"gpu": gpu,
|
|
128
|
+
"channels": channels if channels is not None else [0, 0],
|
|
129
|
+
"channel_axis": channel_axis,
|
|
130
|
+
"diameter": diameter,
|
|
131
|
+
"do_3D": do_3D,
|
|
132
|
+
"cellpose_kwargs": cellpose_kwargs,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_model(cellpose_dict: dict[str, Any]) -> Any:
|
|
137
|
+
"""Return a worker-local cached Cellpose model."""
|
|
138
|
+
_require_cellpose()
|
|
139
|
+
key = (cellpose_dict["model"], cellpose_dict.get("gpu", False))
|
|
140
|
+
if key not in _model_cache:
|
|
141
|
+
gpu = cellpose_dict.get("gpu", False)
|
|
142
|
+
model_type = cellpose_dict["model"]
|
|
143
|
+
if _CELLPOSE_V4:
|
|
144
|
+
_model_cache[key] = _cellpose_models.CellposeModel(
|
|
145
|
+
model_type=model_type, gpu=gpu
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
_model_cache[key] = _cellpose_models.Cellpose(
|
|
149
|
+
model_type=model_type, gpu=gpu
|
|
150
|
+
)
|
|
151
|
+
return _model_cache[key]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _run(block: np.ndarray, cellpose_dict: dict[str, Any]) -> np.ndarray:
|
|
155
|
+
"""Segment one tile with a cached Cellpose model."""
|
|
156
|
+
model = _get_model(cellpose_dict)
|
|
157
|
+
do_3D = cellpose_dict["do_3D"]
|
|
158
|
+
|
|
159
|
+
if _CELLPOSE_V4:
|
|
160
|
+
kwargs: dict[str, Any] = dict(
|
|
161
|
+
channel_axis=cellpose_dict.get("channel_axis"),
|
|
162
|
+
diameter=cellpose_dict["diameter"],
|
|
163
|
+
do_3D=do_3D,
|
|
164
|
+
**cellpose_dict.get("cellpose_kwargs", {}),
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
kwargs = dict(
|
|
168
|
+
channels=cellpose_dict["channels"],
|
|
169
|
+
diameter=cellpose_dict["diameter"],
|
|
170
|
+
do_3D=do_3D,
|
|
171
|
+
**cellpose_dict.get("cellpose_kwargs", {}),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if do_3D:
|
|
175
|
+
kwargs["z_axis"] = 0
|
|
176
|
+
return model.eval(block, **kwargs)[0].astype("int32")
|
|
177
|
+
else:
|
|
178
|
+
# Squeeze singleton z so Cellpose gets a clean 2-D image
|
|
179
|
+
squeeze = block.ndim == 3 and block.shape[0] == 1
|
|
180
|
+
img = block[0] if squeeze else block
|
|
181
|
+
masks = model.eval(img, **kwargs)[0].astype("int32")
|
|
182
|
+
return masks[np.newaxis] if squeeze else masks
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# Keep the lower-level names available for advanced users
|
|
186
|
+
make_cellpose_config = _make_config
|
|
187
|
+
get_cellpose_model = _get_model
|
|
188
|
+
run_cellpose = _run
|