swcgeom 0.17.0__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/_version.py +2 -2
- swcgeom/analysis/feature_extractor.py +13 -12
- swcgeom/analysis/features.py +4 -4
- swcgeom/analysis/lmeasure.py +5 -5
- swcgeom/analysis/sholl.py +4 -4
- swcgeom/analysis/trunk.py +12 -11
- swcgeom/analysis/visualization.py +9 -9
- swcgeom/analysis/visualization3d.py +85 -0
- swcgeom/analysis/volume.py +4 -4
- swcgeom/core/branch.py +4 -3
- swcgeom/core/branch_tree.py +3 -4
- swcgeom/core/compartment.py +3 -2
- swcgeom/core/node.py +2 -2
- swcgeom/core/path.py +3 -2
- swcgeom/core/population.py +16 -27
- swcgeom/core/swc.py +11 -10
- swcgeom/core/swc_utils/base.py +8 -17
- swcgeom/core/swc_utils/io.py +7 -6
- swcgeom/core/swc_utils/normalizer.py +4 -3
- swcgeom/core/swc_utils/subtree.py +2 -2
- swcgeom/core/tree.py +22 -34
- swcgeom/core/tree_utils.py +11 -10
- swcgeom/core/tree_utils_impl.py +3 -3
- swcgeom/images/augmentation.py +3 -3
- swcgeom/images/folder.py +10 -16
- swcgeom/images/io.py +19 -30
- swcgeom/transforms/image_stack.py +6 -5
- swcgeom/transforms/images.py +2 -3
- swcgeom/transforms/neurolucida_asc.py +4 -6
- swcgeom/transforms/population.py +1 -3
- swcgeom/transforms/tree.py +8 -7
- swcgeom/transforms/tree_assembler.py +4 -3
- swcgeom/utils/ellipse.py +3 -4
- swcgeom/utils/neuromorpho.py +17 -16
- swcgeom/utils/plotter_2d.py +12 -6
- swcgeom/utils/plotter_3d.py +31 -0
- swcgeom/utils/renderer.py +6 -6
- swcgeom/utils/sdf.py +2 -2
- swcgeom/utils/solid_geometry.py +1 -3
- swcgeom/utils/transforms.py +1 -3
- swcgeom/utils/volumetric_object.py +8 -10
- {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/METADATA +1 -1
- swcgeom-0.17.1.dist-info/RECORD +67 -0
- swcgeom-0.17.0.dist-info/RECORD +0 -65
- {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/LICENSE +0 -0
- {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/WHEEL +0 -0
- {swcgeom-0.17.0.dist-info → swcgeom-0.17.1.dist-info}/top_level.txt +0 -0
swcgeom/_version.py
CHANGED
|
@@ -7,10 +7,11 @@ naming specification.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
+
from collections.abc import Callable
|
|
10
11
|
from functools import cached_property
|
|
11
12
|
from itertools import chain
|
|
12
13
|
from os.path import basename
|
|
13
|
-
from typing import Any,
|
|
14
|
+
from typing import Any, Literal, overload
|
|
14
15
|
|
|
15
16
|
import numpy as np
|
|
16
17
|
import numpy.typing as npt
|
|
@@ -54,7 +55,7 @@ Feature = Literal[
|
|
|
54
55
|
]
|
|
55
56
|
|
|
56
57
|
NDArrayf32 = npt.NDArray[np.float32]
|
|
57
|
-
FeatAndKwargs = Feature |
|
|
58
|
+
FeatAndKwargs = Feature | tuple[Feature, dict[str, Any]]
|
|
58
59
|
|
|
59
60
|
Feature1D = set(["length", "volume", "node_count", "bifurcation_count", "tip_count"])
|
|
60
61
|
|
|
@@ -121,9 +122,9 @@ class FeatureExtractor(ABC):
|
|
|
121
122
|
@overload
|
|
122
123
|
def get(self, feature: Feature, **kwargs) -> NDArrayf32: ...
|
|
123
124
|
@overload
|
|
124
|
-
def get(self, feature:
|
|
125
|
+
def get(self, feature: list[FeatAndKwargs]) -> list[NDArrayf32]: ...
|
|
125
126
|
@overload
|
|
126
|
-
def get(self, feature:
|
|
127
|
+
def get(self, feature: dict[Feature, dict[str, Any]]) -> dict[str, NDArrayf32]: ...
|
|
127
128
|
# fmt:on
|
|
128
129
|
def get(self, feature, **kwargs):
|
|
129
130
|
"""Get feature.
|
|
@@ -168,7 +169,7 @@ class FeatureExtractor(ABC):
|
|
|
168
169
|
|
|
169
170
|
# Custom Plots
|
|
170
171
|
|
|
171
|
-
def plot_node_branch_order(self, feature_kwargs:
|
|
172
|
+
def plot_node_branch_order(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
172
173
|
vals = self._get("node_branch_order", **feature_kwargs)
|
|
173
174
|
bin_edges = np.arange(int(np.ceil(vals.max() + 1))) + 0.5
|
|
174
175
|
return self._plot_histogram_impl(vals, bin_edges, **kwargs)
|
|
@@ -234,7 +235,7 @@ class TreeFeatureExtractor(FeatureExtractor):
|
|
|
234
235
|
|
|
235
236
|
def plot_sholl(
|
|
236
237
|
self,
|
|
237
|
-
feature_kwargs:
|
|
238
|
+
feature_kwargs: dict[str, Any], # pylint: disable=unused-argument
|
|
238
239
|
**kwargs,
|
|
239
240
|
) -> Axes:
|
|
240
241
|
_, ax = self._features.sholl.plot(**kwargs)
|
|
@@ -264,7 +265,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
|
|
|
264
265
|
"""Extract features from population."""
|
|
265
266
|
|
|
266
267
|
_population: Population
|
|
267
|
-
_features:
|
|
268
|
+
_features: list[Features]
|
|
268
269
|
|
|
269
270
|
def __init__(self, population: Population) -> None:
|
|
270
271
|
super().__init__()
|
|
@@ -279,7 +280,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
|
|
|
279
280
|
|
|
280
281
|
# Custom Plots
|
|
281
282
|
|
|
282
|
-
def plot_sholl(self, feature_kwargs:
|
|
283
|
+
def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
283
284
|
vals, rs = self._get_sholl_impl(**feature_kwargs)
|
|
284
285
|
ax = self._lineplot(xs=rs, ys=vals.flatten(), **kwargs)
|
|
285
286
|
ax.set_ylabel("Count of Intersections")
|
|
@@ -295,7 +296,7 @@ class PopulationFeatureExtractor(FeatureExtractor):
|
|
|
295
296
|
|
|
296
297
|
def _get_sholl_impl(
|
|
297
298
|
self, steps: int = 20, **kwargs
|
|
298
|
-
) ->
|
|
299
|
+
) -> tuple[NDArrayf32, NDArrayf32]:
|
|
299
300
|
rmax = max(t.sholl.rmax for t in self._features)
|
|
300
301
|
rs = Sholl.get_rs(rmax=rmax, steps=steps)
|
|
301
302
|
vals = self._get_impl("sholl", steps=rs, **kwargs)
|
|
@@ -333,7 +334,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
|
|
|
333
334
|
"""Extract feature from population."""
|
|
334
335
|
|
|
335
336
|
_populations: Populations
|
|
336
|
-
_features:
|
|
337
|
+
_features: list[list[Features]]
|
|
337
338
|
|
|
338
339
|
def __init__(self, populations: Populations) -> None:
|
|
339
340
|
super().__init__()
|
|
@@ -348,7 +349,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
|
|
|
348
349
|
|
|
349
350
|
# Custom Plots
|
|
350
351
|
|
|
351
|
-
def plot_sholl(self, feature_kwargs:
|
|
352
|
+
def plot_sholl(self, feature_kwargs: dict[str, Any], **kwargs) -> Axes:
|
|
352
353
|
vals, rs = self._get_sholl_impl(**feature_kwargs)
|
|
353
354
|
ax = self._lineplot(xs=rs, ys=vals, **kwargs)
|
|
354
355
|
ax.set_ylabel("Count of Intersections")
|
|
@@ -369,7 +370,7 @@ class PopulationsFeatureExtractor(FeatureExtractor):
|
|
|
369
370
|
|
|
370
371
|
def _get_sholl_impl(
|
|
371
372
|
self, steps: int = 20, **kwargs
|
|
372
|
-
) ->
|
|
373
|
+
) -> tuple[NDArrayf32, NDArrayf32]:
|
|
373
374
|
rmaxs = chain.from_iterable((t.sholl.rmax for t in p) for p in self._features)
|
|
374
375
|
rmax = max(rmaxs)
|
|
375
376
|
rs = Sholl.get_rs(rmax=rmax, steps=steps)
|
swcgeom/analysis/features.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from functools import cached_property
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import TypeVar
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
@@ -163,7 +163,7 @@ class PathFeatures:
|
|
|
163
163
|
return np.array([path.tortuosity() for path in self._paths], dtype=np.float32)
|
|
164
164
|
|
|
165
165
|
@cached_property
|
|
166
|
-
def _paths(self) ->
|
|
166
|
+
def _paths(self) -> list[Tree.Path]:
|
|
167
167
|
return self.tree.get_paths()
|
|
168
168
|
|
|
169
169
|
|
|
@@ -201,7 +201,7 @@ class BranchFeatures:
|
|
|
201
201
|
return self.calc_angle(self._branches, eps=eps)
|
|
202
202
|
|
|
203
203
|
@staticmethod
|
|
204
|
-
def calc_angle(branches:
|
|
204
|
+
def calc_angle(branches: list[T], eps: float = 1e-7) -> npt.NDArray[np.float32]:
|
|
205
205
|
"""Calc agnle between branches.
|
|
206
206
|
|
|
207
207
|
Returns
|
|
@@ -219,5 +219,5 @@ class BranchFeatures:
|
|
|
219
219
|
return angle
|
|
220
220
|
|
|
221
221
|
@cached_property
|
|
222
|
-
def _branches(self) ->
|
|
222
|
+
def _branches(self) -> list[Tree.Branch]:
|
|
223
223
|
return self.tree.get_branches()
|
swcgeom/analysis/lmeasure.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""L-Measure analysis."""
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
from typing import Literal
|
|
4
|
+
from typing import Literal
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
@@ -274,7 +274,7 @@ class LMeasure:
|
|
|
274
274
|
rall_power, _, _, _ = self._rall_power(bif)
|
|
275
275
|
return rall_power
|
|
276
276
|
|
|
277
|
-
def _rall_power_d(self, bif: Tree.Node) ->
|
|
277
|
+
def _rall_power_d(self, bif: Tree.Node) -> tuple[float, float, float]:
|
|
278
278
|
children = bif.children()
|
|
279
279
|
assert len(children) == 2, "Rall Power is only defined for bifurcations"
|
|
280
280
|
parent = bif.parent()
|
|
@@ -284,7 +284,7 @@ class LMeasure:
|
|
|
284
284
|
da, db = 2 * children[0].r, 2 * children[1].r
|
|
285
285
|
return dp, da, db
|
|
286
286
|
|
|
287
|
-
def _rall_power(self, bif: Tree.Node) ->
|
|
287
|
+
def _rall_power(self, bif: Tree.Node) -> tuple[float, float, float, float]:
|
|
288
288
|
dp, da, db = self._rall_power_d(bif)
|
|
289
289
|
start, stop, step = 0, 5, 5 / 1000
|
|
290
290
|
xs = np.arange(start, stop, step)
|
|
@@ -501,7 +501,7 @@ class LMeasure:
|
|
|
501
501
|
|
|
502
502
|
def _bif_vector_local(
|
|
503
503
|
self, bif: Tree.Node
|
|
504
|
-
) ->
|
|
504
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
|
|
505
505
|
children = bif.children()
|
|
506
506
|
assert len(children) == 2, "Only defined for bifurcations"
|
|
507
507
|
|
|
@@ -511,7 +511,7 @@ class LMeasure:
|
|
|
511
511
|
|
|
512
512
|
def _bif_vector_remote(
|
|
513
513
|
self, bif: Tree.Node
|
|
514
|
-
) ->
|
|
514
|
+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
|
|
515
515
|
children = bif.children()
|
|
516
516
|
assert len(children) == 2, "Only defined for bifurcations"
|
|
517
517
|
|
swcgeom/analysis/sholl.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Sholl analysis."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Literal, Optional
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
@@ -84,18 +84,18 @@ class Sholl:
|
|
|
84
84
|
|
|
85
85
|
def plot( # pylint: disable=too-many-arguments
|
|
86
86
|
self,
|
|
87
|
-
steps:
|
|
87
|
+
steps: list[float] | int = 20,
|
|
88
88
|
plot_type: str | None = None,
|
|
89
89
|
kind: Literal["bar", "linechart", "circles"] = "circles",
|
|
90
90
|
fig: Figure | None = None,
|
|
91
91
|
ax: Axes | None = None,
|
|
92
92
|
**kwargs,
|
|
93
|
-
) ->
|
|
93
|
+
) -> tuple[Figure, Axes]:
|
|
94
94
|
"""Plot Sholl analysis.
|
|
95
95
|
|
|
96
96
|
Parameters
|
|
97
97
|
----------
|
|
98
|
-
steps : int or
|
|
98
|
+
steps : int or list[float], default to 20
|
|
99
99
|
Steps of raius of circle. If steps is int, then it will be
|
|
100
100
|
evenly divided into n radii.
|
|
101
101
|
kind : "bar" | "linechart" | "circles", default `circles`
|
swcgeom/analysis/trunk.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
# pylint: disable=invalid-name
|
|
4
4
|
|
|
5
|
+
from collections.abc import Iterable
|
|
5
6
|
from itertools import chain
|
|
6
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Literal, Optional, cast
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
|
@@ -28,28 +29,28 @@ def draw_trunk(
|
|
|
28
29
|
*,
|
|
29
30
|
fig: Optional[Figure] = None,
|
|
30
31
|
ax: Optional[Axes] = None,
|
|
31
|
-
bound: Bounds |
|
|
32
|
-
point: bool |
|
|
32
|
+
bound: Bounds | tuple[Bounds, dict[str, Any]] | None = "ellipse",
|
|
33
|
+
point: bool | dict[str, Any] = True,
|
|
33
34
|
projection: Projection = "2d",
|
|
34
35
|
cmap: Any = "viridis",
|
|
35
36
|
**kwargs,
|
|
36
|
-
) ->
|
|
37
|
+
) -> tuple[Figure, Axes]:
|
|
37
38
|
"""Draw trunk tree.
|
|
38
39
|
|
|
39
40
|
Parameters
|
|
40
41
|
----------
|
|
41
42
|
t : Tree
|
|
42
|
-
florets : List of (int |
|
|
43
|
+
florets : List of (int | list of int)
|
|
43
44
|
The florets that needs to be removed, each floret can be a
|
|
44
45
|
subtree or multiple subtrees (e.g., dendrites are a bunch of
|
|
45
46
|
subtrees), each number is the id of a tree node.
|
|
46
47
|
fig : ~matplotlib.figure.Figure, optional
|
|
47
48
|
ax : ~matplotlib.axes.Axes, optional
|
|
48
|
-
bound : Bounds | (Bounds,
|
|
49
|
+
bound : Bounds | (Bounds, dict[str, Any]) | None, default 'ellipse'
|
|
49
50
|
Kind of bound, support 'aabb', 'ellipse'. If bound is None, no
|
|
50
51
|
bound will be drawn. If bound is a tuple, the second item will
|
|
51
52
|
used as kwargs and forward to draw function.
|
|
52
|
-
point : bool |
|
|
53
|
+
point : bool | dict[str, Any], default True
|
|
53
54
|
Draw point at the start of a subtree. If point is False, no
|
|
54
55
|
point will be drawn. If point is a dict, this will used a
|
|
55
56
|
kwargs and forward to draw function.
|
|
@@ -57,7 +58,7 @@ def draw_trunk(
|
|
|
57
58
|
Colormap, any value supported by ~matplotlib.cm.Colormap. We
|
|
58
59
|
will use the ratio of the length of the subtree to the total
|
|
59
60
|
length of the tree to determine the color.
|
|
60
|
-
**kwargs :
|
|
61
|
+
**kwargs : dict[str, Any]
|
|
61
62
|
Forward to ~swcgeom.analysis.draw.
|
|
62
63
|
"""
|
|
63
64
|
# pylint: disable=too-many-locals
|
|
@@ -83,14 +84,14 @@ def draw_trunk(
|
|
|
83
84
|
|
|
84
85
|
def split_florets(
|
|
85
86
|
t: Tree, florets: Iterable[int | Iterable[int]]
|
|
86
|
-
) ->
|
|
87
|
+
) -> tuple[Tree, list[list[Tree]]]:
|
|
87
88
|
florets = [[i] if isinstance(i, (int, np.integer)) else i for i in florets]
|
|
88
89
|
subtrees = [[get_subtree(t, ff) for ff in f] for f in florets]
|
|
89
90
|
trunk = to_subtree(t, chain(*florets))
|
|
90
91
|
return trunk, subtrees
|
|
91
92
|
|
|
92
93
|
|
|
93
|
-
def get_length_ratio(t: Tree, tss:
|
|
94
|
+
def get_length_ratio(t: Tree, tss: list[list[Tree]]) -> Any:
|
|
94
95
|
lens = np.array([sum(t.length() for t in ts) for ts in tss])
|
|
95
96
|
return lens / t.length()
|
|
96
97
|
|
|
@@ -101,7 +102,7 @@ def get_length_ratio(t: Tree, tss: List[List[Tree]]) -> Any:
|
|
|
101
102
|
def draw_bound(
|
|
102
103
|
ts: Iterable[Tree],
|
|
103
104
|
ax: Axes,
|
|
104
|
-
bound: Bounds |
|
|
105
|
+
bound: Bounds | tuple[Bounds, dict[str, Any]],
|
|
105
106
|
projection: Projection,
|
|
106
107
|
**kwargs,
|
|
107
108
|
) -> None:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import weakref
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from matplotlib.axes import Axes
|
|
@@ -21,15 +21,15 @@ from swcgeom.utils import (
|
|
|
21
21
|
|
|
22
22
|
__all__ = ["draw"]
|
|
23
23
|
|
|
24
|
-
Positions = Literal["lt", "lb", "rt", "rb"] |
|
|
25
|
-
locations:
|
|
24
|
+
Positions = Literal["lt", "lb", "rt", "rb"] | tuple[float, float]
|
|
25
|
+
locations: dict[Literal["lt", "lb", "rt", "rb"], tuple[float, float]] = {
|
|
26
26
|
"lt": (0.10, 0.90),
|
|
27
27
|
"lb": (0.10, 0.10),
|
|
28
28
|
"rt": (0.90, 0.90),
|
|
29
29
|
"rb": (0.90, 0.10),
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
ax_weak_memo = weakref.WeakKeyDictionary[Axes,
|
|
32
|
+
ax_weak_memo = weakref.WeakKeyDictionary[Axes, dict[str, Any]]({})
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def draw(
|
|
@@ -39,7 +39,7 @@ def draw(
|
|
|
39
39
|
ax: Optional[Axes] = None,
|
|
40
40
|
show: bool | None = None,
|
|
41
41
|
camera: CameraOptions = "xy",
|
|
42
|
-
color: Optional[
|
|
42
|
+
color: Optional[dict[int, str] | str] = None,
|
|
43
43
|
label: str | bool = True,
|
|
44
44
|
direction_indicator: Positions | Literal[False] = "rb",
|
|
45
45
|
unit: Optional[str] = None,
|
|
@@ -64,7 +64,7 @@ def draw(
|
|
|
64
64
|
vector, then then threat it as (look-at, up), so camera is
|
|
65
65
|
((0, 0, 0), look-at, up). An easy way is to use the presets
|
|
66
66
|
"xy", "yz" and "zx".
|
|
67
|
-
color :
|
|
67
|
+
color : dict[int, str] | "vaa3d" | str, optional
|
|
68
68
|
Color map. If is dict, segments will be colored by the type of
|
|
69
69
|
parent node.If is string, the value will be use for any type.
|
|
70
70
|
label : str | bool, default True
|
|
@@ -120,14 +120,14 @@ def draw(
|
|
|
120
120
|
return fig, ax
|
|
121
121
|
|
|
122
122
|
|
|
123
|
-
def get_ax_swc(ax: Axes) ->
|
|
123
|
+
def get_ax_swc(ax: Axes) -> list[SWCLike]:
|
|
124
124
|
ax_weak_memo.setdefault(ax, {})
|
|
125
125
|
return ax_weak_memo[ax]["swc"]
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
def get_ax_color(
|
|
129
|
-
ax: Axes, swc: SWCLike, color: Optional[
|
|
130
|
-
) -> str |
|
|
129
|
+
ax: Axes, swc: SWCLike, color: Optional[dict[int, str] | str] = None
|
|
130
|
+
) -> str | list[str]:
|
|
131
131
|
if color == "vaa3d":
|
|
132
132
|
color = palette.vaa3d
|
|
133
133
|
elif isinstance(color, str):
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Painter utils.
|
|
2
|
+
|
|
3
|
+
Notes
|
|
4
|
+
-----
|
|
5
|
+
This is a experimental function, it may be changed in the future.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from matplotlib.axes import Axes
|
|
12
|
+
from matplotlib.figure import Figure
|
|
13
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
14
|
+
|
|
15
|
+
from swcgeom.analysis.visualization import (
|
|
16
|
+
_set_ax_memo,
|
|
17
|
+
get_ax_color,
|
|
18
|
+
get_ax_swc,
|
|
19
|
+
set_ax_legend,
|
|
20
|
+
)
|
|
21
|
+
from swcgeom.core import SWCLike, Tree
|
|
22
|
+
from swcgeom.utils.plotter_3d import draw_lines_3d
|
|
23
|
+
|
|
24
|
+
__all__ = ["draw3d"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO: support Camera
|
|
28
|
+
def draw3d(
|
|
29
|
+
swc: SWCLike | str,
|
|
30
|
+
*,
|
|
31
|
+
ax: Axes,
|
|
32
|
+
show: bool | None = None,
|
|
33
|
+
color: Optional[dict[int, str] | str] = None,
|
|
34
|
+
label: str | bool = True,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> tuple[Figure, Axes]:
|
|
37
|
+
r"""Draw neuron tree.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
swc : SWCLike | str
|
|
42
|
+
If it is str, then it is treated as the path of swc file.
|
|
43
|
+
fig : ~matplotlib.axes.Figure, optional
|
|
44
|
+
ax : ~matplotlib.axes.Axes, optional
|
|
45
|
+
show : bool | None, default `None`
|
|
46
|
+
Wheather to call `plt.show()`. If not specified, it will depend
|
|
47
|
+
on if ax is passed in, it will not be called, otherwise it will
|
|
48
|
+
be called by default.
|
|
49
|
+
color : dict[int, str] | "vaa3d" | str, optional
|
|
50
|
+
Color map. If is dict, segments will be colored by the type of
|
|
51
|
+
parent node.If is string, the value will be use for any type.
|
|
52
|
+
label : str | bool, default True
|
|
53
|
+
Label of legend, disable if False.
|
|
54
|
+
**kwargs : dict[str, Unknown]
|
|
55
|
+
Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
assert isinstance(ax, Axes3D), "only support 3D axes."
|
|
59
|
+
|
|
60
|
+
swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
|
|
61
|
+
|
|
62
|
+
show = (show is True) or (show is None and ax is None)
|
|
63
|
+
my_color = get_ax_color(ax, swc, color) # type: ignore
|
|
64
|
+
|
|
65
|
+
xyz = swc.xyz()
|
|
66
|
+
starts, ends = swc.id()[1:], swc.pid()[1:]
|
|
67
|
+
lines = np.stack([xyz[starts], xyz[ends]], axis=1)
|
|
68
|
+
collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)
|
|
69
|
+
|
|
70
|
+
min_vals = lines.reshape(-1, 3).min(axis=0)
|
|
71
|
+
max_vals = lines.reshape(-1, 3).max(axis=0)
|
|
72
|
+
ax.set_xlim(min_vals[0], max_vals[0])
|
|
73
|
+
ax.set_ylim(min_vals[1], max_vals[1])
|
|
74
|
+
ax.set_zlim(min_vals[2], max_vals[2])
|
|
75
|
+
|
|
76
|
+
_set_ax_memo(ax, swc, label=label, handle=collection)
|
|
77
|
+
|
|
78
|
+
if len(get_ax_swc(ax)) == 1:
|
|
79
|
+
# ax.set_aspect(1)
|
|
80
|
+
ax.spines[["top", "right"]].set_visible(False)
|
|
81
|
+
else:
|
|
82
|
+
set_ax_legend(ax, loc="upper right") # enable legend
|
|
83
|
+
|
|
84
|
+
fig = ax.figure
|
|
85
|
+
return fig, ax # type: ignore
|
swcgeom/analysis/volume.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Analysis of volume of a SWC tree."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from sdflit import ColoredMaterial, ObjectsScene, SDFObject, UniformSampler
|
|
@@ -11,7 +11,7 @@ from swcgeom.utils import VolFrustumCone, VolSphere
|
|
|
11
11
|
__all__ = ["get_volume"]
|
|
12
12
|
|
|
13
13
|
ACCURACY_LEVEL = Literal["low", "middle", "high"]
|
|
14
|
-
ACCURACY_LEVELS:
|
|
14
|
+
ACCURACY_LEVELS: dict[ACCURACY_LEVEL, int] = {"low": 3, "middle": 5, "high": 8}
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def get_volume(
|
|
@@ -93,7 +93,7 @@ def _get_volume_frustum_cone(tree: Tree, *, accuracy: int) -> float:
|
|
|
93
93
|
|
|
94
94
|
volume = 0.0
|
|
95
95
|
|
|
96
|
-
def leave(n: Tree.Node, children:
|
|
96
|
+
def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
|
|
97
97
|
sphere = VolSphere(n.xyz(), n.r)
|
|
98
98
|
cones = [VolFrustumCone(n.xyz(), n.r, c.center, c.radius) for c in children]
|
|
99
99
|
|
|
@@ -129,7 +129,7 @@ def _get_volume_frustum_cone_mc_only(tree: Tree) -> float:
|
|
|
129
129
|
scene = ObjectsScene()
|
|
130
130
|
scene.set_background((0, 0, 0))
|
|
131
131
|
|
|
132
|
-
def leave(n: Tree.Node, children:
|
|
132
|
+
def leave(n: Tree.Node, children: list[VolSphere]) -> VolSphere:
|
|
133
133
|
sphere = VolSphere(n.xyz(), n.r)
|
|
134
134
|
scene.add_object(SDFObject(sphere.sdf, material).into())
|
|
135
135
|
|
swcgeom/core/branch.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Branch is a set of node points."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Generic
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
@@ -92,7 +93,7 @@ class Branch(Path, Generic[SWCTypeVar]):
|
|
|
92
93
|
@classmethod
|
|
93
94
|
def from_xyzr_batch(
|
|
94
95
|
cls, xyzr_batch: npt.NDArray[np.float32]
|
|
95
|
-
) ->
|
|
96
|
+
) -> list["Branch[DictSWC]"]:
|
|
96
97
|
r"""Create list of branch form ~numpy.ndarray.
|
|
97
98
|
|
|
98
99
|
Parameters
|
|
@@ -112,7 +113,7 @@ class Branch(Path, Generic[SWCTypeVar]):
|
|
|
112
113
|
)
|
|
113
114
|
xyzr_batch = np.concatenate([xyzr_batch, ones], axis=2)
|
|
114
115
|
|
|
115
|
-
branches:
|
|
116
|
+
branches: list[Branch[DictSWC]] = []
|
|
116
117
|
for xyzr in xyzr_batch:
|
|
117
118
|
n_nodes = xyzr.shape[0]
|
|
118
119
|
idx = np.arange(0, n_nodes, step=1, dtype=np.int32)
|
swcgeom/core/branch_tree.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Branch tree is a simplified neuron tree."""
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
|
-
from typing import Dict, List
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import pandas as pd
|
|
@@ -19,13 +18,13 @@ class BranchTree(Tree):
|
|
|
19
18
|
A branch tree that contains only soma, branch, and tip nodes.
|
|
20
19
|
"""
|
|
21
20
|
|
|
22
|
-
branches:
|
|
21
|
+
branches: dict[int, list[Branch]]
|
|
23
22
|
|
|
24
|
-
def get_origin_branches(self) ->
|
|
23
|
+
def get_origin_branches(self) -> list[Branch]:
|
|
25
24
|
"""Get branches of original tree."""
|
|
26
25
|
return list(itertools.chain(*self.branches.values()))
|
|
27
26
|
|
|
28
|
-
def get_origin_node_branches(self, idx: int) ->
|
|
27
|
+
def get_origin_node_branches(self, idx: int) -> list[Branch]:
|
|
29
28
|
"""Get branches of node of original tree."""
|
|
30
29
|
return self.branches[idx]
|
|
31
30
|
|
swcgeom/core/compartment.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""The segment is a branch with two nodes."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
@@ -43,7 +44,7 @@ class Compartment(Path, Generic[SWCTypeVar]):
|
|
|
43
44
|
T = TypeVar("T", bound=Compartment)
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
class Compartments(
|
|
47
|
+
class Compartments(list[T]):
|
|
47
48
|
r"""Comparments contains a set of comparment."""
|
|
48
49
|
|
|
49
50
|
names: SWCNames
|
swcgeom/core/node.py
CHANGED
swcgeom/core/path.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Nueron path."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from
|
|
4
|
+
from collections.abc import Iterable, Iterator
|
|
5
|
+
from typing import Generic, overload
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import numpy.typing as npt
|
|
@@ -44,7 +45,7 @@ class Path(SWCLike, Generic[SWCTypeVar]):
|
|
|
44
45
|
@overload
|
|
45
46
|
def __getitem__(self, key: int) -> Node: ...
|
|
46
47
|
@overload
|
|
47
|
-
def __getitem__(self, key: slice) ->
|
|
48
|
+
def __getitem__(self, key: slice) -> list[Node]: ...
|
|
48
49
|
@overload
|
|
49
50
|
def __getitem__(self, key: str) -> npt.NDArray: ...
|
|
50
51
|
# fmt:on
|