swcgeom 0.18.3__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (62) hide show
  1. swcgeom/analysis/feature_extractor.py +22 -24
  2. swcgeom/analysis/features.py +18 -40
  3. swcgeom/analysis/lmeasure.py +227 -323
  4. swcgeom/analysis/sholl.py +17 -23
  5. swcgeom/analysis/trunk.py +23 -28
  6. swcgeom/analysis/visualization.py +37 -44
  7. swcgeom/analysis/visualization3d.py +16 -25
  8. swcgeom/analysis/volume.py +33 -47
  9. swcgeom/core/__init__.py +1 -6
  10. swcgeom/core/branch.py +10 -17
  11. swcgeom/core/branch_tree.py +3 -2
  12. swcgeom/core/compartment.py +1 -1
  13. swcgeom/core/node.py +3 -6
  14. swcgeom/core/path.py +11 -16
  15. swcgeom/core/population.py +32 -51
  16. swcgeom/core/swc.py +25 -16
  17. swcgeom/core/swc_utils/__init__.py +4 -6
  18. swcgeom/core/swc_utils/assembler.py +5 -12
  19. swcgeom/core/swc_utils/base.py +40 -31
  20. swcgeom/core/swc_utils/checker.py +3 -8
  21. swcgeom/core/swc_utils/io.py +32 -47
  22. swcgeom/core/swc_utils/normalizer.py +17 -23
  23. swcgeom/core/swc_utils/subtree.py +13 -20
  24. swcgeom/core/tree.py +61 -51
  25. swcgeom/core/tree_utils.py +36 -49
  26. swcgeom/core/tree_utils_impl.py +4 -6
  27. swcgeom/images/augmentation.py +23 -39
  28. swcgeom/images/contrast.py +22 -46
  29. swcgeom/images/folder.py +32 -34
  30. swcgeom/images/io.py +80 -121
  31. swcgeom/transforms/base.py +28 -19
  32. swcgeom/transforms/branch.py +31 -41
  33. swcgeom/transforms/branch_tree.py +3 -1
  34. swcgeom/transforms/geometry.py +13 -4
  35. swcgeom/transforms/image_preprocess.py +2 -0
  36. swcgeom/transforms/image_stack.py +40 -35
  37. swcgeom/transforms/images.py +31 -24
  38. swcgeom/transforms/mst.py +27 -40
  39. swcgeom/transforms/neurolucida_asc.py +13 -13
  40. swcgeom/transforms/path.py +4 -0
  41. swcgeom/transforms/population.py +4 -0
  42. swcgeom/transforms/tree.py +16 -11
  43. swcgeom/transforms/tree_assembler.py +37 -54
  44. swcgeom/utils/download.py +7 -14
  45. swcgeom/utils/dsu.py +12 -0
  46. swcgeom/utils/ellipse.py +26 -14
  47. swcgeom/utils/file.py +8 -13
  48. swcgeom/utils/neuromorpho.py +78 -92
  49. swcgeom/utils/numpy_helper.py +15 -12
  50. swcgeom/utils/plotter_2d.py +10 -16
  51. swcgeom/utils/plotter_3d.py +7 -9
  52. swcgeom/utils/renderer.py +16 -8
  53. swcgeom/utils/sdf.py +12 -23
  54. swcgeom/utils/solid_geometry.py +58 -2
  55. swcgeom/utils/transforms.py +164 -100
  56. swcgeom/utils/volumetric_object.py +29 -53
  57. {swcgeom-0.18.3.dist-info → swcgeom-0.19.0.dist-info}/METADATA +5 -4
  58. swcgeom-0.19.0.dist-info/RECORD +67 -0
  59. {swcgeom-0.18.3.dist-info → swcgeom-0.19.0.dist-info}/WHEEL +1 -1
  60. swcgeom-0.18.3.dist-info/RECORD +0 -67
  61. {swcgeom-0.18.3.dist-info → swcgeom-0.19.0.dist-info/licenses}/LICENSE +0 -0
  62. {swcgeom-0.18.3.dist-info → swcgeom-0.19.0.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@
17
17
 
18
18
  import numpy as np
19
19
  import numpy.typing as npt
20
- from typing_extensions import deprecated
20
+ from typing_extensions import deprecated, override
21
21
 
22
22
  from swcgeom.transforms.base import Identity, Transform
23
23
 
@@ -49,12 +49,14 @@ class ImagesCenterCrop(Transform[NDArrayf32, NDArrayf32]):
49
49
  else (shape_out, shape_out, shape_out)
50
50
  )
