swcgeom 0.20.0__cp312-cp312-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +394 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +112 -0
  23. swcgeom/core/swc_utils/io.py +335 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +387 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-312-darwin.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-312-darwin.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +160 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.20.0.dist-info/METADATA +86 -0
  69. swcgeom-0.20.0.dist-info/RECORD +72 -0
  70. swcgeom-0.20.0.dist-info/WHEEL +5 -0
  71. swcgeom-0.20.0.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.20.0.dist-info/top_level.txt +1 -0
swcgeom/core/tree.py ADDED
@@ -0,0 +1,387 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Neuron tree."""
6
+
7
+ import itertools
8
+ import os
9
+ from collections.abc import Callable, Iterable, Iterator
10
+ from typing import Literal, TypeVar, Union, overload
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ import pandas as pd
15
+ from typing_extensions import deprecated
16
+
17
+ from swcgeom.core.branch import Branch
18
+ from swcgeom.core.compartment import Compartment, Compartments
19
+ from swcgeom.core.node import Node
20
+ from swcgeom.core.path import Path
21
+ from swcgeom.core.swc import DictSWC, eswc_cols
22
+ from swcgeom.core.swc_utils import SWCNames, get_names, read_swc, traverse
23
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl
24
+ from swcgeom.utils import PathOrIO, padding1d
25
+
26
+ __all__ = ["Tree"]
27
+
28
+ T = TypeVar("T")
29
+ K = TypeVar("K")
30
+
31
+
32
+ class Tree(DictSWC):
33
+ """A neuron tree, which should be a binary tree in most cases."""
34
+
35
+ class Node(Node["Tree"]):
36
+ """Neural node."""
37
+
38
+ def parent(self) -> Union["Tree.Node", None]:
39
+ return Tree.Node(self.attach, self.pid) if self.pid != -1 else None
40
+
41
+ def children(self) -> list["Tree.Node"]:
42
+ children = self.attach.id()[self.attach.pid() == self.id]
43
+ return [Tree.Node(self.attach, idx) for idx in children]
44
+
45
+ def branch(self) -> "Tree.Branch":
46
+ ns: list["Tree.Node"] = [self]
47
+ while not ns[-1].is_furcation() and (p := ns[-1].parent()) is not None:
48
+ ns.append(p)
49
+
50
+ ns.reverse()
51
+ while not (ns[-1].is_furcation() or ns[-1].is_tip()):
52
+ ns.append(ns[-1].children()[0])
53
+
54
+ return Tree.Branch(self.attach, [n.id for n in ns])
55
+
56
+ def radial_distance(self) -> float:
57
+ """The end-to-end straight-line distance to soma."""
58
+ return self.distance(self.attach.soma())
59
+
60
+ def subtree(self, *, out_mapping: Mapping | None = None) -> "Tree":
61
+ """Get subtree from node.
62
+
63
+ Args:
64
+ out_mapping: Map from new id to old id.
65
+ """
66
+
67
+ n_nodes, ndata, source, names = get_subtree_impl(
68
+ self.attach, self.id, out_mapping=out_mapping
69
+ )
70
+ return Tree(n_nodes, **ndata, source=source, names=names)
71
+
72
+ def is_root(self) -> bool:
73
+ return self.parent() is None
74
+
75
+ def is_soma(self) -> bool: # TODO: support multi soma, e.g. 3 points
76
+ return self.type == self.attach.types.soma and self.is_root()
77
+
78
+ @overload
79
+ def traverse(
80
+ self, *, enter: Callable[[Node, T | None], T], mode: Literal["dfs"] = ...
81
+ ) -> None: ...
82
+ @overload
83
+ def traverse(
84
+ self, *, leave: Callable[[Node, list[K]], K], mode: Literal["dfs"] = ...
85
+ ) -> K: ...
86
+ @overload
87
+ def traverse(
88
+ self,
89
+ *,
90
+ enter: Callable[[Node, T | None], T],
91
+ leave: Callable[[Node, list[K]], K],
92
+ mode: Literal["dfs"] = ...,
93
+ ) -> K: ...
94
+ def traverse(self, **kwargs): # type: ignore
95
+ """Traverse from node.
96
+
97
+ See Also:
98
+ ~Tree.traverse
99
+ """
100
+ return self.attach.traverse(root=self.idx, **kwargs)
101
+
102
+ class Path(Path["Tree"]):
103
+ # TODO: should returns `Tree.Node`
104
+ """Neural path."""
105
+
106
+ class Compartment(Compartment["Tree"]):
107
+ # TODO: should returns `Tree.Node`
108
+ """Neural compartment."""
109
+
110
+ Segment = Compartment # Alias
111
+
112
+ class Branch(Branch["Tree"]):
113
+ # TODO: should returns `Tree.Node`
114
+ """Neural branch."""
115
+
116
+ def __init__(
117
+ self,
118
+ n_nodes: int,
119
+ *,
120
+ source: str = "",
121
+ comments: Iterable[str] | None = None,
122
+ names: SWCNames | None = None,
123
+ **kwargs: npt.NDArray,
124
+ ) -> None:
125
+ names = get_names(names)
126
+
127
+ if names.id not in kwargs:
128
+ kwargs[names.id] = np.arange(0, n_nodes, step=1, dtype=np.int32)
129
+
130
+ if names.pid not in kwargs:
131
+ kwargs[names.pid] = np.arange(-1, n_nodes - 1, step=1, dtype=np.int32)
132
+
133
+ ndata = {
134
+ names.id: padding1d(n_nodes, kwargs.pop(names.id, None), dtype=np.int32),
135
+ names.type: padding1d(
136
+ n_nodes, kwargs.pop(names.type, None), dtype=np.int32
137
+ ),
138
+ names.x: padding1d(n_nodes, kwargs.pop(names.x, None), dtype=np.float32),
139
+ names.y: padding1d(n_nodes, kwargs.pop(names.y, None), dtype=np.float32),
140
+ names.z: padding1d(n_nodes, kwargs.pop(names.z, None), dtype=np.float32),
141
+ names.r: padding1d(
142
+ n_nodes, kwargs.pop(names.r, None), dtype=np.float32, padding_value=1
143
+ ),
144
+ names.pid: padding1d(n_nodes, kwargs.pop(names.pid, None), dtype=np.int32),
145
+ }
146
+ # ? padding other columns
147
+ super().__init__(
148
+ **ndata, **kwargs, source=source, comments=comments, names=names
149
+ )
150
+
151
+ def __iter__(self) -> Iterator[Node]:
152
+ return (self[i] for i in range(len(self)))
153
+
154
+ def __repr__(self) -> str:
155
+ n_nodes, n_edges = self.number_of_nodes(), self.number_of_edges()
156
+ return f"Neuron Tree with {n_nodes} nodes and {n_edges} edges"
157
+
158
+ @overload
159
+ def __getitem__(self, key: slice) -> list[Node]: ...
160
+ @overload
161
+ def __getitem__(self, key: int) -> Node: ...
162
+ @overload
163
+ def __getitem__(self, key: str) -> npt.NDArray: ...
164
+ def __getitem__(self, key):
165
+ if isinstance(key, slice):
166
+ return [self.node(i) for i in range(*key.indices(len(self)))]
167
+
168
+ if isinstance(key, (int, np.integer)):
169
+ length = len(self)
170
+ if key < -length or key >= length:
171
+ raise IndexError(f"The index ({key}) is out of range.")
172
+
173
+ if key < 0: # Handle negative indices
174
+ key += length
175
+
176
+ return self.node(key)
177
+
178
+ if isinstance(key, str):
179
+ return self.get_ndata(key)
180
+
181
+ raise TypeError("Invalid argument type.")
182
+
183
+ def keys(self) -> Iterable[str]:
184
+ return self.ndata.keys()
185
+
186
+ def node(self, idx: int | np.integer) -> Node:
187
+ return self.Node(self, idx)
188
+
189
+ def soma(self, type_check: bool = True) -> Node:
190
+ """Get soma of neuron."""
191
+ # TODO: find soma, see also: https://neuromorpho.org/myfaq.jsp
192
+ n = self.node(0)
193
+ if type_check and n.type != self.types.soma:
194
+ raise ValueError(f"no soma found in: {self.source}")
195
+ return n
196
+
197
+ def get_furcations(self) -> list[Node]:
198
+ """Get all node of furcations."""
199
+ furcations: list[int] = []
200
+
201
+ def collect_furcations(n: Tree.Node, children: list[None]) -> None:
202
+ if len(children) > 1:
203
+ furcations.append(n.id)
204
+
205
+ self.traverse(leave=collect_furcations)
206
+ return [self.node(i) for i in furcations]
207
+
208
+ @deprecated("Use `get_furcations` instead")
209
+ def get_bifurcations(self) -> list[Node]:
210
+ """Get all node of furcations.
211
+
212
+ .. deprecated:: 0.17.2
213
+ Deprecated due to the wrong spelling of furcation. For now, it is just an
214
+ alias of `get_furcations` and raise a warning. It will be change to raise
215
+ an error in the future.
216
+ """
217
+ return self.get_furcations()
218
+
219
+ def get_tips(self) -> list[Node]:
220
+ """Get all node of tips."""
221
+ tip_ids = np.setdiff1d(self.id(), self.pid(), assume_unique=True)
222
+ return [self.node(i) for i in tip_ids]
223
+
224
+ def get_compartments(self) -> Compartments[Compartment]:
225
+ return Compartments(self.Compartment(self, n.pid, n.id) for n in self[1:])
226
+
227
+ def get_segments(self) -> Compartments[Compartment]: # Alias
228
+ return self.get_compartments()
229
+
230
+ def get_branches(self) -> list[Branch]:
231
+ def collect_branches(
232
+ node: "Tree.Node", pre: list[tuple[list[Tree.Branch], list[int]]]
233
+ ) -> tuple[list[Tree.Branch], list[int]]:
234
+ if len(pre) == 1:
235
+ branches, children = pre[0]
236
+ children.append(node.id)
237
+ return branches, children
238
+
239
+ branches: list[Tree.Branch] = []
240
+ for sub_branches, children in pre:
241
+ children.append(node.id)
242
+ children.reverse()
243
+ sub_branches.append(
244
+ Tree.Branch(self, np.array(children, dtype=np.int32))
245
+ )
246
+ sub_branches.reverse()
247
+ branches.extend(sub_branches)
248
+
249
+ return branches, [node.id]
250
+
251
+ branches, children = self.traverse(leave=collect_branches)
252
+ if len(children) > 0:
253
+ branches.append(Tree.Branch(self, np.array(children, dtype=np.int32)))
254
+
255
+ return branches
256
+
257
+ def get_paths(self) -> list[Path]:
258
+ """Get all path from soma to tips."""
259
+ path_dic: dict[int, list[int]] = {}
260
+
261
+ def assign_path(n: Tree.Node, pre_path: list[int] | None) -> list[int]:
262
+ path = [] if pre_path is None else pre_path.copy()
263
+ path.append(n.id)
264
+ path_dic[n.id] = path
265
+ return path
266
+
267
+ def collect_path(
268
+ n: Tree.Node, children: list[list[list[int]]]
269
+ ) -> list[list[int]]:
270
+ if len(children) == 0:
271
+ return [path_dic[n.id]]
272
+
273
+ return list(itertools.chain(*children))
274
+
275
+ paths = self.traverse(enter=assign_path, leave=collect_path)
276
+ return [self.Path(self, idx) for idx in paths]
277
+
278
+ def get_neurites(self, type_check: bool = True) -> Iterable["Tree"]:
279
+ """Get neurites from soma."""
280
+ return (n.subtree() for n in self.soma(type_check).children())
281
+
282
+ def get_dendrites(self, type_check: bool = True) -> Iterable["Tree"]:
283
+ """Get dendrites."""
284
+ types = [self.types.apical_dendrite, self.types.basal_dendrite]
285
+ children = self.soma(type_check).children()
286
+ return (n.subtree() for n in children if n.type in types)
287
+
288
+ @overload
289
+ def traverse(
290
+ self,
291
+ *,
292
+ enter: Callable[[Node, T | None], T],
293
+ root: int | np.integer = ...,
294
+ mode: Literal["dfs"] = ...,
295
+ ) -> None: ...
296
+ @overload
297
+ def traverse(
298
+ self,
299
+ *,
300
+ leave: Callable[[Node, list[K]], K],
301
+ root: int | np.integer = ...,
302
+ mode: Literal["dfs"] = ...,
303
+ ) -> K: ...
304
+ @overload
305
+ def traverse(
306
+ self,
307
+ *,
308
+ enter: Callable[[Node, T | None], T],
309
+ leave: Callable[[Node, list[K]], K],
310
+ root: int | np.integer = ...,
311
+ mode: Literal["dfs"] = ...,
312
+ ) -> K: ...
313
+ def traverse(self, *, enter=None, leave=None, **kwargs):
314
+ """Traverse nodes.
315
+
316
+ Args:
317
+ enter: (n: Node, parent: T | None) => T
318
+ leave: (n: Node, children: list[T]) => T
319
+
320
+ See Also:
321
+ ~swc_utils.traverse
322
+ """
323
+
324
+ def wrap(fn) -> Callable | None:
325
+ if fn is None:
326
+ return None
327
+
328
+ def fn_wrapped(idx, *args, **kwargs):
329
+ return fn(self[idx], *args, **kwargs)
330
+
331
+ return fn_wrapped
332
+
333
+ topology = (self.id(), self.pid())
334
+ enter, leave = wrap(enter), wrap(leave)
335
+ return traverse(topology, enter=enter, leave=leave, **kwargs) # type: ignore
336
+
337
+ def length(self) -> float:
338
+ """Get length of tree."""
339
+ return sum(s.length() for s in self.get_segments())
340
+
341
+ @staticmethod
342
+ def from_data_frame(
343
+ df: pd.DataFrame,
344
+ source: str = "",
345
+ *,
346
+ comments: Iterable[str] | None = None,
347
+ names: SWCNames | None = None,
348
+ ) -> "Tree":
349
+ """Read neuron tree from data frame."""
350
+ names = get_names(names)
351
+ tree = Tree(
352
+ df.shape[0],
353
+ **{k: df[k].to_numpy() for k in names.cols()},
354
+ source=source,
355
+ comments=comments,
356
+ names=names,
357
+ )
358
+ return tree
359
+
360
+ @classmethod
361
+ def from_swc(cls, swc_file: PathOrIO, **kwargs) -> "Tree":
362
+ """Read neuron tree from swc file.
363
+
364
+ See Also:
365
+ ~swcgeom.core.swc_utils.read_swc
366
+ """
367
+
368
+ try:
369
+ df, comments = read_swc(swc_file, **kwargs)
370
+ except Exception as e: # pylint: disable=broad-except
371
+ raise ValueError(f"fails to read swc: {swc_file}") from e
372
+
373
+ source = os.path.abspath(swc_file) if isinstance(swc_file, str) else ""
374
+ return cls.from_data_frame(df, source=source, comments=comments)
375
+
376
+ @classmethod
377
+ def from_eswc(
378
+ cls, swc_file: str, extra_cols: list[str] | None = None, **kwargs
379
+ ) -> "Tree":
380
+ """Read neuron tree from eswc file.
381
+
382
+ See Also:
383
+ ~swcgeom.Tree.from_swc
384
+ """
385
+ extra_cols = extra_cols or []
386
+ extra_cols.extend(k for k, t in eswc_cols)
387
+ return cls.from_swc(swc_file, extra_cols=extra_cols, **kwargs)
@@ -0,0 +1,277 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """SWC util wrapper for tree."""
7
+
8
+ import warnings
9
+ from collections.abc import Callable, Iterable
10
+ from typing import TypeVar, overload
11
+
12
+ import numpy as np
13
+ from typing_extensions import deprecated
14
+
15
+ from swcgeom.core.swc import SWCLike
16
+ from swcgeom.core.swc_utils import (
17
+ REMOVAL,
18
+ SWCNames,
19
+ Topology,
20
+ get_names,
21
+ is_bifurcate,
22
+ propagate_removal,
23
+ sort_nodes_impl,
24
+ to_sub_topology,
25
+ )
26
+ from swcgeom.core.tree import Tree
27
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl, to_subtree_impl
28
+
29
+ __all__ = [
30
+ "sort_tree",
31
+ "cut_tree",
32
+ "to_sub_tree",
33
+ "to_subtree",
34
+ "get_subtree",
35
+ "redirect_tree",
36
+ "cat_tree",
37
+ ]
38
+
39
+ T = TypeVar("T")
40
+ K = TypeVar("K")
41
+ EPS = 1e-5
42
+
43
+
44
+ def is_binary_tree(tree: Tree, exclude_soma: bool = True) -> bool:
45
+ """Check is it a bifurcate tree."""
46
+ return is_bifurcate((tree.id(), tree.pid()), exclude_root=exclude_soma)
47
+
48
+
49
+ def sort_tree(tree: Tree) -> Tree:
50
+ """Sort the indices of neuron tree.
51
+
52
+ See Also:
53
+ ~.core.swc_utils.sort_nodes
54
+ """
55
+ return _sort_tree(tree.copy())
56
+
57
+
58
+ @overload
59
+ def cut_tree(
60
+ tree: Tree, *, enter: Callable[[Tree.Node, T | None], tuple[T, bool]]
61
+ ) -> Tree: ...
62
+ @overload
63
+ def cut_tree(
64
+ tree: Tree, *, leave: Callable[[Tree.Node, list[K]], tuple[K, bool]]
65
+ ) -> Tree: ...
66
+ def cut_tree(tree: Tree, *, enter=None, leave=None):
67
+ """Traverse and cut the tree.
68
+
69
+ Returning a `True` can delete the current node and its children.
70
+ """
71
+
72
+ removals: list[int] = []
73
+
74
+ if enter:
75
+
76
+ def _enter(n: Tree.Node, parent: tuple[T, bool] | None) -> tuple[T, bool]:
77
+ if parent is not None and parent[1]:
78
+ removals.append(n.id)
79
+ return parent
80
+
81
+ res, removal = enter(n, parent[0] if parent else None)
82
+ if removal:
83
+ removals.append(n.id)
84
+
85
+ return res, removal
86
+
87
+ tree.traverse(enter=_enter)
88
+
89
+ elif leave:
90
+
91
+ def _leave(n: Tree.Node, children: list[K]) -> K:
92
+ res, removal = leave(n, children)
93
+ if removal:
94
+ removals.append(n.id)
95
+
96
+ return res
97
+
98
+ tree.traverse(leave=_leave)
99
+
100
+ else:
101
+ return tree.copy()
102
+
103
+ return to_subtree(tree, removals)
104
+
105
+
106
+ @deprecated("Use `to_subtree` instead")
107
+ def to_sub_tree(swc_like: SWCLike, sub: Topology) -> tuple[Tree, dict[int, int]]:
108
+ """Create subtree from origin tree.
109
+
110
+ You can directly mark the node for removal, and we will remove it,
111
+ but if the node you remove is not a leaf node, you need to use
112
+ `propagate_remove` to remove all children.
113
+
114
+ .. deprecated:: 0.6.0
115
+ Use :meth:`to_subtree` instead.
116
+
117
+ Returns
118
+ tree: Tree
119
+ id_map: dict[int, int]
120
+ """
121
+
122
+ sub = propagate_removal(sub)
123
+ (new_id, new_pid), id_map_arr = to_sub_topology(sub)
124
+
125
+ n_nodes = new_id.shape[0]
126
+ ndata = {k: swc_like.get_ndata(k)[id_map_arr].copy() for k in swc_like.keys()}
127
+ ndata.update(id=new_id, pid=new_pid)
128
+
129
+ subtree = Tree(n_nodes, **ndata, source=swc_like.source, names=swc_like.names)
130
+
131
+ id_map = {}
132
+ for i, idx in enumerate(id_map_arr):
133
+ id_map[idx] = i
134
+ return subtree, id_map
135
+
136
+
137
+ def to_subtree(
138
+ swc_like: SWCLike,
139
+ removals: Iterable[int],
140
+ *,
141
+ out_mapping: Mapping | None = None,
142
+ ) -> Tree:
143
+ """Create subtree from origin tree.
144
+
145
+ Args:
146
+ swc_like: SWCLike
147
+ removals: A list of id of nodes to be removed.
148
+ out_mapping: Map new id to old id.
149
+ """
150
+
151
+ new_ids = swc_like.id().copy()
152
+ for i in removals:
153
+ new_ids[i] = REMOVAL
154
+
155
+ sub = propagate_removal((new_ids, swc_like.pid()))
156
+ n_nodes, ndata, source, names = to_subtree_impl(
157
+ swc_like, sub, out_mapping=out_mapping
158
+ )
159
+ return Tree(n_nodes, **ndata, source=source, names=names)
160
+
161
+
162
+ def get_subtree(
163
+ swc_like: SWCLike, n: int, *, out_mapping: Mapping | None = None
164
+ ) -> Tree:
165
+ """Get subtree rooted at n.
166
+
167
+ Args:
168
+ swc_like: SWCLike
169
+ n: Id of the root of the subtree.
170
+ out_mapping: Map new id to old id.
171
+ """
172
+
173
+ n_nodes, ndata, source, names = get_subtree_impl(
174
+ swc_like, n, out_mapping=out_mapping
175
+ )
176
+ return Tree(n_nodes, **ndata, source=source, names=names)
177
+
178
+
179
+ def redirect_tree(tree: Tree, new_root: int, sort: bool = True) -> Tree:
180
+ """Set root to new point and redirect tree graph.
181
+
182
+ Args:
183
+ tree: The tree.
184
+ new_root: The id of new root.
185
+ sort: If true, sort indices of nodes after redirect.
186
+ """
187
+
188
+ tree = tree.copy()
189
+ path = [tree.node(new_root)]
190
+ while (p := path[-1].parent()) is not None:
191
+ path.append(p)
192
+
193
+ path[0].pid = -1
194
+ path[0].type, path[-1].type = path[-1].type, path[0].type
195
+ for n, p in zip(path[1:], path[:-1]):
196
+ n.pid = p.id
197
+
198
+ if sort:
199
+ _sort_tree(tree)
200
+
201
+ return tree
202
+
203
+
204
+ def cat_tree( # pylint: disable=too-many-arguments
205
+ tree1: Tree,
206
+ tree2: Tree,
207
+ node1: int = 0,
208
+ node2: int = 0,
209
+ *,
210
+ translate: bool = True,
211
+ names: SWCNames | None = None,
212
+ no_move: bool | None = None, # legacy
213
+ ) -> Tree:
214
+ """Concatenates the second tree onto the first one.
215
+
216
+ Args:
217
+ tree1: Tree
218
+ tree2: Tree
219
+ node1: The node id of the tree to be connected.
220
+ node2: The node id of the connection point.
221
+ translate: Weather to translate node_2 to node_1.
222
+ If False, add link between node_1 and node_2 without translate.
223
+ """
224
+ if no_move is not None:
225
+ warnings.warn(
226
+ "`no_move` has been, it is replaced by `translate` in "
227
+ "v0.12.0, and this will be removed in next version",
228
+ DeprecationWarning,
229
+ )
230
+ translate = not no_move
231
+
232
+ names = get_names(names)
233
+ tree, tree2 = tree1.copy(), tree2.copy()
234
+ if not tree2.node(node2).is_root():
235
+ tree2 = redirect_tree(tree2, node2, sort=False)
236
+
237
+ c = tree.node(node1)
238
+ if translate:
239
+ tree2.ndata[names.x] -= tree2.node(node2).x - c.x
240
+ tree2.ndata[names.y] -= tree2.node(node2).y - c.y
241
+ tree2.ndata[names.z] -= tree2.node(node2).z - c.z
242
+
243
+ ns = tree.number_of_nodes()
244
+ if np.linalg.norm(tree2.node(node2).xyz() - c.xyz()) < EPS:
245
+ remove = [node2 + ns]
246
+ link_to_root = [n.id + ns for n in tree2.node(node2).children()]
247
+ else:
248
+ remove = None
249
+ link_to_root = [node2 + ns]
250
+
251
+ # APIs of tree2 are no longer available since we modify the topology
252
+ tree2.ndata[names.id] += ns
253
+ tree2.ndata[names.pid] += ns
254
+
255
+ for k, v in tree.ndata.items(): # only keep keys in tree1
256
+ if k in tree2.ndata:
257
+ tree.ndata[k] = np.concatenate([v, tree2.ndata[k]])
258
+ else:
259
+ tree.ndata[k] = np.pad(v, (0, tree2.number_of_nodes()))
260
+
261
+ for n in link_to_root:
262
+ tree.node(n).pid = node1
263
+
264
+ if remove is not None: # TODO: This should be easy to implement during sort
265
+ for k, v in tree.ndata.items():
266
+ tree.ndata[k] = np.delete(v, remove)
267
+
268
+ _sort_tree(tree)
269
+ return tree
270
+
271
+
272
+ def _sort_tree(tree: Tree) -> Tree:
273
+ """Sort the indices of neuron tree inplace."""
274
+ (new_ids, new_pids), id_map = sort_nodes_impl((tree.id(), tree.pid()))
275
+ tree.ndata = {k: tree.ndata[k][id_map] for k in tree.ndata}
276
+ tree.ndata.update(id=new_ids, pid=new_pids)
277
+ return tree