swcgeom 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (68) hide show
  1. swcgeom/__init__.py +12 -1
  2. swcgeom/analysis/__init__.py +6 -6
  3. swcgeom/analysis/feature_extractor.py +22 -24
  4. swcgeom/analysis/features.py +18 -40
  5. swcgeom/analysis/lmeasure.py +227 -323
  6. swcgeom/analysis/sholl.py +17 -23
  7. swcgeom/analysis/trunk.py +23 -28
  8. swcgeom/analysis/visualization.py +37 -44
  9. swcgeom/analysis/visualization3d.py +16 -25
  10. swcgeom/analysis/volume.py +33 -47
  11. swcgeom/core/__init__.py +12 -13
  12. swcgeom/core/branch.py +10 -17
  13. swcgeom/core/branch_tree.py +3 -2
  14. swcgeom/core/compartment.py +1 -1
  15. swcgeom/core/node.py +3 -6
  16. swcgeom/core/path.py +11 -16
  17. swcgeom/core/population.py +32 -51
  18. swcgeom/core/swc.py +25 -16
  19. swcgeom/core/swc_utils/__init__.py +10 -12
  20. swcgeom/core/swc_utils/assembler.py +5 -12
  21. swcgeom/core/swc_utils/base.py +40 -31
  22. swcgeom/core/swc_utils/checker.py +3 -8
  23. swcgeom/core/swc_utils/io.py +32 -47
  24. swcgeom/core/swc_utils/normalizer.py +17 -23
  25. swcgeom/core/swc_utils/subtree.py +13 -20
  26. swcgeom/core/tree.py +61 -51
  27. swcgeom/core/tree_utils.py +36 -49
  28. swcgeom/core/tree_utils_impl.py +4 -6
  29. swcgeom/images/__init__.py +2 -2
  30. swcgeom/images/augmentation.py +23 -39
  31. swcgeom/images/contrast.py +22 -46
  32. swcgeom/images/folder.py +32 -34
  33. swcgeom/images/io.py +80 -121
  34. swcgeom/transforms/__init__.py +13 -13
  35. swcgeom/transforms/base.py +28 -19
  36. swcgeom/transforms/branch.py +31 -41
  37. swcgeom/transforms/branch_tree.py +3 -1
  38. swcgeom/transforms/geometry.py +13 -4
  39. swcgeom/transforms/image_preprocess.py +2 -0
  40. swcgeom/transforms/image_stack.py +40 -35
  41. swcgeom/transforms/images.py +31 -24
  42. swcgeom/transforms/mst.py +27 -40
  43. swcgeom/transforms/neurolucida_asc.py +13 -13
  44. swcgeom/transforms/path.py +4 -0
  45. swcgeom/transforms/population.py +4 -0
  46. swcgeom/transforms/tree.py +16 -11
  47. swcgeom/transforms/tree_assembler.py +37 -54
  48. swcgeom/utils/__init__.py +12 -12
  49. swcgeom/utils/download.py +7 -14
  50. swcgeom/utils/dsu.py +12 -0
  51. swcgeom/utils/ellipse.py +26 -14
  52. swcgeom/utils/file.py +8 -13
  53. swcgeom/utils/neuromorpho.py +78 -92
  54. swcgeom/utils/numpy_helper.py +15 -12
  55. swcgeom/utils/plotter_2d.py +10 -16
  56. swcgeom/utils/plotter_3d.py +7 -9
  57. swcgeom/utils/renderer.py +16 -8
  58. swcgeom/utils/sdf.py +12 -23
  59. swcgeom/utils/solid_geometry.py +58 -2
  60. swcgeom/utils/transforms.py +164 -100
  61. swcgeom/utils/volumetric_object.py +29 -53
  62. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/METADATA +7 -6
  63. swcgeom-0.19.0.dist-info/RECORD +67 -0
  64. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/WHEEL +1 -1
  65. swcgeom/_version.py +0 -16
  66. swcgeom-0.18.1.dist-info/RECORD +0 -68
  67. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info/licenses}/LICENSE +0 -0
  68. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/top_level.txt +0 -0
swcgeom/images/io.py CHANGED
@@ -21,7 +21,7 @@ import warnings
21
21
  from abc import ABC, abstractmethod
22
22
  from collections.abc import Callable, Iterable
23
23
  from functools import cache, lru_cache
