swcgeom 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

swcgeom/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # file generated by setuptools_scm
2
2
  # don't change, don't track in version control
3
- __version__ = version = '0.4.0'
4
- __version_tuple__ = version_tuple = (0, 4, 0)
3
+ __version__ = version = '0.5.0'
4
+ __version_tuple__ = version_tuple = (0, 5, 0)
@@ -2,32 +2,28 @@
2
2
 
3
3
  Notes
4
4
  -----
5
- For development, see method `Features.get_evaluator`
6
- to confirm the naming specification.
5
+ For development, see method `Features.get_evaluator` to confirm the
6
+ naming specification.
7
7
  """
8
8
 
9
9
  from functools import cached_property
10
- from typing import Any, Callable, Dict, List, Literal, Tuple, cast, overload
10
+ from itertools import chain
11
+ from os.path import basename
12
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, overload
11
13
 
12
14
  import numpy as np
13
15
  import numpy.typing as npt
14
16
  import seaborn as sns
15
17
  from matplotlib.axes import Axes
16
- from matplotlib.figure import Figure
17
18
 
18
- from ..core import Population, Tree
19
- from ..utils import XYPair, get_fig_ax, padding1d, to_distribution
19
+ from ..core import Population, Populations, Tree
20
+ from ..utils import padding1d
20
21
  from .branch_features import BranchFeatures
21
22
  from .node_features import NodeFeatures
22
23
  from .path_features import PathFeatures
23
24
  from .sholl import Sholl
24
25
 
25
- __all__ = [
26
- "Feature",
27
- "FeatureExtractor",
28
- "PopulationFeatureExtractor",
29
- "extract_feature",
30
- ]
26
+ __all__ = ["Feature", "extract_feature"]
31
27
 
32
28
  Feature = Literal[
33
29
  "length",
@@ -43,6 +39,11 @@ Feature = Literal[
43
39
  "path_tortuosity",
44
40
  ]
45
41
 
42
+ Bins = int | npt.ArrayLike | str
43
+ Range = Optional[Tuple[float, float]]
44
+ HistAndBinEdges = Tuple[npt.NDArray, npt.NDArray]
45
+ FeatureWithKwargs = Feature | Tuple[Feature, Dict[str, Any]]
46
+
46
47
 
47
48
  class Features:
48
49
  """Tree features"""
@@ -62,10 +63,10 @@ class Features:
62
63
  def __init__(self, tree: Tree) -> None:
63
64
  self.tree = tree
64
65
 
65
- def get(self, feature: Feature, **kwargs) -> npt.NDArray[np.float32]:
66
- evaluator = self.get_evaluator(feature)
67
- feat = evaluator(**kwargs).astype(np.float32)
68
- return feat
66
+ def get(self, feature: FeatureWithKwargs, **kwargs) -> npt.NDArray[np.float32]:
67
+ feat, kwargs = _get_feature_and_kwargs(feature, **kwargs)
68
+ evaluator = self.get_evaluator(feat)
69
+ return evaluator(**kwargs)
69
70
 
70
71
  def get_evaluator(self, feature: Feature) -> Callable[[], npt.NDArray]:
71
72
  if callable(calc := getattr(self, f"get_{feature}", None)):
@@ -79,196 +80,204 @@ class Features:
79
80
 
80
81
  raise ValueError(f"Invalid feature: {feature}")
81
82
 
82
- def get_distribution(
83
- self, feature: Feature, step: float | None, **kwargs
84
- ) -> XYPair:
85
- if callable(method := getattr(self, f"get_{feature}_distribution", None)):
86
- if step is not None:
87
- kwargs.setdefault("step", step)
88
- return method(**kwargs) # custom feature distribution
89
-
90
- feat = self.get(feature, **kwargs)
91
- step = cast(float, feat.max() / 100) if step is None else step
92
- x, y = to_distribution(feat, step)
93
- return x, y
94
-
95
83
  # Features
96
84
 
97
85
  def get_length(self, **kwargs) -> npt.NDArray[np.float32]:
98
86
  return np.array([self.tree.length(**kwargs)], dtype=np.float32)
99
87
 
100
- def get_sholl(self, **kwargs) -> npt.NDArray[np.int32]:
101
- return Sholl(self.tree, **kwargs).get_count()
102
-
103
- def get_sholl_distribution(self, **kwargs) -> XYPair:
104
- x, y = Sholl(self.tree, **kwargs).get_distribution()
105
- return x, y.astype(np.float32)
88
+ def get_sholl(self, **kwargs) -> npt.NDArray[np.float32]:
89
+ return Sholl(self.tree, **kwargs).get().astype(np.float32)
106
90
 
107
91
 
108
92
  class FeatureExtractor:
109
- """Extract feature from tree."""
110
-
111
- features: Features
112
-
113
- def __init__(self, tree: Tree) -> None:
114
- self.features = Features(tree)
115
-
116
93
  # fmt:off
117
94
  @overload
118
95
  def get(self, feature: Feature, **kwargs) -> npt.NDArray[np.float32]: ...
119
96
  @overload
120
- def get(self, feature: List[Feature], **kwargs) -> List[npt.NDArray[np.float32]]: ...
97
+ def get(self, feature: List[FeatureWithKwargs]) -> List[npt.NDArray[np.float32]]: ...
121
98
  @overload
122
- def get(self, feature: Dict[Feature, Dict[str, Any]], **kwargs) -> Dict[str, npt.NDArray[np.float32]]: ...
99
+ def get(self, feature: Dict[Feature, Dict[str, Any]]) -> Dict[str, npt.NDArray[np.float32]]: ...
123
100
  # fmt:on
124
101
  def get(self, feature, **kwargs):
125
- """Get feature of shape (L,)."""
102
+ """Get feature.
103
+
104
+ Notes
105
+ -----
106
+ Shape of returned array is not uniform, `TreeFeatureExtractor`
107
+ returns array of shape (L, ), `PopulationFeatureExtracor`
108
+ returns array of shape (N, L).
109
+ """
126
110
  if isinstance(feature, dict):
127
- return {k: self.features.get(k, **v) for k, v in feature.items()}
111
+ return {k: self._get(k, **v) for k, v in feature.items()}
128
112
 
129
113
  if isinstance(feature, list):
130
- return [self.features.get(k) for k in feature]
114
+ return [self._get(k) for k in feature]
131
115
 
132
- return self.features.get(feature, **kwargs)
116
+ return self._get(feature, **kwargs)
133
117
 
134
- # fmt:off
135
- @overload
136
- def get_distribution(self, feature: Feature, step: float = ..., **kwargs) -> XYPair: ...
137
- @overload
138
- def get_distribution(self, feature: List[Feature], step: float = ..., **kwargs) -> List[XYPair]: ...
139
- @overload
140
- def get_distribution(self, feature: Dict[Feature, Dict[str, Any]], step: float = ..., **kwargs) -> Dict[str, XYPair]: ...
141
- # fmt:on
142
- def get_distribution(self, feature, step: float | None = None, **kwargs):
143
- """Get feature distribution of shape (S,)."""
144
- if isinstance(feature, dict):
145
- return {
146
- k: self.features.get_distribution(k, step, **v)
147
- for k, v in feature.items()
148
- }
118
+ def plot(
119
+ self, feature: FeatureWithKwargs, title: str | bool = True, **kwargs
120
+ ) -> Axes: # TODO: sholl
121
+ feat, _ = _get_feature_and_kwargs(feature)
122
+ if not callable(plot := getattr(self, f"_plot_{feat}", None)):
123
+ plot = self._plot # default plot
149
124
 
150
- if isinstance(feature, list):
151
- return [self.features.get_distribution(k, step) for k in feature]
125
+ vals = self._get(feature)
126
+ ax = plot(vals, **kwargs)
152
127
 
153
- return self.features.get_distribution(feature, step, **kwargs)
128
+ if isinstance(title, str):
129
+ ax.set_title(title)
130
+ elif title is True:
131
+ ax.set_title(feat.replace("_", " ").title())
154
132
 
155
- def plot_distribution(
156
- self,
157
- feature: Feature,
158
- feature_args: Dict[Any, Any] = {},
159
- fig: Figure | None = None,
160
- ax: Axes | None = None,
161
- **kwargs,
162
- ) -> Tuple[Figure, Axes]:
163
- x, y = self.get_distribution(feature, **feature_args)
164
- fig, ax = get_fig_ax(fig, ax)
165
- sns.lineplot(x=x, y=y, ax=ax, **kwargs)
166
- return fig, ax
133
+ return ax
167
134
 
135
+ def _get(self, feature: FeatureWithKwargs, **kwargs) -> npt.NDArray[np.float32]:
136
+ raise NotImplementedError()
168
137
 
169
- class PopulationFeatureExtractor:
170
- """Extract feature from population."""
138
+ def _plot(self, vals: npt.NDArray[np.float32], **kwargs) -> Axes:
139
+ raise NotImplementedError()
171
140
 
172
- population: Population
173
141
 
174
- @cached_property
175
- def _trees(self) -> List[Features]:
176
- return [Features(tree) for tree in self.population]
142
+ class TreeFeatureExtractor(FeatureExtractor):
143
+ """Extract feature from tree."""
177
144
 
178
- def __init__(self, population: Population) -> None:
179
- self.population = population
145
+ _tree: Tree
146
+ _features: Features
180
147
 
181
- # fmt:off
182
- @overload
183
- def get(self, feature: Feature, **kwargs) -> List[npt.NDArray[np.float32]]: ...
184
- @overload
185
- def get(self, feature: Dict[Feature, Dict[str, Any]], **kwargs) -> Dict[str, List[npt.NDArray[np.float32]]]: ...
186
- # fmt:on
187
- def get(self, feature, **kwargs):
188
- """Get feature list of array of shape (N, L_i).
148
+ def __init__(self, tree: Tree) -> None:
149
+ super().__init__()
150
+ self._tree = tree
151
+ self._features = Features(tree)
189
152
 
