swcgeom 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (68) hide show
  1. swcgeom/__init__.py +12 -1
  2. swcgeom/analysis/__init__.py +6 -6
  3. swcgeom/analysis/feature_extractor.py +22 -24
  4. swcgeom/analysis/features.py +18 -40
  5. swcgeom/analysis/lmeasure.py +227 -323
  6. swcgeom/analysis/sholl.py +17 -23
  7. swcgeom/analysis/trunk.py +23 -28
  8. swcgeom/analysis/visualization.py +37 -44
  9. swcgeom/analysis/visualization3d.py +16 -25
  10. swcgeom/analysis/volume.py +33 -47
  11. swcgeom/core/__init__.py +12 -13
  12. swcgeom/core/branch.py +10 -17
  13. swcgeom/core/branch_tree.py +3 -2
  14. swcgeom/core/compartment.py +1 -1
  15. swcgeom/core/node.py +3 -6
  16. swcgeom/core/path.py +11 -16
  17. swcgeom/core/population.py +32 -51
  18. swcgeom/core/swc.py +25 -16
  19. swcgeom/core/swc_utils/__init__.py +10 -12
  20. swcgeom/core/swc_utils/assembler.py +5 -12
  21. swcgeom/core/swc_utils/base.py +40 -31
  22. swcgeom/core/swc_utils/checker.py +3 -8
  23. swcgeom/core/swc_utils/io.py +32 -47
  24. swcgeom/core/swc_utils/normalizer.py +17 -23
  25. swcgeom/core/swc_utils/subtree.py +13 -20
  26. swcgeom/core/tree.py +61 -51
  27. swcgeom/core/tree_utils.py +36 -49
  28. swcgeom/core/tree_utils_impl.py +4 -6
  29. swcgeom/images/__init__.py +2 -2
  30. swcgeom/images/augmentation.py +23 -39
  31. swcgeom/images/contrast.py +22 -46
  32. swcgeom/images/folder.py +32 -34
  33. swcgeom/images/io.py +80 -121
  34. swcgeom/transforms/__init__.py +13 -13
  35. swcgeom/transforms/base.py +28 -19
  36. swcgeom/transforms/branch.py +31 -41
  37. swcgeom/transforms/branch_tree.py +3 -1
  38. swcgeom/transforms/geometry.py +13 -4
  39. swcgeom/transforms/image_preprocess.py +2 -0
  40. swcgeom/transforms/image_stack.py +40 -35
  41. swcgeom/transforms/images.py +31 -24
  42. swcgeom/transforms/mst.py +27 -40
  43. swcgeom/transforms/neurolucida_asc.py +13 -13
  44. swcgeom/transforms/path.py +4 -0
  45. swcgeom/transforms/population.py +4 -0
  46. swcgeom/transforms/tree.py +16 -11
  47. swcgeom/transforms/tree_assembler.py +37 -54
  48. swcgeom/utils/__init__.py +12 -12
  49. swcgeom/utils/download.py +7 -14
  50. swcgeom/utils/dsu.py +12 -0
  51. swcgeom/utils/ellipse.py +26 -14
  52. swcgeom/utils/file.py +8 -13
  53. swcgeom/utils/neuromorpho.py +78 -92
  54. swcgeom/utils/numpy_helper.py +15 -12
  55. swcgeom/utils/plotter_2d.py +10 -16
  56. swcgeom/utils/plotter_3d.py +7 -9
  57. swcgeom/utils/renderer.py +16 -8
  58. swcgeom/utils/sdf.py +12 -23
  59. swcgeom/utils/solid_geometry.py +58 -2
  60. swcgeom/utils/transforms.py +164 -100
  61. swcgeom/utils/volumetric_object.py +29 -53
  62. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/METADATA +7 -6
  63. swcgeom-0.19.0.dist-info/RECORD +67 -0
  64. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/WHEEL +1 -1
  65. swcgeom/_version.py +0 -16
  66. swcgeom-0.18.1.dist-info/RECORD +0 -68
  67. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info/licenses}/LICENSE +0 -0
  68. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Iterable
16
+ from collections.abc import Iterable
17
17
 
18
18
  import numpy as np
19
+ from typing_extensions import override
19
20
 
20
21
  from swcgeom.core import Branch, BranchTree, Node, Tree
