swcgeom 0.20.0__cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/__init__.py +21 -0
- swcgeom/analysis/__init__.py +13 -0
- swcgeom/analysis/feature_extractor.py +454 -0
- swcgeom/analysis/features.py +218 -0
- swcgeom/analysis/lmeasure.py +750 -0
- swcgeom/analysis/sholl.py +201 -0
- swcgeom/analysis/trunk.py +183 -0
- swcgeom/analysis/visualization.py +191 -0
- swcgeom/analysis/visualization3d.py +81 -0
- swcgeom/analysis/volume.py +143 -0
- swcgeom/core/__init__.py +19 -0
- swcgeom/core/branch.py +129 -0
- swcgeom/core/branch_tree.py +65 -0
- swcgeom/core/compartment.py +107 -0
- swcgeom/core/node.py +130 -0
- swcgeom/core/path.py +155 -0
- swcgeom/core/population.py +394 -0
- swcgeom/core/swc.py +247 -0
- swcgeom/core/swc_utils/__init__.py +19 -0
- swcgeom/core/swc_utils/assembler.py +35 -0
- swcgeom/core/swc_utils/base.py +180 -0
- swcgeom/core/swc_utils/checker.py +112 -0
- swcgeom/core/swc_utils/io.py +335 -0
- swcgeom/core/swc_utils/normalizer.py +163 -0
- swcgeom/core/swc_utils/subtree.py +70 -0
- swcgeom/core/tree.py +387 -0
- swcgeom/core/tree_utils.py +277 -0
- swcgeom/core/tree_utils_impl.py +58 -0
- swcgeom/images/__init__.py +9 -0
- swcgeom/images/augmentation.py +149 -0
- swcgeom/images/contrast.py +87 -0
- swcgeom/images/folder.py +217 -0
- swcgeom/images/io.py +578 -0
- swcgeom/images/loaders/__init__.py +8 -0
- swcgeom/images/loaders/pbd.cpython-313-x86_64-linux-gnu.so +0 -0
- swcgeom/images/loaders/pbd.pyx +523 -0
- swcgeom/images/loaders/raw.cpython-313-x86_64-linux-gnu.so +0 -0
- swcgeom/images/loaders/raw.pyx +183 -0
- swcgeom/transforms/__init__.py +20 -0
- swcgeom/transforms/base.py +136 -0
- swcgeom/transforms/branch.py +223 -0
- swcgeom/transforms/branch_tree.py +74 -0
- swcgeom/transforms/geometry.py +270 -0
- swcgeom/transforms/image_preprocess.py +107 -0
- swcgeom/transforms/image_stack.py +219 -0
- swcgeom/transforms/images.py +206 -0
- swcgeom/transforms/mst.py +183 -0
- swcgeom/transforms/neurolucida_asc.py +498 -0
- swcgeom/transforms/path.py +56 -0
- swcgeom/transforms/population.py +36 -0
- swcgeom/transforms/tree.py +265 -0
- swcgeom/transforms/tree_assembler.py +160 -0
- swcgeom/utils/__init__.py +18 -0
- swcgeom/utils/debug.py +23 -0
- swcgeom/utils/download.py +119 -0
- swcgeom/utils/dsu.py +58 -0
- swcgeom/utils/ellipse.py +131 -0
- swcgeom/utils/file.py +90 -0
- swcgeom/utils/neuromorpho.py +581 -0
- swcgeom/utils/numpy_helper.py +70 -0
- swcgeom/utils/plotter_2d.py +134 -0
- swcgeom/utils/plotter_3d.py +35 -0
- swcgeom/utils/renderer.py +145 -0
- swcgeom/utils/sdf.py +324 -0
- swcgeom/utils/solid_geometry.py +154 -0
- swcgeom/utils/transforms.py +367 -0
- swcgeom/utils/volumetric_object.py +483 -0
- swcgeom-0.20.0.dist-info/METADATA +86 -0
- swcgeom-0.20.0.dist-info/RECORD +72 -0
- swcgeom-0.20.0.dist-info/WHEEL +7 -0
- swcgeom-0.20.0.dist-info/licenses/LICENSE +201 -0
- swcgeom-0.20.0.dist-info/top_level.txt +1 -0
swcgeom/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""A neuron geometry library for swc format."""
|
|
6
|
+
|
|
7
|
+
from swcgeom import analysis, core, images, transforms
|
|
8
|
+
from swcgeom.analysis import draw
|
|
9
|
+
from swcgeom.core import BranchTree, Population, Populations, Tree
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"analysis",
|
|
13
|
+
"core",
|
|
14
|
+
"images",
|
|
15
|
+
"transforms",
|
|
16
|
+
"draw",
|
|
17
|
+
"BranchTree",
|
|
18
|
+
"Population",
|
|
19
|
+
"Populations",
|
|
20
|
+
"Tree",
|
|
21
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Analysis for neuron trees."""
|
|
7
|
+
|
|
8
|
+
from swcgeom.analysis.feature_extractor import * # noqa: F403
|
|
9
|
+
from swcgeom.analysis.features import * # noqa: F403
|
|
10
|
+
from swcgeom.analysis.sholl import * # noqa: F403
|
|
11
|
+
from swcgeom.analysis.trunk import * # noqa: F403
|
|
12
|
+
from swcgeom.analysis.visualization import * # noqa: F403
|
|
13
|
+
from swcgeom.analysis.volume import * # noqa: F403
|
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Easy way to compute and visualize common features for feature.
|
|
6
|
+
|
|
7
|
+
NOTE: For development, see method `Features.get_evaluator` to confirm the naming
|
|
8
|
+
specification.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from functools import cached_property
|
|
14
|
+
from itertools import chain
|
|
15
|
+
from os.path import basename
|
|
16
|
+
from typing import Any, Literal, overload
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import numpy.typing as npt
|
|
20
|
+
import seaborn as sns
|
|
21
|
+
from matplotlib.axes import Axes
|
|
22
|
+
|
|
23
|
+
from swcgeom.analysis.features import (
|
|
24
|
+
BranchFeatures,
|
|
25
|
+
FurcationFeatures,
|
|
26
|
+
NodeFeatures,
|
|
27
|
+
PathFeatures,
|
|
28
|
+
TipFeatures,
|
|
29
|
+
)
|
|
30
|
+
from swcgeom.analysis.sholl import Sholl
|
|
31
|
+
from swcgeom.analysis.volume import get_volume
|
|
32
|
+
from swcgeom.core import Population, Populations, Tree
|
|
33
|
+
from swcgeom.utils import padding1d
|
|
34
|
+
|
|
35
|
+
__all__ = ["Feature", "extract_feature"]
|
|
36
|
+
|
|
37
|
+
Feature = Literal[
|
|
38
|
+
"length",
|
|
39
|
+
"volume",
|
|
40
|
+
"sholl",
|
|
41
|
+
# node
|
|
42
|
+
"node_count",
|
|
43
|
+
"node_radial_distance",
|
|
44
|
+
"node_branch_order",
|
|
45
|
+
# furcation nodes
|
|
46
|
+
"furcation_count",
|
|
47
|
+
"furcation_radial_distance",
|
|
48
|
+
# bifurcation nodes
|
|
49
|
+
"bifurcation_count",
|
|
50
|
+
"bifurcation_radial_distance",
|
|
51
|
+
# tip nodes
|
|
52
|
+
"tip_count",
|
|
53
|
+
"tip_radial_distance",
|
|
54
|
+
# branch
|
|
55
|
+
"branch_length",
|
|
56
|
+
"branch_tortuosity",
|
|
57
|
+
# path
|
|
58
|
+
"path_length",
|
|
59
|
+
"path_tortuosity",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
NDArrayf32 = npt.NDArray[np.float32]
|
|
63
|
+
FeatAndKwargs = Feature | tuple[Feature, dict[str, Any]]
|
|
64
|
+
|
|
65
|
+
Feature1D = set(["length", "volume", "node_count", "furcation_count", "tip_count"])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Features:
|
|
69
|
+
"""Tree features"""
|
|
70
|
+
|
|
71
|
+
tree: Tree
|
|
72
|
+
|
|
73
|
+
# Modules
|
|
74
|
+
@cached_property
|
|
75
|
+
def node_features(self) -> NodeFeatures:
|
|
76
|
+
return NodeFeatures(self.tree)
|
|
77
|
+
|
|
78
|
+
@cached_property
|
|
79
|
+
def furcation_features(self) -> FurcationFeatures:
|
|
80
|
+
return FurcationFeatures(self.node_features)
|
|
81
|
+
|
|
82
|
+
@cached_property
|
|
83
|
+
def tip_features(self) -> TipFeatures:
|
|
84
|
+
return TipFeatures(self.node_features)
|
|
85
|
+
|
|
86
|
+
@cached_property
|
|
87
|
+
def branch_features(self) -> BranchFeatures:
|
|
88
|
+
return BranchFeatures(self.tree)
|
|
89
|
+
|
|
90
|
+
@cached_property
|
|
91
|
+
def path_features(self) -> PathFeatures:
|
|
92
|
+
return PathFeatures(self.tree)
|
|
93
|
+
|
|
94
|
+
# Caches
|
|
95
|
+
@cached_property
|
|
96
|
+
def sholl(self) -> Sholl:
|
|
97
|
+
return Sholl(self.tree)
|
|
98
|
+
|
|
99
|
+
def __init__(self, tree: Tree) -> None:
|
|
100
|
+
self.tree = tree
|
|
101
|
+
|
|
102
|
+
def get(self, feature: FeatAndKwargs, **kwargs) -> NDArrayf32:
|
|
103
|
+
feat, kwargs = _get_feat_and_kwargs(feature, **kwargs)
|
|
104
|
+
evaluator = self.get_evaluator(feat)
|
|
105
|
+
return evaluator(**kwargs)
|
|
106
|
+
|
|
107
|
+
def get_evaluator(self, feature: Feature) -> Callable[[], npt.NDArray]:
|
|
108
|
+
if callable(calc := getattr(self, f"get_{feature}", None)):
|
|
109
|
+
return calc # custom features
|
|
110
|
+
|
|
111
|
+
components = feature.split("_")
|
|
112
|
+
if (module := getattr(self, f"{components[0]}_features", None)) and callable(
|
|
113
|
+
calc := getattr(module, f"get_{'_'.join(components[1:])}", None)
|
|
114
|
+
):
|
|
115
|
+
return calc
|
|
116
|
+
|
|
117
|
+
raise ValueError(f"Invalid feature: {feature}")
|
|
118
|
+
|
|
119
|
+
# Custom Features
|
|
120
|
+
|
|
121
|
+
def get_length(self, **kwargs) -> NDArrayf32:
|
|
122
|
+
return np.array([self.tree.length(**kwargs)], dtype=np.float32)
|
|
123
|
+
|
|
124
|
+
def get_volume(self, **kwargs) -> NDArrayf32:
|
|
125
|
+
return np.array([get_volume(self.tree, **kwargs)], dtype=np.float32)
|
|
126
|
+
|
|
127
|
+
def get_sholl(self, **kwargs) -> NDArrayf32:
|
|
128
|
+
return self.sholl.get(**kwargs).astype(np.float32)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class FeatureExtractor(ABC):
|
|
132
|
+
"""Extract features from tree."""
|
|
133
|
+
|
|
134
|
+
@overload
|
|
135
|
+
def get(self, feature: Feature, **kwargs) -> NDArrayf32: ...
|
|
136
|
+
@overload
|
|
137
|
+
def get(self, feature: list[FeatAndKwargs]) -> list[NDArrayf32]: ...
|
|
138
|
+
@overload
|
|
139
|
+
def get(self, feature: dict[Feature, dict[str, Any]]) -> dict[str, NDArrayf32]: ...
|
|
140
|
+
def get(self, feature, **kwargs):
|
|
141
|
+
"""Get feature.
|
|
142
|
+
|
|
143
|
+
NOTE: Shape of returned array is not uniform, `TreeFeatureExtractor` returns
|
|
144
|
+
array of shape (L, ), `PopulationFeatureExtracor` returns array of shape (N, L).
|
|
145
|
+
"""
|
|
146
|
+
if isinstance(feature, dict):
|
|
147
|
+
return {k: self._get(k, **v) for k, v in feature.items()}
|
|
148
|
+
|
|
149
|
+
if isinstance(feature, list):
|
|
150
|
+
return [self._get(k) for k in feature]
|
|
151
|
+
|
|
152
|
+
return self._get(feature, **kwargs)
|
|
153
|
+
|
|
154
|
+
def plot(self, feature: FeatAndKwargs, title: str | bool = True, **kwargs) -> Axes:
|
|
155
|
+
"""Plot feature with appropriate way.
|
|
156
|
+
|
|
157
|
+
NOTE: The drawing method is different in different classes, different in
|
|
158
|
+
different features, and may different between versions, there are NO guarantees.
|
|
159
|
+
"""
|
|
160
|
+
feat, feat_kwargs = _get_feat_and_kwargs(feature)
|
|
161
|
+
if callable(custom_plot := getattr(self, f"plot_{feat}", None)):
|
|
162
|
+
ax = custom_plot(feat_kwargs, **kwargs)
|
|
163
|
+
elif feat in Feature1D:
|
|
164
|
+
ax = self._plot_1d(feature, **kwargs)
|
|
165
|
+
else:
|
|
166
|
+
ax = self._plot_histogram(feature, **kwargs) # default plot
|
|
167
|
+
|
|
168
|
+
if isinstance(title, str):
|
|
169
|
+
ax.set_title(title)
|
|
170
|
+
elif title is True:
|
|
171
|
+
ax.set_title(_get_feature_name(feat))
|
|
172
|
+
|
|
173
|
+
return ax
|
|
174
|
+
|
|
175
|
+
# Custom Plots
|
|
176
|
+
|
|
177
|
+
def plot_node_branch_order(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
178
|
+
vals = self._get("node_branch_order", **feature_kwargs)
|
|
179
|
+
bin_edges = np.arange(int(np.ceil(vals.max() + 1))) + 0.5
|
|
180
|
+
return self._plot_histogram_impl(vals, bin_edges, **kwargs)
|
|
181
|
+
|
|
182
|
+
# Implements
|
|
183
|
+
|
|
184
|
+
def _get(self, feature: FeatAndKwargs, **kwargs) -> NDArrayf32:
|
|
185
|
+
feat, kwargs = _get_feat_and_kwargs(feature, **kwargs)
|
|
186
|
+
if callable(custom_get := getattr(self, f"get_{feat}", None)):
|
|
187
|
+
return custom_get(**kwargs)
|
|
188
|
+
|
|
189
|
+
return self._get_impl(feat, **kwargs) # default
|
|
190
|
+
|
|
191
|
+
def _plot_1d(self, feature: FeatAndKwargs, **kwargs) -> Axes:
|
|
192
|
+
vals = self._get(feature)
|
|
193
|
+
ax = self._plot_1d_impl(vals, **kwargs)
|
|
194
|
+
ax.set_ylabel(_get_feature_name(feature))
|
|
195
|
+
return ax
|
|
196
|
+
|
|
197
|
+
def _plot_histogram(
|
|
198
|
+
self,
|
|
199
|
+
feature: FeatAndKwargs,
|
|
200
|
+
bins=20,
|
|
201
|
+
range=None, # pylint: disable=redefined-builtin
|
|
202
|
+
**kwargs,
|
|
203
|
+
) -> Axes:
|
|
204
|
+
vals = self._get(feature)
|
|
205
|
+
bin_edges = np.histogram_bin_edges(vals, bins, range)
|
|
206
|
+
return self._plot_histogram_impl(vals, bin_edges, **kwargs)
|
|
207
|
+
|
|
208
|
+
@abstractmethod
|
|
209
|
+
def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
|
|
210
|
+
raise NotImplementedError()
|
|
211
|
+
|
|
212
|
+
@abstractmethod
|
|
213
|
+
def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
|
|
214
|
+
raise NotImplementedError()
|
|
215
|
+
|
|
216
|
+
@abstractmethod
|
|
217
|
+
def _plot_histogram_impl(
|
|
218
|
+
self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
|
|
219
|
+
) -> Axes:
|
|
220
|
+
raise NotImplementedError()
|
|
221
|
+
|
|
222
|
+
def get_bifurcation_count(self, **kwargs):
|
|
223
|
+
raise DeprecationWarning("Use `furcation_count` instead.")
|
|
224
|
+
|
|
225
|
+
def get_bifurcation_radial_distance(self, **kwargs):
|
|
226
|
+
raise DeprecationWarning("Use `furcation_radial_distance` instead.")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class TreeFeatureExtractor(FeatureExtractor):
|
|
230
|
+
"""Extract feature from tree."""
|
|
231
|
+
|
|
232
|
+
_tree: Tree
|
|
233
|
+
_features: Features
|
|
234
|
+
|
|
235
|
+
def __init__(self, tree: Tree) -> None:
|
|
236
|
+
super().__init__()
|
|
237
|
+
self._tree = tree
|
|
238
|
+
self._features = Features(tree)
|
|
239
|
+
|
|
240
|
+
# Custom Features
|
|
241
|
+
|
|
242
|
+
def get_sholl(self, **kwargs) -> NDArrayf32:
|
|
243
|
+
return self._features.sholl.get(**kwargs).astype(np.float32)
|
|
244
|
+
|
|
245
|
+
# Custom Plots
|
|
246
|
+
|
|
247
|
+
def plot_sholl(
|
|
248
|
+
self,
|
|
249
|
+
feature_kwargs: dict[str, Any], # pylint: disable=unused-argument
|
|
250
|
+
**kwargs,
|
|
251
|
+
) -> Axes:
|
|
252
|
+
_, ax = self._features.sholl.plot(**kwargs)
|
|
253
|
+
return ax
|
|
254
|
+
|
|
255
|
+
# Implements
|
|
256
|
+
|
|
257
|
+
def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
|
|
258
|
+
return self._features.get(feature, **kwargs)
|
|
259
|
+
|
|
260
|
+
def _plot_histogram_impl(
|
|
261
|
+
self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
|
|
262
|
+
) -> Axes:
|
|
263
|
+
hist, _ = np.histogram(vals[vals != 0], bins=bin_edges)
|
|
264
|
+
x = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
265
|
+
|
|
266
|
+
ax: Axes = sns.barplot(x=x, y=hist, **kwargs)
|
|
267
|
+
ax.set_ylabel("Count")
|
|
268
|
+
return ax
|
|
269
|
+
|
|
270
|
+
def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
|
|
271
|
+
name = basename(self._tree.source)
|
|
272
|
+
return sns.barplot(x=[name], y=vals.squeeze(), **kwargs)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class PopulationFeatureExtractor(FeatureExtractor):
|
|
276
|
+
"""Extract features from population."""
|
|
277
|
+
|
|
278
|
+
_population: Population
|
|
279
|
+
_features: list[Features]
|
|
280
|
+
|
|
281
|
+
def __init__(self, population: Population) -> None:
|
|
282
|
+
super().__init__()
|
|
283
|
+
self._population = population
|
|
284
|
+
self._features = [Features(t) for t in self._population]
|
|
285
|
+
|
|
286
|
+
# Custom Features
|
|
287
|
+
|
|
288
|
+
def get_sholl(self, **kwargs) -> NDArrayf32:
|
|
289
|
+
vals, _ = self._get_sholl_impl(**kwargs)
|
|
290
|
+
return vals
|
|
291
|
+
|
|
292
|
+
# Custom Plots
|
|
293
|
+
|
|
294
|
+
def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
295
|
+
vals, rs = self._get_sholl_impl(**feature_kwargs)
|
|
296
|
+
ax = self._lineplot(xs=rs, ys=vals.flatten(), **kwargs)
|
|
297
|
+
ax.set_ylabel("Count of Intersections")
|
|
298
|
+
return ax
|
|
299
|
+
|
|
300
|
+
# Implements
|
|
301
|
+
|
|
302
|
+
def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
|
|
303
|
+
vals = [f.get(feature, **kwargs) for f in self._features]
|
|
304
|
+
len_max = max(len(v) for v in vals)
|
|
305
|
+
v = np.stack([padding1d(len_max, v, dtype=np.float32) for v in vals])
|
|
306
|
+
return v
|
|
307
|
+
|
|
308
|
+
def _get_sholl_impl(
|
|
309
|
+
self, steps: int = 20, **kwargs
|
|
310
|
+
) -> tuple[NDArrayf32, NDArrayf32]:
|
|
311
|
+
rmax = max(t.sholl.rmax for t in self._features)
|
|
312
|
+
rs = Sholl.get_rs(rmax=rmax, steps=steps)
|
|
313
|
+
vals = self._get_impl("sholl", steps=rs, **kwargs)
|
|
314
|
+
return vals, rs
|
|
315
|
+
|
|
316
|
+
def _plot_histogram_impl(
|
|
317
|
+
self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
|
|
318
|
+
) -> Axes:
|
|
319
|
+
def hist(v):
|
|
320
|
+
return np.histogram(v[v != 0], bins=bin_edges)[0]
|
|
321
|
+
|
|
322
|
+
xs = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
323
|
+
ys = np.stack([hist(v) for v in vals])
|
|
324
|
+
|
|
325
|
+
ax: Axes = self._lineplot(xs, ys, **kwargs)
|
|
326
|
+
ax.set_ylabel("Count")
|
|
327
|
+
return ax
|
|
328
|
+
|
|
329
|
+
def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
|
|
330
|
+
x = [basename(t.source) for t in self._population]
|
|
331
|
+
y = vals.flatten()
|
|
332
|
+
ax: Axes = sns.barplot(x=x, y=y, **kwargs)
|
|
333
|
+
ax.axhline(y=y.mean(), ls="--", lw=1)
|
|
334
|
+
ax.set_xticks([])
|
|
335
|
+
return ax
|
|
336
|
+
|
|
337
|
+
def _lineplot(self, xs, ys, **kwargs) -> Axes:
|
|
338
|
+
xs = np.tile(xs, len(self._population))
|
|
339
|
+
ys = ys.flatten()
|
|
340
|
+
ax: Axes = sns.lineplot(x=xs, y=ys, **kwargs)
|
|
341
|
+
return ax
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class PopulationsFeatureExtractor(FeatureExtractor):
|
|
345
|
+
"""Extract feature from population."""
|
|
346
|
+
|
|
347
|
+
_populations: Populations
|
|
348
|
+
_features: list[list[Features]]
|
|
349
|
+
|
|
350
|
+
def __init__(self, populations: Populations) -> None:
|
|
351
|
+
super().__init__()
|
|
352
|
+
self._populations = populations
|
|
353
|
+
self._features = [[Features(t) for t in p] for p in populations.populations]
|
|
354
|
+
|
|
355
|
+
# Custom Features
|
|
356
|
+
|
|
357
|
+
def get_sholl(self, **kwargs) -> NDArrayf32:
|
|
358
|
+
vals, _ = self._get_sholl_impl(**kwargs)
|
|
359
|
+
return vals
|
|
360
|
+
|
|
361
|
+
# Custom Plots
|
|
362
|
+
|
|
363
|
+
def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
364
|
+
vals, rs = self._get_sholl_impl(**feature_kwargs)
|
|
365
|
+
ax = self._lineplot(xs=rs, ys=vals, **kwargs)
|
|
366
|
+
ax.set_ylabel("Count of Intersections")
|
|
367
|
+
return ax
|
|
368
|
+
|
|
369
|
+
# Implements
|
|
370
|
+
|
|
371
|
+
def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
|
|
372
|
+
vals = [[f.get(feature, **kwargs) for f in fs] for fs in self._features]
|
|
373
|
+
len_max1 = max(len(v) for v in vals)
|
|
374
|
+
len_max2 = max(*chain.from_iterable(((len(vv) for vv in v) for v in vals)))
|
|
375
|
+
out = np.zeros((len(vals), len_max1, len_max2), dtype=np.float32)
|
|
376
|
+
for i, v in enumerate(vals):
|
|
377
|
+
for j, vv in enumerate(v):
|
|
378
|
+
out[i, j, : len(vv)] = vv
|
|
379
|
+
|
|
380
|
+
return out
|
|
381
|
+
|
|
382
|
+
def _get_sholl_impl(
|
|
383
|
+
self, steps: int = 20, **kwargs
|
|
384
|
+
) -> tuple[NDArrayf32, NDArrayf32]:
|
|
385
|
+
rmaxs = chain.from_iterable((t.sholl.rmax for t in p) for p in self._features)
|
|
386
|
+
rmax = max(rmaxs)
|
|
387
|
+
rs = Sholl.get_rs(rmax=rmax, steps=steps)
|
|
388
|
+
vals = self._get_impl("sholl", steps=rs, **kwargs)
|
|
389
|
+
return vals, rs
|
|
390
|
+
|
|
391
|
+
def _plot_histogram_impl(
|
|
392
|
+
self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
|
|
393
|
+
) -> Axes:
|
|
394
|
+
def hist(v):
|
|
395
|
+
return np.histogram(v[v != 0], bins=bin_edges)[0]
|
|
396
|
+
|
|
397
|
+
xs = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
398
|
+
ys = np.stack([np.stack([hist(t) for t in p]) for p in vals])
|
|
399
|
+
|
|
400
|
+
ax = self._lineplot(xs=xs, ys=ys, **kwargs)
|
|
401
|
+
ax.set_ylabel("Count")
|
|
402
|
+
return ax
|
|
403
|
+
|
|
404
|
+
def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
|
|
405
|
+
labels = self._populations.labels
|
|
406
|
+
xs = np.concatenate([np.full(vals.shape[1], fill_value=i) for i in labels])
|
|
407
|
+
ys = vals.flatten()
|
|
408
|
+
|
|
409
|
+
# The numbers of tree in different populations may not be equal
|
|
410
|
+
valid = ys != 0
|
|
411
|
+
xs, ys = xs[valid], ys[valid]
|
|
412
|
+
|
|
413
|
+
ax: Axes = sns.boxplot(x=xs, y=ys, **kwargs)
|
|
414
|
+
return ax
|
|
415
|
+
|
|
416
|
+
def _lineplot(self, xs, ys, **kwargs) -> Axes:
|
|
417
|
+
p, t, f = ys.shape
|
|
418
|
+
labels = self._populations.labels # (P,)
|
|
419
|
+
x = np.tile(xs, p * t) # (F,) -> (P * T * F)
|
|
420
|
+
y = ys.flatten() # (P, T, F) -> (P * T * F)
|
|
421
|
+
hue = np.concatenate([np.full(t * f, fill_value=i) for i in labels])
|
|
422
|
+
|
|
423
|
+
# The numbers of tree in different populations may not be equal
|
|
424
|
+
valid = np.repeat(np.any(ys != 0, axis=2), f)
|
|
425
|
+
x, y, hue = x[valid], y[valid], hue[valid]
|
|
426
|
+
|
|
427
|
+
ax: Axes = sns.lineplot(x=x, y=y, hue=hue, **kwargs)
|
|
428
|
+
ax.set_ylabel("Count")
|
|
429
|
+
return ax
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def extract_feature(obj: Tree | Population | Populations) -> FeatureExtractor:
|
|
433
|
+
if isinstance(obj, Tree):
|
|
434
|
+
return TreeFeatureExtractor(obj)
|
|
435
|
+
|
|
436
|
+
if isinstance(obj, Population):
|
|
437
|
+
return PopulationFeatureExtractor(obj)
|
|
438
|
+
|
|
439
|
+
if isinstance(obj, Populations):
|
|
440
|
+
return PopulationsFeatureExtractor(obj)
|
|
441
|
+
|
|
442
|
+
raise TypeError("Invalid argument type.")
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _get_feat_and_kwargs(feature: FeatAndKwargs, **kwargs):
|
|
446
|
+
if isinstance(feature, tuple):
|
|
447
|
+
return feature[0], {**feature[1], **kwargs}
|
|
448
|
+
|
|
449
|
+
return feature, kwargs
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _get_feature_name(feature: FeatAndKwargs) -> str:
|
|
453
|
+
feat, _ = _get_feat_and_kwargs(feature)
|
|
454
|
+
return feat.replace("_", " ").title()
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Feature analysis of tree."""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
from typing import TypeVar
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
from typing_extensions import Self, deprecated
|
|
15
|
+
|
|
16
|
+
from swcgeom.core import Branch, BranchTree, Tree
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"NodeFeatures",
|
|
20
|
+
"BifurcationFeatures",
|
|
21
|
+
"TipFeatures",
|
|
22
|
+
"PathFeatures",
|
|
23
|
+
"BranchFeatures",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
T = TypeVar("T", bound=Branch)
|
|
27
|
+
|
|
28
|
+
# Node Level
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NodeFeatures:
|
|
32
|
+
"""Evaluate node feature of tree."""
|
|
33
|
+
|
|
34
|
+
tree: Tree
|
|
35
|
+
|
|
36
|
+
@cached_property
|
|
37
|
+
def _branch_tree(self) -> BranchTree:
|
|
38
|
+
return BranchTree.from_tree(self.tree)
|
|
39
|
+
|
|
40
|
+
def __init__(self, tree: Tree) -> None:
|
|
41
|
+
self.tree = tree
|
|
42
|
+
|
|
43
|
+
def get_count(self) -> npt.NDArray[np.float32]:
|
|
44
|
+
"""Get number of nodes.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
count: array of shape (1,)
|
|
48
|
+
"""
|
|
49
|
+
return np.array([self.tree.number_of_nodes()], dtype=np.float32)
|
|
50
|
+
|
|
51
|
+
def get_radial_distance(self) -> npt.NDArray[np.float32]:
|
|
52
|
+
"""Get the end-to-end straight-line distance to soma.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
radial_distance: Array of shape (N,).
|
|
56
|
+
"""
|
|
57
|
+
xyz = self.tree.xyz() - self.tree.soma().xyz()
|
|
58
|
+
radial_distance = np.linalg.norm(xyz, axis=1)
|
|
59
|
+
return radial_distance
|
|
60
|
+
|
|
61
|
+
def get_branch_order(self) -> npt.NDArray[np.int32]:
|
|
62
|
+
"""Get branch order of criticle nodes of tree.
|
|
63
|
+
|
|
64
|
+
Branch order is the number of bifurcations between current
|
|
65
|
+
position and the root.
|
|
66
|
+
|
|
67
|
+
Criticle node means that soma, bifucation nodes, tips.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
order: Array of shape (N,), which k is the number of branches.
|
|
71
|
+
"""
|
|
72
|
+
order = np.zeros_like(self._branch_tree.id(), dtype=np.int32)
|
|
73
|
+
|
|
74
|
+
def assign_depth(n: Tree.Node, pre_depth: int | None) -> int:
|
|
75
|
+
cur_order = pre_depth + 1 if pre_depth is not None else 0
|
|
76
|
+
order[n.id] = cur_order
|
|
77
|
+
return cur_order
|
|
78
|
+
|
|
79
|
+
self._branch_tree.traverse(enter=assign_depth)
|
|
80
|
+
return order
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _SubsetNodesFeatures(ABC):
|
|
84
|
+
_features: NodeFeatures
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def nodes(self) -> npt.NDArray[np.bool_]:
|
|
89
|
+
raise NotImplementedError()
|
|
90
|
+
|
|
91
|
+
def __init__(self, features: NodeFeatures) -> None:
|
|
92
|
+
self._features = features
|
|
93
|
+
|
|
94
|
+
def get_count(self) -> npt.NDArray[np.float32]:
|
|
95
|
+
"""Get number of nodes.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
count: Array of shape (1,).
|
|
99
|
+
"""
|
|
100
|
+
return np.array([np.count_nonzero(self.nodes)], dtype=np.float32)
|
|
101
|
+
|
|
102
|
+
def get_radial_distance(self) -> npt.NDArray[np.float32]:
|
|
103
|
+
"""Get the end-to-end straight-line distance to soma.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
radial_distance: Array of shape (N,).
|
|
107
|
+
"""
|
|
108
|
+
return self._features.get_radial_distance()[self.nodes]
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def from_tree(cls, tree: Tree) -> Self:
|
|
112
|
+
return cls(NodeFeatures(tree))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class FurcationFeatures(_SubsetNodesFeatures):
|
|
116
|
+
"""Evaluate furcation node feature of tree."""
|
|
117
|
+
|
|
118
|
+
@cached_property
|
|
119
|
+
def nodes(self) -> npt.NDArray[np.bool_]:
|
|
120
|
+
return np.array([n.is_furcation() for n in self._features.tree])
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@deprecated("Use FurcationFeatures instead")
|
|
124
|
+
class BifurcationFeatures(FurcationFeatures):
|
|
125
|
+
"""Evaluate bifurcation node feature of tree.
|
|
126
|
+
|
|
127
|
+
NOTE: Deprecated due to the wrong spelling of furcation. For now, it is just an
|
|
128
|
+
alias of `FurcationFeatures` and raise a warning. It will be change to raise an
|
|
129
|
+
error in the future.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class TipFeatures(_SubsetNodesFeatures):
|
|
134
|
+
"""Evaluate tip node feature of tree."""
|
|
135
|
+
|
|
136
|
+
@cached_property
|
|
137
|
+
def nodes(self) -> npt.NDArray[np.bool_]:
|
|
138
|
+
return np.array([n.is_tip() for n in self._features.tree])
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Path Level
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class PathFeatures:
|
|
145
|
+
"""Path analysis of tree."""
|
|
146
|
+
|
|
147
|
+
tree: Tree
|
|
148
|
+
|
|
149
|
+
def __init__(self, tree: Tree) -> None:
|
|
150
|
+
self.tree = tree
|
|
151
|
+
|
|
152
|
+
def get_count(self) -> int:
|
|
153
|
+
return len(self._paths)
|
|
154
|
+
|
|
155
|
+
def get_length(self) -> npt.NDArray[np.float32]:
|
|
156
|
+
"""Get length of paths."""
|
|
157
|
+
|
|
158
|
+
length = [path.length() for path in self._paths]
|
|
159
|
+
return np.array(length, dtype=np.float32)
|
|
160
|
+
|
|
161
|
+
def get_tortuosity(self) -> npt.NDArray[np.float32]:
|
|
162
|
+
"""Get tortuosity of path."""
|
|
163
|
+
|
|
164
|
+
return np.array([path.tortuosity() for path in self._paths], dtype=np.float32)
|
|
165
|
+
|
|
166
|
+
@cached_property
|
|
167
|
+
def _paths(self) -> list[Tree.Path]:
|
|
168
|
+
return self.tree.get_paths()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class BranchFeatures:
|
|
172
|
+
"""Analysis bransh of tree."""
|
|
173
|
+
|
|
174
|
+
tree: Tree
|
|
175
|
+
|
|
176
|
+
def __init__(self, tree: Tree) -> None:
|
|
177
|
+
self.tree = tree
|
|
178
|
+
|
|
179
|
+
def get_count(self) -> int:
|
|
180
|
+
return len(self._branches)
|
|
181
|
+
|
|
182
|
+
def get_length(self) -> npt.NDArray[np.float32]:
|
|
183
|
+
"""Get length of branches."""
|
|
184
|
+
|
|
185
|
+
length = [br.length() for br in self._branches]
|
|
186
|
+
return np.array(length, dtype=np.float32)
|
|
187
|
+
|
|
188
|
+
def get_tortuosity(self) -> npt.NDArray[np.float32]:
|
|
189
|
+
"""Get tortuosity of path."""
|
|
190
|
+
|
|
191
|
+
return np.array([br.tortuosity() for br in self._branches], dtype=np.float32)
|
|
192
|
+
|
|
193
|
+
def get_angle(self, eps: float = 1e-7) -> npt.NDArray[np.float32]:
|
|
194
|
+
"""Get agnle between branches.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
angle: An array of shape (N, N), which N is length of branches.
|
|
198
|
+
"""
|
|
199
|
+
return self.calc_angle(self._branches, eps=eps)
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def calc_angle(branches: list[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
|
|
203
|
+
"""Calc agnle between branches.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
angle: An array of shape (N, N), which N is length of branches.
|
|
207
|
+
"""
|
|
208
|
+
vector = np.array([br[-1].xyz() - br[0].xyz() for br in branches])
|
|
209
|
+
vector_dot = np.matmul(vector, vector.T)
|
|
210
|
+
vector_norm = np.linalg.norm(vector, ord=2, axis=1, keepdims=True)
|
|
211
|
+
vector_norm_dot = np.matmul(vector_norm, vector_norm.T) + eps
|
|
212
|
+
arccos = np.clip(vector_dot / vector_norm_dot, -1, 1)
|
|
213
|
+
angle = np.arccos(arccos)
|
|
214
|
+
return angle
|
|
215
|
+
|
|
216
|
+
@cached_property
|
|
217
|
+
def _branches(self) -> list[Tree.Branch]:
|
|
218
|
+
return self.tree.get_branches()
|