swcgeom 0.16.0__py3-none-any.whl → 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +26 -1
  2. swcgeom/analysis/__init__.py +21 -8
  3. swcgeom/analysis/feature_extractor.py +43 -18
  4. swcgeom/analysis/features.py +250 -0
  5. swcgeom/analysis/lmeasure.py +48 -12
  6. swcgeom/analysis/sholl.py +25 -28
  7. swcgeom/analysis/trunk.py +27 -11
  8. swcgeom/analysis/visualization.py +24 -9
  9. swcgeom/analysis/visualization3d.py +100 -0
  10. swcgeom/analysis/volume.py +19 -4
  11. swcgeom/core/__init__.py +31 -12
  12. swcgeom/core/branch.py +19 -3
  13. swcgeom/core/branch_tree.py +18 -4
  14. swcgeom/core/compartment.py +18 -2
  15. swcgeom/core/node.py +32 -3
  16. swcgeom/core/path.py +21 -9
  17. swcgeom/core/population.py +58 -29
  18. swcgeom/core/swc.py +26 -10
  19. swcgeom/core/swc_utils/__init__.py +21 -7
  20. swcgeom/core/swc_utils/assembler.py +15 -0
  21. swcgeom/core/swc_utils/base.py +23 -17
  22. swcgeom/core/swc_utils/checker.py +19 -12
  23. swcgeom/core/swc_utils/io.py +24 -7
  24. swcgeom/core/swc_utils/normalizer.py +20 -4
  25. swcgeom/core/swc_utils/subtree.py +17 -2
  26. swcgeom/core/tree.py +56 -40
  27. swcgeom/core/tree_utils.py +28 -17
  28. swcgeom/core/tree_utils_impl.py +18 -3
  29. swcgeom/images/__init__.py +17 -2
  30. swcgeom/images/augmentation.py +18 -3
  31. swcgeom/images/contrast.py +15 -0
  32. swcgeom/images/folder.py +27 -26
  33. swcgeom/images/io.py +94 -117
  34. swcgeom/transforms/__init__.py +28 -12
  35. swcgeom/transforms/base.py +17 -2
  36. swcgeom/transforms/branch.py +74 -8
  37. swcgeom/transforms/branch_tree.py +82 -0
  38. swcgeom/transforms/geometry.py +22 -7
  39. swcgeom/transforms/image_preprocess.py +15 -0
  40. swcgeom/transforms/image_stack.py +36 -9
  41. swcgeom/transforms/images.py +121 -14
  42. swcgeom/transforms/mst.py +15 -0
  43. swcgeom/transforms/neurolucida_asc.py +20 -7
  44. swcgeom/transforms/path.py +15 -0
  45. swcgeom/transforms/population.py +16 -3
  46. swcgeom/transforms/tree.py +84 -30
  47. swcgeom/transforms/tree_assembler.py +23 -7
  48. swcgeom/utils/__init__.py +27 -12
  49. swcgeom/utils/debug.py +15 -0
  50. swcgeom/utils/download.py +59 -21
  51. swcgeom/utils/dsu.py +15 -0
  52. swcgeom/utils/ellipse.py +18 -4
  53. swcgeom/utils/file.py +15 -0
  54. swcgeom/utils/neuromorpho.py +35 -23
  55. swcgeom/utils/numpy_helper.py +15 -0
  56. swcgeom/utils/plotter_2d.py +27 -6
  57. swcgeom/utils/plotter_3d.py +48 -0
  58. swcgeom/utils/renderer.py +21 -6
  59. swcgeom/utils/sdf.py +19 -7
  60. swcgeom/utils/solid_geometry.py +16 -3
  61. swcgeom/utils/transforms.py +17 -4
  62. swcgeom/utils/volumetric_object.py +23 -10
  63. {swcgeom-0.16.0.dist-info → swcgeom-0.18.3.dist-info}/LICENSE +1 -1
  64. {swcgeom-0.16.0.dist-info → swcgeom-0.18.3.dist-info}/METADATA +28 -24
  65. swcgeom-0.18.3.dist-info/RECORD +67 -0
  66. {swcgeom-0.16.0.dist-info → swcgeom-0.18.3.dist-info}/WHEEL +1 -1
  67. swcgeom/_version.py +0 -16
  68. swcgeom/analysis/branch_features.py +0 -67
  69. swcgeom/analysis/node_features.py +0 -121
  70. swcgeom/analysis/path_features.py +0 -37
  71. swcgeom-0.16.0.dist-info/RECORD +0 -67
  72. {swcgeom-0.16.0.dist-info → swcgeom-0.18.3.dist-info}/top_level.txt +0 -0
swcgeom/images/io.py CHANGED
@@ -1,33 +1,38 @@
1
+ # Copyright 2022-2025 Zexin Yuan
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
1
16
  """Read and write image stack."""
2
17
 
3
18
  import os
4
19
  import re
5
20
  import warnings
6
21
  from abc import ABC, abstractmethod
22
+ from collections.abc import Callable, Iterable
7
23
  from functools import cache, lru_cache
8
- from typing import (
9
- Any,
10
- Callable,
11
- Generic,
12
- Iterable,
13
- List,
14
- Literal,
15
- Optional,
16
- Tuple,
17
- TypeVar,
18
- cast,
19
- overload,
20
- )
24
+ from typing import Any, Generic, Literal, Optional, TypeVar, cast, overload
21
25
 
22
26
  import nrrd
23
27
  import numpy as np
24
28
  import numpy.typing as npt
25
29
  import tifffile
30
+ from typing_extensions import deprecated
26
31
  from v3dpy.loaders import PBD, Raw
27
32
 
28
33
  __all__ = ["read_imgs", "save_tiff", "read_images"]
29
34
 
30
- Vec3i = Tuple[int, int, int]
35
+ Vec3i = tuple[int, int, int]
31
36
  ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
32
37
 
33
38
  RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
@@ -58,17 +63,17 @@ class ImageStack(ABC, Generic[ScalarType]):
58
63
  def __getitem__(self, key: int) -> npt.NDArray[ScalarType]: ... # array of shape (Y, Z, C)
59
64
  @overload
60
65
  @abstractmethod