21
22
  from swcgeom.transforms.base import Transform
@@ -26,6 +27,7 @@ __all__ = ["BranchTreeAssembler"]
26
27
  class BranchTreeAssembler(Transform[BranchTree, Tree]):
27
28
  EPS = 1e-6
28
29
 
30
+ @override
29
31
  def __call__(self, x: BranchTree) -> Tree:
30
32
  nodes = [x.soma().detach()]
31
33
  stack = [(x.soma(), 0)] # n_orig, id_new
@@ -16,10 +16,11 @@
16
16
  """SWC geometry operations."""
17
17
 
18
18
  import warnings
19
- from typing import Generic, Literal, Optional, TypeVar
19
+ from typing import Generic, Literal, TypeVar
20
20
 
21
21
  import numpy as np
22
22
  import numpy.typing as npt
23
+ from typing_extensions import override
23
24
 
24
25
  from swcgeom.core import DictSWC
25
26
  from swcgeom.core.swc_utils import SWCNames
@@ -54,7 +55,7 @@ Center = Literal["root", "soma", "origin"]
54
55
  class Normalizer(Generic[T], Transform[T, T]):
55
56
  """Noramlize coordinates and radius to 0-1."""
56
57
 
57
- def __init__(self, *, names: Optional[SWCNames] = None) -> None:
58
+ def __init__(self, *, names: SWCNames | None = None) -> None:
58
59
  super().__init__()
59
60
  if names is not None:
60
61
  warnings.warn(
@@ -63,6 +64,7 @@ class Normalizer(Generic[T], Transform[T, T]):
63
64
  DeprecationWarning,
64
65
  )
65
66
 
67
+ @override
66
68
  def __call__(self, x: T) -> T:
67
69
  """Scale the `x`, `y`, `z`, `r` of nodes to 0-1."""
68
70
  new_tree = x.copy()
@@ -87,6 +89,7 @@ class RadiusReseter(Generic[T], Transform[T, T]):
87
89
  new_tree.ndata[new_tree.names.r] = r
88
90
  return new_tree
89
91
 
92
+ @override
90
93
  def extra_repr(self) -> str:
91
94
  return f"r={self.r:.4f}"
92
95
 
@@ -103,8 +106,8 @@ class AffineTransform(Generic[T], Transform[T, T]):
103
106
  tm: npt.NDArray[np.float32],
104
107
  center: Center = "origin",
105
108
  *,
106
- fmt: Optional[str] = None,
107
- names: Optional[SWCNames] = None,
109
+ fmt: str | None = None,
110
+ names: SWCNames | None = None,
108
111
  ) -> None:
109
112
  self.tm, self.center = tm, center
110
113
 
@@ -122,6 +125,7 @@ class AffineTransform(Generic[T], Transform[T, T]):
122
125
  DeprecationWarning,
123
126
  )
124
127
 
128
+ @override
125
129
  def __call__(self, x: T) -> T:
126
130
  match self.center:
127
131
  case "root" | "soma":
@@ -156,6 +160,7 @@ class Translate(Generic[T], AffineTransform[T]):
156
160
  super().__init__(translate3d(tx, ty, tz), **kwargs)
157
161
  self.tx, self.ty, self.tz = tx, ty, tz
158
162
 
163
+ @override
159
164
  def extra_repr(self) -> str:
160
165
  return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
161
166
 
@@ -209,6 +214,7 @@ class Rotate(Generic[T], AffineTransform[T]):
209
214
  self.theta = theta
210
215
  self.center = center
211
216
 
217
+ @override
212
218
  def extra_repr(self) -> str:
213
219
  return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: improve format of n
214
220
 
@@ -231,6 +237,7 @@ class RotateX(Generic[T], AffineTransform[T]):
231
237
  super().__init__(rotate3d_x(theta), center=center, **kwargs)
232
238
  self.theta = theta
233
239
 
240
+ @override
234
241
  def extra_repr(self) -> str:
235
242
  return f"center={self.center}, theta={self.theta:.4f}"
236
243
 
@@ -247,6 +254,7 @@ class RotateY(Generic[T], AffineTransform[T]):
247
254
  self.theta = theta
248
255
  self.center = center
249
256
 
257
+ @override
250
258
  def extra_repr(self) -> str:
