swcgeom 0.19.4__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +341 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +107 -0
  23. swcgeom/core/swc_utils/io.py +204 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +384 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-313-x86_64-linux-gnu.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-313-x86_64-linux-gnu.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +161 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.19.4.dist-info/METADATA +86 -0
  69. swcgeom-0.19.4.dist-info/RECORD +72 -0
  70. swcgeom-0.19.4.dist-info/WHEEL +6 -0
  71. swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.19.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,201 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Sholl analysis."""
7
+
8
+ import warnings
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ import seaborn as sns
14
+ from matplotlib.axes import Axes
15
+ from matplotlib.figure import Figure
16
+ from typing_extensions import deprecated
17
+
18
+ from swcgeom.analysis.visualization import draw
19
+ from swcgeom.core import Tree
20
+ from swcgeom.transforms import TranslateOrigin
21
+ from swcgeom.utils import draw_circles, get_fig_ax
22
+
23
+ __all__ = ["Sholl"]
24
+
25
+ XLABEL = "Radial Distance"
26
+ YLABLE = "Count of Intersections"
27
+
28
+
29
+ class Sholl:
30
+ """Sholl analysis.
31
+
32
+ Implementation of original Sholl analysis as described in [1]_. The Sholl analysis
33
+ is a method to quantify the spatial distribution of neuronal processes. It is
34
+ based on the number of intersections of concentric circles with the neuronal
35
+ processes.
36
+
37
+ References:
38
+ .. [1] Dendritic organization in the neurons of the visual and
39
+ motor cortices of the cat J. Anat., 87 (1953), pp. 387-406
40
+ """
41
+
42
+ tree: Tree
43
+ rs: npt.NDArray[np.float32]
44
+ rmax: float
45
+
46
+ # compat
47
+ step: float | None = None
48
+
49
+ def __init__(self, tree: Tree | str, step: float | None = None) -> None:
50
+ tree = Tree.from_swc(tree) if isinstance(tree, str) else tree
51
+ try:
52
+ self.tree = TranslateOrigin.transform(tree) # shift
53
+ self.rs = np.linalg.norm(self.tree.get_segments().xyz(), axis=2)
54
+ self.rmax = self.rs.max()
55
+ except Exception as e:
56
+ raise ValueError(f"invalid tree: {tree.source or ''}") from e
57
+
58
+ if step is not None:
59
+ warnings.warn(
60
+ "`Sholl(x, step=...)` has been replaced by "
61
+ "`Sholl(x).get(steps=...)` since v0.6.0 because it has "
62
+ "been change to dynamic calculate, and will be removed "
63
+ "in next version",
64
+ DeprecationWarning,
65
+ )
66
+ self.step = step
67
+
68
+ def get(self, steps: int | npt.ArrayLike = 20) -> npt.NDArray[np.int64]:
69
+ intersections = [
70
+ np.logical_or(
71
+ np.logical_and(self.rs[:, 0] <= r, self.rs[:, 1] > r),
72
+ np.logical_and(self.rs[:, 1] <= r, self.rs[:, 0] > r),
73
+ )
74
+ for r in self._get_rs(steps=steps)
75
+ ]
76
+ return np.count_nonzero(intersections, axis=1)
77
+
78
+ def intersect(self, r: float) -> int:
79
+ return np.count_nonzero(
80
+ np.logical_or(
81
+ np.logical_and(self.rs[:, 0] <= r, self.rs[:, 1] > r),
82
+ np.logical_and(self.rs[:, 1] <= r, self.rs[:, 0] > r),
83
+ )
84
+ )
85
+
86
+ def plot( # pylint: disable=too-many-arguments
87
+ self,
88
+ steps: list[float] | int = 20,
89
+ plot_type: str | None = None,
90
+ kind: Literal["bar", "linechart", "circles"] = "circles",
91
+ fig: Figure | None = None,
92
+ ax: Axes | None = None,
93
+ **kwargs,
94
+ ) -> tuple[Figure, Axes]:
95
+ """Plot Sholl analysis.
96
+
97
+ Args:
98
+ steps: Steps of raius of circle.
99
+ If steps is int, then it will be evenly divided into n radii.
100
+ kind: Kind of plot.
101
+ fig: The figure to plot on.
102
+ ax: The axes to plot on.
103
+ **kwargs: Forwarding to plot method.
104
+ """
105
+ if plot_type is not None:
106
+ warnings.warn(
107
+ "`plot_type` has been renamed to `kind` since v0.5.0, "
108
+ "and will be removed in next version",
109
+ DeprecationWarning,
110
+ )
111
+ kind = plot_type # type: ignore
112
+
113
+ xs = self._get_rs(steps=steps)
114
+ ys = self.get(steps=xs)
115
+ fig, ax = get_fig_ax(fig, ax)
116
+ match kind:
117
+ case "bar":
118
+ sns.barplot(x=xs, y=ys, ax=ax, **kwargs)
119
+ ax.set_ylabel(YLABLE)
120
+
121
+ case "linechart":
122
+ sns.lineplot(x=xs, y=ys, ax=ax, **kwargs)
123
+ ax.set_ylabel(YLABLE)
124
+
125
+ case "circles":
126
+ kwargs.setdefault("y_min", 0)
127
+ drawtree = kwargs.pop("drawtree", True)
128
+ colorbar = kwargs.pop("colorbar", True)
129
+ cmap = kwargs.pop("cmap", "Blues")
130
+ patches = draw_circles(ax, xs, ys, cmap=cmap, **kwargs)
131
+
132
+ if drawtree is True:
133
+ draw(self.tree, ax=ax, direction_indicator=False)
134
+ elif isinstance(drawtree, str):
135
+ draw(self.tree, ax=ax, color=drawtree, direction_indicator=False)
136
+
137
+ if colorbar is True:
138
+ fig.colorbar(patches, ax=ax, label=YLABLE)
139
+ elif isinstance(colorbar, (Axes, np.ndarray, list)):
140
+ fig.colorbar(patches, ax=colorbar, label=YLABLE)
141
+
142
+ case _:
143
+ raise ValueError(f"unsupported kind: {kind}")
144
+
145
+ ax.set_xlabel(XLABEL)
146
+ return fig, ax
147
+
148
+ @staticmethod
149
+ def get_rs(rmax: float, steps: int | npt.ArrayLike) -> npt.NDArray[np.float32]:
150
+ """Function to calculate the list of radius used by the sholl."""
151
+ if isinstance(steps, int):
152
+ s = rmax / (steps + 1)
153
+ return np.arange(s, rmax, s)
154
+
155
+ return np.array(steps)
156
+
157
+ def _get_rs(self, steps: int | npt.ArrayLike) -> npt.NDArray[np.float32]:
158
+ if self.step is not None: # compat
159
+ return np.arange(self.step, int(np.ceil(self.rmax)), self.step)
160
+
161
+ return self.get_rs(self.rmax, steps)
162
+
163
+ @deprecated("Use `Sholl.get(x)` instead")
164
+ def get_count(self) -> npt.NDArray[np.int32]:
165
+ """Get the count of intersection.
166
+
167
+ .. deprecated:: 0.5.0
168
+ Use :meth:`Sholl(x).get()` instead.
169
+ """
170
+
171
+ return self.get().astype(np.int32)
172
+
173
+ @deprecated("Use `Shool(x).get().mean()` instead")
174
+ def avg(self) -> float:
175
+ """Get the average of the count of intersection.
176
+
177
+ .. deprecated:: 0.6.0
178
+ Use :meth:`Shool(x).get().mean()` instead.
179
+ """
180
+
181
+ return self.get().mean()
182
+
183
+ @deprecated("Use `Shool(x).get().std()` instead")
184
+ def std(self) -> float:
185
+ """Get the std of the count of intersection.
186
+
187
+ .. deprecated:: 0.6.0
188
+ Use :meth:`Shool(x).get().std()` instead.
189
+ """
190
+
191
+ return self.get().std()
192
+
193
+ @deprecated("Use `Shool(x).get().sum()` instead")
194
+ def sum(self) -> int:
195
+ """Get the sum of the count of intersection.
196
+
197
+ .. deprecated:: 0.6.0
198
+ Use :meth:`Shool(x).get().sum()` instead.
199
+ """
200
+
201
+ return self.get().sum()
@@ -0,0 +1,183 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Plot trunk and florets."""
7
+
8
+ # pylint: disable=invalid-name
9
+
10
+ from collections.abc import Iterable
11
+ from itertools import chain
12
+ from typing import Any, Literal, cast
13
+
14
+ import numpy as np
15
+ import numpy.typing as npt
16
+ from matplotlib import cm
17
+ from matplotlib.axes import Axes
18
+ from matplotlib.figure import Figure
19
+ from matplotlib.patches import Circle, Ellipse, Patch, Rectangle
20
+
21
+ from swcgeom.analysis.visualization import draw
22
+ from swcgeom.core import Tree, get_subtree, to_subtree
23
+ from swcgeom.utils import get_fig_ax, mvee
24
+
25
+ __all__ = ["draw_trunk"]
26
+
27
+ Bounds = Literal["aabb", "ellipse"]
28
+ Projection = Literal["2d"]
29
+
30
+
31
+ def draw_trunk(
32
+ t: Tree,
33
+ florets: Iterable[int | Iterable[int]],
34
+ *,
35
+ fig: Figure | None = None,
36
+ ax: Axes | None = None,
37
+ bound: Bounds | tuple[Bounds, dict[str, Any]] | None = "ellipse",
38
+ point: bool | dict[str, Any] = True,
39
+ projection: Projection = "2d",
40
+ cmap: Any = "viridis",
41
+ **kwargs,
42
+ ) -> tuple[Figure, Axes]:
43
+ """Draw trunk tree.
44
+
45
+ Args:
46
+ t: Tree
47
+ florets: The florets that needs to be removed.
48
+ Each floret can be a subtree or multiple subtrees (e.g., dendrites are a
49
+ bunch of subtrees), each number is the id of a tree node.
50
+ fig: Figure to plot on.
51
+ ax: Axes to plot on.
52
+ bound: Kind of bound, support 'aabb', 'ellipse'.
53
+ If bound is None, no bound will be drawn. If bound is a tuple, the second
54
+ item will used as kwargs and forward to draw function.
55
+ point: Draw point at the start of a subtree.
56
+ If point is False, no point will be drawn. If point is a dict, this will
57
+ used a kwargs and forward to draw function.
58
+ cmap: Colormap.
59
+ Any value supported by ~matplotlib.cm.Colormap. We will use the ratio of
60
+ the length of the subtree to the total length of the tree to determine the
61
+ color.
62
+ **kwargs: Forward to ~swcgeom.analysis.draw.
63
+ """
64
+ # pylint: disable=too-many-locals
65
+ trunk, tss = split_florets(t, florets)
66
+ lens = get_length_ratio(t, tss)
67
+
68
+ cmap = cm.get_cmap(cmap)
69
+ c = cmap(lens)
70
+
71
+ fig, ax = get_fig_ax(fig, ax)
72
+ if bound is not None:
73
+ for ts, cc in zip(tss, c):
74
+ draw_bound(ts, ax, bound, projection, color=cc)
75
+
76
+ if point is not False:
77
+ point_kwargs = point if isinstance(point, dict) else {}
78
+ for ts, cc in zip(tss, c):
79
+ draw_point(ts, ax, projection, color=cc, **point_kwargs)
80
+
81
+ draw(trunk, ax=ax, color=cmap(1), **kwargs)
82
+ return fig, ax
83
+
84
+
85
+ def split_florets(
86
+ t: Tree, florets: Iterable[int | Iterable[int]]
87
+ ) -> tuple[Tree, list[list[Tree]]]:
88
+ florets = [[i] if isinstance(i, (int, np.integer)) else i for i in florets]
89
+ subtrees = [[get_subtree(t, ff) for ff in f] for f in florets]
90
+ trunk = to_subtree(t, chain(*florets))
91
+ return trunk, subtrees
92
+
93
+
94
+ def get_length_ratio(t: Tree, tss: list[list[Tree]]) -> Any:
95
+ lens = np.array([sum(t.length() for t in ts) for ts in tss])
96
+ return lens / t.length()
97
+
98
+
99
+ # Bounds
100
+
101
+
102
+ def draw_bound(
103
+ ts: Iterable[Tree],
104
+ ax: Axes,
105
+ bound: Bounds | tuple[Bounds, dict[str, Any]],
106
+ projection: Projection,
107
+ **kwargs,
108
+ ) -> None:
109
+ kind, bound_kwargs = (bound, {}) if isinstance(bound, str) else bound
110
+ if projection == "2d":
111
+ patch = create_bound_2d(ts, kind, **kwargs, **bound_kwargs)
112
+ else:
113
+ raise ValueError(f"unsupported projection {projection}")
114
+
115
+ ax.add_patch(patch)
116
+
117
+
118
+ def create_bound_2d(ts: Iterable[Tree], bound: Bounds, **kwargs) -> Patch:
119
+ xyz = np.concatenate([t.xyz() for t in ts])
120
+ xy = xyz[:, :2] # TODO: camera
121
+
122
+ if bound == "aabb":
123
+ return create_aabb_2d(xy, **kwargs)
124
+ if bound == "ellipse":
125
+ return create_ellipse_2d(xy, **kwargs)
126
+ raise ValueError(f"unsupported bound `{bound}` in 2d projection")
127
+
128
+
129
+ def create_aabb_2d(xy: npt.NDArray, fill: bool = False, **kwargs) -> Rectangle:
130
+ xmin, ymin = xy[:, 0].min(), xy[:, 1].min()
131
+ xmax, ymax = xy[:, 0].max(), xy[:, 1].max()
132
+ width, height = xmax - xmin, ymax - ymin
133
+ rect = Rectangle(
134
+ xy=(xmin, ymin), width=width, height=height, angle=0, fill=fill, **kwargs
135
+ )
136
+ return rect
137
+
138
+
139
+ def create_ellipse_2d(xy: npt.NDArray, fill: bool = False, **kwargs) -> Ellipse:
140
+ ellipse = mvee(xy)
141
+ patch = Ellipse(
142
+ xy=ellipse.centroid, # type:ignore
143
+ width=ellipse.a,
144
+ height=ellipse.b,
145
+ angle=ellipse.alpha,
146
+ fill=fill,
147
+ **kwargs,
148
+ )
149
+ return patch
150
+
151
+
152
+ # point
153
+
154
+
155
+ def draw_point(ts: Iterable[Tree], ax: Axes, projection: Projection, **kwargs) -> None:
156
+ if projection == "2d":
157
+ patch = create_point_2d(ts, **kwargs)
158
+ else:
159
+ raise ValueError(f"unsupported projection {projection}")
160
+
161
+ ax.add_patch(patch)
162
+
163
+
164
+ def create_point_2d(
165
+ ts: Iterable[Tree], radius: float | None = None, **kwargs
166
+ ) -> Circle:
167
+ if radius is None:
168
+ xyz = np.concatenate([t.xyz() for t in ts]) # TODO: cache
169
+ radius = 0.05 * min(
170
+ xyz[:, 0].max() - xyz[:, 0].min(),
171
+ xyz[:, 1].max() - xyz[:, 1].min(),
172
+ )
173
+ radius = cast(float, radius)
174
+
175
+ center = np.mean([t.xyz()[0, :2] for t in ts], axis=0)
176
+ return Circle(center, radius, **kwargs)
177
+
178
+
179
+ # Helpers
180
+
181
+
182
+ def get_dendrites(tree: Tree) -> Iterable[int]:
183
+ return (t.node(0).id for t in tree.get_dendrites())
@@ -0,0 +1,191 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """Painter utils."""
21
+
22
+ import os
23
+ import weakref
24
+ from typing import Any, Literal
25
+
26
+ import numpy as np
27
+ from matplotlib.axes import Axes
28
+ from matplotlib.figure import Figure
29
+ from matplotlib.legend import Legend
30
+
31
+ from swcgeom.core import SWCLike, Tree
32
+ from swcgeom.utils import (
33
+ CameraOptions,
34
+ SimpleCamera,
35
+ draw_direction_indicator,
36
+ draw_lines,
37
+ get_fig_ax,
38
+ palette,
39
+ )
40
+
41
+ __all__ = ["draw"]
42
+
43
+ Positions = Literal["lt", "lb", "rt", "rb"] | tuple[float, float]
44
+ locations: dict[Literal["lt", "lb", "rt", "rb"], tuple[float, float]] = {
45
+ "lt": (0.10, 0.90),
46
+ "lb": (0.10, 0.10),
47
+ "rt": (0.90, 0.90),
48
+ "rb": (0.90, 0.10),
49
+ }
50
+
51
+ ax_weak_memo = weakref.WeakKeyDictionary[Axes, dict[str, Any]]({})
52
+
53
+
54
+ def draw(
55
+ swc: SWCLike | str,
56
+ *,
57
+ fig: Figure | None = None,
58
+ ax: Axes | None = None,
59
+ show: bool | None = None,
60
+ camera: CameraOptions = "xy",
61
+ color: dict[int, str] | str | None = None,
62
+ label: str | bool = True,
63
+ direction_indicator: Positions | Literal[False] = "rb",
64
+ unit: str | None = None,
65
+ **kwargs,
66
+ ) -> tuple[Figure, Axes]:
67
+ r"""Draw neuron tree.
68
+
69
+ Args:
70
+ swc: The swc tree to draw.
71
+ If it is str, then it is treated as the path of swc file.
72
+ fig: The figure to plot on.
73
+ ax: The axes to plot on.
74
+ show: Weather to call `plt.show()`.
75
+ If not specified, it will depend on if ax is passed in, it will not be
76
+ called, otherwise it will be called by default.
77
+ camera: Camera options (position, look-at, up).
78
+ One, two, or three vectors are supported, if only one vector, then threat
79
+ it as look-at, so camera is ((0, 0, 0), look-at, (0, 1, 0)); if two vector,
80
+ then then threat it as (look-at, up), so camera is ((0, 0, 0), look-at, up).
81
+ An easy way is to use the presets "xy", "yz" and "zx".
82
+ color: Color map.
83
+ If is dict, segments will be colored by the type of parent node.If is
84
+ string, the value will be use for any type.
85
+ label: Label of legend, disable if False.
86
+ direction_indicator: Draw a xyz direction indicator.
87
+ Can be place on 'lt', 'lb', 'rt', 'rb', or custom position.
88
+ unit: str, optional
89
+ Add unit text, e.g.: r"$\mu m$".
90
+ **kwargs: dict[str, Unknown]
91
+ Forwarded to `~matplotlib.collections.LineCollection`.
92
+
93
+ Returns:
94
+ fig: The figure to plot on.
95
+ ax: The axes to plot on.
96
+ """
97
+ # pylint: disable=too-many-locals
98
+ swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
99
+
100
+ show = (show is True) or (show is None and ax is None)
101
+ fig, ax = get_fig_ax(fig, ax)
102
+
103
+ my_camera = SimpleCamera.from_options(camera)
104
+ my_color = get_ax_color(ax, swc, color)
105
+
106
+ xyz = swc.xyz()
107
+ starts, ends = swc.id()[1:], swc.pid()[1:]
108
+ lines = np.stack([xyz[starts], xyz[ends]], axis=1)
109
+ collection = draw_lines(ax, lines, camera=my_camera, color=my_color, **kwargs)
110
+
111
+ ax.autoscale()
112
+ _set_ax_memo(ax, swc, label=label, handle=collection)
113
+
114
+ if len(get_ax_swc(ax)) == 1:
115
+ ax.set_aspect(1)
116
+ ax.spines[["top", "right"]].set_visible(False)
117
+ if direction_indicator is not False:
118
+ loc = (
119
+ locations[direction_indicator]
120
+ if isinstance(direction_indicator, str)
121
+ else direction_indicator
122
+ )
123
+ draw_direction_indicator(ax, camera=my_camera, loc=loc)
124
+ if unit is not None:
125
+ ax.text(0.05, 0.95, unit, transform=ax.transAxes)
126
+ else:
127
+ set_ax_legend(ax, loc="upper right") # enable legend
128
+
129
+ if show:
130
+ fig.show(warn=False)
131
+
132
+ return fig, ax
133
+
134
+
135
+ def get_ax_swc(ax: Axes) -> list[SWCLike]:
136
+ ax_weak_memo.setdefault(ax, {})
137
+ return ax_weak_memo[ax]["swc"]
138
+
139
+
140
+ def get_ax_color(
141
+ ax: Axes,
142
+ swc: SWCLike,
143
+ color: dict[int, str] | str | None = None, # TODO: improve typing
144
+ ) -> str | list[str]:
145
+ if color == "vaa3d":
146
+ color = palette.vaa3d
147
+ elif isinstance(color, str):
148
+ return color # user specified
149
+
150
+ # choose default
151
+ ax_weak_memo.setdefault(ax, {})
152
+ ax_weak_memo[ax].setdefault("color", -1)
153
+ ax_weak_memo[ax]["color"] += 1
154
+ c = palette.default[ax_weak_memo[ax]["color"] % len(palette.default)]
155
+
156
+ if isinstance(color, dict):
157
+ types = swc.type()[:-1] # colored by type of parent node
158
+ return list(map(lambda type: color.get(type, c), types))
159
+
160
+ return c
161
+
162
+
163
+ def set_ax_legend(ax: Axes, *args, **kwargs) -> Legend | None:
164
+ labels = ax_weak_memo[ax].get("labels", [])
165
+ handles = ax_weak_memo[ax].get("handles", [])
166
+
167
+ # filter `label = False`
168
+ handles = [a for i, a in enumerate(handles) if labels[i] is not False]
169
+ labels = [a for a in labels if a is not False]
170
+
171
+ if len(labels) == 0:
172
+ return None
173
+
174
+ return ax.legend(handles, labels, *args, **kwargs)
175
+
176
+
177
+ def _set_ax_memo(
178
+ ax: Axes, swc: SWCLike, label: str | bool | None = None, handle: Any | None = None
179
+ ):
180
+ ax_weak_memo.setdefault(ax, {})
181
+ ax_weak_memo[ax].setdefault("swc", [])
182
+ ax_weak_memo[ax]["swc"].append(swc)
183
+
184
+ if label is not None:
185
+ label = os.path.basename(swc.source) if label is True else label
186
+ ax_weak_memo[ax].setdefault("labels", [])
187
+ ax_weak_memo[ax]["labels"].append(label)
188
+
189
+ if handle is not None:
190
+ ax_weak_memo[ax].setdefault("handles", [])
191
+ ax_weak_memo[ax]["handles"].append(handle)
@@ -0,0 +1,81 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Painter utils.
7
+
8
+ NOTE: This is a experimental function, it may be changed in the future.
9
+ """
10
+
11
+ import numpy as np
12
+ from matplotlib.axes import Axes
13
+ from matplotlib.figure import Figure
14
+ from mpl_toolkits.mplot3d import Axes3D
15
+
16
+ from swcgeom.analysis.visualization import (
17
+ _set_ax_memo,
18
+ get_ax_color,
19
+ get_ax_swc,
20
+ set_ax_legend,
21
+ )
22
+ from swcgeom.core import SWCLike, Tree
23
+ from swcgeom.utils.plotter_3d import draw_lines_3d
24
+
25
+ __all__ = ["draw3d"]
26
+
27
+
28
+ # TODO: support Camera
29
+ def draw3d(
30
+ swc: SWCLike | str,
31
+ *,
32
+ ax: Axes,
33
+ show: bool | None = None,
34
+ color: dict[int, str] | str | None = None, # TODO: improve typing
35
+ label: str | bool = True,
36
+ **kwargs,
37
+ ) -> tuple[Figure, Axes]:
38
+ r"""Draw neuron tree.
39
+
40
+ Args:
41
+ swc: The swc tree to draw.
42
+ If it is str, then it is treated as the path of swc file.
43
+ fig: The figure to plot on.
44
+ ax: The axes to plot on.
45
+ show: Weather to call `plt.show()`.
46
+ If not specified, it will depend on if ax is passed in, it will not be
47
+ called, otherwise it will be called by default.
48
+ color: Color map.
49
+ If is dict, segments will be colored by the type of parent node.If is
50
+ string, the value will be use for any type.
51
+ label: Label of legend, disable if False.
52
+ **kwargs: Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
53
+ """
54
+ assert isinstance(ax, Axes3D), "only support 3D axes."
55
+
56
+ swc = Tree.from_swc(swc) if isinstance(swc, str) else swc
57
+
58
+ show = (show is True) or (show is None and ax is None)
59
+ my_color = get_ax_color(ax, swc, color)
60
+
61
+ xyz = swc.xyz()
62
+ starts, ends = swc.id()[1:], swc.pid()[1:]
63
+ lines = np.stack([xyz[starts], xyz[ends]], axis=1)
64
+ collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)
65
+
66
+ min_vals = lines.reshape(-1, 3).min(axis=0)
67
+ max_vals = lines.reshape(-1, 3).max(axis=0)
68
+ ax.set_xlim(min_vals[0], max_vals[0])
69
+ ax.set_ylim(min_vals[1], max_vals[1])
70
+ ax.set_zlim(min_vals[2], max_vals[2])
71
+
72
+ _set_ax_memo(ax, swc, label=label, handle=collection)
73
+
74
+ if len(get_ax_swc(ax)) == 1:
75
+ # ax.set_aspect(1)
76
+ ax.spines[["top", "right"]].set_visible(False)
77
+ else:
78
+ set_ax_legend(ax, loc="upper right") # enable legend
79
+
80
+ fig = ax.figure
81
+ return fig, ax # type: ignore