swcgeom 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (45) hide show
  1. swcgeom/_version.py +2 -2
  2. swcgeom/analysis/lmeasure.py +821 -0
  3. swcgeom/analysis/sholl.py +31 -2
  4. swcgeom/core/__init__.py +4 -0
  5. swcgeom/core/branch.py +9 -4
  6. swcgeom/core/branch_tree.py +2 -3
  7. swcgeom/core/{segment.py → compartment.py} +14 -9
  8. swcgeom/core/node.py +0 -8
  9. swcgeom/core/path.py +21 -6
  10. swcgeom/core/population.py +42 -3
  11. swcgeom/core/swc_utils/assembler.py +20 -138
  12. swcgeom/core/swc_utils/base.py +12 -5
  13. swcgeom/core/swc_utils/checker.py +12 -2
  14. swcgeom/core/swc_utils/subtree.py +2 -2
  15. swcgeom/core/tree.py +53 -49
  16. swcgeom/core/tree_utils.py +27 -5
  17. swcgeom/core/tree_utils_impl.py +22 -6
  18. swcgeom/images/augmentation.py +6 -1
  19. swcgeom/images/contrast.py +107 -0
  20. swcgeom/images/folder.py +111 -29
  21. swcgeom/images/io.py +79 -40
  22. swcgeom/transforms/__init__.py +2 -0
  23. swcgeom/transforms/base.py +41 -21
  24. swcgeom/transforms/branch.py +5 -5
  25. swcgeom/transforms/geometry.py +42 -18
  26. swcgeom/transforms/image_preprocess.py +100 -0
  27. swcgeom/transforms/image_stack.py +46 -28
  28. swcgeom/transforms/images.py +76 -6
  29. swcgeom/transforms/mst.py +10 -18
  30. swcgeom/transforms/neurolucida_asc.py +495 -0
  31. swcgeom/transforms/population.py +2 -2
  32. swcgeom/transforms/tree.py +12 -14
  33. swcgeom/transforms/tree_assembler.py +85 -19
  34. swcgeom/utils/__init__.py +1 -0
  35. swcgeom/utils/neuromorpho.py +425 -300
  36. swcgeom/utils/numpy_helper.py +14 -4
  37. swcgeom/utils/plotter_2d.py +130 -0
  38. swcgeom/utils/renderer.py +28 -139
  39. swcgeom/utils/sdf.py +5 -1
  40. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/METADATA +3 -3
  41. swcgeom-0.16.0.dist-info/RECORD +67 -0
  42. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/WHEEL +1 -1
  43. swcgeom-0.14.0.dist-info/RECORD +0 -62
  44. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/LICENSE +0 -0
  45. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/top_level.txt +0 -0
