swcgeom 0.20.0__cp312-cp312-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +394 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +112 -0
  23. swcgeom/core/swc_utils/io.py +335 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +387 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-312-darwin.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-312-darwin.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +160 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.20.0.dist-info/METADATA +86 -0
  69. swcgeom-0.20.0.dist-info/RECORD +72 -0
  70. swcgeom-0.20.0.dist-info/WHEEL +5 -0
  71. swcgeom-0.20.0.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,183 @@
1
+ # -----------------------------------------------------------------------------
2
+ # This file is adapted from the v3d-py-helper project:
3
+ # https://github.com/SEU-ALLEN-codebase/v3d-py-helper
4
+ #
5
+ # Original license: MIT License
6
+ # Copyright (c) Zuohan Zhao
7
+ #
8
+ # Vaa3D in Python Made Easy
9
+ # Python library for Vaa3D functions.
10
+ #
11
+ # The original project is distributed via PyPI under the name `v3d-py-helper`,
12
+ # with its latest release (v0.4.1) not supporting Python 3.13.
13
+ #
14
+ # As of Python 3.13 (released on October 7, 2024), this package fails to build
15
+ # from source due to missing dependencies (e.g., libtiff), and no prebuilt wheels
16
+ # are available on PyPI for Python 3.13. An issue has been raised, but the project
17
+ # appears to be unmaintained at this time, and the author has not responded.
18
+ #
19
+ # To ensure continued compatibility and usability of Vaa3D features under Python 3.13+,
20
+ # we have copied and minimally adapted necessary source files into this project,
21
+ # preserving license and attribution in accordance with the MIT License.
22
+ #
23
+ # Please consult the original repository for full documentation:
24
+ # https://SEU-ALLEN-codebase.github.io/v3d-py-helper
25
+ #
26
+ # If the upstream project resumes maintenance and releases official support
27
+ # for Python 3.13+, this bundled version may be deprecated in favor of the
28
+ # canonical package.
29
+ # -----------------------------------------------------------------------------
30
+
31
+ import io
32
+ import os
33
+ import struct
34
+ import sys
35
+ import cython
36
+ cimport cython
37
+ import numpy as np
38
+ cimport numpy as np
39
+
40
+
41
+ DEF FORMAT_KEY_4 = b"raw_image_stack_by_hpeng"
42
+ DEF FORMAT_KEY_5 = b"raw5image_stack_by_hpeng"
43
+ DEF FORMAT_LEN = 24
44
+
45
+
46
+ cdef class Raw:
47
+ """
48
+ For vaa3d raw formats, allowing uint8, uint16 and float32 data types. The image dimension is presumed to be 4.
49
+
50
+ As a raw format, it not only stores the image buffer, but also the image dimension sizes and its data type in the
51
+ header. It also stores the endian of the buffer.
52
+
53
+ This interface saves and loads a multi-dimension numpy array of either of the 3 data types. It also compares the endian of
54
+ the file/numpy array and the machine to make it work properly. Based on the v3draw format key in the header, the
55
+ array can be 4D or 5D.
56
+
57
+ modified from v3d_external/v3d_main/basic_c_fun/stackutils.cpp
58
+
59
+ by Zuohan Zhao
60
+
61
+ 2022/5/8
62
+ """
63
+ cdef:
64
+ bint sz2byte
65
+
66
+ def __init__(self, sz2byte = False):
67
+ """
68
+
69
+ :param sz2byte: set size array to be 2 byte (short int), for compatibility. Default as False.
70
+ """
71
+ self.sz2byte = sz2byte
72
+
73
+ cpdef np.ndarray load(self, path: str | os.PathLike, int choose = -1):
74
+ """
75
+ :param path: input image path of v3draw.
76
+ :param choose: choose a channel(4D) or stack(5D) to load, starting from 0, default as -1, meaning all.
77
+ :return: a numpy array of 4D or 5D based on the format key.
78
+ """
79
+ cdef:
80
+ short datatype
81
+ list sz
82
+ bytes endian_code_data
83
+ str endian, dt
84
+ char dim, header_sz, i
85
+ bytes format_key
86
+ long long bulk_sz, filesize
87
+
88
+ filesize = os.path.getsize(path)
89
+ assert filesize >= FORMAT_LEN, "File size too small, file might be corrupted"
90
+
91
+ with open(path, "rb") as f:
92
+ format_key = f.read(FORMAT_LEN)
93
+ if format_key == FORMAT_KEY_4:
94
+ dim = 4
95
+ elif format_key == FORMAT_KEY_5:
96
+ dim = 5
97
+ else:
98
+ raise RuntimeError("Format key isn't for v3draw")
99
+ if self.sz2byte:
100
+ header_sz = FORMAT_LEN + dim * 2 + 2 + 1
101
+ else:
102
+ header_sz = FORMAT_LEN + dim * 4 + 2 + 1
103
+ assert filesize >= header_sz, "File size too small, file might be corrupted"
104
+ endian_code_data = f.read(1)
105
+ if endian_code_data == b'B':
106
+ endian = '>'
107
+ elif endian_code_data == b'L':
108
+ endian = '<'
109
+ else:
110
+ raise RuntimeError('Endian code should be either B/L')
111
+ datatype = struct.unpack(f'{endian}h', f.read(2))[0]
112
+ if datatype == 1:
113
+ dt = 'u1'
114
+ elif datatype == 2:
115
+ dt = 'u2'
116
+ elif datatype == 4:
117
+ dt = 'f4'
118
+ else:
119
+ raise RuntimeError('v3draw data type can only be 1/2/4')
120
+ if self.sz2byte:
121
+ sz = list(struct.unpack(f'{endian}{dim}h', f.read(dim * 2)))
122
+ else:
123
+ sz = list(struct.unpack(f'{endian}{dim}i', f.read(dim * 4)))
124
+ bulk_sz = 1
125
+ for i in range(dim - 1):
126
+ bulk_sz *= sz[i]
127
+ assert bulk_sz * sz[-1] * datatype + header_sz == filesize, "file size doesn't match with the image"
128
+ if choose < 0:
129
+ return np.frombuffer(f.read(), endian + dt).reshape(sz[::-1])
130
+ else:
131
+ assert choose < sz[-1], "Choose index exceeding the range"
132
+ f.seek(bulk_sz * datatype * choose, 1)
133
+ img = np.frombuffer(f.read(bulk_sz * datatype), endian + dt)
134
+ return img.reshape(sz[-2::-1])
135
+
136
+ cpdef void save(self, path: str | os.PathLike, np.ndarray img):
137
+ """
138
+ :param path: output image path of v3draw.
139
+ :param img: the image array to save.
140
+ """
141
+ assert img.ndim == 4, "The image has to be 4D"
142
+ assert img.dtype in [np.uint8, np.uint16], "The pixel type has to be uint8 or uint16"
143
+ cdef:
144
+ short datatype
145
+ list sz
146
+ bytes endian_code_data, format
147
+ str endian, bo = img.dtype.byteorder
148
+ bytes header
149
+ char dim = img.ndim
150
+
151
+ if dim == 4:
152
+ format = FORMAT_KEY_4
153
+ elif dim == 5:
154
+ format = FORMAT_KEY_5
155
+ else:
156
+ raise RuntimeError("Dimension not supported by v3draw")
157
+
158
+ sz = [img.shape[i] for i in range(dim)]
159
+ sz.extend([1] * (dim - len(sz)))
160
+ sz.reverse()
161
+
162
+ if bo == '>' or bo == '=' and sys.byteorder == 'big':
163
+ endian_code_data = b'B'
164
+ endian = '>'
165
+ else:
166
+ endian_code_data = b'L'
167
+ endian = '<'
168
+
169
+ if img.dtype == np.uint8:
170
+ datatype = 1
171
+ elif img.dtype == np.uint16:
172
+ datatype = 2
173
+ elif img.dtype == np.float32:
174
+ datatype = 4
175
+ else:
176
+ raise RuntimeError("numpy data type not supported by v3draw")
177
+
178
+ header = struct.pack(f'{endian}{FORMAT_LEN}sch4{"h" if self.sz2byte else "i"}',
179
+ format, endian_code_data, datatype, *sz)
180
+
181
+ with open(path, 'wb') as f:
182
+ f.write(header)
183
+ f.write(img.tobytes())
@@ -0,0 +1,20 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """A series of transformations to compose codes."""
7
+
8
+ from swcgeom.transforms.base import * # noqa: F403
9
+ from swcgeom.transforms.branch import * # noqa: F403
10
+ from swcgeom.transforms.branch_tree import * # noqa: F403
11
+ from swcgeom.transforms.geometry import * # noqa: F403
12
+ from swcgeom.transforms.image_preprocess import * # noqa: F403
13
+ from swcgeom.transforms.image_stack import * # noqa: F403
14
+ from swcgeom.transforms.images import * # noqa: F403
15
+ from swcgeom.transforms.mst import * # noqa: F403
16
+ from swcgeom.transforms.neurolucida_asc import * # noqa: F403
17
+ from swcgeom.transforms.path import * # noqa: F403
18
+ from swcgeom.transforms.population import * # noqa: F403
19
+ from swcgeom.transforms.tree import * # noqa: F403
20
+ from swcgeom.transforms.tree_assembler import * # noqa: F403
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Transformation in tree."""
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, Generic, TypeVar, overload
9
+
10
+ from typing_extensions import override
11
+
12
+ __all__ = ["Transform", "Transforms", "Identity"]
13
+
14
+ T = TypeVar("T")
15
+ K = TypeVar("K")
16
+
17
+ T1 = TypeVar("T1")
18
+ T2 = TypeVar("T2")
19
+ T3 = TypeVar("T3")
20
+ T4 = TypeVar("T4")
21
+ T5 = TypeVar("T5")
22
+ T6 = TypeVar("T6")
23
+
24
+
25
+ class Transform(ABC, Generic[T, K]):
26
+ r"""An abstract class representing a :class:`Transform`.
27
+
28
+ All transforms that represent a map from `T` to `K`.
29
+ """
30
+
31
+ @abstractmethod
32
+ def __call__(self, x: T) -> K:
33
+ """Apply transform.
34
+
35
+ NOTE: All subclasses should overwrite :meth:`__call__`, supporting
36
+ applying transform in `x`.
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def __repr__(self) -> str:
41
+ classname = self.__class__.__name__
42
+ repr_ = self.extra_repr()
43
+ return f"{classname}({repr_})"
44
+
45
+ def extra_repr(self) -> str:
46
+ """Provides a human-friendly representation of the module.
47
+
48
+ This method extends the basic string representation provided by
49
+ `__repr__` method. It is designed to display additional details
50
+ about the module's parameters or its specific configuration,
51
+ which can be particularly useful for debugging and model
52
+ architecture introspection.
53
+
54
+ >>> class Foo(Transform[T, K]):
55
+ ... def __init__(self, my_parameter: int = 1):
56
+ ... self.my_parameter = my_parameter
57
+ ...
58
+ ... def extra_repr(self) -> str:
59
+ ... return f"my_parameter={self.my_parameter}"
60
+
61
+ NOTE: This method should be overridden in custom modules to provide
62
+ specific details relevant to the module's functionality and
63
+ configuration.
64
+ """
65
+ return ""
66
+
67
+
68
+ class Transforms(Transform[T, K]):
69
+ """A simple typed wrapper for transforms."""
70
+
71
+ transforms: list[Transform[Any, Any]]
72
+
73
+ # fmt: off
74
+ @overload
75
+ def __init__(self, t1: Transform[T, K], /) -> None: ...
76
+ @overload
77
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, K], /) -> None: ...
78
+ @overload
79
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
80
+ t3: Transform[T2, K], /) -> None: ...
81
+ @overload
82
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
83
+ t3: Transform[T2, T3], t4: Transform[T3, K], /) -> None: ...
84
+ @overload
85
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
86
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
87
+ t5: Transform[T4, K], /) -> None: ...
88
+ @overload
89
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
90
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
91
+ t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
92
+ @overload
93
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
94
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
95
+ t5: Transform[T4, T5], t6: Transform[T5, T6],
96
+ t7: Transform[T6, K], /) -> None: ...
97
+ @overload
98
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
99
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
100
+ t5: Transform[T4, T5], t6: Transform[T5, T6],
101
+ t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
102
+ # fmt: on
103
+ def __init__(self, *transforms: Transform[Any, Any]) -> None:
104
+ trans = []
105
+ for t in transforms:
106
+ if isinstance(t, Transforms):
107
+ trans.extend(t.transforms)
108
+ else:
109
+ trans.append(t)
110
+ self.transforms = trans
111
+
112
+ @override
113
+ def __call__(self, x: T) -> K:
114
+ """Apply transforms."""
115
+ for transform in self.transforms:
116
+ x = transform(x)
117
+
118
+ return x # type: ignore
119
+
120
+ def __getitem__(self, idx: int) -> Transform[Any, Any]:
121
+ return self.transforms[idx]
122
+
123
+ def __len__(self) -> int:
124
+ return len(self.transforms)
125
+
126
+ @override
127
+ def extra_repr(self) -> str:
128
+ return ", ".join([str(transform) for transform in self])
129
+
130
+
131
+ class Identity(Transform[T, T]):
132
+ """Resurn input as-is."""
133
+
134
+ @override
135
+ def __call__(self, x: T) -> T:
136
+ return x
@@ -0,0 +1,223 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Transformation in branch."""
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import cast
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ from scipy import signal
14
+ from typing_extensions import override
15
+
16
+ from swcgeom.core import Branch, DictSWC
17
+ from swcgeom.transforms.base import Transform
18
+ from swcgeom.utils import (
19
+ angle,
20
+ rotate3d_x,
21
+ rotate3d_y,
22
+ rotate3d_z,
23
+ scale3d,
24
+ to_homogeneous,
25
+ translate3d,
26
+ )
27
+
28
+ __all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
29
+
30
+
31
+ class _BranchResampler(Transform[Branch, Branch], ABC):
32
+ r"""Resample branch."""
33
+
34
+ @override
35
+ def __call__(self, x: Branch) -> Branch:
36
+ xyzr = self.resample(x.xyzr())
37
+ return Branch.from_xyzr(xyzr)
38
+
39
+ @abstractmethod
40
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
41
+
42
+
43
+ class BranchLinearResampler(_BranchResampler):
44
+ r"""Resampling by linear interpolation, DO NOT keep original node."""
45
+
46
+ def __init__(self, n_nodes: int) -> None:
47
+ """Resample branch to special num of nodes.
48
+
49
+ Args:
50
+ n_nodes: Number of nodes after resample.
51
+ """
52
+ super().__init__()
53
+ self.n_nodes = n_nodes
54
+
55
+ @override
56
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
57
+ """Resampling by linear interpolation, DO NOT keep original node.
58
+
59
+ Args:
60
+ xyzr: The array of shape (N, 4).
61
+
62
+ Returns:
63
+ coordinates: An array of shape (n_nodes, 4).
64
+ """
65
+
66
+ xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
67
+ xp = np.insert(xp, 0, 0)
68
+ xvals = np.linspace(0, xp[-1], self.n_nodes)
69
+
70
+ x = np.interp(xvals, xp, xyzr[:, 0])
71
+ y = np.interp(xvals, xp, xyzr[:, 1])
72
+ z = np.interp(xvals, xp, xyzr[:, 2])
73
+ r = np.interp(xvals, xp, xyzr[:, 3])
74
+ return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
75
+
76
+ @override
77
+ def extra_repr(self) -> str:
78
+ return f"n_nodes={self.n_nodes}"
79
+
80
+
81
+ class BranchIsometricResampler(_BranchResampler):
82
+ def __init__(self, distance: float, *, adjust_last_gap: bool = True) -> None:
83
+ super().__init__()
84
+ self.distance = distance
85
+ self.adjust_last_gap = adjust_last_gap
86
+
87
+ @override
88
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
89
+ """Resampling by isometric interpolation, DO NOT keep original node.
90
+
91
+ Args:
92
+ xyzr: The array of shape (N, 4).
93
+
94
+ Returns:
95
+ new_xyzr: An array of shape (n_nodes, 4).
96
+ """
97
+
98
+ # Compute the cumulative distances between consecutive points
99
+ diffs = np.diff(xyzr[:, :3], axis=0)
100
+ distances = np.sqrt((diffs**2).sum(axis=1))
101
+ cumulative_distances = np.concatenate([[0], np.cumsum(distances)])
102
+
103
+ total_length = cumulative_distances[-1]
104
+ n_nodes = int(np.ceil(total_length / self.distance)) + 1
105
+
106
+ # Determine the new distances
107
+ if self.adjust_last_gap and n_nodes > 1:
108
+ new_distances = np.linspace(0, total_length, n_nodes)
109
+ else:
110
+ new_distances = np.arange(0, total_length, self.distance)
111
+ # keep endpoint
112
+ new_distances = np.concatenate([new_distances, total_length])
113
+
114
+ # Interpolate the new points
115
+ new_xyzr = np.zeros((n_nodes, 4), dtype=np.float32)
116
+ new_xyzr[:, :3] = np.array(
117
+ [
118
+ np.interp(new_distances, cumulative_distances, xyzr[:, i])
119
+ for i in range(3)
120
+ ]
121
+ ).T
122
+ new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
123
+ return new_xyzr
124
+
125
+ @override
126
+ def extra_repr(self) -> str:
127
+ return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
128
+
129
+
130
+ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
131
+ r"""Smooth the branch by sliding window."""
132
+
133
+ def __init__(self, n_nodes: int = 5) -> None:
134
+ """
135
+ Args:
136
+ n_nodes: Window size.
137
+ """
138
+ super().__init__()
139
+ self.n_nodes = n_nodes
140
+ self.kernel = np.ones(n_nodes)
141
+
142
+ @override
143
+ def __call__(self, x: Branch) -> Branch[DictSWC]:
144
+ x = x.detach()
145
+ c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
146
+ for k in ["x", "y", "z"]:
147
+ v = x.get_ndata(k)
148
+ s = signal.convolve(v, self.kernel, mode="same")
149
+ x.attach.ndata[k][1:-1] = (s / c)[1:-1]
150
+
151
+ return x
152
+
153
+ def extra_repr(self) -> str:
154
+ return f"n_nodes={self.n_nodes}"
155
+
156
+
157
+ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
158
+ r"""Standardize branch.
159
+
160
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
161
+ radius to 1.
162
+ """
163
+
164
+ @override
165
+ def __call__(self, x: Branch) -> Branch:
166
+ xyzr = x.xyzr()
167
+ xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
168
+ T = self.get_matrix(xyz)
169
+
170
+ xyz4 = to_homogeneous(xyz, 1).transpose() # (4, N)
171
+ new_xyz = np.dot(T, xyz4)[:3].transpose()
172
+ new_xyzr = np.concatenate([new_xyz, r / r.max()], axis=1)
173
+ return Branch.from_xyzr(new_xyzr)
174
+
175
+ @staticmethod
176
+ def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
177
+ r"""Get standardize transformation matrix.
178
+
179
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
180
+
181
+ Args:
182
+ xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
183
+
184
+ Returns:
185
+ T: An homogeneous transformation matrix of shape (4, 4).
186
+ """
187
+
188
+ assert xyz.ndim == 2 and xyz.shape[1] == 3, (
189
+ f"xyz should be of shape (N, 3), got {xyz.shape}"
190
+ )
191
+
192
+ xyz = xyz[:, :3]
193
+ T = np.identity(4)
194
+ v = np.concatenate([xyz[-1] - xyz[0], np.zeros((1))])[:, None]
195
+
196
+ # translate to the origin
197
+ T = translate3d(-xyz[0, 0], -xyz[0, 1], -xyz[0, 2]).dot(T)
198
+
199
+ # scale to unit vector
200
+ s = (1 / np.linalg.norm(v[:3, 0])).item()
201
+ T = scale3d(s, s, s).dot(T)
202
+
203
+ # rotate v to the xz-plane, v should be (x, 0, z) now
204
+ vy = np.dot(T, v)[:, 0]
205
+ # when looking at the xz-plane along the positive y-axis, the
206
+ # coordinates should be (z, x)
207
+ T = rotate3d_y(angle([vy[2], vy[0]], [0, 1])).dot(T)
208
+
209
+ # rotate v to the x-axis, v should be (1, 0, 0) now
210
+ vx = np.dot(T, v)[:, 0]
211
+ T = rotate3d_z(angle([vx[0], vx[1]], [1, 0])).dot(T)
212
+
213
+ # rotate the farthest point to the xy-plane
214
+ if xyz.shape[0] > 2:
215
+ xyz4 = to_homogeneous(xyz, 1).transpose() # (4, N)
216
+ new_xyz4 = np.dot(T, xyz4) # (4, N)
217
+ max_index = np.argmax(np.linalg.norm(new_xyz4[1:3, :], axis=0)[1:-1]) + 1
218
+ max_xyz4 = xyz4[:, max_index].reshape(4, 1)
219
+ max_xyz4_t = np.dot(T, max_xyz4) # (4, 1)
220
+ angle_x = angle(max_xyz4_t[1:3, 0], [1, 0])
221
+ T = rotate3d_x(angle_x).dot(T)
222
+
223
+ return T
@@ -0,0 +1,74 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ from collections.abc import Iterable
7
+
8
+ import numpy as np
9
+ from typing_extensions import override
10
+
11
+ from swcgeom.core import Branch, BranchTree, Node, Tree
12
+ from swcgeom.transforms.base import Transform
13
+
14
+ __all__ = ["BranchTreeAssembler"]
15
+
16
+
17
+ class BranchTreeAssembler(Transform[BranchTree, Tree]):
18
+ EPS = 1e-6
19
+
20
+ @override
21
+ def __call__(self, x: BranchTree) -> Tree:
22
+ nodes = [x.soma().detach()]
23
+ stack = [(x.soma(), 0)] # n_orig, id_new
24
+ while len(stack):
25
+ n_orig, pid_new = stack.pop()
26
+ children = n_orig.children()
27
+
28
+ for br, c in self.pair(x.branches.get(n_orig.id, []), children):
29
+ s = 1 if np.linalg.norm(br[0].xyz() - n_orig.xyz()) < self.EPS else 0
30
+ e = -2 if np.linalg.norm(br[-1].xyz() - c.xyz()) < self.EPS else -1
31
+
32
+ br_nodes = [n.detach() for n in br[s:e]] + [c.detach()]
33
+ for i, n in enumerate(br_nodes):
34
+ # reindex
35
+ n.id = len(nodes) + i
36
+ n.pid = len(nodes) + i - 1
37
+
38
+ br_nodes[0].pid = pid_new
39
+ nodes.extend(br_nodes)
40
+ stack.append((c, br_nodes[-1].id))
41
+
42
+ return Tree(
43
+ len(nodes),
44
+ source=x.source,
45
+ comments=x.comments,
46
+ names=x.names,
47
+ **{
48
+ k: np.array([n.__getattribute__(k) for n in nodes])
49
+ for k in x.names.cols()
50
+ },
51
+ )
52
+
53
+ def pair(
54
+ self, branches: list[Branch], endpoints: list[Node]
55
+ ) -> Iterable[tuple[Branch, Node]]:
56
+ assert len(branches) == len(endpoints)
57
+ xyz1 = [br[-1].xyz() for br in branches]
58
+ xyz2 = [n.xyz() for n in endpoints]
59
+ v = np.reshape(xyz1, (-1, 1, 3)) - np.reshape(xyz2, (1, -1, 3))
60
+ dis = np.linalg.norm(v, axis=-1)
61
+
62
+ # greedy algorithm
63
+ pairs = []
64
+ for _ in range(len(branches)):
65
+ # find minimal
66
+ min_idx = np.argmin(dis)
67
+ min_branch_idx, min_endpoint_idx = np.unravel_index(min_idx, dis.shape)
68
+ pairs.append((branches[min_branch_idx], endpoints[min_endpoint_idx]))
69
+
70
+ # remove current node
71
+ dis[min_branch_idx, :] = np.inf
72
+ dis[:, min_endpoint_idx] = np.inf
73
+
74
+ return pairs