patchworks 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patchworks/_merge.py ADDED
@@ -0,0 +1,405 @@
1
+ """Zarr-native label merge: boundary scan → scipy CC → parallel relabel.
2
+
3
+ Three steps, all zarr-native with no dask task graph:
4
+ 1. Scan thin boundary slabs → touching label pairs (O(n_faces × face_area))
5
+ 2. scipy sparse connected_components on pairs → relabeling LUT
6
+ 3. Apply LUT to each chunk in parallel via multiprocessing.Pool
7
+
8
+ Trade-off: touching-label merge only (overlap_depth=0 semantics for merge).
9
+ IoU-overlap merge is not supported here. Keep overlap > 0 during segmentation
10
+ for boundary-cell context; trim the halo before staging so chunk boundaries
11
+ in the staged zarr are clean for this merge.
12
+
13
+ Public API
14
+ ----------
15
+ ``merge_tile_labels(labeled, write_to, ...)`` — standalone merge for labeled
16
+ dask arrays or pre-staged zarr stores. Use this directly if you already have
17
+ per-tile labels and just need the boundary-stitching step.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import logging
22
+ import os
23
+ import tempfile
24
+ from contextlib import nullcontext as _nullcontext
25
+ from itertools import product as _iproduct
26
+ from multiprocessing import Pool as _Pool
27
+ from pathlib import Path
28
+ from typing import Any, Union
29
+
30
+ import dask.array as da
31
+ import numpy as np
32
+ import zarr
33
+
34
+ try:
35
+ from tqdm.auto import tqdm as _tqdm
36
+ except ImportError:
37
+ _tqdm = None
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ _ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3
42
+ _LUT_WARN_THRESHOLD = 100_000_000 # warn when max_label > 100 M (LUT > 800 MB)
43
+
44
+ # Per-worker globals set by _init_worker.
45
+ # LUT is memory-mapped from disk so it is shared read-only across all workers
46
+ # (OS page cache, no per-process copy). Passing the LUT directly via pickle
47
+ # would deserialize N separate copies — e.g. 4 workers × 800 MB = 3.2 GB wasted.
48
+ _merge_lut: "np.ndarray | None" = None
49
+ _merge_lut_path: "str | None" = None
50
+ _merge_staged_path: "str | None" = None
51
+ _merge_staged_comp: "str | None" = None
52
+ _merge_out_path: "str | None" = None
53
+ _merge_out_comp: "str | None" = None
54
+
55
+
56
+ def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp):
57
+ global _merge_lut, _merge_lut_path, _merge_staged_path, _merge_staged_comp
58
+ global _merge_out_path, _merge_out_comp
59
+ _merge_lut = np.load(lut_path, mmap_mode="r") # shared read-only via OS page cache
60
+ _merge_lut_path = lut_path
61
+ _merge_staged_path = staged_path
62
+ _merge_staged_comp = staged_comp
63
+ _merge_out_path = out_path
64
+ _merge_out_comp = out_comp
65
+
66
+
67
+ def _relabel_chunk_worker(chunk_slice: tuple) -> None:
68
+ src = zarr.open_group(_merge_staged_path, mode="r")[_merge_staged_comp]
69
+ dst = zarr.open_group(_merge_out_path, mode="r+")[_merge_out_comp]
70
+ block = np.asarray(src[chunk_slice], dtype=np.int64)
71
+ max_b = int(block.max())
72
+ if max_b == 0:
73
+ dst[chunk_slice] = block.astype(np.int32)
74
+ return
75
+ lut = _merge_lut
76
+ if max_b < len(lut):
77
+ out = lut[block]
78
+ else:
79
+ ext = np.arange(len(lut), max_b + 1, dtype=np.int64)
80
+ out = np.concatenate([lut, ext])[block]
81
+ dst[chunk_slice] = out.astype(np.int32)
82
+
83
+
84
+ def _boundary_face_specs(
85
+ shape: tuple[int, ...], chunk_shape: tuple[int, ...]
86
+ ) -> list[tuple[int, int]]:
87
+ specs = []
88
+ for ax, (s, cs) in enumerate(zip(shape, chunk_shape)):
89
+ pos = cs
90
+ while pos < s:
91
+ specs.append((ax, pos))
92
+ pos += cs
93
+ return specs
94
+
95
+
96
+ def _scan_touching_pairs(
97
+ zarr_path: str, component: str, chunk_shape: tuple[int, ...]
98
+ ) -> np.ndarray:
99
+ """Scan chunk-boundary slabs; return (N, 2) int64 array of touching pairs.
100
+
101
+ Reads the boundary face one zarr-chunk column at a time so memory per read
102
+ is bounded to one chunk (~200 MB). Reading the full face at once
103
+ (slice(None) on face axes) would allocate face_area × 8 bytes in one shot —
104
+ e.g. 37888 × 27392 × 8 = 8 GiB for a single z-face (OOM on real datasets).
105
+ """
106
+ root = zarr.open_group(zarr_path, mode="r")
107
+ arr = root[component]
108
+ shape = arr.shape
109
+ specs = _boundary_face_specs(shape, chunk_shape)
110
+ all_pairs: list[np.ndarray] = []
111
+ for ax, pos in specs:
112
+ # tile the face dimensions using chunk_shape columns
113
+ face_axes = [a for a in range(arr.ndim) if a != ax]
114
+ face_ranges = [range(0, shape[a], chunk_shape[a]) for a in face_axes]
115
+ for offsets in _iproduct(*face_ranges):
116
+ sl: list = [slice(None)] * arr.ndim
117
+ sl[ax] = slice(pos - 1, pos + 1)
118
+ for a, off in zip(face_axes, offsets):
119
+ sl[a] = slice(off, min(off + chunk_shape[a], shape[a]))
120
+ slab = np.moveaxis(np.asarray(arr[tuple(sl)]), ax, 0)
121
+ a_vals = slab[0].ravel().astype(np.int64)
122
+ b_vals = slab[1].ravel().astype(np.int64)
123
+ mask = (a_vals > 0) & (b_vals > 0) & (a_vals != b_vals)
124
+ if mask.any():
125
+ pairs = np.sort(
126
+ np.stack([a_vals[mask], b_vals[mask]], axis=1), axis=1
127
+ )
128
+ all_pairs.append(np.unique(pairs, axis=0))
129
+ if not all_pairs:
130
+ return np.empty((0, 2), dtype=np.int64)
131
+ return np.unique(np.vstack(all_pairs), axis=0)
132
+
133
+
134
+ def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray:
135
+ """Touching-pairs → scipy connected components → relabeling LUT."""
136
+ if max_label > _LUT_WARN_THRESHOLD:
137
+ logger.warning(
138
+ "_build_relabel_lut: max_label=%d → LUT ~%.0f MB. "
139
+ "Memory use is bounded but large LUTs slow the merge.",
140
+ max_label, max_label * 8 / 1024**2,
141
+ )
142
+ lut = np.arange(max_label + 1, dtype=np.int64)
143
+ if len(pairs) == 0 or max_label == 0:
144
+ return lut
145
+ from scipy.sparse import csr_matrix
146
+ from scipy.sparse.csgraph import connected_components
147
+
148
+ n = max_label + 1
149
+ valid = (pairs[:, 0] < n) & (pairs[:, 1] < n)
150
+ pairs = pairs[valid]
151
+ if len(pairs) == 0:
152
+ return lut
153
+ rows = np.concatenate([pairs[:, 0], pairs[:, 1]])
154
+ cols = np.concatenate([pairs[:, 1], pairs[:, 0]])
155
+ graph = csr_matrix(
156
+ (np.ones(len(rows), dtype=np.float32), (rows, cols)), shape=(n, n)
157
+ )
158
+ n_cc, cc_of = connected_components(graph, directed=False)
159
+ cc_min = np.full(n_cc, n, dtype=np.int64)
160
+ np.minimum.at(cc_min, cc_of, np.arange(n, dtype=np.int64))
161
+ return cc_min[cc_of]
162
+
163
+
164
+ def _create_zarr_label_array(
165
+ group: zarr.Group, name: str, shape: tuple, chunks: tuple
166
+ ) -> zarr.Array:
167
+ if name in group:
168
+ del group[name]
169
+ if _ZARR_V3:
170
+ return group.create_array(name, shape=shape, chunks=chunks, dtype=np.int32)
171
+ return group.zeros(name, shape=shape, chunks=chunks, dtype=np.int32, overwrite=True)
172
+
173
+
174
+ def zarr_native_merge(
175
+ staged_path: str,
176
+ staged_component: str,
177
+ out_path: str,
178
+ out_component: str,
179
+ n_workers: int = 4,
180
+ show_progress: bool = False,
181
+ ) -> None:
182
+ """Zarr-native label merge: boundary scan → scipy CC → parallel relabel.
183
+
184
+ Scales to 2000+ chunks where the dask_image approach stalls (O(n_chunks²)
185
+ graph). Reads *staged_path/staged_component*, merges touching cross-boundary
186
+ labels, writes result to *out_path/out_component*. No dask task graph.
187
+ """
188
+ root = zarr.open_group(staged_path, mode="r")
189
+ arr = root[staged_component]
190
+ shape, chunk_shape = arr.shape, arr.chunks
191
+
192
+ max_label = int(da.from_zarr(staged_path, component=staged_component).max().compute())
193
+ logger.info(
194
+ "zarr_native_merge: shape=%s chunks=%s max_label=%d", shape, chunk_shape, max_label
195
+ )
196
+
197
+ n_faces = len(_boundary_face_specs(shape, chunk_shape))
198
+ logger.info("zarr_native_merge: scanning %d boundary faces…", n_faces)
199
+ pairs = _scan_touching_pairs(staged_path, staged_component, chunk_shape)
200
+ logger.info("zarr_native_merge: %d touching pairs → building LUT", len(pairs))
201
+
202
+ lut = _build_relabel_lut(pairs, max_label)
203
+ n_remapped = int((lut != np.arange(len(lut), dtype=np.int64)).sum())
204
+ logger.info("zarr_native_merge: %d labels remapped across boundaries", n_remapped)
205
+
206
+ out_root = zarr.open_group(out_path, mode="a")
207
+ _create_zarr_label_array(out_root, out_component, shape, chunk_shape)
208
+
209
+ n_per_dim = [(s + c - 1) // c for s, c in zip(shape, chunk_shape)]
210
+ chunk_slices = [
211
+ tuple(
212
+ slice(i * c, min((i + 1) * c, s))
213
+ for i, c, s in zip(idx, chunk_shape, shape)
214
+ )
215
+ for idx in _iproduct(*[range(n) for n in n_per_dim])
216
+ ]
217
+ n_chunks = len(chunk_slices)
218
+ n_w = max(1, min(n_workers, n_chunks))
219
+ logger.info("zarr_native_merge: relabeling %d chunks with %d worker(s)…", n_chunks, n_w)
220
+
221
+ # Save LUT to a temp .npy file so workers memory-map it (shared OS page cache).
222
+ # Pickling the LUT array directly via multiprocessing initargs would
223
+ # deserialize a full copy per worker — e.g. 4 workers × 800 MB = 3.2 GB.
224
+ _lut_dir = tempfile.mkdtemp(prefix="bb_lut_")
225
+ lut_path = os.path.join(_lut_dir, "lut.npy")
226
+ np.save(lut_path, lut)
227
+ del lut # parent no longer needs it; workers load via mmap
228
+
229
+ try:
230
+ if n_w <= 1:
231
+ _init_worker(lut_path, staged_path, staged_component, out_path, out_component)
232
+ it: Any = chunk_slices
233
+ if show_progress and _tqdm is not None:
234
+ it = _tqdm(it, total=n_chunks, desc="relabel chunks")
235
+ for sl in it:
236
+ _relabel_chunk_worker(sl)
237
+ else:
238
+ with _Pool(
239
+ processes=n_w,
240
+ initializer=_init_worker,
241
+ initargs=(lut_path, staged_path, staged_component, out_path, out_component),
242
+ ) as pool:
243
+ it = pool.imap_unordered(_relabel_chunk_worker, chunk_slices)
244
+ if show_progress and _tqdm is not None:
245
+ it = _tqdm(it, total=n_chunks, desc="relabel chunks")
246
+ for _ in it:
247
+ pass
248
+ finally:
249
+ import shutil
250
+ shutil.rmtree(_lut_dir, ignore_errors=True)
251
+
252
+
253
+ # ---------------------------------------------------------------------------
254
+ # Public standalone merge API
255
+ # ---------------------------------------------------------------------------
256
+
257
+
258
+ def merge_tile_labels(
259
+ labeled: Union["da.Array", str, Path],
260
+ write_to: Union[str, Path, None] = None,
261
+ *,
262
+ input_component: str = "labels",
263
+ output_component: str = "labels",
264
+ overlap: int = 0,
265
+ sequential_labels: bool = False,
266
+ n_workers: int | None = None,
267
+ stage_dir: Union[str, Path, None] = None,
268
+ keep_stage: bool = False,
269
+ progress: bool = False,
270
+ ) -> "da.Array":
271
+ """Merge per-tile labels into a globally consistent label array.
272
+
273
+ Standalone merge step — use this when you already have per-tile labels
274
+ (from your own segmentation pipeline) and just need the boundary stitching.
275
+
276
+ Accepts either:
277
+
278
+ - A **dask array** of per-tile integer labels (e.g. output of
279
+ ``dask.array.map_blocks`` on your own segmentation function).
280
+ - A **zarr store path** whose ``input_component`` array already contains
281
+ per-tile labels written by your own pipeline.
282
+
283
+ Labels that **touch** across tile boundaries are merged into a single ID.
284
+ The merge is zarr-native (boundary scan → scipy connected components →
285
+ parallel relabel) — no dask task graph, scales to thousands of tiles.
286
+
287
+ Parameters
288
+ ----------
289
+ labeled:
290
+ Per-tile label array. Either a dask array or a path to a zarr store
291
+ that contains per-tile labels in ``input_component``.
292
+ write_to:
293
+ Output zarr store path. When None, an auto-temp store is used.
294
+ input_component:
295
+ Array name inside a zarr *input* store (ignored for dask arrays).
296
+ output_component:
297
+ Array name inside ``write_to``. Default ``"labels"``.
298
+ overlap:
299
+ If ``labeled`` is a dask array that was computed with ``da.overlap``,
300
+ pass the same depth here to trim the halos before merging.
301
+ Set 0 (default) if the array has no overlap halos.
302
+ sequential_labels:
303
+ Renumber the merged labels to a contiguous ``1..N`` range via a cheap
304
+ linear post-pass (O(voxels)). Default False.
305
+ n_workers:
306
+ Parallel workers for the relabel step. Default ``min(4, cpu_count)``.
307
+ stage_dir:
308
+ Directory for the temp stage zarr when *labeled* is a dask array.
309
+ Default: a system temp directory.
310
+ keep_stage:
311
+ Keep the temp stage zarr after merging. Default False.
312
+ progress:
313
+ Show a progress bar during the relabel step.
314
+
315
+ Returns
316
+ -------
317
+ da.Array
318
+ Merged label array (int32) backed by ``write_to``.
319
+
320
+ Examples
321
+ --------
322
+ **From a dask array of per-tile labels:**
323
+
324
+ >>> import dask.array as da
325
+ >>> from patchworks import merge_tile_labels
326
+ >>>
327
+ >>> # your own tiling + segmentation
328
+ >>> image = da.from_zarr("image.zarr").rechunk((1, 1024, 1024))
329
+ >>> labeled = image.map_blocks(my_segment_fn, dtype="int32",
330
+ ... meta=np.empty((0,) * image.ndim, dtype="int32"))
331
+ >>>
332
+ >>> merged = merge_tile_labels(labeled, write_to="labels.zarr", progress=True)
333
+
334
+ **From a pre-staged zarr store (your pipeline already wrote labels):**
335
+
336
+ >>> merged = merge_tile_labels(
337
+ ... "my_staged_labels.zarr",
338
+ ... input_component="raw_labels",
339
+ ... write_to="merged_labels.zarr",
340
+ ... sequential_labels=True,
341
+ ... )
342
+
343
+ **Trim overlap halos before merging:**
344
+
345
+ >>> # if labeled was computed with da.overlap.overlap(depth=20)
346
+ >>> merged = merge_tile_labels(labeled, write_to="labels.zarr", overlap=20)
347
+ """
348
+ import dask.array as da
349
+ from ._relabel import relabel_sequential_zarr
350
+
351
+ nw = n_workers if n_workers is not None else min(4, os.cpu_count() or 1)
352
+
353
+ # -- Stage dask array to zarr if needed --
354
+ stage_path: str | None = None
355
+ staged_component = "staged"
356
+
357
+ if isinstance(labeled, (str, Path)):
358
+ stage_path = str(labeled)
359
+ staged_component = input_component
360
+ else:
361
+ # labeled is a dask array
362
+ if overlap > 0:
363
+ labeled = da.overlap.trim_overlap(labeled, depth=overlap, boundary="none")
364
+
365
+ _base = str(stage_dir) if stage_dir is not None else tempfile.mkdtemp(prefix="bb_stage_")
366
+ stage_path = os.path.join(_base, "_bb_stage.zarr")
367
+
368
+ import dask
369
+ from dask.diagnostics import ProgressBar
370
+
371
+ ctx = ProgressBar() if progress else _nullcontext()
372
+ logger.info("Staging per-tile labels to %s …", stage_path)
373
+ with ctx:
374
+ dask.compute(
375
+ labeled.to_zarr(stage_path, component=staged_component, overwrite=True, compute=False)
376
+ )
377
+
378
+ # -- Resolve output path --
379
+ if write_to is not None:
380
+ effective_out = str(write_to)
381
+ else:
382
+ effective_out = os.path.join(
383
+ tempfile.mkdtemp(prefix="bb_merge_"), "merged.zarr"
384
+ )
385
+ logger.info("write_to not set — merged labels in auto-temp %s", effective_out)
386
+
387
+ # -- Merge --
388
+ zarr_native_merge(
389
+ stage_path, staged_component,
390
+ effective_out, output_component,
391
+ n_workers=nw,
392
+ show_progress=progress,
393
+ )
394
+
395
+ if sequential_labels:
396
+ logger.info("Relabelling to contiguous ids…")
397
+ relabel_sequential_zarr(effective_out, output_component)
398
+
399
+ # -- Cleanup temp stage (only when we created it) --
400
+ if not isinstance(labeled, (str, Path)) and not keep_stage:
401
+ import shutil
402
+ shutil.rmtree(stage_path, ignore_errors=True)
403
+ logger.info("Removed stage store %s", stage_path)
404
+
405
+ return da.from_zarr(effective_out, component=output_component)
patchworks/_relabel.py ADDED
@@ -0,0 +1,83 @@
1
+ """Linear sequential relabelling (O(voxels), not O(n_chunks²))."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from itertools import product as _iproduct
6
+
7
+ import numpy as np
8
+ import zarr
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ _LUT_WARN_THRESHOLD = 100_000_000 # warn when max_label > 100 M (LUT > 800 MB)
14
+
15
+
16
+ def relabel_sequential_array(labels: np.ndarray) -> np.ndarray:
17
+ """Remap *labels* to a contiguous ``0, 1, … N`` range.
18
+
19
+ Background (0) stays 0. Runs in one ``np.unique`` + a lookup-table gather,
20
+ i.e. O(voxels) — unlike dask's ``relabel_sequential`` which is O(n_chunks²).
21
+
22
+ Examples
23
+ --------
24
+ >>> relabel_sequential_array(np.array([0, 500000, 500000, 7]))
25
+ array([0, 2, 2, 1])
26
+ """
27
+ uniq = np.unique(labels)
28
+ max_label = int(uniq[-1])
29
+ if max_label > _LUT_WARN_THRESHOLD:
30
+ logger.warning(
31
+ "relabel_sequential_array: max_label=%d → LUT size ~%.0f MB. "
32
+ "Consider using write_to= so labels never need to be in RAM.",
33
+ max_label, max_label * 8 / 1024**2,
34
+ )
35
+ lut = np.zeros(max_label + 1, dtype=np.int64)
36
+ lut[uniq] = np.arange(uniq.size)
37
+ out = lut[labels]
38
+ n = uniq.size - 1 if uniq[0] == 0 else uniq.size
39
+ dtype = np.uint16 if n < np.iinfo(np.uint16).max else np.uint32
40
+ return out.astype(dtype)
41
+
42
+
43
+ def relabel_sequential_zarr(store_path: str, component: str = "labels") -> int:
44
+ """Relabel a written label zarr to contiguous ids, in place. Returns N.
45
+
46
+ Two-pass streaming algorithm — safe for arrays far larger than RAM.
47
+ Pass 1 collects unique ids (bounded memory: a set). Pass 2 applies the
48
+ lookup-table remap chunk by chunk.
49
+ """
50
+ root = zarr.open_group(store_path, mode="r+")
51
+ z = root[component]
52
+ z_shape, z_chunks = z.shape, z.chunks
53
+
54
+ # Iterate over actual zarr chunks in ALL dimensions. The z-slab approach
55
+ # (step = z_chunks[0], slice z[i0:i0+step]) reads the full y/x extent per
56
+ # step — for chunks like (120, 731, 731) that means (120, 37888, 27392)
57
+ # = 464 GiB in one allocation (MemoryError).
58
+ n_per_dim = [(s + c - 1) // c for s, c in zip(z_shape, z_chunks)]
59
+ chunk_slices = [
60
+ tuple(slice(i * c, min((i + 1) * c, s)) for i, c, s in zip(idx, z_chunks, z_shape))
61
+ for idx in _iproduct(*[range(n) for n in n_per_dim])
62
+ ]
63
+
64
+ uniq: set[int] = set()
65
+ for sl in chunk_slices:
66
+ uniq.update(np.unique(np.asarray(z[sl])).tolist())
67
+ sorted_ids = np.array(sorted(uniq), dtype=np.int64)
68
+ max_label = int(sorted_ids[-1])
69
+ if max_label > _LUT_WARN_THRESHOLD:
70
+ logger.warning(
71
+ "relabel_sequential_zarr: max_label=%d → LUT size ~%.0f MB.",
72
+ max_label, max_label * 8 / 1024**2,
73
+ )
74
+ lut = np.zeros(max_label + 1, dtype=np.int64)
75
+ lut[sorted_ids] = np.arange(sorted_ids.size)
76
+ n = sorted_ids.size - 1 if sorted_ids[0] == 0 else sorted_ids.size
77
+ # Use same dtype logic as relabel_sequential_array so output never overflows.
78
+ out_dtype = np.uint16 if n < np.iinfo(np.uint16).max else np.uint32
79
+ for sl in chunk_slices:
80
+ block = np.asarray(z[sl])
81
+ z[sl] = lut[block].astype(out_dtype)
82
+ logger.info("relabel_sequential_zarr: %d objects renumbered to 1..%d", n, n)
83
+ return int(n)
@@ -0,0 +1 @@
1
+ # patchworks plugins
@@ -0,0 +1,188 @@
1
+ """Cellpose plugin for patchworks.
2
+
3
+ Requires cellpose >= 3.0 (compatible with v3 and v4).
4
+
5
+ Usage
6
+ -----
7
+ >>> from patchworks.plugins.cellpose import cellpose_fn
8
+ >>> from patchworks import tile_process
9
+ >>>
10
+ >>> fn = cellpose_fn("cyto3", gpu=True, diameter=30)
11
+ >>> result = tile_process("image.zarr", fn, tile_shape=(1, 2048, 2048),
12
+ ... overlap=20, write_to="labels.zarr", progress=True)
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import importlib.metadata
17
+ import logging
18
+ from functools import partial
19
+ from typing import Any, Callable
20
+
21
+ import numpy as np
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ try:
26
+ from cellpose import models as _cellpose_models
27
+ _CELLPOSE_VERSION: tuple[int, ...] = tuple(
28
+ int(x) for x in importlib.metadata.version("cellpose").split(".")[:2]
29
+ )
30
+ _CELLPOSE_V4 = _CELLPOSE_VERSION[0] >= 4
31
+ except ImportError as _e:
32
+ _cellpose_models = None # type: ignore[assignment]
33
+ _CELLPOSE_VERSION = (0, 0)
34
+ _CELLPOSE_V4 = False
35
+
36
+ # Per-process model cache keyed by (model_type, gpu)
37
+ _model_cache: dict[tuple, Any] = {}
38
+
39
+
40
+ def _require_cellpose():
41
+ if _cellpose_models is None:
42
+ raise ImportError(
43
+ "cellpose is not installed. Install it with:\n"
44
+ " pip install cellpose\n"
45
+ "or:\n"
46
+ " pip install patchworks[cellpose]"
47
+ )
48
+
49
+
50
+ def cellpose_fn(
51
+ model: str = "cyto3",
52
+ *,
53
+ gpu: bool = False,
54
+ diameter: float | None = None,
55
+ do_3D: bool = False,
56
+ channels: list[int] | None = None,
57
+ channel_axis: int | None = None,
58
+ **cellpose_kwargs: Any,
59
+ ) -> Callable[[np.ndarray], np.ndarray]:
60
+ """Return a ready-to-use Cellpose function for ``tile_process``.
61
+
62
+ One-liner convenience wrapper: combines model configuration and function
63
+ creation into a single call.
64
+
65
+ Parameters
66
+ ----------
67
+ model:
68
+ Cellpose model type: ``"cyto3"``, ``"cyto2"``, ``"nuclei"``, etc.
69
+ gpu:
70
+ Use GPU for inference.
71
+ diameter:
72
+ Expected cell diameter in pixels. ``None`` → Cellpose auto-estimates.
73
+ do_3D:
74
+ Run in 3-D mode. Each tile must contain the full z-stack — use
75
+ ``auto_tile_shape_cellpose(do_3D=True)`` for appropriate tile shapes.
76
+ channels:
77
+ *Cellpose 3 only.* ``[cytoplasm_channel, nucleus_channel]`` (1-based,
78
+ 0 = greyscale). ``[0, 0]`` → greyscale. ``[1, 2]`` → cyto=ch1, nuc=ch2.
79
+ channel_axis:
80
+ *Cellpose 4 only.* Index of the channel axis in the input array.
81
+ ``None`` → greyscale input.
82
+ **cellpose_kwargs:
83
+ Extra kwargs forwarded to ``model.eval()``
84
+ (e.g. ``flow_threshold``, ``cellprob_threshold``, ``anisotropy``).
85
+
86
+ Returns
87
+ -------
88
+ Callable[[ndarray], ndarray]
89
+ Picklable function ready for ``tile_process``.
90
+
91
+ Examples
92
+ --------
93
+ Greyscale 2-D:
94
+
95
+ >>> fn = cellpose_fn("cyto3", gpu=True, diameter=30)
96
+ >>> result = tile_process("image.zarr", fn, tile_shape=(1, 2048, 2048), overlap=20)
97
+
98
+ Nuclear segmentation:
99
+
100
+ >>> fn = cellpose_fn("nuclei", diameter=15)
101
+ >>> result = tile_process("image.zarr", fn, channel=1)
102
+
103
+ 3-D with anisotropy:
104
+
105
+ >>> fn = cellpose_fn("cyto3", gpu=True, do_3D=True, anisotropy=3.0, diameter=20)
106
+ >>> from functools import partial
107
+ >>> from patchworks import auto_tile_shape_cellpose, tile_process
108
+ >>> tile_fn = partial(auto_tile_shape_cellpose, do_3D=True, use_gpu=True, diameter=20)
109
+ >>> result = tile_process("image.zarr", fn, tile_shape=tile_fn, overlap=10)
110
+ """
111
+ _require_cellpose()
112
+ cfg = _make_config(model, gpu, channels, channel_axis, diameter, do_3D, **cellpose_kwargs)
113
+ return partial(_run, cellpose_dict=cfg)
114
+
115
+
116
+ def _make_config(
117
+ model: str = "cyto3",
118
+ gpu: bool = False,
119
+ channels: list[int] | None = None,
120
+ channel_axis: int | None = None,
121
+ diameter: float | None = None,
122
+ do_3D: bool = False,
123
+ **cellpose_kwargs: Any,
124
+ ) -> dict[str, Any]:
125
+ return {
126
+ "model": model,
127
+ "gpu": gpu,
128
+ "channels": channels if channels is not None else [0, 0],
129
+ "channel_axis": channel_axis,
130
+ "diameter": diameter,
131
+ "do_3D": do_3D,
132
+ "cellpose_kwargs": cellpose_kwargs,
133
+ }
134
+
135
+
136
+ def _get_model(cellpose_dict: dict[str, Any]) -> Any:
137
+ """Return a worker-local cached Cellpose model."""
138
+ _require_cellpose()
139
+ key = (cellpose_dict["model"], cellpose_dict.get("gpu", False))
140
+ if key not in _model_cache:
141
+ gpu = cellpose_dict.get("gpu", False)
142
+ model_type = cellpose_dict["model"]
143
+ if _CELLPOSE_V4:
144
+ _model_cache[key] = _cellpose_models.CellposeModel(
145
+ model_type=model_type, gpu=gpu
146
+ )
147
+ else:
148
+ _model_cache[key] = _cellpose_models.Cellpose(
149
+ model_type=model_type, gpu=gpu
150
+ )
151
+ return _model_cache[key]
152
+
153
+
154
+ def _run(block: np.ndarray, cellpose_dict: dict[str, Any]) -> np.ndarray:
155
+ """Segment one tile with a cached Cellpose model."""
156
+ model = _get_model(cellpose_dict)
157
+ do_3D = cellpose_dict["do_3D"]
158
+
159
+ if _CELLPOSE_V4:
160
+ kwargs: dict[str, Any] = dict(
161
+ channel_axis=cellpose_dict.get("channel_axis"),
162
+ diameter=cellpose_dict["diameter"],
163
+ do_3D=do_3D,
164
+ **cellpose_dict.get("cellpose_kwargs", {}),
165
+ )
166
+ else:
167
+ kwargs = dict(
168
+ channels=cellpose_dict["channels"],
169
+ diameter=cellpose_dict["diameter"],
170
+ do_3D=do_3D,
171
+ **cellpose_dict.get("cellpose_kwargs", {}),
172
+ )
173
+
174
+ if do_3D:
175
+ kwargs["z_axis"] = 0
176
+ return model.eval(block, **kwargs)[0].astype("int32")
177
+ else:
178
+ # Squeeze singleton z so Cellpose gets a clean 2-D image
179
+ squeeze = block.ndim == 3 and block.shape[0] == 1
180
+ img = block[0] if squeeze else block
181
+ masks = model.eval(img, **kwargs)[0].astype("int32")
182
+ return masks[np.newaxis] if squeeze else masks
183
+
184
+
185
+ # Keep the lower-level names available for advanced users
186
+ make_cellpose_config = _make_config
187
+ get_cellpose_model = _get_model
188
+ run_cellpose = _run