swcgeom 0.19.4__cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +341 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +107 -0
  23. swcgeom/core/swc_utils/io.py +204 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +384 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-311-x86_64-linux-gnu.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-311-x86_64-linux-gnu.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +161 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.19.4.dist-info/METADATA +86 -0
  69. swcgeom-0.19.4.dist-info/RECORD +72 -0
  70. swcgeom-0.19.4.dist-info/WHEEL +6 -0
  71. swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.19.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Analysis of volume of a SWC tree."""
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ from sdflit import ColoredMaterial, ObjectsScene, SDFObject, UniformSampler
12
+
13
+ from swcgeom.core import Tree
14
+ from swcgeom.utils import VolFrustumCone, VolSphere
15
+
16
+ __all__ = ["get_volume"]
17
+
18
+ ACCURACY_LEVEL = Literal["low", "middle", "high"]
19
+ ACCURACY_LEVELS: dict[ACCURACY_LEVEL, int] = {"low": 3, "middle": 5, "high": 8}
20
+
21
+
22
+ def get_volume(
23
+ tree: Tree,
24
+ *,
25
+ method: Literal["frustum_cone"] = "frustum_cone",
26
+ accuracy: int | ACCURACY_LEVEL = "middle",
27
+ ) -> float:
28
+ """Get the volume of the tree.
29
+
30
+ Args:
31
+ tree: Neuronal tree.
32
+ method: Method for volume calculation.
33
+ accuracy: Accuracy level for volume calculation. The higher the accuracy,
34
+ the more accurate the volume calculation, but the slower the
35
+ calculation. The accuracy level can be specified either as an
36
+ integer or as a string.
37
+
38
+ The string values correspond to the following accuracy levels:
39
+
40
+ - "low": 3
41
+ - "middle": 5
42
+ - "high": 8
43
+
44
+ Returns:
45
+ volume: Volume of the tree.
46
+
47
+ NOTE: The SWC format is a method for representing neurons, which includes both the
48
+ radius of individual points and their interconnectivity. Consequently, there are
49
+ multiple distinct approaches to representation within this framework.
50
+
51
+ Currently, we support a standard approach to volume calculation. This method
52
+ involves treating each node as a sphere and representing the connections between
53
+ them as truncated cone-like structures, or frustums, with varying radii at their
54
+ top and bottom surfaces.
55
+
56
+ We welcome additional representation methods through pull requests.
57
+ """
58
+ if isinstance(accuracy, str):
59
+ accuracy = ACCURACY_LEVELS[accuracy]
60
+
61
+ assert 0 < accuracy <= 10
62
+
63
+ match method:
64
+ case "frustum_cone":
65
+ return _get_volume_frustum_cone(tree, accuracy=accuracy)
66
+ case _:
67
+ raise ValueError(f"Unsupported method: {method}")
68
+
69
+
70
+ def _get_volume_frustum_cone(tree: Tree, *, accuracy: int) -> float:
71
+ """Get the volume of the tree using the frustum cone method.
72
+
73
+ Args:
74
+ tree: Neuronal tree.
75
+ accuracy: Accuracy level.
76
+ 1 : Sphere only
77
+ 2 : Sphere and Frustum Cone
78
+ 3 : Sphere, Frustum Cone, and intersection in single-branch
79
+ 5 : Above and Sphere-Frustum Cone intersection in multi-branch
80
+ 10 : Fully calculated by Monte Carlo method
81
+ """
82
+ if accuracy == 10:
83
+ return _get_volume_frustum_cone_mc_only(tree)
84
+
85
+ volume = 0.0
86
+
87
+ def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
88
+ sphere = VolSphere(n.xyz(), n.r)
89
+ cones = [VolFrustumCone(n.xyz(), n.r, c.center, c.radius) for c in children]
90
+
91
+ v = sphere.get_volume()
92
+ if accuracy >= 2:
93
+ v += sum(fc.get_volume() for fc in cones)
94
+
95
+ if accuracy >= 3:
96
+ v -= sum(sphere.intersect(fc).get_volume() for fc in cones)
97
+ v -= sum(s.intersect(fc).get_volume() for s, fc in zip(children, cones))
98
+ v += sum(s.intersect(sphere).get_volume() for s in children)
99
+
100
+ if accuracy >= 5:
101
+ v -= sum(
102
+ cones[i].intersect(cones[j]).subtract(sphere).get_volume()
103
+ for i in range(len(cones))
104
+ for j in range(i + 1, len(cones))
105
+ )
106
+
107
+ nonlocal volume
108
+ volume += v
109
+ return sphere
110
+
111
+ tree.traverse(leave=leave)
112
+ return volume
113
+
114
+
115
+ def _get_volume_frustum_cone_mc_only(tree: Tree) -> float:
116
+ if tree.number_of_nodes() == 0:
117
+ return 0
118
+
119
+ material = ColoredMaterial((1, 0, 0)).into()
120
+ scene = ObjectsScene()
121
+ scene.set_background((0, 0, 0))
122
+
123
+ def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
124
+ sphere = VolSphere(n.xyz(), n.r)
125
+ scene.add_object(SDFObject(sphere.sdf, material).into())
126
+
127
+ for c in children:
128
+ fc = VolFrustumCone(n.xyz(), n.r, c.center, c.radius)
129
+ scene.add_object(SDFObject(fc.sdf, material).into())
130
+
131
+ return sphere
132
+
133
+ tree.traverse(leave=leave)
134
+ scene.build_bvh()
135
+
136
+ # TODO: estimate the number of samples needed
137
+ n_samples = 100_000_000
138
+
139
+ vmin, vmax = scene.bounding_box()
140
+ sampler = UniformSampler(vmin, vmax)
141
+ data = sampler.sample(scene.into(), n_samples)
142
+ volume = data.sum() / n_samples * np.subtract(vmax, vmin).prod()
143
+ return volume
@@ -0,0 +1,19 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Neuron trees."""
7
+
8
+ from swcgeom.core import swc_utils as swc_utils
9
+ from swcgeom.core.branch import * # noqa: F403
10
+ from swcgeom.core.branch_tree import * # noqa: F403
11
+
12
+ # Segment and Segments don't expose
13
+ from swcgeom.core.compartment import Compartment, Compartments # noqa: F401
14
+ from swcgeom.core.node import * # noqa: F403
15
+ from swcgeom.core.path import * # noqa: F403
16
+ from swcgeom.core.population import * # noqa: F403
17
+ from swcgeom.core.swc import * # noqa: F403
18
+ from swcgeom.core.tree import * # noqa: F403
19
+ from swcgeom.core.tree_utils import * # noqa: F403
swcgeom/core/branch.py ADDED
@@ -0,0 +1,129 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Branch is a set of node points."""
7
+
8
+ from collections.abc import Iterable
9
+ from typing import Generic
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+
14
+ from swcgeom.core.compartment import Compartment, Compartments
15
+ from swcgeom.core.path import Path
16
+ from swcgeom.core.swc import DictSWC, SWCTypeVar
17
+
18
+ __all__ = ["Branch"]
19
+
20
+
21
+ class Branch(Path, Generic[SWCTypeVar]):
22
+ r"""Neural branch.
23
+
24
+ NOTE: Only a part of data of branch nodes is valid, such as `x`, `y`, `z` and `r`,
25
+ but the `id` and `pid` is usually invalid.
26
+ """
27
+
28
+ attach: SWCTypeVar
29
+ idx: npt.NDArray[np.int32]
30
+
31
+ class Compartment(Compartment["Branch"]):
32
+ """Segment of branch."""
33
+
34
+ Segment = Compartment # Alias
35
+
36
+ def __repr__(self) -> str:
37
+ return f"Neuron branch with {len(self)} nodes."
38
+
39
+ def keys(self) -> Iterable[str]:
40
+ return self.attach.keys()
41
+
42
+ def get_ndata(self, key: str) -> npt.NDArray:
43
+ return self.attach.get_ndata(key)[self.idx]
44
+
45
+ def get_compartments(self) -> Compartments[Compartment]:
46
+ return Compartments(self.Compartment(self, n.pid, n.id) for n in self[1:])
47
+
48
+ def get_segments(self) -> Compartments[Compartment]:
49
+ return self.get_compartments() # Alias
50
+
51
+ def detach(self) -> "Branch[DictSWC]":
52
+ """Detach from current attached object."""
53
+ # pylint: disable=consider-using-dict-items
54
+ attact = DictSWC(
55
+ **{k: self[k] for k in self.keys()},
56
+ source=self.attach.source,
57
+ names=self.names,
58
+ )
59
+ attact.ndata[self.names.id] = self.id()
60
+ attact.ndata[self.names.pid] = self.pid()
61
+ return Branch(attact, self.id())
62
+
63
+ @classmethod
64
+ def from_xyzr(cls, xyzr: npt.NDArray[np.float32]) -> "Branch[DictSWC]":
65
+ r"""Create a branch from ~numpy.ndarray.
66
+
67
+ Args:
68
+ xyzr: Collection of nodes.
69
+ If shape (n, 4), both `x`, `y`, `z`, `r` of nodes is enabled. If shape
70
+ (n, 3), only `x`, `y`, `z` is enabled and `r` will fill by 1.
71
+ """
72
+ assert xyzr.ndim == 2 and xyzr.shape[1] in (
73
+ 3,
74
+ 4,
75
+ ), f"xyzr should be of shape (N, 3) or (N, 4), got {xyzr.shape}"
76
+
77
+ n_nodes = xyzr.shape[0]
78
+ if xyzr.shape[1] == 3:
79
+ ones = np.ones([n_nodes, 1], dtype=np.float32)
80
+ xyzr = np.concatenate([xyzr, ones], axis=1)
81
+
82
+ idx = np.arange(0, n_nodes, step=1, dtype=np.int32)
83
+ attact = DictSWC(
84
+ id=idx,
85
+ type=np.full((n_nodes), fill_value=3, dtype=np.int32),
86
+ x=xyzr[:, 0],
87
+ y=xyzr[:, 1],
88
+ z=xyzr[:, 2],
89
+ r=xyzr[:, 3],
90
+ pid=np.arange(-1, n_nodes - 1, step=1, dtype=np.int32),
91
+ )
92
+ return Branch(attact, idx)
93
+
94
+ @classmethod
95
+ def from_xyzr_batch(
96
+ cls, xyzr_batch: npt.NDArray[np.float32]
97
+ ) -> list["Branch[DictSWC]"]:
98
+ r"""Create list of branch form ~numpy.ndarray.
99
+
100
+ Args:
101
+ xyzr: Batch of collection of nodes.
102
+ If shape (bs, n, 4), both `x`, `y`, `z`, `r` of nodes is enabled. If
103
+ shape (bs, n, 3), only `x`, `y`, `z` is enabled and `r` will fill by 1.
104
+ """
105
+ assert xyzr_batch.ndim == 3
106
+ assert xyzr_batch.shape[1] >= 3
107
+
108
+ if xyzr_batch.shape[2] == 3:
109
+ ones = np.ones(
110
+ [xyzr_batch.shape[0], xyzr_batch.shape[1], 1], dtype=np.float32
111
+ )
112
+ xyzr_batch = np.concatenate([xyzr_batch, ones], axis=2)
113
+
114
+ branches: list[Branch[DictSWC]] = []
115
+ for xyzr in xyzr_batch:
116
+ n_nodes = xyzr.shape[0]
117
+ idx = np.arange(0, n_nodes, step=1, dtype=np.int32)
118
+ attact = DictSWC(
119
+ id=idx,
120
+ type=np.full((n_nodes), fill_value=3, dtype=np.int32),
121
+ x=xyzr[:, 0],
122
+ y=xyzr[:, 1],
123
+ z=xyzr[:, 2],
124
+ r=xyzr[:, 3],
125
+ pid=np.arange(-1, n_nodes - 1, step=1, dtype=np.int32),
126
+ )
127
+ branches.append(Branch(attact, idx))
128
+
129
+ return branches
@@ -0,0 +1,65 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Branch tree is a simplified neuron tree."""
7
+
8
+ import itertools
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from typing_extensions import Self
13
+
14
+ from swcgeom.core.branch import Branch
15
+ from swcgeom.core.swc_utils import to_sub_topology
16
+ from swcgeom.core.tree import Tree
17
+
18
+ __all__ = ["BranchTree"]
19
+
20
+
21
+ class BranchTree(Tree):
22
+ """Branch tree keeps topology of tree.
23
+
24
+ A branch tree that contains only soma, branch, and tip nodes.
25
+ """
26
+
27
+ branches: dict[int, list[Branch]]
28
+
29
+ def get_origin_branches(self) -> list[Branch]:
30
+ """Get branches of original tree."""
31
+ return list(itertools.chain(*self.branches.values()))
32
+
33
+ def get_origin_node_branches(self, idx: int) -> list[Branch]:
34
+ """Get branches of node of original tree."""
35
+ return self.branches[idx]
36
+
37
+ @classmethod
38
+ def from_tree(cls, tree: Tree) -> Self:
39
+ """Generating a branch tree from tree."""
40
+
41
+ branches = tree.get_branches()
42
+
43
+ sub_id = np.array([0] + [br[-1].id for br in branches], dtype=np.int32)
44
+ sub_pid = np.array([-1] + [br[0].id for br in branches], dtype=np.int32)
45
+
46
+ (new_id, new_pid), id_map = to_sub_topology((sub_id, sub_pid))
47
+
48
+ n_nodes = new_id.shape[0]
49
+ ndata = {k: tree.get_ndata(k)[id_map].copy() for k in tree.keys()}
50
+ ndata.update(id=new_id, pid=new_pid)
51
+
52
+ branch_tree = cls(n_nodes, **ndata, source=tree.source, names=tree.names)
53
+
54
+ branch_tree.branches = {}
55
+ for br in branches:
56
+ idx = np.nonzero(id_map == br[0].id)[0][0].item()
57
+ branch_tree.branches.setdefault(idx, [])
58
+ branch_tree.branches[idx].append(br.detach())
59
+
60
+ return branch_tree
61
+
62
+ @classmethod
63
+ def from_data_frame(cls, df: pd.DataFrame, *args, **kwargs) -> Self:
64
+ tree = super().from_data_frame(df, *args, **kwargs)
65
+ return cls.from_tree(tree)
@@ -0,0 +1,107 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """The segment is a branch with two nodes."""
7
+
8
+ from collections.abc import Iterable
9
+ from typing import Generic, TypeVar
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+
14
+ from swcgeom.core.path import Path
15
+ from swcgeom.core.swc import DictSWC, SWCTypeVar
16
+ from swcgeom.core.swc_utils import SWCNames, get_names
17
+
18
+ __all__ = ["Compartment", "Compartments", "Segment", "Segments"]
19
+
20
+
21
+ class Compartment(Path, Generic[SWCTypeVar]):
22
+ r"""Compartment attached to external object."""
23
+
24
+ attach: SWCTypeVar
25
+ idx: npt.NDArray[np.int32]
26
+
27
+ def __init__(self, attach: SWCTypeVar, pid: int, idx: int) -> None:
28
+ super().__init__(attach, np.array([pid, idx]))
29
+
30
+ def keys(self) -> Iterable[str]:
31
+ return self.attach.keys()
32
+
33
+ def get_ndata(self, key: str) -> npt.NDArray:
34
+ return self.attach.get_ndata(key)[self.idx]
35
+
36
+ def detach(self) -> "Compartment[DictSWC]":
37
+ """Detach from current attached object."""
38
+ # pylint: disable=consider-using-dict-items
39
+ attact = DictSWC(
40
+ **{k: self[k] for k in self.keys()},
41
+ source=self.attach.source,
42
+ names=self.names,
43
+ )
44
+ attact.ndata[self.names.id] = self.id()
45
+ attact.ndata[self.names.pid] = self.pid()
46
+ return Compartment(attact, 0, 1)
47
+
48
+
49
+ T = TypeVar("T", bound=Compartment)
50
+
51
+
52
+ class Compartments(list[T]):
53
+ r"""Comparments contains a set of compartment."""
54
+
55
+ names: SWCNames
56
+
57
+ def __init__(self, segments: Iterable[T]) -> None:
58
+ super().__init__(segments)
59
+ self.names = self[0].names if len(self) > 0 else get_names()
60
+
61
+ def id(self) -> npt.NDArray[np.int32]: # pylint: disable=invalid-name
62
+ """Get the ids of shape (n_sample, 2)."""
63
+ return self.get_ndata(self.names.id)
64
+
65
+ def type(self) -> npt.NDArray[np.int32]:
66
+ """Get the types of shape (n_sample, 2)."""
67
+ return self.get_ndata(self.names.type)
68
+
69
+ def x(self) -> npt.NDArray[np.float32]:
70
+ """Get the x coordinates of shape (n_sample, 2)."""
71
+ return self.get_ndata(self.names.x)
72
+
73
+ def y(self) -> npt.NDArray[np.float32]:
74
+ """Get the y coordinates of shape (n_sample, 2)."""
75
+ return self.get_ndata(self.names.y)
76
+
77
+ def z(self) -> npt.NDArray[np.float32]:
78
+ """Get the z coordinates of shape (n_sample, 2)."""
79
+ return self.get_ndata(self.names.z)
80
+
81
+ def r(self) -> npt.NDArray[np.float32]:
82
+ """Get the radius of shape (n_sample, 2)."""
83
+ return self.get_ndata(self.names.r)
84
+
85
+ def pid(self) -> npt.NDArray[np.int32]:
86
+ """Get the ids of parent of shape (n_sample, 2)."""
87
+ return self.get_ndata(self.names.pid)
88
+
89
+ def xyz(self) -> npt.NDArray[np.float32]:
90
+ """Get the coordinates of shape (n_sample, 2, 3)."""
91
+ return np.stack([self.x(), self.y(), self.z()], axis=2)
92
+
93
+ def xyzr(self) -> npt.NDArray[np.float32]:
94
+ """Get the xyzr array of shape (n_sample, 2, 4)."""
95
+ return np.stack([self.x(), self.y(), self.z(), self.r()], axis=2)
96
+
97
+ def get_ndata(self, key: str) -> npt.NDArray:
98
+ """Get ndata of shape (n_sample, 2).
99
+
100
+ The order of axis 1 is (parent, current node).
101
+ """
102
+ return np.array([s.get_ndata(key) for s in self])
103
+
104
+
105
+ # Aliases
106
+ Segment = Compartment
107
+ Segments = Compartments
swcgeom/core/node.py ADDED
@@ -0,0 +1,130 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Nueron node."""
7
+
8
+ from collections.abc import Iterable
9
+ from typing import Any, Generic
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ from typing_extensions import deprecated
14
+
15
+ from swcgeom.core.swc import DictSWC, SWCTypeVar
16
+ from swcgeom.core.swc_utils import SWCNames
17
+
18
+ __all__ = ["Node"]
19
+
20
+
21
+ class Node(Generic[SWCTypeVar]):
22
+ """Neural node."""
23
+
24
+ attach: SWCTypeVar
25
+ idx: int | np.integer
26
+ names: SWCNames
27
+
28
+ # fmt: off
29
+ @property
30
+ def id(self) -> int: return self[self.names.id]
31
+ @id.setter
32
+ def id(self, v: int): self[self.names.id] = v
33
+
34
+ @property
35
+ def type(self) -> int: return self[self.names.type]
36
+ @type.setter
37
+ def type(self, v: int): self[self.names.type] = v
38
+
39
+ @property
40
+ def x(self) -> float: return self[self.names.x]
41
+ @x.setter
42
+ def x(self, v: float): self[self.names.x] = v
43
+
44
+ @property
45
+ def y(self) -> float: return self[self.names.y]
46
+ @y.setter
47
+ def y(self, v: float): self[self.names.y] = v
48
+
49
+ @property
50
+ def z(self) -> float: return self[self.names.z]
51
+ @z.setter
52
+ def z(self, v: float): self[self.names.z] = v
53
+
54
+ @property
55
+ def r(self) -> float: return self[self.names.r]
56
+ @r.setter
57
+ def r(self, v: float): self[self.names.r] = v
58
+
59
+ @property
60
+ def pid(self) -> int: return self[self.names.pid]
61
+ @pid.setter
62
+ def pid(self, v: int): self[self.names.pid] = v
63
+ # fmt: on
64
+
65
+ def __init__(self, attach: SWCTypeVar, idx: int | np.integer) -> None:
66
+ super().__init__()
67
+ self.attach = attach
68
+ self.idx = idx
69
+ self.names = attach.names
70
+
71
+ def __getitem__(self, key: str) -> Any:
72
+ return self.attach.get_ndata(key)[self.idx]
73
+
74
+ def __setitem__(self, k: str, v: Any) -> None:
75
+ self.attach.get_ndata(k)[self.idx] = v
76
+
77
+ def __str__(self) -> str:
78
+ return self.format_swc()
79
+
80
+ def __repr__(self) -> str:
81
+ return self.format_swc()
82
+
83
+ def xyz(self) -> npt.NDArray[np.float32]:
84
+ """Get the `x`, `y`, `z` of node, an array of shape (3,)"""
85
+ return np.array([self.x, self.y, self.z], dtype=np.float32)
86
+
87
+ def xyzr(self) -> npt.NDArray[np.float32]:
88
+ """Get the `x`, `y`, `z`, `r` of node, an array of shape (4,)"""
89
+ return np.array([self.x, self.y, self.z, self.r], dtype=np.float32)
90
+
91
+ def keys(self) -> Iterable[str]:
92
+ return self.attach.keys()
93
+
94
+ def distance(self, b: "Node") -> float:
95
+ """Get the distance of two nodes."""
96
+ return np.linalg.norm(self.xyz() - b.xyz()).item()
97
+
98
+ def format_swc(self) -> str:
99
+ """Get the SWC format string."""
100
+ x, y, z, r = [f"{f:.4f}" for f in [self.x, self.y, self.z, self.r]]
101
+ items = [self.id, self.type, x, y, z, r, self.pid]
102
+ return " ".join(map(str, items))
103
+
104
+ def is_furcation(self) -> bool:
105
+ """Is furcation node."""
106
+ return np.count_nonzero(self.attach.pid() == self.id) > 1
107
+
108
+ @deprecated("Use is_furcation instead")
109
+ def is_bifurcation(self) -> bool:
110
+ """Is furcation node.
111
+
112
+ NOTE: Deprecated due to the wrong spelling of furcation. For now, it is just an
113
+ alias of `is_furcation` and raise a warning. It will be change to raise an
114
+ error in the future.
115
+ """
116
+ return self.is_furcation()
117
+
118
+ def is_tip(self) -> bool:
119
+ return self.id not in self.attach.pid()
120
+
121
+ def detach(self) -> "Node[DictSWC]":
122
+ """Detach from current attached object."""
123
+ attact = DictSWC(
124
+ **{k: np.array([self[k]]) for k in self.keys()},
125
+ source=self.attach.source,
126
+ names=self.names,
127
+ )
128
+ attact.ndata[self.names.id] = np.array([0])
129
+ attact.ndata[self.names.pid] = np.array([-1])
130
+ return Node(attact, 0)