51
51
 
52
+ @override
52
53
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
53
54
  diff = np.subtract(x.shape[:3], self.shape_out)
54
55
  s = diff // 2
55
56
  e = np.add(s, self.shape_out)
56
57
  return x[s[0] : e[0], s[1] : e[1], s[2] : e[2], :]
57
58
 
59
+ @override
58
60
  def extra_repr(self) -> str:
59
61
  return f"shape_out=({','.join(str(a) for a in self.shape_out)})"
60
62
 
@@ -73,9 +75,11 @@ class ImagesScale(Transform[NDArrayf32, NDArrayf32]):
73
75
  super().__init__()
74
76
  self.scaler = scaler
75
77
 
78
+ @override
76
79
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
77
80
  return self.scaler * x
78
81
 
82
+ @override
79
83
  def extra_repr(self) -> str:
80
84
  return f"scaler={self.scaler}"
81
85
 
@@ -85,9 +89,11 @@ class ImagesClip(Transform[NDArrayf32, NDArrayf32]):
85
89
  super().__init__()
86
90
  self.vmin, self.vmax = vmin, vmax
87
91
 
92
+ @override
88
93
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
89
94
  return np.clip(x, self.vmin, self.vmax)
90
95
 
96
+ @override
91
97
  def extra_repr(self) -> str:
92
98
  return f"vmin={self.vmin}, vmax={self.vmax}"
93
99
 
@@ -99,9 +105,11 @@ class ImagesFlip(Transform[NDArrayf32, NDArrayf32]):
99
105
  super().__init__()
100
106
  self.axis = axis
101
107
 
108
+ @override
102
109
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
103
110
  return np.flip(x, axis=self.axis)
104
111
 
112
+ @override
105
113
  def extra_repr(self) -> str:
106
114
  return f"axis={self.axis}"
107
115
 
@@ -109,15 +117,13 @@ class ImagesFlip(Transform[NDArrayf32, NDArrayf32]):
109
117
  class ImagesFlipY(ImagesFlip):
110
118
  """Flip image stack along Y-axis.
111
119
 
112
- See Also
113
- --------
114
- ~.images.io.TeraflyImageStack:
115
- Terafly and Vaa3d use a especial right-handed coordinate system
116
- (with origin point in the left-top and z-axis points front),
117
- but we flip y-axis to makes it a left-handed coordinate system
118
- (with orgin point in the left-bottom and z-axis points front).
119
- If you need to use its coordinate system, remember to FLIP
120
- Y-AXIS BACK.
120
+ See Also:
121
+ ~.images.io.TeraflyImageStack:
122
+ Terafly and Vaa3d use a especial right-handed coordinate system (with
123
+ origin point in the left-top and z-axis points front), but we flip y-axis
124
+ to makes it a left-handed coordinate system (with origin point in the
125
+ left-bottom and z-axis points front). If you need to use its coordinate
126
+ system, remember to FLIP Y-AXIS BACK.
121
127
  """
122
128
 
123
129
  def __init__(self, axis: int = 1, /) -> None:
@@ -127,6 +133,7 @@ class ImagesFlipY(ImagesFlip):
127
133
  class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]):
128
134
  """Normalize image stack."""
129
135
 
136
+ @override
130
137
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
131
138
  mean = np.mean(x)
132
139
  variance = np.var(x)
@@ -136,9 +143,8 @@ class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]):
136
143
  class ImagesMeanVarianceAdjustment(Transform[NDArrayf32, NDArrayf32]):
137
144
  """Adjust image stack mean and variance.
138
145
 
139
- See Also
140
- --------
141
- ~swcgeom.images.ImageStackFolder.stat
146
+ See Also:
147
+ ~swcgeom.images.ImageStackFolder.stat
142
148
  """
