swcgeom 0.19.4__cp312-cp312-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/__init__.py +21 -0
- swcgeom/analysis/__init__.py +13 -0
- swcgeom/analysis/feature_extractor.py +454 -0
- swcgeom/analysis/features.py +218 -0
- swcgeom/analysis/lmeasure.py +750 -0
- swcgeom/analysis/sholl.py +201 -0
- swcgeom/analysis/trunk.py +183 -0
- swcgeom/analysis/visualization.py +191 -0
- swcgeom/analysis/visualization3d.py +81 -0
- swcgeom/analysis/volume.py +143 -0
- swcgeom/core/__init__.py +19 -0
- swcgeom/core/branch.py +129 -0
- swcgeom/core/branch_tree.py +65 -0
- swcgeom/core/compartment.py +107 -0
- swcgeom/core/node.py +130 -0
- swcgeom/core/path.py +155 -0
- swcgeom/core/population.py +341 -0
- swcgeom/core/swc.py +247 -0
- swcgeom/core/swc_utils/__init__.py +19 -0
- swcgeom/core/swc_utils/assembler.py +35 -0
- swcgeom/core/swc_utils/base.py +180 -0
- swcgeom/core/swc_utils/checker.py +107 -0
- swcgeom/core/swc_utils/io.py +204 -0
- swcgeom/core/swc_utils/normalizer.py +163 -0
- swcgeom/core/swc_utils/subtree.py +70 -0
- swcgeom/core/tree.py +384 -0
- swcgeom/core/tree_utils.py +277 -0
- swcgeom/core/tree_utils_impl.py +58 -0
- swcgeom/images/__init__.py +9 -0
- swcgeom/images/augmentation.py +149 -0
- swcgeom/images/contrast.py +87 -0
- swcgeom/images/folder.py +217 -0
- swcgeom/images/io.py +578 -0
- swcgeom/images/loaders/__init__.py +8 -0
- swcgeom/images/loaders/pbd.cpython-312-darwin.so +0 -0
- swcgeom/images/loaders/pbd.pyx +523 -0
- swcgeom/images/loaders/raw.cpython-312-darwin.so +0 -0
- swcgeom/images/loaders/raw.pyx +183 -0
- swcgeom/transforms/__init__.py +20 -0
- swcgeom/transforms/base.py +136 -0
- swcgeom/transforms/branch.py +223 -0
- swcgeom/transforms/branch_tree.py +74 -0
- swcgeom/transforms/geometry.py +270 -0
- swcgeom/transforms/image_preprocess.py +107 -0
- swcgeom/transforms/image_stack.py +219 -0
- swcgeom/transforms/images.py +206 -0
- swcgeom/transforms/mst.py +183 -0
- swcgeom/transforms/neurolucida_asc.py +498 -0
- swcgeom/transforms/path.py +56 -0
- swcgeom/transforms/population.py +36 -0
- swcgeom/transforms/tree.py +265 -0
- swcgeom/transforms/tree_assembler.py +161 -0
- swcgeom/utils/__init__.py +18 -0
- swcgeom/utils/debug.py +23 -0
- swcgeom/utils/download.py +119 -0
- swcgeom/utils/dsu.py +58 -0
- swcgeom/utils/ellipse.py +131 -0
- swcgeom/utils/file.py +90 -0
- swcgeom/utils/neuromorpho.py +581 -0
- swcgeom/utils/numpy_helper.py +70 -0
- swcgeom/utils/plotter_2d.py +134 -0
- swcgeom/utils/plotter_3d.py +35 -0
- swcgeom/utils/renderer.py +145 -0
- swcgeom/utils/sdf.py +324 -0
- swcgeom/utils/solid_geometry.py +154 -0
- swcgeom/utils/transforms.py +367 -0
- swcgeom/utils/volumetric_object.py +483 -0
- swcgeom-0.19.4.dist-info/METADATA +86 -0
- swcgeom-0.19.4.dist-info/RECORD +72 -0
- swcgeom-0.19.4.dist-info/WHEEL +5 -0
- swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
- swcgeom-0.19.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# -----------------------------------------------------------------------------
|
|
2
|
+
# This file is adapted from the v3d-py-helper project:
|
|
3
|
+
# https://github.com/SEU-ALLEN-codebase/v3d-py-helper
|
|
4
|
+
#
|
|
5
|
+
# Original license: MIT License
|
|
6
|
+
# Copyright (c) Zuohan Zhao
|
|
7
|
+
#
|
|
8
|
+
# Vaa3D in Python Made Easy
|
|
9
|
+
# Python library for Vaa3D functions.
|
|
10
|
+
#
|
|
11
|
+
# The original project is distributed via PyPI under the name `v3d-py-helper`,
|
|
12
|
+
# with its latest release (v0.4.1) not supporting Python 3.13.
|
|
13
|
+
#
|
|
14
|
+
# As of Python 3.13 (released on October 7, 2024), this package fails to build
|
|
15
|
+
# from source due to missing dependencies (e.g., libtiff), and no prebuilt wheels
|
|
16
|
+
# are available on PyPI for Python 3.13. An issue has been raised, but the project
|
|
17
|
+
# appears to be unmaintained at this time, and the author has not responded.
|
|
18
|
+
#
|
|
19
|
+
# To ensure continued compatibility and usability of Vaa3D features under Python 3.13+,
|
|
20
|
+
# we have copied and minimally adapted necessary source files into this project,
|
|
21
|
+
# preserving license and attribution in accordance with the MIT License.
|
|
22
|
+
#
|
|
23
|
+
# Please consult the original repository for full documentation:
|
|
24
|
+
# https://SEU-ALLEN-codebase.github.io/v3d-py-helper
|
|
25
|
+
#
|
|
26
|
+
# If the upstream project resumes maintenance and releases official support
|
|
27
|
+
# for Python 3.13+, this bundled version may be deprecated in favor of the
|
|
28
|
+
# canonical package.
|
|
29
|
+
# -----------------------------------------------------------------------------
|
|
30
|
+
|
|
31
|
+
import io
|
|
32
|
+
import os
|
|
33
|
+
import struct
|
|
34
|
+
import sys
|
|
35
|
+
import cython
|
|
36
|
+
cimport cython
|
|
37
|
+
import numpy as np
|
|
38
|
+
cimport numpy as np
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
DEF FORMAT_KEY_4 = b"raw_image_stack_by_hpeng"
|
|
42
|
+
DEF FORMAT_KEY_5 = b"raw5image_stack_by_hpeng"
|
|
43
|
+
DEF FORMAT_LEN = 24
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
cdef class Raw:
|
|
47
|
+
"""
|
|
48
|
+
For vaa3d raw formats, allowing uint8, uint16 and float32 data types. The image dimension is presumed to be 4.
|
|
49
|
+
|
|
50
|
+
As a raw format, it not only stores the image buffer, but also the image dimension sizes and its data type in the
|
|
51
|
+
header. It also stores the endian of the buffer.
|
|
52
|
+
|
|
53
|
+
This interface saves and loads a multi-dimension numpy array of either of the 3 data types. It also compares the endian of
|
|
54
|
+
the file/numpy array and the machine to make it work properly. Based on the v3draw format key in the header, the
|
|
55
|
+
array can be 4D or 5D.
|
|
56
|
+
|
|
57
|
+
modified from v3d_external/v3d_main/basic_c_fun/stackutils.cpp
|
|
58
|
+
|
|
59
|
+
by Zuohan Zhao
|
|
60
|
+
|
|
61
|
+
2022/5/8
|
|
62
|
+
"""
|
|
63
|
+
cdef:
|
|
64
|
+
bint sz2byte
|
|
65
|
+
|
|
66
|
+
def __init__(self, sz2byte = False):
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
:param sz2byte: set size array to be 2 byte (short int), for compatibility. Default as False.
|
|
70
|
+
"""
|
|
71
|
+
self.sz2byte = sz2byte
|
|
72
|
+
|
|
73
|
+
cpdef np.ndarray load(self, path: str | os.PathLike, int choose = -1):
|
|
74
|
+
"""
|
|
75
|
+
:param path: input image path of v3draw.
|
|
76
|
+
:param choose: choose a channel(4D) or stack(5D) to load, starting from 0, default as -1, meaning all.
|
|
77
|
+
:return: a numpy array of 4D or 5D based on the format key.
|
|
78
|
+
"""
|
|
79
|
+
cdef:
|
|
80
|
+
short datatype
|
|
81
|
+
list sz
|
|
82
|
+
bytes endian_code_data
|
|
83
|
+
str endian, dt
|
|
84
|
+
char dim, header_sz, i
|
|
85
|
+
bytes format_key
|
|
86
|
+
long long bulk_sz, filesize
|
|
87
|
+
|
|
88
|
+
filesize = os.path.getsize(path)
|
|
89
|
+
assert filesize >= FORMAT_LEN, "File size too small, file might be corrupted"
|
|
90
|
+
|
|
91
|
+
with open(path, "rb") as f:
|
|
92
|
+
format_key = f.read(FORMAT_LEN)
|
|
93
|
+
if format_key == FORMAT_KEY_4:
|
|
94
|
+
dim = 4
|
|
95
|
+
elif format_key == FORMAT_KEY_5:
|
|
96
|
+
dim = 5
|
|
97
|
+
else:
|
|
98
|
+
raise RuntimeError("Format key isn't for v3draw")
|
|
99
|
+
if self.sz2byte:
|
|
100
|
+
header_sz = FORMAT_LEN + dim * 2 + 2 + 1
|
|
101
|
+
else:
|
|
102
|
+
header_sz = FORMAT_LEN + dim * 4 + 2 + 1
|
|
103
|
+
assert filesize >= header_sz, "File size too small, file might be corrupted"
|
|
104
|
+
endian_code_data = f.read(1)
|
|
105
|
+
if endian_code_data == b'B':
|
|
106
|
+
endian = '>'
|
|
107
|
+
elif endian_code_data == b'L':
|
|
108
|
+
endian = '<'
|
|
109
|
+
else:
|
|
110
|
+
raise RuntimeError('Endian code should be either B/L')
|
|
111
|
+
datatype = struct.unpack(f'{endian}h', f.read(2))[0]
|
|
112
|
+
if datatype == 1:
|
|
113
|
+
dt = 'u1'
|
|
114
|
+
elif datatype == 2:
|
|
115
|
+
dt = 'u2'
|
|
116
|
+
elif datatype == 4:
|
|
117
|
+
dt = 'f4'
|
|
118
|
+
else:
|
|
119
|
+
raise RuntimeError('v3draw data type can only be 1/2/4')
|
|
120
|
+
if self.sz2byte:
|
|
121
|
+
sz = list(struct.unpack(f'{endian}{dim}h', f.read(dim * 2)))
|
|
122
|
+
else:
|
|
123
|
+
sz = list(struct.unpack(f'{endian}{dim}i', f.read(dim * 4)))
|
|
124
|
+
bulk_sz = 1
|
|
125
|
+
for i in range(dim - 1):
|
|
126
|
+
bulk_sz *= sz[i]
|
|
127
|
+
assert bulk_sz * sz[-1] * datatype + header_sz == filesize, "file size doesn't match with the image"
|
|
128
|
+
if choose < 0:
|
|
129
|
+
return np.frombuffer(f.read(), endian + dt).reshape(sz[::-1])
|
|
130
|
+
else:
|
|
131
|
+
assert choose < sz[-1], "Choose index exceeding the range"
|
|
132
|
+
f.seek(bulk_sz * datatype * choose, 1)
|
|
133
|
+
img = np.frombuffer(f.read(bulk_sz * datatype), endian + dt)
|
|
134
|
+
return img.reshape(sz[-2::-1])
|
|
135
|
+
|
|
136
|
+
cpdef void save(self, path: str | os.PathLike, np.ndarray img):
|
|
137
|
+
"""
|
|
138
|
+
:param path: output image path of v3draw.
|
|
139
|
+
:param img: the image array to save.
|
|
140
|
+
"""
|
|
141
|
+
assert img.ndim == 4, "The image has to be 4D"
|
|
142
|
+
assert img.dtype in [np.uint8, np.uint16], "The pixel type has to be uint8 or uint16"
|
|
143
|
+
cdef:
|
|
144
|
+
short datatype
|
|
145
|
+
list sz
|
|
146
|
+
bytes endian_code_data, format
|
|
147
|
+
str endian, bo = img.dtype.byteorder
|
|
148
|
+
bytes header
|
|
149
|
+
char dim = img.ndim
|
|
150
|
+
|
|
151
|
+
if dim == 4:
|
|
152
|
+
format = FORMAT_KEY_4
|
|
153
|
+
elif dim == 5:
|
|
154
|
+
format = FORMAT_KEY_5
|
|
155
|
+
else:
|
|
156
|
+
raise RuntimeError("Dimension not supported by v3draw")
|
|
157
|
+
|
|
158
|
+
sz = [img.shape[i] for i in range(dim)]
|
|
159
|
+
sz.extend([1] * (dim - len(sz)))
|
|
160
|
+
sz.reverse()
|
|
161
|
+
|
|
162
|
+
if bo == '>' or bo == '=' and sys.byteorder == 'big':
|
|
163
|
+
endian_code_data = b'B'
|
|
164
|
+
endian = '>'
|
|
165
|
+
else:
|
|
166
|
+
endian_code_data = b'L'
|
|
167
|
+
endian = '<'
|
|
168
|
+
|
|
169
|
+
if img.dtype == np.uint8:
|
|
170
|
+
datatype = 1
|
|
171
|
+
elif img.dtype == np.uint16:
|
|
172
|
+
datatype = 2
|
|
173
|
+
elif img.dtype == np.float32:
|
|
174
|
+
datatype = 4
|
|
175
|
+
else:
|
|
176
|
+
raise RuntimeError("numpy data type not supported by v3draw")
|
|
177
|
+
|
|
178
|
+
header = struct.pack(f'{endian}{FORMAT_LEN}sch4{"h" if self.sz2byte else "i"}',
|
|
179
|
+
format, endian_code_data, datatype, *sz)
|
|
180
|
+
|
|
181
|
+
with open(path, 'wb') as f:
|
|
182
|
+
f.write(header)
|
|
183
|
+
f.write(img.tobytes())
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""A series of transformations to compose codes."""
|
|
7
|
+
|
|
8
|
+
from swcgeom.transforms.base import * # noqa: F403
|
|
9
|
+
from swcgeom.transforms.branch import * # noqa: F403
|
|
10
|
+
from swcgeom.transforms.branch_tree import * # noqa: F403
|
|
11
|
+
from swcgeom.transforms.geometry import * # noqa: F403
|
|
12
|
+
from swcgeom.transforms.image_preprocess import * # noqa: F403
|
|
13
|
+
from swcgeom.transforms.image_stack import * # noqa: F403
|
|
14
|
+
from swcgeom.transforms.images import * # noqa: F403
|
|
15
|
+
from swcgeom.transforms.mst import * # noqa: F403
|
|
16
|
+
from swcgeom.transforms.neurolucida_asc import * # noqa: F403
|
|
17
|
+
from swcgeom.transforms.path import * # noqa: F403
|
|
18
|
+
from swcgeom.transforms.population import * # noqa: F403
|
|
19
|
+
from swcgeom.transforms.tree import * # noqa: F403
|
|
20
|
+
from swcgeom.transforms.tree_assembler import * # noqa: F403
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Transformation in tree."""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Any, Generic, TypeVar, overload
|
|
9
|
+
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
__all__ = ["Transform", "Transforms", "Identity"]
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
K = TypeVar("K")
|
|
16
|
+
|
|
17
|
+
T1 = TypeVar("T1")
|
|
18
|
+
T2 = TypeVar("T2")
|
|
19
|
+
T3 = TypeVar("T3")
|
|
20
|
+
T4 = TypeVar("T4")
|
|
21
|
+
T5 = TypeVar("T5")
|
|
22
|
+
T6 = TypeVar("T6")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Transform(ABC, Generic[T, K]):
|
|
26
|
+
r"""An abstract class representing a :class:`Transform`.
|
|
27
|
+
|
|
28
|
+
All transforms that represent a map from `T` to `K`.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def __call__(self, x: T) -> K:
|
|
33
|
+
"""Apply transform.
|
|
34
|
+
|
|
35
|
+
NOTE: All subclasses should overwrite :meth:`__call__`, supporting
|
|
36
|
+
applying transform in `x`.
|
|
37
|
+
"""
|
|
38
|
+
raise NotImplementedError()
|
|
39
|
+
|
|
40
|
+
def __repr__(self) -> str:
|
|
41
|
+
classname = self.__class__.__name__
|
|
42
|
+
repr_ = self.extra_repr()
|
|
43
|
+
return f"{classname}({repr_})"
|
|
44
|
+
|
|
45
|
+
def extra_repr(self) -> str:
|
|
46
|
+
"""Provides a human-friendly representation of the module.
|
|
47
|
+
|
|
48
|
+
This method extends the basic string representation provided by
|
|
49
|
+
`__repr__` method. It is designed to display additional details
|
|
50
|
+
about the module's parameters or its specific configuration,
|
|
51
|
+
which can be particularly useful for debugging and model
|
|
52
|
+
architecture introspection.
|
|
53
|
+
|
|
54
|
+
>>> class Foo(Transform[T, K]):
|
|
55
|
+
... def __init__(self, my_parameter: int = 1):
|
|
56
|
+
... self.my_parameter = my_parameter
|
|
57
|
+
...
|
|
58
|
+
... def extra_repr(self) -> str:
|
|
59
|
+
... return f"my_parameter={self.my_parameter}"
|
|
60
|
+
|
|
61
|
+
NOTE: This method should be overridden in custom modules to provide
|
|
62
|
+
specific details relevant to the module's functionality and
|
|
63
|
+
configuration.
|
|
64
|
+
"""
|
|
65
|
+
return ""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Transforms(Transform[T, K]):
|
|
69
|
+
"""A simple typed wrapper for transforms."""
|
|
70
|
+
|
|
71
|
+
transforms: list[Transform[Any, Any]]
|
|
72
|
+
|
|
73
|
+
# fmt: off
|
|
74
|
+
@overload
|
|
75
|
+
def __init__(self, t1: Transform[T, K], /) -> None: ...
|
|
76
|
+
@overload
|
|
77
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, K], /) -> None: ...
|
|
78
|
+
@overload
|
|
79
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
80
|
+
t3: Transform[T2, K], /) -> None: ...
|
|
81
|
+
@overload
|
|
82
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
83
|
+
t3: Transform[T2, T3], t4: Transform[T3, K], /) -> None: ...
|
|
84
|
+
@overload
|
|
85
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
86
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
87
|
+
t5: Transform[T4, K], /) -> None: ...
|
|
88
|
+
@overload
|
|
89
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
90
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
91
|
+
t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
|
|
92
|
+
@overload
|
|
93
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
94
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
95
|
+
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
96
|
+
t7: Transform[T6, K], /) -> None: ...
|
|
97
|
+
@overload
|
|
98
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
99
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
100
|
+
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
101
|
+
t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
|
|
102
|
+
# fmt: on
|
|
103
|
+
def __init__(self, *transforms: Transform[Any, Any]) -> None:
|
|
104
|
+
trans = []
|
|
105
|
+
for t in transforms:
|
|
106
|
+
if isinstance(t, Transforms):
|
|
107
|
+
trans.extend(t.transforms)
|
|
108
|
+
else:
|
|
109
|
+
trans.append(t)
|
|
110
|
+
self.transforms = trans
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def __call__(self, x: T) -> K:
|
|
114
|
+
"""Apply transforms."""
|
|
115
|
+
for transform in self.transforms:
|
|
116
|
+
x = transform(x)
|
|
117
|
+
|
|
118
|
+
return x # type: ignore
|
|
119
|
+
|
|
120
|
+
def __getitem__(self, idx: int) -> Transform[Any, Any]:
|
|
121
|
+
return self.transforms[idx]
|
|
122
|
+
|
|
123
|
+
def __len__(self) -> int:
|
|
124
|
+
return len(self.transforms)
|
|
125
|
+
|
|
126
|
+
@override
|
|
127
|
+
def extra_repr(self) -> str:
|
|
128
|
+
return ", ".join([str(transform) for transform in self])
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Identity(Transform[T, T]):
|
|
132
|
+
"""Resurn input as-is."""
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def __call__(self, x: T) -> T:
|
|
136
|
+
return x
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
"""Transformation in branch."""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import numpy.typing as npt
|
|
13
|
+
from scipy import signal
|
|
14
|
+
from typing_extensions import override
|
|
15
|
+
|
|
16
|
+
from swcgeom.core import Branch, DictSWC
|
|
17
|
+
from swcgeom.transforms.base import Transform
|
|
18
|
+
from swcgeom.utils import (
|
|
19
|
+
angle,
|
|
20
|
+
rotate3d_x,
|
|
21
|
+
rotate3d_y,
|
|
22
|
+
rotate3d_z,
|
|
23
|
+
scale3d,
|
|
24
|
+
to_homogeneous,
|
|
25
|
+
translate3d,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _BranchResampler(Transform[Branch, Branch], ABC):
|
|
32
|
+
r"""Resample branch."""
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def __call__(self, x: Branch) -> Branch:
|
|
36
|
+
xyzr = self.resample(x.xyzr())
|
|
37
|
+
return Branch.from_xyzr(xyzr)
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class BranchLinearResampler(_BranchResampler):
|
|
44
|
+
r"""Resampling by linear interpolation, DO NOT keep original node."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, n_nodes: int) -> None:
|
|
47
|
+
"""Resample branch to special num of nodes.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
n_nodes: Number of nodes after resample.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.n_nodes = n_nodes
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
57
|
+
"""Resampling by linear interpolation, DO NOT keep original node.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
xyzr: The array of shape (N, 4).
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
coordinates: An array of shape (n_nodes, 4).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
|
|
67
|
+
xp = np.insert(xp, 0, 0)
|
|
68
|
+
xvals = np.linspace(0, xp[-1], self.n_nodes)
|
|
69
|
+
|
|
70
|
+
x = np.interp(xvals, xp, xyzr[:, 0])
|
|
71
|
+
y = np.interp(xvals, xp, xyzr[:, 1])
|
|
72
|
+
z = np.interp(xvals, xp, xyzr[:, 2])
|
|
73
|
+
r = np.interp(xvals, xp, xyzr[:, 3])
|
|
74
|
+
return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def extra_repr(self) -> str:
|
|
78
|
+
return f"n_nodes={self.n_nodes}"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class BranchIsometricResampler(_BranchResampler):
|
|
82
|
+
def __init__(self, distance: float, *, adjust_last_gap: bool = True) -> None:
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.distance = distance
|
|
85
|
+
self.adjust_last_gap = adjust_last_gap
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
89
|
+
"""Resampling by isometric interpolation, DO NOT keep original node.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
xyzr: The array of shape (N, 4).
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
new_xyzr: An array of shape (n_nodes, 4).
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
# Compute the cumulative distances between consecutive points
|
|
99
|
+
diffs = np.diff(xyzr[:, :3], axis=0)
|
|
100
|
+
distances = np.sqrt((diffs**2).sum(axis=1))
|
|
101
|
+
cumulative_distances = np.concatenate([[0], np.cumsum(distances)])
|
|
102
|
+
|
|
103
|
+
total_length = cumulative_distances[-1]
|
|
104
|
+
n_nodes = int(np.ceil(total_length / self.distance)) + 1
|
|
105
|
+
|
|
106
|
+
# Determine the new distances
|
|
107
|
+
if self.adjust_last_gap and n_nodes > 1:
|
|
108
|
+
new_distances = np.linspace(0, total_length, n_nodes)
|
|
109
|
+
else:
|
|
110
|
+
new_distances = np.arange(0, total_length, self.distance)
|
|
111
|
+
# keep endpoint
|
|
112
|
+
new_distances = np.concatenate([new_distances, total_length])
|
|
113
|
+
|
|
114
|
+
# Interpolate the new points
|
|
115
|
+
new_xyzr = np.zeros((n_nodes, 4), dtype=np.float32)
|
|
116
|
+
new_xyzr[:, :3] = np.array(
|
|
117
|
+
[
|
|
118
|
+
np.interp(new_distances, cumulative_distances, xyzr[:, i])
|
|
119
|
+
for i in range(3)
|
|
120
|
+
]
|
|
121
|
+
).T
|
|
122
|
+
new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
|
|
123
|
+
return new_xyzr
|
|
124
|
+
|
|
125
|
+
@override
|
|
126
|
+
def extra_repr(self) -> str:
|
|
127
|
+
return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
131
|
+
r"""Smooth the branch by sliding window."""
|
|
132
|
+
|
|
133
|
+
def __init__(self, n_nodes: int = 5) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Args:
|
|
136
|
+
n_nodes: Window size.
|
|
137
|
+
"""
|
|
138
|
+
super().__init__()
|
|
139
|
+
self.n_nodes = n_nodes
|
|
140
|
+
self.kernel = np.ones(n_nodes)
|
|
141
|
+
|
|
142
|
+
@override
|
|
143
|
+
def __call__(self, x: Branch) -> Branch[DictSWC]:
|
|
144
|
+
x = x.detach()
|
|
145
|
+
c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
|
|
146
|
+
for k in ["x", "y", "z"]:
|
|
147
|
+
v = x.get_ndata(k)
|
|
148
|
+
s = signal.convolve(v, self.kernel, mode="same")
|
|
149
|
+
x.attach.ndata[k][1:-1] = (s / c)[1:-1]
|
|
150
|
+
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
def extra_repr(self) -> str:
|
|
154
|
+
return f"n_nodes={self.n_nodes}"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
|
158
|
+
r"""Standardize branch.
|
|
159
|
+
|
|
160
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
|
|
161
|
+
radius to 1.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
@override
|
|
165
|
+
def __call__(self, x: Branch) -> Branch:
|
|
166
|
+
xyzr = x.xyzr()
|
|
167
|
+
xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
|
|
168
|
+
T = self.get_matrix(xyz)
|
|
169
|
+
|
|
170
|
+
xyz4 = to_homogeneous(xyz, 1).transpose() # (4, N)
|
|
171
|
+
new_xyz = np.dot(T, xyz4)[:3].transpose()
|
|
172
|
+
new_xyzr = np.concatenate([new_xyz, r / r.max()], axis=1)
|
|
173
|
+
return Branch.from_xyzr(new_xyzr)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
177
|
+
r"""Get standardize transformation matrix.
|
|
178
|
+
|
|
179
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
T: An homogeneous transformation matrix of shape (4, 4).
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
assert xyz.ndim == 2 and xyz.shape[1] == 3, (
|
|
189
|
+
f"xyz should be of shape (N, 3), got {xyz.shape}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
xyz = xyz[:, :3]
|
|
193
|
+
T = np.identity(4)
|
|
194
|
+
v = np.concatenate([xyz[-1] - xyz[0], np.zeros((1))])[:, None]
|
|
195
|
+
|
|
196
|
+
# translate to the origin
|
|
197
|
+
T = translate3d(-xyz[0, 0], -xyz[0, 1], -xyz[0, 2]).dot(T)
|
|
198
|
+
|
|
199
|
+
# scale to unit vector
|
|
200
|
+
s = (1 / np.linalg.norm(v[:3, 0])).item()
|
|
201
|
+
T = scale3d(s, s, s).dot(T)
|
|
202
|
+
|
|
203
|
+
# rotate v to the xz-plane, v should be (x, 0, z) now
|
|
204
|
+
vy = np.dot(T, v)[:, 0]
|
|
205
|
+
# when looking at the xz-plane along the positive y-axis, the
|
|
206
|
+
# coordinates should be (z, x)
|
|
207
|
+
T = rotate3d_y(angle([vy[2], vy[0]], [0, 1])).dot(T)
|
|
208
|
+
|
|
209
|
+
# rotate v to the x-axis, v should be (1, 0, 0) now
|
|
210
|
+
vx = np.dot(T, v)[:, 0]
|
|
211
|
+
T = rotate3d_z(angle([vx[0], vx[1]], [1, 0])).dot(T)
|
|
212
|
+
|
|
213
|
+
# rotate the farthest point to the xy-plane
|
|
214
|
+
if xyz.shape[0] > 2:
|
|
215
|
+
xyz4 = to_homogeneous(xyz, 1).transpose() # (4, N)
|
|
216
|
+
new_xyz4 = np.dot(T, xyz4) # (4, N)
|
|
217
|
+
max_index = np.argmax(np.linalg.norm(new_xyz4[1:3, :], axis=0)[1:-1]) + 1
|
|
218
|
+
max_xyz4 = xyz4[:, max_index].reshape(4, 1)
|
|
219
|
+
max_xyz4_t = np.dot(T, max_xyz4) # (4, 1)
|
|
220
|
+
angle_x = angle(max_xyz4_t[1:3, 0], [1, 0])
|
|
221
|
+
T = rotate3d_x(angle_x).dot(T)
|
|
222
|
+
|
|
223
|
+
return T
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
|
|
2
|
+
# SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from swcgeom.core import Branch, BranchTree, Node, Tree
|
|
12
|
+
from swcgeom.transforms.base import Transform
|
|
13
|
+
|
|
14
|
+
__all__ = ["BranchTreeAssembler"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BranchTreeAssembler(Transform[BranchTree, Tree]):
|
|
18
|
+
EPS = 1e-6
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def __call__(self, x: BranchTree) -> Tree:
|
|
22
|
+
nodes = [x.soma().detach()]
|
|
23
|
+
stack = [(x.soma(), 0)] # n_orig, id_new
|
|
24
|
+
while len(stack):
|
|
25
|
+
n_orig, pid_new = stack.pop()
|
|
26
|
+
children = n_orig.children()
|
|
27
|
+
|
|
28
|
+
for br, c in self.pair(x.branches.get(n_orig.id, []), children):
|
|
29
|
+
s = 1 if np.linalg.norm(br[0].xyz() - n_orig.xyz()) < self.EPS else 0
|
|
30
|
+
e = -2 if np.linalg.norm(br[-1].xyz() - c.xyz()) < self.EPS else -1
|
|
31
|
+
|
|
32
|
+
br_nodes = [n.detach() for n in br[s:e]] + [c.detach()]
|
|
33
|
+
for i, n in enumerate(br_nodes):
|
|
34
|
+
# reindex
|
|
35
|
+
n.id = len(nodes) + i
|
|
36
|
+
n.pid = len(nodes) + i - 1
|
|
37
|
+
|
|
38
|
+
br_nodes[0].pid = pid_new
|
|
39
|
+
nodes.extend(br_nodes)
|
|
40
|
+
stack.append((c, br_nodes[-1].id))
|
|
41
|
+
|
|
42
|
+
return Tree(
|
|
43
|
+
len(nodes),
|
|
44
|
+
source=x.source,
|
|
45
|
+
comments=x.comments,
|
|
46
|
+
names=x.names,
|
|
47
|
+
**{
|
|
48
|
+
k: np.array([n.__getattribute__(k) for n in nodes])
|
|
49
|
+
for k in x.names.cols()
|
|
50
|
+
},
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def pair(
|
|
54
|
+
self, branches: list[Branch], endpoints: list[Node]
|
|
55
|
+
) -> Iterable[tuple[Branch, Node]]:
|
|
56
|
+
assert len(branches) == len(endpoints)
|
|
57
|
+
xyz1 = [br[-1].xyz() for br in branches]
|
|
58
|
+
xyz2 = [n.xyz() for n in endpoints]
|
|
59
|
+
v = np.reshape(xyz1, (-1, 1, 3)) - np.reshape(xyz2, (1, -1, 3))
|
|
60
|
+
dis = np.linalg.norm(v, axis=-1)
|
|
61
|
+
|
|
62
|
+
# greedy algorithm
|
|
63
|
+
pairs = []
|
|
64
|
+
for _ in range(len(branches)):
|
|
65
|
+
# find minimal
|
|
66
|
+
min_idx = np.argmin(dis)
|
|
67
|
+
min_branch_idx, min_endpoint_idx = np.unravel_index(min_idx, dis.shape)
|
|
68
|
+
pairs.append((branches[min_branch_idx], endpoints[min_endpoint_idx]))
|
|
69
|
+
|
|
70
|
+
# remove current node
|
|
71
|
+
dis[min_branch_idx, :] = np.inf
|
|
72
|
+
dis[:, min_endpoint_idx] = np.inf
|
|
73
|
+
|
|
74
|
+
return pairs
|