24
- from typing import Any, Generic, Literal, Optional, TypeVar, cast, overload
24
+ from typing import Any, Generic, Literal, TypeVar, cast, overload
25
25
 
26
26
  import nrrd
27
27
  import numpy as np
@@ -39,10 +39,10 @@ RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
39
39
  RE_TERAFLY_NAME = re.compile(r"^\d+(_\d+)?(_\d+)?")
40
40
 
41
41
  UINT_MAX = {
42
- np.dtype(np.uint8): (2**8) - 1, # type: ignore
43
- np.dtype(np.uint16): (2**16) - 1, # type: ignore
44
- np.dtype(np.uint32): (2**32) - 1, # type: ignore
45
- np.dtype(np.uint64): (2**64) - 1, # type: ignore
42
+ np.dtype(np.uint8): (2**8) - 1,
43
+ np.dtype(np.uint16): (2**16) - 1,
44
+ np.dtype(np.uint32): (2**32) - 1,
45
+ np.dtype(np.uint64): (2**64) - 1,
46
46
  }
47
47
 
48
48
  AXES_ORDER = {
@@ -83,19 +83,15 @@ class ImageStack(ABC, Generic[ScalarType]):
83
83
  def __getitem__(self, key):
84
84
  """Get pixel/patch of image stack.
85
85
 
86
- Returns
87
- -------
88
- value : ndarray of f32
89
- NDArray which shape depends on key. If key is tuple of ints,
86
+ Returns:
87
+ value: NDArray which shape depends on key. If key is tuple of ints,
90
88
  """
91
89
  raise NotImplementedError()
92
90
 
93
91
  def get_full(self) -> npt.NDArray[ScalarType]:
94
92
  """Get full image stack.
95
93
 
96
- Notes
97
- -----
98
- this will load the full image stack into memory.
94
+ NOTE: this will load the full image stack into memory.
99
95
  """
100
96
  return self[:, :, :, :]
101
97
 
@@ -104,29 +100,20 @@ class ImageStack(ABC, Generic[ScalarType]):
104
100
  raise NotImplementedError()
105
101
 
106
102
 
107
- # fmt:off
108
103
  @overload
109
104
  def read_imgs(fname: str, *, dtype: ScalarType, **kwargs) -> ImageStack[ScalarType]: ...
110
105
  @overload
111
- def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float32]: ...
112
- # fmt:on
113
-
114
-
106
+ def read_imgs(fname: str, *, dtype: None = ..., **kwargs) -> ImageStack[np.float32]: ...
115
107
  def read_imgs(fname: str, **kwargs): # type: ignore
116
108
  """Read image stack.
117
109
 
118
- Parameters
119
- ----------
120
- fname : str
121
- The path of image stack.
122
- dtype : np.dtype, default to `np.float32`
123
- Casting data to specified dtype. If integer and float
124
- conversions occur, they will be scaled (assuming floats are
125
- between 0 and 1).
126
- **kwargs : dict[str, Any]
127
- Forwarding to the corresponding reader.
110
+ Args:
111
+ fname: The path of image stack.
112
+ dtype: Casting data to specified dtype.
113
+ If integer and float conversions occur, they will be scaled (assuming floats
114
+ are between 0 and 1). Default to `np.float32`.
115
+ **kwargs: Forwarding to the corresponding reader.
128
116
  """
129
-
130
117
  kwargs.setdefault("dtype", np.float32)
131
118
  if not os.path.exists(fname):
132
119
  raise ValueError(f"image stack not exists: {fname}")
@@ -155,27 +142,22 @@ def save_tiff(
155
142
  data: npt.NDArray | ImageStack,
156
143
  fname: str,
157
144
  *,
158
- dtype: Optional[np.unsignedinteger | np.floating] = None,
145
+ dtype: np.unsignedinteger | np.floating | None = None,
159
146
  compression: str | Literal[False] = "zlib",
160
147
  **kwargs,
161
148
  ) -> None:
