swcgeom 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

swcgeom/images/folder.py CHANGED
@@ -2,14 +2,25 @@
2
2
 
3
3
  import os
4
4
  import re
5
+ import warnings
5
6
  from abc import ABC, abstractmethod
6
- from typing import Callable, Generic, Iterable, List, Literal, Optional, Tuple, TypeVar
7
+ from typing import (
8
+ Callable,
9
+ Generic,
10
+ Iterable,
11
+ List,
12
+ Literal,
13
+ Optional,
14
+ Tuple,
15
+ TypeVar,
16
+ overload,
17
+ )
7
18
 
8
19
  import numpy as np
9
20
  import numpy.typing as npt
10
21
  from typing_extensions import Self
11
22
 
12
- from swcgeom.images.io import read_imgs
23
+ from swcgeom.images.io import ScalarType, read_imgs
13
24
  from swcgeom.transforms import Identity, Transform
14
25
 
15
26
  __all__ = [
@@ -21,20 +32,23 @@ __all__ = [
21
32
  T = TypeVar("T")
22
33
 
23
34
 
24
- class ImageStackFolderBase(Generic[T], ABC):
35
+ class ImageStackFolderBase(Generic[ScalarType, T], ABC):
25
36
  """Image stack folder base."""
26
37
 
27
38
  files: List[str]
28
- transform: Transform[npt.NDArray[np.float32], T]
39
+ transform: Transform[npt.NDArray[ScalarType], T]
29
40
 
30
- def __init__(
31
- self,
32
- files: Iterable[str],
33
- *,
34
- transform: Optional[Transform[npt.NDArray[np.float32], T]] = None,
35
- ) -> None:
41
+ # fmt: off
42
+ @overload
43
+ def __init__(self, files: Iterable[str], *, dtype: None = ..., transform: Optional[Transform[npt.NDArray[np.float32], T]] = None) -> None: ...
44
+ @overload
45
+ def __init__(self, files: Iterable[str], *, dtype: ScalarType, transform: Optional[Transform[npt.NDArray[ScalarType], T]] = None) -> None: ...
46
+ # fmt: on
47
+
48
+ def __init__(self, files: Iterable[str], *, dtype=None, transform=None) -> None:
36
49
  super().__init__()
37
50
  self.files = list(files)
51
+ self.dtype = dtype or np.float32
38
52
  self.transform = transform or Identity() # type: ignore
39
53
 
40
54
  @abstractmethod
@@ -45,13 +59,12 @@ class ImageStackFolderBase(Generic[T], ABC):
45
59
  return len(self.files)
46
60
 
47
61
  def _get(self, fname: str) -> T:
48
- imgs = self.read_imgs(fname)
62
+ imgs = self._read(fname)
49
63
  imgs = self.transform(imgs)
50
64
  return imgs
51
65
 
52
- @staticmethod
53
- def read_imgs(fname: str) -> npt.NDArray[np.float32]:
54
- return read_imgs(fname).get_full()
66
+ def _read(self, fname: str) -> npt.NDArray[ScalarType]:
67
+ return read_imgs(fname, dtype=self.dtype).get_full() # type: ignore
55
68
 
56
69
  @staticmethod
57
70
  def scan(root: str, *, pattern: Optional[str] = None) -> List[str]:
@@ -63,8 +76,20 @@ class ImageStackFolderBase(Generic[T], ABC):
63
76
 
64
77
  return fs
65
78
 
79
+ @staticmethod
80
+ def read_imgs(fname: str) -> npt.NDArray[np.float32]:
81
+ warnings.warn(
82
+ "`ImageStackFolderBase.read_imgs` serves as a "
83
+ "straightforward wrapper for `~swcgeom.images.io.read_imgs(fname).get_full()`. "
84
+ "However, as it is not utilized within our internal "
85
+ "processes, it is scheduled for removal in the "
86
+ "forthcoming version.",
87
+ DeprecationWarning,
88
+ )
89
+ return read_imgs(fname).get_full()
90
+
66
91
 
67
- class ImageStackFolder(Generic[T], ImageStackFolderBase[T]):
92
+ class ImageStackFolder(ImageStackFolderBase[ScalarType, T]):
68
93
  """Image stack folder."""
69
94
 
70
95
  def __getitem__(self, idx: int, /) -> T:
@@ -84,7 +109,7 @@ class ImageStackFolder(Generic[T], ImageStackFolderBase[T]):
84
109
  return cls(cls.scan(root, pattern=pattern), **kwargs)
85
110
 
86
111
 
87
- class LabeledImageStackFolder(Generic[T], ImageStackFolderBase[T]):
112
+ class LabeledImageStackFolder(ImageStackFolderBase[ScalarType, T]):
88
113
  """Image stack folder with label."""
89
114
 
90
115
  labels: List[int]
@@ -115,7 +140,7 @@ class LabeledImageStackFolder(Generic[T], ImageStackFolderBase[T]):
115
140
  return cls(files, labels, **kwargs)
116
141
 
117
142
 
118
- class PathImageStackFolder(Generic[T], ImageStackFolder[T]):
143
+ class PathImageStackFolder(ImageStackFolder[ScalarType, T]):
119
144
  """Image stack folder with relpath."""
120
145
 
121
146
  root: str
swcgeom/images/io.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Read and write image stack."""
2
2
 
3
-
4
3
  import os
5
4
  import re
6
5
  import warnings
@@ -9,11 +8,13 @@ from functools import cache, lru_cache
9
8
  from typing import (
10
9
  Any,
11
10
  Callable,
11
+ Generic,
12
12
  Iterable,
13
13
  List,
14
14
  Literal,
15
15
  Optional,
16
16
  Tuple,
17
+ TypeVar,
17
18
  cast,
18
19
  overload,
19
20
  )
@@ -27,6 +28,8 @@ from v3dpy.loaders import PBD, Raw
27
28
  __all__ = ["read_imgs", "save_tiff", "read_images"]
28
29
 
29
30
  Vec3i = Tuple[int, int, int]
31
+ ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
32
+
30
33
  RE_TERAFLY_ROOT = re.compile(r"^RES\((\d+)x(\d+)x(\d+)\)$")
31
34
  RE_TERAFLY_NAME = re.compile(r"^\d+(_\d+)?(_\d+)?")
32
35
 
@@ -46,30 +49,30 @@ AXES_ORDER = {
46
49
  }
47
50
 
48
51
 
49
- class ImageStack(ABC):
52
+ class ImageStack(ABC, Generic[ScalarType]):
50
53
  """Image stack."""
51
54
 
52
55
  # fmt: off
53
56
  @overload
54
57
  @abstractmethod
55
- def __getitem__(self, key: int) -> npt.NDArray[np.float32]: ... # array of shape (Y, Z, C)
58
+ def __getitem__(self, key: int) -> npt.NDArray[ScalarType]: ... # array of shape (Y, Z, C)
56
59
  @overload
57
60
  @abstractmethod
58
- def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[np.float32]: ... # array of shape (Z, C)
61
+ def __getitem__(self, key: Tuple[int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (Z, C)
59
62
  @overload
60
63
  @abstractmethod
61
- def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[np.float32]: ... # array of shape (C,)
64
+ def __getitem__(self, key: Tuple[int, int, int]) -> npt.NDArray[ScalarType]: ... # array of shape (C,)
62
65
  @overload
63
66
  @abstractmethod
64
- def __getitem__(self, key: Tuple[int, int, int, int]) -> np.float32: ... # value
67
+ def __getitem__(self, key: Tuple[int, int, int, int]) -> ScalarType: ... # value
65
68
  @overload
66
69
  @abstractmethod
67
70
  def __getitem__(
68
71
  self, key: slice | Tuple[slice, slice] | Tuple[slice, slice, slice] | Tuple[slice, slice, slice, slice],
69
- ) -> npt.NDArray[np.float32]: ... # array of shape (X, Y, Z, C)
72
+ ) -> npt.NDArray[ScalarType]: ... # array of shape (X, Y, Z, C)
70
73
  @overload
71
74
  @abstractmethod
72
- def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[np.float32]: ...
75
+ def __getitem__(self, key: npt.NDArray[np.integer[Any]]) -> npt.NDArray[ScalarType]: ...
73
76
  # fmt: on
74
77
  @abstractmethod
75
78
  def __getitem__(self, key):
@@ -82,7 +85,7 @@ class ImageStack(ABC):
82
85
  """
83
86
  raise NotImplementedError()
84
87
 
85
- def get_full(self) -> npt.NDArray[np.float32]:
88
+ def get_full(self) -> npt.NDArray[ScalarType]:
86
89
  """Get full image stack.
87
90
 
88
91
  Notes
@@ -96,8 +99,19 @@ class ImageStack(ABC):
96
99
  raise NotImplementedError()
97
100
 
98
101
 
99
- def read_imgs(fname: str, **kwargs) -> ImageStack:
102
+ # fmt:off
103
+ @overload
104
+ def read_imgs(fname: str, *, dtype: ScalarType, **kwargs) -> ImageStack[ScalarType]: ...
105
+ @overload
106
+ def read_imgs(fname: str, *, dtype: None =..., **kwargs) -> ImageStack[np.float32]: ...
107
+ # fmt:on
108
+
109
+
110
+ def read_imgs(fname: str, *, dtype=None, **kwargs): # type: ignore
100
111
  """Read image stack."""
112
+
113
+ kwargs["dtype"] = dtype or np.float32
114
+
101
115
  ext = os.path.splitext(fname)[-1]
102
116
  if ext in [".tif", ".tiff"]:
103
117
  return TiffImageStack(fname, **kwargs)
@@ -191,7 +205,7 @@ def save_tiff(
191
205
  tifffile.imwrite(fname, data, **kwargs)
192
206
 
193
207
 
194
- class NDArrayImageStack(ImageStack):
208
+ class NDArrayImageStack(ImageStack[ScalarType]):
195
209
  """NDArray image stack."""
196
210
 
197
211
  def __init__(
@@ -199,6 +213,8 @@ class NDArrayImageStack(ImageStack):
199
213
  imgs: npt.NDArray[Any],
200
214
  swap_xy: Optional[bool] = None,
201
215
  filp_xy: Optional[bool] = None,
216
+ *,
217
+ dtype: ScalarType,
202
218
  ) -> None:
203
219
  super().__init__()
204
220
 
@@ -206,10 +222,6 @@ class NDArrayImageStack(ImageStack):
206
222
  imgs = np.expand_dims(imgs, -1)
207
223
  assert imgs.ndim == 4, "Should be shape of (X, Y, Z, C)"
208
224
 
209
- sclar_factor = 1.0
210
- if np.issubdtype((dtype := imgs.dtype), np.unsignedinteger):
211
- sclar_factor /= UINT_MAX[dtype]
212
-
213
225
  if swap_xy is not None:
214
226
  warnings.warn(
215
227
  "flag `swap_xy` now is unnecessary, tifffile will "
@@ -231,12 +243,18 @@ class NDArrayImageStack(ImageStack):
231
243
  if filp_xy is True:
232
244
  imgs = np.flip(imgs, (0, 1)) # (X, Y, Z, C)
233
245
 
234
- self.imgs = imgs.astype(np.float32) * sclar_factor
246
+ dtype_raw = imgs.dtype
247
+ self.imgs = imgs.astype(dtype)
248
+ if np.issubdtype(dtype, np.floating) and np.issubdtype(
249
+ dtype_raw, np.unsignedinteger
250
+ ): # TODO: add a option to disable this
251
+ sclar_factor = 1.0 / UINT_MAX[imgs.dtype]
252
+ self.imgs *= sclar_factor
235
253
 
236
254
  def __getitem__(self, key):
237
255
  return self.imgs.__getitem__(key)
238
256
 
239
- def get_full(self) -> npt.NDArray[np.float32]:
257
+ def get_full(self) -> npt.NDArray[ScalarType]:
240
258
  return self.imgs
241
259
 
242
260
  @property
@@ -244,7 +262,7 @@ class NDArrayImageStack(ImageStack):
244
262
  return cast(Tuple[int, int, int, int], self.imgs.shape)
245
263
 
246
264
 
247
- class TiffImageStack(NDArrayImageStack):
265
+ class TiffImageStack(NDArrayImageStack[ScalarType]):
248
266
  """Tiff image stack."""
249
267
 
250
268
  def __init__(
@@ -252,6 +270,8 @@ class TiffImageStack(NDArrayImageStack):
252
270
  fname: str,
253
271
  swap_xy: Optional[bool] = None,
254
272
  filp_xy: Optional[bool] = None,
273
+ *,
274
+ dtype: ScalarType,
255
275
  **kwargs,
256
276
  ) -> None:
257
277
  with tifffile.TiffFile(fname, **kwargs) as f:
@@ -265,10 +285,10 @@ class TiffImageStack(NDArrayImageStack):
265
285
 
266
286
  orders = [AXES_ORDER[c] for c in axes]
267
287
  imgs = imgs.transpose(np.argsort(orders))
268
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
288
+ super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
269
289
 
270
290
 
271
- class NrrdImageStack(NDArrayImageStack):
291
+ class NrrdImageStack(NDArrayImageStack[ScalarType]):
272
292
  """Nrrd image stack."""
273
293
 
274
294
  def __init__(
@@ -276,37 +296,41 @@ class NrrdImageStack(NDArrayImageStack):
276
296
  fname: str,
277
297
  swap_xy: Optional[bool] = None,
278
298
  filp_xy: Optional[bool] = None,
299
+ *,
300
+ dtype: ScalarType,
279
301
  **kwargs,
280
302
  ) -> None:
281
303
  imgs, header = nrrd.read(fname, **kwargs)
282
- super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy)
304
+ super().__init__(imgs, swap_xy=swap_xy, filp_xy=filp_xy, dtype=dtype)
283
305
  self.header = header
284
306
 
285
307
 
286
- class V3dImageStack(NDArrayImageStack):
308
+ class V3dImageStack(NDArrayImageStack[ScalarType]):
287
309
  """v3d image stack."""
288
310
 
289
- def __init__(self, fname: str, loader: Raw | PBD, **kwargs) -> None:
311
+ def __init__(
312
+ self, fname: str, loader: Raw | PBD, *, dtype: ScalarType, **kwargs
313
+ ) -> None:
290
314
  r = loader()
291
315
  imgs = r.load(fname)
292
- super().__init__(imgs, **kwargs)
316
+ super().__init__(imgs, dtype=dtype, **kwargs)
293
317
 
294
318
 
295
- class V3drawImageStack(V3dImageStack):
319
+ class V3drawImageStack(V3dImageStack[ScalarType]):
296
320
  """v3draw image stack."""
297
321
 
298
- def __init__(self, fname: str, **kwargs) -> None:
299
- super().__init__(fname, loader=Raw, **kwargs)
322
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
323
+ super().__init__(fname, loader=Raw, dtype=dtype, **kwargs)
300
324
 
301
325
 
302
- class V3dpbdImageStack(V3dImageStack):
326
+ class V3dpbdImageStack(V3dImageStack[ScalarType]):
303
327
  """v3dpbd image stack."""
304
328
 
305
- def __init__(self, fname: str, **kwargs) -> None:
306
- super().__init__(fname, loader=PBD, **kwargs)
329
+ def __init__(self, fname: str, *, dtype: ScalarType, **kwargs) -> None:
330
+ super().__init__(fname, loader=PBD, dtype=dtype, **kwargs)
307
331
 
308
332
 
309
- class TeraflyImageStack(ImageStack):
333
+ class TeraflyImageStack(ImageStack[ScalarType]):
310
334
  """TeraFly image stack.
311
335
 
312
336
  References
@@ -328,21 +352,26 @@ class TeraflyImageStack(ImageStack):
328
352
  _listdir: Callable[[str], List[str]]
329
353
  _read_patch: Callable[[str], npt.NDArray]
330
354
 
331
- def __init__(self, root: str, *, lru_maxsize: int | None = 128) -> None:
355
+ def __init__(
356
+ self, root: str, *, dtype: ScalarType, lru_maxsize: int | None = 128
357
+ ) -> None:
332
358
  r"""
333
359
  Parameters
334
360
  ----------
335
361
  root : str
336
362
  The root of terafly which contains directories named as
337
363
  `RES(YxXxZ)`.
364
+ dtype : np.dtype
338
365
  lru_maxsize : int or None, default to 128
339
366
  Forwarding to `functools.lru_cache`. A decompressed array
340
367
  size of (256, 256, 256, 1), which is the typical size of
341
368
  terafly image stack, takes about 256 * 256 * 256 * 1 *
342
369
  4B = 64MB. A cache size of 128 requires about 8GB memeory.
343
370
  """
371
+
344
372
  super().__init__()
345
373
  self.root = root
374
+ self.dtype = dtype
346
375
  self.res, self.res_dirs, self.res_patch_sizes = self.get_resolutions(root)
347
376
 
348
377
  @cache
@@ -350,8 +379,8 @@ class TeraflyImageStack(ImageStack):
350
379
  return os.listdir(path)
351
380
 
352
381
  @lru_cache(maxsize=lru_maxsize)
353
- def read_patch(path: str) -> npt.NDArray[np.float32]:
354
- return read_imgs(path).get_full()
382
+ def read_patch(path: str) -> npt.NDArray[ScalarType]:
383
+ return read_imgs(path, dtype=dtype).get_full()
355
384
 
356
385
  self._listdir, self._read_patch = listdir, read_patch
357
386
 
@@ -382,7 +411,7 @@ class TeraflyImageStack(ImageStack):
382
411
 
383
412
  def get_patch(
384
413
  self, starts, ends, strides: int | Vec3i = 1, res_level=-1
385
- ) -> npt.NDArray[np.float32]:
414
+ ) -> npt.NDArray[ScalarType]:
386
415
  """Get patch of image stack.
387
416
 
388
417
  Returns
@@ -397,7 +426,7 @@ class TeraflyImageStack(ImageStack):
397
426
  assert np.equal(strides, [1, 1, 1]).all() # TODO: support stride
398
427
 
399
428
  shape_out = np.concatenate([ends - starts, [1]])
400
- out = np.zeros(shape_out, dtype=np.float32)
429
+ out = np.zeros(shape_out, dtype=self.dtype)
401
430
  self._get_range(starts, ends, res_level, out=out)
402
431
 
403
432
  # flip y-axis to makes it a left-handed coordinate system
@@ -1,7 +1,7 @@
1
1
  """Transformation in tree."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import Any, Generic, TypeVar, cast, overload
4
+ from typing import Any, Generic, TypeVar, overload
5
5
 
6
6
  __all__ = ["Transform", "Transforms", "Identity"]
7
7
 
@@ -15,26 +15,49 @@ class Transform(ABC, Generic[T, K]):
15
15
  r"""An abstract class representing a :class:`Transform`.
16
16
 
17
17
  All transforms that represent a map from `T` to `K`.
18
-
19
- Methods
20
- -------
21
- __call__(x: T) -> K
22
- All subclasses should overwrite :meth:`__call__`, supporting
23
- applying transform in `x`.
24
- __repr__() -> str
25
- Subclasses could also optionally overwrite :meth:`__repr__`.
26
- Avoid using the underscore `_` because it is used by
27
- `Transforms`. If not provied, class name will be a default
28
- value.
29
18
  """
30
19
 
31
20
  @abstractmethod
32
21
  def __call__(self, x: T) -> K:
33
- """Apply transform."""
22
+ """Apply transform.
23
+
24
+ Notes
25
+ -----
26
+ All subclasses should overwrite :meth:`__call__`, supporting
27
+ applying transform in `x`.
28
+ """
34
29
  raise NotImplementedError()
35
30
 
36
31
  def __repr__(self) -> str:
37
- return self.__class__.__name__
32
+ classname = self.__class__.__name__
33
+ repr_ = self.extra_repr()
34
+ return f"{classname}({repr_})"
35
+
36
+ def extra_repr(self):
37
+ """Provides a human-friendly representation of the module.
38
+
39
+ This method extends the basic string representation provided by
40
+ `__repr__` method. It is designed to display additional details
41
+ about the module's parameters or its specific configuration,
42
+ which can be particularly useful for debugging and model
43
+ architecture introspection.
44
+
45
+ Examples
46
+ --------
47
+ class Foo(Transform[T, K]):
48
+ def __init__(self, my_parameter: int = 1):
49
+ self.my_parameter = my_parameter
50
+
51
+ def extra_repr(self):
52
+ return f"my_parameter={self.my_parameter}"
53
+
54
+ Notes
55
+ -----
56
+ This method should be overridden in custom modules to provide
57
+ specific details relevant to the module's functionality and
58
+ configuration.
59
+ """
60
+ return ""
38
61
 
39
62
 
40
63
  class Transforms(Transform[T, K]):
@@ -81,23 +104,20 @@ class Transforms(Transform[T, K]):
81
104
  for transform in self.transforms:
82
105
  x = transform(x)
83
106
 
84
- return cast(K, x)
107
+ return x # type: ignore
85
108
 
86
109
  def __getitem__(self, idx: int) -> Transform[Any, Any]:
87
110
  return self.transforms[idx]
88
111
 
89
- def __repr__(self) -> str:
90
- return "_".join([str(transform) for transform in self])
91
-
92
112
  def __len__(self) -> int:
93
113
  return len(self.transforms)
94
114
 
115
+ def extra_repr(self) -> str:
116
+ return ", ".join([str(transform) for transform in self])
117
+
95
118
 
96
119
  class Identity(Transform[T, T]):
97
120
  """Resurn input as-is."""
98
121
 
99
122
  def __call__(self, x: T) -> T:
100
123
  return x
101
-
102
- def __repr__(self) -> str:
103
- return ""
@@ -48,9 +48,6 @@ class BranchLinearResampler(_BranchResampler):
48
48
  super().__init__()
49
49
  self.n_nodes = n_nodes
50
50
 
51
- def __repr__(self) -> str:
52
- return f"BranchLinearResampler-{self.n_nodes}"
53
-
54
51
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
55
52
  """Resampling by linear interpolation, DO NOT keep original node.
56
53
 
@@ -75,6 +72,9 @@ class BranchLinearResampler(_BranchResampler):
75
72
  r = np.interp(xvals, xp, xyzr[:, 3])
76
73
  return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
77
74
 
75
+ def extra_repr(self):
76
+ return f"n_nodes={self.n_nodes}"
77
+
78
78
 
79
79
  class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
80
80
  r"""Smooth the branch by sliding window."""
@@ -100,8 +100,8 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
100
100
 
101
101
  return x
102
102
 
103
- def __repr__(self) -> str:
104
- return f"BranchConvSmoother-{self.n_nodes}"
103
+ def extra_repr(self):
104
+ return f"n_nodes={self.n_nodes}"
105
105
 
106
106
 
107
107
  class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
@@ -72,6 +72,9 @@ class RadiusReseter(Generic[T], Transform[T, T]):
72
72
  new_tree.ndata[new_tree.names.r] = r
73
73
  return new_tree
74
74
 
75
+ def extra_repr(self):
76
+ return f"r={self.r:.4f}"
77
+
75
78
 
76
79
  class AffineTransform(Generic[T], Transform[T, T]):
77
80
  """Apply affine matrix."""
@@ -85,10 +88,18 @@ class AffineTransform(Generic[T], Transform[T, T]):
85
88
  tm: npt.NDArray[np.float32],
86
89
  center: Center = "origin",
87
90
  *,
88
- fmt: str,
91
+ fmt: Optional[str] = None,
89
92
  names: Optional[SWCNames] = None,
90
93
  ) -> None:
91
- self.tm, self.center, self.fmt = tm, center, fmt
94
+ self.tm, self.center = tm, center
95
+
96
+ if fmt is not None:
97
+ warnings.warn(
98
+ "`fmt` parameter is no longer needed, now use the "
99
+ "extra_repr(), you can directly remove it.",
100
+ DeprecationWarning,
101
+ )
102
+
92
103
  if names is not None:
93
104
  warnings.warn(
94
105
  "`name` parameter is no longer needed, now use the "
@@ -111,9 +122,6 @@ class AffineTransform(Generic[T], Transform[T, T]):
111
122
 
112
123
  return self.apply(x, tm)
113
124
 
114
- def __repr__(self) -> str:
115
- return self.fmt
116
-
117
125
  @staticmethod
118
126
  def apply(x: T, tm: npt.NDArray[np.float32]) -> T:
119
127
  xyzw = x.xyzw().dot(tm.T).T
@@ -130,8 +138,11 @@ class Translate(Generic[T], AffineTransform[T]):
130
138
  """Translate SWC."""
131
139
 
132
140
  def __init__(self, tx: float, ty: float, tz: float, **kwargs) -> None:
133
- fmt = f"Translate-{tx}-{ty}-{tz}"
134
- super().__init__(translate3d(tx, ty, tz), fmt=fmt, **kwargs)
141
+ super().__init__(translate3d(tx, ty, tz), **kwargs)
142
+ self.tx, self.ty, self.tz = tx, ty, tz
143
+
144
+ def extra_repr(self):
145
+ return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
135
146
 
136
147
  @classmethod
137
148
  def transform(cls, x: T, tx: float, ty: float, tz: float, **kwargs) -> T:
@@ -158,8 +169,7 @@ class Scale(Generic[T], AffineTransform[T]):
158
169
  def __init__(
159
170
  self, sx: float, sy: float, sz: float, center: Center = "root", **kwargs
160
171
  ) -> None:
161
- fmt = f"Scale-{sx}-{sy}-{sz}"
162
- super().__init__(scale3d(sx, sy, sz), center=center, fmt=fmt, **kwargs)
172
+ super().__init__(scale3d(sx, sy, sz), center=center, **kwargs)
163
173
 
164
174
  @classmethod
165
175
  def transform( # pylint: disable=too-many-arguments
@@ -180,6 +190,12 @@ class Rotate(Generic[T], AffineTransform[T]):
180
190
  ) -> None:
181
191
  fmt = f"Rotate-{n[0]}-{n[1]}-{n[2]}-{theta:.4f}"
182
192
  super().__init__(rotate3d(n, theta), center=center, fmt=fmt, **kwargs)
193
+ self.n = n
194
+ self.theta = theta
195
+ self.center = center
196
+
197
+ def extra_repr(self):
198
+ return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: imporve format of n
183
199
 
184
200
  @classmethod
185
201
  def transform(
@@ -197,9 +213,11 @@ class RotateX(Generic[T], AffineTransform[T]):
197
213
  """Rotate SWC with x-axis."""
198
214
 
199
215
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
200
- super().__init__(
201
- rotate3d_x(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
202
- )
216
+ super().__init__(rotate3d_x(theta), center=center, **kwargs)
217
+ self.theta = theta
218
+
219
+ def extra_repr(self):
220
+ return f"center={self.center}, theta={self.theta:.4f}"
203
221
 
204
222
  @classmethod
205
223
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
@@ -210,9 +228,12 @@ class RotateY(Generic[T], AffineTransform[T]):
210
228
  """Rotate SWC with y-axis."""
211
229
 
212
230
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
213
- super().__init__(
214
- rotate3d_y(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
215
- )
231
+ super().__init__(rotate3d_y(theta), center=center, **kwargs)
232
+ self.theta = theta
233
+ self.center = center
234
+
235
+ def extra_repr(self):
236
+ return f"theta={self.theta:.4f}, center={self.center}"
216
237
 
217
238
  @classmethod
218
239
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T:
@@ -223,9 +244,12 @@ class RotateZ(Generic[T], AffineTransform[T]):
223
244
  """Rotate SWC with z-axis."""
224
245
 
225
246
  def __init__(self, theta: float, center: Center = "root", **kwargs) -> None:
226
- super().__init__(
227
- rotate3d_z(theta), center=center, fmt=f"RotateX-{theta}", **kwargs
228
- )
247
+ super().__init__(rotate3d_z(theta), center=center, **kwargs)
248
+ self.theta = theta
249
+ self.center = center
250
+
251
+ def extra_repr(self):
252
+ return f"theta={self.theta:.4f}, center={self.center}"
229
253
 
230
254
  @classmethod
231
255
  def transform(cls, x: T, theta: float, center: Center = "root", **kwargs) -> T: