swcgeom 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/_version.py +2 -2
- swcgeom/analysis/lmeasure.py +821 -0
- swcgeom/analysis/sholl.py +31 -2
- swcgeom/core/__init__.py +4 -0
- swcgeom/core/branch.py +9 -4
- swcgeom/core/branch_tree.py +2 -3
- swcgeom/core/{segment.py → compartment.py} +14 -9
- swcgeom/core/node.py +0 -8
- swcgeom/core/path.py +21 -6
- swcgeom/core/population.py +42 -3
- swcgeom/core/swc_utils/assembler.py +20 -138
- swcgeom/core/swc_utils/base.py +12 -5
- swcgeom/core/swc_utils/checker.py +12 -2
- swcgeom/core/swc_utils/subtree.py +2 -2
- swcgeom/core/tree.py +53 -49
- swcgeom/core/tree_utils.py +27 -5
- swcgeom/core/tree_utils_impl.py +22 -6
- swcgeom/images/augmentation.py +6 -1
- swcgeom/images/contrast.py +107 -0
- swcgeom/images/folder.py +111 -29
- swcgeom/images/io.py +79 -40
- swcgeom/transforms/__init__.py +2 -0
- swcgeom/transforms/base.py +41 -21
- swcgeom/transforms/branch.py +5 -5
- swcgeom/transforms/geometry.py +42 -18
- swcgeom/transforms/image_preprocess.py +100 -0
- swcgeom/transforms/image_stack.py +46 -28
- swcgeom/transforms/images.py +76 -6
- swcgeom/transforms/mst.py +10 -18
- swcgeom/transforms/neurolucida_asc.py +495 -0
- swcgeom/transforms/population.py +2 -2
- swcgeom/transforms/tree.py +12 -14
- swcgeom/transforms/tree_assembler.py +85 -19
- swcgeom/utils/__init__.py +1 -0
- swcgeom/utils/neuromorpho.py +425 -300
- swcgeom/utils/numpy_helper.py +14 -4
- swcgeom/utils/plotter_2d.py +130 -0
- swcgeom/utils/renderer.py +28 -139
- swcgeom/utils/sdf.py +5 -1
- {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/METADATA +3 -3
- swcgeom-0.16.0.dist-info/RECORD +67 -0
- {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/WHEEL +1 -1
- swcgeom-0.14.0.dist-info/RECORD +0 -62
- {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/LICENSE +0 -0
- {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/top_level.txt +0 -0
swcgeom/images/io.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Read and write image stack."""
|
|
2
2
|
|
|
3
|
-
|
|
4
3
|
import os
|
|
5
4
|
import re
|
|
6
5
|
import warnings
|
|
@@ -9,11 +8,13 @@ from functools import cache, lru_cache
|
|
|
9
8
|
from typing import (
|
|
10
9
|
Any,
|
|
11
10
|
Callable,
|
|
11
|
+
Generic,
|
|
12
12
|
Iterable,
|
|
13
13
|
List,
|
|
14
14
|
Literal,
|
|
15
15
|
Optional,
|
|
16
16
|
Tuple,
|
|
17
|
+
TypeVar,
|
|
17
18
|
cast,
|
|
18
19
|
overload,
|
|
19
20
|
)
|
|
@@ -27,6 +28,8 @@ from v3dpy.loaders import PBD, Raw
|
|
|
27
28
|
__all__ = ["read_imgs", "save_tiff", "read_images"]
|
|
28
29
|
|
|
29
30
|
Vec3i = Tuple[int, int, int]
|
|
31
|
+
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
|
|
32
|
+
|
|
30
33
|
RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
|
|
31
34
|
RE_TERAFLY_NAME = re.compile(r"^\d+(_\d+)?(_\d+)?")
|
|
32
35
|
|
|
@@ -46,30 +49,30 @@ AXES_ORDER = {
|
|
|
46
49
|
}
|
|
47
50
|
|
|
48
51
|
|
|
49
|
-
class ImageStack(ABC):
|
|
52
|
+
class ImageStack(ABC, Generic[ScalarType]):
|
|
50
53
|
"""Image stack."""
|
|
51
54
|
|
|
52
55
|
# fmt: off
|
|
53
56
|
@overload
|
|
54
57
|
@abstractmethod
|
|
55
|
-
def __getitem__(self, key: int) -> npt.NDArray[
|
|
58
|
+
def __getitem__(self, key: int) -> npt.NDArray[ScalarType]: ... # array of shape (Y, Z, C)
|
|
56
59
|
@overload
|
|
57
60
|
@abstractmethod
|
|
58
|
-
def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[
|
|
61
|
+
def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (Z, C)
|
|
59
62
|
@overload
|
|
60
63
|
@abstractmethod
|
|
61
|
-
def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[
|
|
64
|
+
def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (C,)
|
|
62
65
|
@overload
|
|
63
66
|
@abstractmethod
|
|
64
|
-
def __getitem__(self, key: Tuple[int, int, int, int]) ->
|
|
67
|
+
def __getitem__(self, key: Tuple[int, int, int, int]) -> ScalarType: ... # value
|
|
65
68
|
@overload
|
|
66
69
|
@abstractmethod
|
|
67
70
|
def __getitem__(
|
|
68
71
|
self, key: slice | Tuple[slice, slice] | Tuple[slice, slice, slice] | Tuple[slice, slice, slice, slice],
|
|
69
|
-
) -> npt.NDArray[
|
|
72
|
+
) -> npt.NDArray[ScalarType]: ... # array of shape (X, Y, Z, C)
|
|
70
73
|
@overload
|
|
71
74
|
@abstractmethod
|
|
72
|
-
def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[
|
|
75
|
+
def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[ScalarType]: ...
|
|
73
76
|
# fmt: on
|
|
74
77
|
@abstractmethod
|
|
75
78
|
def __getitem__(self, key):
|
|
@@ -82,7 +85,7 @@ class ImageStack(ABC):
|
|
|
82
85
|
"""
|
|
83
86
|
raise NotImplementedError()
|
|
84
87
|
|
|
85
|
-
def get_full(self) -> npt.NDArray[
|
|
88
|
+
def get_full(self) -> npt.NDArray[ScalarType]:
|
|
86
89
|
"""Get full image stack.
|
|
87
90
|
|
|
88
91
|
Notes
|
|
@@ -96,8 +99,19 @@ class ImageStack(ABC):
|
|
|
96
99
|
raise NotImplementedError()
|
|
97
100
|
|
|
98
101
|
|
|
99
|
-
|
|
102
|
+
# fmt:off
|
|
103
|
+
@overload
|
|
104
|
+
def read_imgs(fname: str, *, dtype: ScalarType, **kwargs) -> ImageStack[ScalarType]: ...
|
|
105
|
+
@overload
|
|
106
|
+
def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float32]: ...
|
|
107
|
+
# fmt:on
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def read_imgs(fname: str, *, dtype=None, **kwargs): # type: ignore
|
|
100
111
|
"""Read image stack."""
|
|
112
|
+
|
|
113
|
+
kwargs["dtype"] = dtype or np.float32
|
|
114
|
+
|
|
101
115
|
ext = os.path.splitext(fname)[-1]
|
|
102
116
|
if ext in [".tif", ".tiff"]:
|
|
103
117
|
return TiffImageStack(fname, **kwargs)
|
|
@@ -191,7 +205,7 @@ def save_tiff(
|
|
|
191
205
|
tifffile.imwrite(fname, data, **kwargs)
|
|
192
206
|
|
|
193
207
|
|
|
194
|
-
class NDArrayImageStack(ImageStack):
|
|
208
|
+
class NDArrayImageStack(ImageStack[ScalarType]):
|
|
195
209
|
"""NDArray image stack."""
|
|
196
210
|
|
|
197
211
|
def __init__(
|
|
@@ -199,6 +213,8 @@ class NDArrayImageStack(ImageStack):
|
|
|
199
213
|
imgs: npt.NDArray[Any],
|
|
200
214
|
swap_xy: Optional[bool] = None,
|
|
201
215
|
filp_xy: Optional[bool] = None,
|
|
216
|
+
*,
|
|
217
|
+
dtype: ScalarType,
|
|
202
218
|
) -> None:
|
|
203
219
|
super().__init__()
|
|
204
220
|
|
|
@@ -206,10 +222,6 @@ class NDArrayImageStack(ImageStack):
|
|
|
206
222
|
imgs = np.expand_dims(imgs, -1)
|
|
207
223
|
assert imgs.ndim == 4, "Should be shape of (X, Y, Z, C)"
|
|
208
224
|
|
|
209
|
-
sclar_factor = 1.0
|
|
210
|
-
if np.issubdtype((dtype := imgs.dtype), np.unsignedinteger):
|
|
211
|
-
sclar_factor /= UINT_MAX[dtype]
|
|
212
|
-
|
|
213
225
|
if swap_xy is not None:
|
|
214
226
|
warnings.warn(
|
|
215
227
|
"flag `swap_xy` now is unnecessary, tifffile will "
|
|
@@ -231,12 +243,18 @@ class NDArrayImageStack(ImageStack):
|
|
|
231
243
|
if filp_xy is True:
|
|
232
244
|
imgs = np.flip(imgs, (0, 1)) # (X, Y, Z, C)
|
|
233
245
|
|
|
234
|
-
|
|
246
|
+
dtype_raw = imgs.dtype
|
|
247
|
+
self.imgs = imgs.astype(dtype)
|
|
248
|
+
if np.issubdtype(dtype, np.floating) and np.issubdtype(
|
|
249
|
+
dtype_raw, np.unsignedinteger
|
|
250
|
+
): # TODO: add a option to disable this
|
|
251
|
+
sclar_factor = 1.0 / UINT_MAX[imgs.dtype]
|
|
252
|
+
self.imgs *= sclar_factor
|
|
235
253
|
|
|
236
254
|
def __getitem__(self, key):
|
|
237
255
|
return self.imgs.__getitem__(key)
|
|
238
256
|
|
|
239
|
-
def get_full(self) -> npt.NDArray[
|
|
257
|
+
def get_full(self) -> npt.NDArray[ScalarType]:
|
|
240
258
|
return self.imgs
|
|
241
259
|
|
|
242
260
|
@property
|
|
@@ -244,7 +262,7 @@ class NDArrayImageStack(ImageStack):
|
|
|
244
262
|
return cast(Tuple[int, int, int, int], self.imgs.shape)
|
|
245
263
|
|
|
246
264
|
|
|
247
|
-
class TiffImageStack(NDArrayImageStack):
|
|
265
|
+
class TiffImageStack(NDArrayImageStack[ScalarType]):
|
|
248
266
|
"""Tiff image stack."""
|
|
249
267
|
|
|
250
268
|
def __init__(
|
|
@@ -252,6 +270,8 @@ class TiffImageStack(NDArrayImageStack):
|
|
|
252
270
|
fname: str,
|
|
253
271
|
swap_xy: Optional[bool] = None,
|
|
254
272
|
filp_xy: Optional[bool] = None,
|
|
273
|
+
*,
|
|
274
|
+
dtype: ScalarType,
|
|
255
275
|
**kwargs,
|
|
256
276
|
) -> None:
|
|
257
277
|
with tifffile.TiffFile(fname, **kwargs) as f:
|
|
@@ -265,10 +285,10 @@ class TiffImageStack(NDArrayImageStack):
|
|
|
265
285
|
|
|
266
286
|
orders = [AXES_ORDER[c] for c in axes]
|
|
267
287
|
imgs = imgs.transpose(np.argsort(orders))
|
|
268
|
-
super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
|
|
288
|
+
super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
|
|
269
289
|
|
|
270
290
|
|
|
271
|
-
class NrrdImageStack(NDArrayImageStack):
|
|
291
|
+
class NrrdImageStack(NDArrayImageStack[ScalarType]):
|
|
272
292
|
"""Nrrd image stack."""
|
|
273
293
|
|
|
274
294
|
def __init__(
|
|
@@ -276,45 +296,53 @@ class NrrdImageStack(NDArrayImageStack):
|
|
|
276
296
|
fname: str,
|
|
277
297
|
swap_xy: Optional[bool] = None,
|
|
278
298
|
filp_xy: Optional[bool] = None,
|
|
299
|
+
*,
|
|
300
|
+
dtype: ScalarType,
|
|
279
301
|
**kwargs,
|
|
280
302
|
) -> None:
|
|
281
303
|
imgs, header = nrrd.read(fname, **kwargs)
|
|
282
|
-
super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
|
|
304
|
+
super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
|
|
283
305
|
self.header = header
|
|
284
306
|
|
|
285
307
|
|
|
286
|
-
class V3dImageStack(NDArrayImageStack):
|
|
308
|
+
class V3dImageStack(NDArrayImageStack[ScalarType]):
|
|
287
309
|
"""v3d image stack."""
|
|
288
310
|
|
|
289
|
-
def __init__(
|
|
311
|
+
def __init__(
|
|
312
|
+
self, fname: str, loader: Raw | PBD, *, dtype: ScalarType, **kwargs
|
|
313
|
+
) -> None:
|
|
290
314
|
r = loader()
|
|
291
315
|
imgs = r.load(fname)
|
|
292
|
-
super().__init__(imgs, **kwargs)
|
|
316
|
+
super().__init__(imgs, dtype=dtype, **kwargs)
|
|
293
317
|
|
|
294
318
|
|
|
295
|
-
class V3drawImageStack(V3dImageStack):
|
|
319
|
+
class V3drawImageStack(V3dImageStack[ScalarType]):
|
|
296
320
|
"""v3draw image stack."""
|
|
297
321
|
|
|
298
|
-
def __init__(self, fname: str, **kwargs) -> None:
|
|
299
|
-
super().__init__(fname, loader=Raw, **kwargs)
|
|
322
|
+
def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
|
|
323
|
+
super().__init__(fname, loader=Raw, dtype=dtype, **kwargs)
|
|
300
324
|
|
|
301
325
|
|
|
302
|
-
class V3dpbdImageStack(V3dImageStack):
|
|
326
|
+
class V3dpbdImageStack(V3dImageStack[ScalarType]):
|
|
303
327
|
"""v3dpbd image stack."""
|
|
304
328
|
|
|
305
|
-
def __init__(self, fname: str, **kwargs) -> None:
|
|
306
|
-
super().__init__(fname, loader=PBD, **kwargs)
|
|
329
|
+
def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
|
|
330
|
+
super().__init__(fname, loader=PBD, dtype=dtype, **kwargs)
|
|
307
331
|
|
|
308
332
|
|
|
309
|
-
class TeraflyImageStack(ImageStack):
|
|
333
|
+
class TeraflyImageStack(ImageStack[ScalarType]):
|
|
310
334
|
"""TeraFly image stack.
|
|
311
335
|
|
|
336
|
+
TeraFly is a terabytes of multidimensional volumetric images file
|
|
337
|
+
format as described in [1]_.
|
|
338
|
+
|
|
312
339
|
References
|
|
313
340
|
----------
|
|
314
|
-
[1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
341
|
+
.. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
|
|
342
|
+
Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional
|
|
343
|
+
Visualization and Annotation of Terabytes of Multidimensional
|
|
344
|
+
Volumetric Images.” Nature Methods 13,
|
|
345
|
+
no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
|
|
318
346
|
|
|
319
347
|
Notes
|
|
320
348
|
-----
|
|
@@ -328,21 +356,26 @@ class TeraflyImageStack(ImageStack):
|
|
|
328
356
|
_listdir: Callable[[str], List[str]]
|
|
329
357
|
_read_patch: Callable[[str], npt.NDArray]
|
|
330
358
|
|
|
331
|
-
def __init__(
|
|
359
|
+
def __init__(
|
|
360
|
+
self, root: str, *, dtype: ScalarType, lru_maxsize: int | None = 128
|
|
361
|
+
) -> None:
|
|
332
362
|
r"""
|
|
333
363
|
Parameters
|
|
334
364
|
----------
|
|
335
365
|
root : str
|
|
336
366
|
The root of terafly which contains directories named as
|
|
337
367
|
`RES(YxXxZ)`.
|
|
368
|
+
dtype : np.dtype
|
|
338
369
|
lru_maxsize : int or None, default to 128
|
|
339
370
|
Forwarding to `functools.lru_cache`. A decompressed array
|
|
340
371
|
size of (256, 256, 256, 1), which is the typical size of
|
|
341
372
|
terafly image stack, takes about 256 * 256 * 256 * 1 *
|
|
342
373
|
4B = 64MB. A cache size of 128 requires about 8GB memeory.
|
|
343
374
|
"""
|
|
375
|
+
|
|
344
376
|
super().__init__()
|
|
345
377
|
self.root = root
|
|
378
|
+
self.dtype = dtype
|
|
346
379
|
self.res, self.res_dirs, self.res_patch_sizes = self.get_resolutions(root)
|
|
347
380
|
|
|
348
381
|
@cache
|
|
@@ -350,8 +383,8 @@ class TeraflyImageStack(ImageStack):
|
|
|
350
383
|
return os.listdir(path)
|
|
351
384
|
|
|
352
385
|
@lru_cache(maxsize=lru_maxsize)
|
|
353
|
-
def read_patch(path: str) -> npt.NDArray[
|
|
354
|
-
return read_imgs(path).get_full()
|
|
386
|
+
def read_patch(path: str) -> npt.NDArray[ScalarType]:
|
|
387
|
+
return read_imgs(path, dtype=dtype).get_full()
|
|
355
388
|
|
|
356
389
|
self._listdir, self._read_patch = listdir, read_patch
|
|
357
390
|
|
|
@@ -382,7 +415,7 @@ class TeraflyImageStack(ImageStack):
|
|
|
382
415
|
|
|
383
416
|
def get_patch(
|
|
384
417
|
self, starts, ends, strides: int | Vec3i = 1, res_level=-1
|
|
385
|
-
) -> npt.NDArray[
|
|
418
|
+
) -> npt.NDArray[ScalarType]:
|
|
386
419
|
"""Get patch of image stack.
|
|
387
420
|
|
|
388
421
|
Returns
|
|
@@ -397,7 +430,7 @@ class TeraflyImageStack(ImageStack):
|
|
|
397
430
|
assert np.equal(strides, [1, 1, 1]).all() # TODO: support stride
|
|
398
431
|
|
|
399
432
|
shape_out = np.concatenate([ends - starts, [1]])
|
|
400
|
-
out = np.zeros(shape_out, dtype=
|
|
433
|
+
out = np.zeros(shape_out, dtype=self.dtype)
|
|
401
434
|
self._get_range(starts, ends, res_level, out=out)
|
|
402
435
|
|
|
403
436
|
# flip y-axis to makes it a left-handed coordinate system
|
|
@@ -604,6 +637,12 @@ class GrayImageStack:
|
|
|
604
637
|
|
|
605
638
|
|
|
606
639
|
def read_images(*args, **kwargs) -> GrayImageStack:
|
|
640
|
+
"""Read images.
|
|
641
|
+
|
|
642
|
+
.. deprecated:: 0.16.0
|
|
643
|
+
Use :meth:`read_imgs` instead.
|
|
644
|
+
"""
|
|
645
|
+
|
|
607
646
|
warnings.warn(
|
|
608
647
|
"`read_images` has been replaced by `read_imgs` because it"
|
|
609
648
|
"provide rgb support, and this will be removed in next version",
|
swcgeom/transforms/__init__.py
CHANGED
|
@@ -3,9 +3,11 @@
|
|
|
3
3
|
from swcgeom.transforms.base import *
|
|
4
4
|
from swcgeom.transforms.branch import *
|
|
5
5
|
from swcgeom.transforms.geometry import *
|
|
6
|
+
from swcgeom.transforms.image_preprocess import *
|
|
6
7
|
from swcgeom.transforms.image_stack import *
|
|
7
8
|
from swcgeom.transforms.images import *
|
|
8
9
|
from swcgeom.transforms.mst import *
|
|
10
|
+
from swcgeom.transforms.neurolucida_asc import *
|
|
9
11
|
from swcgeom.transforms.path import *
|
|
10
12
|
from swcgeom.transforms.population import *
|
|
11
13
|
from swcgeom.transforms.tree import *
|
swcgeom/transforms/base.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Transformation in tree."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Any, Generic, TypeVar,
|
|
4
|
+
from typing import Any, Generic, TypeVar, overload
|
|
5
5
|
|
|
6
6
|
__all__ = ["Transform", "Transforms", "Identity"]
|
|
7
7
|
|
|
@@ -15,26 +15,49 @@ class Transform(ABC, Generic[T, K]):
|
|
|
15
15
|
r"""An abstract class representing a :class:`Transform`.
|
|
16
16
|
|
|
17
17
|
All transforms that represent a map from `T` to `K`.
|
|
18
|
-
|
|
19
|
-
Methods
|
|
20
|
-
-------
|
|
21
|
-
__call__(x: T) -> K
|
|
22
|
-
All subclasses should overwrite :meth:`__call__`, supporting
|
|
23
|
-
applying transform in `x`.
|
|
24
|
-
__repr__() -> str
|
|
25
|
-
Subclasses could also optionally overwrite :meth:`__repr__`.
|
|
26
|
-
Avoid using the underscore `_` because it is used by
|
|
27
|
-
`Transforms`. If not provied, class name will be a default
|
|
28
|
-
value.
|
|
29
18
|
"""
|
|
30
19
|
|
|
31
20
|
@abstractmethod
|
|
32
21
|
def __call__(self, x: T) -> K:
|
|
33
|
-
"""Apply transform.
|
|
22
|
+
"""Apply transform.
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
All subclasses should overwrite :meth:`__call__`, supporting
|
|
27
|
+
applying transform in `x`.
|
|
28
|
+
"""
|
|
34
29
|
raise NotImplementedError()
|
|
35
30
|
|
|
36
31
|
def __repr__(self) -> str:
|
|
37
|
-
|
|
32
|
+
classname = self.__class__.__name__
|
|
33
|
+
repr_ = self.extra_repr()
|
|
34
|
+
return f"{classname}({repr_})"
|
|
35
|
+
|
|
36
|
+
def extra_repr(self):
|
|
37
|
+
"""Provides a human-friendly representation of the module.
|
|
38
|
+
|
|
39
|
+
This method extends the basic string representation provided by
|
|
40
|
+
`__repr__` method. It is designed to display additional details
|
|
41
|
+
about the module's parameters or its specific configuration,
|
|
42
|
+
which can be particularly useful for debugging and model
|
|
43
|
+
architecture introspection.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
class Foo(Transform[T, K]):
|
|
48
|
+
def __init__(self, my_parameter: int = 1):
|
|
49
|
+
self.my_parameter = my_parameter
|
|
50
|
+
|
|
51
|
+
def extra_repr(self):
|
|
52
|
+
return f"my_parameter={self.my_parameter}"
|
|
53
|
+
|
|
54
|
+
Notes
|
|
55
|
+
-----
|
|
56
|
+
This method should be overridden in custom modules to provide
|
|
57
|
+
specific details relevant to the module's functionality and
|
|
58
|
+
configuration.
|
|
59
|
+
"""
|
|
60
|
+
return ""
|
|
38
61
|
|
|
39
62
|
|
|
40
63
|
class Transforms(Transform[T, K]):
|
|
@@ -81,23 +104,20 @@ class Transforms(Transform[T, K]):
|
|
|
81
104
|
for transform in self.transforms:
|
|
82
105
|
x = transform(x)
|
|
83
106
|
|
|
84
|
-
return
|
|
107
|
+
return x # type: ignore
|
|
85
108
|
|
|
86
109
|
def __getitem__(self, idx: int) -> Transform[Any, Any]:
|
|
87
110
|
return self.transforms[idx]
|
|
88
111
|
|
|
89
|
-
def __repr__(self) -> str:
|
|
90
|
-
return "_".join([str(transform) for transform in self])
|
|
91
|
-
|
|
92
112
|
def __len__(self) -> int:
|
|
93
113
|
return len(self.transforms)
|
|
94
114
|
|
|
115
|
+
def extra_repr(self) -> str:
|
|
116
|
+
return ", ".join([str(transform) for transform in self])
|
|
117
|
+
|
|
95
118
|
|
|
96
119
|
class Identity(Transform[T, T]):
|
|
97
120
|
"""Resurn input as-is."""
|
|
98
121
|
|
|
99
122
|
def __call__(self, x: T) -> T:
|
|
100
123
|
return x
|
|
101
|
-
|
|
102
|
-
def __repr__(self) -> str:
|
|
103
|
-
return ""
|
swcgeom/transforms/branch.py
CHANGED
|
@@ -48,9 +48,6 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
48
48
|
super().__init__()
|
|
49
49
|
self.n_nodes = n_nodes
|
|
50
50
|
|
|
51
|
-
def __repr__(self) -> str:
|
|
52
|
-
return f"BranchLinearResampler-{self.n_nodes}"
|
|
53
|
-
|
|
54
51
|
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
55
52
|
"""Resampling by linear interpolation, DO NOT keep original node.
|
|
56
53
|
|
|
@@ -75,6 +72,9 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
75
72
|
r = np.interp(xvals, xp, xyzr[:, 3])
|
|
76
73
|
return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
|
|
77
74
|
|
|
75
|
+
def extra_repr(self):
|
|
76
|
+
return f"n_nodes={self.n_nodes}"
|
|
77
|
+
|
|
78
78
|
|
|
79
79
|
class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
80
80
|
r"""Smooth the branch by sliding window."""
|
|
@@ -100,8 +100,8 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
|
100
100
|
|
|
101
101
|
return x
|
|
102
102
|
|
|
103
|
-
def
|
|
104
|
-
return f"
|
|
103
|
+
def extra_repr(self):
|
|
104
|
+
return f"n_nodes={self.n_nodes}"
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
swcgeom/transforms/geometry.py
CHANGED
|
@@ -72,6 +72,9 @@ class RadiusReseter(Generic[T], Transform[T, T]):
|
|
|
72
72
|
new_tree.ndata[new_tree.names.r] = r
|
|
73
73
|
return new_tree
|
|
74
74
|
|
|
75
|
+
def extra_repr(self):
|
|
76
|
+
return f"r={self.r:.4f}"
|
|
77
|
+
|
|
75
78
|
|
|
76
79
|
class AffineTransform(Generic[T], Transform[T, T]):
|
|
77
80
|
"""Apply affine matrix."""
|
|
@@ -85,10 +88,18 @@ class AffineTransform(Generic[T], Transform[T, T]):
|
|
|
85
88
|
tm: npt.NDArray[np.float32],
|
|
86
89
|
center: Center = "origin",
|
|
87
90
|
*,
|
|
88
|
-
fmt: str,
|
|
91
|
+
fmt: Optional[str] = None,
|
|
89
92
|
names: Optional[SWCNames] = None,
|
|
90
93
|
) -> None:
|
|
91
|
-
self.tm, self.center
|
|
94
|
+
self.tm, self.center = tm, center
|
|
95
|
+
|
|
96
|
+
if fmt is not None:
|
|
97
|
+
warnings.warn(
|
|
98
|
+
"`fmt` parameter is no longer needed, now use the "
|
|
99
|
+
"extra_repr(), you can directly remove it.",
|
|
100
|
+
DeprecationWarning,
|
|
101
|
+
)
|
|
102
|
+
|
|
92
103
|
if names is not None:
|
|
93
104
|
warnings.warn(
|
|
94
105
|
"`name` parameter is no longer needed, now use the "
|
|
@@ -111,9 +122,6 @@ class AffineTransform(Generic[T], Transform[T, T]):
|
|
|
111
122
|
|
|
112
123
|
return self.apply(x, tm)
|
|
113
124
|
|
|
114
|
-
def __repr__(self) -> str:
|
|
115
|
-
return self.fmt
|
|
116
|
-
|
|
117
125
|
@staticmethod
|
|
118
126
|
def apply(x: T, tm: npt.NDArray[np.float32]) -> T:
|
|
119
127
|
xyzw = x.xyzw().dot(tm.T).T
|
|
@@ -130,8 +138,11 @@ class Translate(Generic[T], AffineTransform[T]):
|
|
|
130
138
|
"""Translate SWC."""
|
|
131
139
|
|
|
132
140
|
def __init__(self, tx: float, ty: float, tz: float, **kwargs) -> None:
|
|
133
|
-
|
|
134
|
-
|
|
141
|
+
super().__init__(translate3d(tx, ty, tz), **kwargs)
|
|
142
|
+
self.tx, self.ty, self.tz = tx, ty, tz
|
|
143
|
+
|
|
144
|
+
def extra_repr(self):
|
|
145
|
+
return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
|
|
135
146
|
|
|
136
147
|
@classmethod
|
|
137
148
|
def transform(cls, x: T, tx: float, ty: float, tz: float, **kwargs) -> T:
|
|
@@ -158,8 +169,7 @@ class Scale(Generic[T], AffineTransform[T]):
|
|
|
158
169
|
def __init__(
|
|
159
170
|
self, sx: float, sy: float, sz: float, center: Center = "root", **kwargs
|
|
160
171
|
) -> None:
|
|
161
|
-
|
|
162
|
-
super().__init__(scale3d(sx, sy, sz), center=center, fmt=fmt, **kwargs)
|
|
172
|
+
super().__init__(scale3d(sx, sy, sz), center=center, **kwargs)
|
|
163
173
|
|
|
164
174
|
@classmethod
|
|
165
175
|
def transform( # pylint: disable=too-many-arguments
|
|
@@ -180,6 +190,12 @@ class Rotate(Generic[T], AffineTransform[T]):
|
|
|
180
190
|
) -> None:
|
|
181
191
|
fmt = f"Rotate-{n[0]}-{n[1]}-{n[2]}-{theta:.4f}"
|
|
182
192
|
super().__init__(rotate3d(n, theta), center=center, fmt=fmt, **kwargs)
|
|
193
|
+
self.n = n
|
|
194
|
+
self.theta = theta
|
|
195
|
+
self.center = center
|
|
196
|
+
|
|
197
|
+
def extra_repr(self):
|
|
198
|
+
return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: imporve format of n
|
|
183
199
|
|
|
184
200
|
@classmethod
|
|
185
201
|
def transform(
|
|
@@ -197,9 +213,11 @@ class RotateX(Generic[T], AffineTransform[T]):
|
|
|
197
213
|
"""Rotate SWC with x-axis."""
|
|
198
214
|
|
|
199
215
|
def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
|
|
200
|
-
super().__init__(
|
|
201
|
-
|
|
202
|
-
|
|
216
|
+
super().__init__(rotate3d_x(theta), center=center, **kwargs)
|
|
217
|
+
self.theta = theta
|
|
218
|
+
|
|
219
|
+
def extra_repr(self):
|
|
220
|
+
return f"center={self.center}, theta={self.theta:.4f}"
|
|
203
221
|
|
|
204
222
|
@classmethod
|
|
205
223
|
def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
|
|
@@ -210,9 +228,12 @@ class RotateY(Generic[T], AffineTransform[T]):
|
|
|
210
228
|
"""Rotate SWC with y-axis."""
|
|
211
229
|
|
|
212
230
|
def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
|
|
213
|
-
super().__init__(
|
|
214
|
-
|
|
215
|
-
|
|
231
|
+
super().__init__(rotate3d_y(theta), center=center, **kwargs)
|
|
232
|
+
self.theta = theta
|
|
233
|
+
self.center = center
|
|
234
|
+
|
|
235
|
+
def extra_repr(self):
|
|
236
|
+
return f"theta={self.theta:.4f}, center={self.center}"
|
|
216
237
|
|
|
217
238
|
@classmethod
|
|
218
239
|
def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
|
|
@@ -223,9 +244,12 @@ class RotateZ(Generic[T], AffineTransform[T]):
|
|
|
223
244
|
"""Rotate SWC with z-axis."""
|
|
224
245
|
|
|
225
246
|
def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
|
|
226
|
-
super().__init__(
|
|
227
|
-
|
|
228
|
-
|
|
247
|
+
super().__init__(rotate3d_z(theta), center=center, **kwargs)
|
|
248
|
+
self.theta = theta
|
|
249
|
+
self.center = center
|
|
250
|
+
|
|
251
|
+
def extra_repr(self):
|
|
252
|
+
return f"theta={self.theta:.4f}, center={self.center}"
|
|
229
253
|
|
|
230
254
|
@classmethod
|
|
231
255
|
def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Image stack pre-processing."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
from scipy.fftpack import fftn, fftshift, ifftn
|
|
6
|
+
from scipy.ndimage import gaussian_filter, minimum_filter
|
|
7
|
+
|
|
8
|
+
from swcgeom.transforms.base import Transform
|
|
9
|
+
|
|
10
|
+
__all__ = ["SGuoImPreProcess"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
|
|
14
|
+
"""Single-Neuron Image Enhancement.
|
|
15
|
+
|
|
16
|
+
Implementation of the image enhancement method described in the paper:
|
|
17
|
+
|
|
18
|
+
Shuxia Guo, Xuan Zhao, Shengdian Jiang, Liya Ding, Hanchuan Peng,
|
|
19
|
+
Image enhancement to leverage the 3D morphological reconstruction
|
|
20
|
+
of single-cell neurons, Bioinformatics, Volume 38, Issue 2,
|
|
21
|
+
January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
|
|
25
|
+
# TODO: support np.float32
|
|
26
|
+
assert x.dtype == np.uint8, "Image must be in uint8 format"
|
|
27
|
+
x = self.sigmoid_adjustment(x)
|
|
28
|
+
x = self.subtract_min_along_z(x)
|
|
29
|
+
x = self.bilateral_filter_3d(x)
|
|
30
|
+
x = self.high_pass_fft(x)
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def sigmoid_adjustment(
|
|
35
|
+
image: npt.NDArray[np.uint8], sigma: float = 3, percentile: float = 25
|
|
36
|
+
) -> npt.NDArray[np.uint8]:
|
|
37
|
+
image_normalized = image / 255.0
|
|
38
|
+
u = np.percentile(image_normalized, percentile)
|
|
39
|
+
adjusted = 1 / (1 + np.exp(-sigma * (image_normalized - u)))
|
|
40
|
+
adjusted_rescaled = (adjusted * 255).astype(np.uint8)
|
|
41
|
+
return adjusted_rescaled
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def subtract_min_along_z(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
|
|
45
|
+
min_along_z = minimum_filter(
|
|
46
|
+
image,
|
|
47
|
+
size=(1, 1, image.shape[2], 1),
|
|
48
|
+
mode="constant",
|
|
49
|
+
cval=np.max(image).item(),
|
|
50
|
+
)
|
|
51
|
+
subtracted = image - min_along_z
|
|
52
|
+
return subtracted
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def bilateral_filter_3d(
|
|
56
|
+
image: npt.NDArray[np.uint8], spatial_sigma=(1, 1, 0.33), range_sigma=35
|
|
57
|
+
) -> npt.NDArray[np.uint8]:
|
|
58
|
+
# initialize the output image
|
|
59
|
+
filtered_image = np.zeros_like(image)
|
|
60
|
+
|
|
61
|
+
spatial_gaussian = gaussian_filter(image, spatial_sigma)
|
|
62
|
+
|
|
63
|
+
# traverse each pixel to perform bilateral filtering
|
|
64
|
+
# TODO: optimization is needed
|
|
65
|
+
for z in range(image.shape[2]):
|
|
66
|
+
for y in range(image.shape[1]):
|
|
67
|
+
for x in range(image.shape[0]):
|
|
68
|
+
value = image[x, y, z]
|
|
69
|
+
range_weight = np.exp(
|
|
70
|
+
-((image - value) ** 2) / (2 * range_sigma**2)
|
|
71
|
+
)
|
|
72
|
+
weights = spatial_gaussian * range_weight
|
|
73
|
+
filtered_image[x, y, z] = np.sum(image * weights) / np.sum(weights)
|
|
74
|
+
|
|
75
|
+
return filtered_image
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def high_pass_fft(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
|
|
79
|
+
# fft
|
|
80
|
+
fft_image = fftn(image)
|
|
81
|
+
fft_shifted = fftshift(fft_image)
|
|
82
|
+
|
|
83
|
+
# create a high-pass filter
|
|
84
|
+
h, w, d, _ = image.shape
|
|
85
|
+
x, y, z = np.ogrid[:h, :w, :d]
|
|
86
|
+
center = (h / 2, w / 2, d / 2)
|
|
87
|
+
distance = np.sqrt(
|
|
88
|
+
(x - center[0]) ** 2 + (y - center[1]) ** 2 + (z - center[2]) ** 2
|
|
89
|
+
)
|
|
90
|
+
# adjust this threshold to control the filtering strength
|
|
91
|
+
high_pass_mask = distance > (d // 4)
|
|
92
|
+
# apply the high-pass filter
|
|
93
|
+
fft_shifted *= high_pass_mask
|
|
94
|
+
|
|
95
|
+
# inverse fft
|
|
96
|
+
fft_unshifted = np.fft.ifftshift(fft_shifted)
|
|
97
|
+
filtered_image = np.real(ifftn(fft_unshifted))
|
|
98
|
+
|
|
99
|
+
filtered_rescaled = np.clip(filtered_image, 0, 255).astype(np.uint8)
|
|
100
|
+
return filtered_rescaled
|