251
259
  return f"theta={self.theta:.4f}, center={self.center}"
252
260
 
@@ -263,6 +271,7 @@ class RotateZ(Generic[T], AffineTransform[T]):
263
271
  self.theta = theta
264
272
  self.center = center
265
273
 
274
+ @override
266
275
  def extra_repr(self) -> str:
267
276
  return f"theta={self.theta:.4f}, center={self.center}"
268
277
 
@@ -19,6 +19,7 @@ import numpy as np
19
19
  import numpy.typing as npt
20
20
  from scipy.fftpack import fftn, fftshift, ifftn
21
21
  from scipy.ndimage import gaussian_filter, minimum_filter
22
+ from typing_extensions import override
22
23
 
23
24
  from swcgeom.transforms.base import Transform
24
25
 
@@ -36,6 +37,7 @@ class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
36
37
  January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
37
38
  """
38
39
 
40
+ @override
39
41
  def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
40
42
  # TODO: support np.float32
41
43
  assert x.dtype == np.uint8, "Image must be in uint8 format"
@@ -15,9 +15,7 @@
15
15
 
16
16
  """Create image stack from morphology.
17
17
 
18
- Notes
19
- -----
20
- All denpendencies need to be installed, try:
18
+ NOTE: All denpendencies need to be installed, try:
21
19
 
