swcgeom 0.19.4__cp312-cp312-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +341 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +107 -0
  23. swcgeom/core/swc_utils/io.py +204 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +384 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-312-darwin.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-312-darwin.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +161 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.19.4.dist-info/METADATA +86 -0
  69. swcgeom-0.19.4.dist-info/RECORD +72 -0
  70. swcgeom-0.19.4.dist-info/WHEEL +5 -0
  71. swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.19.4.dist-info/top_level.txt +1 -0
swcgeom/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """A neuron geometry library for swc format."""
6
+
7
+ from swcgeom import analysis, core, images, transforms
8
+ from swcgeom.analysis import draw
9
+ from swcgeom.core import BranchTree, Population, Populations, Tree
10
+
11
+ __all__ = [
12
+ "analysis",
13
+ "core",
14
+ "images",
15
+ "transforms",
16
+ "draw",
17
+ "BranchTree",
18
+ "Population",
19
+ "Populations",
20
+ "Tree",
21
+ ]
@@ -0,0 +1,13 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Analysis for neuron trees."""
7
+
8
+ from swcgeom.analysis.feature_extractor import * # noqa: F403
9
+ from swcgeom.analysis.features import * # noqa: F403
10
+ from swcgeom.analysis.sholl import * # noqa: F403
11
+ from swcgeom.analysis.trunk import * # noqa: F403
12
+ from swcgeom.analysis.visualization import * # noqa: F403
13
+ from swcgeom.analysis.volume import * # noqa: F403
@@ -0,0 +1,454 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Easy way to compute and visualize common features for feature.
6
+
7
+ NOTE: For development, see method `Features.get_evaluator` to confirm the naming
8
+ specification.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from collections.abc import Callable
13
+ from functools import cached_property
14
+ from itertools import chain
15
+ from os.path import basename
16
+ from typing import Any, Literal, overload
17
+
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+ import seaborn as sns
21
+ from matplotlib.axes import Axes
22
+
23
+ from swcgeom.analysis.features import (
24
+ BranchFeatures,
25
+ FurcationFeatures,
26
+ NodeFeatures,
27
+ PathFeatures,
28
+ TipFeatures,
29
+ )
30
+ from swcgeom.analysis.sholl import Sholl
31
+ from swcgeom.analysis.volume import get_volume
32
+ from swcgeom.core import Population, Populations, Tree
33
+ from swcgeom.utils import padding1d
34
+
35
+ __all__ = ["Feature", "extract_feature"]
36
+
37
+ Feature = Literal[
38
+ "length",
39
+ "volume",
40
+ "sholl",
41
+ # node
42
+ "node_count",
43
+ "node_radial_distance",
44
+ "node_branch_order",
45
+ # furcation nodes
46
+ "furcation_count",
47
+ "furcation_radial_distance",
48
+ # bifurcation nodes
49
+ "bifurcation_count",
50
+ "bifurcation_radial_distance",
51
+ # tip nodes
52
+ "tip_count",
53
+ "tip_radial_distance",
54
+ # branch
55
+ "branch_length",
56
+ "branch_tortuosity",
57
+ # path
58
+ "path_length",
59
+ "path_tortuosity",
60
+ ]
61
+
62
+ NDArrayf32 = npt.NDArray[np.float32]
63
+ FeatAndKwargs = Feature | tuple[Feature, dict[str, Any]]
64
+
65
+ Feature1D = set(["length", "volume", "node_count", "furcation_count", "tip_count"])
66
+
67
+
68
+ class Features:
69
+ """Tree features"""
70
+
71
+ tree: Tree
72
+
73
+ # Modules
74
+ @cached_property
75
+ def node_features(self) -> NodeFeatures:
76
+ return NodeFeatures(self.tree)
77
+
78
+ @cached_property
79
+ def furcation_features(self) -> FurcationFeatures:
80
+ return FurcationFeatures(self.node_features)
81
+
82
+ @cached_property
83
+ def tip_features(self) -> TipFeatures:
84
+ return TipFeatures(self.node_features)
85
+
86
+ @cached_property
87
+ def branch_features(self) -> BranchFeatures:
88
+ return BranchFeatures(self.tree)
89
+
90
+ @cached_property
91
+ def path_features(self) -> PathFeatures:
92
+ return PathFeatures(self.tree)
93
+
94
+ # Caches
95
+ @cached_property
96
+ def sholl(self) -> Sholl:
97
+ return Sholl(self.tree)
98
+
99
+ def __init__(self, tree: Tree) -> None:
100
+ self.tree = tree
101
+
102
+ def get(self, feature: FeatAndKwargs, **kwargs) -> NDArrayf32:
103
+ feat, kwargs = _get_feat_and_kwargs(feature, **kwargs)
104
+ evaluator = self.get_evaluator(feat)
105
+ return evaluator(**kwargs)
106
+
107
+ def get_evaluator(self, feature: Feature) -> Callable[[], npt.NDArray]:
108
+ if callable(calc := getattr(self, f"get_{feature}", None)):
109
+ return calc # custom features
110
+
111
+ components = feature.split("_")
112
+ if (module := getattr(self, f"{components[0]}_features", None)) and callable(
113
+ calc := getattr(module, f"get_{'_'.join(components[1:])}", None)
114
+ ):
115
+ return calc
116
+
117
+ raise ValueError(f"Invalid feature: {feature}")
118
+
119
+ # Custom Features
120
+
121
+ def get_length(self, **kwargs) -> NDArrayf32:
122
+ return np.array([self.tree.length(**kwargs)], dtype=np.float32)
123
+
124
+ def get_volume(self, **kwargs) -> NDArrayf32:
125
+ return np.array([get_volume(self.tree, **kwargs)], dtype=np.float32)
126
+
127
+ def get_sholl(self, **kwargs) -> NDArrayf32:
128
+ return self.sholl.get(**kwargs).astype(np.float32)
129
+
130
+
131
+ class FeatureExtractor(ABC):
132
+ """Extract features from tree."""
133
+
134
+ @overload
135
+ def get(self, feature: Feature, **kwargs) -> NDArrayf32: ...
136
+ @overload
137
+ def get(self, feature: list[FeatAndKwargs]) -> list[NDArrayf32]: ...
138
+ @overload
139
+ def get(self, feature: dict[Feature, dict[str, Any]]) -> dict[str, NDArrayf32]: ...
140
+ def get(self, feature, **kwargs):
141
+ """Get feature.
142
+
143
+ NOTE: Shape of returned array is not uniform, `TreeFeatureExtractor` returns
144
+ array of shape (L, ), `PopulationFeatureExtracor` returns array of shape (N, L).
145
+ """
146
+ if isinstance(feature, dict):
147
+ return {k: self._get(k, **v) for k, v in feature.items()}
148
+
149
+ if isinstance(feature, list):
150
+ return [self._get(k) for k in feature]
151
+
152
+ return self._get(feature, **kwargs)
153
+
154
+ def plot(self, feature: FeatAndKwargs, title: str | bool = True, **kwargs) -> Axes:
155
+ """Plot feature with appropriate way.
156
+
157
+ NOTE: The drawing method is different in different classes, different in
158
+ different features, and may different between versions, there are NO guarantees.
159
+ """
160
+ feat, feat_kwargs = _get_feat_and_kwargs(feature)
161
+ if callable(custom_plot := getattr(self, f"plot_{feat}", None)):
162
+ ax = custom_plot(feat_kwargs, **kwargs)
163
+ elif feat in Feature1D:
164
+ ax = self._plot_1d(feature, **kwargs)
165
+ else:
166
+ ax = self._plot_histogram(feature, **kwargs) # default plot
167
+
168
+ if isinstance(title, str):
169
+ ax.set_title(title)
170
+ elif title is True:
171
+ ax.set_title(_get_feature_name(feat))
172
+
173
+ return ax
174
+
175
+ # Custom Plots
176
+
177
+ def plot_node_branch_order(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
178
+ vals = self._get("node_branch_order", **feature_kwargs)
179
+ bin_edges = np.arange(int(np.ceil(vals.max() + 1))) + 0.5
180
+ return self._plot_histogram_impl(vals, bin_edges, **kwargs)
181
+
182
+ # Implements
183
+
184
+ def _get(self, feature: FeatAndKwargs, **kwargs) -> NDArrayf32:
185
+ feat, kwargs = _get_feat_and_kwargs(feature, **kwargs)
186
+ if callable(custom_get := getattr(self, f"get_{feat}", None)):
187
+ return custom_get(**kwargs)
188
+
189
+ return self._get_impl(feat, **kwargs) # default
190
+
191
+ def _plot_1d(self, feature: FeatAndKwargs, **kwargs) -> Axes:
192
+ vals = self._get(feature)
193
+ ax = self._plot_1d_impl(vals, **kwargs)
194
+ ax.set_ylabel(_get_feature_name(feature))
195
+ return ax
196
+
197
+ def _plot_histogram(
198
+ self,
199
+ feature: FeatAndKwargs,
200
+ bins=20,
201
+ range=None, # pylint: disable=redefined-builtin
202
+ **kwargs,
203
+ ) -> Axes:
204
+ vals = self._get(feature)
205
+ bin_edges = np.histogram_bin_edges(vals, bins, range)
206
+ return self._plot_histogram_impl(vals, bin_edges, **kwargs)
207
+
208
+ @abstractmethod
209
+ def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
210
+ raise NotImplementedError()
211
+
212
+ @abstractmethod
213
+ def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
214
+ raise NotImplementedError()
215
+
216
+ @abstractmethod
217
+ def _plot_histogram_impl(
218
+ self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
219
+ ) -> Axes:
220
+ raise NotImplementedError()
221
+
222
+ def get_bifurcation_count(self, **kwargs):
223
+ raise DeprecationWarning("Use `furcation_count` instead.")
224
+
225
+ def get_bifurcation_radial_distance(self, **kwargs):
226
+ raise DeprecationWarning("Use `furcation_radial_distance` instead.")
227
+
228
+
229
+ class TreeFeatureExtractor(FeatureExtractor):
230
+ """Extract feature from tree."""
231
+
232
+ _tree: Tree
233
+ _features: Features
234
+
235
+ def __init__(self, tree: Tree) -> None:
236
+ super().__init__()
237
+ self._tree = tree
238
+ self._features = Features(tree)
239
+
240
+ # Custom Features
241
+
242
+ def get_sholl(self, **kwargs) -> NDArrayf32:
243
+ return self._features.sholl.get(**kwargs).astype(np.float32)
244
+
245
+ # Custom Plots
246
+
247
+ def plot_sholl(
248
+ self,
249
+ feature_kwargs: dict[str, Any], # pylint: disable=unused-argument
250
+ **kwargs,
251
+ ) -> Axes:
252
+ _, ax = self._features.sholl.plot(**kwargs)
253
+ return ax
254
+
255
+ # Implements
256
+
257
+ def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
258
+ return self._features.get(feature, **kwargs)
259
+
260
+ def _plot_histogram_impl(
261
+ self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
262
+ ) -> Axes:
263
+ hist, _ = np.histogram(vals[vals != 0], bins=bin_edges)
264
+ x = (bin_edges[:-1] + bin_edges[1:]) / 2
265
+
266
+ ax: Axes = sns.barplot(x=x, y=hist, **kwargs)
267
+ ax.set_ylabel("Count")
268
+ return ax
269
+
270
+ def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
271
+ name = basename(self._tree.source)
272
+ return sns.barplot(x=[name], y=vals.squeeze(), **kwargs)
273
+
274
+
275
+ class PopulationFeatureExtractor(FeatureExtractor):
276
+ """Extract features from population."""
277
+
278
+ _population: Population
279
+ _features: list[Features]
280
+
281
+ def __init__(self, population: Population) -> None:
282
+ super().__init__()
283
+ self._population = population
284
+ self._features = [Features(t) for t in self._population]
285
+
286
+ # Custom Features
287
+
288
+ def get_sholl(self, **kwargs) -> NDArrayf32:
289
+ vals, _ = self._get_sholl_impl(**kwargs)
290
+ return vals
291
+
292
+ # Custom Plots
293
+
294
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
295
+ vals, rs = self._get_sholl_impl(**feature_kwargs)
296
+ ax = self._lineplot(xs=rs, ys=vals.flatten(), **kwargs)
297
+ ax.set_ylabel("Count of Intersections")
298
+ return ax
299
+
300
+ # Implements
301
+
302
+ def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
303
+ vals = [f.get(feature, **kwargs) for f in self._features]
304
+ len_max = max(len(v) for v in vals)
305
+ v = np.stack([padding1d(len_max, v, dtype=np.float32) for v in vals])
306
+ return v
307
+
308
+ def _get_sholl_impl(
309
+ self, steps: int = 20, **kwargs
310
+ ) -> tuple[NDArrayf32, NDArrayf32]:
311
+ rmax = max(t.sholl.rmax for t in self._features)
312
+ rs = Sholl.get_rs(rmax=rmax, steps=steps)
313
+ vals = self._get_impl("sholl", steps=rs, **kwargs)
314
+ return vals, rs
315
+
316
+ def _plot_histogram_impl(
317
+ self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
318
+ ) -> Axes:
319
+ def hist(v):
320
+ return np.histogram(v[v != 0], bins=bin_edges)[0]
321
+
322
+ xs = (bin_edges[:-1] + bin_edges[1:]) / 2
323
+ ys = np.stack([hist(v) for v in vals])
324
+
325
+ ax: Axes = self._lineplot(xs, ys, **kwargs)
326
+ ax.set_ylabel("Count")
327
+ return ax
328
+
329
+ def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
330
+ x = [basename(t.source) for t in self._population]
331
+ y = vals.flatten()
332
+ ax: Axes = sns.barplot(x=x, y=y, **kwargs)
333
+ ax.axhline(y=y.mean(), ls="--", lw=1)
334
+ ax.set_xticks([])
335
+ return ax
336
+
337
+ def _lineplot(self, xs, ys, **kwargs) -> Axes:
338
+ xs = np.tile(xs, len(self._population))
339
+ ys = ys.flatten()
340
+ ax: Axes = sns.lineplot(x=xs, y=ys, **kwargs)
341
+ return ax
342
+
343
+
344
+ class PopulationsFeatureExtractor(FeatureExtractor):
345
+ """Extract feature from population."""
346
+
347
+ _populations: Populations
348
+ _features: list[list[Features]]
349
+
350
+ def __init__(self, populations: Populations) -> None:
351
+ super().__init__()
352
+ self._populations = populations
353
+ self._features = [[Features(t) for t in p] for p in populations.populations]
354
+
355
+ # Custom Features
356
+
357
+ def get_sholl(self, **kwargs) -> NDArrayf32:
358
+ vals, _ = self._get_sholl_impl(**kwargs)
359
+ return vals
360
+
361
+ # Custom Plots
362
+
363
+ def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
364
+ vals, rs = self._get_sholl_impl(**feature_kwargs)
365
+ ax = self._lineplot(xs=rs, ys=vals, **kwargs)
366
+ ax.set_ylabel("Count of Intersections")
367
+ return ax
368
+
369
+ # Implements
370
+
371
+ def _get_impl(self, feature: Feature, **kwargs) -> NDArrayf32:
372
+ vals = [[f.get(feature, **kwargs) for f in fs] for fs in self._features]
373
+ len_max1 = max(len(v) for v in vals)
374
+ len_max2 = max(*chain.from_iterable(((len(vv) for vv in v) for v in vals)))
375
+ out = np.zeros((len(vals), len_max1, len_max2), dtype=np.float32)
376
+ for i, v in enumerate(vals):
377
+ for j, vv in enumerate(v):
378
+ out[i, j, : len(vv)] = vv
379
+
380
+ return out
381
+
382
+ def _get_sholl_impl(
383
+ self, steps: int = 20, **kwargs
384
+ ) -> tuple[NDArrayf32, NDArrayf32]:
385
+ rmaxs = chain.from_iterable((t.sholl.rmax for t in p) for p in self._features)
386
+ rmax = max(rmaxs)
387
+ rs = Sholl.get_rs(rmax=rmax, steps=steps)
388
+ vals = self._get_impl("sholl", steps=rs, **kwargs)
389
+ return vals, rs
390
+
391
+ def _plot_histogram_impl(
392
+ self, vals: NDArrayf32, bin_edges: npt.NDArray, **kwargs
393
+ ) -> Axes:
394
+ def hist(v):
395
+ return np.histogram(v[v != 0], bins=bin_edges)[0]
396
+
397
+ xs = (bin_edges[:-1] + bin_edges[1:]) / 2
398
+ ys = np.stack([np.stack([hist(t) for t in p]) for p in vals])
399
+
400
+ ax = self._lineplot(xs=xs, ys=ys, **kwargs)
401
+ ax.set_ylabel("Count")
402
+ return ax
403
+
404
+ def _plot_1d_impl(self, vals: NDArrayf32, **kwargs) -> Axes:
405
+ labels = self._populations.labels
406
+ xs = np.concatenate([np.full(vals.shape[1], fill_value=i) for i in labels])
407
+ ys = vals.flatten()
408
+
409
+ # The numbers of tree in different populations may not be equal
410
+ valid = ys != 0
411
+ xs, ys = xs[valid], ys[valid]
412
+
413
+ ax: Axes = sns.boxplot(x=xs, y=ys, **kwargs)
414
+ return ax
415
+
416
+ def _lineplot(self, xs, ys, **kwargs) -> Axes:
417
+ p, t, f = ys.shape
418
+ labels = self._populations.labels # (P,)
419
+ x = np.tile(xs, p * t) # (F,) -> (P * T * F)
420
+ y = ys.flatten() # (P, T, F) -> (P * T * F)
421
+ hue = np.concatenate([np.full(t * f, fill_value=i) for i in labels])
422
+
423
+ # The numbers of tree in different populations may not be equal
424
+ valid = np.repeat(np.any(ys != 0, axis=2), f)
425
+ x, y, hue = x[valid], y[valid], hue[valid]
426
+
427
+ ax: Axes = sns.lineplot(x=x, y=y, hue=hue, **kwargs)
428
+ ax.set_ylabel("Count")
429
+ return ax
430
+
431
+
432
+ def extract_feature(obj: Tree | Population | Populations) -> FeatureExtractor:
433
+ if isinstance(obj, Tree):
434
+ return TreeFeatureExtractor(obj)
435
+
436
+ if isinstance(obj, Population):
437
+ return PopulationFeatureExtractor(obj)
438
+
439
+ if isinstance(obj, Populations):
440
+ return PopulationsFeatureExtractor(obj)
441
+
442
+ raise TypeError("Invalid argument type.")
443
+
444
+
445
+ def _get_feat_and_kwargs(feature: FeatAndKwargs, **kwargs):
446
+ if isinstance(feature, tuple):
447
+ return feature[0], {**feature[1], **kwargs}
448
+
449
+ return feature, kwargs
450
+
451
+
452
+ def _get_feature_name(feature: FeatAndKwargs) -> str:
453
+ feat, _ = _get_feat_and_kwargs(feature)
454
+ return feat.replace("_", " ").title()
@@ -0,0 +1,218 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Feature analysis of tree."""
7
+
8
+ from abc import ABC, abstractmethod
9
+ from functools import cached_property
10
+ from typing import TypeVar
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ from typing_extensions import Self, deprecated
15
+
16
+ from swcgeom.core import Branch, BranchTree, Tree
17
+
18
+ __all__ = [
19
+ "NodeFeatures",
20
+ "BifurcationFeatures",
21
+ "TipFeatures",
22
+ "PathFeatures",
23
+ "BranchFeatures",
24
+ ]
25
+
26
+ T = TypeVar("T", bound=Branch)
27
+
28
+ # Node Level
29
+
30
+
31
+ class NodeFeatures:
32
+ """Evaluate node feature of tree."""
33
+
34
+ tree: Tree
35
+
36
+ @cached_property
37
+ def _branch_tree(self) -> BranchTree:
38
+ return BranchTree.from_tree(self.tree)
39
+
40
+ def __init__(self, tree: Tree) -> None:
41
+ self.tree = tree
42
+
43
+ def get_count(self) -> npt.NDArray[np.float32]:
44
+ """Get number of nodes.
45
+
46
+ Args:
47
+ count: array of shape (1,)
48
+ """
49
+ return np.array([self.tree.number_of_nodes()], dtype=np.float32)
50
+
51
+ def get_radial_distance(self) -> npt.NDArray[np.float32]:
52
+ """Get the end-to-end straight-line distance to soma.
53
+
54
+ Returns:
55
+ radial_distance: Array of shape (N,).
56
+ """
57
+ xyz = self.tree.xyz() - self.tree.soma().xyz()
58
+ radial_distance = np.linalg.norm(xyz, axis=1)
59
+ return radial_distance
60
+
61
+ def get_branch_order(self) -> npt.NDArray[np.int32]:
62
+ """Get branch order of criticle nodes of tree.
63
+
64
+ Branch order is the number of bifurcations between current
65
+ position and the root.
66
+
67
+ Criticle node means that soma, bifucation nodes, tips.
68
+
69
+ Returns:
70
+ order: Array of shape (N,), which k is the number of branches.
71
+ """
72
+ order = np.zeros_like(self._branch_tree.id(), dtype=np.int32)
73
+
74
+ def assign_depth(n: Tree.Node, pre_depth: int | None) -> int:
75
+ cur_order = pre_depth + 1 if pre_depth is not None else 0
76
+ order[n.id] = cur_order
77
+ return cur_order
78
+
79
+ self._branch_tree.traverse(enter=assign_depth)
80
+ return order
81
+
82
+
83
+ class _SubsetNodesFeatures(ABC):
84
+ _features: NodeFeatures
85
+
86
+ @property
87
+ @abstractmethod
88
+ def nodes(self) -> npt.NDArray[np.bool_]:
89
+ raise NotImplementedError()
90
+
91
+ def __init__(self, features: NodeFeatures) -> None:
92
+ self._features = features
93
+
94
+ def get_count(self) -> npt.NDArray[np.float32]:
95
+ """Get number of nodes.
96
+
97
+ Returns:
98
+ count: Array of shape (1,).
99
+ """
100
+ return np.array([np.count_nonzero(self.nodes)], dtype=np.float32)
101
+
102
+ def get_radial_distance(self) -> npt.NDArray[np.float32]:
103
+ """Get the end-to-end straight-line distance to soma.
104
+
105
+ Returns:
106
+ radial_distance: Array of shape (N,).
107
+ """
108
+ return self._features.get_radial_distance()[self.nodes]
109
+
110
+ @classmethod
111
+ def from_tree(cls, tree: Tree) -> Self:
112
+ return cls(NodeFeatures(tree))
113
+
114
+
115
+ class FurcationFeatures(_SubsetNodesFeatures):
116
+ """Evaluate furcation node feature of tree."""
117
+
118
+ @cached_property
119
+ def nodes(self) -> npt.NDArray[np.bool_]:
120
+ return np.array([n.is_furcation() for n in self._features.tree])
121
+
122
+
123
+ @deprecated("Use FurcationFeatures instead")
124
+ class BifurcationFeatures(FurcationFeatures):
125
+ """Evaluate bifurcation node feature of tree.
126
+
127
+ NOTE: Deprecated due to the wrong spelling of furcation. For now, it is just an
128
+ alias of `FurcationFeatures` and raise a warning. It will be change to raise an
129
+ error in the future.
130
+ """
131
+
132
+
133
+ class TipFeatures(_SubsetNodesFeatures):
134
+ """Evaluate tip node feature of tree."""
135
+
136
+ @cached_property
137
+ def nodes(self) -> npt.NDArray[np.bool_]:
138
+ return np.array([n.is_tip() for n in self._features.tree])
139
+
140
+
141
+ # Path Level
142
+
143
+
144
+ class PathFeatures:
145
+ """Path analysis of tree."""
146
+
147
+ tree: Tree
148
+
149
+ def __init__(self, tree: Tree) -> None:
150
+ self.tree = tree
151
+
152
+ def get_count(self) -> int:
153
+ return len(self._paths)
154
+
155
+ def get_length(self) -> npt.NDArray[np.float32]:
156
+ """Get length of paths."""
157
+
158
+ length = [path.length() for path in self._paths]
159
+ return np.array(length, dtype=np.float32)
160
+
161
+ def get_tortuosity(self) -> npt.NDArray[np.float32]:
162
+ """Get tortuosity of path."""
163
+
164
+ return np.array([path.tortuosity() for path in self._paths], dtype=np.float32)
165
+
166
+ @cached_property
167
+ def _paths(self) -> list[Tree.Path]:
168
+ return self.tree.get_paths()
169
+
170
+
171
+ class BranchFeatures:
172
+ """Analysis bransh of tree."""
173
+
174
+ tree: Tree
175
+
176
+ def __init__(self, tree: Tree) -> None:
177
+ self.tree = tree
178
+
179
+ def get_count(self) -> int:
180
+ return len(self._branches)
181
+
182
+ def get_length(self) -> npt.NDArray[np.float32]:
183
+ """Get length of branches."""
184
+
185
+ length = [br.length() for br in self._branches]
186
+ return np.array(length, dtype=np.float32)
187
+
188
+ def get_tortuosity(self) -> npt.NDArray[np.float32]:
189
+ """Get tortuosity of path."""
190
+
191
+ return np.array([br.tortuosity() for br in self._branches], dtype=np.float32)
192
+
193
+ def get_angle(self, eps: float = 1e-7) -> npt.NDArray[np.float32]:
194
+ """Get agnle between branches.
195
+
196
+ Returns:
197
+ angle: An array of shape (N, N), which N is length of branches.
198
+ """
199
+ return self.calc_angle(self._branches, eps=eps)
200
+
201
+ @staticmethod
202
+ def calc_angle(branches: list[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
203
+ """Calc agnle between branches.
204
+
205
+ Returns:
206
+ angle: An array of shape (N, N), which N is length of branches.
207
+ """
208
+ vector = np.array([br[-1].xyz() - br[0].xyz() for br in branches])
209
+ vector_dot = np.matmul(vector, vector.T)
210
+ vector_norm = np.linalg.norm(vector, ord=2, axis=1, keepdims=True)
211
+ vector_norm_dot = np.matmul(vector_norm, vector_norm.T) + eps
212
+ arccos = np.clip(vector_dot / vector_norm_dot, -1, 1)
213
+ angle = np.arccos(arccos)
214
+ return angle
215
+
216
+ @cached_property
217
+ def _branches(self) -> list[Tree.Branch]:
218
+ return self.tree.get_branches()