swcgeom 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/_version.py +2 -2
- swcgeom/analysis/feature_extractor.py +179 -170
- swcgeom/analysis/node_features.py +4 -102
- swcgeom/analysis/sholl.py +46 -22
- swcgeom/analysis/visualization.py +102 -65
- swcgeom/core/branch_tree.py +5 -3
- swcgeom/core/node.py +1 -1
- swcgeom/core/population.py +128 -26
- swcgeom/core/swc.py +14 -10
- swcgeom/core/swc_utils.py +41 -20
- swcgeom/core/tree.py +13 -10
- swcgeom/core/tree_utils.py +2 -2
- swcgeom/utils/numpy.py +2 -27
- swcgeom/utils/renderer.py +22 -24
- {swcgeom-0.4.0.dist-info → swcgeom-0.5.0.dist-info}/METADATA +1 -1
- {swcgeom-0.4.0.dist-info → swcgeom-0.5.0.dist-info}/RECORD +19 -19
- {swcgeom-0.4.0.dist-info → swcgeom-0.5.0.dist-info}/LICENSE +0 -0
- {swcgeom-0.4.0.dist-info → swcgeom-0.5.0.dist-info}/WHEEL +0 -0
- {swcgeom-0.4.0.dist-info → swcgeom-0.5.0.dist-info}/top_level.txt +0 -0
swcgeom/analysis/sholl.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Sholl analysis."""
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import Literal, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
@@ -21,52 +22,75 @@ class Sholl:
|
|
|
21
22
|
cortices of the cat J. Anat., 87 (1953), pp. 387-406
|
|
22
23
|
"""
|
|
23
24
|
|
|
24
|
-
count: npt.NDArray[np.
|
|
25
|
+
count: npt.NDArray[np.int64]
|
|
25
26
|
step: float
|
|
27
|
+
steps: npt.NDArray[np.float32]
|
|
26
28
|
|
|
27
29
|
def __init__(self, tree: Tree, step: float = 1) -> None:
|
|
28
30
|
xyz = tree.get_segments().xyz() - tree.soma().xyz() # shift
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
self.count = np.array(count, dtype=np.int32)
|
|
31
|
+
r = np.linalg.norm(xyz, axis=2)
|
|
32
|
+
steps = np.arange(step, int(np.ceil(r.max())), step)
|
|
33
|
+
intersections = [np.logical_and(r[:, 0] <= i, r[:, 1] > i) for i in steps]
|
|
34
|
+
count = np.count_nonzero(intersections, axis=1)
|
|
35
|
+
|
|
36
|
+
self.count = count
|
|
37
37
|
self.step = step
|
|
38
|
+
self.steps = steps
|
|
38
39
|
|
|
39
40
|
def __getitem__(self, idx: int) -> int:
|
|
40
41
|
return self.count[idx] if 0 <= idx < len(self.count) else 0
|
|
41
42
|
|
|
42
|
-
def
|
|
43
|
+
def get(self) -> npt.NDArray[np.int64]:
|
|
43
44
|
return self.count.copy()
|
|
44
45
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
def get_count(self) -> npt.NDArray[np.int32]:
|
|
47
|
+
warnings.warn(
|
|
48
|
+
"`Sholl.get_count` has been renamed to `get` since v0.5.0, "
|
|
49
|
+
"and will be removed in next version",
|
|
50
|
+
DeprecationWarning,
|
|
51
|
+
)
|
|
52
|
+
return self.get().astype(np.int32)
|
|
49
53
|
|
|
50
54
|
def avg(self) -> float:
|
|
51
|
-
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"`Sholl.avg` will be removed in next version",
|
|
57
|
+
DeprecationWarning,
|
|
58
|
+
)
|
|
59
|
+
return self.count.mean()
|
|
52
60
|
|
|
53
61
|
def std(self) -> float:
|
|
54
|
-
|
|
62
|
+
warnings.warn(
|
|
63
|
+
"`Sholl.std` will be removed in next version",
|
|
64
|
+
DeprecationWarning,
|
|
65
|
+
)
|
|
66
|
+
return self.count.std()
|
|
55
67
|
|
|
56
68
|
def sum(self) -> int:
|
|
57
|
-
|
|
69
|
+
warnings.warn(
|
|
70
|
+
"`Sholl.sum` will be removed in next version",
|
|
71
|
+
DeprecationWarning,
|
|
72
|
+
)
|
|
73
|
+
return self.count.sum()
|
|
58
74
|
|
|
59
75
|
def plot(
|
|
60
76
|
self,
|
|
61
|
-
plot_type:
|
|
77
|
+
plot_type: str | None = None,
|
|
78
|
+
kind: Literal["bar", "linechart", "circles"] = "linechart",
|
|
62
79
|
fig: Figure | None = None,
|
|
63
80
|
ax: Axes | None = None,
|
|
64
81
|
**kwargs,
|
|
65
82
|
) -> Tuple[Figure, Axes]:
|
|
83
|
+
if plot_type is not None:
|
|
84
|
+
warnings.warn(
|
|
85
|
+
"`plot_type` has been renamed to `kind` since v0.5.0, "
|
|
86
|
+
"and will be removed in next version",
|
|
87
|
+
DeprecationWarning,
|
|
88
|
+
)
|
|
89
|
+
kind = plot_type # type: ignore
|
|
90
|
+
|
|
91
|
+
x, y = self.steps, self.count
|
|
66
92
|
fig, ax = get_fig_ax(fig, ax)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
match plot_type:
|
|
93
|
+
match kind:
|
|
70
94
|
case "bar":
|
|
71
95
|
kwargs.setdefault("width", self.step)
|
|
72
96
|
ax.bar(x, y, **kwargs)
|
|
@@ -76,6 +100,6 @@ class Sholl:
|
|
|
76
100
|
kwargs.setdefault("y_min", 0)
|
|
77
101
|
draw_circles(fig, ax, x, y, **kwargs)
|
|
78
102
|
case _:
|
|
79
|
-
raise ValueError(f"unsupported plot
|
|
103
|
+
raise ValueError(f"unsupported plot kind: {kind}")
|
|
80
104
|
|
|
81
105
|
return fig, ax
|
|
@@ -2,14 +2,22 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import weakref
|
|
5
|
-
from typing import Any, Dict, Literal, Tuple
|
|
5
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from matplotlib.axes import Axes
|
|
9
9
|
from matplotlib.figure import Figure
|
|
10
|
+
from matplotlib.legend import Legend
|
|
10
11
|
|
|
11
12
|
from ..core import SWCLike, Tree
|
|
12
|
-
from ..utils import
|
|
13
|
+
from ..utils import (
|
|
14
|
+
Camera,
|
|
15
|
+
Vec3f,
|
|
16
|
+
draw_direction_indicator,
|
|
17
|
+
draw_lines,
|
|
18
|
+
get_fig_ax,
|
|
19
|
+
palette,
|
|
20
|
+
)
|
|
13
21
|
|
|
14
22
|
__all__ = ["draw"]
|
|
15
23
|
|
|
@@ -22,24 +30,31 @@ CameraPresets: Dict[CameraPreset, Camera] = {
|
|
|
22
30
|
"zy": Camera((0, 0, 0), (-1, 0, 0), (0, 0, -1)),
|
|
23
31
|
"xz": Camera((0, 0, 0), (0, -1, 0), (-1, 0, 0)),
|
|
24
32
|
}
|
|
25
|
-
CameraOptions =
|
|
26
|
-
|
|
27
|
-
|
|
33
|
+
CameraOptions = Vec3f | Tuple[Vec3f, Vec3f] | Tuple[Vec3f, Vec3f, Vec3f]
|
|
34
|
+
Positions = Literal["lt", "lb", "rt", "rb"] | Tuple[float, float]
|
|
35
|
+
positions = {
|
|
36
|
+
"lt": (0.10, 0.90),
|
|
37
|
+
"lb": (0.10, 0.10),
|
|
38
|
+
"rt": (0.90, 0.90),
|
|
39
|
+
"rb": (0.90, 0.10),
|
|
40
|
+
}
|
|
28
41
|
|
|
29
|
-
|
|
42
|
+
ax_weak_memo = weakref.WeakKeyDictionary[Axes, Dict[str, Any]]({})
|
|
30
43
|
|
|
31
44
|
|
|
32
45
|
def draw(
|
|
33
46
|
swc: SWCLike | str,
|
|
34
47
|
*,
|
|
35
|
-
fig: Figure
|
|
36
|
-
ax: Axes
|
|
48
|
+
fig: Optional[Figure] = None,
|
|
49
|
+
ax: Optional[Axes] = None,
|
|
37
50
|
camera: CameraOptions | CameraPreset = "xy",
|
|
38
|
-
color: Dict[int, str] | str
|
|
39
|
-
label: str |
|
|
51
|
+
color: Optional[Dict[int, str] | str] = None,
|
|
52
|
+
label: str | bool = True,
|
|
53
|
+
direction_indicator: Positions | Literal[False] = "rb",
|
|
54
|
+
unit: Optional[str] = None,
|
|
40
55
|
**kwargs,
|
|
41
56
|
) -> tuple[Figure, Axes]:
|
|
42
|
-
"""Draw neuron tree.
|
|
57
|
+
r"""Draw neuron tree.
|
|
43
58
|
|
|
44
59
|
Parameters
|
|
45
60
|
----------
|
|
@@ -59,6 +74,11 @@ def draw(
|
|
|
59
74
|
parent node.If is string, the value will be use for any type.
|
|
60
75
|
label : str | bool, default True
|
|
61
76
|
Label of legend, disable if False.
|
|
77
|
+
direction_indicator : str or (float, float), default 'rb'
|
|
78
|
+
Draw a xyz direction indicator, can be place on 'lt', 'lb',
|
|
79
|
+
'rt', 'rb', or custom position.
|
|
80
|
+
unit : optional[str]
|
|
81
|
+
Add unit text, e.g.: r"$\mu m$".
|
|
62
82
|
**kwargs : dict[str, Unknown]
|
|
63
83
|
Forwarded to `~matplotlib.collections.LineCollection`.
|
|
64
84
|
|
|
@@ -68,70 +88,56 @@ def draw(
|
|
|
68
88
|
ax : ~matplotlib.axes.Axes
|
|
69
89
|
"""
|
|
70
90
|
|
|
71
|
-
|
|
72
|
-
ax_weak_dict.setdefault(ax, {})
|
|
73
|
-
ax_weak_dict[ax].setdefault("swc", [])
|
|
74
|
-
|
|
75
|
-
if isinstance(swc, str):
|
|
76
|
-
swc = Tree.from_swc(swc)
|
|
77
|
-
ax_weak_dict[ax]["swc"].append(swc)
|
|
91
|
+
swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
|
|
78
92
|
|
|
79
|
-
|
|
80
|
-
|
|
93
|
+
fig, ax = get_fig_ax(fig, ax)
|
|
94
|
+
my_camera = _get_camera(camera)
|
|
95
|
+
my_color = get_ax_color(ax, swc, color)
|
|
81
96
|
|
|
82
97
|
xyz = swc.xyz()
|
|
83
98
|
starts, ends = swc.id()[1:], swc.pid()[1:]
|
|
84
99
|
lines = np.stack([xyz[starts], xyz[ends]], axis=1)
|
|
85
|
-
|
|
86
100
|
collection = draw_lines(lines, ax=ax, camera=my_camera, color=my_color, **kwargs)
|
|
87
101
|
|
|
88
|
-
# legend
|
|
89
|
-
ax_weak_dict[ax].setdefault("handles", [])
|
|
90
|
-
ax_weak_dict[ax]["handles"].append(collection)
|
|
91
|
-
set_lable(ax, swc, label)
|
|
92
|
-
|
|
93
102
|
ax.autoscale()
|
|
103
|
+
_set_ax_memo(ax, swc, label=label, handle=collection) # legend
|
|
94
104
|
|
|
95
|
-
if len(
|
|
105
|
+
if len(get_ax_swc(ax)) == 1:
|
|
96
106
|
ax.set_aspect(1)
|
|
97
107
|
ax.spines[["top", "right"]].set_visible(False)
|
|
98
|
-
|
|
99
|
-
|
|
108
|
+
if direction_indicator is not False:
|
|
109
|
+
p = (
|
|
110
|
+
positions[direction_indicator]
|
|
111
|
+
if isinstance(direction_indicator, str)
|
|
112
|
+
else direction_indicator
|
|
113
|
+
)
|
|
114
|
+
draw_direction_indicator(ax=ax, camera=my_camera, position=p)
|
|
115
|
+
if unit is not None:
|
|
116
|
+
ax.text(0.05, 0.95, unit, transform=ax.transAxes)
|
|
100
117
|
else:
|
|
101
|
-
# legend
|
|
102
|
-
handles = ax_weak_dict[ax].get("handles", [])
|
|
103
|
-
labels = ax_weak_dict[ax].get("labels", [])
|
|
104
|
-
ax.legend(handles, labels, loc="upper right")
|
|
118
|
+
set_ax_legend(ax, loc="upper right") # enable legend
|
|
105
119
|
|
|
106
120
|
return fig, ax
|
|
107
121
|
|
|
108
122
|
|
|
109
|
-
def
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if len(camera) == 1:
|
|
114
|
-
return Camera((0, 0, 0), camera, (0, 1, 0))
|
|
115
|
-
|
|
116
|
-
if len(camera) == 2:
|
|
117
|
-
return Camera((0, 0, 0), camera[0], camera[1])
|
|
118
|
-
|
|
119
|
-
return Camera(*camera)
|
|
123
|
+
def get_ax_swc(ax: Axes) -> List[SWCLike]:
|
|
124
|
+
ax_weak_memo.setdefault(ax, {})
|
|
125
|
+
return ax_weak_memo[ax]["swc"]
|
|
120
126
|
|
|
121
127
|
|
|
122
|
-
def
|
|
123
|
-
ax: Axes, swc: SWCLike, color: Dict[int, str] | str
|
|
124
|
-
) -> str |
|
|
128
|
+
def get_ax_color(
|
|
129
|
+
ax: Axes, swc: SWCLike, color: Optional[Dict[int, str] | str] = None
|
|
130
|
+
) -> str | List[str]:
|
|
125
131
|
if color == "vaa3d":
|
|
126
132
|
color = palette.vaa3d
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
return color
|
|
133
|
+
elif isinstance(color, str):
|
|
134
|
+
return color # user specified
|
|
130
135
|
|
|
131
136
|
# choose default
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
137
|
+
ax_weak_memo.setdefault(ax, {})
|
|
138
|
+
ax_weak_memo[ax].setdefault("color", -1)
|
|
139
|
+
ax_weak_memo[ax]["color"] += 1
|
|
140
|
+
c = palette.default[ax_weak_memo[ax]["color"] % len(palette.default)]
|
|
135
141
|
|
|
136
142
|
if isinstance(color, dict):
|
|
137
143
|
types = swc.type()[:-1] # colored by type of parent node
|
|
@@ -140,17 +146,48 @@ def get_color(
|
|
|
140
146
|
return c
|
|
141
147
|
|
|
142
148
|
|
|
143
|
-
def
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
149
|
+
def set_ax_legend(ax: Axes, *args, **kwargs) -> Legend | None:
|
|
150
|
+
labels = ax_weak_memo[ax].get("labels", [])
|
|
151
|
+
handles = ax_weak_memo[ax].get("handles", [])
|
|
152
|
+
|
|
153
|
+
# filter `label = False`
|
|
154
|
+
handles = [a for i, a in enumerate(handles) if labels[i] != False]
|
|
155
|
+
labels = [a for i, a in enumerate(labels) if labels[i] != False]
|
|
156
|
+
|
|
157
|
+
if len(labels) > 0:
|
|
158
|
+
return ax.legend(handles, labels, *args, **kwargs)
|
|
159
|
+
else:
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _set_ax_memo(
|
|
164
|
+
ax: Axes,
|
|
165
|
+
swc: SWCLike,
|
|
166
|
+
label: Optional[str | bool] = None,
|
|
167
|
+
handle: Optional[Any] = None,
|
|
168
|
+
):
|
|
169
|
+
ax_weak_memo.setdefault(ax, {})
|
|
170
|
+
ax_weak_memo[ax].setdefault("swc", [])
|
|
171
|
+
ax_weak_memo[ax]["swc"].append(swc)
|
|
172
|
+
|
|
173
|
+
if label is not None:
|
|
174
|
+
label = os.path.basename(swc.source) if label is True else label
|
|
175
|
+
ax_weak_memo[ax].setdefault("labels", [])
|
|
176
|
+
ax_weak_memo[ax]["labels"].append(label)
|
|
177
|
+
|
|
178
|
+
if handle is not None:
|
|
179
|
+
ax_weak_memo[ax].setdefault("handles", [])
|
|
180
|
+
ax_weak_memo[ax]["handles"].append(handle)
|
|
181
|
+
|
|
148
182
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
183
|
+
def _get_camera(camera: CameraOptions | CameraPreset) -> Camera:
|
|
184
|
+
if isinstance(camera, str):
|
|
185
|
+
return CameraPresets[camera]
|
|
186
|
+
|
|
187
|
+
if len(camera) == 1:
|
|
188
|
+
return Camera((0, 0, 0), camera, (0, 1, 0))
|
|
155
189
|
|
|
156
|
-
|
|
190
|
+
if len(camera) == 2:
|
|
191
|
+
return Camera((0, 0, 0), camera[0], camera[1])
|
|
192
|
+
|
|
193
|
+
return Camera(*camera)
|
swcgeom/core/branch_tree.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
import itertools
|
|
4
4
|
from typing import Dict, List
|
|
5
5
|
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
6
8
|
from .branch import Branch
|
|
7
9
|
from .tree import Tree
|
|
8
10
|
from .tree_utils import to_sub_tree
|
|
@@ -50,7 +52,7 @@ class BranchTree(Tree):
|
|
|
50
52
|
|
|
51
53
|
return branch_tree
|
|
52
54
|
|
|
53
|
-
@
|
|
54
|
-
def
|
|
55
|
-
tree = super().
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_data_frame(cls, df: pd.DataFrame, *args, **kwargs) -> "BranchTree":
|
|
57
|
+
tree = super().from_data_frame(df, *args, **kwargs)
|
|
56
58
|
return BranchTree.from_tree(tree)
|
swcgeom/core/node.py
CHANGED
|
@@ -101,7 +101,7 @@ class Node(Generic[SWCTypeVar]):
|
|
|
101
101
|
return self.attach.id()[self.attach.pid() == self.id]
|
|
102
102
|
|
|
103
103
|
def is_bifurcation(self) -> bool:
|
|
104
|
-
return
|
|
104
|
+
return np.count_nonzero(self.attach.pid() == self.id) > 1
|
|
105
105
|
|
|
106
106
|
def is_tip(self) -> bool:
|
|
107
107
|
return self.id not in self.attach.pid()
|
swcgeom/core/population.py
CHANGED
|
@@ -1,26 +1,35 @@
|
|
|
1
1
|
"""Neuron population is a set of tree."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import warnings
|
|
5
|
+
from functools import reduce
|
|
6
|
+
from itertools import chain
|
|
7
|
+
from typing import Any, Dict, Iterator, List, Optional, cast, overload
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Self
|
|
5
10
|
|
|
6
11
|
from .swc import eswc_cols
|
|
7
12
|
from .tree import Tree
|
|
8
13
|
|
|
9
|
-
__all__ = ["Population"]
|
|
14
|
+
__all__ = ["Population", "Populations"]
|
|
10
15
|
|
|
11
16
|
|
|
12
17
|
class Population:
|
|
13
18
|
"""Neuron population."""
|
|
14
19
|
|
|
15
|
-
|
|
20
|
+
root: str
|
|
16
21
|
swcs: List[str]
|
|
17
22
|
trees: List[Tree | None]
|
|
18
|
-
|
|
23
|
+
kwargs: Dict[str, Any]
|
|
19
24
|
|
|
20
|
-
def __init__(
|
|
25
|
+
def __init__(
|
|
26
|
+
self, swcs: List[str], lazy_loading: bool = True, root: str = "", **kwargs
|
|
27
|
+
) -> None:
|
|
21
28
|
super().__init__()
|
|
29
|
+
self.root = root
|
|
22
30
|
self.swcs = swcs
|
|
23
31
|
self.trees = [None for _ in swcs]
|
|
32
|
+
self.kwargs = kwargs
|
|
24
33
|
if not lazy_loading:
|
|
25
34
|
self.load(slice(len(swcs)))
|
|
26
35
|
|
|
@@ -47,7 +56,7 @@ class Population:
|
|
|
47
56
|
return (self[i] for i in range(self.__len__()))
|
|
48
57
|
|
|
49
58
|
def __repr__(self) -> str:
|
|
50
|
-
return f"Neuron population in '{self.
|
|
59
|
+
return f"Neuron population in '{self.root}'"
|
|
51
60
|
|
|
52
61
|
# fmt:off
|
|
53
62
|
@overload
|
|
@@ -65,28 +74,121 @@ class Population:
|
|
|
65
74
|
|
|
66
75
|
for i in idx:
|
|
67
76
|
if self.trees[i] is None:
|
|
68
|
-
self.trees[i] = Tree.from_swc(self.swcs[i], **self.
|
|
69
|
-
|
|
70
|
-
@
|
|
71
|
-
def from_swc(
|
|
72
|
-
swcs =
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return
|
|
77
|
+
self.trees[i] = Tree.from_swc(self.swcs[i], **self.kwargs)
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def from_swc(cls, root: str, ext: str = ".swc", **kwargs) -> Self:
|
|
81
|
+
swcs = cls.find_swcs(root, ext)
|
|
82
|
+
if len(swcs) == 0:
|
|
83
|
+
warnings.warn(f"no trees in population from '{root}'")
|
|
84
|
+
return Population(swcs, root=root, **kwargs)
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_eswc(
|
|
88
|
+
cls, root: str, ext: str = ".eswc", extra_cols: List[str] = [], **kwargs
|
|
89
|
+
) -> Self:
|
|
90
|
+
extra_cols.extend(k for k, t in eswc_cols)
|
|
91
|
+
return cls.from_swc(root, ext, extra_cols=extra_cols, **kwargs)
|
|
83
92
|
|
|
84
93
|
@staticmethod
|
|
85
|
-
def find_swcs(
|
|
94
|
+
def find_swcs(root: str, ext: str = ".swc", relpath: bool = False) -> List[str]:
|
|
86
95
|
"""Find all swc files."""
|
|
87
|
-
swcs =
|
|
88
|
-
for
|
|
89
|
-
|
|
90
|
-
|
|
96
|
+
swcs: List[str] = []
|
|
97
|
+
for r, _, files in os.walk(root):
|
|
98
|
+
rr = os.path.relpath(r, root) if relpath else r
|
|
99
|
+
fs = filter(lambda f: os.path.splitext(f)[-1] == ext, files)
|
|
100
|
+
swcs.extend(os.path.join(rr, f) for f in fs)
|
|
91
101
|
|
|
92
102
|
return swcs
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Populations:
|
|
106
|
+
len: int
|
|
107
|
+
populations: List[Population]
|
|
108
|
+
labels: List[str]
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self, populations: List[Population], labels: Optional[List[str]] = None
|
|
112
|
+
) -> None:
|
|
113
|
+
self.len = min(len(p) for p in populations)
|
|
114
|
+
self.populations = populations
|
|
115
|
+
|
|
116
|
+
labels = labels or ["" for i in populations]
|
|
117
|
+
assert len(labels) == len(
|
|
118
|
+
populations
|
|
119
|
+
), f"got {len(populations)} populations, but {len(labels)} labels"
|
|
120
|
+
self.labels = labels
|
|
121
|
+
|
|
122
|
+
# fmt:off
|
|
123
|
+
@overload
|
|
124
|
+
def __getitem__(self, key: slice) -> List[List[Tree]]: ...
|
|
125
|
+
@overload
|
|
126
|
+
def __getitem__(self, key: int) -> List[Tree]: ...
|
|
127
|
+
# fmt:on
|
|
128
|
+
def __getitem__(self, key):
|
|
129
|
+
return [p[key] for p in self.populations]
|
|
130
|
+
|
|
131
|
+
def __len__(self) -> int:
|
|
132
|
+
return self.len
|
|
133
|
+
|
|
134
|
+
def __iter__(self) -> Iterator[List[Tree]]:
|
|
135
|
+
return (self[i] for i in range(self.len))
|
|
136
|
+
|
|
137
|
+
def __repr__(self) -> str:
|
|
138
|
+
return (
|
|
139
|
+
f"A cluster of {self.num_of_populations()} neuron populations, "
|
|
140
|
+
f"each containing at least {self.len} trees"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def num_of_populations(self) -> int:
|
|
144
|
+
return len(self.populations)
|
|
145
|
+
|
|
146
|
+
def to_population(self) -> Population:
|
|
147
|
+
swcs = list(chain.from_iterable(p.swcs for p in self.populations))
|
|
148
|
+
return Population(swcs)
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def from_swc(
|
|
152
|
+
cls,
|
|
153
|
+
roots: List[str],
|
|
154
|
+
ext: str = ".swc",
|
|
155
|
+
intersect: bool = True,
|
|
156
|
+
check_same: bool = True,
|
|
157
|
+
labels: Optional[List[str]] = None,
|
|
158
|
+
**kwargs,
|
|
159
|
+
) -> Self:
|
|
160
|
+
"""Get population from dirs.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
roots : list of str
|
|
165
|
+
intersect : bool, default `False`
|
|
166
|
+
Take the intersection of these populations.
|
|
167
|
+
check_same : bool, default `True`
|
|
168
|
+
Check if the directories contains the same swc.
|
|
169
|
+
labels : List of str, optional
|
|
170
|
+
Label of populations.
|
|
171
|
+
**kwargs : Any
|
|
172
|
+
Forwarding to `Population`.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
fs = [Population.find_swcs(d, ext=ext, relpath=True) for d in roots]
|
|
176
|
+
if intersect:
|
|
177
|
+
inter = list(reduce(lambda a, b: set(a).intersection(set(b)), fs))
|
|
178
|
+
if len(inter) == 0:
|
|
179
|
+
warnings.warn("no intersection among populations")
|
|
180
|
+
|
|
181
|
+
fs = [inter for _ in roots]
|
|
182
|
+
elif check_same:
|
|
183
|
+
assert reduce(lambda a, b: a == b, fs), "not the same among populations"
|
|
184
|
+
|
|
185
|
+
populations = [
|
|
186
|
+
Population([os.path.join(roots[i], p) for p in fs[i]], root=d, **kwargs)
|
|
187
|
+
for i, d in enumerate(roots)
|
|
188
|
+
]
|
|
189
|
+
return cls(populations, labels=labels)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def from_eswc(cls, roots: List[str], extra_cols: List[str] = [], **kwargs) -> Self:
|
|
193
|
+
extra_cols.extend(k for k, t in eswc_cols)
|
|
194
|
+
return cls.from_swc(roots, extra_cols=extra_cols, **kwargs)
|
swcgeom/core/swc.py
CHANGED
|
@@ -8,9 +8,13 @@ import numpy.typing as npt
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import scipy.sparse as sp
|
|
10
10
|
|
|
11
|
-
from .swc_utils import
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
from .swc_utils import (
|
|
12
|
+
check_single_root,
|
|
13
|
+
link_roots_to_nearest_,
|
|
14
|
+
mark_roots_as_somas_,
|
|
15
|
+
reset_index_,
|
|
16
|
+
sort_nodes_,
|
|
17
|
+
)
|
|
14
18
|
|
|
15
19
|
__all__ = ["swc_cols", "eswc_cols", "read_swc", "SWCLike", "SWCTypeVar"]
|
|
16
20
|
|
|
@@ -176,7 +180,7 @@ def read_swc(
|
|
|
176
180
|
swc_file: str,
|
|
177
181
|
extra_cols: List[str] | None = None,
|
|
178
182
|
fix_roots: Literal["somas", "nearest", False] = False,
|
|
179
|
-
|
|
183
|
+
sort_nodes: bool = False,
|
|
180
184
|
reset_index: bool = True,
|
|
181
185
|
) -> pd.DataFrame:
|
|
182
186
|
"""Read swc file.
|
|
@@ -189,7 +193,7 @@ def read_swc(
|
|
|
189
193
|
Read more cols in swc file.
|
|
190
194
|
fix_roots : `somas`|`nearest`|False, default `False`
|
|
191
195
|
Fix multiple roots.
|
|
192
|
-
|
|
196
|
+
sort_nodes : bool, default `False`
|
|
193
197
|
Sort the indices of neuron tree, the index for parent are
|
|
194
198
|
always less than children.
|
|
195
199
|
reset_index : bool, default `True`
|
|
@@ -214,16 +218,16 @@ def read_swc(
|
|
|
214
218
|
if fix_roots is not False and np.count_nonzero(df["pid"] == -1) > 1:
|
|
215
219
|
match fix_roots:
|
|
216
220
|
case "somas":
|
|
217
|
-
|
|
221
|
+
mark_roots_as_somas_(df)
|
|
218
222
|
case "nearest":
|
|
219
|
-
|
|
223
|
+
link_roots_to_nearest_(df)
|
|
220
224
|
case _:
|
|
221
225
|
raise ValueError(f"unknown fix type: {fix_roots}")
|
|
222
226
|
|
|
223
|
-
if
|
|
224
|
-
|
|
227
|
+
if sort_nodes:
|
|
228
|
+
sort_nodes_(df)
|
|
225
229
|
elif reset_index:
|
|
226
|
-
|
|
230
|
+
reset_index_(df)
|
|
227
231
|
|
|
228
232
|
# check swc
|
|
229
233
|
if not check_single_root(df):
|