143
149
 
144
150
  def __init__(self, mean: float, variance: float) -> None:
@@ -146,9 +152,11 @@ class ImagesMeanVarianceAdjustment(Transform[NDArrayf32, NDArrayf32]):
146
152
  self.mean = mean
147
153
  self.variance = variance
148
154
 
155
+ @override
149
156
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
150
157
  return (x - self.mean) / self.variance
151
158
 
159
+ @override
152
160
  def extra_repr(self) -> str:
153
161
  return f"mean={self.mean}, variance={self.variance}"
154
162
 
@@ -159,14 +167,10 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
159
167
  def __init__(self, vmin: float, vmax: float, *, clip: bool = True) -> None:
160
168
  """Scale image stack to unit range.
161
169
 
162
- Parameters
163
- ----------
164
- vmin : float
165
- Minimum value.
166
- vmax : float
167
- Maximum value.
168
- clip : bool, default True
169
- Clip values to [0, 1] to avoid numerical issues.
170
+ Args:
171
+ vmin: Minimum value.
172
+ vmax: Maximum value.
173
+ clip: Clip values to [0, 1] to avoid numerical issues.
170
174
  """
171
175
 
172
176
  super().__init__()
@@ -176,9 +180,11 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
176
180
  self.clip = clip
177
181
  self.post = ImagesClip(0, 1) if self.clip else Identity()
178
182
 
183
+ @override
179
184
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
180
185
  return self.post((x - self.vmin) / self.diff)
181
186
 
187
+ @override
182
188
  def extra_repr(self) -> str:
183
189
  return f"vmin={self.vmin}, vmax={self.vmax}, clip={self.clip}"
184
190
 
@@ -186,15 +192,15 @@ class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
186
192
  class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]):
187
193
  """Image histogram equalization.
188
194
 
189
- References
190
- ----------
191
- http://www.janeriksolem.net/histogram-equalization-with-python-and.html
195
+ References:
196
+ http://www.janeriksolem.net/histogram-equalization-with-python-and.html
192
197
  """
193
198
 
194
199
  def __init__(self, bins: int = 256) -> None:
195
200
  super().__init__()
196
201
  self.bins = bins
197
202
 
203
+ @override
198
204
  def __call__(self, x: NDArrayf32) -> NDArrayf32:
199
205
  # get image histogram
200
206
  hist, bin_edges = np.histogram(x.flatten(), self.bins, density=True)
@@ -205,5 +211,6 @@ class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]):
205
211
  equalized = np.interp(x.flatten(), bin_edges[:-1], cdf)
206
212
  return equalized.reshape(x.shape).astype(np.float32)
207
213
 
214
+ @override
208
215
  def extra_repr(self) -> str:
209
216
  return f"bins={self.bins}"
swcgeom/transforms/mst.py CHANGED
@@ -16,12 +16,12 @@
16
16
  """Minimum spanning tree."""
17
17
 
18
18
  import warnings
19
- from typing import Optional
20
19
 
21
20
  import numpy as np
22
21
  import pandas as pd
23
22
  from numpy import ma
24
23
  from numpy import typing as npt
24
+ from typing_extensions import override
25
25
 
26
26
  from swcgeom.core import Tree, sort_tree
27
27
  from swcgeom.core.swc_utils import SWCNames, SWCTypes, get_names, get_types
@@ -36,8 +36,7 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
36
36
  Creates trees corresponding to the minimum spanning tree keeping
37
37
  the path length to the root small (with balancing factor bf).
38
38
 
39
- References
40
- ----------
39
+ References:
41
40
  .. [1] Cuntz, H., Forstner, F., Borst, A. & Häusser, M. One Rule to
42
41
  Grow Them Al: A General Theory of Neuronal Branching and Its
43
42
  Practical Application. PLOS Comput Biol 6, e1000877 (2010).
@@ -52,21 +51,15 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
52
51
  furcations: int = 2,
53
52
  exclude_soma: bool = True,
54
53
  sort: bool = True,
55
- names: Optional[SWCNames] = None,
56
- types: Optional[SWCTypes] = None,
54
+ names: SWCNames | None = None,
55
+ types: SWCTypes | None = None,
57
56
  ) -> None:
58
57
  """
59
- Parameters
60
- ----------
61
- bf : float, default `0.4`
62
- Balancing factor between 0~1.
63
- furcations : int, default `2`
64
- Suppress multi-furcations which more than k. If set to -1,
65
- no suppression.
66
- exclude_soma : bool, default `True`
67
- Suppress multi-furcations exclude soma.
68
- names : SWCNames, optional
69
- types : SWCTypes, optional
58
+ Args:
59
+ bf: Balancing factor between 0~1.
60
+ furcations: Suppress multi-furcations which more than k.
61
+ If set to -1, no suppression.
62
+ exclude_soma: Suppress multi-furcations exclude soma.
70
63
  """
71
64
  self.bf = np.clip(bf, 0, 1)
72
65
  self.furcations = furcations
@@ -75,29 +68,26 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
75
68
  self.names = get_names(names)
76
69
  self.types = get_types(types)
77
70
 
71
+ @override
78
72
  def __call__( # pylint: disable=too-many-locals
79
73
  self,
80
74
  points: npt.NDArray[np.floating],
81
- soma: Optional[npt.ArrayLike] = None,
75
+ soma: npt.ArrayLike | None = None,
82
76
  *,
83
- names: Optional[SWCNames] = None,
77
+ names: SWCNames | None = None,
84
78
  ) -> Tree:
85
79
  """
86
- Paramters
87
- ---------
88
- points : array of shape (N, 3)
89
- Positions of points cloud.
90
- soma : array of shape (3,), default `None`
91
- Position of soma. If none, use the first point as soma.
92
- names : SWCNames, optional
80
+ Args:
81
+ points: Positions of points cloud.
82
+ soma: Position of soma. If none, use the first point as soma.
93
83
  """
94
84
  if names is None:
95
85
  names = self.names
96
86
  else:
97
87
  warnings.warn(
98
- "`PointsToCuntzMST(...)(names=...)` has been "
99
- "replaced by `PointsToCuntzMST(...,names=...)` since "
100
- "v0.12.0, and will be removed in next version",
88
+ "`PointsToCuntzMST(...)(names=...)` has been replaced by "
89
+ "`PointsToCuntzMST(...,names=...)` since v0.12.0, and will be removed "
90
+ "in next version",
101
91
  DeprecationWarning,
102
92
  )
103
93
  names = get_names(names) # TODO: remove it
@@ -155,6 +145,7 @@ class PointsToCuntzMST(Transform[npt.NDArray[np.float32], Tree]):
155
145
  t = sort_tree(t)
156
146
  return t
157
147
 
148
+ @override
158
149
  def extra_repr(self) -> str: # TODO: names, types
159
150
  return f"bf={self.bf:.4f}, furcations={self.furcations}, exclude_soma={self.exclude_soma}, sort={self.sort}"
160
151
 
@@ -166,22 +157,17 @@ class PointsToMST(PointsToCuntzMST): # pylint: disable=too-few-public-methods
166
157
  self,
167
158
  furcations: int = 2,
168
159
  *,
169
- k_furcations: Optional[int] = None,
160
+ k_furcations: int | None = None,
170
161
  exclude_soma: bool = True,
171
- names: Optional[SWCNames] = None,
172
- types: Optional[SWCTypes] = None,
162
+ names: SWCNames | None = None,
163
+ types: SWCTypes | None = None,
173
164
  **kwargs,
174
165
  ) -> None:
175
166
  """
176
- Parameters
177
- ----------
178
- furcations : int, default `2`
179
- Suppress multifurcations which more than k. If set to -1,
180
- no suppression.
181
- exclude_soma : bool, default `True`
182
- Suppress multi-furcations exclude soma.
183
- names : SWCNames, optional
184
- types : SWCTypes, optional
167
+ Args:
168
+ furcations: Suppress multifurcations which more than k.
169
+ If set to -1, no suppression.
170
+ exclude_soma: Suppress multi-furcations exclude soma.
185
171
  """
186
172
 
187
173
  if k_furcations is not None:
@@ -202,5 +188,6 @@ class PointsToMST(PointsToCuntzMST): # pylint: disable=too-few-public-methods
202
188
  **kwargs,
