swcgeom 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (45) hide show
  1. swcgeom/_version.py +2 -2
  2. swcgeom/analysis/lmeasure.py +821 -0
  3. swcgeom/analysis/sholl.py +31 -2
  4. swcgeom/core/__init__.py +4 -0
  5. swcgeom/core/branch.py +9 -4
  6. swcgeom/core/branch_tree.py +2 -3
  7. swcgeom/core/{segment.py → compartment.py} +14 -9
  8. swcgeom/core/node.py +0 -8
  9. swcgeom/core/path.py +21 -6
  10. swcgeom/core/population.py +42 -3
  11. swcgeom/core/swc_utils/assembler.py +20 -138
  12. swcgeom/core/swc_utils/base.py +12 -5
  13. swcgeom/core/swc_utils/checker.py +12 -2
  14. swcgeom/core/swc_utils/subtree.py +2 -2
  15. swcgeom/core/tree.py +53 -49
  16. swcgeom/core/tree_utils.py +27 -5
  17. swcgeom/core/tree_utils_impl.py +22 -6
  18. swcgeom/images/augmentation.py +6 -1
  19. swcgeom/images/contrast.py +107 -0
  20. swcgeom/images/folder.py +111 -29
  21. swcgeom/images/io.py +79 -40
  22. swcgeom/transforms/__init__.py +2 -0
  23. swcgeom/transforms/base.py +41 -21
  24. swcgeom/transforms/branch.py +5 -5
  25. swcgeom/transforms/geometry.py +42 -18
  26. swcgeom/transforms/image_preprocess.py +100 -0
  27. swcgeom/transforms/image_stack.py +46 -28
  28. swcgeom/transforms/images.py +76 -6
  29. swcgeom/transforms/mst.py +10 -18
  30. swcgeom/transforms/neurolucida_asc.py +495 -0
  31. swcgeom/transforms/population.py +2 -2
  32. swcgeom/transforms/tree.py +12 -14
  33. swcgeom/transforms/tree_assembler.py +85 -19
  34. swcgeom/utils/__init__.py +1 -0
  35. swcgeom/utils/neuromorpho.py +425 -300
  36. swcgeom/utils/numpy_helper.py +14 -4
  37. swcgeom/utils/plotter_2d.py +130 -0
  38. swcgeom/utils/renderer.py +28 -139
  39. swcgeom/utils/sdf.py +5 -1
  40. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/METADATA +3 -3
  41. swcgeom-0.16.0.dist-info/RECORD +67 -0
  42. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/WHEEL +1 -1
  43. swcgeom-0.14.0.dist-info/RECORD +0 -62
  44. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/LICENSE +0 -0
  45. {swcgeom-0.14.0.dist-info → swcgeom-0.16.0.dist-info}/top_level.txt +0 -0