190
- Which N is the number of tree of population, L is length of
191
- nodes.
192
- """
193
- if isinstance(feature, dict):
194
- return {k: self._get(k, **v) for k, v in feature.items()}
153
+ def _get(self, feature: FeatureWithKwargs, **kwargs) -> npt.NDArray[np.float32]:
154
+ return self._features.get(feature, **kwargs)
195
155
 
196
- return self._get(feature, **kwargs)
156
+ def _plot(self, vals: npt.NDArray[np.float32], **kwargs) -> Axes:
157
+ ax: Axes = sns.histplot(x=vals, **kwargs)
158
+ ax.set_ylabel("Count")
159
+ return ax
197
160
 
198
- # fmt:off
199
- @overload
200
- def get_distribution(self, feature: Feature, step: float = ..., **kwargs) -> XYPair: ...
201
- @overload
202
- def get_distribution(self, feature: List[Feature], step: float = ..., **kwargs) -> List[XYPair]: ...
203
- @overload
204
- def get_distribution(self, feature: Dict[Feature, Dict[str, Any]], step: float = ..., **kwargs) -> Dict[str, XYPair]: ...
205
- # fmt:on
206
- def get_distribution(self, feature, step: float | None = None, **kwargs):
207
- """Get feature distribution of shape (N, S).
208
-
209
- Which N is the number of tree of population, S is size of
210
- distrtibution.
211
-
212
- Returns
213
- -------
214
- x : npt.NDArray[np.float32]
215
- Array of shape (S,).
216
- y : npt.NDArray[np.float32]
217
- Array of shape (N, S).
218
- """
219
- if isinstance(feature, dict):
220
- return {
221
- k: self._get_distribution(k, step=step, **v) for k, v in feature.items()
222
- }
161
+ def _plot_length(self, vals: npt.NDArray[np.float32], **kwargs) -> Axes:
162
+ name = basename(self._tree.source)
163
+ ax: Axes = sns.barplot(x=[name], y=vals.squeeze(), **kwargs)
164
+ ax.set_ylabel("Length")
165
+ return ax
223
166
 
224
- if isinstance(feature, list):
225
- return [self._get_distribution(k, step=step) for k in feature]
226
-
227
- return self._get_distribution(feature, step=step, **kwargs)
228
-
229
- def plot_distribution(
230
- self,
231
- feature: Feature,
232
- feature_args: Dict[Any, Any] = {},
233
- fig: Figure | None = None,
234
- ax: Axes | None = None,
235
- **kwargs,
236
- ) -> Tuple[Figure, Axes]:
237
- x, y = self.get_distribution(feature, **feature_args)
238
- x, y = np.tile(x, y.shape[0]), y.flatten()
239
-
240
- fig, ax = get_fig_ax(fig, ax)
241
- sns.lineplot(x=x, y=y, ax=ax, **kwargs)
242
- return fig, ax
243
-
244
- def _get(self, feature: Feature, **kwargs) -> List[npt.NDArray[np.float32]]:
245
- return [ex.get(feature, **kwargs) for ex in self._trees]
246
-
247
- def _get_distribution(self, feature: Feature, **kwargs) -> XYPair:
248
- assert len(self._trees) != 0
249
-
250
- x, ys = np.array([], dtype=np.float32), list[npt.NDArray[np.float32]]()
251
- for features in self._trees:
252
- xx, y = features.get_distribution(feature, **kwargs)
253
- x = xx if xx.shape[0] > x.shape[0] else x
254
- ys.append(y)
255
-
256
- max_len_y = max(y.shape[0] for y in ys)
257
- y = np.stack([padding1d(max_len_y, y, 0) for y in ys])
258
- return x, y
259
-
260
-
261
- # fmt: off
262
- @overload
263
- def extract_feature(obj: Tree) -> FeatureExtractor: ...
264
- @overload
265
- def extract_feature(obj: Population) -> PopulationFeatureExtractor: ...
266
- # fmt: on
267
- def extract_feature(obj):
167
+
168
+ class PopulationFeatureExtractor(FeatureExtractor):
169
+ """Extract feature from population."""
170
+
171
+ _population: Population
172
+ _features: List[Features]
173
+
174
+ def __init__(self, population: Population) -> None:
175
+ super().__init__()
176
+ self._population = population
177
+ self._features = [Features(t) for t in self._population]
178
+
179
+ def _get(self, feature: FeatureWithKwargs, **kwargs) -> npt.NDArray[np.float32]:
180
+ vals = [f.get(feature, **kwargs) for f in self._features]
181
+ len_max = max(len(v) for v in vals)
182
+ v = np.stack([padding1d(len_max, v, dtype=np.float32) for v in vals])
183
+ return v
184
+
185
+ def _plot(
186
+ self, vals: npt.NDArray[np.float32], bins="auto", range=None, **kwargs
187
+ ) -> Axes:
188
+ bin_edges = np.histogram_bin_edges(vals, bins, range)
189
+ hists = [
190
+ np.histogram(v, bins=bin_edges, weights=(v != 0).astype(np.int32))[0]
191
+ for v in vals
192
+ ]
193
+ hist = np.concatenate(hists)
194
+ x = np.tile((bin_edges[:-1] + bin_edges[1:]) / 2, len(self._population))
195
+
196
+ ax: Axes = sns.lineplot(x=x, y=hist, **kwargs)
197
+ ax.set_ylabel("Count")
198
+ return ax
199
+
200
+ def _plot_length(self, vals: npt.NDArray[np.float32], **kwargs) -> Axes:
201
+ vals = vals.squeeze(axis=1)
202
+ x = [basename(t.source) for t in self._population]
203
+ y = vals.flatten()
204
+ ax: Axes = sns.barplot(x=x, y=y, **kwargs)
205
+ ax.axhline(y=y.mean(), ls="--", lw=1)
206
+ ax.set_ylabel("Length")
207
+ ax.set_xticks([])
208
+ return ax
209
+
210
+
211
+ class PopulationsFeatureExtractor(FeatureExtractor):
212
+ """Extract feature from population."""
213
+
214
+ _populations: Populations
215
+ _features: List[List[Features]]
216
+
217
+ def __init__(self, populations: Populations) -> None:
218
+ super().__init__()
219
+ self._populations = populations
220
+ self._features = [
221
+ [Features(t) for t in p] for p in self._populations.populations
222
+ ]
223
+
224
+ def _get(self, feature: FeatureWithKwargs, **kwargs) -> npt.NDArray[np.float32]:
225
+ vals = [[f.get(feature, **kwargs) for f in fs] for fs in self._features]
226
+ len_max1 = max(len(v) for v in vals)
227
+ len_max2 = max(*chain.from_iterable(((len(vv) for vv in v) for v in vals)))
228
+ out = np.zeros((len(vals), len_max1, len_max2), dtype=np.float32)
229
+ for i, v in enumerate(vals):
230
+ for j, vv in enumerate(v):
231
+ out[i, j, : len(vv)] = vv
232
+
233
+ return out
234
+
235
+ def _plot(
236
+ self, vals: npt.NDArray[np.float32], bins="auto", range=None, **kwargs
237
+ ) -> Axes:
238
+ bin_edges = np.histogram_bin_edges(vals, bins, range)
239
+ histogram = lambda v: np.histogram(
240
+ v, bins=bin_edges, weights=(v != 0).astype(np.int32)
241
+ )
242
+ hists = [[histogram(t)[0] for t in p] for p in vals]
243
+ hist = np.concatenate(hists).flatten()
244
+
245
+ repeats = np.prod(vals.shape[:2]).item()
246
+ x = np.tile((bin_edges[:-1] + bin_edges[1:]) / 2, repeats)
247
+
248
+ labels = self._populations.labels
249
+ length = (len(bin_edges) - 1) * vals.shape[1]
250
+ hue = np.concatenate([np.full(length, fill_value=i) for i in labels])
251
+
252
+ ax: Axes = sns.lineplot(x=x, y=hist, hue=hue, **kwargs)
253
+ ax.set_ylabel("Count")
254
+ return ax
255
+
256
+ def _plot_length(self, vals: npt.NDArray[np.float32], **kwargs) -> Axes:
257
+ vals = vals.squeeze(axis=2)
258
+ labels = self._populations.labels
259
+ x = np.concatenate([np.full(vals.shape[1], fill_value=i) for i in labels])
260
+ y = vals.flatten()
261
+ ax: Axes = sns.boxplot(x=x, y=y, **kwargs)
262
+ ax.set_ylabel("Length")
263
+ return ax
264
+
265
+
266
+ def extract_feature(obj: Tree | Population) -> FeatureExtractor:
268
267
  if isinstance(obj, Tree):
269
- return FeatureExtractor(obj)
268
+ return TreeFeatureExtractor(obj)
270
269
 
271
270
  if isinstance(obj, Population):
272
271
  return PopulationFeatureExtractor(obj)
273
272
 
273
+ if isinstance(obj, Populations):
274
+ return PopulationsFeatureExtractor(obj)
275
+
274
276
  raise TypeError("Invalid argument type.")
277
+
278
+
279
+ def _get_feature_and_kwargs(feature: FeatureWithKwargs, **kwargs):
280
+ if isinstance(feature, tuple):
281
+ return feature[0], {**feature[1], **kwargs}
282
+ else:
283
+ return feature, kwargs
@@ -1,14 +1,11 @@
1
1
  """Depth distribution of tree."""
2
2
 
3
-
4
3
  from functools import cached_property
5
- from typing import List
6
4
 
7
5
  import numpy as np
8
6
  import numpy.typing as npt
9
7
 
10
8
  from ..core import BranchTree, Tree
11
- from ..utils import XYPair, to_distribution
12
9
 
13
10
  __all__ = ["NodeFeatures"]
14
11
 
@@ -27,54 +24,22 @@ class NodeFeatures:
27
24
  Returns
28
25
  -------
29
26
  radial_distance : npt.NDArray[np.float32]
30
- Array of shape (N,), while N is the number of nodes.
27
+ Array of shape (N,).
31
28
  """