swcgeom/images/io.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Read and write image stack."""
2
2
 
3
-
4
3
  import os
5
4
  import re
6
5
  import warnings
@@ -9,11 +8,13 @@ from functools import cache, lru_cache
9
8
  from typing import (
10
9
  Any,
11
10
  Callable,
11
+ Generic,
12
12
  Iterable,
13
13
  List,
14
14
  Literal,
15
15
  Optional,
16
16
  Tuple,
17
+ TypeVar,
17
18
  cast,
18
19
  overload,
19
20
  )
@@ -27,6 +28,8 @@ from v3dpy.loaders import PBD, Raw
27
28
  __all__ = ["read_imgs", "save_tiff", "read_images"]
28
29
 
29
30
  Vec3i = Tuple[int, int, int]
31
+ ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
32
+
30
33
  RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
31
34
  RE_TERAFLY_NAME = re.compile(r"^\d+(_\d+)?(_\d+)?")
32
35
 
@@ -46,30 +49,30 @@ AXES_ORDER = {
46
49
  }
47
50
 
48
51
 
49
- class ImageStack(ABC):
52
+ class ImageStack(ABC, Generic[ScalarType]):
50
53
  """Image stack."""
51
54
 
52
55
  # fmt: off
53
56
  @overload
54
57
  @abstractmethod
55
- def __getitem__(self, key: int) -> npt.NDArray[np.float32]: ... # array of shape (Y, Z, C)
58
+ def __getitem__(self, key: int) -> npt.NDArray[ScalarType]: ... # array of shape (Y, Z, C)
56
59
  @overload
57
60
  @abstractmethod
58
- def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[np.float32]: ... # array of shape (Z, C)
61
+ def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (Z, C)
59
62
  @overload
60
63
  @abstractmethod
61
- def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[np.float32]: ... # array of shape (C,)
64
+ def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (C,)
62
65
  @overload
63
66
  @abstractmethod
64
- def __getitem__(self, key: Tuple[int, int, int, int]) -> np.float32: ... # value
67
+ def __getitem__(self, key: Tuple[int, int, int, int]) -> ScalarType: ... # value
65
68
  @overload
66
69
  @abstractmethod
67
70
  def __getitem__(
68
71
  self, key: slice | Tuple[slice, slice] | Tuple[slice, slice, slice] | Tuple[slice, slice, slice, slice],
69
- ) -> npt.NDArray[np.float32]: ... # array of shape (X, Y, Z, C)
72
+ ) -> npt.NDArray[ScalarType]: ... # array of shape (X, Y, Z, C)
70
73
  @overload
71
74
  @abstractmethod
72
- def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[np.float32]: ...
75
+ def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[ScalarType]: ...
73
76
  # fmt: on
74
77
  @abstractmethod
75
78
  def __getitem__(self, key):
@@ -82,7 +85,7 @@ class ImageStack(ABC):
82
85
  """
83
86
  raise NotImplementedError()
84
87
 