swcgeom/core/tree.py CHANGED
@@ -20,15 +20,14 @@ from typing import (
20
20
  import numpy as np
21
21
  import numpy.typing as npt
22
22
  import pandas as pd
23
- from typing_extensions import Self
24
23
 
25
24
  from swcgeom.core.branch import Branch
25
+ from swcgeom.core.compartment import Compartment, Compartments
26
26
  from swcgeom.core.node import Node
27
27
  from swcgeom.core.path import Path
28
- from swcgeom.core.segment import Segment, Segments
29
28
  from swcgeom.core.swc import DictSWC, eswc_cols
30
29
  from swcgeom.core.swc_utils import SWCNames, get_names, read_swc, traverse
31
- from swcgeom.core.tree_utils_impl import get_subtree_impl
30
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl
32
31
  from swcgeom.utils import PathOrIO, padding1d
33
32
 
34
33
  __all__ = ["Tree"]
@@ -49,15 +48,6 @@ class Tree(DictSWC):
49
48
  children = self.attach.id()[self.attach.pid() == self.id]
50
49
  return [Tree.Node(self.attach, idx) for idx in children]
51
50
 
52
- def get_branch(self) -> "Tree.Branch":
53
- warnings.warn(
54
- "`Tree.Node.get_branch` has been renamed to "
55
- "`Tree.Node.branch` since v0.3.1 and will be removed "
56
- "in next version",
57
- DeprecationWarning,
58
- )
59
- return self.branch()
60
-
61
51
  def branch(self) -> "Tree.Branch":
62
52
  ns: List["Tree.Node"] = [self]
63
53
  while not ns[-1].is_bifurcation() and (p := ns[-1].parent()) is not None:
@@ -73,9 +63,18 @@ class Tree(DictSWC):
73
63
  """The end-to-end straight-line distance to soma."""
74
64
  return self.distance(self.attach.soma())
75
65
 
76
- def subtree(self) -> "Tree":
77
- """Get subtree from node."""
78
- n_nodes, ndata, source, names = get_subtree_impl(self.attach, self.id)
66
+ def subtree(self, *, out_mapping: Optional[Mapping] = None) -> "Tree":
67
+ """Get subtree from node.
68
+
69
+ Parameters
70
+ ----------
71
+ out_mapping : List of int or Dict[int, int], optional
72
+ Map from new id to old id.
73
+ """
74
+
75
+ n_nodes, ndata, source, names = get_subtree_impl(
76
+ self.attach, self.id, out_mapping=out_mapping
77
+ )
79
78
  return Tree(n_nodes, **ndata, source=source, names=names)
80
79
 
81
80
  def is_root(self) -> bool:
@@ -107,9 +106,11 @@ class Tree(DictSWC):
107
106
  # TODO: should returns `Tree.Node`
108
107
  """Neural path."""
109
108
 
110
- class Segment(Segment["Tree"]):
109
+ class Compartment(Compartment["Tree"]):
111
110
  # TODO: should returns `Tree.Node`
112
- """Neural segment."""
111
+ """Neural compartment."""
112
+
113
+ Segment = Compartment # Alias
113
114
 
114
115
  class Branch(Branch["Tree"]):
115
116
  # TODO: should returns `Tree.Node`
@@ -119,33 +120,33 @@ class Tree(DictSWC):
119
120
  self,
120
121
  n_nodes: int,
121
122
  *,
122
- # pylint: disable-next=redefined-builtin
123
- id: Optional[npt.NDArray[np.int32]] = None,
124
- # pylint: disable-next=redefined-builtin
125
- type: Optional[npt.NDArray[np.int32]] = None,
126
- x: Optional[npt.NDArray[np.float32]] = None,
127
- y: Optional[npt.NDArray[np.float32]] = None,
128
- z: Optional[npt.NDArray[np.float32]] = None,
129
- r: Optional[npt.NDArray[np.float32]] = None,
130
- pid: Optional[npt.NDArray[np.int32]] = None,
131
123
  source: str = "",
132
124
  comments: Optional[Iterable[str]] = None,
133
125
  names: Optional[SWCNames] = None,
134
126
  **kwargs: npt.NDArray,
135
127
  ) -> None:
136
128
  names = get_names(names)
137
- id = np.arange(0, n_nodes, step=1, dtype=np.int32) if id is None else id
138
- pid = np.arange(-1, n_nodes - 1, step=1, dtype=np.int32) if pid is None else pid
129
+
130
+ if names.id not in kwargs:
131
+ kwargs[names.id] = np.arange(0, n_nodes, step=1, dtype=np.int32)
132
+
133
+ if names.pid not in kwargs:
134
+ kwargs[names.pid] = np.arange(-1, n_nodes - 1, step=1, dtype=np.int32)
139
135
 
140
136
  ndata = {
141
- names.id: padding1d(n_nodes, id, dtype=np.int32),
142
- names.type: padding1d(n_nodes, type, dtype=np.int32),
143
- names.x: padding1d(n_nodes, x),
144
- names.y: padding1d(n_nodes, y),
145
- names.z: padding1d(n_nodes, z),
146
- names.r: padding1d(n_nodes, r, padding_value=1),
147
- names.pid: padding1d(n_nodes, pid, dtype=np.int32),
137
+ names.id: padding1d(n_nodes, kwargs.pop(names.id, None), dtype=np.int32),
138
+ names.type: padding1d(
139
+ n_nodes, kwargs.pop(names.type, None), dtype=np.int32
140
+ ),
141
+ names.x: padding1d(n_nodes, kwargs.pop(names.x, None), dtype=np.float32),
142
+ names.y: padding1d(n_nodes, kwargs.pop(names.y, None), dtype=np.float32),
143
+ names.z: padding1d(n_nodes, kwargs.pop(names.z, None), dtype=np.float32),
144
+ names.r: padding1d(
145
+ n_nodes, kwargs.pop(names.r, None), dtype=np.float32, padding_value=1
146
+ ),
147
+ names.pid: padding1d(n_nodes, kwargs.pop(names.pid, None), dtype=np.int32),
148
148
  }