203
189
  )
204
190
 
191
+ @override
205
192
  def extra_repr(self) -> str:
206
193
  return f"furcations-{self.furcations}, exclude-soma={self.exclude_soma}"
@@ -18,8 +18,9 @@
18
18
  import os
19
19
  import re
20
20
  from enum import Enum, auto
21
- from io import TextIOBase
22
- from typing import Any, NamedTuple, Optional, cast
21
+ from typing import IO, Any, NamedTuple, cast
22
+
23
+ from typing_extensions import override
23
24
 
24
25
  from swcgeom.core import Tree
25
26
  from swcgeom.core.swc_utils import SWCNames, SWCTypes, get_names, get_types
@@ -31,6 +32,7 @@ __all__ = ["NeurolucidaAscToSwc"]
31
32
  class NeurolucidaAscToSwc(Transform[str, Tree]):
32
33
  """Convert neurolucida asc format to swc format."""
33
34
 
35
+ @override
34
36
  def __call__(self, x: str) -> Tree:
35
37
  return self.convert(x)
36
38
 
@@ -42,7 +44,7 @@ class NeurolucidaAscToSwc(Transform[str, Tree]):
42
44
  return tree
43
45
 
44
46
  @classmethod
45
- def from_stream(cls, x: TextIOBase, *, source: str = "") -> Tree:
47
+ def from_stream(cls, x: IO[str], *, source: str = "") -> Tree:
46
48
  parser = Parser(x, source=source)
47
49
  ast = parser.parse()
48
50
  tree = cls.from_ast(ast)
@@ -52,8 +54,8 @@ class NeurolucidaAscToSwc(Transform[str, Tree]):
52
54
  def from_ast(
53
55
  ast: "AST",
54
56
  *,
55
- names: Optional[SWCNames] = None,
56
- types: Optional[SWCTypes] = None,
57
+ names: SWCNames | None = None,
58
+ types: SWCTypes | None = None,
57
59
  ) -> Tree:
58
60
  names = get_names(names)
59
61
  types = get_types(types)
@@ -129,8 +131,8 @@ class ASTNode:
129
131
  self,
130
132
  type: ASTType,
131
133
  value: Any = None,
132
- tokens: Optional[list["Token"]] = None,
133
- children: Optional[list["ASTNode"]] = None,
134
+ tokens: list["Token"] | None = None,
135
+ children: list["ASTNode"] | None = None,
134
136
  ):
135
137
  self.type = type
136
138
  self.value = value
@@ -149,9 +151,7 @@ class ASTNode:
149
151
  """
150
152
  Compare two ASTNode objects.
151
153
 
152
- Notes
153
- -----
154
- The `parent`, `tokens` attribute is not compared.
154
+ NOTE: The `parent`, `tokens` attribute is not compared.
155
155
  """
