swcgeom 0.18.3__py3-none-any.whl → 0.19.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/analysis/feature_extractor.py +22 -24
- swcgeom/analysis/features.py +18 -40
- swcgeom/analysis/lmeasure.py +227 -323
- swcgeom/analysis/sholl.py +17 -23
- swcgeom/analysis/trunk.py +23 -28
- swcgeom/analysis/visualization.py +37 -44
- swcgeom/analysis/visualization3d.py +16 -25
- swcgeom/analysis/volume.py +33 -47
- swcgeom/core/__init__.py +1 -6
- swcgeom/core/branch.py +10 -17
- swcgeom/core/branch_tree.py +3 -2
- swcgeom/core/compartment.py +1 -1
- swcgeom/core/node.py +3 -6
- swcgeom/core/path.py +11 -16
- swcgeom/core/population.py +32 -51
- swcgeom/core/swc.py +25 -16
- swcgeom/core/swc_utils/__init__.py +4 -6
- swcgeom/core/swc_utils/assembler.py +5 -12
- swcgeom/core/swc_utils/base.py +40 -31
- swcgeom/core/swc_utils/checker.py +3 -8
- swcgeom/core/swc_utils/io.py +32 -47
- swcgeom/core/swc_utils/normalizer.py +17 -23
- swcgeom/core/swc_utils/subtree.py +13 -20
- swcgeom/core/tree.py +61 -51
- swcgeom/core/tree_utils.py +36 -49
- swcgeom/core/tree_utils_impl.py +4 -6
- swcgeom/images/augmentation.py +23 -39
- swcgeom/images/contrast.py +22 -46
- swcgeom/images/folder.py +32 -34
- swcgeom/images/io.py +108 -126
- swcgeom/transforms/base.py +28 -19
- swcgeom/transforms/branch.py +31 -41
- swcgeom/transforms/branch_tree.py +3 -1
- swcgeom/transforms/geometry.py +13 -4
- swcgeom/transforms/image_preprocess.py +2 -0
- swcgeom/transforms/image_stack.py +40 -35
- swcgeom/transforms/images.py +31 -24
- swcgeom/transforms/mst.py +27 -40
- swcgeom/transforms/neurolucida_asc.py +13 -13
- swcgeom/transforms/path.py +4 -0
- swcgeom/transforms/population.py +4 -0
- swcgeom/transforms/tree.py +16 -11
- swcgeom/transforms/tree_assembler.py +37 -54
- swcgeom/utils/download.py +7 -14
- swcgeom/utils/dsu.py +12 -0
- swcgeom/utils/ellipse.py +26 -14
- swcgeom/utils/file.py +8 -13
- swcgeom/utils/neuromorpho.py +78 -92
- swcgeom/utils/numpy_helper.py +15 -12
- swcgeom/utils/plotter_2d.py +10 -16
- swcgeom/utils/plotter_3d.py +7 -9
- swcgeom/utils/renderer.py +16 -8
- swcgeom/utils/sdf.py +12 -23
- swcgeom/utils/solid_geometry.py +58 -2
- swcgeom/utils/transforms.py +164 -100
- swcgeom/utils/volumetric_object.py +29 -53
- {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/METADATA +6 -5
- swcgeom-0.19.1.dist-info/RECORD +67 -0
- {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/WHEEL +1 -1
- swcgeom-0.18.3.dist-info/RECORD +0 -67
- {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info/licenses}/LICENSE +0 -0
- {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/top_level.txt +0 -0
swcgeom/transforms/base.py
CHANGED
|
@@ -18,12 +18,19 @@
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from typing import Any, Generic, TypeVar, overload
|
|
20
20
|
|
|
21
|
+
from typing_extensions import override
|
|
22
|
+
|
|
21
23
|
__all__ = ["Transform", "Transforms", "Identity"]
|
|
22
24
|
|
|
23
|
-
T
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
K = TypeVar("K")
|
|
24
27
|
|
|
25
|
-
T1
|
|
26
|
-
|
|
28
|
+
T1 = TypeVar("T1")
|
|
29
|
+
T2 = TypeVar("T2")
|
|
30
|
+
T3 = TypeVar("T3")
|
|
31
|
+
T4 = TypeVar("T4")
|
|
32
|
+
T5 = TypeVar("T5")
|
|
33
|
+
T6 = TypeVar("T6")
|
|
27
34
|
|
|
28
35
|
|
|
29
36
|
class Transform(ABC, Generic[T, K]):
|
|
@@ -36,9 +43,7 @@ class Transform(ABC, Generic[T, K]):
|
|
|
36
43
|
def __call__(self, x: T) -> K:
|
|
37
44
|
"""Apply transform.
|
|
38
45
|
|
|
39
|
-
|
|
40
|
-
-----
|
|
41
|
-
All subclasses should overwrite :meth:`__call__`, supporting
|
|
46
|
+
NOTE: All subclasses should overwrite :meth:`__call__`, supporting
|
|
42
47
|
applying transform in `x`.
|
|
43
48
|
"""
|
|
44
49
|
raise NotImplementedError()
|
|
@@ -57,18 +62,14 @@ class Transform(ABC, Generic[T, K]):
|
|
|
57
62
|
which can be particularly useful for debugging and model
|
|
58
63
|
architecture introspection.
|
|
59
64
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def extra_repr(self) -> str:
|
|
67
|
-
return f"my_parameter={self.my_parameter}"
|
|
65
|
+
>>> class Foo(Transform[T, K]):
|
|
66
|
+
... def __init__(self, my_parameter: int = 1):
|
|
67
|
+
... self.my_parameter = my_parameter
|
|
68
|
+
...
|
|
69
|
+
... def extra_repr(self) -> str:
|
|
70
|
+
... return f"my_parameter={self.my_parameter}"
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
-----
|
|
71
|
-
This method should be overridden in custom modules to provide
|
|
72
|
+
NOTE: This method should be overridden in custom modules to provide
|
|
72
73
|
specific details relevant to the module's functionality and
|
|
73
74
|
configuration.
|
|
74
75
|
"""
|
|
@@ -80,7 +81,7 @@ class Transforms(Transform[T, K]):
|
|
|
80
81
|
|
|
81
82
|
transforms: list[Transform[Any, Any]]
|
|
82
83
|
|
|
83
|
-
# fmt:off
|
|
84
|
+
# fmt: off
|
|
84
85
|
@overload
|
|
85
86
|
def __init__(self, t1: Transform[T, K], /) -> None: ...
|
|
86
87
|
@overload
|
|
@@ -100,11 +101,16 @@ class Transforms(Transform[T, K]):
|
|
|
100
101
|
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
101
102
|
t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
|
|
102
103
|
@overload
|
|
104
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
105
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
106
|
+
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
107
|
+
t7: Transform[T6, K], /) -> None: ...
|
|
108
|
+
@overload
|
|
103
109
|
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
104
110
|
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
105
111
|
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
106
112
|
t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
|
|
107
|
-
# fmt:on
|
|
113
|
+
# fmt: on
|
|
108
114
|
def __init__(self, *transforms: Transform[Any, Any]) -> None:
|
|
109
115
|
trans = []
|
|
110
116
|
for t in transforms:
|
|
@@ -114,6 +120,7 @@ class Transforms(Transform[T, K]):
|
|
|
114
120
|
trans.append(t)
|
|
115
121
|
self.transforms = trans
|
|
116
122
|
|
|
123
|
+
@override
|
|
117
124
|
def __call__(self, x: T) -> K:
|
|
118
125
|
"""Apply transforms."""
|
|
119
126
|
for transform in self.transforms:
|
|
@@ -127,6 +134,7 @@ class Transforms(Transform[T, K]):
|
|
|
127
134
|
def __len__(self) -> int:
|
|
128
135
|
return len(self.transforms)
|
|
129
136
|
|
|
137
|
+
@override
|
|
130
138
|
def extra_repr(self) -> str:
|
|
131
139
|
return ", ".join([str(transform) for transform in self])
|
|
132
140
|
|
|
@@ -134,5 +142,6 @@ class Transforms(Transform[T, K]):
|
|
|
134
142
|
class Identity(Transform[T, T]):
|
|
135
143
|
"""Resurn input as-is."""
|
|
136
144
|
|
|
145
|
+
@override
|
|
137
146
|
def __call__(self, x: T) -> T:
|
|
138
147
|
return x
|
swcgeom/transforms/branch.py
CHANGED
|
@@ -21,6 +21,7 @@ from typing import cast
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import numpy.typing as npt
|
|
23
23
|
from scipy import signal
|
|
24
|
+
from typing_extensions import override
|
|
24
25
|
|
|
25
26
|
from swcgeom.core import Branch, DictSWC
|
|
26
27
|
from swcgeom.transforms.base import Transform
|
|
@@ -40,13 +41,13 @@ __all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
|
|
|
40
41
|
class _BranchResampler(Transform[Branch, Branch], ABC):
|
|
41
42
|
r"""Resample branch."""
|
|
42
43
|
|
|
44
|
+
@override
|
|
43
45
|
def __call__(self, x: Branch) -> Branch:
|
|
44
46
|
xyzr = self.resample(x.xyzr())
|
|
45
47
|
return Branch.from_xyzr(xyzr)
|
|
46
48
|
|
|
47
49
|
@abstractmethod
|
|
48
|
-
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
49
|
-
raise NotImplementedError()
|
|
50
|
+
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class BranchLinearResampler(_BranchResampler):
|
|
@@ -55,26 +56,21 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
55
56
|
def __init__(self, n_nodes: int) -> None:
|
|
56
57
|
"""Resample branch to special num of nodes.
|
|
57
58
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
n_nodes : int
|
|
61
|
-
Number of nodes after resample.
|
|
59
|
+
Args:
|
|
60
|
+
n_nodes: Number of nodes after resample.
|
|
62
61
|
"""
|
|
63
62
|
super().__init__()
|
|
64
63
|
self.n_nodes = n_nodes
|
|
65
64
|
|
|
65
|
+
@override
|
|
66
66
|
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
67
67
|
"""Resampling by linear interpolation, DO NOT keep original node.
|
|
68
68
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
xyzr : np.ndarray[np.float32]
|
|
72
|
-
The array of shape (N, 4).
|
|
69
|
+
Args:
|
|
70
|
+
xyzr: The array of shape (N, 4).
|
|
73
71
|
|
|
74
|
-
Returns
|
|
75
|
-
|
|
76
|
-
coordinates : ~numpy.NDArray[float64]
|
|
77
|
-
An array of shape (n_nodes, 4).
|
|
72
|
+
Returns:
|
|
73
|
+
coordinates: An array of shape (n_nodes, 4).
|
|
78
74
|
"""
|
|
79
75
|
|
|
80
76
|
xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
|
|
@@ -87,6 +83,7 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
87
83
|
r = np.interp(xvals, xp, xyzr[:, 3])
|
|
88
84
|
return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
|
|
89
85
|
|
|
86
|
+
@override
|
|
90
87
|
def extra_repr(self) -> str:
|
|
91
88
|
return f"n_nodes={self.n_nodes}"
|
|
92
89
|
|
|
@@ -97,18 +94,15 @@ class BranchIsometricResampler(_BranchResampler):
|
|
|
97
94
|
self.distance = distance
|
|
98
95
|
self.adjust_last_gap = adjust_last_gap
|
|
99
96
|
|
|
97
|
+
@override
|
|
100
98
|
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
101
99
|
"""Resampling by isometric interpolation, DO NOT keep original node.
|
|
102
100
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
xyzr : np.ndarray[np.float32]
|
|
106
|
-
The array of shape (N, 4).
|
|
101
|
+
Args:
|
|
102
|
+
xyzr: The array of shape (N, 4).
|
|
107
103
|
|
|
108
|
-
Returns
|
|
109
|
-
|
|
110
|
-
new_xyzr : ~numpy.NDArray[float32]
|
|
111
|
-
An array of shape (n_nodes, 4).
|
|
104
|
+
Returns:
|
|
105
|
+
new_xyzr: An array of shape (n_nodes, 4).
|
|
112
106
|
"""
|
|
113
107
|
|
|
114
108
|
# Compute the cumulative distances between consecutive points
|
|
@@ -138,6 +132,7 @@ class BranchIsometricResampler(_BranchResampler):
|
|
|
138
132
|
new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
|
|
139
133
|
return new_xyzr
|
|
140
134
|
|
|
135
|
+
@override
|
|
141
136
|
def extra_repr(self) -> str:
|
|
142
137
|
return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
|
|
143
138
|
|
|
@@ -147,15 +142,14 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
|
147
142
|
|
|
148
143
|
def __init__(self, n_nodes: int = 5) -> None:
|
|
149
144
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
n_nodes : int, default `5`
|
|
153
|
-
Window size.
|
|
145
|
+
Args:
|
|
146
|
+
n_nodes: Window size.
|
|
154
147
|
"""
|
|
155
148
|
super().__init__()
|
|
156
149
|
self.n_nodes = n_nodes
|
|
157
150
|
self.kernel = np.ones(n_nodes)
|
|
158
151
|
|
|
152
|
+
@override
|
|
159
153
|
def __call__(self, x: Branch) -> Branch[DictSWC]:
|
|
160
154
|
x = x.detach()
|
|
161
155
|
c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
|
|
@@ -173,10 +167,11 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
|
173
167
|
class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
|
174
168
|
r"""Standardize branch.
|
|
175
169
|
|
|
176
|
-
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at
|
|
177
|
-
|
|
170
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
|
|
171
|
+
radius to 1.
|
|
178
172
|
"""
|
|
179
173
|
|
|
174
|
+
@override
|
|
180
175
|
def __call__(self, x: Branch) -> Branch:
|
|
181
176
|
xyzr = x.xyzr()
|
|
182
177
|
xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
|
|
@@ -191,23 +186,18 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
|
|
191
186
|
def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
192
187
|
r"""Get standardize transformation matrix.
|
|
193
188
|
|
|
194
|
-
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up
|
|
195
|
-
at y.
|
|
189
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
|
|
196
190
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
xyz : np.ndarray[np.float32]
|
|
200
|
-
The `x`, `y`, `z` matrix of shape (N, 3) of branch.
|
|
191
|
+
Args:
|
|
192
|
+
xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
|
|
201
193
|
|
|
202
|
-
Returns
|
|
203
|
-
|
|
204
|
-
T : np.ndarray[np.float32]
|
|
205
|
-
An homogeneous transformation matrix of shape (4, 4).
|
|
194
|
+
Returns:
|
|
195
|
+
T: An homogeneous transformation matrix of shape (4, 4).
|
|
206
196
|
"""
|
|
207
197
|
|
|
208
|
-
assert (
|
|
209
|
-
xyz
|
|
210
|
-
)
|
|
198
|
+
assert xyz.ndim == 2 and xyz.shape[1] == 3, (
|
|
199
|
+
f"xyz should be of shape (N, 3), got {xyz.shape}"
|
|
200
|
+
)
|
|
211
201
|
|
|
212
202
|
xyz = xyz[:, :3]
|
|
213
203
|
T = np.identity(4)
|
|
@@ -13,9 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from collections.abc import Iterable
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
|
+
from typing_extensions import override
|
|
19
20
|
|
|
20
21
|
from swcgeom.core import Branch, BranchTree, Node, Tree
|
|
21
22
|
from swcgeom.transforms.base import Transform
|
|
@@ -26,6 +27,7 @@ __all__ = ["BranchTreeAssembler"]
|
|
|
26
27
|
class BranchTreeAssembler(Transform[BranchTree, Tree]):
|
|
27
28
|
EPS = 1e-6
|
|
28
29
|
|
|
30
|
+
@override
|
|
29
31
|
def __call__(self, x: BranchTree) -> Tree:
|
|
30
32
|
nodes = [x.soma().detach()]
|
|
31
33
|
stack = [(x.soma(), 0)] # n_orig, id_new
|
swcgeom/transforms/geometry.py
CHANGED
|
@@ -16,10 +16,11 @@
|
|
|
16
16
|
"""SWC geometry operations."""
|
|
17
17
|
|
|
18
18
|
import warnings
|
|
19
|
-
from typing import Generic, Literal,
|
|
19
|
+
from typing import Generic, Literal, TypeVar
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import numpy.typing as npt
|
|
23
|
+
from typing_extensions import override
|
|
23
24
|
|
|
24
25
|
from swcgeom.core import DictSWC
|
|
25
26
|
from swcgeom.core.swc_utils import SWCNames
|
|
@@ -54,7 +55,7 @@ Center = Literal["root", "soma", "origin"]
|
|
|
54
55
|
class Normalizer(Generic[T], Transform[T, T]):
|
|
55
56
|
"""Noramlize coordinates and radius to 0-1."""
|
|
56
57
|
|
|
57
|
-
def __init__(self, *, names:
|
|
58
|
+
def __init__(self, *, names: SWCNames | None = None) -> None:
|
|
58
59
|
super().__init__()
|
|
59
60
|
if names is not None:
|
|
60
61
|
warnings.warn(
|
|
@@ -63,6 +64,7 @@ class Normalizer(Generic[T], Transform[T, T]):
|
|
|
63
64
|
DeprecationWarning,
|
|
64
65
|
)
|
|
65
66
|
|
|
67
|
+
@override
|
|
66
68
|
def __call__(self, x: T) -> T:
|
|
67
69
|
"""Scale the `x`, `y`, `z`, `r` of nodes to 0-1."""
|
|
68
70
|
new_tree = x.copy()
|
|
@@ -87,6 +89,7 @@ class RadiusReseter(Generic[T], Transform[T, T]):
|
|
|
87
89
|
new_tree.ndata[new_tree.names.r] = r
|
|
88
90
|
return new_tree
|
|
89
91
|
|
|
92
|
+
@override
|
|
90
93
|
def extra_repr(self) -> str:
|
|
91
94
|
return f"r={self.r:.4f}"
|
|
92
95
|
|
|
@@ -103,8 +106,8 @@ class AffineTransform(Generic[T], Transform[T, T]):
|
|
|
103
106
|
tm: npt.NDArray[np.float32],
|
|
104
107
|
center: Center = "origin",
|
|
105
108
|
*,
|
|
106
|
-
fmt:
|
|
107
|
-
names:
|
|
109
|
+
fmt: str | None = None,
|
|
110
|
+
names: SWCNames | None = None,
|
|
108
111
|
) -> None:
|
|
109
112
|
self.tm, self.center = tm, center
|
|
110
113
|
|
|
@@ -122,6 +125,7 @@ class AffineTransform(Generic[T], Transform[T, T]):
|
|
|
122
125
|
DeprecationWarning,
|
|
123
126
|
)
|
|
124
127
|
|
|
128
|
+
@override
|
|
125
129
|
def __call__(self, x: T) -> T:
|
|
126
130
|
match self.center:
|
|
127
131
|
case "root" | "soma":
|
|
@@ -156,6 +160,7 @@ class Translate(Generic[T], AffineTransform[T]):
|
|
|
156
160
|
super().__init__(translate3d(tx, ty, tz), **kwargs)
|
|
157
161
|
self.tx, self.ty, self.tz = tx, ty, tz
|
|
158
162
|
|
|
163
|
+
@override
|
|
159
164
|
def extra_repr(self) -> str:
|
|
160
165
|
return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
|
|
161
166
|
|
|
@@ -209,6 +214,7 @@ class Rotate(Generic[T], AffineTransform[T]):
|
|
|
209
214
|
self.theta = theta
|
|
210
215
|
self.center = center
|
|
211
216
|
|
|
217
|
+
@override
|
|
212
218
|
def extra_repr(self) -> str:
|
|
213
219
|
return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: improve format of n
|
|
214
220
|
|
|
@@ -231,6 +237,7 @@ class RotateX(Generic[T], AffineTransform[T]):
|
|
|
231
237
|
super().__init__(rotate3d_x(theta), center=center, **kwargs)
|
|
232
238
|
self.theta = theta
|
|
233
239
|
|
|
240
|
+
@override
|
|
234
241
|
def extra_repr(self) -> str:
|
|
235
242
|
return f"center={self.center}, theta={self.theta:.4f}"
|
|
236
243
|
|
|
@@ -247,6 +254,7 @@ class RotateY(Generic[T], AffineTransform[T]):
|
|
|
247
254
|
self.theta = theta
|
|
248
255
|
self.center = center
|
|
249
256
|
|
|
257
|
+
@override
|
|
250
258
|
def extra_repr(self) -> str:
|
|
251
259
|
return f"theta={self.theta:.4f}, center={self.center}"
|
|
252
260
|
|
|
@@ -263,6 +271,7 @@ class RotateZ(Generic[T], AffineTransform[T]):
|
|
|
263
271
|
self.theta = theta
|
|
264
272
|
self.center = center
|
|
265
273
|
|
|
274
|
+
@override
|
|
266
275
|
def extra_repr(self) -> str:
|
|
267
276
|
return f"theta={self.theta:.4f}, center={self.center}"
|
|
268
277
|
|
|
@@ -19,6 +19,7 @@ import numpy as np
|
|
|
19
19
|
import numpy.typing as npt
|
|
20
20
|
from scipy.fftpack import fftn, fftshift, ifftn
|
|
21
21
|
from scipy.ndimage import gaussian_filter, minimum_filter
|
|
22
|
+
from typing_extensions import override
|
|
22
23
|
|
|
23
24
|
from swcgeom.transforms.base import Transform
|
|
24
25
|
|
|
@@ -36,6 +37,7 @@ class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
|
|
|
36
37
|
January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
|
|
37
38
|
"""
|
|
38
39
|
|
|
40
|
+
@override
|
|
39
41
|
def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
|
|
40
42
|
# TODO: support np.float32
|
|
41
43
|
assert x.dtype == np.uint8, "Image must be in uint8 format"
|
|
@@ -15,9 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Create image stack from morphology.
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
-----
|
|
20
|
-
All denpendencies need to be installed, try:
|
|
18
|
+
NOTE: All denpendencies need to be installed, try:
|
|
21
19
|
|
|
22
20
|
```sh
|
|
23
21
|
pip install swcgeom[all]
|
|
@@ -28,7 +26,7 @@ import os
|
|
|
28
26
|
import re
|
|
29
27
|
import time
|
|
30
28
|
from collections.abc import Iterable
|
|
31
|
-
from typing import
|
|
29
|
+
from typing import Sequence
|
|
32
30
|
|
|
33
31
|
import numpy as np
|
|
34
32
|
import numpy.typing as npt
|
|
@@ -42,7 +40,7 @@ from sdflit import (
|
|
|
42
40
|
SDFObject,
|
|
43
41
|
)
|
|
44
42
|
from tqdm import tqdm
|
|
45
|
-
from typing_extensions import deprecated
|
|
43
|
+
from typing_extensions import deprecated, override
|
|
46
44
|
|
|
47
45
|
from swcgeom.core import Population, Tree
|
|
48
46
|
from swcgeom.transforms.base import Transform
|
|
@@ -58,46 +56,53 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
58
56
|
def __init__(self, resolution: int | float | npt.ArrayLike = 1) -> None:
|
|
59
57
|
"""Transform tree to image stack.
|
|
60
58
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
Resolution of image stack.
|
|
59
|
+
Args:
|
|
60
|
+
resolution: Resolution of image stack.
|
|
61
|
+
If a scalar, it will be broadcasted to a vector of 3d.
|
|
65
62
|
"""
|
|
66
|
-
|
|
67
63
|
if isinstance(resolution, (int, float, np.integer, np.floating)):
|
|
68
64
|
resolution = [resolution, resolution, resolution] # type: ignore
|
|
69
65
|
|
|
70
66
|
self.resolution = np.array(resolution, dtype=np.float32)
|
|
71
|
-
assert len(self.resolution) == 3, "resolution
|
|
67
|
+
assert len(self.resolution) == 3, "resolution should be vector of 3d."
|
|
72
68
|
|
|
69
|
+
@override
|
|
73
70
|
def __call__(self, x: Tree) -> npt.NDArray[np.uint8]:
|
|
74
71
|
"""Transform tree to image stack.
|
|
75
72
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
This method loads the entire image stack into memory, so it
|
|
79
|
-
ONLY works for small image stacks, use :meth`transform_and_save`
|
|
80
|
-
for big image stack.
|
|
73
|
+
NOTE: This method loads the entire image stack into memory, so it ONLY works
|
|
74
|
+
for small image stacks, use :meth`transform_and_save` for big image stack.
|
|
81
75
|
"""
|
|
82
76
|
return np.stack(list(self.transform(x, verbose=False)), axis=0)
|
|
83
77
|
|
|
84
78
|
def transform(
|
|
85
79
|
self,
|
|
86
|
-
x: Tree,
|
|
80
|
+
x: Tree | Sequence[Tree],
|
|
87
81
|
verbose: bool = True,
|
|
88
82
|
*,
|
|
89
|
-
ranges:
|
|
83
|
+
ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
|
|
90
84
|
) -> Iterable[npt.NDArray[np.uint8]]:
|
|
85
|
+
trees = [x] if isinstance(x, Tree) else x
|
|
86
|
+
if not trees:
|
|
87
|
+
return iter([]) # Return empty iterator if sequence is empty
|
|
88
|
+
|
|
89
|
+
time_start = None
|
|
91
90
|
if verbose:
|
|
92
|
-
|
|
91
|
+
sources = ", ".join(t.source for t in trees if t.source)
|
|
92
|
+
print(f"To image stack: {sources if sources else 'unnamed trees'}")
|
|
93
93
|
time_start = time.time()
|
|
94
94
|
|
|
95
|
-
scene = self._get_scene(
|
|
95
|
+
scene = self._get_scene(trees)
|
|
96
96
|
|
|
97
97
|
if ranges is None:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
all_xyz = np.concatenate([t.xyz() for t in trees], axis=0)
|
|
99
|
+
all_r = np.concatenate([t.r() for t in trees], axis=0).reshape(-1, 1)
|
|
100
|
+
if all_xyz.size == 0: # Handle empty trees
|
|
101
|
+
coord_min = np.zeros(3, dtype=np.float32)
|
|
102
|
+
coord_max = np.zeros(3, dtype=np.float32)
|
|
103
|
+
else:
|
|
104
|
+
coord_min = np.floor(np.min(all_xyz - all_r, axis=0))
|
|
105
|
+
coord_max = np.ceil(np.max(all_xyz + all_r, axis=0))
|
|
101
106
|
else:
|
|
102
107
|
assert len(ranges) == 2
|
|
103
108
|
coord_min = np.array(ranges[0])
|
|
@@ -106,12 +111,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
106
111
|
|
|
107
112
|
samplers = self._get_samplers(coord_min, coord_max)
|
|
108
113
|
|
|
109
|
-
if verbose:
|
|
114
|
+
if verbose and time_start is not None:
|
|
110
115
|
total = (coord_max[2] - coord_min[2]) / self.resolution[2]
|
|
111
116
|
samplers = tqdm(samplers, total=total.astype(np.int64).item())
|
|
112
117
|
|
|
113
118
|
time_end = time.time()
|
|
114
|
-
print("Prepare in: ", time_end - time_start, "s")
|
|
119
|
+
print("Prepare in: ", time_end - time_start, "s")
|
|
115
120
|
|
|
116
121
|
for sampler in samplers:
|
|
117
122
|
voxel = sampler.sample(scene) # should be shape of (x, y, z, 3) and z = 1
|
|
@@ -124,12 +129,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
124
129
|
x: Tree,
|
|
125
130
|
verbose: bool = True,
|
|
126
131
|
*,
|
|
127
|
-
ranges:
|
|
132
|
+
ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
|
|
128
133
|
) -> Iterable[npt.NDArray[np.uint8]]:
|
|
129
134
|
return self.transform(x, verbose, ranges=ranges)
|
|
130
135
|
|
|
131
136
|
def transform_and_save(
|
|
132
|
-
self, fname: str, x: Tree, verbose: bool = True, **kwargs
|
|
137
|
+
self, fname: str, x: Tree | Sequence[Tree], verbose: bool = True, **kwargs
|
|
133
138
|
) -> None:
|
|
134
139
|
self.save_tif(fname, self.transform(x, verbose=verbose, **kwargs))
|
|
135
140
|
|
|
@@ -151,11 +156,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
151
156
|
if not os.path.isfile(tif):
|
|
152
157
|
self.transform_and_save(tif, tree, verbose=False)
|
|
153
158
|
|
|
159
|
+
@override
|
|
154
160
|
def extra_repr(self) -> str:
|
|
155
161
|
res = ",".join(f"{a:.4f}" for a in self.resolution)
|
|
156
162
|
return f"resolution=({res})"
|
|
157
163
|
|
|
158
|
-
def _get_scene(self,
|
|
164
|
+
def _get_scene(self, trees: Sequence[Tree]) -> Scene:
|
|
159
165
|
material = ColoredMaterial((1, 0, 0)).into()
|
|
160
166
|
scene = ObjectsScene()
|
|
161
167
|
scene.set_background((0, 0, 0))
|
|
@@ -164,10 +170,11 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
164
170
|
for c in children:
|
|
165
171
|
sdf = RoundCone(_tp3f(n.xyz()), _tp3f(c.xyz()), n.r, c.r).into()
|
|
166
172
|
scene.add_object(SDFObject(sdf, material).into())
|
|
167
|
-
|
|
168
173
|
return n
|
|
169
174
|
|
|
170
|
-
|
|
175
|
+
for tree in trees:
|
|
176
|
+
tree.traverse(leave=leave)
|
|
177
|
+
|
|
171
178
|
scene.build_bvh()
|
|
172
179
|
return scene.into()
|
|
173
180
|
|
|
@@ -175,14 +182,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
|
|
|
175
182
|
self,
|
|
176
183
|
coord_min: npt.NDArray,
|
|
177
184
|
coord_max: npt.NDArray,
|
|
178
|
-
offset:
|
|
185
|
+
offset: npt.NDArray | None = None,
|
|
179
186
|
) -> Iterable[RangeSampler]:
|
|
180
187
|
"""Get Samplers.
|
|
181
188
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
coord_min, coord_max: npt.ArrayLike
|
|
185
|
-
Coordinates array of shape (3,).
|
|
189
|
+
Args:
|
|
190
|
+
coord_min, coord_max: Coordinates array of shape (3,).
|
|
186
191
|
"""
|
|
187
192
|
|
|
188
193
|
eps = 1e-6
|