swcgeom 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (42) hide show
  1. swcgeom/_version.py +2 -2
  2. swcgeom/analysis/__init__.py +1 -3
  3. swcgeom/analysis/feature_extractor.py +3 -3
  4. swcgeom/analysis/{node_features.py → features.py} +105 -3
  5. swcgeom/analysis/lmeasure.py +821 -0
  6. swcgeom/analysis/sholl.py +31 -2
  7. swcgeom/core/__init__.py +4 -0
  8. swcgeom/core/branch.py +9 -4
  9. swcgeom/core/{segment.py → compartment.py} +14 -9
  10. swcgeom/core/node.py +0 -8
  11. swcgeom/core/path.py +21 -6
  12. swcgeom/core/population.py +47 -7
  13. swcgeom/core/swc_utils/assembler.py +12 -1
  14. swcgeom/core/swc_utils/base.py +12 -5
  15. swcgeom/core/swc_utils/checker.py +12 -2
  16. swcgeom/core/tree.py +34 -37
  17. swcgeom/core/tree_utils.py +4 -0
  18. swcgeom/images/augmentation.py +6 -1
  19. swcgeom/images/contrast.py +107 -0
  20. swcgeom/images/folder.py +71 -14
  21. swcgeom/images/io.py +74 -88
  22. swcgeom/transforms/__init__.py +2 -0
  23. swcgeom/transforms/image_preprocess.py +100 -0
  24. swcgeom/transforms/image_stack.py +1 -4
  25. swcgeom/transforms/images.py +176 -5
  26. swcgeom/transforms/mst.py +5 -5
  27. swcgeom/transforms/neurolucida_asc.py +495 -0
  28. swcgeom/transforms/tree.py +5 -1
  29. swcgeom/utils/__init__.py +1 -0
  30. swcgeom/utils/neuromorpho.py +425 -300
  31. swcgeom/utils/numpy_helper.py +14 -4
  32. swcgeom/utils/plotter_2d.py +130 -0
  33. swcgeom/utils/renderer.py +28 -139
  34. swcgeom/utils/sdf.py +5 -1
  35. {swcgeom-0.15.0.dist-info → swcgeom-0.17.0.dist-info}/METADATA +3 -3
  36. swcgeom-0.17.0.dist-info/RECORD +65 -0
  37. {swcgeom-0.15.0.dist-info → swcgeom-0.17.0.dist-info}/WHEEL +1 -1
  38. swcgeom/analysis/branch_features.py +0 -67
  39. swcgeom/analysis/path_features.py +0 -37
  40. swcgeom-0.15.0.dist-info/RECORD +0 -62
  41. {swcgeom-0.15.0.dist-info → swcgeom-0.17.0.dist-info}/LICENSE +0 -0
  42. {swcgeom-0.15.0.dist-info → swcgeom-0.17.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,107 @@
1
+ """The contrast of an image.
2
+
3
+ Notes
4
+ -----
5
+ This is expremental code, and the API is subject to change.
6
+ """
7
+
8
+ from typing import Optional, overload
9
+
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+
13
+ __all__ = ["contrast_std", "contrast_michelson", "contrast_rms", "contrast_weber"]
14
+
15
+ Array3D = npt.NDArray[np.float32]
16
+
17
+
18
+ @overload
19
+ def contrast_std(image: Array3D) -> float:
20
+ """Get the std contrast of an image stack.
21
+
22
+ Parameters
23
+ ----------
24
+ imgs : ndarray
25
+
26
+ Returns
27
+ -------
28
+ contrast : float
29
+ """
30
+ ...
31
+
32
+
33
+ @overload
34
+ def contrast_std(image: Array3D, contrast: float) -> Array3D:
35
+ """Adjust the contrast of an image stack.
36
+
37
+ Parameters
38
+ ----------
39
+ imgs : ndarray
40
+ constrast : float
41
+ The contrast adjustment factor. 1.0 leaves the image unchanged.
42
+
43
+ Returns
44
+ -------
45
+ imgs : ndarray
46
+ The adjusted image.
47
+ """
48
+ ...
49
+
50
+
51
+ def contrast_std(image: Array3D, contrast: Optional[float] = None):
52
+ if contrast is None:
53
+ return np.std(image).item()
54
+ else:
55
+ return np.clip(contrast * image, 0, 1)
56
+
57
+
58
+ def contrast_michelson(image: Array3D) -> float:
59
+ """Get the Michelson contrast of an image stack.
60
+
61
+ Parameters
62
+ ----------
63
+ imgs : ndarray
64
+
65
+ Returns
66
+ -------
67
+ contrast : float
68
+ """
69
+
70
+ vmax = np.max(image)
71
+ vmin = np.min(image)
72
+ return ((vmax - vmin) / (vmax + vmin)).item()
73
+
74
+
75
+ def contrast_rms(imgs: npt.NDArray[np.float32]) -> float:
76
+ """Get the RMS contrast of an image stack.
77
+
78
+ Parameters
79
+ ----------
80
+ imgs : ndarray
81
+
82
+ Returns
83
+ -------
84
+ contrast : float
85
+ """
86
+
87
+ return np.sqrt(np.mean(imgs**2)).item()
88
+
89
+
90
+ def contrast_weber(imgs: Array3D, mask: npt.NDArray[np.bool_]) -> float:
91
+ """Get the Weber contrast of an image stack.
92
+
93
+ Parameters
94
+ ----------
95
+ imgs : ndarray
96
+ mask : ndarray of bool
97
+ The mask to segment the foreground and background. 1 for
98
+ foreground, 0 for background.
99
+
100
+ Returns
101
+ -------
102
+ contrast : float
103
+ """
104
+
105
+ l_foreground = np.mean(imgs, where=mask)
106
+ l_background = np.mean(imgs, where=np.logical_not(mask))
107
+ return ((l_foreground - l_background) / l_background).item()
swcgeom/images/folder.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """Image stack folder."""
2
2
 
3
+ import math
3
4
  import os
4
5
  import re
5
6
  import warnings
6
- from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
7
8
  from typing import (
8
9
  Callable,
9
10
  Generic,
@@ -18,21 +19,18 @@ from typing import (
18
19
 
19
20
  import numpy as np
20
21
  import numpy.typing as npt
22
+ from tqdm import tqdm
21
23
  from typing_extensions import Self
22
24
 
23
25
  from swcgeom.images.io import ScalarType, read_imgs
24
26
  from swcgeom.transforms import Identity, Transform
25
27
 
26
- __all__ = [
27
- "ImageStackFolder",
28
- "LabeledImageStackFolder",
29
- "PathImageStackFolder",
30
- ]
28
+ __all__ = ["ImageStackFolder", "LabeledImageStackFolder", "PathImageStackFolder"]
31
29
 
32
30
  T = TypeVar("T")
33
31
 
34
32
 
35
- class ImageStackFolderBase(Generic[ScalarType, T], ABC):
33
+ class ImageStackFolderBase(Generic[ScalarType, T]):
36
34
  """Image stack folder base."""
37
35
 
38
36
  files: List[str]
@@ -51,10 +49,6 @@ class ImageStackFolderBase(Generic[ScalarType, T], ABC):
51
49
  self.dtype = dtype or np.float32
52
50
  self.transform = transform or Identity() # type: ignore
53
51
 
54
- @abstractmethod
55
- def __getitem__(self, key: str, /) -> T:
56
- raise NotImplementedError()
57
-
58
52
  def __len__(self) -> int:
59
53
  return len(self.files)
60
54
 
@@ -78,6 +72,12 @@ class ImageStackFolderBase(Generic[ScalarType, T], ABC):
78
72
 
79
73
  @staticmethod
80
74
  def read_imgs(fname: str) -> npt.NDArray[np.float32]:
75
+ """Read images.
76
+
77
+ .. deprecated:: 0.16.0
78
+ Use :meth:`~swcgeom.images.io.read_imgs(fname).get_full()` instead.
79
+ """
80
+
81
81
  warnings.warn(
82
82
  "`ImageStackFolderBase.read_imgs` serves as a "
83
83
  "straightforward wrapper for `~swcgeom.images.io.read_imgs(fname).get_full()`. "
@@ -89,12 +89,67 @@ class ImageStackFolderBase(Generic[ScalarType, T], ABC):
89
89
  return read_imgs(fname).get_full()
90
90
 
91
91
 
92
+ @dataclass(frozen=True)
93
+ class Statistics:
94
+ count: int = 0
95
+ minimum: float = math.nan
96
+ maximum: float = math.nan
97
+ mean: float = 0
98
+ variance: float = 0
99
+
100
+
92
101
  class ImageStackFolder(ImageStackFolderBase[ScalarType, T]):
93
102
  """Image stack folder."""
94
103
 
95
104
  def __getitem__(self, idx: int, /) -> T:
96
105
  return self._get(self.files[idx])
97
106
 
107
+ def stat(self, *, transform: bool = False, verbose: bool = False) -> Statistics:
108
+ """Statistics of folder.
109
+
110
+ Parameters
111
+ ----------
112
+ transform : bool, default to False
113
+ Apply transform to the images. If True, you need to make
114
+ sure the transformed data is a ndarray.
115
+ verbose : bool, optional
116
+
117
+ Notes
118
+ -----
119
+ We are asserting that the images are of the same shape.
120
+ """
121
+
122
+ vmin, vmax = math.inf, -math.inf
123
+ n, mean, M2 = 0, None, None
124
+
125
+ for idx in tqdm(range(len(self))) if verbose else range(len(self)):
126
+ imgs = self[idx] if transform else self._read(self.files[idx])
127
+
128
+ vmin = min(vmin, np.min(imgs)) # type: ignore
129
+ vmax = max(vmax, np.max(imgs)) # type: ignore
130
+ # Welford algorithm to calculate mean and variance
131
+ if mean is None:
132
+ mean = np.zeros_like(imgs)
133
+ M2 = np.zeros_like(imgs)
134
+
135
+ n += 1
136
+ delta = imgs - mean # type: ignore
137
+ mean += delta / n
138
+ delta2 = imgs - mean
139
+ M2 += delta * delta2
140
+
141
+ if mean is None or M2 is None: # n = 0
142
+ raise ValueError("empty folder")
143
+
144
+ variance = M2 / (n - 1) if n > 1 else np.zeros_like(mean)
145
+ return Statistics(
146
+ count=len(self),
147
+ maximum=vmax,
148
+ minimum=vmin,
149
+ mean=np.mean(mean).item(),
150
+ variance=np.mean(variance).item(),
151
+ )
152
+
98
153
  @classmethod
99
154
  def from_dir(cls, root: str, *, pattern: Optional[str] = None, **kwargs) -> Self:
100
155
  """
@@ -106,6 +161,7 @@ class ImageStackFolder(ImageStackFolderBase[ScalarType, T]):
106
161
  **kwargs
107
162
  Pass to `cls.__init__`
108
163
  """
164
+
109
165
  return cls(cls.scan(root, pattern=pattern), **kwargs)
110
166
 
111
167
 
@@ -118,8 +174,8 @@ class LabeledImageStackFolder(ImageStackFolderBase[ScalarType, T]):
118
174
  super().__init__(files, **kwargs)
119
175
  self.labels = list(labels)
120
176
 
121
- def __getitem__(self, idx: int) -> Tuple[npt.NDArray[np.float32], int]:
122
- return self.read_imgs(self.files[idx]), self.labels[idx]
177
+ def __getitem__(self, idx: int) -> Tuple[T, int]:
178
+ return self._get(self.files[idx]), self.labels[idx]
123
179
 
124
180
  @classmethod
125
181
  def from_dir(
@@ -140,7 +196,7 @@ class LabeledImageStackFolder(ImageStackFolderBase[ScalarType, T]):
140
196
  return cls(files, labels, **kwargs)
141
197
 
142
198
 
143
- class PathImageStackFolder(ImageStackFolder[ScalarType, T]):
199
+ class PathImageStackFolder(ImageStackFolderBase[ScalarType, T]):
144
200
  """Image stack folder with relpath."""
145
201
 
146
202
  root: str
@@ -164,6 +220,7 @@ class PathImageStackFolder(ImageStackFolder[ScalarType, T]):
164
220
  **kwargs
165
221
  Pass to `cls.__init__`
166
222
  """
223
+
167
224
  return cls(cls.scan(root, pattern=pattern), root=root, **kwargs)
168
225
 
169
226
 
swcgeom/images/io.py CHANGED
@@ -107,26 +107,42 @@ def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float3
107
107
  # fmt:on
108
108
 
109
109
 
110
- def read_imgs(fname: str, *, dtype=None, **kwargs): # type: ignore
111
- """Read image stack."""
112
-
113
- kwargs["dtype"] = dtype or np.float32
114
-
115
- ext = os.path.splitext(fname)[-1]
116
- if ext in [".tif", ".tiff"]:
117
- return TiffImageStack(fname, **kwargs)
118
- if ext in [".nrrd"]:
119
- return NrrdImageStack(fname, **kwargs)
120
- if ext in [".v3dpbd"]:
121
- return V3dpbdImageStack(fname, **kwargs)
122
- if ext in [".v3draw"]:
123
- return V3drawImageStack(fname, **kwargs)
124
- if ext in [".npy", ".npz"]:
125
- return NDArrayImageStack(np.load(fname), **kwargs)
110
+ def read_imgs(fname: str, **kwargs): # type: ignore
111
+ """Read image stack.
112
+
113
+ Parameters
114
+ ----------
115
+ fname : str
116
+ The path of image stack.
117
+ dtype : np.dtype, default to `np.float32`
118
+ Casting data to specified dtype. If integer and float
119
+ conversions occur, they will be scaled (assuming floats are
120
+ between 0 and 1).
121
+ **kwargs : Dict[str, Any]
122
+ Forwarding to the corresponding reader.
123
+ """
124
+
125
+ kwargs.setdefault("dtype", np.float32)
126
+ if not os.path.exists(fname):
127
+ raise ValueError(f"image stack not exists: {fname}")
128
+
129
+ # match file extension
130
+ match os.path.splitext(fname)[-1]:
131
+ case ".tif" | ".tiff":
132
+ return TiffImageStack(fname, **kwargs)
133
+ case ".nrrd":
134
+ return NrrdImageStack(fname, **kwargs)
135
+ case ".v3dpbd":
136
+ return V3dpbdImageStack(fname, **kwargs)
137
+ case ".v3draw":
138
+ return V3drawImageStack(fname, **kwargs)
139
+ case ".npy":
140
+ return NDArrayImageStack(np.load(fname), **kwargs)
141
+
142
+ # try to read as terafly
126
143
  if TeraflyImageStack.is_root(fname):
127
144
  return TeraflyImageStack(fname, **kwargs)
128
- if not os.path.exists(fname):
129
- raise ValueError("image stack not exists")
145
+
130
146
  raise ValueError("unsupported image stack")
131
147
 
132
148
 
@@ -135,7 +151,6 @@ def save_tiff(
135
151
  fname: str,
136
152
  *,
137
153
  dtype: Optional[np.unsignedinteger | np.floating] = None,
138
- swap_xy: Optional[bool] = None,
139
154
  compression: str | Literal[False] = "zlib",
140
155
  **kwargs,
141
156
  ) -> None:
@@ -164,17 +179,6 @@ def save_tiff(
164
179
  data = np.expand_dims(data, -1) # (_, _, _) -> (_, _, _, C), C === 1
165
180
 
166
181
  axes = "ZXYC"
167
- if swap_xy is not None:
168
- warnings.warn(
169
- "flag `swap_xy` is easy to implement in user space and "
170
- "is more flexiable. Since this flag is rarely used, we "
171
- "decided to remove it in the next version",
172
- DeprecationWarning,
173
- )
174
- if swap_xy is True:
175
- axes = "ZYXC"
176
- data = data.swapaxes(0, 1) # (X, Y, _, _) -> (Y, X, _, _)
177
-
178
182
  assert data.ndim == 4, "should be an array of shape (X, Y, Z, C)"
179
183
  assert data.shape[-1] in [1, 3], "support 'miniblack' or 'rgb'"
180
184
 
@@ -209,12 +213,7 @@ class NDArrayImageStack(ImageStack[ScalarType]):
209
213
  """NDArray image stack."""
210
214
 
211
215
  def __init__(
212
- self,
213
- imgs: npt.NDArray[Any],
214
- swap_xy: Optional[bool] = None,
215
- filp_xy: Optional[bool] = None,
216
- *,
217
- dtype: ScalarType,
216
+ self, imgs: npt.NDArray[Any], *, dtype: Optional[ScalarType] = None
218
217
  ) -> None:
219
218
  super().__init__()
220
219
 
@@ -222,34 +221,22 @@ class NDArrayImageStack(ImageStack[ScalarType]):
222
221
  imgs = np.expand_dims(imgs, -1)
223
222
  assert imgs.ndim == 4, "Should be shape of (X, Y, Z, C)"
224
223
 
225
- if swap_xy is not None:
226
- warnings.warn(
227
- "flag `swap_xy` now is unnecessary, tifffile will "
228
- "automatically adjust dimensions according to "
229
- "`tags.axes`, so this flag will be removed in the next "
230
- " version",
231
- DeprecationWarning,
232
- )
233
- if swap_xy is True:
234
- imgs = imgs.swapaxes(0, 1) # (Y, X, _, _) -> (X, Y, _, _)
235
-
236
- if filp_xy is not None:
237
- warnings.warn(
238
- "flag `filp_xy` is easy to implement in user space and "
239
- "is more flexiable. Since this flag is rarely used, we "
240
- "decided to remove it in the next version",
241
- DeprecationWarning,
242
- )
243
- if filp_xy is True:
244
- imgs = np.flip(imgs, (0, 1)) # (X, Y, Z, C)
224
+ if dtype is not None:
225
+ dtype_raw = imgs.dtype
226
+ if np.issubdtype(dtype, np.floating) and np.issubdtype(
227
+ dtype_raw, np.unsignedinteger
228
+ ):
229
+ sclar_factor = 1.0 / UINT_MAX[dtype_raw]
230
+ imgs = sclar_factor * imgs.astype(dtype)
231
+ elif np.issubdtype(dtype, np.unsignedinteger) and np.issubdtype(
232
+ dtype_raw, np.floating
233
+ ):
234
+ sclar_factor = UINT_MAX[dtype] # type: ignore
235
+ imgs *= (sclar_factor * imgs).astype(dtype)
236
+ else:
237
+ imgs = imgs.astype(dtype)
245
238
 
246
- dtype_raw = imgs.dtype
247
- self.imgs = imgs.astype(dtype)
248
- if np.issubdtype(dtype, np.floating) and np.issubdtype(
249
- dtype_raw, np.unsignedinteger
250
- ): # TODO: add a option to disable this
251
- sclar_factor = 1.0 / UINT_MAX[imgs.dtype]
252
- self.imgs *= sclar_factor
239
+ self.imgs = imgs
253
240
 
254
241
  def __getitem__(self, key):
255
242
  return self.imgs.__getitem__(key)
@@ -265,15 +252,7 @@ class NDArrayImageStack(ImageStack[ScalarType]):
265
252
  class TiffImageStack(NDArrayImageStack[ScalarType]):
266
253
  """Tiff image stack."""
267
254
 
268
- def __init__(
269
- self,
270
- fname: str,
271
- swap_xy: Optional[bool] = None,
272
- filp_xy: Optional[bool] = None,
273
- *,
274
- dtype: ScalarType,
275
- **kwargs,
276
- ) -> None:
255
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
277
256
  with tifffile.TiffFile(fname, **kwargs) as f:
278
257
  s = f.series[0]
279
258
  imgs, axes = s.asarray(), s.axes
@@ -285,23 +264,15 @@ class TiffImageStack(NDArrayImageStack[ScalarType]):
285
264
 
286
265
  orders = [AXES_ORDER[c] for c in axes]
287
266
  imgs = imgs.transpose(np.argsort(orders))
288
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
267
+ super().__init__(imgs, dtype=dtype)
289
268
 
290
269
 
291
270
  class NrrdImageStack(NDArrayImageStack[ScalarType]):
292
271
  """Nrrd image stack."""
293
272
 
294
- def __init__(
295
- self,
296
- fname: str,
297
- swap_xy: Optional[bool] = None,
298
- filp_xy: Optional[bool] = None,
299
- *,
300
- dtype: ScalarType,
301
- **kwargs,
302
- ) -> None:
273
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
303
274
  imgs, header = nrrd.read(fname, **kwargs)
304
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
275
+ super().__init__(imgs, dtype=dtype)
305
276
  self.header = header
306
277
 
307
278
 
@@ -333,12 +304,16 @@ class V3dpbdImageStack(V3dImageStack[ScalarType]):
333
304
  class TeraflyImageStack(ImageStack[ScalarType]):
334
305
  """TeraFly image stack.
335
306
 
307
+ TeraFly is a terabytes of multidimensional volumetric images file
308
+ format as described in [1]_.
309
+
336
310
  References
337
311
  ----------
338
- [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
339
- Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional Visualization
340
- and Annotation of Terabytes of Multidimensional Volumetric Images.”
341
- Nature Methods 13, no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
312
+ .. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
313
+ Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional
314
+ Visualization and Annotation of Terabytes of Multidimensional
315
+ Volumetric Images.” Nature Methods 13,
316
+ no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
342
317
 
343
318
  Notes
344
319
  -----
@@ -380,7 +355,12 @@ class TeraflyImageStack(ImageStack[ScalarType]):
380
355
 
381
356
  @lru_cache(maxsize=lru_maxsize)
382
357
  def read_patch(path: str) -> npt.NDArray[ScalarType]:
383
- return read_imgs(path, dtype=dtype).get_full()
358
+ match os.path.splitext(path)[-1]:
359
+ case "raw":
360
+ # Treat it as a v3draw file
361
+ return V3drawImageStack(path, dtype=dtype).get_full()
362
+ case _:
363
+ return read_imgs(path, dtype=dtype).get_full()
384
364
 
385
365
  self._listdir, self._read_patch = listdir, read_patch
386
366
 
@@ -633,6 +613,12 @@ class GrayImageStack:
633
613
 
634
614
 
635
615
  def read_images(*args, **kwargs) -> GrayImageStack:
616
+ """Read images.
617
+
618
+ .. deprecated:: 0.16.0
619
+ Use :meth:`read_imgs` instead.
620
+ """
621
+
636
622
  warnings.warn(
637
623
  "`read_images` has been replaced by `read_imgs` because it"
638
624
  "provide rgb support, and this will be removed in next version",
@@ -3,9 +3,11 @@
3
3
  from swcgeom.transforms.base import *
4
4
  from swcgeom.transforms.branch import *
5
5
  from swcgeom.transforms.geometry import *
6
+ from swcgeom.transforms.image_preprocess import *
6
7
  from swcgeom.transforms.image_stack import *
7
8
  from swcgeom.transforms.images import *
8
9
  from swcgeom.transforms.mst import *
10
+ from swcgeom.transforms.neurolucida_asc import *
9
11
  from swcgeom.transforms.path import *
10
12
  from swcgeom.transforms.population import *
11
13
  from swcgeom.transforms.tree import *
@@ -0,0 +1,100 @@
1
+ """Image stack pre-processing."""
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from scipy.fftpack import fftn, fftshift, ifftn
6
+ from scipy.ndimage import gaussian_filter, minimum_filter
7
+
8
+ from swcgeom.transforms.base import Transform
9
+
10
+ __all__ = ["SGuoImPreProcess"]
11
+
12
+
13
+ class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
14
+ """Single-Neuron Image Enhancement.
15
+
16
+ Implementation of the image enhancement method described in the paper:
17
+
18
+ Shuxia Guo, Xuan Zhao, Shengdian Jiang, Liya Ding, Hanchuan Peng,
19
+ Image enhancement to leverage the 3D morphological reconstruction
20
+ of single-cell neurons, Bioinformatics, Volume 38, Issue 2,
21
+ January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
22
+ """
23
+
24
+ def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
25
+ # TODO: support np.float32
26
+ assert x.dtype == np.uint8, "Image must be in uint8 format"
27
+ x = self.sigmoid_adjustment(x)
28
+ x = self.subtract_min_along_z(x)
29
+ x = self.bilateral_filter_3d(x)
30
+ x = self.high_pass_fft(x)
31
+ return x
32
+
33
+ @staticmethod
34
+ def sigmoid_adjustment(
35
+ image: npt.NDArray[np.uint8], sigma: float = 3, percentile: float = 25
36
+ ) -> npt.NDArray[np.uint8]:
37
+ image_normalized = image / 255.0
38
+ u = np.percentile(image_normalized, percentile)
39
+ adjusted = 1 / (1 + np.exp(-sigma * (image_normalized - u)))
40
+ adjusted_rescaled = (adjusted * 255).astype(np.uint8)
41
+ return adjusted_rescaled
42
+
43
+ @staticmethod
44
+ def subtract_min_along_z(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
45
+ min_along_z = minimum_filter(
46
+ image,
47
+ size=(1, 1, image.shape[2], 1),
48
+ mode="constant",
49
+ cval=np.max(image).item(),
50
+ )
51
+ subtracted = image - min_along_z
52
+ return subtracted
53
+
54
+ @staticmethod
55
+ def bilateral_filter_3d(
56
+ image: npt.NDArray[np.uint8], spatial_sigma=(1, 1, 0.33), range_sigma=35
57
+ ) -> npt.NDArray[np.uint8]:
58
+ # initialize the output image
59
+ filtered_image = np.zeros_like(image)
60
+
61
+ spatial_gaussian = gaussian_filter(image, spatial_sigma)
62
+
63
+ # traverse each pixel to perform bilateral filtering
64
+ # TODO: optimization is needed
65
+ for z in range(image.shape[2]):
66
+ for y in range(image.shape[1]):
67
+ for x in range(image.shape[0]):
68
+ value = image[x, y, z]
69
+ range_weight = np.exp(
70
+ -((image - value) ** 2) / (2 * range_sigma**2)
71
+ )
72
+ weights = spatial_gaussian * range_weight
73
+ filtered_image[x, y, z] = np.sum(image * weights) / np.sum(weights)
74
+
75
+ return filtered_image
76
+
77
+ @staticmethod
78
+ def high_pass_fft(image: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
79
+ # fft
80
+ fft_image = fftn(image)
81
+ fft_shifted = fftshift(fft_image)
82
+
83
+ # create a high-pass filter
84
+ h, w, d, _ = image.shape
85
+ x, y, z = np.ogrid[:h, :w, :d]
86
+ center = (h / 2, w / 2, d / 2)
87
+ distance = np.sqrt(
88
+ (x - center[0]) ** 2 + (y - center[1]) ** 2 + (z - center[2]) ** 2
89
+ )
90
+ # adjust this threshold to control the filtering strength
91
+ high_pass_mask = distance > (d // 4)
92
+ # apply the high-pass filter
93
+ fft_shifted *= high_pass_mask
94
+
95
+ # inverse fft
96
+ fft_unshifted = np.fft.ifftshift(fft_shifted)
97
+ filtered_image = np.real(ifftn(fft_unshifted))
98
+
99
+ filtered_rescaled = np.clip(filtered_image, 0, 255).astype(np.uint8)
100
+ return filtered_rescaled
@@ -25,6 +25,7 @@ from sdflit import (
25
25
  Scene,
26
26
  SDFObject,
27
27
  )
28
+ from tqdm import tqdm
28
29
 
29
30
  from swcgeom.core import Population, Tree
30
31
  from swcgeom.transforms.base import Transform
@@ -89,8 +90,6 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
89
90
  samplers = self._get_samplers(coord_min, coord_max)
90
91
 
91
92
  if verbose:
92
- from tqdm import tqdm
93
-
94
93
  total = (coord_max[2] - coord_min[2]) / self.resolution[2]
95
94
  samplers = tqdm(samplers, total=total.astype(np.int64).item())
96
95
 
@@ -117,8 +116,6 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
117
116
  )
118
117
 
119
118
  if verbose:
120
- from tqdm import tqdm
121
-
122
119
  trees = tqdm(trees)
123
120
 
124
121
  # TODO: multiprocess