32
29
  xyz = self.tree.xyz() - self.tree.soma().xyz()
33
30
  radial_distance = np.linalg.norm(xyz, axis=1)
34
31
  return radial_distance
35
32
 
36
- def get_radial_distance_distribution(
37
- self,
38
- step: float = 1,
39
- /,
40
- filter_bifurcation: bool = False,
41
- filter_tip: bool = False,
42
- filter_other: bool = True,
43
- ) -> XYPair:
44
- """Get radial distance distribution of tree.
45
-
46
- Parameters
47
- ----------
48
- filter_bifurcation : bool, default `False`
49
- Filter bifurcation nodes.
50
- filter_tip : bool, default `False`
51
- Filter tip nodes.
52
- filter_other : bool, default `False`
53
- Filter nodes that are not bifurcations or tips.
54
-
55
- Returns
56
- -------
57
- radial_distance : npt.NDArray[np.float32]
58
- Array of shape (N,), while N is the number of nodes.
59
- """
60
- return self._to_distribution(
61
- self.get_radial_distance(),
62
- step,
63
- filter_bifurcation=filter_bifurcation,
64
- filter_tip=filter_tip,
65
- filter_other=filter_other,
66
- )
67
-
68
33
  def get_branch_order(self) -> npt.NDArray[np.int32]:
69
34
  """Get branch order of tree.
70
35
 
71
- Bifurcation order is the number of bifurcations between current position
72
- and the root.
36
+ Bifurcation order is the number of bifurcations between current
37
+ position and the root.
73
38
 
74
39
  Returns
75
40
  -------
76
41
  order : npt.NDArray[np.int32]
77
- Array of shape (N,), while N is the number of nodes.
42
+ Array of shape (k,), which k is the number of branchs.
78
43
  """
79
44
  order = np.zeros_like(self._branch_tree.id(), dtype=np.int32)
80
45
 
@@ -86,69 +51,6 @@ class NodeFeatures:
86
51
  self._branch_tree.traverse(enter=assign_depth)
87
52
  return order
88
53
 
89
- def get_branch_order_distribution(
90
- self,
91
- step: int = 1,
92
- /,
93
- filter_bifurcation: bool = False,
94
- filter_tip: bool = False,
95
- filter_other: bool = True,
96
- ) -> XYPair:
97
- """Get branch order distribution of tree.
98
-
99
- Parameters
100
- ----------
101
- filter_bifurcation : bool, default `False`
102
- Filter bifurcation nodes.
103
- filter_tip : bool, default `False`
104
- Filter tip nodes.
105
- filter_other : bool, default `False`
106
- Filter nodes that are not bifurcations or tips.
107
- """
108
- return self._to_distribution(
109
- self.get_branch_order(),
110
- step,
111
- filter_bifurcation=filter_bifurcation,
112
- filter_tip=filter_tip,
113
- filter_other=filter_other,
114
- )
115
-
116
- def _to_distribution(
117
- self,
118
- x: npt.NDArray,
119
- step: float,
120
- /,
121
- filter_bifurcation: bool,
122
- filter_tip: bool,
123
- filter_other: bool,
124
- ) -> XYPair:
125
- if filter_bifurcation:
126
- x[self._bifurcations] = -1
127
-
128
- if filter_tip:
129
- x[self._tips] = -1
130
-
131
- if filter_other:
132
- x[self._other] = -1
133
-
134
- x = x[x != -1]
135
- return to_distribution(x, step)
136
-
137
54
  @cached_property
138
55
  def _branch_tree(self) -> BranchTree:
139
56
  return BranchTree.from_tree(self.tree)
140
-
141
- @cached_property
142
- def _bifurcations(self) -> List[int]:
143
- return [n.id for n in self.tree.get_bifurcations()]
144
-
145
- @cached_property
146
- def _tips(self) -> List[int]:
147
- return [n.id for n in self.tree.get_tips()]
148
-
149
- @cached_property
150
- def _other(self) -> npt.NDArray[np.int32]:
151
- other = self.tree.id()
152
- other = np.setdiff1d(other, self._bifurcations)
153
- other = np.setdiff1d(other, self._tips)
154
- return other