85
- def get_full(self) -> npt.NDArray[np.float32]:
88
+ def get_full(self) -> npt.NDArray[ScalarType]:
86
89
  """Get full image stack.
87
90
 
88
91
  Notes
@@ -96,8 +99,19 @@ class ImageStack(ABC):
96
99
  raise NotImplementedError()
97
100
 
98
101
 
99
- def read_imgs(fname: str, **kwargs) -> ImageStack:
102
+ # fmt:off
103
+ @overload
104
+ def read_imgs(fname: str, *, dtype: ScalarType, **kwargs) -> ImageStack[ScalarType]: ...
105
+ @overload
106
+ def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float32]: ...
107
+ # fmt:on
108
+
109
+
110
+ def read_imgs(fname: str, *, dtype=None, **kwargs): # type: ignore
100
111
  """Read image stack."""
112
+
113
+ kwargs["dtype"] = dtype or np.float32
114
+
101
115
  ext = os.path.splitext(fname)[-1]
102
116
  if ext in [".tif", ".tiff"]:
103
117
  return TiffImageStack(fname, **kwargs)
@@ -191,7 +205,7 @@ def save_tiff(
191
205
  tifffile.imwrite(fname, data, **kwargs)
192
206
 
193
207
 
194
- class NDArrayImageStack(ImageStack):
208
+ class NDArrayImageStack(ImageStack[ScalarType]):
195
209
  """NDArray image stack."""
196
210
 
197
211
  def __init__(
@@ -199,6 +213,8 @@ class NDArrayImageStack(ImageStack):
199
213
  imgs: npt.NDArray[Any],
200
214
  swap_xy: Optional[bool] = None,
201
215
  filp_xy: Optional[bool] = None,
216
+ *,
217
+ dtype: ScalarType,
202
218
  ) -> None:
203
219
  super().__init__()
204
220
 
@@ -206,10 +222,6 @@ class NDArrayImageStack(ImageStack):
206
222
  imgs = np.expand_dims(imgs, -1)
207
223
  assert imgs.ndim == 4, "Should be shape of (X, Y, Z, C)"
208
224
 
209
- sclar_factor = 1.0
210
- if np.issubdtype((dtype := imgs.dtype), np.unsignedinteger):
211
- sclar_factor /= UINT_MAX[dtype]
212
-
213
225
  if swap_xy is not None:
214
226
  warnings.warn(
215
227
  "flag `swap_xy` now is unnecessary, tifffile will "
@@ -231,12 +243,18 @@ class NDArrayImageStack(ImageStack):
231
243
  if filp_xy is True:
232
244
  imgs = np.flip(imgs, (0, 1)) # (X, Y, Z, C)
233
245
 
234
- self.imgs = imgs.astype(np.float32) * sclar_factor
246
+ dtype_raw = imgs.dtype
247
+ self.imgs = imgs.astype(dtype)
248
+ if np.issubdtype(dtype, np.floating) and np.issubdtype(
249
+ dtype_raw, np.unsignedinteger
250
+ ): # TODO: add a option to disable this
251
+ sclar_factor = 1.0 / UINT_MAX[imgs.dtype]
252
+ self.imgs *= sclar_factor
235
253
 
236
254
  def __getitem__(self, key):
237
255
  return self.imgs.__getitem__(key)
238
256
 
239
- def get_full(self) -> npt.NDArray[np.float32]:
257
+ def get_full(self) -> npt.NDArray[ScalarType]:
240
258
  return self.imgs
241
259
 
242
260
  @property
@@ -244,7 +262,7 @@ class NDArrayImageStack(ImageStack):
244
262
  return cast(Tuple[int, int, int, int], self.imgs.shape)
245
263
 
246
264
 
247
- class TiffImageStack(NDArrayImageStack):
265
+ class TiffImageStack(NDArrayImageStack[ScalarType]):
248
266
  """Tiff image stack."""
249
267
 
250
268
  def __init__(
@@ -252,6 +270,8 @@ class TiffImageStack(NDArrayImageStack):
252
270
  fname: str,
253
271
  swap_xy: Optional[bool] = None,
254
272
  filp_xy: Optional[bool] = None,
273
+ *,
274
+ dtype: ScalarType,
255
275
  **kwargs,
256
276
  ) -> None:
257
277
  with tifffile.TiffFile(fname, **kwargs) as f:
@@ -265,10 +285,10 @@ class TiffImageStack(NDArrayImageStack):
265
285
 
266
286
  orders = [AXES_ORDER[c] for c in axes]
267
287
  imgs = imgs.transpose(np.argsort(orders))
268
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
288
+ super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
269
289
 
270
290
 
271
- class NrrdImageStack(NDArrayImageStack):
291
+ class NrrdImageStack(NDArrayImageStack[ScalarType]):
272
292
  """Nrrd image stack."""
273
293
 
274
294
  def __init__(
@@ -276,45 +296,53 @@ class NrrdImageStack(NDArrayImageStack):
276
296
  fname: str,
277
297
  swap_xy: Optional[bool] = None,
278
298
  filp_xy: Optional[bool] = None,
299
+ *,
300
+ dtype: ScalarType,
279
301
  **kwargs,
280
302
  ) -> None:
281
303
  imgs, header = nrrd.read(fname, **kwargs)
282
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
304
+ super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
283
305
  self.header = header
284
306
 
285
307
 
286
- class V3dImageStack(NDArrayImageStack):
308
+ class V3dImageStack(NDArrayImageStack[ScalarType]):
287
309
  """v3d image stack."""
288
310
 
289
- def __init__(self, fname: str, loader: Raw | PBD, **kwargs) -> None:
311
+ def __init__(
312
+ self, fname: str, loader: Raw | PBD, *, dtype: ScalarType, **kwargs
313
+ ) -> None:
290
314
  r = loader()
291
315
  imgs = r.load(fname)
292
- super().__init__(imgs, **kwargs)
316
+ super().__init__(imgs, dtype=dtype, **kwargs)
293
317
 
294
318
 
295
- class V3drawImageStack(V3dImageStack):
319
+ class V3drawImageStack(V3dImageStack[ScalarType]):
296
320
  """v3draw image stack."""
297
321
 
298
- def __init__(self, fname: str, **kwargs) -> None:
299
- super().__init__(fname, loader=Raw, **kwargs)
322
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
323
+ super().__init__(fname, loader=Raw, dtype=dtype, **kwargs)
300
324
 
301
325
 
302
- class V3dpbdImageStack(V3dImageStack):
326
+ class V3dpbdImageStack(V3dImageStack[ScalarType]):
303
327
  """v3dpbd image stack."""
304
328
 
305
- def __init__(self, fname: str, **kwargs) -> None:
306
- super().__init__(fname, loader=PBD, **kwargs)
329
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
330
+ super().__init__(fname, loader=PBD, dtype=dtype, **kwargs)
307
331
 
308
332
 
309
- class TeraflyImageStack(ImageStack):
333
+ class TeraflyImageStack(ImageStack[ScalarType]):
310
334
  """TeraFly image stack.