162
149
  """Save image stack as tiff.
163
150
 
164
- Parameters
165
- ----------
166
- data : array
167
- The image stack.
168
- fname : str
169
- dtype : np.dtype, optional
170
- Casting data to specified dtype. If integer and float
171
- conversions occur, they will be scaled (assuming floats are
172
- between 0 and 1).
173
- compression : str | False, default `zlib`
174
- Compression algorithm, forwarding to `tifffile.imwrite`. If no
175
- algorithnm is specify specified, we will use the zlib algorithm
176
- with compression level 6 by default.
177
- **kwargs : dict[str, Any]
178
- Forwarding to `tifffile.imwrite`
151
+ Args:
152
+ data: The image stack.
153
+ fname: str
154
+ dtype: Casting data to specified dtype.
155
+ If integer and float conversions occur, they will be scaled (assuming
156
+ floats are between 0 and 1).
157
+ compression: Compression algorithm, forwarding to `tifffile.imwrite`.
158
+ If no algorithnm is specify specified, we will use the zlib algorithm with
159
+ compression level 6 by default.
160
+ **kwargs: Forwarding to `tifffile.imwrite`
179
161
  """
180
162
  if isinstance(data, ImageStack):
181
163
  data = data.get_full() # TODO: avoid load full imgs to memory
@@ -191,11 +173,11 @@ def save_tiff(
191
173
  if np.issubdtype(data.dtype, np.floating) and np.issubdtype(
192
174
  dtype, np.unsignedinteger
193
175
  ):
194
- scaler_factor = UINT_MAX[np.dtype(dtype)] # type: ignore
176
+ scaler_factor = UINT_MAX[np.dtype(dtype)]
195
177
  elif np.issubdtype(data.dtype, np.unsignedinteger) and np.issubdtype(
196
178
  dtype, np.floating
197
179
  ):
198
- scaler_factor = 1 / UINT_MAX[np.dtype(data.dtype)] # type: ignore
180
+ scaler_factor = 1 / UINT_MAX[np.dtype(data.dtype)]
199
181
  else:
200
182
  scaler_factor = 1
201
183
 
@@ -218,7 +200,7 @@ class NDArrayImageStack(ImageStack[ScalarType]):
218
200
  """NDArray image stack."""
219
201
 
220
202
  def __init__(
221
- self, imgs: npt.NDArray[Any], *, dtype: Optional[ScalarType] = None
203
+ self, imgs: npt.NDArray[Any], *, dtype: ScalarType | None = None
222
204
  ) -> None:
223
205
  super().__init__()
224
206
 
@@ -231,13 +213,13 @@ class NDArrayImageStack(ImageStack[ScalarType]):
231
213
  if np.issubdtype(dtype, np.floating) and np.issubdtype(
232
214
  dtype_raw, np.unsignedinteger
233
215
  ):
234
- sclar_factor = 1.0 / UINT_MAX[dtype_raw]
235
- imgs = sclar_factor * imgs.astype(dtype)
216
+ scalar_factor = 1.0 / UINT_MAX[dtype_raw]
217
+ imgs = scalar_factor * imgs.astype(dtype)
236
218
  elif np.issubdtype(dtype, np.unsignedinteger) and np.issubdtype(
237
219
  dtype_raw, np.floating
238
220
  ):
239
- sclar_factor = UINT_MAX[dtype] # type: ignore
240
- imgs *= (sclar_factor * imgs).astype(dtype)
221
+ scalar_factor = UINT_MAX[dtype]
222
+ imgs *= (scalar_factor * imgs).astype(dtype)
241
223
  else:
242
224
  imgs = imgs.astype(dtype)
243
225
 
@@ -284,27 +266,23 @@ class NrrdImageStack(NDArrayImageStack[ScalarType]):
284
266
  class V3dImageStack(NDArrayImageStack[ScalarType]):
285
267
  """v3d image stack."""
286
268
 
287
- def __init__(
288
- self, fname: str, loader: Raw | PBD, *, dtype: ScalarType, **kwargs
289
- ) -> None:
290
- r = loader()
269
+ def __init_subclass__(cls, loader: Raw | PBD) -> None:
270
+ super().__init_subclass__()
271
+ cls._loader = loader
272
+
273
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
274
+ r = self._loader()
291
275
  imgs = r.load(fname)
292
276
  super().__init__(imgs, dtype=dtype, **kwargs)
293
277
 
294
278
 
295
- class V3drawImageStack(V3dImageStack[ScalarType]):
279
+ class V3drawImageStack(V3dImageStack[ScalarType], loader=Raw):
296
280
  """v3draw image stack."""
297
281
 
298
- def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
299
- super().__init__(fname, loader=Raw, dtype=dtype, **kwargs)
300
-
301
282
 
302
- class V3dpbdImageStack(V3dImageStack[ScalarType]):
283
+ class V3dpbdImageStack(V3dImageStack[ScalarType], loader=PBD):
303
284
  """v3dpbd image stack."""
304
285
 
305
- def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
306
- super().__init__(fname, loader=PBD, dtype=dtype, **kwargs)
307
-
308
286
 
309
287
  class TeraflyImageStack(ImageStack[ScalarType]):
310
288
  """TeraFly image stack.
@@ -312,21 +290,17 @@ class TeraflyImageStack(ImageStack[ScalarType]):
312
290
  TeraFly is a terabytes of multidimensional volumetric images file
313
291
  format as described in [1]_.
314
292
 
315
- References
316
- ----------
317
- .. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and
318
- Hanchuan Peng. “TeraFly: Real-Time Three-Dimensional
319
- Visualization and Annotation of Terabytes of Multidimensional
320
- Volumetric Images.” Nature Methods 13,
321
- no. 3 (March 2016): 192-94. https://doi.org/10.1038/nmeth.3767.
322
-
323
- Notes
324
- -----
325
- Terafly and Vaa3d use a especial right-handed coordinate system
293
+ NOTE: Terafly and Vaa3d use a especial right-handed coordinate system
326
294
  (with origin point in the left-top and z-axis points front), but we
327
- flip y-axis to makes it a left-handed coordinate system (with orgin
295
+ flip y-axis to makes it a left-handed coordinate system (with origin
328
296
  point in the left-bottom and z-axis points front). If you need to
329
297
  use its coordinate system, remember to FLIP Y-AXIS BACK.
298
+
299
+ References:
300
+ .. [1] Bria, Alessandro, Giulio Iannello, Leonardo Onofri, and Hanchuan Peng.
301
+ “TeraFly: Real-Time Three-Dimensional Visualization and Annotation of Terabytes
302
+ of Multidimensional Volumetric Images.” Nature Methods 13, no. 3 (March 2016):
303
+ 192-94. https://doi.org/10.1038/nmeth.3767.
330
304
  """
331
305
 
332
306
  _listdir: Callable[[str], list[str]]
@@ -336,17 +310,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
336
310
  self, root: str, *, dtype: ScalarType, lru_maxsize: int | None = 128
337
311
  ) -> None:
338
312
  r"""
339
- Parameters
340
- ----------
341
- root : str
342
- The root of terafly which contains directories named as
343
- `RES(YxXxZ)`.
344
- dtype : np.dtype
345
- lru_maxsize : int or None, default to 128
346
- Forwarding to `functools.lru_cache`. A decompressed array
347
- size of (256, 256, 256, 1), which is the typical size of
348
- terafly image stack, takes about 256 * 256 * 256 * 1 *
349
- 4B = 64MB. A cache size of 128 requires about 8GB memeory.
313
+ Args:
314
+ root: The root of terafly which contains directories named as `RES(YxXxZ)`.
315
+ dtype: np.dtype
316
+ lru_maxsize: Forwarding to `functools.lru_cache`.
317
+ A decompressed array size of (256, 256, 256, 1), which is the typical
318
+ size of terafly image stack, takes about 256 * 256 * 256 * 1 * 4B = 64MB.
319
+ A cache size of 128 requires about 8GB memory.
350
320
  """
351
321
 
352
322
  super().__init__()
@@ -372,11 +342,8 @@ class TeraflyImageStack(ImageStack[ScalarType]):
372
342
  def __getitem__(self, key):
373
343
  """Get images in max resolution.
374
344
 
375
- Examples
376
- --------
377
- ```python
378
- imgs[0, 0, 0, 0] # get value
379
- imgs[0:64, 0:64, 0:64, :] # get patch
345
+ >>> imgs[0, 0, 0, 0] # get value # doctest: +SKIP
346
+ >>> imgs[0:64, 0:64, 0:64, :] # get patch # doctest: +SKIP
380
347
  ```
381
348
  """
382
349
  if not isinstance(key, tuple):
@@ -399,9 +366,8 @@ class TeraflyImageStack(ImageStack[ScalarType]):
399
366
  ) -> npt.NDArray[ScalarType]:
400
367
  """Get patch of image stack.
401
368
 
402
- Returns
403
- -------
404
- patch : array of shape (X, Y, Z, C)
369
+ Returns:
370
+ patch: array of shape (X, Y, Z, C)
405
371
  """