61
- def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (Z, C)
66
+ def __getitem__(self, key: tuple[int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (Z, C)
62
67
  @overload
63
68
  @abstractmethod
64
- def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (C,)
69
+ def __getitem__(self, key: tuple[int, int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (C,)
65
70
  @overload
66
71
  @abstractmethod
67
- def __getitem__(self, key: Tuple[int, int, int, int]) -> ScalarType: ... # value
72
+ def __getitem__(self, key: tuple[int, int, int, int]) -> ScalarType: ... # value
68
73
  @overload
69
74
  @abstractmethod
70
75
  def __getitem__(
71
- self, key: slice | Tuple[slice, slice] | Tuple[slice, slice, slice] | Tuple[slice, slice, slice, slice],
76
+ self, key: slice | tuple[slice, slice] | tuple[slice, slice, slice] | tuple[slice, slice, slice, slice],
72
77
  ) -> npt.NDArray[ScalarType]: ... # array of shape (X, Y, Z, C)
73
78
  @overload
74
79
  @abstractmethod
@@ -95,7 +100,7 @@ class ImageStack(ABC, Generic[ScalarType]):
95
100
  return self[:, :, :, :]
96
101
 
97
102
  @property
98
- def shape(self) -> Tuple[int, int, int, int]:
103
+ def shape(self) -> tuple[int, int, int, int]:
99
104
  raise NotImplementedError()
100
105
 
101
106
 
@@ -107,26 +112,42 @@ def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float3
107
112
  # fmt:on
108
113
 
109
114
 
110
- def read_imgs(fname: str, *, dtype=None, **kwargs): # type: ignore
111
- """Read image stack."""
115
+ def read_imgs(fname: str, **kwargs): # type: ignore
116
+ """Read image stack.
112
117
 
113
- kwargs["dtype"] = dtype or np.float32
118
+ Parameters
119
+ ----------
120
+ fname : str
121
+ The path of image stack.
122
+ dtype : np.dtype, default to `np.float32`
123
+ Casting data to specified dtype. If integer and float
124
+ conversions occur, they will be scaled (assuming floats are
125
+ between 0 and 1).
126
+ **kwargs : dict[str, Any]
127
+ Forwarding to the corresponding reader.
128
+ """
114
129
 
115
- ext = os.path.splitext(fname)[-1]
116
- if ext in [".tif", ".tiff"]:
117
- return TiffImageStack(fname, **kwargs)
118
- if ext in [".nrrd"]:
119
- return NrrdImageStack(fname, **kwargs)
120
- if ext in [".v3dpbd"]:
121
- return V3dpbdImageStack(fname, **kwargs)
122
- if ext in [".v3draw"]:
123
- return V3drawImageStack(fname, **kwargs)
124
- if ext in [".npy", ".npz"]:
125
- return NDArrayImageStack(np.load(fname), **kwargs)
130
+ kwargs.setdefault("dtype", np.float32)
131
+ if not os.path.exists(fname):
132
+ raise ValueError(f"image stack not exists: {fname}")
133
+
134
+ # match file extension
135
+ match os.path.splitext(fname)[-1]:
136
+ case ".tif" | ".tiff":
137
+ return TiffImageStack(fname, **kwargs)
138
+ case ".nrrd":
139
+ return NrrdImageStack(fname, **kwargs)
140
+ case ".v3dpbd":
141
+ return V3dpbdImageStack(fname, **kwargs)
142
+ case ".v3draw":
143
+ return V3drawImageStack(fname, **kwargs)
144
+ case ".npy":
145
+ return NDArrayImageStack(np.load(fname), **kwargs)
146
+
147
+ # try to read as terafly
126
148
  if TeraflyImageStack.is_root(fname):
127
149
  return TeraflyImageStack(fname, **kwargs)
128
- if not os.path.exists(fname):
129
- raise ValueError("image stack not exists")
150
+
130
151
  raise ValueError("unsupported image stack")
131
152
 
132
153
 
@@ -135,7 +156,6 @@ def save_tiff(
135
156
  fname: str,
136
157
  *,
137
158
  dtype: Optional[np.unsignedinteger | np.floating] = None,
138
- swap_xy: Optional[bool] = None,
139
159
  compression: str | Literal[False] = "zlib",
140
160
  **kwargs,
141
161
  ) -> None:
@@ -154,7 +174,7 @@ def save_tiff(
154
174
  Compression algorithm, forwarding to `tifffile.imwrite`. If no
155
175
  algorithnm is specify specified, we will use the zlib algorithm
156
176
  with compression level 6 by default.
157
- **kwargs : Dict[str, Any]
177
+ **kwargs : dict[str, Any]
158
178
  Forwarding to `tifffile.imwrite`
159
179
  """
160
180
  if isinstance(data, ImageStack):
@@ -164,17 +184,6 @@ def save_tiff(
164
184
  data = np.expand_dims(data, -1) # (_, _, _) -> (_, _, _, C), C === 1
165
185
 
166
186
  axes = "ZXYC"
167
- if swap_xy is not None:
168
- warnings.warn(
169
- "flag `swap_xy` is easy to implement in user space and "
170
- "is more flexiable. Since this flag is rarely used, we "
171
- "decided to remove it in the next version",
172
- DeprecationWarning,
173
- )
174
- if swap_xy is True:
175
- axes = "ZYXC"
176
- data = data.swapaxes(0, 1) # (X, Y, _, _) -> (Y, X, _, _)
177
-
178
187
  assert data.ndim == 4, "should be an array of shape (X, Y, Z, C)"
179
188
  assert data.shape[-1] in [1, 3], "support 'miniblack' or 'rgb'"
180
189
 
@@ -209,12 +218,7 @@ class NDArrayImageStack(ImageStack[ScalarType]):
209
218
  """NDArray image stack."""
210
219
 
211
220
  def __init__(
212
- self,
213
- imgs: npt.NDArray[Any],
214
- swap_xy: Optional[bool] = None,
215
- filp_xy: Optional[bool] = None,
216
- *,
217
- dtype: ScalarType,
221
+ self, imgs: npt.NDArray[Any], *, dtype: Optional[ScalarType] = None
218
222
  ) -> None:
219
223
  super().__init__()
220
224
 
@@ -222,34 +226,22 @@ class NDArrayImageStack(ImageStack[ScalarType]):
222
226
  imgs = np.expand_dims(imgs, -1)
223
227
  assert imgs.ndim == 4, "Should be shape of (X, Y, Z, C)"
224
228
 
225
- if swap_xy is not None:
226
- warnings.warn(
227
- "flag `swap_xy` now is unnecessary, tifffile will "
228
- "automatically adjust dimensions according to "
229
- "`tags.axes`, so this flag will be removed in the next "
230
- " version",
231
- DeprecationWarning,
232
- )
233
- if swap_xy is True:
234
- imgs = imgs.swapaxes(0, 1) # (Y, X, _, _) -> (X, Y, _, _)
235
-
236
- if filp_xy is not None:
237
- warnings.warn(
238
- "flag `filp_xy` is easy to implement in user space and "
239
- "is more flexiable. Since this flag is rarely used, we "
240
- "decided to remove it in the next version",
241
- DeprecationWarning,
242
- )
243
- if filp_xy is True:
244
- imgs = np.flip(imgs, (0, 1)) # (X, Y, Z, C)
229
+ if dtype is not None:
230
+ dtype_raw = imgs.dtype
231
+ if np.issubdtype(dtype, np.floating) and np.issubdtype(
232
+ dtype_raw, np.unsignedinteger
233
+ ):
234
+ sclar_factor = 1.0 / UINT_MAX[dtype_raw]
235
+ imgs = sclar_factor * imgs.astype(dtype)
236
+ elif np.issubdtype(dtype, np.unsignedinteger) and np.issubdtype(
237
+ dtype_raw, np.floating
238
+ ):
239
+ sclar_factor = UINT_MAX[dtype] # type: ignore
240
+ imgs *= (sclar_factor * imgs).astype(dtype)
241
+ else:
242
+ imgs = imgs.astype(dtype)
245
243
 
246
- dtype_raw = imgs.dtype
247
- self.imgs = imgs.astype(dtype)
248
- if np.issubdtype(dtype, np.floating) and np.issubdtype(
249
- dtype_raw, np.unsignedinteger
250
- ): # TODO: add a option to disable this
251
- sclar_factor = 1.0 / UINT_MAX[imgs.dtype]
252
- self.imgs *= sclar_factor
244
+ self.imgs = imgs
253
245
 
254
246
  def __getitem__(self, key):
255
247
  return self.imgs.__getitem__(key)
@@ -258,22 +250,14 @@ class NDArrayImageStack(ImageStack[ScalarType]):
258
250
  return self.imgs
259
251
 
260
252
  @property
261
- def shape(self) -> Tuple[int, int, int, int]:
262
- return cast(Tuple[int, int, int, int], self.imgs.shape)
253
+ def shape(self) -> tuple[int, int, int, int]:
254
+ return cast(tuple[int, int, int, int], self.imgs.shape)
263
255
 
264
256
 
265
257
  class TiffImageStack(NDArrayImageStack[ScalarType]):
266
258
  """Tiff image stack."""
267
259
 
268
- def __init__(
269
- self,
270
- fname: str,
271
- swap_xy: Optional[bool] = None,
272
- filp_xy: Optional[bool] = None,
273
- *,
274
- dtype: ScalarType,
275
- **kwargs,
276
- ) -> None:
260
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
277
261
  with tifffile.TiffFile(fname, **kwargs) as f:
278
262
  s = f.series[0]
279
263
  imgs, axes = s.asarray(), s.axes
@@ -285,23 +269,15 @@ class TiffImageStack(NDArrayImageStack[ScalarType]):
285
269
 
286
270
  orders = [AXES_ORDER[c] for c in axes]
287
271
  imgs = imgs.transpose(np.argsort(orders))
288
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
272
+ super().__init__(imgs, dtype=dtype)
289
273
 
290
274
 
291
275
  class NrrdImageStack(NDArrayImageStack[ScalarType]):
292
276
  """Nrrd image stack."""
293
277
 
294
- def __init__(
295
- self,
296
- fname: str,
297
- swap_xy: Optional[bool] = None,
298
- filp_xy: Optional[bool] = None,
299
- *,
300
- dtype: ScalarType,
301
- **kwargs,
302
- ) -> None:
278
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
303
279
  imgs, header = nrrd.read(fname, **kwargs)
304
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
280
+ super().__init__(imgs, dtype=dtype)
305
281
  self.header = header
306
282
 
307
283
 
@@ -353,7 +329,7 @@ class TeraflyImageStack(ImageStack[ScalarType]):
353
329
  use its coordinate system, remember to FLIP Y-AXIS BACK.
354
330
  """
355
331
 
356
- _listdir: Callable[[str], List[str]]
332
+ _listdir: Callable[[str], list[str]]
357
333
  _read_patch: Callable[[str], npt.NDArray]
358
334
 
359
335
  def __init__(
@@ -379,12 +355,17 @@ class TeraflyImageStack(ImageStack[ScalarType]):
379
355
  self.res, self.res_dirs, self.res_patch_sizes = self.get_resolutions(root)
380
356
 
381
357
  @cache
382
- def listdir(path: str) -> List[str]:
358
+ def listdir(path: str) -> list[str]:
383
359
  return os.listdir(path)
384
360
 
385
361
  @lru_cache(maxsize=lru_maxsize)
386
362
  def read_patch(path: str) -> npt.NDArray[ScalarType]:
387
- return read_imgs(path, dtype=dtype).get_full()
363
+ match os.path.splitext(path)[-1]:
364
+ case "raw":
365
+ # Treat it as a v3draw file
366
+ return V3drawImageStack(path, dtype=dtype).get_full()
367
+ case _:
368
+ return read_imgs(path, dtype=dtype).get_full()
388
369
 
389
370
  self._listdir, self._read_patch = listdir, read_patch
390
371
 
@@ -453,19 +434,19 @@ class TeraflyImageStack(ImageStack[ScalarType]):
453
434
  raise NotImplementedError() # TODO
454
435
 
455
436
  @property
456
- def shape(self) -> Tuple[int, int, int, int]:
437
+ def shape(self) -> tuple[int, int, int, int]:
457
438
  res_max = self.res[-1]
458
439
  return res_max[0], res_max[1], res_max[2], 1
459
440
 
460
441
  @classmethod
461
- def get_resolutions(cls, root: str) -> Tuple[List[Vec3i], List[str], List[Vec3i]]:
442
+ def get_resolutions(cls, root: str) -> tuple[list[Vec3i], list[str], list[Vec3i]]:
462
443
  """Get all resolutions.
463
444
 
464
445
  Returns
465
446
  -------
466
447
  resolutions : List of (int, int, int)
467
448
  Sequence of sorted resolutions (from small to large).
468
- roots : List[str]
449
+ roots : list[str]
469
450
  Sequence of root of resolutions respectively.
470
451
  patch_sizes : List of (int, int, int)
471
452
  Sequence of patch size of resolutions respectively.
@@ -572,7 +553,7 @@ class TeraflyImageStack(ImageStack[ScalarType]):
572
553
  if (invalid := diff > 10 * v).all():
573
554
  return None, None
574
555
 
575
- diff[invalid] = np.NINF # remove values which greate than v
556
+ diff[invalid] = -np.inf # remove values which greater than v
576
557
 
577
558
  # find the index of the value smaller than v and closest to v
578
559
  idx = np.argmax(diff)
@@ -605,7 +586,7 @@ class GrayImageStack:
605
586
  @overload
606
587
  def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> np.float32: ...
607
588
  @overload
608
- def __getitem__(self, key: slice | Tuple[slice, slice] | Tuple[slice, slice, slice]) -> npt.NDArray[np.float32]: ...
589
+ def __getitem__(self, key: slice | tuple[slice, slice] | tuple[slice, slice, slice]) -> npt.NDArray[np.float32]: ...
609
590
  # fmt: on
610
591
  def __getitem__(self, key):
611
592
  """Get pixel/patch of image stack."""
@@ -632,10 +613,11 @@ class GrayImageStack:
632
613
  return self.imgs.get_full()[:, :, :, 0]
633
614
 
634
615
  @property
635
- def shape(self) -> Tuple[int, int, int]:
616
+ def shape(self) -> tuple[int, int, int]:
636
617
  return self.imgs.shape[:-1]
637
618
 
638
619
 
620
+ @deprecated("Use `read_imgs` instead")
639
621
  def read_images(*args, **kwargs) -> GrayImageStack:
640
622
  """Read images.
641
623
 
@@ -643,9 +625,4 @@ def read_images(*args, **kwargs) -> GrayImageStack:
643
625
  Use :meth:`read_imgs` instead.
644
626
  """
645
627
 
646
- warnings.warn(
647
- "`read_images` has been replaced by `read_imgs` because it"
648
- "provide rgb support, and this will be removed in next version",
649
- DeprecationWarning,
650
- )
651
628
  return GrayImageStack(read_imgs(*args, **kwargs))
@@ -1,14 +1,30 @@
1
+ # Copyright 2022-2025 Zexin Yuan
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
1
16
  """A series of transformations to compose codes."""
2
17
 
3
- from swcgeom.transforms.base import *
4
- from swcgeom.transforms.branch import *
5
- from swcgeom.transforms.geometry import *
6
- from swcgeom.transforms.image_preprocess import *
7
- from swcgeom.transforms.image_stack import *
8
- from swcgeom.transforms.images import *
9
- from swcgeom.transforms.mst import *
10
- from swcgeom.transforms.neurolucida_asc import *
11
- from swcgeom.transforms.path import *
12
- from swcgeom.transforms.population import *
13
- from swcgeom.transforms.tree import *
14
- from swcgeom.transforms.tree_assembler import *
18
+ from swcgeom.transforms.base import * # noqa: F403
19
+ from swcgeom.transforms.branch import * # noqa: F403
20
+ from swcgeom.transforms.branch_tree import * # noqa: F403
21
+ from swcgeom.transforms.geometry import * # noqa: F403
22
+ from swcgeom.transforms.image_preprocess import * # noqa: F403
23
+ from swcgeom.transforms.image_stack import * # noqa: F403
24
+ from swcgeom.transforms.images import * # noqa: F403
25
+ from swcgeom.transforms.mst import * # noqa: F403
26
+ from swcgeom.transforms.neurolucida_asc import * # noqa: F403
27
+ from swcgeom.transforms.path import * # noqa: F403
28
+ from swcgeom.transforms.population import * # noqa: F403
29
+ from swcgeom.transforms.tree import * # noqa: F403
30
+ from swcgeom.transforms.tree_assembler import * # noqa: F403
@@ -1,3 +1,18 @@
1
+ # Copyright 2022-2025 Zexin Yuan
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
1
16
  """Transformation in tree."""
2
17
 
3
18
  from abc import ABC, abstractmethod
@@ -33,7 +48,7 @@ class Transform(ABC, Generic[T, K]):
33
48
  repr_ = self.extra_repr()
34
49
  return f"{classname}({repr_})"
35
50
 
36
- def extra_repr(self):
51
+ def extra_repr(self) -> str:
37
52
  """Provides a human-friendly representation of the module.
38
53
 
39
54
  This method extends the basic string representation provided by
@@ -48,7 +63,7 @@ class Transform(ABC, Generic[T, K]):
48
63
  def __init__(self, my_parameter: int = 1):
49
64
  self.my_parameter = my_parameter
50
65
 
51
- def extra_repr(self):
66
+ def extra_repr(self) -> str:
52
67
  return f"my_parameter={self.my_parameter}"
53
68
 
54
69
  Notes
@@ -1,3 +1,18 @@
1
+ # Copyright 2022-2025 Zexin Yuan
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
1
16
  """Transformation in branch."""
2
17
 
3
18
  from abc import ABC, abstractmethod
@@ -72,10 +87,61 @@ class BranchLinearResampler(_BranchResampler):
72
87
  r = np.interp(xvals, xp, xyzr[:, 3])
73
88
  return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
74
89
 
75
- def extra_repr(self):
90
+ def extra_repr(self) -> str:
76
91
  return f"n_nodes={self.n_nodes}"
77
92
 
78
93
 
94
+ class BranchIsometricResampler(_BranchResampler):
95
+ def __init__(self, distance: float, *, adjust_last_gap: bool = True) -> None:
96
+ super().__init__()
97
+ self.distance = distance
98
+ self.adjust_last_gap = adjust_last_gap
99
+
100
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
101
+ """Resampling by isometric interpolation, DO NOT keep original node.
102
+
103
+ Parameters
104
+ ----------
105
+ xyzr : np.ndarray[np.float32]
106
+ The array of shape (N, 4).
107
+
108
+ Returns
109
+ -------
110
+ new_xyzr : ~numpy.NDArray[float32]
111
+ An array of shape (n_nodes, 4).
112
+ """
113
+
114
+ # Compute the cumulative distances between consecutive points
115
+ diffs = np.diff(xyzr[:, :3], axis=0)
116
+ distances = np.sqrt((diffs**2).sum(axis=1))
117
+ cumulative_distances = np.concatenate([[0], np.cumsum(distances)])
118
+
119
+ total_length = cumulative_distances[-1]
120
+ n_nodes = int(np.ceil(total_length / self.distance)) + 1
121
+
122
+ # Determine the new distances
123
+ if self.adjust_last_gap and n_nodes > 1:
124
+ new_distances = np.linspace(0, total_length, n_nodes)
125
+ else:
126
+ new_distances = np.arange(0, total_length, self.distance)
127
+ # keep endpoint
128
+ new_distances = np.concatenate([new_distances, total_length])
129
+
130
+ # Interpolate the new points
131
+ new_xyzr = np.zeros((n_nodes, 4), dtype=np.float32)
132
+ new_xyzr[:, :3] = np.array(
133
+ [
134
+ np.interp(new_distances, cumulative_distances, xyzr[:, i])
135
+ for i in range(3)
136
+ ]
137
+ ).T
138
+ new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
139
+ return new_xyzr
140
+
141
+ def extra_repr(self) -> str:
142
+ return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
143
+
144
+
79
145
  class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
80
146
  r"""Smooth the branch by sliding window."""
81
147
 
@@ -88,24 +154,24 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
88
154
  """
89
155
  super().__init__()
90
156
  self.n_nodes = n_nodes
91
- self.kernal = np.ones(n_nodes)
157
+ self.kernel = np.ones(n_nodes)
92
158
 
93
159
  def __call__(self, x: Branch) -> Branch[DictSWC]:
94
160
  x = x.detach()
95
- c = signal.convolve(np.ones(x.number_of_nodes()), self.kernal, mode="same")
161
+ c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
96
162
  for k in ["x", "y", "z"]:
97
163
  v = x.get_ndata(k)
98
- s = signal.convolve(v, self.kernal, mode="same")
164
+ s = signal.convolve(v, self.kernel, mode="same")
99
165
  x.attach.ndata[k][1:-1] = (s / c)[1:-1]
100
166
 
101
167
  return x
102
168
 
103
- def extra_repr(self):
169
+ def extra_repr(self) -> str:
104
170
  return f"n_nodes={self.n_nodes}"
105
171
 
106
172
 
107
173
  class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
108
- r"""Standarize branch.
174
+ r"""Standardize branch.
109
175
 
110
176
  Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at
111
177
  y, and scale max radius to 1.
@@ -123,7 +189,7 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
123
189
 
124
190
  @staticmethod
125
191
  def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
126
- r"""Get standarize transformation matrix.
192
+ r"""Get standardize transformation matrix.
127
193
 
128
194
  Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up
129
195
  at y.
@@ -136,7 +202,7 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
136
202
  Returns
137
203
  -------
138
204
  T : np.ndarray[np.float32]
139
- An homogeneous transfomation matrix of shape (4, 4).
205
+ An homogeneous transformation matrix of shape (4, 4).
140
206
  """
141
207
 
142
208
  assert (
@@ -0,0 +1,82 @@
1
+ # Copyright 2022-2025 Zexin Yuan
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Iterable
17
+
18
+ import numpy as np
19
+
20
+ from swcgeom.core import Branch, BranchTree, Node, Tree
21
+ from swcgeom.transforms.base import Transform
22
+
23
+ __all__ = ["BranchTreeAssembler"]
24
+
25
+
26
+ class BranchTreeAssembler(Transform[BranchTree, Tree]):
27
+ EPS = 1e-6
28
+
29
+ def __call__(self, x: BranchTree) -> Tree:
30
+ nodes = [x.soma().detach()]
31
+ stack = [(x.soma(), 0)] # n_orig, id_new
32
+ while len(stack):
33
+ n_orig, pid_new = stack.pop()
34
+ children = n_orig.children()
35
+
36
+ for br, c in self.pair(x.branches.get(n_orig.id, []), children):
37
+ s = 1 if np.linalg.norm(br[0].xyz() - n_orig.xyz()) < self.EPS else 0
38
+ e = -2 if np.linalg.norm(br[-1].xyz() - c.xyz()) < self.EPS else -1
39
+
40
+ br_nodes = [n.detach() for n in br[s:e]] + [c.detach()]
41
+ for i, n in enumerate(br_nodes):
42
+ # reindex
43
+ n.id = len(nodes) + i
44
+ n.pid = len(nodes) + i - 1
45
+
46
+ br_nodes[0].pid = pid_new
47
+ nodes.extend(br_nodes)
48
+ stack.append((c, br_nodes[-1].id))
49
+
50
+ return Tree(
51
+ len(nodes),
52
+ source=x.source,
53
+ comments=x.comments,
54
+ names=x.names,
55
+ **{
56
+ k: np.array([n.__getattribute__(k) for n in nodes])
57
+ for k in x.names.cols()
58
+ },
59
+ )
60
+
61
+ def pair(
62
+ self, branches: list[Branch], endpoints: list[Node]
63
+ ) -> Iterable[tuple[Branch, Node]]:
64
+ assert len(branches) == len(endpoints)
65
+ xyz1 = [br[-1].xyz() for br in branches]
66
+ xyz2 = [n.xyz() for n in endpoints]
67
+ v = np.reshape(xyz1, (-1, 1, 3)) - np.reshape(xyz2, (1, -1, 3))
68
+ dis = np.linalg.norm(v, axis=-1)
69
+
70
+ # greedy algorithm
71
+ pairs = []
72
+ for _ in range(len(branches)):
73
+ # find minimal
74
+ min_idx = np.argmin(dis)
75
+ min_branch_idx, min_endpoint_idx = np.unravel_index(min_idx, dis.shape)
76
+ pairs.append((branches[min_branch_idx], endpoints[min_endpoint_idx]))
77
+
78
+ # remove current node
79
+ dis[min_branch_idx, :] = np.inf
80
+ dis[:, min_endpoint_idx] = np.inf
81
+
82
+ return pairs