22
20
  ```sh
23
21
  pip install swcgeom[all]
@@ -28,7 +26,7 @@ import os
28
26
  import re
29
27
  import time
30
28
  from collections.abc import Iterable
31
- from typing import Optional
29
+ from typing import Sequence
32
30
 
33
31
  import numpy as np
34
32
  import numpy.typing as npt
@@ -42,7 +40,7 @@ from sdflit import (
42
40
  SDFObject,
43
41
  )
44
42
  from tqdm import tqdm
45
- from typing_extensions import deprecated
43
+ from typing_extensions import deprecated, override
46
44
 
47
45
  from swcgeom.core import Population, Tree
48
46
  from swcgeom.transforms.base import Transform
@@ -58,46 +56,53 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
58
56
  def __init__(self, resolution: int | float | npt.ArrayLike = 1) -> None:
59
57
  """Transform tree to image stack.
60
58
 
61
- Parameters
62
- ----------
63
- resolution : int | (x, y, z), default `(1, 1, 1)`
64
- Resolution of image stack.
59
+ Args:
60
+ resolution: Resolution of image stack.
61
+ If a scalar, it will be broadcasted to a vector of 3d.
65
62
  """
66
-
67
63
  if isinstance(resolution, (int, float, np.integer, np.floating)):
68
64
  resolution = [resolution, resolution, resolution] # type: ignore
69
65
 
70
66
  self.resolution = np.array(resolution, dtype=np.float32)
71
- assert len(self.resolution) == 3, "resolution shoule be vector of 3d."
67
+ assert len(self.resolution) == 3, "resolution should be vector of 3d."
72
68
 
69
+ @override
73
70
  def __call__(self, x: Tree) -> npt.NDArray[np.uint8]:
74
71
  """Transform tree to image stack.
75
72
 
76
- Notes
77
- -----
78
- This method loads the entire image stack into memory, so it
79
- ONLY works for small image stacks, use :meth`transform_and_save`
80
- for big image stack.
73
+ NOTE: This method loads the entire image stack into memory, so it ONLY works
74
+ for small image stacks, use :meth`transform_and_save` for big image stack.
81
75
  """
82
76
  return np.stack(list(self.transform(x, verbose=False)), axis=0)
83
77
 
84
78
  def transform(
85
79
  self,
86
- x: Tree,
80
+ x: Tree | Sequence[Tree],
87
81
  verbose: bool = True,
88
82
  *,
89
- ranges: Optional[tuple[npt.ArrayLike, npt.ArrayLike]] = None,
83
+ ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
90
84
  ) -> Iterable[npt.NDArray[np.uint8]]:
85
+ trees = [x] if isinstance(x, Tree) else x
86
+ if not trees:
87
+ return iter([]) # Return empty iterator if sequence is empty
88
+
89
+ time_start = None
91
90
  if verbose:
92
- print("To image stack: " + x.source)
91
+ sources = ", ".join(t.source for t in trees if t.source)
92
+ print(f"To image stack: {sources if sources else 'unnamed trees'}")
93
93
  time_start = time.time()
94
94
 
95
- scene = self._get_scene(x)
95
+ scene = self._get_scene(trees)
96
96
 
97
97
  if ranges is None:
98
- xyz, r = x.xyz(), x.r().reshape(-1, 1)
99
- coord_min = np.floor(np.min(xyz - r, axis=0))
100
- coord_max = np.ceil(np.max(xyz + r, axis=0))
98
+ all_xyz = np.concatenate([t.xyz() for t in trees], axis=0)
99
+ all_r = np.concatenate([t.r() for t in trees], axis=0).reshape(-1, 1)
100
+ if all_xyz.size == 0: # Handle empty trees
101
+ coord_min = np.zeros(3, dtype=np.float32)
102
+ coord_max = np.zeros(3, dtype=np.float32)
103
+ else:
104
+ coord_min = np.floor(np.min(all_xyz - all_r, axis=0))
105
+ coord_max = np.ceil(np.max(all_xyz + all_r, axis=0))
101
106
  else:
102
107
  assert len(ranges) == 2
103
108
  coord_min = np.array(ranges[0])
@@ -106,12 +111,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
106
111
 
107
112
  samplers = self._get_samplers(coord_min, coord_max)
108
113
 
109
- if verbose:
114
+ if verbose and time_start is not None:
110
115
  total = (coord_max[2] - coord_min[2]) / self.resolution[2]
111
116
  samplers = tqdm(samplers, total=total.astype(np.int64).item())
112
117
 
113
118
  time_end = time.time()
114
- print("Prepare in: ", time_end - time_start, "s") # type: ignore
119
+ print("Prepare in: ", time_end - time_start, "s")
115
120
 
116
121
  for sampler in samplers:
117
122
  voxel = sampler.sample(scene) # should be shape of (x, y, z, 3) and z = 1
@@ -124,12 +129,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
124
129
  x: Tree,
125
130
  verbose: bool = True,
126
131
  *,
127
- ranges: Optional[tuple[npt.ArrayLike, npt.ArrayLike]] = None,
132
+ ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
128
133
  ) -> Iterable[npt.NDArray[np.uint8]]:
129
134
  return self.transform(x, verbose, ranges=ranges)
130
135
 
131
136
  def transform_and_save(
132
- self, fname: str, x: Tree, verbose: bool = True, **kwargs
137
+ self, fname: str, x: Tree | Sequence[Tree], verbose: bool = True, **kwargs
133
138
  ) -> None:
134
139
  self.save_tif(fname, self.transform(x, verbose=verbose, **kwargs))
135
140
 
@@ -151,11 +156,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
151
156
  if not os.path.isfile(tif):
152
157
  self.transform_and_save(tif, tree, verbose=False)
153
158
 
159
+ @override
154
160
  def extra_repr(self) -> str:
155
161
  res = ",".join(f"{a:.4f}" for a in self.resolution)
156
162
  return f"resolution=({res})"
157
163
 
158
- def _get_scene(self, x: Tree) -> Scene:
164
+ def _get_scene(self, trees: Sequence[Tree]) -> Scene:
159
165
  material = ColoredMaterial((1, 0, 0)).into()
160
166
  scene = ObjectsScene()
161
167
  scene.set_background((0, 0, 0))
@@ -164,10 +170,11 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
164
170
  for c in children:
165
171
  sdf = RoundCone(_tp3f(n.xyz()), _tp3f(c.xyz()), n.r, c.r).into()
166
172
  scene.add_object(SDFObject(sdf, material).into())
167
-
168
173
  return n
169
174
 
170
- x.traverse(leave=leave)
175
+ for tree in trees:
176
+ tree.traverse(leave=leave)
177
+
171
178
  scene.build_bvh()
172
179
  return scene.into()
173
180
 
@@ -175,14 +182,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
175
182
  self,
176
183
  coord_min: npt.NDArray,
177
184
  coord_max: npt.NDArray,
178
- offset: Optional[npt.NDArray] = None,
185
+ offset: npt.NDArray | None = None,
179
186
  ) -> Iterable[RangeSampler]:
180
187
  """Get Samplers.
181
188
 
182
- Parameters
183
- ----------
184
- coord_min, coord_max: npt.ArrayLike
185
- Coordinates array of shape (3,).
189
+ Args:
190
+ coord_min, coord_max: Coordinates array of shape (3,).
186
191
  """
187
192
 
188
193
  eps = 1e-6
@@ -17,7 +17,7 @@
17
17
 
18
18
  import numpy as np
19
19
  import numpy.typing as npt
20
- from typing_extensions import deprecated
20
+ from typing_extensions import deprecated, override
21
21
 
22
22
  from swcgeom.transforms.base import Identity, Transform
23
23
 
@@ -49,12 +49,14 @@ class ImagesCenterCrop(Transform[NDArrayf32, NDArrayf32]):
49
49
  else (shape_out, shape_out, shape_out)
50
50
  )
51
51
 
52
+ @override
52
53
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
53
54
  diff = np.subtract(x.shape[:3], self.shape_out)
54
55
  s = diff // 2
55
56
  e = np.add(s, self.shape_out)
56
57
  return x[s[0] : e[0], s[1] : e[1], s[2] : e[2], :]
57
58
 
59
+ @override
58
60
  def extra_repr(self) -> str:
59
61
  return f"shape_out=({','.join(str(a) for a in self.shape_out)})"
60
62
 
@@ -73,9 +75,11 @@ class ImagesScale(Transform[NDArrayf32, NDArrayf32]):
73
75
  super().__init__()
74
76
  self.scaler = scaler
75
77
 
78
+ @override
76
79
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
77
80
  return self.scaler * x
78
81
 
82
+ @override
79
83
  def extra_repr(self) -> str:
80
84
  return f"scaler={self.scaler}"
81
85
 
@@ -85,9 +89,11 @@ class ImagesClip(Transform[NDArrayf32, NDArrayf32]):
85
89
  super().__init__()
86
90
  self.vmin, self.vmax = vmin, vmax
87
91
 
92
+ @override
88
93
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
89
94
  return np.clip(x, self.vmin, self.vmax)
90
95
 
96
+ @override
91
97
  def extra_repr(self) -> str:
92
98
  return f"vmin={self.vmin}, vmax={self.vmax}"
93
99
 
@@ -99,9 +105,11 @@ class ImagesFlip(Transform[NDArrayf32, NDArrayf32]):
99
105
  super().__init__()
100
106
  self.axis = axis
101
107
 
108
+ @override
102
109
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
103
110
  return np.flip(x, axis=self.axis)
104
111
 
112
+ @override
105
113
  def extra_repr(self) -> str:
106
114
  return f"axis={self.axis}"
107
115
 
@@ -109,15 +117,13 @@ class ImagesFlip(Transform[NDArrayf32, NDArrayf32]):
109
117
  class ImagesFlipY(ImagesFlip):
110
118
  """Flip image stack along Y-axis.
111
119
 
112
- See Also
113
- --------
114
- ~.images.io.TeraflyImageStack:
115
- Terafly and Vaa3d use a especial right-handed coordinate system
116
- (with origin point in the left-top and z-axis points front),
117
- but we flip y-axis to makes it a left-handed coordinate system
118
- (with orgin point in the left-bottom and z-axis points front).
119
- If you need to use its coordinate system, remember to FLIP
120
- Y-AXIS BACK.
120
+ See Also:
121
+ ~.images.io.TeraflyImageStack:
122
+ Terafly and Vaa3d use a especial right-handed coordinate system (with
123
+ origin point in the left-top and z-axis points front), but we flip y-axis
124
+ to makes it a left-handed coordinate system (with origin point in the
125
+ left-bottom and z-axis points front). If you need to use its coordinate
126
+ system, remember to FLIP Y-AXIS BACK.
121
127
  """
122
128
 
123
129
  def __init__(self, axis: int = 1, /) -> None:
@@ -127,6 +133,7 @@ class ImagesFlipY(ImagesFlip):
127
133
  class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]):
128
134
  """Normalize image stack."""
129
135
 
136
+ @override
130
137
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
131
138
  mean = np.mean(x)
132
139
  variance = np.var(x)
@@ -136,9 +143,8 @@ class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]):
136
143
  class ImagesMeanVarianceAdjustment(Transform[NDArrayf32, NDArrayf32]):
137
144
  """Adjust image stack mean and variance.
138
145
 
139
- See Also
140
- --------
141
- ~swcgeom.images.ImageStackFolder.stat
146
+ See Also:
147
+ ~swcgeom.images.ImageStackFolder.stat
142
148
  """
143
149
 
144
150
  def __init__(self, mean: float, variance: float) -> None:
@@ -146,9 +152,11 @@ class ImagesMeanVarianceAdjustment(Transform[NDArrayf32, NDArrayf32]):
146
152
  self.mean = mean
147
153
  self.variance = variance
148
154
 
155
+ @override
149
156
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
150
157
  return (x - self.mean) / self.variance
151
158
 
159
+ @override
152
160
  def extra_repr(self) -> str:
153
161
  return f"mean={self.mean}, variance={self.variance}"
154
162
 
@@ -159,14 +167,10 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
159
167
  def __init__(self, vmin: float, vmax: float, *, clip: bool = True) -> None:
160
168
  """Scale image stack to unit range.
161
169
 
162
- Parameters
163
- ----------
164
- vmin : float
165
- Minimum value.
166
- vmax : float
167
- Maximum value.
168
- clip : bool, default True
169
- Clip values to [0, 1] to avoid numerical issues.
170
+ Args:
171
+ vmin: Minimum value.
172
+ vmax: Maximum value.
173
+ clip: Clip values to [0, 1] to avoid numerical issues.
170
174
  """
171
175
 
172
176
  super().__init__()
@@ -176,9 +180,11 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
176
180
  self.clip = clip
177
181
  self.post = ImagesClip(0, 1) if self.clip else Identity()
178
182
 
183
+ @override
179
184
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
180
185
  return self.post((x - self.vmin) / self.diff)
181
186
 
187
+ @override
182
188
  def extra_repr(self) -> str:
183
189
  return f"vmin={self.vmin}, vmax={self.vmax}, clip={self.clip}"
184
190
 
@@ -186,15 +192,15 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
186
192
  class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]):
187
193
  """Image histogram equalization.
188
194
 
189
- References
190
- ----------
191
- http://www.janeriksolem.net/histogram-equalization-with-python-and.html
195
+ References:
196
+ http://www.janeriksolem.net/histogram-equalization-with-python-and.html
192
197
  """
193
198
 
194
199
  def __init__(self, bins: int = 256) -> None:
195
200
  super().__init__()
196
201
  self.bins = bins
197
202
 
203
+ @override
198
204
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
199
205
  # get image histogram
200
206
  hist, bin_edges = np.histogram(x.flatten(), self.bins, density=True)
@@ -205,5 +211,6 @@ class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]):
205
211
  equalized = np.interp(x.flatten(), bin_edges[:-1], cdf)
206
212
  return equalized.reshape(x.shape).astype(np.float32)
207
213
 
214
+ @override
208
215
  def extra_repr(self) -> str:
209
216
  return f"bins={self.bins}"
swcgeom/transforms/mst.py CHANGED
@@ -16,12 +16,12 @@
16
16
  """Minimum spanning tree."""
17
17
 
18
18
  import warnings
19
- from typing import Optional
20
19
 
21
20
  import numpy as np
22
21
  import pandas as pd
23
22
  from numpy import ma
24
23
  from numpy import typing as npt
24
+ from typing_extensions import override
25
25
 
26
26
  from swcgeom.core import Tree, sort_tree
27
27
  from swcgeom.core.swc_utils import SWCNames, SWCTypes, get_names, get_types
@@ -36,8 +36,7 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
36
36
  Creates trees corresponding to the minimum spanning tree keeping
37
37
  the path length to the root small (with balancing factor bf).
38
38
 
39
- References
40
- ----------
39
+ References:
41
40
  .. [1] Cuntz, H., Forstner, F., Borst, A. & Häusser, M. One Rule to
42
41
  Grow Them Al: A General Theory of Neuronal Branching and Its
43
42
  Practical Application. PLOS Comput Biol 6, e1000877 (2010).
@@ -52,21 +51,15 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
52
51
  furcations: int = 2,
53
52
  exclude_soma: bool = True,
54
53
  sort: bool = True,
55
- names: Optional[SWCNames] = None,
56
- types: Optional[SWCTypes] = None,
54
+ names: SWCNames | None = None,
55
+ types: SWCTypes | None = None,
57
56
  ) -> None:
58
57
  """
59
- Parameters
60
- ----------
61
- bf : float, default `0.4`
62
- Balancing factor between 0~1.
63
- furcations : int, default `2`
64
- Suppress multi-furcations which more than k. If set to -1,
65
- no suppression.
66
- exclude_soma : bool, default `True`
67
- Suppress multi-furcations exclude soma.
68
- names : SWCNames, optional
69
- types : SWCTypes, optional
58
+ Args:
59
+ bf: Balancing factor between 0~1.
60
+ furcations: Suppress multi-furcations which more than k.
61
+ If set to -1, no suppression.
62
+ exclude_soma: Suppress multi-furcations exclude soma.
70
63
  """
71
64
  self.bf = np.clip(bf, 0, 1)
72
65
  self.furcations = furcations
@@ -75,29 +68,26 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
75
68
  self.names = get_names(names)
76
69
  self.types = get_types(types)
77
70
 
71
+ @override
78
72
  def __call__( # pylint: disable=too-many-locals
79
73
  self,
80
74
  points: npt.NDArray[np.floating],
81
- soma: Optional[npt.ArrayLike] = None,
75
+ soma: npt.ArrayLike | None = None,
82
76
  *,
83
- names: Optional[SWCNames] = None,
77
+ names: SWCNames | None = None,
84
78
  ) -> Tree:
85
79
  """
86
- Paramters
87
- ---------
88
- points : array of shape (N, 3)
89
- Positions of points cloud.
90
- soma : array of shape (3,), default `None`
91
- Position of soma. If none, use the first point as soma.
92
- names : SWCNames, optional
80
+ Args:
81
+ points: Positions of points cloud.
82
+ soma: Position of soma. If none, use the first point as soma.
93
83
  """
94
84
  if names is None:
95
85
  names = self.names
96
86
  else:
97
87
  warnings.warn(
98
- "`PointsToCuntzMST(...)(names=...)` has been "
99
- "replaced by `PointsToCuntzMST(...,names=...)` since "
100
- "v0.12.0, and will be removed in next version",
88
+ "`PointsToCuntzMST(...)(names=...)` has been replaced by "
89
+ "`PointsToCuntzMST(...,names=...)` since v0.12.0, and will be removed "
90
+ "in next version",
101
91
  DeprecationWarning,
102
92
  )
103
93
  names = get_names(names) # TODO: remove it
@@ -155,6 +145,7 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
155
145
  t = sort_tree(t)
156
146
  return t
157
147
 
148
+ @override
158
149
  def extra_repr(self) -> str: # TODO: names, types
159
150
  return f"bf={self.bf:.4f}, furcations={self.furcations}, exclude_soma={self.exclude_soma}, sort={self.sort}"
160
151
 
@@ -166,22 +157,17 @@ class PointsToMST(PointsToCuntzMST): # pylint: disable=too-few-public-methods
166
157
  self,
167
158
  furcations: int = 2,
168
159
  *,
169
- k_furcations: Optional[int] = None,
160
+ k_furcations: int | None = None,
170
161
  exclude_soma: bool = True,
171
- names: Optional[SWCNames] = None,
172
- types: Optional[SWCTypes] = None,
162
+ names: SWCNames | None = None,
163
+ types: SWCTypes | None = None,
173
164
  **kwargs,
174
165
  ) -> None:
175
166
  """
176
- Parameters
177
- ----------
178
- furcations : int, default `2`
179
- Suppress multifurcations which more than k. If set to -1,
180
- no suppression.
181
- exclude_soma : bool, default `True`
182
- Suppress multi-furcations exclude soma.
183
- names : SWCNames, optional
184
- types : SWCTypes, optional
167
+ Args:
168
+ furcations: Suppress multifurcations which more than k.
169
+ If set to -1, no suppression.
170
+ exclude_soma: Suppress multi-furcations exclude soma.
185
171
  """
186
172
 
187
173
  if k_furcations is not None:
@@ -202,5 +188,6 @@ class PointsToMST(PointsToCuntzMST): # pylint: disable=too-few-public-methods
202
188
  **kwargs,
203
189
  )
204
190
 
191
+ @override
205
192
  def extra_repr(self) -> str:
206
193
  return f"furcations-{self.furcations}, exclude-soma={self.exclude_soma}"