406
372
  if isinstance(strides, int):
407
373
  strides = (strides, strides, strides)
@@ -421,10 +387,9 @@ class TeraflyImageStack(ImageStack[ScalarType]):
421
387
  def find_correspond_imgs(self, p, res_level=-1):
422
388
  """Find the image which contain this point.
423
389
 
424
- Returns
425
- -------
426
- patch : array of shape (X, Y, Z, C)
427
- patch_offset : (int, int, int)
390
+ Returns:
391
+ patch: array of shape (X, Y, Z, C)
392
+ patch_offset: (int, int, int)
428
393
  """
429
394
  p = np.array(p)
430
395
  self._check_params(res_level, p)
@@ -442,14 +407,10 @@ class TeraflyImageStack(ImageStack[ScalarType]):
442
407
  def get_resolutions(cls, root: str) -> tuple[list[Vec3i], list[str], list[Vec3i]]:
443
408
  """Get all resolutions.
444
409
 
445
- Returns
446
- -------
447
- resolutions : List of (int, int, int)
448
- Sequence of sorted resolutions (from small to large).
449
- roots : list[str]
450
- Sequence of root of resolutions respectively.
451
- patch_sizes : List of (int, int, int)
452
- Sequence of patch size of resolutions respectively.
410
+ Returns:
411
+ resolutions: Sequence of sorted resolutions (from small to large).
412
+ roots: Sequence of root of resolutions respectively.
413
+ patch_sizes: Sequence of patch size of resolutions respectively.
453
414
  """
454
415
 
455
416
  roots = list(cls.get_resolution_dirs(root))
@@ -497,13 +458,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
497
458
 
498
459
  res = self.res[res_level]
499
460
  for p in coords:
500
- assert np.less(
501
- [0, 0, 0], p
502
- ).all(), f"indices ({p[0]}, {p[1]}, {p[2]}) out of range (0, 0, 0)"
461
+ assert np.less([0, 0, 0], p).all(), (
462
+ f"indices ({p[0]}, {p[1]}, {p[2]}) out of range (0, 0, 0)"
463
+ )
503
464
 
504
- assert np.greater(
505
- res, p
506
- ).all(), f"indices ({p[0]}, {p[1]}, {p[2]}) out of range ({res[0]}, {res[1]}, {res[2]})"
465
+ assert np.greater(res, p).all(), (
466
+ f"indices ({p[0]}, {p[1]}, {p[2]}) out of range ({res[0]}, {res[1]}, {res[2]})"
467
+ )
507
468
 
508
469
  def _get_range(self, starts, ends, res_level, out):
509
470
  # pylint: disable=too-many-locals
@@ -529,13 +490,13 @@ class TeraflyImageStack(ImageStack[ScalarType]):
529
490
  if shape[1] > lens[1]:
530
491
  starts_y = starts + [0, lens[1], 0]
531
492
  ends_y = np.array([starts[0], ends[1], ends[2]])
532
- ends_y += [min(shape[0], lens[0]), 0, 0] # type: ignore
493
+ ends_y += [min(shape[0], lens[0]), 0, 0]
533
494
  self._get_range(starts_y, ends_y, res_level, out[:, lens[1] :, :])
534
495
 
535
496
  if shape[2] > lens[2]:
536
497
  starts_z = starts + [0, 0, lens[2]]
537
498
  ends_z = np.array([starts[0], starts[1], ends[2]])
538
- ends_z += [min(shape[0], lens[0]), min(shape[1], lens[1]), 0] # type: ignore
499
+ ends_z += [min(shape[0], lens[0]), min(shape[1], lens[1]), 0]
539
500
  self._get_range(starts_z, ends_z, res_level, out[:, :, lens[2] :])
540
501
 
541
502
  def _find_correspond_imgs(self, p, res_level):
@@ -580,14 +541,14 @@ class GrayImageStack:
580
541
  def __init__(self, imgs: ImageStack) -> None:
581
542
  self.imgs = imgs
582
543
 
583
- # fmt: off
584
544
  @overload
585
545
  def __getitem__(self, key: Vec3i) -> np.float32: ...
586
546
  @overload
587
547
  def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> np.float32: ...
588
548
  @overload