311
335
 
336
+ TeraFly is a terabytes of multidimensional volumetric images file
337
+ format as described in [1]_.
338
+
312
339
  References
313
340
  ----------
314
- [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
315
- Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional Visualization
316
- and Annotation of Terabytes of Multidimensional Volumetric Images.”
317
- Nature Methods 13, no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
341
+ .. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
342
+ Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional
343
+ Visualization and Annotation of Terabytes of Multidimensional
344
+ Volumetric Images.” Nature Methods 13,
345
+ no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
318
346
 
319
347
  Notes
320
348
  -----
@@ -328,21 +356,26 @@ class TeraflyImageStack(ImageStack):
328
356
  _listdir: Callable[[str], List[str]]
329
357
  _read_patch: Callable[[str], npt.NDArray]
330
358
 
331
- def __init__(self, root: str, *, lru_maxsize: int | None = 128) -> None:
359
+ def __init__(
360
+ self, root: str, *, dtype: ScalarType, lru_maxsize: int | None = 128
361
+ ) -> None:
332
362
  r"""
333
363
  Parameters
334
364
  ----------
335
365
  root : str
336
366
  The root of terafly which contains directories named as
337
367
  `RES(YxXxZ)`.
368
+ dtype : np.dtype
338
369
  lru_maxsize : int or None, default to 128
339
370
  Forwarding to `functools.lru_cache`. A decompressed array
340
371
  size of (256, 256, 256, 1), which is the typical size of
341
372
  terafly image stack, takes about 256 * 256 * 256 * 1 *
342
373
  4B = 64MB. A cache size of 128 requires about 8GB memeory.
343
374
  """
375
+
344
376
  super().__init__()
345
377
  self.root = root
378
+ self.dtype = dtype
346
379
  self.res, self.res_dirs, self.res_patch_sizes = self.get_resolutions(root)
347
380
 
348
381
  @cache
@@ -350,8 +383,8 @@ class TeraflyImageStack(ImageStack):
350
383
  return os.listdir(path)
351
384
 
352
385
  @lru_cache(maxsize=lru_maxsize)
353
- def read_patch(path: str) -> npt.NDArray[np.float32]:
354
- return read_imgs(path).get_full()
386
+ def read_patch(path: str) -> npt.NDArray[ScalarType]:
387
+ return read_imgs(path, dtype=dtype).get_full()
355
388
 
356
389
  self._listdir, self._read_patch = listdir, read_patch
357
390
 
@@ -382,7 +415,7 @@ class TeraflyImageStack(ImageStack):
382
415
 
383
416
  def get_patch(
384
417
  self, starts, ends, strides: int | Vec3i = 1, res_level=-1
385
- ) -> npt.NDArray[np.float32]:
418
+ ) -> npt.NDArray[ScalarType]:
386
419
  """Get patch of image stack.
387
420
 
388
421
  Returns
@@ -397,7 +430,7 @@ class TeraflyImageStack(ImageStack):
397
430
  assert np.equal(strides, [1, 1, 1]).all() # TODO: support stride
398
431
 
399
432
  shape_out = np.concatenate([ends - starts, [1]])
400
- out = np.zeros(shape_out, dtype=np.float32)
433
+ out = np.zeros(shape_out, dtype=self.dtype)
401
434
  self._get_range(starts, ends, res_level, out=out)
402
435
 
403
436
  # flip y-axis to makes it a left-handed coordinate system
@@ -604,6 +637,12 @@ class GrayImageStack:
604
637
 
605
638
 
606
639
  def read_images(*args, **kwargs) -> GrayImageStack:
640
+ """Read images.
641
+
642
+ .. deprecated:: 0.16.0
643
+ Use :meth:`read_imgs` instead.
644
+ """
645
+
607
646
  warnings.warn(
608
647
  "`read_images` has been replaced by `read_imgs` because it"
609
648
  "provide rgb support, and this will be removed in next version",
@@ -3,9 +3,11 @@
3
3
  from swcgeom.transforms.base import *
4
4
  from swcgeom.transforms.branch import *
5
5
  from swcgeom.transforms.geometry import *
6
+ from swcgeom.transforms.image_preprocess import *
6
7
  from swcgeom.transforms.image_stack import *
7
8
  from swcgeom.transforms.images import *
8
9
  from swcgeom.transforms.mst import *
10
+ from swcgeom.transforms.neurolucida_asc import *
9
11
  from swcgeom.transforms.path import *
10
12
  from swcgeom.transforms.population import *
11
13
  from swcgeom.transforms.tree import *
@@ -1,7 +1,7 @@
1
1
  """Transformation in tree."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import Any, Generic, TypeVar, cast, overload
