swcgeom 0.17.0__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (49) hide show
  1. swcgeom/_version.py +2 -2
  2. swcgeom/analysis/feature_extractor.py +25 -15
  3. swcgeom/analysis/features.py +20 -8
  4. swcgeom/analysis/lmeasure.py +33 -12
  5. swcgeom/analysis/sholl.py +10 -28
  6. swcgeom/analysis/trunk.py +12 -11
  7. swcgeom/analysis/visualization.py +9 -9
  8. swcgeom/analysis/visualization3d.py +85 -0
  9. swcgeom/analysis/volume.py +4 -4
  10. swcgeom/core/branch.py +4 -3
  11. swcgeom/core/branch_tree.py +3 -4
  12. swcgeom/core/compartment.py +3 -2
  13. swcgeom/core/node.py +17 -3
  14. swcgeom/core/path.py +6 -9
  15. swcgeom/core/population.py +43 -29
  16. swcgeom/core/swc.py +11 -10
  17. swcgeom/core/swc_utils/base.py +8 -17
  18. swcgeom/core/swc_utils/checker.py +3 -11
  19. swcgeom/core/swc_utils/io.py +7 -6
  20. swcgeom/core/swc_utils/normalizer.py +4 -3
  21. swcgeom/core/swc_utils/subtree.py +2 -2
  22. swcgeom/core/tree.py +41 -40
  23. swcgeom/core/tree_utils.py +13 -17
  24. swcgeom/core/tree_utils_impl.py +3 -3
  25. swcgeom/images/augmentation.py +3 -3
  26. swcgeom/images/folder.py +12 -26
  27. swcgeom/images/io.py +21 -35
  28. swcgeom/transforms/image_stack.py +20 -8
  29. swcgeom/transforms/images.py +3 -12
  30. swcgeom/transforms/neurolucida_asc.py +4 -6
  31. swcgeom/transforms/population.py +1 -3
  32. swcgeom/transforms/tree.py +38 -25
  33. swcgeom/transforms/tree_assembler.py +4 -3
  34. swcgeom/utils/download.py +44 -21
  35. swcgeom/utils/ellipse.py +3 -4
  36. swcgeom/utils/neuromorpho.py +17 -16
  37. swcgeom/utils/plotter_2d.py +12 -6
  38. swcgeom/utils/plotter_3d.py +31 -0
  39. swcgeom/utils/renderer.py +6 -6
  40. swcgeom/utils/sdf.py +4 -7
  41. swcgeom/utils/solid_geometry.py +1 -3
  42. swcgeom/utils/transforms.py +2 -4
  43. swcgeom/utils/volumetric_object.py +8 -10
  44. {swcgeom-0.17.0.dist-info → swcgeom-0.17.2.dist-info}/METADATA +19 -19
  45. swcgeom-0.17.2.dist-info/RECORD +67 -0
  46. {swcgeom-0.17.0.dist-info → swcgeom-0.17.2.dist-info}/WHEEL +1 -1
  47. swcgeom-0.17.0.dist-info/RECORD +0 -65
  48. {swcgeom-0.17.0.dist-info → swcgeom-0.17.2.dist-info}/LICENSE +0 -0
  49. {swcgeom-0.17.0.dist-info → swcgeom-0.17.2.dist-info}/top_level.txt +0 -0
swcgeom/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.17.0'
16
- __version_tuple__ = version_tuple = (0, 17, 0)
15
+ __version__ = version = '0.17.2'
16
+ __version_tuple__ = version_tuple = (0, 17, 2)
@@ -7,10 +7,11 @@ naming specification.
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
+ from collections.abc import Callable
10
11
  from functools import cached_property
11
12
  from itertools import chain
12
13
  from os.path import basename
13
- from typing import Any, Callable, Dict, List, Literal, Tuple, overload
14
+ from typing import Any, Literal, overload
14
15
 
15
16
  import numpy as np
16
17
  import numpy.typing as npt
@@ -18,8 +19,8 @@ import seaborn as sns
18
19
  from matplotlib.axes import Axes
19
20
 
