swcgeom 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of swcgeom might be problematic. Click here for more details.
- swcgeom/__init__.py +12 -1
- swcgeom/analysis/__init__.py +6 -6
- swcgeom/analysis/feature_extractor.py +22 -24
- swcgeom/analysis/features.py +18 -40
- swcgeom/analysis/lmeasure.py +227 -323
- swcgeom/analysis/sholl.py +17 -23
- swcgeom/analysis/trunk.py +23 -28
- swcgeom/analysis/visualization.py +37 -44
- swcgeom/analysis/visualization3d.py +16 -25
- swcgeom/analysis/volume.py +33 -47
- swcgeom/core/__init__.py +12 -13
- swcgeom/core/branch.py +10 -17
- swcgeom/core/branch_tree.py +3 -2
- swcgeom/core/compartment.py +1 -1
- swcgeom/core/node.py +3 -6
- swcgeom/core/path.py +11 -16
- swcgeom/core/population.py +32 -51
- swcgeom/core/swc.py +25 -16
- swcgeom/core/swc_utils/__init__.py +10 -12
- swcgeom/core/swc_utils/assembler.py +5 -12
- swcgeom/core/swc_utils/base.py +40 -31
- swcgeom/core/swc_utils/checker.py +3 -8
- swcgeom/core/swc_utils/io.py +32 -47
- swcgeom/core/swc_utils/normalizer.py +17 -23
- swcgeom/core/swc_utils/subtree.py +13 -20
- swcgeom/core/tree.py +61 -51
- swcgeom/core/tree_utils.py +36 -49
- swcgeom/core/tree_utils_impl.py +4 -6
- swcgeom/images/__init__.py +2 -2
- swcgeom/images/augmentation.py +23 -39
- swcgeom/images/contrast.py +22 -46
- swcgeom/images/folder.py +32 -34
- swcgeom/images/io.py +80 -121
- swcgeom/transforms/__init__.py +13 -13
- swcgeom/transforms/base.py +28 -19
- swcgeom/transforms/branch.py +31 -41
- swcgeom/transforms/branch_tree.py +3 -1
- swcgeom/transforms/geometry.py +13 -4
- swcgeom/transforms/image_preprocess.py +2 -0
- swcgeom/transforms/image_stack.py +40 -35
- swcgeom/transforms/images.py +31 -24
- swcgeom/transforms/mst.py +27 -40
- swcgeom/transforms/neurolucida_asc.py +13 -13
- swcgeom/transforms/path.py +4 -0
- swcgeom/transforms/population.py +4 -0
- swcgeom/transforms/tree.py +16 -11
- swcgeom/transforms/tree_assembler.py +37 -54
- swcgeom/utils/__init__.py +12 -12
- swcgeom/utils/download.py +7 -14
- swcgeom/utils/dsu.py +12 -0
- swcgeom/utils/ellipse.py +26 -14
- swcgeom/utils/file.py +8 -13
- swcgeom/utils/neuromorpho.py +78 -92
- swcgeom/utils/numpy_helper.py +15 -12
- swcgeom/utils/plotter_2d.py +10 -16
- swcgeom/utils/plotter_3d.py +7 -9
- swcgeom/utils/renderer.py +16 -8
- swcgeom/utils/sdf.py +12 -23
- swcgeom/utils/solid_geometry.py +58 -2
- swcgeom/utils/transforms.py +164 -100
- swcgeom/utils/volumetric_object.py +29 -53
- {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/METADATA +7 -6
- swcgeom-0.19.0.dist-info/RECORD +67 -0
- {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/WHEEL +1 -1
- swcgeom/_version.py +0 -16
- swcgeom-0.18.1.dist-info/RECORD +0 -68
- {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info/licenses}/LICENSE +0 -0
- {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/top_level.txt +0 -0
swcgeom/images/io.py
CHANGED
|
@@ -21,7 +21,7 @@ import warnings
|
|
|
21
21
|
from abc import ABC, abstractmethod
|
|
22
22
|
from collections.abc import Callable, Iterable
|
|
23
23
|
from functools import cache, lru_cache
|
|
24
|
-
from typing import Any, Generic, Literal,
|
|
24
|
+
from typing import Any, Generic, Literal, TypeVar, cast, overload
|
|
25
25
|
|
|
26
26
|
import nrrd
|
|
27
27
|
import numpy as np
|
|
@@ -39,10 +39,10 @@ RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
|
|
|
39
39
|
RE_TERAFLY_NAME = re.compile(r"^\d+(_\d+)?(_\d+)?")
|
|
40
40
|
|
|
41
41
|
UINT_MAX = {
|
|
42
|
-
np.dtype(np.uint8): (2**8) - 1,
|
|
43
|
-
np.dtype(np.uint16): (2**16) - 1,
|
|
44
|
-
np.dtype(np.uint32): (2**32) - 1,
|
|
45
|
-
np.dtype(np.uint64): (2**64) - 1,
|
|
42
|
+
np.dtype(np.uint8): (2**8) - 1,
|
|
43
|
+
np.dtype(np.uint16): (2**16) - 1,
|
|
44
|
+
np.dtype(np.uint32): (2**32) - 1,
|
|
45
|
+
np.dtype(np.uint64): (2**64) - 1,
|
|
46
46
|
}
|
|
47
47
|
|
|
48
48
|
AXES_ORDER = {
|
|
@@ -83,19 +83,15 @@ class ImageStack(ABC, Generic[ScalarType]):
|
|
|
83
83
|
def __getitem__(self, key):
|
|
84
84
|
"""Get pixel/patch of image stack.
|
|
85
85
|
|
|
86
|
-
Returns
|
|
87
|
-
|
|
88
|
-
value : ndarray of f32
|
|
89
|
-
NDArray which shape depends on key. If key is tuple of ints,
|
|
86
|
+
Returns:
|
|
87
|
+
value: NDArray which shape depends on key. If key is tuple of ints,
|
|
90
88
|
"""
|
|
91
89
|
raise NotImplementedError()
|
|
92
90
|
|
|
93
91
|
def get_full(self) -> npt.NDArray[ScalarType]:
|
|
94
92
|
"""Get full image stack.
|
|
95
93
|
|
|
96
|
-
|
|
97
|
-
-----
|
|
98
|
-
this will load the full image stack into memory.
|
|
94
|
+
NOTE: this will load the full image stack into memory.
|
|
99
95
|
"""
|
|
100
96
|
return self[:, :, :, :]
|
|
101
97
|
|
|
@@ -104,29 +100,20 @@ class ImageStack(ABC, Generic[ScalarType]):
|
|
|
104
100
|
raise NotImplementedError()
|
|
105
101
|
|
|
106
102
|
|
|
107
|
-
# fmt:off
|
|
108
103
|
@overload
|
|
109
104
|
def read_imgs(fname: str, *, dtype: ScalarType, **kwargs) -> ImageStack[ScalarType]: ...
|
|
110
105
|
@overload
|
|
111
|
-
def read_imgs(fname: str, *, dtype: None
|
|
112
|
-
# fmt:on
|
|
113
|
-
|
|
114
|
-
|
|
106
|
+
def read_imgs(fname: str, *, dtype: None = ..., **kwargs) -> ImageStack[np.float32]: ...
|
|
115
107
|
def read_imgs(fname: str, **kwargs): # type: ignore
|
|
116
108
|
"""Read image stack.
|
|
117
109
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
conversions occur, they will be scaled (assuming floats are
|
|
125
|
-
between 0 and 1).
|
|
126
|
-
**kwargs : dict[str, Any]
|
|
127
|
-
Forwarding to the corresponding reader.
|
|
110
|
+
Args:
|
|
111
|
+
fname: The path of image stack.
|
|
112
|
+
dtype: Casting data to specified dtype.
|
|
113
|
+
If integer and float conversions occur, they will be scaled (assuming floats
|
|
114
|
+
are between 0 and 1). Default to `np.float32`.
|
|
115
|
+
**kwargs: Forwarding to the corresponding reader.
|
|
128
116
|
"""
|
|
129
|
-
|
|
130
117
|
kwargs.setdefault("dtype", np.float32)
|
|
131
118
|
if not os.path.exists(fname):
|
|
132
119
|
raise ValueError(f"image stack not exists: {fname}")
|
|
@@ -155,27 +142,22 @@ def save_tiff(
|
|
|
155
142
|
data: npt.NDArray | ImageStack,
|
|
156
143
|
fname: str,
|
|
157
144
|
*,
|
|
158
|
-
dtype:
|
|
145
|
+
dtype: np.unsignedinteger | np.floating | None = None,
|
|
159
146
|
compression: str | Literal[False] = "zlib",
|
|
160
147
|
**kwargs,
|
|
161
148
|
) -> None:
|
|
162
149
|
"""Save image stack as tiff.
|
|
163
150
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
Compression algorithm, forwarding to `tifffile.imwrite`. If no
|
|
175
|
-
algorithnm is specify specified, we will use the zlib algorithm
|
|
176
|
-
with compression level 6 by default.
|
|
177
|
-
**kwargs : dict[str, Any]
|
|
178
|
-
Forwarding to `tifffile.imwrite`
|
|
151
|
+
Args:
|
|
152
|
+
data: The image stack.
|
|
153
|
+
fname: str
|
|
154
|
+
dtype: Casting data to specified dtype.
|
|
155
|
+
If integer and float conversions occur, they will be scaled (assuming
|
|
156
|
+
floats are between 0 and 1).
|
|
157
|
+
compression: Compression algorithm, forwarding to `tifffile.imwrite`.
|
|
158
|
+
If no algorithnm is specify specified, we will use the zlib algorithm with
|
|
159
|
+
compression level 6 by default.
|
|
160
|
+
**kwargs: Forwarding to `tifffile.imwrite`
|
|
179
161
|
"""
|
|
180
162
|
if isinstance(data, ImageStack):
|
|
181
163
|
data = data.get_full() # TODO: avoid load full imgs to memory
|
|
@@ -191,11 +173,11 @@ def save_tiff(
|
|
|
191
173
|
if np.issubdtype(data.dtype, np.floating) and np.issubdtype(
|
|
192
174
|
dtype, np.unsignedinteger
|
|
193
175
|
):
|
|
194
|
-
scaler_factor = UINT_MAX[np.dtype(dtype)]
|
|
176
|
+
scaler_factor = UINT_MAX[np.dtype(dtype)]
|
|
195
177
|
elif np.issubdtype(data.dtype, np.unsignedinteger) and np.issubdtype(
|
|
196
178
|
dtype, np.floating
|
|
197
179
|
):
|
|
198
|
-
scaler_factor = 1 / UINT_MAX[np.dtype(data.dtype)]
|
|
180
|
+
scaler_factor = 1 / UINT_MAX[np.dtype(data.dtype)]
|
|
199
181
|
else:
|
|
200
182
|
scaler_factor = 1
|
|
201
183
|
|
|
@@ -218,7 +200,7 @@ class NDArrayImageStack(ImageStack[ScalarType]):
|
|
|
218
200
|
"""NDArray image stack."""
|
|
219
201
|
|
|
220
202
|
def __init__(
|
|
221
|
-
self, imgs: npt.NDArray[Any], *, dtype:
|
|
203
|
+
self, imgs: npt.NDArray[Any], *, dtype: ScalarType | None = None
|
|
222
204
|
) -> None:
|
|
223
205
|
super().__init__()
|
|
224
206
|
|
|
@@ -231,13 +213,13 @@ class NDArrayImageStack(ImageStack[ScalarType]):
|
|
|
231
213
|
if np.issubdtype(dtype, np.floating) and np.issubdtype(
|
|
232
214
|
dtype_raw, np.unsignedinteger
|
|
233
215
|
):
|
|
234
|
-
|
|
235
|
-
imgs =
|
|
216
|
+
scalar_factor = 1.0 / UINT_MAX[dtype_raw]
|
|
217
|
+
imgs = scalar_factor * imgs.astype(dtype)
|
|
236
218
|
elif np.issubdtype(dtype, np.unsignedinteger) and np.issubdtype(
|
|
237
219
|
dtype_raw, np.floating
|
|
238
220
|
):
|
|
239
|
-
|
|
240
|
-
imgs *= (
|
|
221
|
+
scalar_factor = UINT_MAX[dtype]
|
|
222
|
+
imgs *= (scalar_factor * imgs).astype(dtype)
|
|
241
223
|
else:
|
|
242
224
|
imgs = imgs.astype(dtype)
|
|
243
225
|
|
|
@@ -284,27 +266,23 @@ class NrrdImageStack(NDArrayImageStack[ScalarType]):
|
|
|
284
266
|
class V3dImageStack(NDArrayImageStack[ScalarType]):
|
|
285
267
|
"""v3d image stack."""
|
|
286
268
|
|
|
287
|
-
def
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
269
|
+
def __init_subclass__(cls, loader: Raw | PBD) -> None:
|
|
270
|
+
super().__init_subclass__()
|
|
271
|
+
cls._loader = loader
|
|
272
|
+
|
|
273
|
+
def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
|
|
274
|
+
r = self._loader()
|
|
291
275
|
imgs = r.load(fname)
|
|
292
276
|
super().__init__(imgs, dtype=dtype, **kwargs)
|
|
293
277
|
|
|
294
278
|
|
|
295
|
-
class V3drawImageStack(V3dImageStack[ScalarType]):
|
|
279
|
+
class V3drawImageStack(V3dImageStack[ScalarType], loader=Raw):
|
|
296
280
|
"""v3draw image stack."""
|
|
297
281
|
|
|
298
|
-
def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
|
|
299
|
-
super().__init__(fname, loader=Raw, dtype=dtype, **kwargs)
|
|
300
|
-
|
|
301
282
|
|
|
302
|
-
class V3dpbdImageStack(V3dImageStack[ScalarType]):
|
|
283
|
+
class V3dpbdImageStack(V3dImageStack[ScalarType], loader=PBD):
|
|
303
284
|
"""v3dpbd image stack."""
|
|
304
285
|
|
|
305
|
-
def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
|
|
306
|
-
super().__init__(fname, loader=PBD, dtype=dtype, **kwargs)
|
|
307
|
-
|
|
308
286
|
|
|
309
287
|
class TeraflyImageStack(ImageStack[ScalarType]):
|
|
310
288
|
"""TeraFly image stack.
|
|
@@ -312,21 +290,17 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
312
290
|
TeraFly is a terabytes of multidimensional volumetric images file
|
|
313
291
|
format as described in [1]_.
|
|
314
292
|
|
|
315
|
-
|
|
316
|
-
----------
|
|
317
|
-
.. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
|
|
318
|
-
Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional
|
|
319
|
-
Visualization and Annotation of Terabytes of Multidimensional
|
|
320
|
-
Volumetric Images.” Nature Methods 13,
|
|
321
|
-
no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
|
|
322
|
-
|
|
323
|
-
Notes
|
|
324
|
-
-----
|
|
325
|
-
Terafly and Vaa3d use a especial right-handed coordinate system
|
|
293
|
+
NOTE: Terafly and Vaa3d use a especial right-handed coordinate system
|
|
326
294
|
(with origin point in the left-top and z-axis points front), but we
|
|
327
|
-
flip y-axis to makes it a left-handed coordinate system (with
|
|
295
|
+
flip y-axis to makes it a left-handed coordinate system (with origin
|
|
328
296
|
point in the left-bottom and z-axis points front). If you need to
|
|
329
297
|
use its coordinate system, remember to FLIP Y-AXIS BACK.
|
|
298
|
+
|
|
299
|
+
References:
|
|
300
|
+
.. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and Hanchuan Peng.
|
|
301
|
+
“TeraFly: Real-Time Three-Dimensional Visualization and Annotation of Terabytes
|
|
302
|
+
of Multidimensional Volumetric Images.” Nature Methods 13, no. 3 (March 2016):
|
|
303
|
+
192-94. https://doi.org/10.1038/nmeth.3767.
|
|
330
304
|
"""
|
|
331
305
|
|
|
332
306
|
_listdir: Callable[[str], list[str]]
|
|
@@ -336,17 +310,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
336
310
|
self, root: str, *, dtype: ScalarType, lru_maxsize: int | None = 128
|
|
337
311
|
) -> None:
|
|
338
312
|
r"""
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
Forwarding to `functools.lru_cache`. A decompressed array
|
|
347
|
-
size of (256, 256, 256, 1), which is the typical size of
|
|
348
|
-
terafly image stack, takes about 256 * 256 * 256 * 1 *
|
|
349
|
-
4B = 64MB. A cache size of 128 requires about 8GB memeory.
|
|
313
|
+
Args:
|
|
314
|
+
root: The root of terafly which contains directories named as `RES(YxXxZ)`.
|
|
315
|
+
dtype: np.dtype
|
|
316
|
+
lru_maxsize: Forwarding to `functools.lru_cache`.
|
|
317
|
+
A decompressed array size of (256, 256, 256, 1), which is the typical
|
|
318
|
+
size of terafly image stack, takes about 256 * 256 * 256 * 1 * 4B = 64MB.
|
|
319
|
+
A cache size of 128 requires about 8GB memory.
|
|
350
320
|
"""
|
|
351
321
|
|
|
352
322
|
super().__init__()
|
|
@@ -372,11 +342,8 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
372
342
|
def __getitem__(self, key):
|
|
373
343
|
"""Get images in max resolution.
|
|
374
344
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
```python
|
|
378
|
-
imgs[0, 0, 0, 0] # get value
|
|
379
|
-
imgs[0:64, 0:64, 0:64, :] # get patch
|
|
345
|
+
>>> imgs[0, 0, 0, 0] # get value # doctest: +SKIP
|
|
346
|
+
>>> imgs[0:64, 0:64, 0:64, :] # get patch # doctest: +SKIP
|
|
380
347
|
```
|
|
381
348
|
"""
|
|
382
349
|
if not isinstance(key, tuple):
|
|
@@ -399,9 +366,8 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
399
366
|
) -> npt.NDArray[ScalarType]:
|
|
400
367
|
"""Get patch of image stack.
|
|
401
368
|
|
|
402
|
-
Returns
|
|
403
|
-
|
|
404
|
-
patch : array of shape (X, Y, Z, C)
|
|
369
|
+
Returns:
|
|
370
|
+
patch: array of shape (X, Y, Z, C)
|
|
405
371
|
"""
|
|
406
372
|
if isinstance(strides, int):
|
|
407
373
|
strides = (strides, strides, strides)
|
|
@@ -421,10 +387,9 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
421
387
|
def find_correspond_imgs(self, p, res_level=-1):
|
|
422
388
|
"""Find the image which contain this point.
|
|
423
389
|
|
|
424
|
-
Returns
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
patch_offset : (int, int, int)
|
|
390
|
+
Returns:
|
|
391
|
+
patch: array of shape (X, Y, Z, C)
|
|
392
|
+
patch_offset: (int, int, int)
|
|
428
393
|
"""
|
|
429
394
|
p = np.array(p)
|
|
430
395
|
self._check_params(res_level, p)
|
|
@@ -442,14 +407,10 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
442
407
|
def get_resolutions(cls, root: str) -> tuple[list[Vec3i], list[str], list[Vec3i]]:
|
|
443
408
|
"""Get all resolutions.
|
|
444
409
|
|
|
445
|
-
Returns
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
Sequence of
|
|
449
|
-
roots : list[str]
|
|
450
|
-
Sequence of root of resolutions respectively.
|
|
451
|
-
patch_sizes : List of (int, int, int)
|
|
452
|
-
Sequence of patch size of resolutions respectively.
|
|
410
|
+
Returns:
|
|
411
|
+
resolutions: Sequence of sorted resolutions (from small to large).
|
|
412
|
+
roots: Sequence of root of resolutions respectively.
|
|
413
|
+
patch_sizes: Sequence of patch size of resolutions respectively.
|
|
453
414
|
"""
|
|
454
415
|
|
|
455
416
|
roots = list(cls.get_resolution_dirs(root))
|
|
@@ -497,13 +458,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
497
458
|
|
|
498
459
|
res = self.res[res_level]
|
|
499
460
|
for p in coords:
|
|
500
|
-
assert np.less(
|
|
501
|
-
[0, 0, 0
|
|
502
|
-
)
|
|
461
|
+
assert np.less([0, 0, 0], p).all(), (
|
|
462
|
+
f"indices ({p[0]}, {p[1]}, {p[2]}) out of range (0, 0, 0)"
|
|
463
|
+
)
|
|
503
464
|
|
|
504
|
-
assert np.greater(
|
|
505
|
-
res,
|
|
506
|
-
)
|
|
465
|
+
assert np.greater(res, p).all(), (
|
|
466
|
+
f"indices ({p[0]}, {p[1]}, {p[2]}) out of range ({res[0]}, {res[1]}, {res[2]})"
|
|
467
|
+
)
|
|
507
468
|
|
|
508
469
|
def _get_range(self, starts, ends, res_level, out):
|
|
509
470
|
# pylint: disable=too-many-locals
|
|
@@ -529,13 +490,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
|
|
|
529
490
|
if shape[1] > lens[1]:
|
|
530
491
|
starts_y = starts + [0, lens[1], 0]
|
|
531
492
|
ends_y = np.array([starts[0], ends[1], ends[2]])
|
|
532
|
-
ends_y += [min(shape[0], lens[0]), 0, 0]
|
|
493
|
+
ends_y += [min(shape[0], lens[0]), 0, 0]
|
|
533
494
|
self._get_range(starts_y, ends_y, res_level, out[:, lens[1] :, :])
|
|
534
495
|
|
|
535
496
|
if shape[2] > lens[2]:
|
|
536
497
|
starts_z = starts + [0, 0, lens[2]]
|
|
537
498
|
ends_z = np.array([starts[0], starts[1], ends[2]])
|
|
538
|
-
ends_z += [min(shape[0], lens[0]), min(shape[1], lens[1]), 0]
|
|
499
|
+
ends_z += [min(shape[0], lens[0]), min(shape[1], lens[1]), 0]
|
|
539
500
|
self._get_range(starts_z, ends_z, res_level, out[:, :, lens[2] :])
|
|
540
501
|
|
|
541
502
|
def _find_correspond_imgs(self, p, res_level):
|
|
@@ -580,14 +541,14 @@ class GrayImageStack:
|
|
|
580
541
|
def __init__(self, imgs: ImageStack) -> None:
|
|
581
542
|
self.imgs = imgs
|
|
582
543
|
|
|
583
|
-
# fmt: off
|
|
584
544
|
@overload
|
|
585
545
|
def __getitem__(self, key: Vec3i) -> np.float32: ...
|
|
586
546
|
@overload
|
|
587
547
|
def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> np.float32: ...
|
|
588
548
|
@overload
|
|
589
|
-
def __getitem__(
|
|
590
|
-
|
|
549
|
+
def __getitem__(
|
|
550
|
+
self, key: slice | tuple[slice, slice] | tuple[slice, slice, slice]
|
|
551
|
+
) -> npt.NDArray[np.float32]: ...
|
|
591
552
|
def __getitem__(self, key):
|
|
592
553
|
"""Get pixel/patch of image stack."""
|
|
593
554
|
v = self[key]
|
|
@@ -601,14 +562,12 @@ class GrayImageStack:
|
|
|
601
562
|
return v[:, 0]
|
|
602
563
|
if v.ndim == 1:
|
|
603
564
|
return v[0]
|
|
604
|
-
raise ValueError("
|
|
565
|
+
raise ValueError("unsupported key")
|
|
605
566
|
|
|
606
567
|
def get_full(self) -> npt.NDArray[np.float32]:
|
|
607
568
|
"""Get full image stack.
|
|
608
569
|
|
|
609
|
-
|
|
610
|
-
-----
|
|
611
|
-
this will load the full image stack into memory.
|
|
570
|
+
NOTE: this will load the full image stack into memory.
|
|
612
571
|
"""
|
|
613
572
|
return self.imgs.get_full()[:, :, :, 0]
|
|
614
573
|
|
swcgeom/transforms/__init__.py
CHANGED
|
@@ -15,16 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
"""A series of transformations to compose codes."""
|
|
17
17
|
|
|
18
|
-
from swcgeom.transforms.base import *
|
|
19
|
-
from swcgeom.transforms.branch import *
|
|
20
|
-
from swcgeom.transforms.branch_tree import *
|
|
21
|
-
from swcgeom.transforms.geometry import *
|
|
22
|
-
from swcgeom.transforms.image_preprocess import *
|
|
23
|
-
from swcgeom.transforms.image_stack import *
|
|
24
|
-
from swcgeom.transforms.images import *
|
|
25
|
-
from swcgeom.transforms.mst import *
|
|
26
|
-
from swcgeom.transforms.neurolucida_asc import *
|
|
27
|
-
from swcgeom.transforms.path import *
|
|
28
|
-
from swcgeom.transforms.population import *
|
|
29
|
-
from swcgeom.transforms.tree import *
|
|
30
|
-
from swcgeom.transforms.tree_assembler import *
|
|
18
|
+
from swcgeom.transforms.base import * # noqa: F403
|
|
19
|
+
from swcgeom.transforms.branch import * # noqa: F403
|
|
20
|
+
from swcgeom.transforms.branch_tree import * # noqa: F403
|
|
21
|
+
from swcgeom.transforms.geometry import * # noqa: F403
|
|
22
|
+
from swcgeom.transforms.image_preprocess import * # noqa: F403
|
|
23
|
+
from swcgeom.transforms.image_stack import * # noqa: F403
|
|
24
|
+
from swcgeom.transforms.images import * # noqa: F403
|
|
25
|
+
from swcgeom.transforms.mst import * # noqa: F403
|
|
26
|
+
from swcgeom.transforms.neurolucida_asc import * # noqa: F403
|
|
27
|
+
from swcgeom.transforms.path import * # noqa: F403
|
|
28
|
+
from swcgeom.transforms.population import * # noqa: F403
|
|
29
|
+
from swcgeom.transforms.tree import * # noqa: F403
|
|
30
|
+
from swcgeom.transforms.tree_assembler import * # noqa: F403
|
swcgeom/transforms/base.py
CHANGED
|
@@ -18,12 +18,19 @@
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from typing import Any, Generic, TypeVar, overload
|
|
20
20
|
|
|
21
|
+
from typing_extensions import override
|
|
22
|
+
|
|
21
23
|
__all__ = ["Transform", "Transforms", "Identity"]
|
|
22
24
|
|
|
23
|
-
T
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
K = TypeVar("K")
|
|
24
27
|
|
|
25
|
-
T1
|
|
26
|
-
|
|
28
|
+
T1 = TypeVar("T1")
|
|
29
|
+
T2 = TypeVar("T2")
|
|
30
|
+
T3 = TypeVar("T3")
|
|
31
|
+
T4 = TypeVar("T4")
|
|
32
|
+
T5 = TypeVar("T5")
|
|
33
|
+
T6 = TypeVar("T6")
|
|
27
34
|
|
|
28
35
|
|
|
29
36
|
class Transform(ABC, Generic[T, K]):
|
|
@@ -36,9 +43,7 @@ class Transform(ABC, Generic[T, K]):
|
|
|
36
43
|
def __call__(self, x: T) -> K:
|
|
37
44
|
"""Apply transform.
|
|
38
45
|
|
|
39
|
-
|
|
40
|
-
-----
|
|
41
|
-
All subclasses should overwrite :meth:`__call__`, supporting
|
|
46
|
+
NOTE: All subclasses should overwrite :meth:`__call__`, supporting
|
|
42
47
|
applying transform in `x`.
|
|
43
48
|
"""
|
|
44
49
|
raise NotImplementedError()
|
|
@@ -57,18 +62,14 @@ class Transform(ABC, Generic[T, K]):
|
|
|
57
62
|
which can be particularly useful for debugging and model
|
|
58
63
|
architecture introspection.
|
|
59
64
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def extra_repr(self) -> str:
|
|
67
|
-
return f"my_parameter={self.my_parameter}"
|
|
65
|
+
>>> class Foo(Transform[T, K]):
|
|
66
|
+
... def __init__(self, my_parameter: int = 1):
|
|
67
|
+
... self.my_parameter = my_parameter
|
|
68
|
+
...
|
|
69
|
+
... def extra_repr(self) -> str:
|
|
70
|
+
... return f"my_parameter={self.my_parameter}"
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
-----
|
|
71
|
-
This method should be overridden in custom modules to provide
|
|
72
|
+
NOTE: This method should be overridden in custom modules to provide
|
|
72
73
|
specific details relevant to the module's functionality and
|
|
73
74
|
configuration.
|
|
74
75
|
"""
|
|
@@ -80,7 +81,7 @@ class Transforms(Transform[T, K]):
|
|
|
80
81
|
|
|
81
82
|
transforms: list[Transform[Any, Any]]
|
|
82
83
|
|
|
83
|
-
# fmt:off
|
|
84
|
+
# fmt: off
|
|
84
85
|
@overload
|
|
85
86
|
def __init__(self, t1: Transform[T, K], /) -> None: ...
|
|
86
87
|
@overload
|
|
@@ -100,11 +101,16 @@ class Transforms(Transform[T, K]):
|
|
|
100
101
|
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
101
102
|
t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
|
|
102
103
|
@overload
|
|
104
|
+
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
105
|
+
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
106
|
+
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
107
|
+
t7: Transform[T6, K], /) -> None: ...
|
|
108
|
+
@overload
|
|
103
109
|
def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
|
|
104
110
|
t3: Transform[T2, T3], t4: Transform[T3, T4],
|
|
105
111
|
t5: Transform[T4, T5], t6: Transform[T5, T6],
|
|
106
112
|
t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
|
|
107
|
-
# fmt:on
|
|
113
|
+
# fmt: on
|
|
108
114
|
def __init__(self, *transforms: Transform[Any, Any]) -> None:
|
|
109
115
|
trans = []
|
|
110
116
|
for t in transforms:
|
|
@@ -114,6 +120,7 @@ class Transforms(Transform[T, K]):
|
|
|
114
120
|
trans.append(t)
|
|
115
121
|
self.transforms = trans
|
|
116
122
|
|
|
123
|
+
@override
|
|
117
124
|
def __call__(self, x: T) -> K:
|
|
118
125
|
"""Apply transforms."""
|
|
119
126
|
for transform in self.transforms:
|
|
@@ -127,6 +134,7 @@ class Transforms(Transform[T, K]):
|
|
|
127
134
|
def __len__(self) -> int:
|
|
128
135
|
return len(self.transforms)
|
|
129
136
|
|
|
137
|
+
@override
|
|
130
138
|
def extra_repr(self) -> str:
|
|
131
139
|
return ", ".join([str(transform) for transform in self])
|
|
132
140
|
|
|
@@ -134,5 +142,6 @@ class Transforms(Transform[T, K]):
|
|
|
134
142
|
class Identity(Transform[T, T]):
|
|
135
143
|
"""Resurn input as-is."""
|
|
136
144
|
|
|
145
|
+
@override
|
|
137
146
|
def __call__(self, x: T) -> T:
|
|
138
147
|
return x
|
swcgeom/transforms/branch.py
CHANGED
|
@@ -21,6 +21,7 @@ from typing import cast
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import numpy.typing as npt
|
|
23
23
|
from scipy import signal
|
|
24
|
+
from typing_extensions import override
|
|
24
25
|
|
|
25
26
|
from swcgeom.core import Branch, DictSWC
|
|
26
27
|
from swcgeom.transforms.base import Transform
|
|
@@ -40,13 +41,13 @@ __all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
|
|
|
40
41
|
class _BranchResampler(Transform[Branch, Branch], ABC):
|
|
41
42
|
r"""Resample branch."""
|
|
42
43
|
|
|
44
|
+
@override
|
|
43
45
|
def __call__(self, x: Branch) -> Branch:
|
|
44
46
|
xyzr = self.resample(x.xyzr())
|
|
45
47
|
return Branch.from_xyzr(xyzr)
|
|
46
48
|
|
|
47
49
|
@abstractmethod
|
|
48
|
-
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
49
|
-
raise NotImplementedError()
|
|
50
|
+
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class BranchLinearResampler(_BranchResampler):
|
|
@@ -55,26 +56,21 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
55
56
|
def __init__(self, n_nodes: int) -> None:
|
|
56
57
|
"""Resample branch to special num of nodes.
|
|
57
58
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
n_nodes : int
|
|
61
|
-
Number of nodes after resample.
|
|
59
|
+
Args:
|
|
60
|
+
n_nodes: Number of nodes after resample.
|
|
62
61
|
"""
|
|
63
62
|
super().__init__()
|
|
64
63
|
self.n_nodes = n_nodes
|
|
65
64
|
|
|
65
|
+
@override
|
|
66
66
|
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
67
67
|
"""Resampling by linear interpolation, DO NOT keep original node.
|
|
68
68
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
xyzr : np.ndarray[np.float32]
|
|
72
|
-
The array of shape (N, 4).
|
|
69
|
+
Args:
|
|
70
|
+
xyzr: The array of shape (N, 4).
|
|
73
71
|
|
|
74
|
-
Returns
|
|
75
|
-
|
|
76
|
-
coordinates : ~numpy.NDArray[float64]
|
|
77
|
-
An array of shape (n_nodes, 4).
|
|
72
|
+
Returns:
|
|
73
|
+
coordinates: An array of shape (n_nodes, 4).
|
|
78
74
|
"""
|
|
79
75
|
|
|
80
76
|
xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
|
|
@@ -87,6 +83,7 @@ class BranchLinearResampler(_BranchResampler):
|
|
|
87
83
|
r = np.interp(xvals, xp, xyzr[:, 3])
|
|
88
84
|
return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
|
|
89
85
|
|
|
86
|
+
@override
|
|
90
87
|
def extra_repr(self) -> str:
|
|
91
88
|
return f"n_nodes={self.n_nodes}"
|
|
92
89
|
|
|
@@ -97,18 +94,15 @@ class BranchIsometricResampler(_BranchResampler):
|
|
|
97
94
|
self.distance = distance
|
|
98
95
|
self.adjust_last_gap = adjust_last_gap
|
|
99
96
|
|
|
97
|
+
@override
|
|
100
98
|
def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
101
99
|
"""Resampling by isometric interpolation, DO NOT keep original node.
|
|
102
100
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
xyzr : np.ndarray[np.float32]
|
|
106
|
-
The array of shape (N, 4).
|
|
101
|
+
Args:
|
|
102
|
+
xyzr: The array of shape (N, 4).
|
|
107
103
|
|
|
108
|
-
Returns
|
|
109
|
-
|
|
110
|
-
new_xyzr : ~numpy.NDArray[float32]
|
|
111
|
-
An array of shape (n_nodes, 4).
|
|
104
|
+
Returns:
|
|
105
|
+
new_xyzr: An array of shape (n_nodes, 4).
|
|
112
106
|
"""
|
|
113
107
|
|
|
114
108
|
# Compute the cumulative distances between consecutive points
|
|
@@ -138,6 +132,7 @@ class BranchIsometricResampler(_BranchResampler):
|
|
|
138
132
|
new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
|
|
139
133
|
return new_xyzr
|
|
140
134
|
|
|
135
|
+
@override
|
|
141
136
|
def extra_repr(self) -> str:
|
|
142
137
|
return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
|
|
143
138
|
|
|
@@ -147,15 +142,14 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
|
147
142
|
|
|
148
143
|
def __init__(self, n_nodes: int = 5) -> None:
|
|
149
144
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
n_nodes : int, default `5`
|
|
153
|
-
Window size.
|
|
145
|
+
Args:
|
|
146
|
+
n_nodes: Window size.
|
|
154
147
|
"""
|
|
155
148
|
super().__init__()
|
|
156
149
|
self.n_nodes = n_nodes
|
|
157
150
|
self.kernel = np.ones(n_nodes)
|
|
158
151
|
|
|
152
|
+
@override
|
|
159
153
|
def __call__(self, x: Branch) -> Branch[DictSWC]:
|
|
160
154
|
x = x.detach()
|
|
161
155
|
c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
|
|
@@ -173,10 +167,11 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
|
|
|
173
167
|
class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
|
174
168
|
r"""Standardize branch.
|
|
175
169
|
|
|
176
|
-
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at
|
|
177
|
-
|
|
170
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
|
|
171
|
+
radius to 1.
|
|
178
172
|
"""
|
|
179
173
|
|
|
174
|
+
@override
|
|
180
175
|
def __call__(self, x: Branch) -> Branch:
|
|
181
176
|
xyzr = x.xyzr()
|
|
182
177
|
xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
|
|
@@ -191,23 +186,18 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
|
|
|
191
186
|
def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
192
187
|
r"""Get standardize transformation matrix.
|
|
193
188
|
|
|
194
|
-
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up
|
|
195
|
-
at y.
|
|
189
|
+
Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
|
|
196
190
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
xyz : np.ndarray[np.float32]
|
|
200
|
-
The `x`, `y`, `z` matrix of shape (N, 3) of branch.
|
|
191
|
+
Args:
|
|
192
|
+
xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
|
|
201
193
|
|
|
202
|
-
Returns
|
|
203
|
-
|
|
204
|
-
T : np.ndarray[np.float32]
|
|
205
|
-
An homogeneous transformation matrix of shape (4, 4).
|
|
194
|
+
Returns:
|
|
195
|
+
T: An homogeneous transformation matrix of shape (4, 4).
|
|
206
196
|
"""
|
|
207
197
|
|
|
208
|
-
assert (
|
|
209
|
-
xyz
|
|
210
|
-
)
|
|
198
|
+
assert xyz.ndim == 2 and xyz.shape[1] == 3, (
|
|
199
|
+
f"xyz should be of shape (N, 3), got {xyz.shape}"
|
|
200
|
+
)
|
|
211
201
|
|
|
212
202
|
xyz = xyz[:, :3]
|
|
213
203
|
T = np.identity(4)
|