swcgeom 0.19.4__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/__init__.py +21 -0
- swcgeom/analysis/__init__.py +13 -0
- swcgeom/analysis/feature_extractor.py +454 -0
- swcgeom/analysis/features.py +218 -0
- swcgeom/analysis/lmeasure.py +750 -0
- swcgeom/analysis/sholl.py +201 -0
- swcgeom/analysis/trunk.py +183 -0
- swcgeom/analysis/visualization.py +191 -0
- swcgeom/analysis/visualization3d.py +81 -0
- swcgeom/analysis/volume.py +143 -0
- swcgeom/core/__init__.py +19 -0
- swcgeom/core/branch.py +129 -0
- swcgeom/core/branch_tree.py +65 -0
- swcgeom/core/compartment.py +107 -0
- swcgeom/core/node.py +130 -0
- swcgeom/core/path.py +155 -0
- swcgeom/core/population.py +341 -0
- swcgeom/core/swc.py +247 -0
- swcgeom/core/swc_utils/__init__.py +19 -0
- swcgeom/core/swc_utils/assembler.py +35 -0
- swcgeom/core/swc_utils/base.py +180 -0
- swcgeom/core/swc_utils/checker.py +107 -0
- swcgeom/core/swc_utils/io.py +204 -0
- swcgeom/core/swc_utils/normalizer.py +163 -0
- swcgeom/core/swc_utils/subtree.py +70 -0
- swcgeom/core/tree.py +384 -0
- swcgeom/core/tree_utils.py +277 -0
- swcgeom/core/tree_utils_impl.py +58 -0
- swcgeom/images/__init__.py +9 -0
- swcgeom/images/augmentation.py +149 -0
- swcgeom/images/contrast.py +87 -0
- swcgeom/images/folder.py +217 -0
- swcgeom/images/io.py +578 -0
- swcgeom/images/loaders/__init__.py +8 -0
- swcgeom/images/loaders/pbd.cpython-313-x86_64-linux-gnu.so +0 -0
- swcgeom/images/loaders/pbd.pyx +523 -0
- swcgeom/images/loaders/raw.cpython-313-x86_64-linux-gnu.so +0 -0
- swcgeom/images/loaders/raw.pyx +183 -0
- swcgeom/transforms/__init__.py +20 -0
- swcgeom/transforms/base.py +136 -0
- swcgeom/transforms/branch.py +223 -0
- swcgeom/transforms/branch_tree.py +74 -0
- swcgeom/transforms/geometry.py +270 -0
- swcgeom/transforms/image_preprocess.py +107 -0
- swcgeom/transforms/image_stack.py +219 -0
- swcgeom/transforms/images.py +206 -0
- swcgeom/transforms/mst.py +183 -0
- swcgeom/transforms/neurolucida_asc.py +498 -0
- swcgeom/transforms/path.py +56 -0
- swcgeom/transforms/population.py +36 -0
- swcgeom/transforms/tree.py +265 -0
- swcgeom/transforms/tree_assembler.py +161 -0
- swcgeom/utils/__init__.py +18 -0
- swcgeom/utils/debug.py +23 -0
- swcgeom/utils/download.py +119 -0
- swcgeom/utils/dsu.py +58 -0
- swcgeom/utils/ellipse.py +131 -0
- swcgeom/utils/file.py +90 -0
- swcgeom/utils/neuromorpho.py +581 -0
- swcgeom/utils/numpy_helper.py +70 -0
- swcgeom/utils/plotter_2d.py +134 -0
- swcgeom/utils/plotter_3d.py +35 -0
- swcgeom/utils/renderer.py +145 -0
- swcgeom/utils/sdf.py +324 -0
- swcgeom/utils/solid_geometry.py +154 -0
- swcgeom/utils/transforms.py +367 -0
- swcgeom/utils/volumetric_object.py +483 -0
- swcgeom-0.19.4.dist-info/METADATA +86 -0
- swcgeom-0.19.4.dist-info/RECORD +72 -0
- swcgeom-0.19.4.dist-info/WHEEL +6 -0
- swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
- swcgeom-0.19.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Sholl analysis."""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import numpy.typing as npt
|
|
13
|
+
import seaborn as sns
|
|
14
|
+
from matplotlib.axes import Axes
|
|
15
|
+
from matplotlib.figure import Figure
|
|
16
|
+
from typing_extensions import deprecated
|
|
17
|
+
|
|
18
|
+
from swcgeom.analysis.visualization import draw
|
|
19
|
+
from swcgeom.core import Tree
|
|
20
|
+
from swcgeom.transforms import TranslateOrigin
|
|
21
|
+
from swcgeom.utils import draw_circles, get_fig_ax
|
|
22
|
+
|
|
23
|
+
__all__ = ["Sholl"]
|
|
24
|
+
|
|
25
|
+
XLABEL = "Radial Distance"
|
|
26
|
+
YLABLE = "Count of Intersections"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Sholl:
|
|
30
|
+
"""Sholl analysis.
|
|
31
|
+
|
|
32
|
+
Implementation of original Sholl analysis as described in [1]_. The Sholl analysis
|
|
33
|
+
is a method to quantify the spatial distribution of neuronal processes. It is
|
|
34
|
+
based on the number of intersections of concentric circles with the neuronal
|
|
35
|
+
processes.
|
|
36
|
+
|
|
37
|
+
References:
|
|
38
|
+
.. [1] Dendritic organization in the neurons of the visual and
|
|
39
|
+
motor cortices of the cat J. Anat., 87 (1953), pp. 387-406
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
tree: Tree
|
|
43
|
+
rs: npt.NDArray[np.float32]
|
|
44
|
+
rmax: float
|
|
45
|
+
|
|
46
|
+
# compat
|
|
47
|
+
step: float | None = None
|
|
48
|
+
|
|
49
|
+
def __init__(self, tree: Tree | str, step: float | None = None) -> None:
|
|
50
|
+
tree = Tree.from_swc(tree) if isinstance(tree, str) else tree
|
|
51
|
+
try:
|
|
52
|
+
self.tree = TranslateOrigin.transform(tree) # shift
|
|
53
|
+
self.rs = np.linalg.norm(self.tree.get_segments().xyz(), axis=2)
|
|
54
|
+
self.rmax = self.rs.max()
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise ValueError(f"invalid tree: {tree.source or ''}") from e
|
|
57
|
+
|
|
58
|
+
if step is not None:
|
|
59
|
+
warnings.warn(
|
|
60
|
+
"`Sholl(x, step=...)` has been replaced by "
|
|
61
|
+
"`Sholl(x).get(steps=...)` since v0.6.0 because it has "
|
|
62
|
+
"been change to dynamic calculate, and will be removed "
|
|
63
|
+
"in next version",
|
|
64
|
+
DeprecationWarning,
|
|
65
|
+
)
|
|
66
|
+
self.step = step
|
|
67
|
+
|
|
68
|
+
def get(self, steps: int | npt.ArrayLike = 20) -> npt.NDArray[np.int64]:
|
|
69
|
+
intersections = [
|
|
70
|
+
np.logical_or(
|
|
71
|
+
np.logical_and(self.rs[:, 0] <= r, self.rs[:, 1] > r),
|
|
72
|
+
np.logical_and(self.rs[:, 1] <= r, self.rs[:, 0] > r),
|
|
73
|
+
)
|
|
74
|
+
for r in self._get_rs(steps=steps)
|
|
75
|
+
]
|
|
76
|
+
return np.count_nonzero(intersections, axis=1)
|
|
77
|
+
|
|
78
|
+
def intersect(self, r: float) -> int:
|
|
79
|
+
return np.count_nonzero(
|
|
80
|
+
np.logical_or(
|
|
81
|
+
np.logical_and(self.rs[:, 0] <= r, self.rs[:, 1] > r),
|
|
82
|
+
np.logical_and(self.rs[:, 1] <= r, self.rs[:, 0] > r),
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def plot( # pylint: disable=too-many-arguments
|
|
87
|
+
self,
|
|
88
|
+
steps: list[float] | int = 20,
|
|
89
|
+
plot_type: str | None = None,
|
|
90
|
+
kind: Literal["bar", "linechart", "circles"] = "circles",
|
|
91
|
+
fig: Figure | None = None,
|
|
92
|
+
ax: Axes | None = None,
|
|
93
|
+
**kwargs,
|
|
94
|
+
) -> tuple[Figure, Axes]:
|
|
95
|
+
"""Plot Sholl analysis.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
steps: Steps of raius of circle.
|
|
99
|
+
If steps is int, then it will be evenly divided into n radii.
|
|
100
|
+
kind: Kind of plot.
|
|
101
|
+
fig: The figure to plot on.
|
|
102
|
+
ax: The axes to plot on.
|
|
103
|
+
**kwargs: Forwarding to plot method.
|
|
104
|
+
"""
|
|
105
|
+
if plot_type is not None:
|
|
106
|
+
warnings.warn(
|
|
107
|
+
"`plot_type` has been renamed to `kind` since v0.5.0, "
|
|
108
|
+
"and will be removed in next version",
|
|
109
|
+
DeprecationWarning,
|
|
110
|
+
)
|
|
111
|
+
kind = plot_type # type: ignore
|
|
112
|
+
|
|
113
|
+
xs = self._get_rs(steps=steps)
|
|
114
|
+
ys = self.get(steps=xs)
|
|
115
|
+
fig, ax = get_fig_ax(fig, ax)
|
|
116
|
+
match kind:
|
|
117
|
+
case "bar":
|
|
118
|
+
sns.barplot(x=xs, y=ys, ax=ax, **kwargs)
|
|
119
|
+
ax.set_ylabel(YLABLE)
|
|
120
|
+
|
|
121
|
+
case "linechart":
|
|
122
|
+
sns.lineplot(x=xs, y=ys, ax=ax, **kwargs)
|
|
123
|
+
ax.set_ylabel(YLABLE)
|
|
124
|
+
|
|
125
|
+
case "circles":
|
|
126
|
+
kwargs.setdefault("y_min", 0)
|
|
127
|
+
drawtree = kwargs.pop("drawtree", True)
|
|
128
|
+
colorbar = kwargs.pop("colorbar", True)
|
|
129
|
+
cmap = kwargs.pop("cmap", "Blues")
|
|
130
|
+
patches = draw_circles(ax, xs, ys, cmap=cmap, **kwargs)
|
|
131
|
+
|
|
132
|
+
if drawtree is True:
|
|
133
|
+
draw(self.tree, ax=ax, direction_indicator=False)
|
|
134
|
+
elif isinstance(drawtree, str):
|
|
135
|
+
draw(self.tree, ax=ax, color=drawtree, direction_indicator=False)
|
|
136
|
+
|
|
137
|
+
if colorbar is True:
|
|
138
|
+
fig.colorbar(patches, ax=ax, label=YLABLE)
|
|
139
|
+
elif isinstance(colorbar, (Axes, np.ndarray, list)):
|
|
140
|
+
fig.colorbar(patches, ax=colorbar, label=YLABLE)
|
|
141
|
+
|
|
142
|
+
case _:
|
|
143
|
+
raise ValueError(f"unsupported kind: {kind}")
|
|
144
|
+
|
|
145
|
+
ax.set_xlabel(XLABEL)
|
|
146
|
+
return fig, ax
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def get_rs(rmax: float, steps: int | npt.ArrayLike) -> npt.NDArray[np.float32]:
|
|
150
|
+
"""Function to calculate the list of radius used by the sholl."""
|
|
151
|
+
if isinstance(steps, int):
|
|
152
|
+
s = rmax / (steps + 1)
|
|
153
|
+
return np.arange(s, rmax, s)
|
|
154
|
+
|
|
155
|
+
return np.array(steps)
|
|
156
|
+
|
|
157
|
+
def _get_rs(self, steps: int | npt.ArrayLike) -> npt.NDArray[np.float32]:
|
|
158
|
+
if self.step is not None: # compat
|
|
159
|
+
return np.arange(self.step, int(np.ceil(self.rmax)), self.step)
|
|
160
|
+
|
|
161
|
+
return self.get_rs(self.rmax, steps)
|
|
162
|
+
|
|
163
|
+
@deprecated("Use `Sholl.get(x)` instead")
|
|
164
|
+
def get_count(self) -> npt.NDArray[np.int32]:
|
|
165
|
+
"""Get the count of intersection.
|
|
166
|
+
|
|
167
|
+
.. deprecated:: 0.5.0
|
|
168
|
+
Use :meth:`Sholl(x).get()` instead.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
return self.get().astype(np.int32)
|
|
172
|
+
|
|
173
|
+
@deprecated("Use `Shool(x).get().mean()` instead")
|
|
174
|
+
def avg(self) -> float:
|
|
175
|
+
"""Get the average of the count of intersection.
|
|
176
|
+
|
|
177
|
+
.. deprecated:: 0.6.0
|
|
178
|
+
Use :meth:`Shool(x).get().mean()` instead.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
return self.get().mean()
|
|
182
|
+
|
|
183
|
+
@deprecated("Use `Shool(x).get().std()` instead")
|
|
184
|
+
def std(self) -> float:
|
|
185
|
+
"""Get the std of the count of intersection.
|
|
186
|
+
|
|
187
|
+
.. deprecated:: 0.6.0
|
|
188
|
+
Use :meth:`Shool(x).get().std()` instead.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
return self.get().std()
|
|
192
|
+
|
|
193
|
+
@deprecated("Use `Shool(x).get().sum()` instead")
|
|
194
|
+
def sum(self) -> int:
|
|
195
|
+
"""Get the sum of the count of intersection.
|
|
196
|
+
|
|
197
|
+
.. deprecated:: 0.6.0
|
|
198
|
+
Use :meth:`Shool(x).get().sum()` instead.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
return self.get().sum()
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Plot trunk and florets."""
|
|
7
|
+
|
|
8
|
+
# pylint: disable=invalid-name
|
|
9
|
+
|
|
10
|
+
from collections.abc import Iterable
|
|
11
|
+
from itertools import chain
|
|
12
|
+
from typing import Any, Literal, cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as npt
|
|
16
|
+
from matplotlib import cm
|
|
17
|
+
from matplotlib.axes import Axes
|
|
18
|
+
from matplotlib.figure import Figure
|
|
19
|
+
from matplotlib.patches import Circle, Ellipse, Patch, Rectangle
|
|
20
|
+
|
|
21
|
+
from swcgeom.analysis.visualization import draw
|
|
22
|
+
from swcgeom.core import Tree, get_subtree, to_subtree
|
|
23
|
+
from swcgeom.utils import get_fig_ax, mvee
|
|
24
|
+
|
|
25
|
+
__all__ = ["draw_trunk"]
|
|
26
|
+
|
|
27
|
+
Bounds = Literal["aabb", "ellipse"]
|
|
28
|
+
Projection = Literal["2d"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def draw_trunk(
|
|
32
|
+
t: Tree,
|
|
33
|
+
florets: Iterable[int | Iterable[int]],
|
|
34
|
+
*,
|
|
35
|
+
fig: Figure | None = None,
|
|
36
|
+
ax: Axes | None = None,
|
|
37
|
+
bound: Bounds | tuple[Bounds, dict[str, Any]] | None = "ellipse",
|
|
38
|
+
point: bool | dict[str, Any] = True,
|
|
39
|
+
projection: Projection = "2d",
|
|
40
|
+
cmap: Any = "viridis",
|
|
41
|
+
**kwargs,
|
|
42
|
+
) -> tuple[Figure, Axes]:
|
|
43
|
+
"""Draw trunk tree.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
t: Tree
|
|
47
|
+
florets: The florets that needs to be removed.
|
|
48
|
+
Each floret can be a subtree or multiple subtrees (e.g., dendrites are a
|
|
49
|
+
bunch of subtrees), each number is the id of a tree node.
|
|
50
|
+
fig: Figure to plot on.
|
|
51
|
+
ax: Axes to plot on.
|
|
52
|
+
bound: Kind of bound, support 'aabb', 'ellipse'.
|
|
53
|
+
If bound is None, no bound will be drawn. If bound is a tuple, the second
|
|
54
|
+
item will used as kwargs and forward to draw function.
|
|
55
|
+
point: Draw point at the start of a subtree.
|
|
56
|
+
If point is False, no point will be drawn. If point is a dict, this will
|
|
57
|
+
used a kwargs and forward to draw function.
|
|
58
|
+
cmap: Colormap.
|
|
59
|
+
Any value supported by ~matplotlib.cm.Colormap. We will use the ratio of
|
|
60
|
+
the length of the subtree to the total length of the tree to determine the
|
|
61
|
+
color.
|
|
62
|
+
**kwargs: Forward to ~swcgeom.analysis.draw.
|
|
63
|
+
"""
|
|
64
|
+
# pylint: disable=too-many-locals
|
|
65
|
+
trunk, tss = split_florets(t, florets)
|
|
66
|
+
lens = get_length_ratio(t, tss)
|
|
67
|
+
|
|
68
|
+
cmap = cm.get_cmap(cmap)
|
|
69
|
+
c = cmap(lens)
|
|
70
|
+
|
|
71
|
+
fig, ax = get_fig_ax(fig, ax)
|
|
72
|
+
if bound is not None:
|
|
73
|
+
for ts, cc in zip(tss, c):
|
|
74
|
+
draw_bound(ts, ax, bound, projection, color=cc)
|
|
75
|
+
|
|
76
|
+
if point is not False:
|
|
77
|
+
point_kwargs = point if isinstance(point, dict) else {}
|
|
78
|
+
for ts, cc in zip(tss, c):
|
|
79
|
+
draw_point(ts, ax, projection, color=cc, **point_kwargs)
|
|
80
|
+
|
|
81
|
+
draw(trunk, ax=ax, color=cmap(1), **kwargs)
|
|
82
|
+
return fig, ax
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def split_florets(
|
|
86
|
+
t: Tree, florets: Iterable[int | Iterable[int]]
|
|
87
|
+
) -> tuple[Tree, list[list[Tree]]]:
|
|
88
|
+
florets = [[i] if isinstance(i, (int, np.integer)) else i for i in florets]
|
|
89
|
+
subtrees = [[get_subtree(t, ff) for ff in f] for f in florets]
|
|
90
|
+
trunk = to_subtree(t, chain(*florets))
|
|
91
|
+
return trunk, subtrees
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_length_ratio(t: Tree, tss: list[list[Tree]]) -> Any:
|
|
95
|
+
lens = np.array([sum(t.length() for t in ts) for ts in tss])
|
|
96
|
+
return lens / t.length()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# Bounds
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def draw_bound(
|
|
103
|
+
ts: Iterable[Tree],
|
|
104
|
+
ax: Axes,
|
|
105
|
+
bound: Bounds | tuple[Bounds, dict[str, Any]],
|
|
106
|
+
projection: Projection,
|
|
107
|
+
**kwargs,
|
|
108
|
+
) -> None:
|
|
109
|
+
kind, bound_kwargs = (bound, {}) if isinstance(bound, str) else bound
|
|
110
|
+
if projection == "2d":
|
|
111
|
+
patch = create_bound_2d(ts, kind, **kwargs, **bound_kwargs)
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(f"unsupported projection {projection}")
|
|
114
|
+
|
|
115
|
+
ax.add_patch(patch)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def create_bound_2d(ts: Iterable[Tree], bound: Bounds, **kwargs) -> Patch:
|
|
119
|
+
xyz = np.concatenate([t.xyz() for t in ts])
|
|
120
|
+
xy = xyz[:, :2] # TODO: camera
|
|
121
|
+
|
|
122
|
+
if bound == "aabb":
|
|
123
|
+
return create_aabb_2d(xy, **kwargs)
|
|
124
|
+
if bound == "ellipse":
|
|
125
|
+
return create_ellipse_2d(xy, **kwargs)
|
|
126
|
+
raise ValueError(f"unsupported bound `{bound}` in 2d projection")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def create_aabb_2d(xy: npt.NDArray, fill: bool = False, **kwargs) -> Rectangle:
|
|
130
|
+
xmin, ymin = xy[:, 0].min(), xy[:, 1].min()
|
|
131
|
+
xmax, ymax = xy[:, 0].max(), xy[:, 1].max()
|
|
132
|
+
width, height = xmax - xmin, ymax - ymin
|
|
133
|
+
rect = Rectangle(
|
|
134
|
+
xy=(xmin, ymin), width=width, height=height, angle=0, fill=fill, **kwargs
|
|
135
|
+
)
|
|
136
|
+
return rect
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def create_ellipse_2d(xy: npt.NDArray, fill: bool = False, **kwargs) -> Ellipse:
|
|
140
|
+
ellipse = mvee(xy)
|
|
141
|
+
patch = Ellipse(
|
|
142
|
+
xy=ellipse.centroid, # type:ignore
|
|
143
|
+
width=ellipse.a,
|
|
144
|
+
height=ellipse.b,
|
|
145
|
+
angle=ellipse.alpha,
|
|
146
|
+
fill=fill,
|
|
147
|
+
**kwargs,
|
|
148
|
+
)
|
|
149
|
+
return patch
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# point
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def draw_point(ts: Iterable[Tree], ax: Axes, projection: Projection, **kwargs) -> None:
|
|
156
|
+
if projection == "2d":
|
|
157
|
+
patch = create_point_2d(ts, **kwargs)
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(f"unsupported projection {projection}")
|
|
160
|
+
|
|
161
|
+
ax.add_patch(patch)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def create_point_2d(
|
|
165
|
+
ts: Iterable[Tree], radius: float | None = None, **kwargs
|
|
166
|
+
) -> Circle:
|
|
167
|
+
if radius is None:
|
|
168
|
+
xyz = np.concatenate([t.xyz() for t in ts]) # TODO: cache
|
|
169
|
+
radius = 0.05 * min(
|
|
170
|
+
xyz[:, 0].max() - xyz[:, 0].min(),
|
|
171
|
+
xyz[:, 1].max() - xyz[:, 1].min(),
|
|
172
|
+
)
|
|
173
|
+
radius = cast(float, radius)
|
|
174
|
+
|
|
175
|
+
center = np.mean([t.xyz()[0, :2] for t in ts], axis=0)
|
|
176
|
+
return Circle(center, radius, **kwargs)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# Helpers
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_dendrites(tree: Tree) -> Iterable[int]:
|
|
183
|
+
return (t.node(0).id for t in tree.get_dendrites())
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
#
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
"""Painter utils."""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
import weakref
|
|
24
|
+
from typing import Any, Literal
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
from matplotlib.axes import Axes
|
|
28
|
+
from matplotlib.figure import Figure
|
|
29
|
+
from matplotlib.legend import Legend
|
|
30
|
+
|
|
31
|
+
from swcgeom.core import SWCLike, Tree
|
|
32
|
+
from swcgeom.utils import (
|
|
33
|
+
CameraOptions,
|
|
34
|
+
SimpleCamera,
|
|
35
|
+
draw_direction_indicator,
|
|
36
|
+
draw_lines,
|
|
37
|
+
get_fig_ax,
|
|
38
|
+
palette,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
__all__ = ["draw"]
|
|
42
|
+
|
|
43
|
+
Positions = Literal["lt", "lb", "rt", "rb"] | tuple[float, float]
|
|
44
|
+
locations: dict[Literal["lt", "lb", "rt", "rb"], tuple[float, float]] = {
|
|
45
|
+
"lt": (0.10, 0.90),
|
|
46
|
+
"lb": (0.10, 0.10),
|
|
47
|
+
"rt": (0.90, 0.90),
|
|
48
|
+
"rb": (0.90, 0.10),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
ax_weak_memo = weakref.WeakKeyDictionary[Axes, dict[str, Any]]({})
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def draw(
|
|
55
|
+
swc: SWCLike | str,
|
|
56
|
+
*,
|
|
57
|
+
fig: Figure | None = None,
|
|
58
|
+
ax: Axes | None = None,
|
|
59
|
+
show: bool | None = None,
|
|
60
|
+
camera: CameraOptions = "xy",
|
|
61
|
+
color: dict[int, str] | str | None = None,
|
|
62
|
+
label: str | bool = True,
|
|
63
|
+
direction_indicator: Positions | Literal[False] = "rb",
|
|
64
|
+
unit: str | None = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> tuple[Figure, Axes]:
|
|
67
|
+
r"""Draw neuron tree.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
swc: The swc tree to draw.
|
|
71
|
+
If it is str, then it is treated as the path of swc file.
|
|
72
|
+
fig: The figure to plot on.
|
|
73
|
+
ax: The axes to plot on.
|
|
74
|
+
show: Weather to call `plt.show()`.
|
|
75
|
+
If not specified, it will depend on if ax is passed in, it will not be
|
|
76
|
+
called, otherwise it will be called by default.
|
|
77
|
+
camera: Camera options (position, look-at, up).
|
|
78
|
+
One, two, or three vectors are supported, if only one vector, then threat
|
|
79
|
+
it as look-at, so camera is ((0, 0, 0), look-at, (0, 1, 0)); if two vector,
|
|
80
|
+
then then threat it as (look-at, up), so camera is ((0, 0, 0), look-at, up).
|
|
81
|
+
An easy way is to use the presets "xy", "yz" and "zx".
|
|
82
|
+
color: Color map.
|
|
83
|
+
If is dict, segments will be colored by the type of parent node.If is
|
|
84
|
+
string, the value will be use for any type.
|
|
85
|
+
label: Label of legend, disable if False.
|
|
86
|
+
direction_indicator: Draw a xyz direction indicator.
|
|
87
|
+
Can be place on 'lt', 'lb', 'rt', 'rb', or custom position.
|
|
88
|
+
unit: str, optional
|
|
89
|
+
Add unit text, e.g.: r"$\mu m$".
|
|
90
|
+
**kwargs: dict[str, Unknown]
|
|
91
|
+
Forwarded to `~matplotlib.collections.LineCollection`.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
fig: The figure to plot on.
|
|
95
|
+
ax: The axes to plot on.
|
|
96
|
+
"""
|
|
97
|
+
# pylint: disable=too-many-locals
|
|
98
|
+
swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
|
|
99
|
+
|
|
100
|
+
show = (show is True) or (show is None and ax is None)
|
|
101
|
+
fig, ax = get_fig_ax(fig, ax)
|
|
102
|
+
|
|
103
|
+
my_camera = SimpleCamera.from_options(camera)
|
|
104
|
+
my_color = get_ax_color(ax, swc, color)
|
|
105
|
+
|
|
106
|
+
xyz = swc.xyz()
|
|
107
|
+
starts, ends = swc.id()[1:], swc.pid()[1:]
|
|
108
|
+
lines = np.stack([xyz[starts], xyz[ends]], axis=1)
|
|
109
|
+
collection = draw_lines(ax, lines, camera=my_camera, color=my_color, **kwargs)
|
|
110
|
+
|
|
111
|
+
ax.autoscale()
|
|
112
|
+
_set_ax_memo(ax, swc, label=label, handle=collection)
|
|
113
|
+
|
|
114
|
+
if len(get_ax_swc(ax)) == 1:
|
|
115
|
+
ax.set_aspect(1)
|
|
116
|
+
ax.spines[["top", "right"]].set_visible(False)
|
|
117
|
+
if direction_indicator is not False:
|
|
118
|
+
loc = (
|
|
119
|
+
locations[direction_indicator]
|
|
120
|
+
if isinstance(direction_indicator, str)
|
|
121
|
+
else direction_indicator
|
|
122
|
+
)
|
|
123
|
+
draw_direction_indicator(ax, camera=my_camera, loc=loc)
|
|
124
|
+
if unit is not None:
|
|
125
|
+
ax.text(0.05, 0.95, unit, transform=ax.transAxes)
|
|
126
|
+
else:
|
|
127
|
+
set_ax_legend(ax, loc="upper right") # enable legend
|
|
128
|
+
|
|
129
|
+
if show:
|
|
130
|
+
fig.show(warn=False)
|
|
131
|
+
|
|
132
|
+
return fig, ax
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_ax_swc(ax: Axes) -> list[SWCLike]:
|
|
136
|
+
ax_weak_memo.setdefault(ax, {})
|
|
137
|
+
return ax_weak_memo[ax]["swc"]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_ax_color(
|
|
141
|
+
ax: Axes,
|
|
142
|
+
swc: SWCLike,
|
|
143
|
+
color: dict[int, str] | str | None = None, # TODO: improve typing
|
|
144
|
+
) -> str | list[str]:
|
|
145
|
+
if color == "vaa3d":
|
|
146
|
+
color = palette.vaa3d
|
|
147
|
+
elif isinstance(color, str):
|
|
148
|
+
return color # user specified
|
|
149
|
+
|
|
150
|
+
# choose default
|
|
151
|
+
ax_weak_memo.setdefault(ax, {})
|
|
152
|
+
ax_weak_memo[ax].setdefault("color", -1)
|
|
153
|
+
ax_weak_memo[ax]["color"] += 1
|
|
154
|
+
c = palette.default[ax_weak_memo[ax]["color"] % len(palette.default)]
|
|
155
|
+
|
|
156
|
+
if isinstance(color, dict):
|
|
157
|
+
types = swc.type()[:-1] # colored by type of parent node
|
|
158
|
+
return list(map(lambda type: color.get(type, c), types))
|
|
159
|
+
|
|
160
|
+
return c
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def set_ax_legend(ax: Axes, *args, **kwargs) -> Legend | None:
|
|
164
|
+
labels = ax_weak_memo[ax].get("labels", [])
|
|
165
|
+
handles = ax_weak_memo[ax].get("handles", [])
|
|
166
|
+
|
|
167
|
+
# filter `label = False`
|
|
168
|
+
handles = [a for i, a in enumerate(handles) if labels[i] is not False]
|
|
169
|
+
labels = [a for a in labels if a is not False]
|
|
170
|
+
|
|
171
|
+
if len(labels) == 0:
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
return ax.legend(handles, labels, *args, **kwargs)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _set_ax_memo(
|
|
178
|
+
ax: Axes, swc: SWCLike, label: str | bool | None = None, handle: Any | None = None
|
|
179
|
+
):
|
|
180
|
+
ax_weak_memo.setdefault(ax, {})
|
|
181
|
+
ax_weak_memo[ax].setdefault("swc", [])
|
|
182
|
+
ax_weak_memo[ax]["swc"].append(swc)
|
|
183
|
+
|
|
184
|
+
if label is not None:
|
|
185
|
+
label = os.path.basename(swc.source) if label is True else label
|
|
186
|
+
ax_weak_memo[ax].setdefault("labels", [])
|
|
187
|
+
ax_weak_memo[ax]["labels"].append(label)
|
|
188
|
+
|
|
189
|
+
if handle is not None:
|
|
190
|
+
ax_weak_memo[ax].setdefault("handles", [])
|
|
191
|
+
ax_weak_memo[ax]["handles"].append(handle)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Painter utils.
|
|
7
|
+
|
|
8
|
+
NOTE: This is a experimental function, it may be changed in the future.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from matplotlib.axes import Axes
|
|
13
|
+
from matplotlib.figure import Figure
|
|
14
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
15
|
+
|
|
16
|
+
from swcgeom.analysis.visualization import (
|
|
17
|
+
_set_ax_memo,
|
|
18
|
+
get_ax_color,
|
|
19
|
+
get_ax_swc,
|
|
20
|
+
set_ax_legend,
|
|
21
|
+
)
|
|
22
|
+
from swcgeom.core import SWCLike, Tree
|
|
23
|
+
from swcgeom.utils.plotter_3d import draw_lines_3d
|
|
24
|
+
|
|
25
|
+
__all__ = ["draw3d"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# TODO: support Camera
|
|
29
|
+
def draw3d(
|
|
30
|
+
swc: SWCLike | str,
|
|
31
|
+
*,
|
|
32
|
+
ax: Axes,
|
|
33
|
+
show: bool | None = None,
|
|
34
|
+
color: dict[int, str] | str | None = None, # TODO: improve typing
|
|
35
|
+
label: str | bool = True,
|
|
36
|
+
**kwargs,
|
|
37
|
+
) -> tuple[Figure, Axes]:
|
|
38
|
+
r"""Draw neuron tree.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
swc: The swc tree to draw.
|
|
42
|
+
If it is str, then it is treated as the path of swc file.
|
|
43
|
+
fig: The figure to plot on.
|
|
44
|
+
ax: The axes to plot on.
|
|
45
|
+
show: Weather to call `plt.show()`.
|
|
46
|
+
If not specified, it will depend on if ax is passed in, it will not be
|
|
47
|
+
called, otherwise it will be called by default.
|
|
48
|
+
color: Color map.
|
|
49
|
+
If is dict, segments will be colored by the type of parent node.If is
|
|
50
|
+
string, the value will be use for any type.
|
|
51
|
+
label: Label of legend, disable if False.
|
|
52
|
+
**kwargs: Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
|
|
53
|
+
"""
|
|
54
|
+
assert isinstance(ax, Axes3D), "only support 3D axes."
|
|
55
|
+
|
|
56
|
+
swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
|
|
57
|
+
|
|
58
|
+
show = (show is True) or (show is None and ax is None)
|
|
59
|
+
my_color = get_ax_color(ax, swc, color)
|
|
60
|
+
|
|
61
|
+
xyz = swc.xyz()
|
|
62
|
+
starts, ends = swc.id()[1:], swc.pid()[1:]
|
|
63
|
+
lines = np.stack([xyz[starts], xyz[ends]], axis=1)
|
|
64
|
+
collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)
|
|
65
|
+
|
|
66
|
+
min_vals = lines.reshape(-1, 3).min(axis=0)
|
|
67
|
+
max_vals = lines.reshape(-1, 3).max(axis=0)
|
|
68
|
+
ax.set_xlim(min_vals[0], max_vals[0])
|
|
69
|
+
ax.set_ylim(min_vals[1], max_vals[1])
|
|
70
|
+
ax.set_zlim(min_vals[2], max_vals[2])
|
|
71
|
+
|
|
72
|
+
_set_ax_memo(ax, swc, label=label, handle=collection)
|
|
73
|
+
|
|
74
|
+
if len(get_ax_swc(ax)) == 1:
|
|
75
|
+
# ax.set_aspect(1)
|
|
76
|
+
ax.spines[["top", "right"]].set_visible(False)
|
|
77
|
+
else:
|
|
78
|
+
set_ax_legend(ax, loc="upper right") # enable legend
|
|
79
|
+
|
|
80
|
+
fig = ax.figure
|
|
81
|
+
return fig, ax # type: ignore
|