swcgeom 0.18.3__py3-none-any.whl → 0.19.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (62) hide show
  1. swcgeom/analysis/feature_extractor.py +22 -24
  2. swcgeom/analysis/features.py +18 -40
  3. swcgeom/analysis/lmeasure.py +227 -323
  4. swcgeom/analysis/sholl.py +17 -23
  5. swcgeom/analysis/trunk.py +23 -28
  6. swcgeom/analysis/visualization.py +37 -44
  7. swcgeom/analysis/visualization3d.py +16 -25
  8. swcgeom/analysis/volume.py +33 -47
  9. swcgeom/core/__init__.py +1 -6
  10. swcgeom/core/branch.py +10 -17
  11. swcgeom/core/branch_tree.py +3 -2
  12. swcgeom/core/compartment.py +1 -1
  13. swcgeom/core/node.py +3 -6
  14. swcgeom/core/path.py +11 -16
  15. swcgeom/core/population.py +32 -51
  16. swcgeom/core/swc.py +25 -16
  17. swcgeom/core/swc_utils/__init__.py +4 -6
  18. swcgeom/core/swc_utils/assembler.py +5 -12
  19. swcgeom/core/swc_utils/base.py +40 -31
  20. swcgeom/core/swc_utils/checker.py +3 -8
  21. swcgeom/core/swc_utils/io.py +32 -47
  22. swcgeom/core/swc_utils/normalizer.py +17 -23
  23. swcgeom/core/swc_utils/subtree.py +13 -20
  24. swcgeom/core/tree.py +61 -51
  25. swcgeom/core/tree_utils.py +36 -49
  26. swcgeom/core/tree_utils_impl.py +4 -6
  27. swcgeom/images/augmentation.py +23 -39
  28. swcgeom/images/contrast.py +22 -46
  29. swcgeom/images/folder.py +32 -34
  30. swcgeom/images/io.py +108 -126
  31. swcgeom/transforms/base.py +28 -19
  32. swcgeom/transforms/branch.py +31 -41
  33. swcgeom/transforms/branch_tree.py +3 -1
  34. swcgeom/transforms/geometry.py +13 -4
  35. swcgeom/transforms/image_preprocess.py +2 -0
  36. swcgeom/transforms/image_stack.py +40 -35
  37. swcgeom/transforms/images.py +31 -24
  38. swcgeom/transforms/mst.py +27 -40
  39. swcgeom/transforms/neurolucida_asc.py +13 -13
  40. swcgeom/transforms/path.py +4 -0
  41. swcgeom/transforms/population.py +4 -0
  42. swcgeom/transforms/tree.py +16 -11
  43. swcgeom/transforms/tree_assembler.py +37 -54
  44. swcgeom/utils/download.py +7 -14
  45. swcgeom/utils/dsu.py +12 -0
  46. swcgeom/utils/ellipse.py +26 -14
  47. swcgeom/utils/file.py +8 -13
  48. swcgeom/utils/neuromorpho.py +78 -92
  49. swcgeom/utils/numpy_helper.py +15 -12
  50. swcgeom/utils/plotter_2d.py +10 -16
  51. swcgeom/utils/plotter_3d.py +7 -9
  52. swcgeom/utils/renderer.py +16 -8
  53. swcgeom/utils/sdf.py +12 -23
  54. swcgeom/utils/solid_geometry.py +58 -2
  55. swcgeom/utils/transforms.py +164 -100
  56. swcgeom/utils/volumetric_object.py +29 -53
  57. {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/METADATA +6 -5
  58. swcgeom-0.19.1.dist-info/RECORD +67 -0
  59. {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/WHEEL +1 -1
  60. swcgeom-0.18.3.dist-info/RECORD +0 -67
  61. {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info/licenses}/LICENSE +0 -0
  62. {swcgeom-0.18.3.dist-info → swcgeom-0.19.1.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,19 @@
18
18
  from abc import ABC, abstractmethod
19
19
  from typing import Any, Generic, TypeVar, overload
20
20
 
21
+ from typing_extensions import override
22
+
21
23
  __all__ = ["Transform", "Transforms", "Identity"]
22
24
 
23
- T, K = TypeVar("T"), TypeVar("K")
25
+ T = TypeVar("T")
26
+ K = TypeVar("K")
24
27
 
25
- T1, T2, T3 = TypeVar("T1"), TypeVar("T2"), TypeVar("T3") # pylint: disable=invalid-name
26
- T4, T5, T6 = TypeVar("T4"), TypeVar("T5"), TypeVar("T6") # pylint: disable=invalid-name
28
+ T1 = TypeVar("T1")
29
+ T2 = TypeVar("T2")
30
+ T3 = TypeVar("T3")
31
+ T4 = TypeVar("T4")
32
+ T5 = TypeVar("T5")
33
+ T6 = TypeVar("T6")
27
34
 
28
35
 
29
36
  class Transform(ABC, Generic[T, K]):
@@ -36,9 +43,7 @@ class Transform(ABC, Generic[T, K]):
36
43
  def __call__(self, x: T) -> K:
37
44
  """Apply transform.
38
45
 
39
- Notes
40
- -----
41
- All subclasses should overwrite :meth:`__call__`, supporting
46
+ NOTE: All subclasses should overwrite :meth:`__call__`, supporting
42
47
  applying transform in `x`.
43
48
  """
44
49
  raise NotImplementedError()
@@ -57,18 +62,14 @@ class Transform(ABC, Generic[T, K]):
57
62
  which can be particularly useful for debugging and model
58
63
  architecture introspection.
59
64
 
60
- Examples
61
- --------
62
- class Foo(Transform[T, K]):
63
- def __init__(self, my_parameter: int = 1):
64
- self.my_parameter = my_parameter
65
-
66
- def extra_repr(self) -> str:
67
- return f"my_parameter={self.my_parameter}"
65
+ >>> class Foo(Transform[T, K]):
66
+ ... def __init__(self, my_parameter: int = 1):
67
+ ... self.my_parameter = my_parameter
68
+ ...
69
+ ... def extra_repr(self) -> str:
70
+ ... return f"my_parameter={self.my_parameter}"
68
71
 
69
- Notes
70
- -----
71
- This method should be overridden in custom modules to provide
72
+ NOTE: This method should be overridden in custom modules to provide
72
73
  specific details relevant to the module's functionality and
73
74
  configuration.
74
75
  """
@@ -80,7 +81,7 @@ class Transforms(Transform[T, K]):
80
81
 
81
82
  transforms: list[Transform[Any, Any]]
82
83
 
83
- # fmt:off
84
+ # fmt: off
84
85
  @overload
85
86
  def __init__(self, t1: Transform[T, K], /) -> None: ...
86
87
  @overload
@@ -100,11 +101,16 @@ class Transforms(Transform[T, K]):
100
101
  t3: Transform[T2, T3], t4: Transform[T3, T4],
101
102
  t5: Transform[T4, T5], t6: Transform[T5, K], /) -> None: ...
102
103
  @overload
104
+ def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
105
+ t3: Transform[T2, T3], t4: Transform[T3, T4],
106
+ t5: Transform[T4, T5], t6: Transform[T5, T6],
107
+ t7: Transform[T6, K], /) -> None: ...
108
+ @overload
103
109
  def __init__(self, t1: Transform[T, T1], t2: Transform[T1, T2],
104
110
  t3: Transform[T2, T3], t4: Transform[T3, T4],
105
111
  t5: Transform[T4, T5], t6: Transform[T5, T6],
106
112
  t7: Transform[T6, Any], /, *transforms: Transform[Any, K]) -> None: ...
107
- # fmt:on
113
+ # fmt: on
108
114
  def __init__(self, *transforms: Transform[Any, Any]) -> None:
109
115
  trans = []
110
116
  for t in transforms:
@@ -114,6 +120,7 @@ class Transforms(Transform[T, K]):
114
120
  trans.append(t)
115
121
  self.transforms = trans
116
122
 
123
+ @override
117
124
  def __call__(self, x: T) -> K:
118
125
  """Apply transforms."""
119
126
  for transform in self.transforms:
@@ -127,6 +134,7 @@ class Transforms(Transform[T, K]):
127
134
  def __len__(self) -> int:
128
135
  return len(self.transforms)
129
136
 
137
+ @override
130
138
  def extra_repr(self) -> str:
131
139
  return ", ".join([str(transform) for transform in self])
132
140
 
@@ -134,5 +142,6 @@ class Transforms(Transform[T, K]):
134
142
  class Identity(Transform[T, T]):
135
143
  """Resurn input as-is."""
136
144
 
145
+ @override
137
146
  def __call__(self, x: T) -> T:
138
147
  return x
@@ -21,6 +21,7 @@ from typing import cast
21
21
  import numpy as np
22
22
  import numpy.typing as npt
23
23
  from scipy import signal
24
+ from typing_extensions import override
24
25
 
25
26
  from swcgeom.core import Branch, DictSWC
26
27
  from swcgeom.transforms.base import Transform
@@ -40,13 +41,13 @@ __all__ = ["BranchLinearResampler", "BranchConvSmoother", "BranchStandardizer"]
40
41
  class _BranchResampler(Transform[Branch, Branch], ABC):
41
42
  r"""Resample branch."""
42
43
 
44
+ @override
43
45
  def __call__(self, x: Branch) -> Branch:
44
46
  xyzr = self.resample(x.xyzr())
45
47
  return Branch.from_xyzr(xyzr)
46
48
 
47
49
  @abstractmethod
48
- def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
49
- raise NotImplementedError()
50
+ def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: ...
50
51
 
51
52
 
52
53
  class BranchLinearResampler(_BranchResampler):
@@ -55,26 +56,21 @@ class BranchLinearResampler(_BranchResampler):
55
56
  def __init__(self, n_nodes: int) -> None:
56
57
  """Resample branch to special num of nodes.
57
58
 
58
- Parameters
59
- ----------
60
- n_nodes : int
61
- Number of nodes after resample.
59
+ Args:
60
+ n_nodes: Number of nodes after resample.
62
61
  """
63
62
  super().__init__()
64
63
  self.n_nodes = n_nodes
65
64
 
65
+ @override
66
66
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
67
67
  """Resampling by linear interpolation, DO NOT keep original node.
68
68
 
69
- Parameters
70
- ----------
71
- xyzr : np.ndarray[np.float32]
72
- The array of shape (N, 4).
69
+ Args:
70
+ xyzr: The array of shape (N, 4).
73
71
 
74
- Returns
75
- -------
76
- coordinates : ~numpy.NDArray[float64]
77
- An array of shape (n_nodes, 4).
72
+ Returns:
73
+ coordinates: An array of shape (n_nodes, 4).
78
74
  """
79
75
 
80
76
  xp = np.cumsum(np.linalg.norm(xyzr[1:, :3] - xyzr[:-1, :3], axis=1))
@@ -87,6 +83,7 @@ class BranchLinearResampler(_BranchResampler):
87
83
  r = np.interp(xvals, xp, xyzr[:, 3])
88
84
  return cast(npt.NDArray[np.float32], np.stack([x, y, z, r], axis=1))
89
85
 
86
+ @override
90
87
  def extra_repr(self) -> str:
91
88
  return f"n_nodes={self.n_nodes}"
92
89
 
@@ -97,18 +94,15 @@ class BranchIsometricResampler(_BranchResampler):
97
94
  self.distance = distance
98
95
  self.adjust_last_gap = adjust_last_gap
99
96
 
97
+ @override
100
98
  def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
101
99
  """Resampling by isometric interpolation, DO NOT keep original node.
102
100
 
103
- Parameters
104
- ----------
105
- xyzr : np.ndarray[np.float32]
106
- The array of shape (N, 4).
101
+ Args:
102
+ xyzr: The array of shape (N, 4).
107
103
 
108
- Returns
109
- -------
110
- new_xyzr : ~numpy.NDArray[float32]
111
- An array of shape (n_nodes, 4).
104
+ Returns:
105
+ new_xyzr: An array of shape (n_nodes, 4).
112
106
  """
113
107
 
114
108
  # Compute the cumulative distances between consecutive points
@@ -138,6 +132,7 @@ class BranchIsometricResampler(_BranchResampler):
138
132
  new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3])
139
133
  return new_xyzr
140
134
 
135
+ @override
141
136
  def extra_repr(self) -> str:
142
137
  return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}"
143
138
 
@@ -147,15 +142,14 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
147
142
 
148
143
  def __init__(self, n_nodes: int = 5) -> None:
149
144
  """
150
- Parameters
151
- ----------
152
- n_nodes : int, default `5`
153
- Window size.
145
+ Args:
146
+ n_nodes: Window size.
154
147
  """
155
148
  super().__init__()
156
149
  self.n_nodes = n_nodes
157
150
  self.kernel = np.ones(n_nodes)
158
151
 
152
+ @override
159
153
  def __call__(self, x: Branch) -> Branch[DictSWC]:
160
154
  x = x.detach()
161
155
  c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same")
@@ -173,10 +167,11 @@ class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]):
173
167
  class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
174
168
  r"""Standardize branch.
175
169
 
176
- Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at
177
- y, and scale max radius to 1.
170
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max
171
+ radius to 1.
178
172
  """
179
173
 
174
+ @override
180
175
  def __call__(self, x: Branch) -> Branch:
181
176
  xyzr = x.xyzr()
182
177
  xyz, r = xyzr[:, 0:3], xyzr[:, 3:4]
@@ -191,23 +186,18 @@ class BranchStandardizer(Transform[Branch, Branch[DictSWC]]):
191
186
  def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
192
187
  r"""Get standardize transformation matrix.
193
188
 
194
- Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up
195
- at y.
189
+ Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y.
196
190
 
197
- Parameters
198
- ----------
199
- xyz : np.ndarray[np.float32]
200
- The `x`, `y`, `z` matrix of shape (N, 3) of branch.
191
+ Args:
192
+ xyz: The `x`, `y`, `z` matrix of shape (N, 3) of branch.
201
193
 
202
- Returns
203
- -------
204
- T : np.ndarray[np.float32]
205
- An homogeneous transformation matrix of shape (4, 4).
194
+ Returns:
195
+ T: An homogeneous transformation matrix of shape (4, 4).
206
196
  """
207
197
 
208
- assert (
209
- xyz.ndim == 2 and xyz.shape[1] == 3
210
- ), f"xyz should be of shape (N, 3), got {xyz.shape}"
198
+ assert xyz.ndim == 2 and xyz.shape[1] == 3, (
199
+ f"xyz should be of shape (N, 3), got {xyz.shape}"
200
+ )
211
201
 
212
202
  xyz = xyz[:, :3]
213
203
  T = np.identity(4)
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Iterable
16
+ from collections.abc import Iterable
17
17
 
18
18
  import numpy as np
19
+ from typing_extensions import override
19
20
 
20
21
  from swcgeom.core import Branch, BranchTree, Node, Tree
21
22
  from swcgeom.transforms.base import Transform
@@ -26,6 +27,7 @@ __all__ = ["BranchTreeAssembler"]
26
27
  class BranchTreeAssembler(Transform[BranchTree, Tree]):
27
28
  EPS = 1e-6
28
29
 
30
+ @override
29
31
  def __call__(self, x: BranchTree) -> Tree:
30
32
  nodes = [x.soma().detach()]
31
33
  stack = [(x.soma(), 0)] # n_orig, id_new
@@ -16,10 +16,11 @@
16
16
  """SWC geometry operations."""
17
17
 
18
18
  import warnings
19
- from typing import Generic, Literal, Optional, TypeVar
19
+ from typing import Generic, Literal, TypeVar
20
20
 
21
21
  import numpy as np
22
22
  import numpy.typing as npt
23
+ from typing_extensions import override
23
24
 
24
25
  from swcgeom.core import DictSWC
25
26
  from swcgeom.core.swc_utils import SWCNames
@@ -54,7 +55,7 @@ Center = Literal["root", "soma", "origin"]
54
55
  class Normalizer(Generic[T], Transform[T, T]):
55
56
  """Noramlize coordinates and radius to 0-1."""
56
57
 
57
- def __init__(self, *, names: Optional[SWCNames] = None) -> None:
58
+ def __init__(self, *, names: SWCNames | None = None) -> None:
58
59
  super().__init__()
59
60
  if names is not None:
60
61
  warnings.warn(
@@ -63,6 +64,7 @@ class Normalizer(Generic[T], Transform[T, T]):
63
64
  DeprecationWarning,
64
65
  )
65
66
 
67
+ @override
66
68
  def __call__(self, x: T) -> T:
67
69
  """Scale the `x`, `y`, `z`, `r` of nodes to 0-1."""
68
70
  new_tree = x.copy()
@@ -87,6 +89,7 @@ class RadiusReseter(Generic[T], Transform[T, T]):
87
89
  new_tree.ndata[new_tree.names.r] = r
88
90
  return new_tree
89
91
 
92
+ @override
90
93
  def extra_repr(self) -> str:
91
94
  return f"r={self.r:.4f}"
92
95
 
@@ -103,8 +106,8 @@ class AffineTransform(Generic[T], Transform[T, T]):
103
106
  tm: npt.NDArray[np.float32],
104
107
  center: Center = "origin",
105
108
  *,
106
- fmt: Optional[str] = None,
107
- names: Optional[SWCNames] = None,
109
+ fmt: str | None = None,
110
+ names: SWCNames | None = None,
108
111
  ) -> None:
109
112
  self.tm, self.center = tm, center
110
113
 
@@ -122,6 +125,7 @@ class AffineTransform(Generic[T], Transform[T, T]):
122
125
  DeprecationWarning,
123
126
  )
124
127
 
128
+ @override
125
129
  def __call__(self, x: T) -> T:
126
130
  match self.center:
127
131
  case "root" | "soma":
@@ -156,6 +160,7 @@ class Translate(Generic[T], AffineTransform[T]):
156
160
  super().__init__(translate3d(tx, ty, tz), **kwargs)
157
161
  self.tx, self.ty, self.tz = tx, ty, tz
158
162
 
163
+ @override
159
164
  def extra_repr(self) -> str:
160
165
  return f"tx={self.tx:.4f}, ty={self.ty:.4f}, tz={self.tz:.4f}"
161
166
 
@@ -209,6 +214,7 @@ class Rotate(Generic[T], AffineTransform[T]):
209
214
  self.theta = theta
210
215
  self.center = center
211
216
 
217
+ @override
212
218
  def extra_repr(self) -> str:
213
219
  return f"n={self.n}, theta={self.theta:.4f}, center={self.center}" # TODO: improve format of n
214
220
 
@@ -231,6 +237,7 @@ class RotateX(Generic[T], AffineTransform[T]):
231
237
  super().__init__(rotate3d_x(theta), center=center, **kwargs)
232
238
  self.theta = theta
233
239
 
240
+ @override
234
241
  def extra_repr(self) -> str:
235
242
  return f"center={self.center}, theta={self.theta:.4f}"
236
243
 
@@ -247,6 +254,7 @@ class RotateY(Generic[T], AffineTransform[T]):
247
254
  self.theta = theta
248
255
  self.center = center
249
256
 
257
+ @override
250
258
  def extra_repr(self) -> str:
251
259
  return f"theta={self.theta:.4f}, center={self.center}"
252
260
 
@@ -263,6 +271,7 @@ class RotateZ(Generic[T], AffineTransform[T]):
263
271
  self.theta = theta
264
272
  self.center = center
265
273
 
274
+ @override
266
275
  def extra_repr(self) -> str:
267
276
  return f"theta={self.theta:.4f}, center={self.center}"
268
277
 
@@ -19,6 +19,7 @@ import numpy as np
19
19
  import numpy.typing as npt
20
20
  from scipy.fftpack import fftn, fftshift, ifftn
21
21
  from scipy.ndimage import gaussian_filter, minimum_filter
22
+ from typing_extensions import override
22
23
 
23
24
  from swcgeom.transforms.base import Transform
24
25
 
@@ -36,6 +37,7 @@ class SGuoImPreProcess(Transform[npt.NDArray[np.uint8], npt.NDArray[np.uint8]]):
36
37
  January 2022, Pages 503–512, https://doi.org/10.1093/bioinformatics/btab638
37
38
  """
38
39
 
40
+ @override
39
41
  def __call__(self, x: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
40
42
  # TODO: support np.float32
41
43
  assert x.dtype == np.uint8, "Image must be in uint8 format"
@@ -15,9 +15,7 @@
15
15
 
16
16
  """Create image stack from morphology.
17
17
 
18
- Notes
19
- -----
20
- All denpendencies need to be installed, try:
18
+ NOTE: All denpendencies need to be installed, try:
21
19
 
22
20
  ```sh
23
21
  pip install swcgeom[all]
@@ -28,7 +26,7 @@ import os
28
26
  import re
29
27
  import time
30
28
  from collections.abc import Iterable
31
- from typing import Optional
29
+ from typing import Sequence
32
30
 
33
31
  import numpy as np
34
32
  import numpy.typing as npt
@@ -42,7 +40,7 @@ from sdflit import (
42
40
  SDFObject,
43
41
  )
44
42
  from tqdm import tqdm
45
- from typing_extensions import deprecated
43
+ from typing_extensions import deprecated, override
46
44
 
47
45
  from swcgeom.core import Population, Tree
48
46
  from swcgeom.transforms.base import Transform
@@ -58,46 +56,53 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
58
56
  def __init__(self, resolution: int | float | npt.ArrayLike = 1) -> None:
59
57
  """Transform tree to image stack.
60
58
 
61
- Parameters
62
- ----------
63
- resolution : int | (x, y, z), default `(1, 1, 1)`
64
- Resolution of image stack.
59
+ Args:
60
+ resolution: Resolution of image stack.
61
+ If a scalar, it will be broadcasted to a vector of 3d.
65
62
  """
66
-
67
63
  if isinstance(resolution, (int, float, np.integer, np.floating)):
68
64
  resolution = [resolution, resolution, resolution] # type: ignore
69
65
 
70
66
  self.resolution = np.array(resolution, dtype=np.float32)
71
- assert len(self.resolution) == 3, "resolution shoule be vector of 3d."
67
+ assert len(self.resolution) == 3, "resolution should be vector of 3d."
72
68
 
69
+ @override
73
70
  def __call__(self, x: Tree) -> npt.NDArray[np.uint8]:
74
71
  """Transform tree to image stack.
75
72
 
76
- Notes
77
- -----
78
- This method loads the entire image stack into memory, so it
79
- ONLY works for small image stacks, use :meth`transform_and_save`
80
- for big image stack.
73
+ NOTE: This method loads the entire image stack into memory, so it ONLY works
74
+ for small image stacks, use :meth`transform_and_save` for big image stack.
81
75
  """
82
76
  return np.stack(list(self.transform(x, verbose=False)), axis=0)
83
77
 
84
78
  def transform(
85
79
  self,
86
- x: Tree,
80
+ x: Tree | Sequence[Tree],
87
81
  verbose: bool = True,
88
82
  *,
89
- ranges: Optional[tuple[npt.ArrayLike, npt.ArrayLike]] = None,
83
+ ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
90
84
  ) -> Iterable[npt.NDArray[np.uint8]]:
85
+ trees = [x] if isinstance(x, Tree) else x
86
+ if not trees:
87
+ return iter([]) # Return empty iterator if sequence is empty
88
+
89
+ time_start = None
91
90
  if verbose:
92
- print("To image stack: " + x.source)
91
+ sources = ", ".join(t.source for t in trees if t.source)
92
+ print(f"To image stack: {sources if sources else 'unnamed trees'}")
93
93
  time_start = time.time()
94
94
 
95
- scene = self._get_scene(x)
95
+ scene = self._get_scene(trees)
96
96
 
97
97
  if ranges is None:
98
- xyz, r = x.xyz(), x.r().reshape(-1, 1)
99
- coord_min = np.floor(np.min(xyz - r, axis=0))
100
- coord_max = np.ceil(np.max(xyz + r, axis=0))
98
+ all_xyz = np.concatenate([t.xyz() for t in trees], axis=0)
99
+ all_r = np.concatenate([t.r() for t in trees], axis=0).reshape(-1, 1)
100
+ if all_xyz.size == 0: # Handle empty trees
101
+ coord_min = np.zeros(3, dtype=np.float32)
102
+ coord_max = np.zeros(3, dtype=np.float32)
103
+ else:
104
+ coord_min = np.floor(np.min(all_xyz - all_r, axis=0))
105
+ coord_max = np.ceil(np.max(all_xyz + all_r, axis=0))
101
106
  else:
102
107
  assert len(ranges) == 2
103
108
  coord_min = np.array(ranges[0])
@@ -106,12 +111,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
106
111
 
107
112
  samplers = self._get_samplers(coord_min, coord_max)
108
113
 
109
- if verbose:
114
+ if verbose and time_start is not None:
110
115
  total = (coord_max[2] - coord_min[2]) / self.resolution[2]
111
116
  samplers = tqdm(samplers, total=total.astype(np.int64).item())
112
117
 
113
118
  time_end = time.time()
114
- print("Prepare in: ", time_end - time_start, "s") # type: ignore
119
+ print("Prepare in: ", time_end - time_start, "s")
115
120
 
116
121
  for sampler in samplers:
117
122
  voxel = sampler.sample(scene) # should be shape of (x, y, z, 3) and z = 1
@@ -124,12 +129,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
124
129
  x: Tree,
125
130
  verbose: bool = True,
126
131
  *,
127
- ranges: Optional[tuple[npt.ArrayLike, npt.ArrayLike]] = None,
132
+ ranges: tuple[npt.ArrayLike, npt.ArrayLike] | None = None,
128
133
  ) -> Iterable[npt.NDArray[np.uint8]]:
129
134
  return self.transform(x, verbose, ranges=ranges)
130
135
 
131
136
  def transform_and_save(
132
- self, fname: str, x: Tree, verbose: bool = True, **kwargs
137
+ self, fname: str, x: Tree | Sequence[Tree], verbose: bool = True, **kwargs
133
138
  ) -> None:
134
139
  self.save_tif(fname, self.transform(x, verbose=verbose, **kwargs))
135
140
 
@@ -151,11 +156,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
151
156
  if not os.path.isfile(tif):
152
157
  self.transform_and_save(tif, tree, verbose=False)
153
158
 
159
+ @override
154
160
  def extra_repr(self) -> str:
155
161
  res = ",".join(f"{a:.4f}" for a in self.resolution)
156
162
  return f"resolution=({res})"
157
163
 
158
- def _get_scene(self, x: Tree) -> Scene:
164
+ def _get_scene(self, trees: Sequence[Tree]) -> Scene:
159
165
  material = ColoredMaterial((1, 0, 0)).into()
160
166
  scene = ObjectsScene()
161
167
  scene.set_background((0, 0, 0))
@@ -164,10 +170,11 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
164
170
  for c in children:
165
171
  sdf = RoundCone(_tp3f(n.xyz()), _tp3f(c.xyz()), n.r, c.r).into()
166
172
  scene.add_object(SDFObject(sdf, material).into())
167
-
168
173
  return n
169
174
 
170
- x.traverse(leave=leave)
175
+ for tree in trees:
176
+ tree.traverse(leave=leave)
177
+
171
178
  scene.build_bvh()
172
179
  return scene.into()
173
180
 
@@ -175,14 +182,12 @@ class ToImageStack(Transform[Tree, npt.NDArray[np.uint8]]):
175
182
  self,
176
183
  coord_min: npt.NDArray,
177
184
  coord_max: npt.NDArray,
178
- offset: Optional[npt.NDArray] = None,
185
+ offset: npt.NDArray | None = None,
179
186
  ) -> Iterable[RangeSampler]:
180
187
  """Get Samplers.
181
188
 
182
- Parameters
183
- ----------
184
- coord_min, coord_max: npt.ArrayLike
185
- Coordinates array of shape (3,).
189
+ Args:
190
+ coord_min, coord_max: Coordinates array of shape (3,).
186
191
  """
187
192
 
188
193
  eps = 1e-6