swcgeom 0.19.4__cp313-cp313-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +341 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +107 -0
  23. swcgeom/core/swc_utils/io.py +204 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +384 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-313-darwin.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-313-darwin.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +161 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.19.4.dist-info/METADATA +86 -0
  69. swcgeom-0.19.4.dist-info/RECORD +72 -0
  70. swcgeom-0.19.4.dist-info/WHEEL +5 -0
  71. swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.19.4.dist-info/top_level.txt +1 -0
swcgeom/core/tree.py ADDED
@@ -0,0 +1,384 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Neuron tree."""
7
+
8
+ import itertools
9
+ import os
10
+ from collections.abc import Callable, Iterable, Iterator
11
+ from typing import Literal, TypeVar, Union, overload
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import pandas as pd
16
+ from typing_extensions import deprecated
17
+
18
+ from swcgeom.core.branch import Branch
19
+ from swcgeom.core.compartment import Compartment, Compartments
20
+ from swcgeom.core.node import Node
21
+ from swcgeom.core.path import Path
22
+ from swcgeom.core.swc import DictSWC, eswc_cols
23
+ from swcgeom.core.swc_utils import SWCNames, get_names, read_swc, traverse
24
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl
25
+ from swcgeom.utils import PathOrIO, padding1d
26
+
27
+ __all__ = ["Tree"]
28
+
29
+ T = TypeVar("T")
30
+ K = TypeVar("K")
31
+
32
+
33
+ class Tree(DictSWC):
34
+ """A neuron tree, which should be a binary tree in most cases."""
35
+
36
+ class Node(Node["Tree"]):
37
+ """Neural node."""
38
+
39
+ def parent(self) -> Union["Tree.Node", None]:
40
+ return Tree.Node(self.attach, self.pid) if self.pid != -1 else None
41
+
42
+ def children(self) -> list["Tree.Node"]:
43
+ children = self.attach.id()[self.attach.pid() == self.id]
44
+ return [Tree.Node(self.attach, idx) for idx in children]
45
+
46
+ def branch(self) -> "Tree.Branch":
47
+ ns: list["Tree.Node"] = [self]
48
+ while not ns[-1].is_furcation() and (p := ns[-1].parent()) is not None:
49
+ ns.append(p)
50
+
51
+ ns.reverse()
52
+ while not (ns[-1].is_furcation() or ns[-1].is_tip()):
53
+ ns.append(ns[-1].children()[0])
54
+
55
+ return Tree.Branch(self.attach, [n.id for n in ns])
56
+
57
+ def radial_distance(self) -> float:
58
+ """The end-to-end straight-line distance to soma."""
59
+ return self.distance(self.attach.soma())
60
+
61
+ def subtree(self, *, out_mapping: Mapping | None = None) -> "Tree":
62
+ """Get subtree from node.
63
+
64
+ Args:
65
+ out_mapping: Map from new id to old id.
66
+ """
67
+
68
+ n_nodes, ndata, source, names = get_subtree_impl(
69
+ self.attach, self.id, out_mapping=out_mapping
70
+ )
71
+ return Tree(n_nodes, **ndata, source=source, names=names)
72
+
73
+ def is_root(self) -> bool:
74
+ return self.parent() is None
75
+
76
+ def is_soma(self) -> bool: # TODO: support multi soma, e.g. 3 points
77
+ return self.type == self.attach.types.soma and self.is_root()
78
+
79
+ @overload
80
+ def traverse(
81
+ self, *, enter: Callable[[Node, T | None], T], mode: Literal["dfs"] = ...
82
+ ) -> None: ...
83
+ @overload
84
+ def traverse(
85
+ self, *, leave: Callable[[Node, list[K]], K], mode: Literal["dfs"] = ...
86
+ ) -> K: ...
87
+ @overload
88
+ def traverse(
89
+ self,
90
+ *,
91
+ enter: Callable[[Node, T | None], T],
92
+ leave: Callable[[Node, list[K]], K],
93
+ mode: Literal["dfs"] = ...,
94
+ ) -> K: ...
95
+ def traverse(self, **kwargs): # type: ignore
96
+ """Traverse from node.
97
+
98
+ See Also:
99
+ ~Tree.traverse
100
+ """
101
+ return self.attach.traverse(root=self.idx, **kwargs)
102
+
103
+ class Path(Path["Tree"]):
104
+ # TODO: should returns `Tree.Node`
105
+ """Neural path."""
106
+
107
+ class Compartment(Compartment["Tree"]):
108
+ # TODO: should returns `Tree.Node`
109
+ """Neural compartment."""
110
+
111
+ Segment = Compartment # Alias
112
+
113
+ class Branch(Branch["Tree"]):
114
+ # TODO: should returns `Tree.Node`
115
+ """Neural branch."""
116
+
117
+ def __init__(
118
+ self,
119
+ n_nodes: int,
120
+ *,
121
+ source: str = "",
122
+ comments: Iterable[str] | None = None,
123
+ names: SWCNames | None = None,
124
+ **kwargs: npt.NDArray,
125
+ ) -> None:
126
+ names = get_names(names)
127
+
128
+ if names.id not in kwargs:
129
+ kwargs[names.id] = np.arange(0, n_nodes, step=1, dtype=np.int32)
130
+
131
+ if names.pid not in kwargs:
132
+ kwargs[names.pid] = np.arange(-1, n_nodes - 1, step=1, dtype=np.int32)
133
+
134
+ ndata = {
135
+ names.id: padding1d(n_nodes, kwargs.pop(names.id, None), dtype=np.int32),
136
+ names.type: padding1d(
137
+ n_nodes, kwargs.pop(names.type, None), dtype=np.int32
138
+ ),
139
+ names.x: padding1d(n_nodes, kwargs.pop(names.x, None), dtype=np.float32),
140
+ names.y: padding1d(n_nodes, kwargs.pop(names.y, None), dtype=np.float32),
141
+ names.z: padding1d(n_nodes, kwargs.pop(names.z, None), dtype=np.float32),
142
+ names.r: padding1d(
143
+ n_nodes, kwargs.pop(names.r, None), dtype=np.float32, padding_value=1
144
+ ),
145
+ names.pid: padding1d(n_nodes, kwargs.pop(names.pid, None), dtype=np.int32),
146
+ }
147
+ # ? padding other columns
148
+ super().__init__(
149
+ **ndata, **kwargs, source=source, comments=comments, names=names
150
+ )
151
+
152
+ def __iter__(self) -> Iterator[Node]:
153
+ return (self[i] for i in range(len(self)))
154
+
155
+ def __repr__(self) -> str:
156
+ n_nodes, n_edges = self.number_of_nodes(), self.number_of_edges()
157
+ return f"Neuron Tree with {n_nodes} nodes and {n_edges} edges"
158
+
159
+ @overload
160
+ def __getitem__(self, key: slice) -> list[Node]: ...
161
+ @overload
162
+ def __getitem__(self, key: int) -> Node: ...
163
+ @overload
164
+ def __getitem__(self, key: str) -> npt.NDArray: ...
165
+ def __getitem__(self, key):
166
+ if isinstance(key, slice):
167
+ return [self.node(i) for i in range(*key.indices(len(self)))]
168
+
169
+ if isinstance(key, (int, np.integer)):
170
+ length = len(self)
171
+ if key < -length or key >= length:
172
+ raise IndexError(f"The index ({key}) is out of range.")
173
+
174
+ if key < 0: # Handle negative indices
175
+ key += length
176
+
177
+ return self.node(key)
178
+
179
+ if isinstance(key, str):
180
+ return self.get_ndata(key)
181
+
182
+ raise TypeError("Invalid argument type.")
183
+
184
+ def keys(self) -> Iterable[str]:
185
+ return self.ndata.keys()
186
+
187
+ def node(self, idx: int | np.integer) -> Node:
188
+ return self.Node(self, idx)
189
+
190
+ def soma(self, type_check: bool = True) -> Node:
191
+ """Get soma of neuron."""
192
+ # TODO: find soma, see also: https://neuromorpho.org/myfaq.jsp
193
+ n = self.node(0)
194
+ if type_check and n.type != self.types.soma:
195
+ raise ValueError(f"no soma found in: {self.source}")
196
+ return n
197
+
198
+ def get_furcations(self) -> list[Node]:
199
+ """Get all node of furcations."""
200
+ furcations: list[int] = []
201
+
202
+ def collect_furcations(n: Tree.Node, children: list[None]) -> None:
203
+ if len(children) > 1:
204
+ furcations.append(n.id)
205
+
206
+ self.traverse(leave=collect_furcations)
207
+ return [self.node(i) for i in furcations]
208
+
209
+ @deprecated("Use `get_furcations` instead")
210
+ def get_bifurcations(self) -> list[Node]:
211
+ """Get all node of furcations.
212
+
213
+ .. deprecated:: 0.17.2
214
+ Deprecated due to the wrong spelling of furcation. For now, it is just an
215
+ alias of `get_furcations` and raise a warning. It will be change to raise
216
+ an error in the future.
217
+ """
218
+ return self.get_furcations()
219
+
220
+ def get_tips(self) -> list[Node]:
221
+ """Get all node of tips."""
222
+ tip_ids = np.setdiff1d(self.id(), self.pid(), assume_unique=True)
223
+ return [self.node(i) for i in tip_ids]
224
+
225
+ def get_compartments(self) -> Compartments[Compartment]:
226
+ return Compartments(self.Compartment(self, n.pid, n.id) for n in self[1:])
227
+
228
+ def get_segments(self) -> Compartments[Compartment]: # Alias
229
+ return self.get_compartments()
230
+
231
+ def get_branches(self) -> list[Branch]:
232
+ def collect_branches(
233
+ node: "Tree.Node", pre: list[tuple[list[Tree.Branch], list[int]]]
234
+ ) -> tuple[list[Tree.Branch], list[int]]:
235
+ if len(pre) == 1:
236
+ branches, child = pre[0]
237
+ child.append(node.id)
238
+ return branches, child
239
+
240
+ branches: list[Tree.Branch] = []
241
+
242
+ for sub_branches, child in pre:
243
+ child.append(node.id)
244
+ child.reverse()
245
+ sub_branches.append(Tree.Branch(self, np.array(child, dtype=np.int32)))
246
+ sub_branches.reverse()
247
+ branches.extend(sub_branches)
248
+
249
+ return branches, [node.id]
250
+
251
+ branches, _ = self.traverse(leave=collect_branches)
252
+ return branches
253
+
254
+ def get_paths(self) -> list[Path]:
255
+ """Get all path from soma to tips."""
256
+ path_dic: dict[int, list[int]] = {}
257
+
258
+ def assign_path(n: Tree.Node, pre_path: list[int] | None) -> list[int]:
259
+ path = [] if pre_path is None else pre_path.copy()
260
+ path.append(n.id)
261
+ path_dic[n.id] = path
262
+ return path
263
+
264
+ def collect_path(
265
+ n: Tree.Node, children: list[list[list[int]]]
266
+ ) -> list[list[int]]:
267
+ if len(children) == 0:
268
+ return [path_dic[n.id]]
269
+
270
+ return list(itertools.chain(*children))
271
+
272
+ paths = self.traverse(enter=assign_path, leave=collect_path)
273
+ return [self.Path(self, idx) for idx in paths]
274
+
275
+ def get_neurites(self, type_check: bool = True) -> Iterable["Tree"]:
276
+ """Get neurites from soma."""
277
+ return (n.subtree() for n in self.soma(type_check).children())
278
+
279
+ def get_dendrites(self, type_check: bool = True) -> Iterable["Tree"]:
280
+ """Get dendrites."""
281
+ types = [self.types.apical_dendrite, self.types.basal_dendrite]
282
+ children = self.soma(type_check).children()
283
+ return (n.subtree() for n in children if n.type in types)
284
+
285
+ @overload
286
+ def traverse(
287
+ self,
288
+ *,
289
+ enter: Callable[[Node, T | None], T],
290
+ root: int | np.integer = ...,
291
+ mode: Literal["dfs"] = ...,
292
+ ) -> None: ...
293
+ @overload
294
+ def traverse(
295
+ self,
296
+ *,
297
+ leave: Callable[[Node, list[K]], K],
298
+ root: int | np.integer = ...,
299
+ mode: Literal["dfs"] = ...,
300
+ ) -> K: ...
301
+ @overload
302
+ def traverse(
303
+ self,
304
+ *,
305
+ enter: Callable[[Node, T | None], T],
306
+ leave: Callable[[Node, list[K]], K],
307
+ root: int | np.integer = ...,
308
+ mode: Literal["dfs"] = ...,
309
+ ) -> K: ...
310
+ def traverse(self, *, enter=None, leave=None, **kwargs):
311
+ """Traverse nodes.
312
+
313
+ Args:
314
+ enter: (n: Node, parent: T | None) => T
315
+ leave: (n: Node, children: list[T]) => T
316
+
317
+ See Also:
318
+ ~swc_utils.traverse
319
+ """
320
+
321
+ def wrap(fn) -> Callable | None:
322
+ if fn is None:
323
+ return None
324
+
325
+ def fn_wrapped(idx, *args, **kwargs):
326
+ return fn(self[idx], *args, **kwargs)
327
+
328
+ return fn_wrapped
329
+
330
+ topology = (self.id(), self.pid())
331
+ enter, leave = wrap(enter), wrap(leave)
332
+ return traverse(topology, enter=enter, leave=leave, **kwargs) # type: ignore
333
+
334
+ def length(self) -> float:
335
+ """Get length of tree."""
336
+ return sum(s.length() for s in self.get_segments())
337
+
338
+ @staticmethod
339
+ def from_data_frame(
340
+ df: pd.DataFrame,
341
+ source: str = "",
342
+ *,
343
+ comments: Iterable[str] | None = None,
344
+ names: SWCNames | None = None,
345
+ ) -> "Tree":
346
+ """Read neuron tree from data frame."""
347
+ names = get_names(names)
348
+ tree = Tree(
349
+ df.shape[0],
350
+ **{k: df[k].to_numpy() for k in names.cols()},
351
+ source=source,
352
+ comments=comments,
353
+ names=names,
354
+ )
355
+ return tree
356
+
357
+ @classmethod
358
+ def from_swc(cls, swc_file: PathOrIO, **kwargs) -> "Tree":
359
+ """Read neuron tree from swc file.
360
+
361
+ See Also:
362
+ ~swcgeom.core.swc_utils.read_swc
363
+ """
364
+
365
+ try:
366
+ df, comments = read_swc(swc_file, **kwargs)
367
+ except Exception as e: # pylint: disable=broad-except
368
+ raise ValueError(f"fails to read swc: {swc_file}") from e
369
+
370
+ source = os.path.abspath(swc_file) if isinstance(swc_file, str) else ""
371
+ return cls.from_data_frame(df, source=source, comments=comments)
372
+
373
+ @classmethod
374
+ def from_eswc(
375
+ cls, swc_file: str, extra_cols: list[str] | None = None, **kwargs
376
+ ) -> "Tree":
377
+ """Read neuron tree from eswc file.
378
+
379
+ See Also:
380
+ ~swcgeom.Tree.from_swc
381
+ """
382
+ extra_cols = extra_cols or []
383
+ extra_cols.extend(k for k, t in eswc_cols)
384
+ return cls.from_swc(swc_file, extra_cols=extra_cols, **kwargs)
@@ -0,0 +1,277 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """SWC util wrapper for tree."""
7
+
8
+ import warnings
9
+ from collections.abc import Callable, Iterable
10
+ from typing import TypeVar, overload
11
+
12
+ import numpy as np
13
+ from typing_extensions import deprecated
14
+
15
+ from swcgeom.core.swc import SWCLike
16
+ from swcgeom.core.swc_utils import (
17
+ REMOVAL,
18
+ SWCNames,
19
+ Topology,
20
+ get_names,
21
+ is_bifurcate,
22
+ propagate_removal,
23
+ sort_nodes_impl,
24
+ to_sub_topology,
25
+ )
26
+ from swcgeom.core.tree import Tree
27
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl, to_subtree_impl
28
+
29
+ __all__ = [
30
+ "sort_tree",
31
+ "cut_tree",
32
+ "to_sub_tree",
33
+ "to_subtree",
34
+ "get_subtree",
35
+ "redirect_tree",
36
+ "cat_tree",
37
+ ]
38
+
39
+ T = TypeVar("T")
40
+ K = TypeVar("K")
41
+ EPS = 1e-5
42
+
43
+
44
+ def is_binary_tree(tree: Tree, exclude_soma: bool = True) -> bool:
45
+ """Check is it a bifurcate tree."""
46
+ return is_bifurcate((tree.id(), tree.pid()), exclude_root=exclude_soma)
47
+
48
+
49
+ def sort_tree(tree: Tree) -> Tree:
50
+ """Sort the indices of neuron tree.
51
+
52
+ See Also:
53
+ ~.core.swc_utils.sort_nodes
54
+ """
55
+ return _sort_tree(tree.copy())
56
+
57
+
58
+ @overload
59
+ def cut_tree(
60
+ tree: Tree, *, enter: Callable[[Tree.Node, T | None], tuple[T, bool]]
61
+ ) -> Tree: ...
62
+ @overload
63
+ def cut_tree(
64
+ tree: Tree, *, leave: Callable[[Tree.Node, list[K]], tuple[K, bool]]
65
+ ) -> Tree: ...
66
+ def cut_tree(tree: Tree, *, enter=None, leave=None):
67
+ """Traverse and cut the tree.
68
+
69
+ Returning a `True` can delete the current node and its children.
70
+ """
71
+
72
+ removals: list[int] = []
73
+
74
+ if enter:
75
+
76
+ def _enter(n: Tree.Node, parent: tuple[T, bool] | None) -> tuple[T, bool]:
77
+ if parent is not None and parent[1]:
78
+ removals.append(n.id)
79
+ return parent
80
+
81
+ res, removal = enter(n, parent[0] if parent else None)
82
+ if removal:
83
+ removals.append(n.id)
84
+
85
+ return res, removal
86
+
87
+ tree.traverse(enter=_enter)
88
+
89
+ elif leave:
90
+
91
+ def _leave(n: Tree.Node, children: list[K]) -> K:
92
+ res, removal = leave(n, children)
93
+ if removal:
94
+ removals.append(n.id)
95
+
96
+ return res
97
+
98
+ tree.traverse(leave=_leave)
99
+
100
+ else:
101
+ return tree.copy()
102
+
103
+ return to_subtree(tree, removals)
104
+
105
+
106
+ @deprecated("Use `to_subtree` instead")
107
+ def to_sub_tree(swc_like: SWCLike, sub: Topology) -> tuple[Tree, dict[int, int]]:
108
+ """Create subtree from origin tree.
109
+
110
+ You can directly mark the node for removal, and we will remove it,
111
+ but if the node you remove is not a leaf node, you need to use
112
+ `propagate_remove` to remove all children.
113
+
114
+ .. deprecated:: 0.6.0
115
+ Use :meth:`to_subtree` instead.
116
+
117
+ Returns
118
+ tree: Tree
119
+ id_map: dict[int, int]
120
+ """
121
+
122
+ sub = propagate_removal(sub)
123
+ (new_id, new_pid), id_map_arr = to_sub_topology(sub)
124
+
125
+ n_nodes = new_id.shape[0]
126
+ ndata = {k: swc_like.get_ndata(k)[id_map_arr].copy() for k in swc_like.keys()}
127
+ ndata.update(id=new_id, pid=new_pid)
128
+
129
+ subtree = Tree(n_nodes, **ndata, source=swc_like.source, names=swc_like.names)
130
+
131
+ id_map = {}
132
+ for i, idx in enumerate(id_map_arr):
133
+ id_map[idx] = i
134
+ return subtree, id_map
135
+
136
+
137
+ def to_subtree(
138
+ swc_like: SWCLike,
139
+ removals: Iterable[int],
140
+ *,
141
+ out_mapping: Mapping | None = None,
142
+ ) -> Tree:
143
+ """Create subtree from origin tree.
144
+
145
+ Args:
146
+ swc_like: SWCLike
147
+ removals: A list of id of nodes to be removed.
148
+ out_mapping: Map new id to old id.
149
+ """
150
+
151
+ new_ids = swc_like.id().copy()
152
+ for i in removals:
153
+ new_ids[i] = REMOVAL
154
+
155
+ sub = propagate_removal((new_ids, swc_like.pid()))
156
+ n_nodes, ndata, source, names = to_subtree_impl(
157
+ swc_like, sub, out_mapping=out_mapping
158
+ )
159
+ return Tree(n_nodes, **ndata, source=source, names=names)
160
+
161
+
162
+ def get_subtree(
163
+ swc_like: SWCLike, n: int, *, out_mapping: Mapping | None = None
164
+ ) -> Tree:
165
+ """Get subtree rooted at n.
166
+
167
+ Args:
168
+ swc_like: SWCLike
169
+ n: Id of the root of the subtree.
170
+ out_mapping: Map new id to old id.
171
+ """
172
+
173
+ n_nodes, ndata, source, names = get_subtree_impl(
174
+ swc_like, n, out_mapping=out_mapping
175
+ )
176
+ return Tree(n_nodes, **ndata, source=source, names=names)
177
+
178
+
179
+ def redirect_tree(tree: Tree, new_root: int, sort: bool = True) -> Tree:
180
+ """Set root to new point and redirect tree graph.
181
+
182
+ Args:
183
+ tree: The tree.
184
+ new_root: The id of new root.
185
+ sort: If true, sort indices of nodes after redirect.
186
+ """
187
+
188
+ tree = tree.copy()
189
+ path = [tree.node(new_root)]
190
+ while (p := path[-1].parent()) is not None:
191
+ path.append(p)
192
+
193
+ path[0].pid = -1
194
+ path[0].type, path[-1].type = path[-1].type, path[0].type
195
+ for n, p in zip(path[1:], path[:-1]):
196
+ n.pid = p.id
197
+
198
+ if sort:
199
+ _sort_tree(tree)
200
+
201
+ return tree
202
+
203
+
204
+ def cat_tree( # pylint: disable=too-many-arguments
205
+ tree1: Tree,
206
+ tree2: Tree,
207
+ node1: int = 0,
208
+ node2: int = 0,
209
+ *,
210
+ translate: bool = True,
211
+ names: SWCNames | None = None,
212
+ no_move: bool | None = None, # legacy
213
+ ) -> Tree:
214
+ """Concatenates the second tree onto the first one.
215
+
216
+ Args:
217
+ tree1: Tree
218
+ tree2: Tree
219
+ node1: The node id of the tree to be connected.
220
+ node2: The node id of the connection point.
221
+ translate: Weather to translate node_2 to node_1.
222
+ If False, add link between node_1 and node_2 without translate.
223
+ """
224
+ if no_move is not None:
225
+ warnings.warn(
226
+ "`no_move` has been, it is replaced by `translate` in "
227
+ "v0.12.0, and this will be removed in next version",
228
+ DeprecationWarning,
229
+ )
230
+ translate = not no_move
231
+
232
+ names = get_names(names)
233
+ tree, tree2 = tree1.copy(), tree2.copy()
234
+ if not tree2.node(node2).is_root():
235
+ tree2 = redirect_tree(tree2, node2, sort=False)
236
+
237
+ c = tree.node(node1)
238
+ if translate:
239
+ tree2.ndata[names.x] -= tree2.node(node2).x - c.x
240
+ tree2.ndata[names.y] -= tree2.node(node2).y - c.y
241
+ tree2.ndata[names.z] -= tree2.node(node2).z - c.z
242
+
243
+ ns = tree.number_of_nodes()
244
+ if np.linalg.norm(tree2.node(node2).xyz() - c.xyz()) < EPS:
245
+ remove = [node2 + ns]
246
+ link_to_root = [n.id + ns for n in tree2.node(node2).children()]
247
+ else:
248
+ remove = None
249
+ link_to_root = [node2 + ns]
250
+
251
+ # APIs of tree2 are no longer available since we modify the topology
252
+ tree2.ndata[names.id] += ns
253
+ tree2.ndata[names.pid] += ns
254
+
255
+ for k, v in tree.ndata.items(): # only keep keys in tree1
256
+ if k in tree2.ndata:
257
+ tree.ndata[k] = np.concatenate([v, tree2.ndata[k]])
258
+ else:
259
+ tree.ndata[k] = np.pad(v, (0, tree2.number_of_nodes()))
260
+
261
+ for n in link_to_root:
262
+ tree.node(n).pid = node1
263
+
264
+ if remove is not None: # TODO: This should be easy to implement during sort
265
+ for k, v in tree.ndata.items():
266
+ tree.ndata[k] = np.delete(v, remove)
267
+
268
+ _sort_tree(tree)
269
+ return tree
270
+
271
+
272
+ def _sort_tree(tree: Tree) -> Tree:
273
+ """Sort the indices of neuron tree inplace."""
274
+ (new_ids, new_pids), id_map = sort_nodes_impl((tree.id(), tree.pid()))
275
+ tree.ndata = {k: tree.ndata[k][id_map] for k in tree.ndata}
276
+ tree.ndata.update(id=new_ids, pid=new_pids)
277
+ return tree