20
21
  from swcgeom.analysis.features import (
21
- BifurcationFeatures,
22
22
  BranchFeatures,
23
+ FurcationFeatures,
23
24
  NodeFeatures,
24
25
  PathFeatures,
25
26
  TipFeatures,
@@ -39,6 +40,9 @@ Feature = Literal[
39
40
  "node_count",
40
41
  "node_radial_distance",
41
42
  "node_branch_order",
43
+ # furcation nodes
44
+ "furcation_count",
45
+ "furcation_radial_distance",
42
46
  # bifurcation nodes
43
47
  "bifurcation_count",
44
48
  "bifurcation_radial_distance",
@@ -54,9 +58,9 @@ Feature = Literal[
54
58
  ]
55
59
 
56
60
  NDArrayf32 = npt.NDArray[np.float32]
57
- FeatAndKwargs = Feature | Tuple[Feature, Dict[str, Any]]
61
+ FeatAndKwargs = Feature | tuple[Feature, dict[str, Any]]
58
62
 
59
- Feature1D = set(["length", "volume", "node_count", "bifurcation_count", "tip_count"])
63
+ Feature1D = set(["length", "volume", "node_count", "furcation_count", "tip_count"])
60
64
 
61
65
 
62
66
  class Features:
@@ -69,7 +73,7 @@ class Features:
69
73
  @cached_property
70
74
  def node_features(self) -> NodeFeatures: return NodeFeatures(self.tree)
71
75
  @cached_property
72
- def bifurcation_features(self) -> BifurcationFeatures: return BifurcationFeatures(self.node_features)
76
+ def furcation_features(self) -> FurcationFeatures: return FurcationFeatures(self.node_features)
73
77
  @cached_property
74
78
  def tip_features(self) -> TipFeatures: return TipFeatures(self.node_features)
75
79
  @cached_property
@@ -121,9 +125,9 @@ class FeatureExtractor(ABC):
121
125
  @overload
122
126
  def get(self, feature: Feature, **kwargs) -> NDArrayf32: ...
123
127
  @overload
124
- def get(self, feature: List[FeatAndKwargs]) -> List[NDArrayf32]: ...
128
+ def get(self, feature: list[FeatAndKwargs]) -> list[NDArrayf32]: ...
125
129
  @overload
126
- def get(self, feature: Dict[Feature, Dict[str, Any]]) -> Dict[str, NDArrayf32]: ...
130
+ def get(self, feature: dict[Feature, dict[str, Any]]) -> dict[str, NDArrayf32]: ...
127
131
  # fmt:on
128
132
  def get(self, feature, **kwargs):
129
133
  """Get feature.
@@ -168,7 +172,7 @@ class FeatureExtractor(ABC):
168
172
 
169
173
  # Custom Plots
170
174
 
171
- def plot_node_branch_order(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
175
+ def plot_node_branch_order(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
172
176
  vals = self._get("node_branch_order", **feature_kwargs)
173
177
  bin_edges = np.arange(int(np.ceil(vals.max() + 1))) + 0.5
174
178
  return self._plot_histogram_impl(vals, bin_edges, **kwargs)
@@ -213,6 +217,12 @@ class FeatureExtractor(ABC):
213
217
  ) -> Axes:
214
218
  raise NotImplementedError()
215
219
 
220
+ def get_bifurcation_count(self, **kwargs):
221
+ raise DeprecationWarning("Use `furcation_count` instead.")
222
+
223
+ def get_bifurcation_radial_distance(self, **kwargs):
224
+ raise DeprecationWarning("Use `furcation_radial_distance` instead.")
225
+
216
226
 
217
227
  class TreeFeatureExtractor(FeatureExtractor):
218
228
  """Extract feature from tree."""
@@ -234,7 +244,7 @@ class TreeFeatureExtractor(FeatureExtractor):
234
244
 
235
245
  def plot_sholl(
236
246
  self,
237
- feature_kwargs: Dict[str, Any], # pylint: disable=unused-argument
247
+ feature_kwargs: dict[str, Any], # pylint: disable=unused-argument
238
248
  **kwargs,
239
249
  ) -> Axes:
240
250
  _, ax = self._features.sholl.plot(**kwargs)
@@ -264,7 +274,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
264
274
  """Extract features from population."""
265
275
 
266
276
  _population: Population
267
- _features: List[Features]
277
+ _features: list[Features]
268
278
 
269
279
  def __init__(self, population: Population) -> None:
270
280
  super().__init__()
@@ -279,7 +289,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
279
289
 
280
290
  # Custom Plots
281
291
 
282
- def plot_sholl(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
292
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
283
293
  vals, rs = self._get_sholl_impl(**feature_kwargs)
284
294
  ax = self._lineplot(xs=rs, ys=vals.flatten(), **kwargs)
285
295
  ax.set_ylabel("Count of Intersections")
@@ -295,7 +305,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
295
305
 
296
306
  def _get_sholl_impl(
297
307
  self, steps: int = 20, **kwargs
298
- ) -> Tuple[NDArrayf32, NDArrayf32]:
308
+ ) -> tuple[NDArrayf32, NDArrayf32]:
299
309
  rmax = max(t.sholl.rmax for t in self._features)
300
310
  rs = Sholl.get_rs(rmax=rmax, steps=steps)
301
311
  vals = self._get_impl("sholl", steps=rs, **kwargs)
@@ -333,7 +343,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
333
343
  """Extract feature from population."""
334
344
 
335
345
  _populations: Populations
336
- _features: List[List[Features]]
346
+ _features: list[list[Features]]
337
347
 
338
348
  def __init__(self, populations: Populations) -> None:
339
349
  super().__init__()
@@ -348,7 +358,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
348
358
 
349
359
  # Custom Plots
350
360
 
351
- def plot_sholl(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
361
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
352
362
  vals, rs = self._get_sholl_impl(**feature_kwargs)
353
363
  ax = self._lineplot(xs=rs, ys=vals, **kwargs)
354
364
  ax.set_ylabel("Count of Intersections")
@@ -369,7 +379,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
369
379
 
370
380
  def _get_sholl_impl(
371
381
  self, steps: int = 20, **kwargs
372
- ) -> Tuple[NDArrayf32, NDArrayf32]:
382
+ ) -> tuple[NDArrayf32, NDArrayf32]:
373
383
  rmaxs = chain.from_iterable((t.sholl.rmax for t in p) for p in self._features)
374
384
  rmax = max(rmaxs)
375
385
  rs = Sholl.get_rs(rmax=rmax, steps=steps)
@@ -2,11 +2,11 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from functools import cached_property
5
- from typing import List, TypeVar
5
+ from typing import TypeVar
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
- from typing_extensions import Self
9
+ from typing_extensions import Self, deprecated
10
10
 
11
11
  from swcgeom.core import Branch, BranchTree, Tree
12
12
 
@@ -121,12 +121,24 @@ class _SubsetNodesFeatures(ABC):
121
121
  return cls(NodeFeatures(tree))
122
122
 
123
123
 
124
- class BifurcationFeatures(_SubsetNodesFeatures):
125
- """Evaluate bifurcation node feature of tree."""
124
+ class FurcationFeatures(_SubsetNodesFeatures):
125
+ """Evaluate furcation node feature of tree."""
126
126
 
127
127
  @cached_property
128
128
  def nodes(self) -> npt.NDArray[np.bool_]:
129
- return np.array([n.is_bifurcation() for n in self._features.tree])
129
+ return np.array([n.is_furcation() for n in self._features.tree])
130
+
131
+
132
+ @deprecated("Use FurcationFeatures instead")
133
+ class BifurcationFeatures(FurcationFeatures):
134
+ """Evaluate bifurcation node feature of tree.
135
+
136
+ Notes
137
+ -----
138
+ Deprecated due to the wrong spelling of furcation. For now, it
139
+ is just an alias of `FurcationFeatures` and raise a warning. It
140
+ will be change to raise an error in the future.
141
+ """
130
142
 
131
143
 
132
144
  class TipFeatures(_SubsetNodesFeatures):
@@ -163,7 +175,7 @@ class PathFeatures:
163
175
  return np.array([path.tortuosity() for path in self._paths], dtype=np.float32)
164
176
 
165
177
  @cached_property
166
- def _paths(self) -> List[Tree.Path]:
178
+ def _paths(self) -> list[Tree.Path]:
167
179
  return self.tree.get_paths()
168
180
 
169
181
 
@@ -201,7 +213,7 @@ class BranchFeatures:
201
213
  return self.calc_angle(self._branches, eps=eps)
202
214
 
203
215
  @staticmethod
204
- def calc_angle(branches: List[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
216
+ def calc_angle(branches: list[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
205
217
  """Calc agnle between branches.
206
218
 
207
219
  Returns
@@ -219,5 +231,5 @@ class BranchFeatures:
219
231
  return angle
220
232
 
221
233
  @cached_property
222
- def _branches(self) -> List[Tree.Branch]:
234
+ def _branches(self) -> list[Tree.Branch]:
223
235
  return self.tree.get_branches()
@@ -1,13 +1,11 @@
1
1
  """L-Measure analysis."""
2
2
 
3
3
  import math
4
- from typing import Literal, Tuple
4
+ from typing import Literal
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
-
9
8
  from swcgeom.core import Branch, Compartment, Node, Tree
10
- from swcgeom.utils import angle
11
9
 
12
10
  __all__ = ["LMeasure"]
13
11
 
@@ -69,7 +67,7 @@ class LMeasure:
69
67
  --------
70
68
  L-Measure: http://cng.gmu.edu:8080/Lm/help/N_bifs.htm
71
69
  """
72
- return len(tree.get_bifurcations())
70
+ return len(tree.get_furcations())
73
71
 
74
72
  def n_branch(self, tree: Tree) -> int:
75
73
  """Number of branches.
@@ -163,12 +161,15 @@ class LMeasure:
163
161
  --------
164
162
  L-Measure: http://cng.gmu.edu:8080/Lm/help/Partition_asymmetry.htm
165
163
  """
164
+
166
165
  children = n.children()
167
166
  assert (
168
167
  len(children) == 2
169
168
  ), "Partition asymmetry is only defined for bifurcations"
170
169
  n1 = len(children[0].subtree().get_tips())
171
170
  n2 = len(children[1].subtree().get_tips())
171
+ if n1 == n2:
172
+ return 0
172
173
  return abs(n1 - n2) / (n1 + n2 - 2)
173
174
 
174
175
  def fractal_dim(self):
@@ -274,7 +275,7 @@ class LMeasure:
274
275
  rall_power, _, _, _ = self._rall_power(bif)
275
276
  return rall_power
276
277
 
277
- def _rall_power_d(self, bif: Tree.Node) -> Tuple[float, float, float]:
278
+ def _rall_power_d(self, bif: Tree.Node) -> tuple[float, float, float]:
278
279
  children = bif.children()
279
280
  assert len(children) == 2, "Rall Power is only defined for bifurcations"
280
281
  parent = bif.parent()
@@ -284,7 +285,7 @@ class LMeasure:
284
285
  da, db = 2 * children[0].r, 2 * children[1].r
285
286
  return dp, da, db
286
287
 
287
- def _rall_power(self, bif: Tree.Node) -> Tuple[float, float, float, float]:
288
+ def _rall_power(self, bif: Tree.Node) -> tuple[float, float, float, float]:
288
289
  dp, da, db = self._rall_power_d(bif)
289
290
  start, stop, step = 0, 5, 5 / 1000
290
291
  xs = np.arange(start, stop, step)
@@ -336,7 +337,7 @@ class LMeasure:
336
337
  return (da**rall_power + db**rall_power) / dp**rall_power
337
338
 
338
339
  def bif_ampl_local(self, bif: Tree.Node) -> float:
339
- """Bifuraction angle.
340
+ """Bifurcation angle.
340
341
 
341
342
  Given a bifurcation, this function returns the angle between
342
343
  the first two compartments (in degree).
@@ -350,7 +351,7 @@ class LMeasure:
350
351
  return np.degrees(angle(v1, v2))
351
352
 
352
353
  def bif_ampl_remote(self, bif: Tree.Node) -> float:
353
- """Bifuraction angle.
354
+ """Bifurcation angle.
354
355
 
355
356
  This function returns the angle between two bifurcation points
356
357
  or between bifurcation point and terminal point or between two
@@ -361,7 +362,7 @@ class LMeasure:
361
362
  L-Measure: http://cng.gmu.edu:8080/Lm/help/Bif_ampl_remote.htm
362
363
  """
363
364
 
364
- v1, v2 = self._bif_vector_local(bif)
365
+ v1, v2 = self._bif_vector_remote(bif)
365
366
  return np.degrees(angle(v1, v2))
366
367
 
367
368
  def bif_tilt_local(self, bif: Tree.Node) -> float:
@@ -501,7 +502,7 @@ class LMeasure:
501
502
 
502
503
  def _bif_vector_local(
503
504
  self, bif: Tree.Node
504
- ) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
505
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
505
506
  children = bif.children()
506
507
  assert len(children) == 2, "Only defined for bifurcations"
507
508
 
@@ -511,7 +512,7 @@ class LMeasure:
511
512
 
512
513
  def _bif_vector_remote(
513
514
  self, bif: Tree.Node
514
- ) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
515
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
515
516
  children = bif.children()
516
517
  assert len(children) == 2, "Only defined for bifurcations"
517
518
 
@@ -665,7 +666,7 @@ class LMeasure:
665
666
  n = node
666
667
  order = 0
667
668
  while n is not None:
668
- if n.is_bifurcation():
669
+ if n.is_furcation():
669
670
  order += 1
670
671
  n = n.parent()
671
672
  return order
@@ -819,3 +820,23 @@ def pill_surface_area(ra: float, rb: float, h: float) -> float:
819
820
  bottom_hemisphere_area = 2 * math.pi * rb**2
820
821
  total_area = lateral_area + top_hemisphere_area + bottom_hemisphere_area
821
822
  return total_area
823
+
824
+
825
+ # TODO: move to `utils`
826
+ def angle(a: npt.ArrayLike, b: npt.ArrayLike) -> float:
827
+ """Get the angle of vectors.
828
+
829
+ Returns
830
+ -------
831
+ angle : float
832
+ Angle [0, PI) in radians.
833
+ """
834
+
835
+ a, b = np.array(a), np.array(b)
836
+ if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0:
837
+ raise ValueError("Input vectors must not be zero vectors.")
838
+
839
+ costheta = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
840
+ costheta = np.clip(costheta, -1, 1) # avoid numerical errors
841
+ theta = np.arccos(costheta)
842
+ return theta
swcgeom/analysis/sholl.py CHANGED
@@ -1,13 +1,14 @@
1
1
  """Sholl analysis."""
2
2
 
3
3
  import warnings
4
- from typing import List, Literal, Optional, Tuple
4
+ from typing import Literal, Optional
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
8
  import seaborn as sns
9
9
  from matplotlib.axes import Axes
10
10
  from matplotlib.figure import Figure
11
+ from typing_extensions import deprecated
11
12
 
12
13
  from swcgeom.analysis.visualization import draw
13
14
  from swcgeom.core import Tree
@@ -84,18 +85,18 @@ class Sholl:
84
85
 
85
86
  def plot( # pylint: disable=too-many-arguments
86
87
  self,
87
- steps: List[float] | int = 20,
88
+ steps: list[float] | int = 20,
88
89
  plot_type: str | None = None,
89
90
  kind: Literal["bar", "linechart", "circles"] = "circles",
90
91
  fig: Figure | None = None,
91
92
  ax: Axes | None = None,
92
93
  **kwargs,
93
- ) -> Tuple[Figure, Axes]:
94
+ ) -> tuple[Figure, Axes]:
94
95
  """Plot Sholl analysis.
95
96
 
96
97
  Parameters
97
98
  ----------
98
- steps : int or List[float], default to 20
99
+ steps : int or list[float], default to 20
99
100
  Steps of raius of circle. If steps is int, then it will be
100
101
  evenly divided into n radii.
101
102
  kind : "bar" | "linechart" | "circles", default `circles`
@@ -160,20 +161,17 @@ class Sholl:
160
161
 
161
162
  return self.get_rs(self.rmax, steps)
162
163
 
164
+ @deprecated("Use `Sholl.get(x)` instead")
163
165
  def get_count(self) -> npt.NDArray[np.int32]:
164
166
  """Get the count of intersection.
165
167
 
166
168
  .. deprecated:: 0.5.0
167
- Use :meth:`Sholl.get` instead.
169
+ Use :meth:`Sholl(x).get()` instead.
168
170
  """
169
171
 
170
- warnings.warn(
171
- "`Sholl.get_count` has been renamed to `get` since v0.5.0, "
172
- "and will be removed in next version",
173
- DeprecationWarning,
174
- )
175
172
  return self.get().astype(np.int32)
176
173
 
174
+ @deprecated("Use `Shool(x).get().mean()` instead")
177
175
  def avg(self) -> float:
178
176
  """Get the average of the count of intersection.
179
177
 
@@ -181,14 +179,9 @@ class Sholl:
181
179
  Use :meth:`Shool(x).get().mean()` instead.
182
180
  """
183
181
 
184
- warnings.warn(
185
- "`Sholl.avg` has been deprecated since v0.6.0 and will be "
186
- "removed in next version, use `Shool(x).get().mean()` "
187
- "instead",
188
- DeprecationWarning,
189
- )
190
182
  return self.get().mean()
191
183
 
184
+ @deprecated("Use `Shool(x).get().std()` instead")
192
185
  def std(self) -> float:
193
186
  """Get the std of the count of intersection.
194
187
 
@@ -196,14 +189,9 @@ class Sholl:
196
189
  Use :meth:`Shool(x).get().std()` instead.
197
190
  """
198
191
 
199
- warnings.warn(
200
- "`Sholl.std` has been deprecate since v0.6.0 and will be "
201
- "removed in next version, use `Shool(x).get().std()` "
202
- "instead",
203
- DeprecationWarning,
204
- )
205
192
  return self.get().std()
206
193
 
194
+ @deprecated("Use `Shool(x).get().sum()` instead")
207
195
  def sum(self) -> int:
208
196
  """Get the sum of the count of intersection.
209
197
 
@@ -211,10 +199,4 @@ class Sholl:
211
199
  Use :meth:`Shool(x).get().sum()` instead.
212
200
  """
213
201
 
214
- warnings.warn(
215
- "`Sholl.sum` has been deprecate since v0.6.0 and will be "
216
- "removed in next version, use `Shool(x).get().sum()` "
217
- "instead",
218
- DeprecationWarning,
219
- )
220
202
  return self.get().sum()
swcgeom/analysis/trunk.py CHANGED
@@ -2,8 +2,9 @@
2
2
 
3
3
  # pylint: disable=invalid-name
4
4
 
5
+ from collections.abc import Iterable
5
6
  from itertools import chain
6
- from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, cast
7
+ from typing import Any, Literal, Optional, cast
7
8
 
8
9
  import numpy as np
9
10
  import numpy.typing as npt
@@ -28,28 +29,28 @@ def draw_trunk(
28
29
  *,
29
30
  fig: Optional[Figure] = None,
30
31
  ax: Optional[Axes] = None,
31
- bound: Bounds | Tuple[Bounds, Dict[str, Any]] | None = "ellipse",
32
- point: bool | Dict[str, Any] = True,
32
+ bound: Bounds | tuple[Bounds, dict[str, Any]] | None = "ellipse",
33
+ point: bool | dict[str, Any] = True,
33
34
  projection: Projection = "2d",
34
35
  cmap: Any = "viridis",
35
36
  **kwargs,
36
- ) -> Tuple[Figure, Axes]:
37
+ ) -> tuple[Figure, Axes]:
37
38
  """Draw trunk tree.
38
39
 
39
40
  Parameters
40
41
  ----------
41
42
  t : Tree
42
- florets : List of (int | List of int)
43
+ florets : List of (int | list of int)
43
44
  The florets that needs to be removed, each floret can be a
44
45
  subtree or multiple subtrees (e.g., dendrites are a bunch of
45
46
  subtrees), each number is the id of a tree node.
46
47
  fig : ~matplotlib.figure.Figure, optional
47
48
  ax : ~matplotlib.axes.Axes, optional
48
- bound : Bounds | (Bounds, Dict[str, Any]) | None, default 'ellipse'
49
+ bound : Bounds | (Bounds, dict[str, Any]) | None, default 'ellipse'
49
50
  Kind of bound, support 'aabb', 'ellipse'. If bound is None, no
50
51
  bound will be drawn. If bound is a tuple, the second item will
51
52
  used as kwargs and forward to draw function.
52
- point : bool | Dict[str, Any], default True
53
+ point : bool | dict[str, Any], default True
53
54
  Draw point at the start of a subtree. If point is False, no
54
55
  point will be drawn. If point is a dict, this will used a
55
56
  kwargs and forward to draw function.
@@ -57,7 +58,7 @@ def draw_trunk(
57
58
  Colormap, any value supported by ~matplotlib.cm.Colormap. We
58
59
  will use the ratio of the length of the subtree to the total
59
60
  length of the tree to determine the color.
60
- **kwargs : Dict[str, Any]
61
+ **kwargs : dict[str, Any]
61
62
  Forward to ~swcgeom.analysis.draw.
62
63
  """
63
64
  # pylint: disable=too-many-locals
@@ -83,14 +84,14 @@ def draw_trunk(
83
84
 
84
85
  def split_florets(
85
86
  t: Tree, florets: Iterable[int | Iterable[int]]
86
- ) -> Tuple[Tree, List[List[Tree]]]:
87
+ ) -> tuple[Tree, list[list[Tree]]]:
87
88
  florets = [[i] if isinstance(i, (int, np.integer)) else i for i in florets]
88
89
  subtrees = [[get_subtree(t, ff) for ff in f] for f in florets]
89
90
  trunk = to_subtree(t, chain(*florets))
90
91
  return trunk, subtrees
91
92
 
92
93
 
93
- def get_length_ratio(t: Tree, tss: List[List[Tree]]) -> Any:
94
+ def get_length_ratio(t: Tree, tss: list[list[Tree]]) -> Any:
94
95
  lens = np.array([sum(t.length() for t in ts) for ts in tss])
95
96
  return lens / t.length()
96
97
 
@@ -101,7 +102,7 @@ def get_length_ratio(t: Tree, tss: List[List[Tree]]) -> Any:
101
102
  def draw_bound(
102
103
  ts: Iterable[Tree],
103
104
  ax: Axes,
104
- bound: Bounds | Tuple[Bounds, Dict[str, Any]],
105
+ bound: Bounds | tuple[Bounds, dict[str, Any]],
105
106
  projection: Projection,
106
107
  **kwargs,
107
108
  ) -> None:
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import weakref
5
- from typing import Any, Dict, List, Literal, Optional, Tuple
5
+ from typing import Any, Literal, Optional
6
6
 
7
7
  import numpy as np
8
8
  from matplotlib.axes import Axes
@@ -21,15 +21,15 @@ from swcgeom.utils import (
21
21
 
22
22
  __all__ = ["draw"]
23
23
 
24
- Positions = Literal["lt", "lb", "rt", "rb"] | Tuple[float, float]
25
- locations: Dict[Literal["lt", "lb", "rt", "rb"], Tuple[float, float]] = {
24
+ Positions = Literal["lt", "lb", "rt", "rb"] | tuple[float, float]
25
+ locations: dict[Literal["lt", "lb", "rt", "rb"], tuple[float, float]] = {
26
26
  "lt": (0.10, 0.90),
27
27
  "lb": (0.10, 0.10),
28
28
  "rt": (0.90, 0.90),
29
29
  "rb": (0.90, 0.10),
30
30
  }
31
31
 
32
- ax_weak_memo = weakref.WeakKeyDictionary[Axes, Dict[str, Any]]({})
32
+ ax_weak_memo = weakref.WeakKeyDictionary[Axes, dict[str, Any]]({})
33
33
 
34
34
 
35
35
  def draw(
@@ -39,7 +39,7 @@ def draw(
39
39
  ax: Optional[Axes] = None,
40
40
  show: bool | None = None,
41
41
  camera: CameraOptions = "xy",
42
- color: Optional[Dict[int, str] | str] = None,
42
+ color: Optional[dict[int, str] | str] = None,
43
43
  label: str | bool = True,
44
44
  direction_indicator: Positions | Literal[False] = "rb",
45
45
  unit: Optional[str] = None,
@@ -64,7 +64,7 @@ def draw(
64
64
  vector, then then threat it as (look-at, up), so camera is
65
65
  ((0, 0, 0), look-at, up). An easy way is to use the presets
66
66
  "xy", "yz" and "zx".
67
- color : Dict[int, str] | "vaa3d" | str, optional
67
+ color : dict[int, str] | "vaa3d" | str, optional
68
68
  Color map. If is dict, segments will be colored by the type of
69
69
  parent node.If is string, the value will be use for any type.
70
70
  label : str | bool, default True
@@ -120,14 +120,14 @@ def draw(
120
120
  return fig, ax
121
121
 
122
122
 
123
- def get_ax_swc(ax: Axes) -> List[SWCLike]:
123
+ def get_ax_swc(ax: Axes) -> list[SWCLike]:
124
124
  ax_weak_memo.setdefault(ax, {})
125
125
  return ax_weak_memo[ax]["swc"]
126
126
 
127
127
 
128
128
  def get_ax_color(
129
- ax: Axes, swc: SWCLike, color: Optional[Dict[int, str] | str] = None
130
- ) -> str | List[str]:
129
+ ax: Axes, swc: SWCLike, color: Optional[dict[int, str] | str] = None
130
+ ) -> str | list[str]:
131
131
  if color == "vaa3d":
132
132
  color = palette.vaa3d
133
133
  elif isinstance(color, str):
@@ -0,0 +1,85 @@
1
+ """Painter utils.
2
+
3
+ Notes
4
+ -----
5
+ This is a experimental function, it may be changed in the future.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import numpy as np
11
+ from matplotlib.axes import Axes
12
+ from matplotlib.figure import Figure
13
+ from mpl_toolkits.mplot3d import Axes3D
14
+
15
+ from swcgeom.analysis.visualization import (
16
+ _set_ax_memo,
17
+ get_ax_color,
18
+ get_ax_swc,
19
+ set_ax_legend,
20
+ )
21
+ from swcgeom.core import SWCLike, Tree
22
+ from swcgeom.utils.plotter_3d import draw_lines_3d
23
+
24
+ __all__ = ["draw3d"]
25
+
26
+
27
+ # TODO: support Camera
28
+ def draw3d(
29
+ swc: SWCLike | str,
30
+ *,
31
+ ax: Axes,
32
+ show: bool | None = None,
33
+ color: Optional[dict[int, str] | str] = None,
34
+ label: str | bool = True,
35
+ **kwargs,
36
+ ) -> tuple[Figure, Axes]:
37
+ r"""Draw neuron tree.
38
+
39
+ Parameters
40
+ ----------
41
+ swc : SWCLike | str
42
+ If it is str, then it is treated as the path of swc file.
43
+ fig : ~matplotlib.axes.Figure, optional
44
+ ax : ~matplotlib.axes.Axes, optional
45
+ show : bool | None, default `None`
46
+ Wheather to call `plt.show()`. If not specified, it will depend
47
+ on if ax is passed in, it will not be called, otherwise it will
48
+ be called by default.
49
+ color : dict[int, str] | "vaa3d" | str, optional
50
+ Color map. If is dict, segments will be colored by the type of
51
+ parent node.If is string, the value will be use for any type.
52
+ label : str | bool, default True
53
+ Label of legend, disable if False.
54
+ **kwargs : dict[str, Unknown]
55
+ Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
56
+ """
57
+
58
+ assert isinstance(ax, Axes3D), "only support 3D axes."
59
+
60
+ swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
61
+
62
+ show = (show is True) or (show is None and ax is None)
63
+ my_color = get_ax_color(ax, swc, color) # type: ignore
64
+
65
+ xyz = swc.xyz()
66
+ starts, ends = swc.id()[1:], swc.pid()[1:]
67
+ lines = np.stack([xyz[starts], xyz[ends]], axis=1)
68
+ collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)
69
+
70
+ min_vals = lines.reshape(-1, 3).min(axis=0)
71
+ max_vals = lines.reshape(-1, 3).max(axis=0)
72
+ ax.set_xlim(min_vals[0], max_vals[0])
73
+ ax.set_ylim(min_vals[1], max_vals[1])
74
+ ax.set_zlim(min_vals[2], max_vals[2])
75
+
76
+ _set_ax_memo(ax, swc, label=label, handle=collection)
77
+
78
+ if len(get_ax_swc(ax)) == 1:
79
+ # ax.set_aspect(1)
80
+ ax.spines[["top", "right"]].set_visible(False)
81
+ else:
82
+ set_ax_legend(ax, loc="upper right") # enable legend
83
+
84
+ fig = ax.figure
85
+ return fig, ax # type: ignore