4
+ from typing import Any, Generic, TypeVar, overload
5
5
 
6
6
  __all__ = ["Transform", "Transforms", "Identity"]
7
7
 
@@ -15,26 +15,49 @@ class Transform(ABC, Generic[T, K]):
15
15
  r"""An abstract class representing a :class:`Transform`.
16
16
 
17
17
  All transforms that represent a map from `T` to `K`.
18
-
19
- Methods
20
- -------
21
- __call__(x: T) -> K
22
- All subclasses should overwrite :meth:`__call__`, supporting
23
- applying transform in `x`.
24
- __repr__() -> str
25
- Subclasses could also optionally overwrite :meth:`__repr__`.
26
- Avoid using the underscore `_` because it is used by
27
- `Transforms`. If not provied, class name will be a default
28
- value.
29
18
  """
30
19
 
31
20
  @abstractmethod
32
21
  def __call__(self, x: T) -> K:
33
- """Apply transform."""
22
+ """Apply transform.
23
+
24
+ Notes
25
+ -----
26
+ All subclasses should overwrite :meth:`__call__`, supporting
27
+ applying transform in `x`.
28
+ """
34
29
  raise NotImplementedError()
35
30
 
36
31
  def __repr__(self) -> str:
37
- return self.__class__.__name__
32
+ classname = self.__class__.__name__
33
+ repr_ = self.extra_repr()
34
+ return f"{classname}({repr_})"
35
+
36
+ def extra_repr(self):
37
+ """Provides a human-friendly representation of the module.
38
+
39
+ This method extends the basic string representation provided by
40
+ `__repr__` method. It is designed to display additional details
41
+ about the module's parameters or its specific configuration,
42
+ which can be particularly useful for debugging and model
43
+ architecture introspection.
44
+
45
+ Examples
46
+ --------
47
+ class Foo(Transform[T, K]):
48
+ def __init__(self, my_parameter: int = 1):
49
+ self.my_parameter = my_parameter
50
+
51
+ def extra_repr(self):
52
+ return f"my_parameter={self.my_parameter}"
53
+
54
+ Notes
55
+ -----
56
+ This method should be overridden in custom modules to provide
57
+ specific details relevant to the module's functionality and
58
+ configuration.
59
+ """
60
+ return ""
38
61
 