149
+ # ? padding other columns
149
150
  super().__init__(
150
151
  **ndata, **kwargs, source=source, comments=comments, names=names
151
152
  )
@@ -214,13 +215,16 @@ class Tree(DictSWC):
214
215
  tip_ids = np.setdiff1d(self.id(), self.pid(), assume_unique=True)
215
216
  return [self.node(i) for i in tip_ids]
216
217
 
217
- def get_segments(self) -> Segments[Segment]:
218
- return Segments(self.Segment(self, n.pid, n.id) for n in self[1:])
218
+ def get_compartments(self) -> Compartments[Compartment]:
219
+ return Compartments(self.Compartment(self, n.pid, n.id) for n in self[1:])
219
220
 
220
- def get_branches(self) -> List[Branch]:
221
- Info = Tuple[List[Tree.Branch], List[int]]
221
+ def get_segments(self) -> Compartments[Compartment]: # Alias
222
+ return self.get_compartments()
222
223
 
223
- def collect_branches(node: "Tree.Node", pre: List[Info]) -> Info:
224
+ def get_branches(self) -> List[Branch]:
225
+ def collect_branches(
226
+ node: "Tree.Node", pre: List[Tuple[List[Tree.Branch], List[int]]]
227
+ ) -> Tuple[List[Tree.Branch], List[int]]:
224
228
  if len(pre) == 1:
225
229
  branches, child = pre[0]
226
230
  child.append(node.id)
@@ -243,7 +247,6 @@ class Tree(DictSWC):
243
247
  def get_paths(self) -> List[Path]:
244
248
  """Get all path from soma to tips."""
245
249
  path_dic: Dict[int, List[int]] = {}
246
- Paths = List[List[int]]
247
250
 
248
251
  def assign_path(n: Tree.Node, pre_path: List[int] | None) -> List[int]:
249
252
  path = [] if pre_path is None else pre_path.copy()
@@ -251,7 +254,9 @@ class Tree(DictSWC):
251
254
  path_dic[n.id] = path
252
255
  return path
253
256
 
254
- def collect_path(n: Tree.Node, children: List[Paths]) -> Paths:
257
+ def collect_path(
258
+ n: Tree.Node, children: List[List[List[int]]]
259
+ ) -> List[List[int]]:
255
260
  if len(children) == 0:
256
261
  return [path_dic[n.id]]
257
262
 
@@ -260,11 +265,11 @@ class Tree(DictSWC):
260
265
  paths = self.traverse(enter=assign_path, leave=collect_path)
261
266
  return [self.Path(self, idx) for idx in paths]
262
267
 
263
- def get_neurites(self, type_check: bool = True) -> Iterable[Self]:
268
+ def get_neurites(self, type_check: bool = True) -> Iterable["Tree"]:
264
269
  """Get neurites from soma."""
265
270
  return (n.subtree() for n in self.soma(type_check).children())
266
271
 
267
- def get_dendrites(self, type_check: bool = True) -> Iterable[Self]:
272
+ def get_dendrites(self, type_check: bool = True) -> Iterable["Tree"]:
268
273
  """Get dendrites."""
269
274
  types = [self.types.apical_dendrite, self.types.basal_dendrite]
270
275
  children = self.soma(type_check).children()
@@ -312,15 +317,14 @@ class Tree(DictSWC):
312
317
  """Get length of tree."""
313
318
  return sum(s.length() for s in self.get_segments())
314
319
 