156
156
  return (
157
157
  isinstance(__value, ASTNode)
@@ -162,7 +162,7 @@ class ASTNode:
162
162
 
163
163
 
164
164
  class AST(ASTNode):
165
- def __init__(self, children: Optional[list[ASTNode]] = None, source: str = ""):
165
+ def __init__(self, children: list[ASTNode] | None = None, source: str = ""):
166
166
  super().__init__(ASTType.ROOT, children=children)
167
167
  self.source = source
168
168
 
@@ -216,7 +216,7 @@ class AssertionTokenTypeError(Exception):
216
216
 
217
217
 
218
218
  class Parser:
219
- def __init__(self, r: TextIOBase, *, source: str = ""):
219
+ def __init__(self, r: IO[str], *, source: str = ""):
220
220
  self.lexer = Lexer(r)
221
221
  self.next_token = None
222
222
  self.source = source
@@ -426,7 +426,7 @@ RE_FLOAT = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?")
426
426
 
427
427
 
428
428
  class Lexer:
429
- def __init__(self, r: TextIOBase):
429
+ def __init__(self, r: IO[str]):
430
430
  self.r = r
431
431
  self.lineno = 1
432
432
  self.column = 1
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Transformation in path."""
17
17
 
18
+ from typing_extensions import override
19
+
18
20
  from swcgeom.core import Path, Tree, redirect_tree
19
21
  from swcgeom.transforms.base import Transform
20
22
 
@@ -24,6 +26,7 @@ __all__ = ["PathToTree", "PathReverser"]
24
26
  class PathToTree(Transform[Path, Tree]):
25
27
  """Transform path to tree."""
26
28
 
29
+ @override
27
30
  def __call__(self, x: Path) -> Tree:
28
31
  t = Tree(
29
32
  x.number_of_nodes(),
@@ -55,6 +58,7 @@ class PathReverser(Transform[Path, Path]):
55
58
  super().__init__()
56
59
  self.to_tree = PathToTree()
57
60
 
61
+ @override
58
62
  def __call__(self, x: Path) -> Path:
59
63
  x[0].type, x[-1].type = x[-1].type, x[0].type
60
64
  t = self.to_tree(x)
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Transformation in population."""
17
17
 
18
+ from typing_extensions import override
19
+
18
20
  from swcgeom.core import Population, Tree
19
21
  from swcgeom.transforms.base import Transform
20
22
 
@@ -28,6 +30,7 @@ class PopulationTransform(Transform[Population, Population]):
28
30
  super().__init__()
29
31
  self.transform = transform
30
32
 
33
+ @override
31
34
  def __call__(self, population: Population) -> Population:
32
35
  trees: list[Tree] = []
33
36
  for t in population:
@@ -38,5 +41,6 @@ class PopulationTransform(Transform[Population, Population]):
38
41
 
39
42
  return Population(trees, root=population.root)
40
43
 
44
+ @override
41
45
  def extra_repr(self) -> str:
42
46
  return f"transform={self.transform}"
@@ -16,10 +16,9 @@
16
16
  """Transformation in tree."""
17
17
 
18
18
  from collections.abc import Callable
19
- from typing import Optional
20
19
 
21
20
  import numpy as np
22
- from typing_extensions import deprecated
21
+ from typing_extensions import deprecated, override
23
22
 
24
23
  from swcgeom.core import Branch, BranchTree, DictSWC, Path, Tree, cut_tree, to_subtree
25
24
  from swcgeom.core.swc_utils import SWCTypes, get_types
@@ -46,6 +45,7 @@ __all__ = [
46
45
  class ToBranchTree(Transform[Tree, BranchTree]):
47
46
  """Transform tree to branch tree."""
48
47
 
48
+ @override
49
49
  def __call__(self, x: Tree) -> BranchTree:
50
50
  return BranchTree.from_tree(x)
51
51
 
@@ -56,6 +56,7 @@ class ToLongestPath(Transform[Tree, Path[DictSWC]]):
56
56
  def __init__(self, *, detach: bool = True) -> None:
57
57
  self.detach = detach
58
58
 
59
+ @override
59
60
  def __call__(self, x: Tree) -> Path[DictSWC]:
60
61
  paths = x.get_paths()
61
62
  idx = np.argmax([p.length() for p in paths])
@@ -71,6 +72,7 @@ class TreeSmoother(Transform[Tree, Tree]): # pylint: disable=missing-class-docs
71
72
  self.n_nodes = n_nodes
72
73
  self.trans = BranchConvSmoother(n_nodes=n_nodes)
73
74
 
75
+ @override
74
76
  def __call__(self, x: Tree) -> Tree:
75
77
  x = x.copy()
76
78
  for br in x.get_branches():
@@ -82,6 +84,7 @@ class TreeSmoother(Transform[Tree, Tree]): # pylint: disable=missing-class-docs
82
84
 
83
85
  return x
84
86
 
87
+ @override
85
88
  def extra_repr(self) -> str:
86
89
  return f"n_nodes={self.n_nodes}"
87
90
 
@@ -100,15 +103,14 @@ class CutByType(Transform[Tree, Tree]):
100
103
 
101
104
  In order to preserve the tree structure, all ancestor nodes of the node to be preserved will be preserved.
102
105
 
103
- Notes
104
- -----
105
- Not all reserved nodes are of the specified type.
106
+ NOTE: Not all reserved nodes are of the specified type.
106
107
  """
107
108
 
108
109
  def __init__(self, type: int) -> None: # pylint: disable=redefined-builtin
109
110
  super().__init__()
110
111
  self.type = type
111
112
 
113
+ @override
112
114
  def __call__(self, x: Tree) -> Tree:
113
115
  removals = set(x.id()[x.type() != self.type])
114
116
 
@@ -121,6 +123,7 @@ class CutByType(Transform[Tree, Tree]):
121
123
  y = to_subtree(x, removals)
122
124
  return y
123
125
 
126
+ @override
124
127
  def extra_repr(self) -> str:
125
128
  return f"type={self.type}"
126
129
 
@@ -128,7 +131,7 @@ class CutByType(Transform[Tree, Tree]):
128
131
  class CutAxonTree(CutByType):
129
132
  """Cut axon tree."""
130
133
 
131
- def __init__(self, types: Optional[SWCTypes] = None) -> None:
134
+ def __init__(self, types: SWCTypes | None = None) -> None:
132
135
  types = get_types(types)
133
136
  super().__init__(type=types.axon)
134
137
 
@@ -136,7 +139,7 @@ class CutAxonTree(CutByType):
136
139
  class CutDendriteTree(CutByType):
137
140
  """Cut dendrite tree."""
138
141
 
139
- def __init__(self, types: Optional[SWCTypes] = None) -> None:
142
+ def __init__(self, types: SWCTypes | None = None) -> None:
140
143
  types = get_types(types)
141
144
  super().__init__(type=types.basal_dendrite) # TODO: apical dendrite
142
145
 
@@ -149,6 +152,7 @@ class CutByFurcationOrder(Transform[Tree, Tree]):
149
152
  def __init__(self, max_bifurcation_order: int) -> None:
150
153
  self.max_furcation_order = max_bifurcation_order
151
154
 
155
+ @override
152
156
  def __call__(self, x: Tree) -> Tree:
153
157
  return cut_tree(x, enter=self._enter)
154
158
 
@@ -169,9 +173,7 @@ class CutByFurcationOrder(Transform[Tree, Tree]):
169
173
  class CutByBifurcationOrder(CutByFurcationOrder):
170
174
  """Cut tree by bifurcation order.
171
175
 
172
- Notes
173
- -----
174
- Deprecated due to the wrong spelling of furcation. For now, it
176
+ NOTE: Deprecated due to the wrong spelling of furcation. For now, it
175
177
  is just an alias of `CutByFurcationOrder` and raise a warning. It
176
178
  will be change to raise an error in the future.
177
179
  """
@@ -197,7 +199,7 @@ class CutShortTipBranch(Transform[Tree, Tree]):
197
199
  callbacks: list[Callable[[Tree.Branch], None]]
198
200
 
199
201
  def __init__(
200
- self, thre: float = 5, callback: Optional[Callable[[Tree.Branch], None]] = None
202
+ self, thre: float = 5, callback: Callable[[Tree.Branch], None] | None = None
201
203
  ) -> None:
202
204
  self.thre = thre
203
205
  self.callbacks = []
@@ -205,6 +207,7 @@ class CutShortTipBranch(Transform[Tree, Tree]):
205
207
  if callback is not None:
206
208
  self.callbacks.append(callback)
207
209
 
210
+ @override
208
211
  def __call__(self, x: Tree) -> Tree:
209
212
  removals: list[int] = []
210
213
  self.callbacks.append(lambda br: removals.append(br[1].id))
@@ -212,6 +215,7 @@ class CutShortTipBranch(Transform[Tree, Tree]):
212
215
  self.callbacks.pop()
213
216
  return to_subtree(x, removals)
214
217
 
218
+ @override
215
219
  def extra_repr(self) -> str:
216
220
  return f"threshold={self.thre}"
217
221
 
@@ -252,6 +256,7 @@ class Resampler(Transform[Tree, Tree]):
252
256
  self.resampler = branch_resampler
253
257
  self.assembler = BranchTreeAssembler()
254
258
 
259
+ @override
255
260
  def __call__(self, x: Tree) -> Tree:
256
261
  t = BranchTree.from_tree(x)
257
262
  t.branches = {