589
- def __getitem__(self, key: slice | tuple[slice, slice] | tuple[slice, slice, slice]) -> npt.NDArray[np.float32]: ...
590
- # fmt: on
549
+ def __getitem__(
550
+ self, key: slice | tuple[slice, slice] | tuple[slice, slice, slice]
551
+ ) -> npt.NDArray[np.float32]: ...
591
552
  def __getitem__(self, key):
592
553
  """Get pixel/patch of image stack."""
593
554
  v = self[key]
@@ -601,14 +562,12 @@ class GrayImageStack:
601
562
  return v[:, 0]
602
563
  if v.ndim == 1:
603
564
  return v[0]
604
- raise ValueError("unsupport key")
565
+ raise ValueError("unsupported key")
605
566
 
606
567
  def get_full(self) -> npt.NDArray[np.float32]:
607
568
  """Get full image stack.
608
569
 
609
- Notes
610
- -----
611
- this will load the full image stack into memory.
570
+ NOTE: this will load the full image stack into memory.
612
571
  """
613
572
  return self.imgs.get_full()[:, :, :, 0]
614
573
 
@@ -15,16 +15,16 @@
15
15
 
16
16
  """A series of transformations to compose codes."""
17
17
 
18
- from swcgeom.transforms.base import *
19
- from swcgeom.transforms.branch import *
20
- from swcgeom.transforms.branch_tree import *
21
- from swcgeom.transforms.geometry import *
22
- from swcgeom.transforms.image_preprocess import *
23
- from swcgeom.transforms.image_stack import *
24
- from swcgeom.transforms.images import *
25
- from swcgeom.transforms.mst import *
26
- from swcgeom.transforms.neurolucida_asc import *
27
- from swcgeom.transforms.path import *
28
- from swcgeom.transforms.population import *
29
- from swcgeom.transforms.tree import *
30
- from swcgeom.transforms.tree_assembler import *
18
+ from swcgeom.transforms.base import * # noqa: F403
19
+ from swcgeom.transforms.branch import * # noqa: F403
20
+ from swcgeom.transforms.branch_tree import * # noqa: F403
21
+ from swcgeom.transforms.geometry import * # noqa: F403
22
+ from swcgeom.transforms.image_preprocess import * # noqa: F403
23
+ from swcgeom.transforms.image_stack import * # noqa: F403
24
+ from swcgeom.transforms.images import * # noqa: F403
25
+ from swcgeom.transforms.mst import * # noqa: F403
26
+ from swcgeom.transforms.neurolucida_asc import * # noqa: F403
27
+ from swcgeom.transforms.path import * # noqa: F403
28
+ from swcgeom.transforms.population import * # noqa: F403
29
+ from swcgeom.transforms.tree import * # noqa: F403
30
+ from swcgeom.transforms.tree_assembler import * # noqa: F403
@@ -18,12 +18,19 @@
18
18
  from abc import ABC, abstractmethod
19
19
  from typing import Any, Generic, TypeVar, overload
20
20
 
21
+ from typing_extensions import override
22
+
21
23
  __all__ = ["Transform", "Transforms", "Identity"]
22
24
 
23
- T, K = TypeVar("T"), TypeVar("K")
25
+ T = TypeVar("T")
26
+ K = TypeVar("K")
24
27
 
25
- T1, T2, T3 = TypeVar("T1"), TypeVar("T2"), TypeVar("T3") # pylint: disable=invalid-name
26
- T4, T5, T6 = TypeVar("T4"), TypeVar("T5"), TypeVar("T6") # pylint: disable=invalid-name
28
+ T1 = TypeVar("T1")
29
+ T2 = TypeVar("T2")
30
+ T3 = TypeVar("T3")
31
+ T4 = TypeVar("T4")
32
+ T5 = TypeVar("T5")
33
+ T6 = TypeVar("T6")
27
34
 
28
35
 
29
36
  class Transform(ABC, Generic[T, K]):
@@ -36,9 +43,7 @@ class Transform(ABC, Generic[T, K]):
36
43
  def __call__(self, x: T) -> K:
37
44
  """Apply transform.
38
45
 
39
- Notes
40
- -----
41
- All subclasses should overwrite :meth:`__call__`, supporting
46
+ NOTE: All subclasses should overwrite :meth:`__call__`, supporting
42
47
  applying transform in `x`.
43
48
  """
44
49
  raise NotImplementedError()