39
62
 
40
63
  class Transforms(Transform[T, K]):
@@ -81,23 +104,20 @@ class Transforms(Transform[T, K]):
81
104
  for transform in self.transforms:
82
105
  x = transform(x)
83
106
 
84
- return cast(K, x)
107
+ return x # type: ignore
85
108
 
86
109
  def __getitem__(self, idx: int) -> Transform[Any, Any]:
87
110
  return self.transforms[idx]
88
111
 
89
- def __repr__(self) -> str:
90
- return "_".join([str(transform) for transform in self])
91
-
92
112
  def __len__(self) -> int:
93
113
  return len(self.transforms)
94
114
 
115
+ def extra_repr(self) -> str:
116
+ return ", ".join([str(transform) for transform in self])
117
+
95
118
 
96
119
  class Identity(Transform[T, T]):
97
120
  """Resurn input as-is."""
98
121
 
99
122
  def __call__(self, x: T) -> T:
100
123
  return x
101
-
102
- def __repr__(self) -> str:
103
- return ""
@@ -48,9 +48,6 @@ class BranchLinearResampler(_BranchResampler):
48
48
  super().__init__()
49
49
  self.n_nodes = n_nodes
50
50
 
51
- def __repr__(self) -> str:
52
- return f"BranchLinearResampler-{self.n_nodes}"
53
-
54
51
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
55
52
  """Resampling by linear interpolation, DO NOT keep original node.
56
53
 
@@ -75,6 +72,9 @@ class BranchLinearResampler(_BranchResampler):
75
72
  r = np.interp(xvals, xp, xyzr[:, 3])
76
73
  return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
77
74
 
75
+ def extra_repr(self):
76
+ return f"n_nodes={self.n_nodes}"
77
+
78
78
 
79
79
  class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
80
80
  r"""Smooth the branch by sliding window."""
@@ -100,8 +100,8 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
100
100
 
101
101
  return x
102
102
 
103
- def __repr__(self) -> str:
104
- return f"BranchConvSmoother-{self.n_nodes}"
103
+ def extra_repr(self):
104
+ return f"n_nodes={self.n_nodes}"
105
105
 
106
106
 
107
107
  class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
@@ -72,6 +72,9 @@ class RadiusReseter(Generic[T], Transform[T, T]):
72
72
  new_tree.ndata[new_tree.names.r] = r
73
73
  return new_tree
74
74
 
75
+ def extra_repr(self):
76
+ return f"r={self.r:.4f}"
77
+
75
78
 
76
79
  class AffineTransform(Generic[T], Transform[T, T]):
77
80
  """Apply affine matrix."""
@@ -85,10 +88,18 @@ class AffineTransform(Generic[T], Transform[T, T]):
85
88
  tm: npt.NDArray[np.float32],
86
89
  center: Center = "origin",
87
90
  *,
88
- fmt: str,
91
+ fmt: Optional[str] = None,
89
92
  names: Optional[SWCNames] = None,
90
93
  ) -> None:
91
- self.tm, self.center, self.fmt = tm, center, fmt
94
+ self.tm, self.center = tm, center
95
+
96
+ if fmt is not None:
97
+ warnings.warn(
98
+ "`fmt` parameter is no longer needed, now use the "
99
+ "extra_repr(), you can directly remove it.",
100
+ DeprecationWarning,
101
+ )
102
+
92
103
  if names is not None:
93
104
  warnings.warn(
94
105
  "`name` parameter is no longer needed, now use the "
@@ -111,9 +122,6 @@ class AffineTransform(Generic[T], Transform[T, T]):
111
122
 
112
123
  return self.apply(x, tm)
113
124
 
114
- def __repr__(self) -> str:
115
- return self.fmt
116
-
117
125
  @staticmethod
118
126
  def apply(x: T, tm: npt.NDArray[np.float32]) -> T:
119
127
  xyzw = x.xyzw().dot(tm.T).T
@@ -130,8 +138,11 @@ class Translate(Generic[T], AffineTransform[T]):
130
138
  """Translate SWC."""
131
139
 
132
140
  def __init__(self, tx: float, ty: float, tz: float, **kwargs) -> None:
133
- fmt = f"Translate-{tx}-{ty}-{tz}"
134
- super().__init__(translate3d(tx, ty, tz), fmt=fmt, **kwargs)
141
+ super().__init__(translate3d(tx, ty, tz), **kwargs)
142
+ self.tx, self.ty, self.tz = tx, ty, tz
143
+
144
+ def extra_repr(self):
145
+ return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
135
146
 
136
147
  @classmethod
137
148
  def transform(cls, x: T, tx: float, ty: float, tz: float, **kwargs) -> T:
@@ -158,8 +169,7 @@ class Scale(Generic[T], AffineTransform[T]):
158
169
  def __init__(
159
170
  self, sx: float, sy: float, sz: float, center: Center = "root", **kwargs
160
171
  ) -> None:
161
- fmt = f"Scale-{sx}-{sy}-{sz}"
162
- super().__init__(scale3d(sx, sy, sz), center=center, fmt=fmt, **kwargs)
172
+ super().__init__(scale3d(sx, sy, sz), center=center, **kwargs)
163
173
 
164
174
  @classmethod
165
175
  def transform( # pylint: disable=too-many-arguments
@@ -180,6 +190,12 @@ class Rotate(Generic[T], AffineTransform[T]):
180
190
  ) -> None:
181
191
  fmt = f"Rotate-{n[0]}-{n[1]}-{n[2]}-{theta:.4f}"
182
192
  super().__init__(rotate3d(n, theta), center=center, fmt=fmt, **kwargs)
193
+ self.n = n
194
+ self.theta = theta
195
+ self.center = center
196
+
197
+ def extra_repr(self):
198
+ return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: imporve format of n
183
199
 
184
200
  @classmethod
185
201
  def transform(
@@ -197,9 +213,11 @@ class RotateX(Generic[T], AffineTransform[T]):
197
213
  """Rotate SWC with x-axis."""
198
214
 
199
215
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
200
- super().__init__(
201
- rotate3d_x(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
202
- )
216
+ super().__init__(rotate3d_x(theta), center=center, **kwargs)
217
+ self.theta = theta
218
+
219
+ def extra_repr(self):
220
+ return f"center={self.center}, theta={self.theta:.4f}"
203
221
 
204
222
  @classmethod
205
223
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
@@ -210,9 +228,12 @@ class RotateY(Generic[T], AffineTransform[T]):
210
228
  """Rotate SWC with y-axis."""
211
229
 
212
230
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
213
- super().__init__(
214
- rotate3d_y(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
215
- )
231
+ super().__init__(rotate3d_y(theta), center=center, **kwargs)
232
+ self.theta = theta
233
+ self.center = center
234
+
235
+ def extra_repr(self):
236
+ return f"theta={self.theta:.4f}, center={self.center}"
216
237
 
217
238
  @classmethod
218
239
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
@@ -223,9 +244,12 @@ class RotateZ(Generic[T], AffineTransform[T]):
223
244
  """Rotate SWC with z-axis."""
224
245
 
225
246
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
226
- super().__init__(
227
- rotate3d_z(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
228
- )
247
+ super().__init__(rotate3d_z(theta), center=center, **kwargs)
248
+ self.theta = theta
249
+ self.center = center
250
+
251
+ def extra_repr(self):
252
+ return f"theta={self.theta:.4f}, center={self.center}"
229
253
 
230
254
  @classmethod
231
255
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
@@ -0,0 +1,100 @@
1
+ """Image stack pre-processing."""
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from scipy.fftpack import fftn, fftshift, ifftn
6
+ from scipy.ndimage import gaussian_filter, minimum_filter
7
+
8
+ from swcgeom.transforms.base import Transform
9
+
10
+ __all__ = ["SGuoImPreProcess"]
11
+
12
+
13
+ class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
14
+ """Single-Neuron Image Enhancement.
15
+
16
+ Implementation of the image enhancement method described in the paper:
17
+
18
+ Shuxia Guo, Xuan Zhao, Shengdian Jiang, Liya Ding, Hanchuan Peng,
19
+ Image enhancement to leverage the 3D morphological reconstruction
20
+ of single-cell neurons, Bioinformatics, Volume 38, Issue 2,
21
+ January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
22
+ """
23
+
24
+ def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
25
+ # TODO: support np.float32
26
+ assert x.dtype == np.uint8, "Image must be in uint8 format"
27
+ x = self.sigmoid_adjustment(x)
28
+ x = self.subtract_min_along_z(x)
29
+ x = self.bilateral_filter_3d(x)
30
+ x = self.high_pass_fft(x)
31
+ return x
32
+
33
+ @staticmethod
34
+ def sigmoid_adjustment(
35
+ image: npt.NDArray[np.uint8], sigma: float = 3, percentile: float = 25
36
+ ) -> npt.NDArray[np.uint8]:
37
+ image_normalized = image / 255.0
38
+ u = np.percentile(image_normalized, percentile)
39
+ adjusted = 1 / (1 + np.exp(-sigma * (image_normalized - u)))
40
+ adjusted_rescaled = (adjusted * 255).astype(np.uint8)
41
+ return adjusted_rescaled
42
+
43
+ @staticmethod
44
+ def subtract_min_along_z(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
45
+ min_along_z = minimum_filter(
46
+ image,
47
+ size=(1, 1, image.shape[2], 1),
48
+ mode="constant",
49
+ cval=np.max(image).item(),
50
+ )
51
+ subtracted = image - min_along_z
52
+ return subtracted
53
+
54
+ @staticmethod
55
+ def bilateral_filter_3d(
56
+ image: npt.NDArray[np.uint8], spatial_sigma=(1, 1, 0.33), range_sigma=35
57
+ ) -> npt.NDArray[np.uint8]:
58
+ # initialize the output image
59
+ filtered_image = np.zeros_like(image)
60
+
61
+ spatial_gaussian = gaussian_filter(image, spatial_sigma)
62
+
63
+ # traverse each pixel to perform bilateral filtering
64
+ # TODO: optimization is needed
65
+ for z in range(image.shape[2]):
66
+ for y in range(image.shape[1]):
67
+ for x in range(image.shape[0]):
68
+ value = image[x, y, z]
69
+ range_weight = np.exp(
70
+ -((image - value) ** 2) / (2 * range_sigma**2)
71
+ )
72
+ weights = spatial_gaussian * range_weight
73
+ filtered_image[x, y, z] = np.sum(image * weights) / np.sum(weights)
74
+
75
+ return filtered_image
76
+
77
+ @staticmethod
78
+ def high_pass_fft(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
79
+ # fft
80
+ fft_image = fftn(image)
81
+ fft_shifted = fftshift(fft_image)
82
+
83
+ # create a high-pass filter
84
+ h, w, d, _ = image.shape
85
+ x, y, z = np.ogrid[:h, :w, :d]
86
+ center = (h / 2, w / 2, d / 2)
87
+ distance = np.sqrt(
88
+ (x - center[0]) ** 2 + (y - center[1]) ** 2 + (z - center[2]) ** 2
89
+ )
90
+ # adjust this threshold to control the filtering strength
91
+ high_pass_mask = distance > (d // 4)
92
+ # apply the high-pass filter
93
+ fft_shifted *= high_pass_mask
94
+
95
+ # inverse fft
96
+ fft_unshifted = np.fft.ifftshift(fft_shifted)
97
+ filtered_image = np.real(ifftn(fft_unshifted))
98
+
99
+ filtered_rescaled = np.clip(filtered_image, 0, 255).astype(np.uint8)
100
+ return filtered_rescaled