315
- @classmethod
320
+ @staticmethod
316
321
  def from_data_frame(
317
- cls,
318
322
  df: pd.DataFrame,
319
323
  source: str = "",
320
324
  *,
321
325
  comments: Optional[Iterable[str]] = None,
322
326
  names: Optional[SWCNames] = None,
323
- ) -> Self:
327
+ ) -> "Tree":
324
328
  """Read neuron tree from data frame."""
325
329
  names = get_names(names)
326
330
  tree = Tree(
@@ -333,7 +337,7 @@ class Tree(DictSWC):
333
337
  return tree
334
338
 
335
339
  @classmethod
336
- def from_swc(cls, swc_file: PathOrIO, **kwargs) -> Self:
340
+ def from_swc(cls, swc_file: PathOrIO, **kwargs) -> "Tree":
337
341
  """Read neuron tree from swc file.
338
342
 
339
343
  See Also
@@ -352,7 +356,7 @@ class Tree(DictSWC):
352
356
  @classmethod
353
357
  def from_eswc(
354
358
  cls, swc_file: str, extra_cols: Optional[List[str]] = None, **kwargs
355
- ) -> Self:
359
+ ) -> "Tree":
356
360
  """Read neuron tree from eswc file.
357
361
 
358
362
  See Also
@@ -17,7 +17,7 @@ from swcgeom.core.swc_utils import (
17
17
  to_sub_topology,
18
18
  )
19
19
  from swcgeom.core.tree import Tree
20
- from swcgeom.core.tree_utils_impl import get_subtree_impl, to_subtree_impl
20
+ from swcgeom.core.tree_utils_impl import Mapping, get_subtree_impl, to_subtree_impl
21
21
 
22
22
  __all__ = [
23
23
  "sort_tree",
@@ -101,11 +101,15 @@ def to_sub_tree(swc_like: SWCLike, sub: Topology) -> Tuple[Tree, Dict[int, int]]
101
101
  but if the node you remove is not a leaf node, you need to use
102
102
  `propagate_remove` to remove all children.
103
103
 
104
+ .. deprecated:: 0.6.0
105
+ Use :meth:`to_subtree` instead.
106
+
104
107
  Returns
105
108
  -------
106
109
  tree : Tree
107
110
  id_map : Dict[int, int]
108
111
  """
112
+
109
113
  warnings.warn(
110
114
  "`to_sub_tree` will be removed in v0.6.0, it is replaced by "
111
115
  "`to_subtree` beacuse it is easy to use, and this will be "
@@ -128,7 +132,12 @@ def to_sub_tree(swc_like: SWCLike, sub: Topology) -> Tuple[Tree, Dict[int, int]]
128
132
  return subtree, id_map
129
133
 
130
134
 
131
- def to_subtree(swc_like: SWCLike, removals: Iterable[int]) -> Tree:
135
+ def to_subtree(
136
+ swc_like: SWCLike,
137
+ removals: Iterable[int],
138
+ *,
139
+ out_mapping: Optional[Mapping] = None,
140
+ ) -> Tree:
132
141
  """Create subtree from origin tree.
133
142
 
134
143
  Parameters
@@ -136,17 +145,24 @@ def to_subtree(swc_like: SWCLike, removals: Iterable[int]) -> Tree:
136
145
  swc_like : SWCLike
137
146
  removals : List of int
138
147
  A list of id of nodes to be removed.
148
+ out_mapping: List of int or Dict[int, int], optional
149
+ Map new id to old id.
139
150
  """
151
+
140
152
  new_ids = swc_like.id().copy()
141
153
  for i in removals:
142
154
  new_ids[i] = REMOVAL
143
155
 
144
156
  sub = propagate_removal((new_ids, swc_like.pid()))
145
- n_nodes, ndata, source, names = to_subtree_impl(swc_like, sub)
157
+ n_nodes, ndata, source, names = to_subtree_impl(
158
+ swc_like, sub, out_mapping=out_mapping
159
+ )
146
160
  return Tree(n_nodes, **ndata, source=source, names=names)
147
161
 
148
162
 
149
- def get_subtree(swc_like: SWCLike, n: int) -> Tree:
163
+ def get_subtree(
164
+ swc_like: SWCLike, n: int, *, out_mapping: Optional[Mapping] = None
165
+ ) -> Tree:
150
166
  """Get subtree rooted at n.
151
167
 
152
168
  Parameters
@@ -154,8 +170,13 @@ def get_subtree(swc_like: SWCLike, n: int) -> Tree:
154
170
  swc_like : SWCLike
155
171
  n : int
156
172
  Id of the root of the subtree.
173
+ out_mapping: List of int or Dict[int, int], optional
174
+ Map new id to old id.
157
175
  """
158
- n_nodes, ndata, source, names = get_subtree_impl(swc_like, n)
176
+
177
+ n_nodes, ndata, source, names = get_subtree_impl(
178
+ swc_like, n, out_mapping=out_mapping
179
+ )
159
180
  return Tree(n_nodes, **ndata, source=source, names=names)
160
181
 
161
182
 
@@ -171,6 +192,7 @@ def redirect_tree(tree: Tree, new_root: int, sort: bool = True) -> Tree:
171
192
  sort : bool, default `True`
172
193
  If true, sort indices of nodes after redirect.
173
194
  """
195
+
174
196
  tree = tree.copy()
175
197
  path = [tree.node(new_root)]
176
198
  while (p := path[-1].parent()) is not None:
@@ -5,7 +5,7 @@ Notes
5
5
  Do not import `Tree` and keep this file minimized.
6
6
  """
7
7
 
8
- from typing import Any, Dict, Tuple
8
+ from typing import Any, Dict, List, Optional, Tuple
9
9
 
10
10
  import numpy as np
11
11
  import numpy.typing as npt
@@ -15,10 +15,13 @@ from swcgeom.core.swc_utils import Topology, to_sub_topology, traverse
15
15
 
16
16
  __all__ = ["get_subtree_impl", "to_subtree_impl"]
17
17
 
18
+ Mapping = Dict[int, int] | List[int]
18
19
  TreeArgs = Tuple[int, Dict[str, npt.NDArray[Any]], str, SWCNames]
19
20
 
20
21
 
21
- def get_subtree_impl(swc_like: SWCLike, n: int) -> TreeArgs:
22
+ def get_subtree_impl(
23
+ swc_like: SWCLike, n: int, *, out_mapping: Optional[Mapping] = None
24
+ ) -> TreeArgs:
22
25
  ids = []
23
26
  topo = (swc_like.id(), swc_like.pid())
24
27
  traverse(topo, enter=lambda n, _: ids.append(n), root=n)
@@ -26,14 +29,27 @@ def get_subtree_impl(swc_like: SWCLike, n: int) -> TreeArgs:
26
29
  sub_ids = np.array(ids, dtype=np.int32)
27
30
  sub_pid = swc_like.pid()[sub_ids]
28
31
  sub_pid[0] = -1
29
- return to_subtree_impl(swc_like, (sub_ids, sub_pid))
32
+ return to_subtree_impl(swc_like, (sub_ids, sub_pid), out_mapping=out_mapping)
30
33
 
31
34
 
32
- def to_subtree_impl(swc_like: SWCLike, sub: Topology) -> TreeArgs:
33
- (new_id, new_pid), id_map = to_sub_topology(sub)
35
+ def to_subtree_impl(
36
+ swc_like: SWCLike,
37
+ sub: Topology,
38
+ *,
39
+ out_mapping: Optional[Mapping] = None,
40
+ ) -> TreeArgs:
41
+ (new_id, new_pid), mapping = to_sub_topology(sub)
34
42
 
35
43
  n_nodes = new_id.shape[0]
36
- ndata = {k: swc_like.get_ndata(k)[id_map].copy() for k in swc_like.keys()}
44
+ ndata = {k: swc_like.get_ndata(k)[mapping].copy() for k in swc_like.keys()}
37
45
  ndata.update(id=new_id, pid=new_pid)
38
46
 
47
+ if isinstance(out_mapping, list):
48
+ out_mapping.clear()
49
+ out_mapping.extend(mapping)
50
+ elif isinstance(out_mapping, dict):
51
+ out_mapping.clear()
52
+ for new_id, old_id in enumerate(mapping):
53
+ out_mapping[new_id] = old_id # returning a dict may leads to bad perf
54
+
39
55
  return n_nodes, ndata, swc_like.source, swc_like.names
@@ -1,4 +1,9 @@
1
- """Play augment in image stack."""
1
+ """Play augment in image stack.
2
+
3
+ Notes
4
+ -----
5
+ This is expremental code, and the API is subject to change.
6
+ """
2
7
 
3
8
  import random
4
9
  from typing import List, Literal, Optional
@@ -0,0 +1,107 @@
1
+ """The contrast of an image.
2
+
3
+ Notes
4
+ -----
5
+ This is expremental code, and the API is subject to change.
6
+ """
7
+
8
+ from typing import Optional, overload
9
+
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+
13
+ __all__ = ["contrast_std", "contrast_michelson", "contrast_rms", "contrast_weber"]
14
+
15
+ Array3D = npt.NDArray[np.float32]
16
+
17
+
18
+ @overload
19
+ def contrast_std(image: Array3D) -> float:
20
+ """Get the std contrast of an image stack.
21
+
22
+ Parameters
23
+ ----------
24
+ imgs : ndarray
25
+
26
+ Returns
27
+ -------
28
+ contrast : float
29
+ """
30
+ ...
31
+
32
+
33
+ @overload
34
+ def contrast_std(image: Array3D, contrast: float) -> Array3D:
35
+ """Adjust the contrast of an image stack.
36
+
37
+ Parameters
38
+ ----------
39
+ imgs : ndarray
40
+ constrast : float
41
+ The contrast adjustment factor. 1.0 leaves the image unchanged.
42
+
43
+ Returns
44
+ -------
45
+ imgs : ndarray
46
+ The adjusted image.
47
+ """
48
+ ...
49
+
50
+
51
+ def contrast_std(image: Array3D, contrast: Optional[float] = None):
52
+ if contrast is None:
53
+ return np.std(image).item()
54
+ else:
55
+ return np.clip(contrast * image, 0, 1)
56
+
57
+
58
+ def contrast_michelson(image: Array3D) -> float:
59
+ """Get the Michelson contrast of an image stack.
60
+
61
+ Parameters
62
+ ----------
63
+ imgs : ndarray
64
+
65
+ Returns
66
+ -------
67
+ contrast : float
68
+ """
69
+
70
+ vmax = np.max(image)
71
+ vmin = np.min(image)
72
+ return ((vmax - vmin) / (vmax + vmin)).item()
73
+
74
+
75
+ def contrast_rms(imgs: npt.NDArray[np.float32]) -> float:
76
+ """Get the RMS contrast of an image stack.
77
+
78
+ Parameters
79
+ ----------
80
+ imgs : ndarray
81
+
82
+ Returns
83
+ -------
84
+ contrast : float
85
+ """
86
+
87
+ return np.sqrt(np.mean(imgs**2)).item()
88
+
89
+
90
+ def contrast_weber(imgs: Array3D, mask: npt.NDArray[np.bool_]) -> float:
91
+ """Get the Weber contrast of an image stack.
92
+
93
+ Parameters
94
+ ----------
95
+ imgs : ndarray
96
+ mask : ndarray of bool
97
+ The mask to segment the foreground and background. 1 for
98
+ foreground, 0 for background.
99
+
100
+ Returns
101
+ -------
102
+ contrast : float
103
+ """
104
+
105
+ l_foreground = np.mean(imgs, where=mask)
106
+ l_background = np.mean(imgs, where=np.logical_not(mask))
107
+ return ((l_foreground - l_background) / l_background).item()
swcgeom/images/folder.py CHANGED
@@ -1,57 +1,64 @@
1
1
  """Image stack folder."""
2
2
 
3
+ import math
3
4
  import os
4
5
  import re
5
- from abc import ABC, abstractmethod
6
- from typing import Callable, Generic, Iterable, List, Literal, Optional, Tuple, TypeVar
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import (
9
+ Callable,
10
+ Generic,
11
+ Iterable,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ TypeVar,
17
+ overload,
18
+ )
7
19
 
8
20
  import numpy as np
9
21
  import numpy.typing as npt
22
+ from tqdm import tqdm
10
23
  from typing_extensions import Self
11
24
 
12
- from swcgeom.images.io import read_imgs
25
+ from swcgeom.images.io import ScalarType, read_imgs
13
26
  from swcgeom.transforms import Identity, Transform
14
27
 
15
- __all__ = [
16
- "ImageStackFolder",
17
- "LabeledImageStackFolder",
18
- "PathImageStackFolder",
19
- ]
28
+ __all__ = ["ImageStackFolder", "LabeledImageStackFolder", "PathImageStackFolder"]
20
29
 
21
30
  T = TypeVar("T")
22
31
 
23
32
 
24
- class ImageStackFolderBase(Generic[T], ABC):
33
+ class ImageStackFolderBase(Generic[ScalarType, T]):
25
34
  """Image stack folder base."""
26
35
 
27
36
  files: List[str]
28
- transform: Transform[npt.NDArray[np.float32], T]
37
+ transform: Transform[npt.NDArray[ScalarType], T]
29
38
 
30
- def __init__(
31
- self,
32
- files: Iterable[str],
33
- *,
34
- transform: Optional[Transform[npt.NDArray[np.float32], T]] = None,
35
- ) -> None:
39
+ # fmt: off
40
+ @overload
41
+ def __init__(self, files: Iterable[str], *, dtype: None = ..., transform: Optional[Transform[npt.NDArray[np.float32], T]] = None) -> None: ...
42
+ @overload
43
+ def __init__(self, files: Iterable[str], *, dtype: ScalarType, transform: Optional[Transform[npt.NDArray[ScalarType], T]] = None) -> None: ...
44
+ # fmt: on
45
+
46
+ def __init__(self, files: Iterable[str], *, dtype=None, transform=None) -> None:
36
47
  super().__init__()
37
48
  self.files = list(files)
49
+ self.dtype = dtype or np.float32
38
50
  self.transform = transform or Identity() # type: ignore
39
51
 
40
- @abstractmethod
41
- def __getitem__(self, key: str, /) -> T:
42
- raise NotImplementedError()
43
-
44
52
  def __len__(self) -> int:
45
53
  return len(self.files)
46
54
 
47
55
  def _get(self, fname: str) -> T:
48
- imgs = self.read_imgs(fname)
56
+ imgs = self._read(fname)
49
57
  imgs = self.transform(imgs)
50
58
  return imgs
51
59
 
52
- @staticmethod
53
- def read_imgs(fname: str) -> npt.NDArray[np.float32]:
54
- return read_imgs(fname).get_full()
60
+ def _read(self, fname: str) -> npt.NDArray[ScalarType]:
61
+ return read_imgs(fname, dtype=self.dtype).get_full() # type: ignore
55
62
 
56
63
  @staticmethod
57
64
  def scan(root: str, *, pattern: Optional[str] = None) -> List[str]:
@@ -63,13 +70,86 @@ class ImageStackFolderBase(Generic[T], ABC):
63
70
 
64
71
  return fs
65
72
 
73
+ @staticmethod
74
+ def read_imgs(fname: str) -> npt.NDArray[np.float32]:
75
+ """Read images.
76
+
77
+ .. deprecated:: 0.16.0
78
+ Use :meth:`~swcgeom.images.io.read_imgs(fname).get_full()` instead.
79
+ """
80
+
81
+ warnings.warn(
82
+ "`ImageStackFolderBase.read_imgs` serves as a "
83
+ "straightforward wrapper for `~swcgeom.images.io.read_imgs(fname).get_full()`. "
84
+ "However, as it is not utilized within our internal "
85
+ "processes, it is scheduled for removal in the "
86
+ "forthcoming version.",
87
+ DeprecationWarning,
88
+ )
89
+ return read_imgs(fname).get_full()
66
90
 
67
- class ImageStackFolder(Generic[T], ImageStackFolderBase[T]):
91
+
92
+ @dataclass(frozen=True)
93
+ class Statistics:
94
+ count: int = 0
95
+ minimum: float = math.nan
96
+ maximum: float = math.nan
97
+ mean: float = 0
98
+ variance: float = 0
99
+
100
+
101
+ class ImageStackFolder(ImageStackFolderBase[ScalarType, T]):
68
102
  """Image stack folder."""
69
103
 
70
104
  def __getitem__(self, idx: int, /) -> T:
71
105
  return self._get(self.files[idx])
72
106
 
107
+ def stat(self, *, transform: bool = False, verbose: bool = False) -> Statistics:
108
+ """Statistics of folder.
109
+
110
+ Parameters
111
+ ----------
112
+ transform : bool, default to False
113
+ Apply transform to the images. If True, you need to make
114
+ sure the transformed data is a ndarray.
115
+ verbose : bool, optional
116
+
117
+ Notes
118
+ -----
119
+ We are asserting that the images are of the same shape.
120
+ """
121
+
122
+ vmin, vmax = math.inf, -math.inf
123
+ n, mean, M2 = 0, None, None
124
+
125
+ for idx in tqdm(range(len(self))) if verbose else range(len(self)):
126
+ imgs = self[idx] if transform else self._read(self.files[idx])
127
+
128
+ vmin = min(vmin, np.min(imgs)) # type: ignore
129
+ vmax = max(vmax, np.max(imgs)) # type: ignore
130
+ # Welford algorithm to calculate mean and variance
131
+ if mean is None:
132
+ mean = np.zeros_like(imgs)
133
+ M2 = np.zeros_like(imgs)
134
+
135
+ n += 1
136
+ delta = imgs - mean # type: ignore
137
+ mean += delta / n
138
+ delta2 = imgs - mean
139
+ M2 += delta * delta2
140
+
141
+ if mean is None or M2 is None: # n = 0
142
+ raise ValueError("empty folder")
143
+
144
+ variance = M2 / (n - 1) if n > 1 else np.zeros_like(mean)
145
+ return Statistics(
146
+ count=len(self),
147
+ maximum=vmax,
148
+ minimum=vmin,
149
+ mean=np.mean(mean).item(),
150
+ variance=np.mean(variance).item(),
151
+ )
152
+
73
153
  @classmethod
74
154
  def from_dir(cls, root: str, *, pattern: Optional[str] = None, **kwargs) -> Self:
75
155
  """
@@ -81,10 +161,11 @@ class ImageStackFolder(Generic[T], ImageStackFolderBase[T]):
81
161
  **kwargs
82
162
  Pass to `cls.__init__`
83
163
  """
164
+
84
165
  return cls(cls.scan(root, pattern=pattern), **kwargs)
85
166
 
86
167
 
87
- class LabeledImageStackFolder(Generic[T], ImageStackFolderBase[T]):
168
+ class LabeledImageStackFolder(ImageStackFolderBase[ScalarType, T]):
88
169
  """Image stack folder with label."""
89
170
 
90
171
  labels: List[int]
@@ -93,8 +174,8 @@ class LabeledImageStackFolder(Generic[T], ImageStackFolderBase[T]):
93
174
  super().__init__(files, **kwargs)
94
175
  self.labels = list(labels)
95
176
 
96
- def __getitem__(self, idx: int) -> Tuple[npt.NDArray[np.float32], int]:
97
- return self.read_imgs(self.files[idx]), self.labels[idx]
177
+ def __getitem__(self, idx: int) -> Tuple[T, int]:
178
+ return self._get(self.files[idx]), self.labels[idx]
98
179
 
99
180
  @classmethod
100
181
  def from_dir(
@@ -115,7 +196,7 @@ class LabeledImageStackFolder(Generic[T], ImageStackFolderBase[T]):
115
196
  return cls(files, labels, **kwargs)
116
197
 
117
198
 
118
- class PathImageStackFolder(Generic[T], ImageStackFolder[T]):
199
+ class PathImageStackFolder(ImageStackFolderBase[ScalarType, T]):
119
200
  """Image stack folder with relpath."""
120
201
 
121
202
  root: str
@@ -139,6 +220,7 @@ class PathImageStackFolder(Generic[T], ImageStackFolder[T]):
139
220
  **kwargs
140
221
  Pass to `cls.__init__`
141
222
  """
223
+
142
224
  return cls(cls.scan(root, pattern=pattern), root=root, **kwargs)
143
225
 
144
226