@@ -57,18 +62,14 @@ class Transform(ABC, Generic[T, K]):
57
62
  which can be particularly useful for debugging and model
58
63
  architecture introspection.
59
64
 
60
- Examples
61
- --------
62
- class Foo(Transform[T, K]):
63
- def __init__(self, my_parameter: int = 1):
64
- self.my_parameter = my_parameter
65
-
66
- def extra_repr(self) -> str:
67
- return f"my_parameter={self.my_parameter}"
65
+ >>> class Foo(Transform[T, K]):
66
+ ... def __init__(self, my_parameter: int = 1):
67
+ ... self.my_parameter = my_parameter
68
+ ...
69
+ ... def extra_repr(self) -> str:
70
+ ... return f"my_parameter={self.my_parameter}"
68
71
 
69
- Notes
70
- -----
71
- This method should be overridden in custom modules to provide
72
+ NOTE: This method should be overridden in custom modules to provide
72
73
  specific details relevant to the module's functionality and
73
74
  configuration.
74
75
  """
@@ -80,7 +81,7 @@ class Transforms(Transform[T, K]):
80
81
 
81
82
  transforms: list[Transform[Any, Any]]
82
83
 
83
- # fmt:off
84
+ # fmt: off
84
85
  @overload
85
86
  def __init__(self, t1: Transform[T, K], /) -> None: ...
86
87
  @overload
@@ -100,11 +101,16 @@ class Transforms(Transform[T, K]):
100
101
  t3: Transform[T2, T3], t4: Transform[T3, T4],
101
102
  t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
102
103
  @overload
104
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
105
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
106
+ t5: Transform[T4, T5], t6: Transform[T5, T6],
107
+ t7: Transform[T6, K], /) -> None: ...
108
+ @overload
103
109
  def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
104
110
  t3: Transform[T2, T3], t4: Transform[T3, T4],
105
111
  t5: Transform[T4, T5], t6: Transform[T5, T6],
106
112
  t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
107
- # fmt:on
113
+ # fmt: on
108
114
  def __init__(self, *transforms: Transform[Any, Any]) -> None:
109
115
  trans = []
110
116
  for t in transforms:
@@ -114,6 +120,7 @@ class Transforms(Transform[T, K]):
114
120
  trans.append(t)
115
121
  self.transforms = trans
116
122
 
123
+ @override
117
124
  def __call__(self, x: T) -> K:
118
125
  """Apply transforms."""
119
126
  for transform in self.transforms:
@@ -127,6 +134,7 @@ class Transforms(Transform[T, K]):
127
134
  def __len__(self) -> int:
128
135
  return len(self.transforms)
129
136
 
137
+ @override
130
138
  def extra_repr(self) -> str:
131
139
  return ", ".join([str(transform) for transform in self])
132
140
 
@@ -134,5 +142,6 @@ class Transforms(Transform[T, K]):
134
142
  class Identity(Transform[T, T]):
135
143
  """Resurn input as-is."""
136
144
 
145
+ @override
137
146
  def __call__(self, x: T) -> T:
138
147
  return x
@@ -21,6 +21,7 @@ from typing import cast
21
21
  import numpy as np
22
22
  import numpy.typing as npt
23
23
  from scipy import signal
24
+ from typing_extensions import override
24
25
 
25
26
  from swcgeom.core import Branch, DictSWC
26
27
  from swcgeom.transforms.base import Transform
@@ -40,13 +41,13 @@ __all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
40
41
  class _BranchResampler(Transform[Branch, Branch], ABC):
41
42
  r"""Resample branch."""
42
43
 
44
+ @override
43
45
  def __call__(self, x: Branch) -> Branch:
44
46
  xyzr = self.resample(x.xyzr())
45
47
  return Branch.from_xyzr(xyzr)
46
48
 
47
49
  @abstractmethod
48
- def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
49
- raise NotImplementedError()
50
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
50
51
 
51
52
 
52
53
  class BranchLinearResampler(_BranchResampler):
@@ -55,26 +56,21 @@ class BranchLinearResampler(_BranchResampler):
55
56
  def __init__(self, n_nodes: int) -> None:
56
57
  """Resample branch to special num of nodes.
57
58
 
