swcgeom 0.17.0__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (47) hide show
  1. swcgeom/_version.py +2 -2
  2. swcgeom/analysis/feature_extractor.py +13 -12
  3. swcgeom/analysis/features.py +4 -4
  4. swcgeom/analysis/lmeasure.py +5 -5
  5. swcgeom/analysis/sholl.py +4 -4
  6. swcgeom/analysis/trunk.py +12 -11
  7. swcgeom/analysis/visualization.py +9 -9
  8. swcgeom/analysis/visualization3d.py +85 -0
  9. swcgeom/analysis/volume.py +4 -4
  10. swcgeom/core/branch.py +4 -3
  11. swcgeom/core/branch_tree.py +3 -4
  12. swcgeom/core/compartment.py +3 -2
  13. swcgeom/core/node.py +2 -2
  14. swcgeom/core/path.py +3 -2
  15. swcgeom/core/population.py +16 -27
  16. swcgeom/core/swc.py +11 -10
  17. swcgeom/core/swc_utils/base.py +8 -17
  18. swcgeom/core/swc_utils/io.py +7 -6
  19. swcgeom/core/swc_utils/normalizer.py +4 -3
  20. swcgeom/core/swc_utils/subtree.py +2 -2
  21. swcgeom/core/tree.py +22 -34
  22. swcgeom/core/tree_utils.py +11 -10
  23. swcgeom/core/tree_utils_impl.py +3 -3
  24. swcgeom/images/augmentation.py +3 -3
  25. swcgeom/images/folder.py +10 -16
  26. swcgeom/images/io.py +19 -30
  27. swcgeom/transforms/image_stack.py +6 -5
  28. swcgeom/transforms/images.py +2 -3
  29. swcgeom/transforms/neurolucida_asc.py +4 -6
  30. swcgeom/transforms/population.py +1 -3
  31. swcgeom/transforms/tree.py +8 -7
  32. swcgeom/transforms/tree_assembler.py +4 -3
  33. swcgeom/utils/ellipse.py +3 -4
  34. swcgeom/utils/neuromorpho.py +17 -16
  35. swcgeom/utils/plotter_2d.py +12 -6
  36. swcgeom/utils/plotter_3d.py +31 -0
  37. swcgeom/utils/renderer.py +6 -6
  38. swcgeom/utils/sdf.py +2 -2
  39. swcgeom/utils/solid_geometry.py +1 -3
  40. swcgeom/utils/transforms.py +1 -3
  41. swcgeom/utils/volumetric_object.py +8 -10
  42. {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/METADATA +1 -1
  43. swcgeom-0.17.1.dist-info/RECORD +67 -0
  44. swcgeom-0.17.0.dist-info/RECORD +0 -65
  45. {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/LICENSE +0 -0
  46. {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/WHEEL +0 -0
  47. {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/top_level.txt +0 -0
swcgeom/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.17.0'
16
- __version_tuple__ = version_tuple = (0, 17, 0)
15
+ __version__ = version = '0.17.1'
16
+ __version_tuple__ = version_tuple = (0, 17, 1)
@@ -7,10 +7,11 @@ naming specification.
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
+ from collections.abc import Callable
10
11
  from functools import cached_property
11
12
  from itertools import chain
12
13
  from os.path import basename
13
- from typing import Any, Callable, Dict, List, Literal, Tuple, overload
14
+ from typing import Any, Literal, overload
14
15
 
15
16
  import numpy as np
16
17
  import numpy.typing as npt
@@ -54,7 +55,7 @@ Feature = Literal[
54
55
  ]
55
56
 
56
57
  NDArrayf32 = npt.NDArray[np.float32]
57
- FeatAndKwargs = Feature | Tuple[Feature, Dict[str, Any]]
58
+ FeatAndKwargs = Feature | tuple[Feature, dict[str, Any]]
58
59
 
59
60
  Feature1D = set(["length", "volume", "node_count", "bifurcation_count", "tip_count"])
60
61
 
@@ -121,9 +122,9 @@ class FeatureExtractor(ABC):
121
122
  @overload
122
123
  def get(self, feature: Feature, **kwargs) -> NDArrayf32: ...
123
124
  @overload
124
- def get(self, feature: List[FeatAndKwargs]) -> List[NDArrayf32]: ...
125
+ def get(self, feature: list[FeatAndKwargs]) -> list[NDArrayf32]: ...
125
126
  @overload
126
- def get(self, feature: Dict[Feature, Dict[str, Any]]) -> Dict[str, NDArrayf32]: ...
127
+ def get(self, feature: dict[Feature, dict[str, Any]]) -> dict[str, NDArrayf32]: ...
127
128
  # fmt:on
128
129
  def get(self, feature, **kwargs):
129
130
  """Get feature.
@@ -168,7 +169,7 @@ class FeatureExtractor(ABC):
168
169
 
169
170
  # Custom Plots
170
171
 
171
- def plot_node_branch_order(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
172
+ def plot_node_branch_order(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
172
173
  vals = self._get("node_branch_order", **feature_kwargs)
173
174
  bin_edges = np.arange(int(np.ceil(vals.max() + 1))) + 0.5
174
175
  return self._plot_histogram_impl(vals, bin_edges, **kwargs)
@@ -234,7 +235,7 @@ class TreeFeatureExtractor(FeatureExtractor):
234
235
 
235
236
  def plot_sholl(
236
237
  self,
237
- feature_kwargs: Dict[str, Any], # pylint: disable=unused-argument
238
+ feature_kwargs: dict[str, Any], # pylint: disable=unused-argument
238
239
  **kwargs,
239
240
  ) -> Axes:
240
241
  _, ax = self._features.sholl.plot(**kwargs)
@@ -264,7 +265,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
264
265
  """Extract features from population."""
265
266
 
266
267
  _population: Population
267
- _features: List[Features]
268
+ _features: list[Features]
268
269
 
269
270
  def __init__(self, population: Population) -> None:
270
271
  super().__init__()
@@ -279,7 +280,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
279
280
 
280
281
  # Custom Plots
281
282
 
282
- def plot_sholl(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
283
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
283
284
  vals, rs = self._get_sholl_impl(**feature_kwargs)
284
285
  ax = self._lineplot(xs=rs, ys=vals.flatten(), **kwargs)
285
286
  ax.set_ylabel("Count of Intersections")
@@ -295,7 +296,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
295
296
 
296
297
  def _get_sholl_impl(
297
298
  self, steps: int = 20, **kwargs
298
- ) -> Tuple[NDArrayf32, NDArrayf32]:
299
+ ) -> tuple[NDArrayf32, NDArrayf32]:
299
300
  rmax = max(t.sholl.rmax for t in self._features)
300
301
  rs = Sholl.get_rs(rmax=rmax, steps=steps)
301
302
  vals = self._get_impl("sholl", steps=rs, **kwargs)
@@ -333,7 +334,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
333
334
  """Extract feature from population."""
334
335
 
335
336
  _populations: Populations
336
- _features: List[List[Features]]
337
+ _features: list[list[Features]]
337
338
 
338
339
  def __init__(self, populations: Populations) -> None:
339
340
  super().__init__()
@@ -348,7 +349,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
348
349
 
349
350
  # Custom Plots
350
351
 
351
- def plot_sholl(self, feature_kwargs: Dict[str, Any], **kwargs) -> Axes:
352
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
352
353
  vals, rs = self._get_sholl_impl(**feature_kwargs)
353
354
  ax = self._lineplot(xs=rs, ys=vals, **kwargs)
354
355
  ax.set_ylabel("Count of Intersections")
@@ -369,7 +370,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
369
370
 
370
371
  def _get_sholl_impl(
371
372
  self, steps: int = 20, **kwargs
372
- ) -> Tuple[NDArrayf32, NDArrayf32]:
373
+ ) -> tuple[NDArrayf32, NDArrayf32]:
373
374
  rmaxs = chain.from_iterable((t.sholl.rmax for t in p) for p in self._features)
374
375
  rmax = max(rmaxs)
375
376
  rs = Sholl.get_rs(rmax=rmax, steps=steps)
@@ -2,7 +2,7 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from functools import cached_property
5
- from typing import List, TypeVar
5
+ from typing import TypeVar
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -163,7 +163,7 @@ class PathFeatures:
163
163
  return np.array([path.tortuosity() for path in self._paths], dtype=np.float32)
164
164
 
165
165
  @cached_property
166
- def _paths(self) -> List[Tree.Path]:
166
+ def _paths(self) -> list[Tree.Path]:
167
167
  return self.tree.get_paths()
168
168
 
169
169
 
@@ -201,7 +201,7 @@ class BranchFeatures:
201
201
  return self.calc_angle(self._branches, eps=eps)
202
202
 
203
203
  @staticmethod
204
- def calc_angle(branches: List[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
204
+ def calc_angle(branches: list[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
205
205
  """Calc agnle between branches.
206
206
 
207
207
  Returns
@@ -219,5 +219,5 @@ class BranchFeatures:
219
219
  return angle
220
220
 
221
221
  @cached_property
222
- def _branches(self) -> List[Tree.Branch]:
222
+ def _branches(self) -> list[Tree.Branch]:
223
223
  return self.tree.get_branches()
@@ -1,7 +1,7 @@
1
1
  """L-Measure analysis."""
2
2
 
3
3
  import math
4
- from typing import Literal, Tuple
4
+ from typing import Literal
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -274,7 +274,7 @@ class LMeasure:
274
274
  rall_power, _, _, _ = self._rall_power(bif)
275
275
  return rall_power
276
276
 
277
- def _rall_power_d(self, bif: Tree.Node) -> Tuple[float, float, float]:
277
+ def _rall_power_d(self, bif: Tree.Node) -> tuple[float, float, float]:
278
278
  children = bif.children()
279
279
  assert len(children) == 2, "Rall Power is only defined for bifurcations"
280
280
  parent = bif.parent()
@@ -284,7 +284,7 @@ class LMeasure:
284
284
  da, db = 2 * children[0].r, 2 * children[1].r
285
285
  return dp, da, db
286
286
 
287
- def _rall_power(self, bif: Tree.Node) -> Tuple[float, float, float, float]:
287
+ def _rall_power(self, bif: Tree.Node) -> tuple[float, float, float, float]:
288
288
  dp, da, db = self._rall_power_d(bif)
289
289
  start, stop, step = 0, 5, 5 / 1000
290
290
  xs = np.arange(start, stop, step)
@@ -501,7 +501,7 @@ class LMeasure:
501
501
 
502
502
  def _bif_vector_local(
503
503
  self, bif: Tree.Node
504
- ) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
504
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
505
505
  children = bif.children()
506
506
  assert len(children) == 2, "Only defined for bifurcations"
507
507
 
@@ -511,7 +511,7 @@ class LMeasure:
511
511
 
512
512
  def _bif_vector_remote(
513
513
  self, bif: Tree.Node
514
- ) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
514
+ ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
515
515
  children = bif.children()
516
516
  assert len(children) == 2, "Only defined for bifurcations"
517
517
 
swcgeom/analysis/sholl.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Sholl analysis."""
2
2
 
3
3
  import warnings
4
- from typing import List, Literal, Optional, Tuple
4
+ from typing import Literal, Optional
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -84,18 +84,18 @@ class Sholl:
84
84
 
85
85
  def plot( # pylint: disable=too-many-arguments
86
86
  self,
87
- steps: List[float] | int = 20,
87
+ steps: list[float] | int = 20,
88
88
  plot_type: str | None = None,
89
89
  kind: Literal["bar", "linechart", "circles"] = "circles",
90
90
  fig: Figure | None = None,
91
91
  ax: Axes | None = None,
92
92
  **kwargs,
93
- ) -> Tuple[Figure, Axes]:
93
+ ) -> tuple[Figure, Axes]:
94
94
  """Plot Sholl analysis.
95
95
 
96
96
  Parameters
97
97
  ----------
98
- steps : int or List[float], default to 20
98
+ steps : int or list[float], default to 20
99
99
  Steps of raius of circle. If steps is int, then it will be
100
100
  evenly divided into n radii.
101
101
  kind : "bar" | "linechart" | "circles", default `circles`
swcgeom/analysis/trunk.py CHANGED
@@ -2,8 +2,9 @@
2
2
 
3
3
  # pylint: disable=invalid-name
4
4
 
5
+ from collections.abc import Iterable
5
6
  from itertools import chain
6
- from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, cast
7
+ from typing import Any, Literal, Optional, cast
7
8
 
8
9
  import numpy as np
9
10
  import numpy.typing as npt
@@ -28,28 +29,28 @@ def draw_trunk(
28
29
  *,
29
30
  fig: Optional[Figure] = None,
30
31
  ax: Optional[Axes] = None,
31
- bound: Bounds | Tuple[Bounds, Dict[str, Any]] | None = "ellipse",
32
- point: bool | Dict[str, Any] = True,
32
+ bound: Bounds | tuple[Bounds, dict[str, Any]] | None = "ellipse",
33
+ point: bool | dict[str, Any] = True,
33
34
  projection: Projection = "2d",
34
35
  cmap: Any = "viridis",
35
36
  **kwargs,
36
- ) -> Tuple[Figure, Axes]:
37
+ ) -> tuple[Figure, Axes]:
37
38
  """Draw trunk tree.
38
39
 
39
40
  Parameters
40
41
  ----------
41
42
  t : Tree
42
- florets : List of (int | List of int)
43
+ florets : List of (int | list of int)
43
44
  The florets that needs to be removed, each floret can be a
44
45
  subtree or multiple subtrees (e.g., dendrites are a bunch of
45
46
  subtrees), each number is the id of a tree node.
46
47
  fig : ~matplotlib.figure.Figure, optional
47
48
  ax : ~matplotlib.axes.Axes, optional
48
- bound : Bounds | (Bounds, Dict[str, Any]) | None, default 'ellipse'
49
+ bound : Bounds | (Bounds, dict[str, Any]) | None, default 'ellipse'
49
50
  Kind of bound, support 'aabb', 'ellipse'. If bound is None, no
50
51
  bound will be drawn. If bound is a tuple, the second item will
51
52
  used as kwargs and forward to draw function.
52
- point : bool | Dict[str, Any], default True
53
+ point : bool | dict[str, Any], default True
53
54
  Draw point at the start of a subtree. If point is False, no
54
55
  point will be drawn. If point is a dict, this will used a
55
56
  kwargs and forward to draw function.
@@ -57,7 +58,7 @@ def draw_trunk(
57
58
  Colormap, any value supported by ~matplotlib.cm.Colormap. We
58
59
  will use the ratio of the length of the subtree to the total
59
60
  length of the tree to determine the color.
60
- **kwargs : Dict[str, Any]
61
+ **kwargs : dict[str, Any]
61
62
  Forward to ~swcgeom.analysis.draw.
62
63
  """
63
64
  # pylint: disable=too-many-locals
@@ -83,14 +84,14 @@ def draw_trunk(
83
84
 
84
85
  def split_florets(
85
86
  t: Tree, florets: Iterable[int | Iterable[int]]
86
- ) -> Tuple[Tree, List[List[Tree]]]:
87
+ ) -> tuple[Tree, list[list[Tree]]]:
87
88
  florets = [[i] if isinstance(i, (int, np.integer)) else i for i in florets]
88
89
  subtrees = [[get_subtree(t, ff) for ff in f] for f in florets]
89
90
  trunk = to_subtree(t, chain(*florets))
90
91
  return trunk, subtrees
91
92
 
92
93
 
93
- def get_length_ratio(t: Tree, tss: List[List[Tree]]) -> Any:
94
+ def get_length_ratio(t: Tree, tss: list[list[Tree]]) -> Any:
94
95
  lens = np.array([sum(t.length() for t in ts) for ts in tss])
95
96
  return lens / t.length()
96
97
 
@@ -101,7 +102,7 @@ def get_length_ratio(t: Tree, tss: List[List[Tree]]) -> Any:
101
102
  def draw_bound(
102
103
  ts: Iterable[Tree],
103
104
  ax: Axes,
104
- bound: Bounds | Tuple[Bounds, Dict[str, Any]],
105
+ bound: Bounds | tuple[Bounds, dict[str, Any]],
105
106
  projection: Projection,
106
107
  **kwargs,
107
108
  ) -> None:
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import weakref
5
- from typing import Any, Dict, List, Literal, Optional, Tuple
5
+ from typing import Any, Literal, Optional
6
6
 
7
7
  import numpy as np
8
8
  from matplotlib.axes import Axes
@@ -21,15 +21,15 @@ from swcgeom.utils import (
21
21
 
22
22
  __all__ = ["draw"]
23
23
 
24
- Positions = Literal["lt", "lb", "rt", "rb"] | Tuple[float, float]
25
- locations: Dict[Literal["lt", "lb", "rt", "rb"], Tuple[float, float]] = {
24
+ Positions = Literal["lt", "lb", "rt", "rb"] | tuple[float, float]
25
+ locations: dict[Literal["lt", "lb", "rt", "rb"], tuple[float, float]] = {
26
26
  "lt": (0.10, 0.90),
27
27
  "lb": (0.10, 0.10),
28
28
  "rt": (0.90, 0.90),
29
29
  "rb": (0.90, 0.10),
30
30
  }
31
31
 
32
- ax_weak_memo = weakref.WeakKeyDictionary[Axes, Dict[str, Any]]({})
32
+ ax_weak_memo = weakref.WeakKeyDictionary[Axes, dict[str, Any]]({})
33
33
 
34
34
 
35
35
  def draw(
@@ -39,7 +39,7 @@ def draw(
39
39
  ax: Optional[Axes] = None,
40
40
  show: bool | None = None,
41
41
  camera: CameraOptions = "xy",
42
- color: Optional[Dict[int, str] | str] = None,
42
+ color: Optional[dict[int, str] | str] = None,
43
43
  label: str | bool = True,
44
44
  direction_indicator: Positions | Literal[False] = "rb",
45
45
  unit: Optional[str] = None,
@@ -64,7 +64,7 @@ def draw(
64
64
  vector, then then threat it as (look-at, up), so camera is
65
65
  ((0, 0, 0), look-at, up). An easy way is to use the presets
66
66
  "xy", "yz" and "zx".
67
- color : Dict[int, str] | "vaa3d" | str, optional
67
+ color : dict[int, str] | "vaa3d" | str, optional
68
68
  Color map. If is dict, segments will be colored by the type of
69
69
  parent node.If is string, the value will be use for any type.
70
70
  label : str | bool, default True
@@ -120,14 +120,14 @@ def draw(
120
120
  return fig, ax
121
121
 
122
122
 
123
- def get_ax_swc(ax: Axes) -> List[SWCLike]:
123
+ def get_ax_swc(ax: Axes) -> list[SWCLike]:
124
124
  ax_weak_memo.setdefault(ax, {})
125
125
  return ax_weak_memo[ax]["swc"]
126
126
 
127
127
 
128
128
  def get_ax_color(
129
- ax: Axes, swc: SWCLike, color: Optional[Dict[int, str] | str] = None
130
- ) -> str | List[str]:
129
+ ax: Axes, swc: SWCLike, color: Optional[dict[int, str] | str] = None
130
+ ) -> str | list[str]:
131
131
  if color == "vaa3d":
132
132
  color = palette.vaa3d
133
133
  elif isinstance(color, str):
@@ -0,0 +1,85 @@
1
+ """Painter utils.
2
+
3
+ Notes
4
+ -----
5
+ This is a experimental function, it may be changed in the future.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import numpy as np
11
+ from matplotlib.axes import Axes
12
+ from matplotlib.figure import Figure
13
+ from mpl_toolkits.mplot3d import Axes3D
14
+
15
+ from swcgeom.analysis.visualization import (
16
+ _set_ax_memo,
17
+ get_ax_color,
18
+ get_ax_swc,
19
+ set_ax_legend,
20
+ )
21
+ from swcgeom.core import SWCLike, Tree
22
+ from swcgeom.utils.plotter_3d import draw_lines_3d
23
+
24
+ __all__ = ["draw3d"]
25
+
26
+
27
+ # TODO: support Camera
28
+ def draw3d(
29
+ swc: SWCLike | str,
30
+ *,
31
+ ax: Axes,
32
+ show: bool | None = None,
33
+ color: Optional[dict[int, str] | str] = None,
34
+ label: str | bool = True,
35
+ **kwargs,
36
+ ) -> tuple[Figure, Axes]:
37
+ r"""Draw neuron tree.
38
+
39
+ Parameters
40
+ ----------
41
+ swc : SWCLike | str
42
+ If it is str, then it is treated as the path of swc file.
43
+ fig : ~matplotlib.axes.Figure, optional
44
+ ax : ~matplotlib.axes.Axes, optional
45
+ show : bool | None, default `None`
46
+ Wheather to call `plt.show()`. If not specified, it will depend
47
+ on if ax is passed in, it will not be called, otherwise it will
48
+ be called by default.
49
+ color : dict[int, str] | "vaa3d" | str, optional
50
+ Color map. If is dict, segments will be colored by the type of
51
+ parent node.If is string, the value will be use for any type.
52
+ label : str | bool, default True
53
+ Label of legend, disable if False.
54
+ **kwargs : dict[str, Unknown]
55
+ Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
56
+ """
57
+
58
+ assert isinstance(ax, Axes3D), "only support 3D axes."
59
+
60
+ swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
61
+
62
+ show = (show is True) or (show is None and ax is None)
63
+ my_color = get_ax_color(ax, swc, color) # type: ignore
64
+
65
+ xyz = swc.xyz()
66
+ starts, ends = swc.id()[1:], swc.pid()[1:]
67
+ lines = np.stack([xyz[starts], xyz[ends]], axis=1)
68
+ collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)
69
+
70
+ min_vals = lines.reshape(-1, 3).min(axis=0)
71
+ max_vals = lines.reshape(-1, 3).max(axis=0)
72
+ ax.set_xlim(min_vals[0], max_vals[0])
73
+ ax.set_ylim(min_vals[1], max_vals[1])
74
+ ax.set_zlim(min_vals[2], max_vals[2])
75
+
76
+ _set_ax_memo(ax, swc, label=label, handle=collection)
77
+
78
+ if len(get_ax_swc(ax)) == 1:
79
+ # ax.set_aspect(1)
80
+ ax.spines[["top", "right"]].set_visible(False)
81
+ else:
82
+ set_ax_legend(ax, loc="upper right") # enable legend
83
+
84
+ fig = ax.figure
85
+ return fig, ax # type: ignore
@@ -1,6 +1,6 @@
1
1
  """Analysis of volume of a SWC tree."""
2
2
 
3
- from typing import Dict, List, Literal
3
+ from typing import Literal
4
4
 
5
5
  import numpy as np
6
6
  from sdflit import ColoredMaterial, ObjectsScene, SDFObject, UniformSampler
@@ -11,7 +11,7 @@ from swcgeom.utils import VolFrustumCone, VolSphere
11
11
  __all__ = ["get_volume"]
12
12
 
13
13
  ACCURACY_LEVEL = Literal["low", "middle", "high"]
14
- ACCURACY_LEVELS: Dict[ACCURACY_LEVEL, int] = {"low": 3, "middle": 5, "high": 8}
14
+ ACCURACY_LEVELS: dict[ACCURACY_LEVEL, int] = {"low": 3, "middle": 5, "high": 8}
15
15
 
16
16
 
17
17
  def get_volume(
@@ -93,7 +93,7 @@ def _get_volume_frustum_cone(tree: Tree, *, accuracy: int) -> float:
93
93
 
94
94
  volume = 0.0
95
95
 
96
- def leave(n: Tree.Node, children: List[VolSphere]) -> VolSphere:
96
+ def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
97
97
  sphere = VolSphere(n.xyz(), n.r)
98
98
  cones = [VolFrustumCone(n.xyz(), n.r, c.center, c.radius) for c in children]
99
99
 
@@ -129,7 +129,7 @@ def _get_volume_frustum_cone_mc_only(tree: Tree) -> float:
129
129
  scene = ObjectsScene()
130
130
  scene.set_background((0, 0, 0))
131
131
 
132
- def leave(n: Tree.Node, children: List[VolSphere]) -> VolSphere:
132
+ def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
133
133
  sphere = VolSphere(n.xyz(), n.r)
134
134
  scene.add_object(SDFObject(sphere.sdf, material).into())
135
135
 
swcgeom/core/branch.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Branch is a set of node points."""
2
2
 
3
- from typing import Generic, Iterable, List
3
+ from collections.abc import Iterable
4
+ from typing import Generic
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
@@ -92,7 +93,7 @@ class Branch(Path, Generic[SWCTypeVar]):
92
93
  @classmethod
93
94
  def from_xyzr_batch(
94
95
  cls, xyzr_batch: npt.NDArray[np.float32]
95
- ) -> List["Branch[DictSWC]"]:
96
+ ) -> list["Branch[DictSWC]"]:
96
97
  r"""Create list of branch form ~numpy.ndarray.
97
98
 
98
99
  Parameters
@@ -112,7 +113,7 @@ class Branch(Path, Generic[SWCTypeVar]):
112
113
  )
113
114
  xyzr_batch = np.concatenate([xyzr_batch, ones], axis=2)
114
115
 
115
- branches: List[Branch[DictSWC]] = []
116
+ branches: list[Branch[DictSWC]] = []
116
117
  for xyzr in xyzr_batch:
117
118
  n_nodes = xyzr.shape[0]
118
119
  idx = np.arange(0, n_nodes, step=1, dtype=np.int32)
@@ -1,7 +1,6 @@
1
1
  """Branch tree is a simplified neuron tree."""
2
2
 
3
3
  import itertools
4
- from typing import Dict, List
5
4
 
6
5
  import numpy as np
7
6
  import pandas as pd
@@ -19,13 +18,13 @@ class BranchTree(Tree):
19
18
  A branch tree that contains only soma, branch, and tip nodes.
20
19
  """
21
20
 
22
- branches: Dict[int, List[Branch]]
21
+ branches: dict[int, list[Branch]]
23
22
 
24
- def get_origin_branches(self) -> List[Branch]:
23
+ def get_origin_branches(self) -> list[Branch]:
25
24
  """Get branches of original tree."""
26
25
  return list(itertools.chain(*self.branches.values()))
27
26
 
28
- def get_origin_node_branches(self, idx: int) -> List[Branch]:
27
+ def get_origin_node_branches(self, idx: int) -> list[Branch]:
29
28
  """Get branches of node of original tree."""
30
29
  return self.branches[idx]
31
30
 
@@ -1,6 +1,7 @@
1
1
  """The segment is a branch with two nodes."""
2
2
 
3
- from typing import Generic, Iterable, List, TypeVar
3
+ from collections.abc import Iterable
4
+ from typing import Generic, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
@@ -43,7 +44,7 @@ class Compartment(Path, Generic[SWCTypeVar]):
43
44
  T = TypeVar("T", bound=Compartment)
44
45
 
45
46
 
46
- class Compartments(List[T]):
47
+ class Compartments(list[T]):
47
48
  r"""Comparments contains a set of comparment."""
48
49
 
49
50
  names: SWCNames
swcgeom/core/node.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Nueron node."""
2
2
 
3
- import warnings
4
- from typing import Any, Generic, Iterable
3
+ from collections.abc import Iterable
4
+ from typing import Any, Generic
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
swcgeom/core/path.py CHANGED
@@ -1,7 +1,8 @@
1
1
  """Nueron path."""
2
2
 
3
3
  import warnings
4
- from typing import Generic, Iterable, Iterator, List, overload
4
+ from collections.abc import Iterable, Iterator
5
+ from typing import Generic, overload
5
6
 
6
7
  import numpy as np
7
8
  import numpy.typing as npt
@@ -44,7 +45,7 @@ class Path(SWCLike, Generic[SWCTypeVar]):
44
45
  @overload
45
46
  def __getitem__(self, key: int) -> Node: ...
46
47
  @overload
47
- def __getitem__(self, key: slice) -> List[Node]: ...
48
+ def __getitem__(self, key: slice) -> list[Node]: ...
48
49
  @overload
49
50
  def __getitem__(self, key: str) -> npt.NDArray: ...
50
51
  # fmt:on