58
- Parameters
59
- ----------
60
- n_nodes : int
61
- Number of nodes after resample.
59
+ Args:
60
+ n_nodes: Number of nodes after resample.
62
61
  """
63
62
  super().__init__()
64
63
  self.n_nodes = n_nodes
65
64
 
65
+ @override
66
66
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
67
67
  """Resampling by linear interpolation, DO NOT keep original node.
68
68
 
69
- Parameters
70
- ----------
71
- xyzr : np.ndarray[np.float32]
72
- The array of shape (N, 4).
69
+ Args:
70
+ xyzr: The array of shape (N, 4).
73
71
 
74
- Returns
75
- -------
76
- coordinates : ~numpy.NDArray[float64]
77
- An array of shape (n_nodes, 4).
72
+ Returns:
73
+ coordinates: An array of shape (n_nodes, 4).
78
74
  """
79
75
 
80
76
  xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
@@ -87,6 +83,7 @@ class BranchLinearResampler(_BranchResampler):
87
83
  r = np.interp(xvals, xp, xyzr[:, 3])
88
84
  return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
89
85
 
86
+ @override
90
87
  def extra_repr(self) -> str:
91
88
  return f"n_nodes={self.n_nodes}"
92
89
 
@@ -97,18 +94,15 @@ class BranchIsometricResampler(_BranchResampler):
97
94
  self.distance = distance
98
95
  self.adjust_last_gap = adjust_last_gap
99
96
 
97
+ @override
100
98
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
101
99
  """Resampling by isometric interpolation, DO NOT keep original node.
102
100
 
103
- Parameters
104
- ----------
105
- xyzr : np.ndarray[np.float32]
106
- The array of shape (N, 4).
101
+ Args:
102
+ xyzr: The array of shape (N, 4).
107
103
 
108
- Returns
109
- -------
110
- new_xyzr : ~numpy.NDArray[float32]
111
- An array of shape (n_nodes, 4).
104
+ Returns:
105
+ new_xyzr: An array of shape (n_nodes, 4).
112
106
  """
113
107
 
114
108
  # Compute the cumulative distances between consecutive points
@@ -138,6 +132,7 @@ class BranchIsometricResampler(_BranchResampler):
138
132
  new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
139
133
  return new_xyzr
140
134
 
135
+ @override
141
136
  def extra_repr(self) -> str:
142
137
  return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
143
138
 
@@ -147,15 +142,14 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
147
142
 
148
143
  def __init__(self, n_nodes: int = 5) -> None:
149
144
  """
150
- Parameters
151
- ----------
152
- n_nodes : int, default `5`
153
- Window size.
145
+ Args:
146
+ n_nodes: Window size.
154
147
  """
155
148
  super().__init__()
156
149
  self.n_nodes = n_nodes
157
150
  self.kernel = np.ones(n_nodes)
158
151
 
152
+ @override
159
153
  def __call__(self, x: Branch) -> Branch[DictSWC]:
160
154
  x = x.detach()
161
155
  c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
@@ -173,10 +167,11 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
173
167
  class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
174
168
  r"""Standardize branch.
175
169
 
176
- Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at
177
- y, and scale max radius to 1.
170
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
171
+ radius to 1.
178
172
  """
179
173
 
174
+ @override
180
175
  def __call__(self, x: Branch) -> Branch:
181
176
  xyzr = x.xyzr()
182
177
  xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
@@ -191,23 +186,18 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
191
186
  def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
192
187
  r"""Get standardize transformation matrix.
193
188
 
194
- Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up
195
- at y.
189
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
196
190
 
197
- Parameters
198
- ----------
199
- xyz : np.ndarray[np.float32]
200
- The `x`, `y`, `z` matrix of shape (N, 3) of branch.
191
+ Args:
192
+ xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
201
193
 
202
- Returns
203
- -------
204
- T : np.ndarray[np.float32]
205
- An homogeneous transformation matrix of shape (4, 4).
194
+ Returns:
195
+ T: An homogeneous transformation matrix of shape (4, 4).
206
196
  """
207
197
 
208
- assert (
209
- xyz.ndim == 2 and xyz.shape[1] == 3
210
- ), f"xyz should be of shape (N, 3), got {xyz.shape}"
198
+ assert xyz.ndim == 2 and xyz.shape[1] == 3, (
199
+ f"xyz should be of shape (N, 3), got {xyz.shape}"
200
+ )
211
201
 
212
202
  xyz = xyz[:, :3]
213
203
  T = np.identity(4)