fibphot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fibphot/__init__.py +6 -0
- fibphot/analysis/__init__.py +0 -0
- fibphot/analysis/aggregate.py +257 -0
- fibphot/analysis/auc.py +354 -0
- fibphot/analysis/irls.py +350 -0
- fibphot/analysis/peaks.py +1163 -0
- fibphot/analysis/photobleaching.py +290 -0
- fibphot/analysis/plotting.py +105 -0
- fibphot/analysis/report.py +56 -0
- fibphot/collection.py +207 -0
- fibphot/fit/__init__.py +0 -0
- fibphot/fit/regression.py +269 -0
- fibphot/io/__init__.py +6 -0
- fibphot/io/doric.py +435 -0
- fibphot/io/excel.py +76 -0
- fibphot/io/h5.py +321 -0
- fibphot/misc.py +11 -0
- fibphot/peaks.py +628 -0
- fibphot/pipeline.py +14 -0
- fibphot/plotting.py +594 -0
- fibphot/stages/__init__.py +22 -0
- fibphot/stages/base.py +101 -0
- fibphot/stages/baseline.py +354 -0
- fibphot/stages/control_dff.py +214 -0
- fibphot/stages/filters.py +273 -0
- fibphot/stages/normalisation.py +260 -0
- fibphot/stages/regression.py +139 -0
- fibphot/stages/smooth.py +442 -0
- fibphot/stages/trim.py +141 -0
- fibphot/state.py +309 -0
- fibphot/tags.py +130 -0
- fibphot/types.py +6 -0
- fibphot-0.1.0.dist-info/METADATA +63 -0
- fibphot-0.1.0.dist-info/RECORD +37 -0
- fibphot-0.1.0.dist-info/WHEEL +5 -0
- fibphot-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- fibphot-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy.ndimage as ndi
|
|
8
|
+
import scipy.signal
|
|
9
|
+
|
|
10
|
+
from ..state import PhotometryState
|
|
11
|
+
from ..types import FloatArray
|
|
12
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _hampel_1d(
|
|
16
|
+
x: FloatArray,
|
|
17
|
+
window_size: int,
|
|
18
|
+
n_sigmas: float,
|
|
19
|
+
*,
|
|
20
|
+
mad_scale: float = 1.4826,
|
|
21
|
+
mode: str = "reflect",
|
|
22
|
+
match_edges: bool = True,
|
|
23
|
+
) -> FloatArray:
|
|
24
|
+
"""
|
|
25
|
+
Fast Hampel filter using rolling medians.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
mad_scale:
|
|
30
|
+
Scale factor so MAD estimates standard deviation under Normal noise.
|
|
31
|
+
mode:
|
|
32
|
+
Padding strategy for the rolling median ('reflect', 'nearest', ...).
|
|
33
|
+
match_edges:
|
|
34
|
+
If True, applies "shrinking window" behaviour at first/last k samples.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
if window_size < 3:
|
|
38
|
+
raise ValueError("window_size must be >= 3.")
|
|
39
|
+
if window_size % 2 == 0:
|
|
40
|
+
window_size += 1
|
|
41
|
+
|
|
42
|
+
x = np.asarray(x, dtype=float)
|
|
43
|
+
n = int(x.shape[0])
|
|
44
|
+
k = window_size // 2
|
|
45
|
+
|
|
46
|
+
# Rolling median
|
|
47
|
+
med = ndi.median_filter(x, size=window_size, mode=mode)
|
|
48
|
+
|
|
49
|
+
# Rolling MAD = median(|x - med|)
|
|
50
|
+
abs_dev = np.abs(x - med)
|
|
51
|
+
mad = mad_scale * ndi.median_filter(abs_dev, size=window_size, mode=mode)
|
|
52
|
+
|
|
53
|
+
out = x.copy()
|
|
54
|
+
mask = (mad > 1e-12) & (abs_dev > (n_sigmas * mad))
|
|
55
|
+
out[mask] = med[mask]
|
|
56
|
+
|
|
57
|
+
if match_edges and k > 0 and n > 0:
|
|
58
|
+
left = min(k, n)
|
|
59
|
+
for i in range(left):
|
|
60
|
+
lo = 0
|
|
61
|
+
hi = min(n, i + k + 1)
|
|
62
|
+
w = x[lo:hi]
|
|
63
|
+
m = float(np.median(w))
|
|
64
|
+
s = mad_scale * float(np.median(np.abs(w - m)))
|
|
65
|
+
if s > 1e-12 and abs(x[i] - m) > n_sigmas * s:
|
|
66
|
+
out[i] = m
|
|
67
|
+
|
|
68
|
+
right_start = max(0, n - k)
|
|
69
|
+
for i in range(right_start, n):
|
|
70
|
+
lo = max(0, i - k)
|
|
71
|
+
hi = n
|
|
72
|
+
w = x[lo:hi]
|
|
73
|
+
m = float(np.median(w))
|
|
74
|
+
s = mad_scale * float(np.median(np.abs(w - m)))
|
|
75
|
+
if s > 1e-12 and abs(x[i] - m) > n_sigmas * s:
|
|
76
|
+
out[i] = m
|
|
77
|
+
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(frozen=True, slots=True)
|
|
82
|
+
class HampelFilter(UpdateStage):
|
|
83
|
+
"""
|
|
84
|
+
Applies a Hampel filter to specified channels.
|
|
85
|
+
|
|
86
|
+
This implementation is fast: it uses rolling medians (SciPy) rather than
|
|
87
|
+
a Python loop over all samples.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
window_size:
|
|
92
|
+
Size of the moving window (forced odd and >= 3).
|
|
93
|
+
n_sigmas:
|
|
94
|
+
Threshold in units of (scaled) MAD.
|
|
95
|
+
channels:
|
|
96
|
+
"all", a channel name, a list of names, or None for all.
|
|
97
|
+
mad_scale:
|
|
98
|
+
Scale factor converting MAD to sigma under Normal noise.
|
|
99
|
+
mode:
|
|
100
|
+
Padding mode used by the rolling median.
|
|
101
|
+
match_edges:
|
|
102
|
+
If True, uses shrinking-window behaviour at the edges.
|
|
103
|
+
|
|
104
|
+
Context
|
|
105
|
+
-------
|
|
106
|
+
The Hampel filter is a robust method for outlier detection and correction
|
|
107
|
+
in time series data. It replaces outliers with the median of neighbouring
|
|
108
|
+
values within a specified window, making it effective for removing transient
|
|
109
|
+
spikes or noise without significantly distorting the underlying signal.
|
|
110
|
+
|
|
111
|
+
Compared to the `MedianFilter`, which replaces each point with the median of
|
|
112
|
+
its neighbours, the Hampel filter specifically targets outliers based on
|
|
113
|
+
their deviation from the local median. Hence, it does not alter the signal
|
|
114
|
+
unless an outlier is detected.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
name: str = field(default="hampel_filter", init=False)
|
|
118
|
+
|
|
119
|
+
window_size: int = 11
|
|
120
|
+
n_sigmas: float = 3.0
|
|
121
|
+
channels: str | list[str] | None = None
|
|
122
|
+
|
|
123
|
+
mad_scale: float = 1.4826
|
|
124
|
+
mode: str = "reflect"
|
|
125
|
+
match_edges: bool = True
|
|
126
|
+
|
|
127
|
+
def _params_for_summary(self) -> dict[str, object]:
|
|
128
|
+
return {
|
|
129
|
+
"window_size": self.window_size,
|
|
130
|
+
"n_sigmas": self.n_sigmas,
|
|
131
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
132
|
+
"mad_scale": self.mad_scale,
|
|
133
|
+
"mode": self.mode,
|
|
134
|
+
"match_edges": self.match_edges,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
138
|
+
idxs = _resolve_channels(state, self.channels)
|
|
139
|
+
new = state.signals.copy()
|
|
140
|
+
for i in idxs:
|
|
141
|
+
new[i] = _hampel_1d(
|
|
142
|
+
new[i],
|
|
143
|
+
self.window_size,
|
|
144
|
+
self.n_sigmas,
|
|
145
|
+
mad_scale=self.mad_scale,
|
|
146
|
+
mode=self.mode,
|
|
147
|
+
match_edges=self.match_edges,
|
|
148
|
+
)
|
|
149
|
+
return StageOutput(signals=new)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass(frozen=True, slots=True)
|
|
153
|
+
class MedianFilter(UpdateStage):
|
|
154
|
+
name: str = field(default="median_filter", init=False)
|
|
155
|
+
kernel_size: int = 5
|
|
156
|
+
channels: str | list[str] | None = None
|
|
157
|
+
|
|
158
|
+
def _params_for_summary(self) -> dict[str, object]:
|
|
159
|
+
return {
|
|
160
|
+
"kernel_size": self.kernel_size,
|
|
161
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
165
|
+
k = self.kernel_size + (self.kernel_size % 2 == 0)
|
|
166
|
+
idxs = _resolve_channels(state, self.channels)
|
|
167
|
+
|
|
168
|
+
new = state.signals.copy()
|
|
169
|
+
for i in idxs:
|
|
170
|
+
new[i] = scipy.signal.medfilt(new[i], kernel_size=k)
|
|
171
|
+
|
|
172
|
+
return StageOutput(signals=new)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass(frozen=True, slots=True)
|
|
176
|
+
class LowPassFilter(UpdateStage):
|
|
177
|
+
"""
|
|
178
|
+
Applies a zero-phase low-pass Butterworth filter to specified channels.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
critical_frequency : float
|
|
183
|
+
The critical frequency (in Hz) for the low-pass filter. This is where
|
|
184
|
+
the filter begins to attenuate higher frequencies.
|
|
185
|
+
order : int
|
|
186
|
+
The order of the Butterworth filter. Higher order filters have a
|
|
187
|
+
steeper roll-off.
|
|
188
|
+
sampling_rate : float | None
|
|
189
|
+
The sampling rate (in Hz) of the input signals. If None, uses the
|
|
190
|
+
sampling rate from the PhotometryState.
|
|
191
|
+
channels : str | list[str] | None
|
|
192
|
+
The channels to which the filter should be applied. Can be "all", a
|
|
193
|
+
single channel name, or a list of channel names. If None, defaults to
|
|
194
|
+
"all".
|
|
195
|
+
representation : Literal["sos", "ba"]
|
|
196
|
+
The filter representation to use. "sos" for second-order sections
|
|
197
|
+
(numerically stable), or "ba" for (b, a) coefficients.
|
|
198
|
+
|
|
199
|
+
Context
|
|
200
|
+
-------
|
|
201
|
+
Biosensor kinetics typically operate on slower (e.g., sub-second) timescales
|
|
202
|
+
relative to higher-frequency electrical noise. A low-pass filter keeps low
|
|
203
|
+
frequencies and attenuates high frequencies.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
name: str = field(default="low_pass_filter", init=False)
|
|
207
|
+
critical_frequency: float = 10.0
|
|
208
|
+
order: int = 2
|
|
209
|
+
sampling_rate: float | None = None
|
|
210
|
+
channels: str | list[str] | None = None
|
|
211
|
+
representation: Literal["sos", "ba"] = "sos"
|
|
212
|
+
|
|
213
|
+
def _params_for_summary(self) -> dict[str, object]:
|
|
214
|
+
return {
|
|
215
|
+
"critical_frequency": self.critical_frequency,
|
|
216
|
+
"order": self.order,
|
|
217
|
+
"sampling_rate": self.sampling_rate,
|
|
218
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
219
|
+
"representation": self.representation,
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
223
|
+
fs = (
|
|
224
|
+
state.sampling_rate
|
|
225
|
+
if self.sampling_rate is None
|
|
226
|
+
else float(self.sampling_rate)
|
|
227
|
+
)
|
|
228
|
+
if not (0.0 < self.critical_frequency < 0.5 * fs):
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"critical_frequency must be > 0 and < Nyquist (fs/2). "
|
|
231
|
+
f"Got critical_frequency={self.critical_frequency}, fs={fs}."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
idxs = _resolve_channels(state, self.channels)
|
|
235
|
+
new = state.signals.copy()
|
|
236
|
+
|
|
237
|
+
if self.representation == "sos":
|
|
238
|
+
sos = scipy.signal.butter(
|
|
239
|
+
N=self.order,
|
|
240
|
+
Wn=self.critical_frequency,
|
|
241
|
+
btype="low",
|
|
242
|
+
fs=fs,
|
|
243
|
+
output="sos",
|
|
244
|
+
)
|
|
245
|
+
for i in idxs:
|
|
246
|
+
new[i] = scipy.signal.sosfiltfilt(sos, new[i])
|
|
247
|
+
return StageOutput(signals=new)
|
|
248
|
+
|
|
249
|
+
if self.representation == "ba":
|
|
250
|
+
res = scipy.signal.butter(
|
|
251
|
+
N=self.order,
|
|
252
|
+
Wn=self.critical_frequency,
|
|
253
|
+
btype="low",
|
|
254
|
+
fs=fs,
|
|
255
|
+
output="ba",
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if res is None:
|
|
259
|
+
raise RuntimeError(
|
|
260
|
+
"scipy.signal.butter returned None; check filter params."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
assert len(res) == 2, (
|
|
264
|
+
"Expected (b,a) tuple from scipy.signal.butter."
|
|
265
|
+
)
|
|
266
|
+
b, a = res
|
|
267
|
+
|
|
268
|
+
for i in idxs:
|
|
269
|
+
new[i] = scipy.signal.filtfilt(b, a, new[i])
|
|
270
|
+
|
|
271
|
+
return StageOutput(signals=new)
|
|
272
|
+
|
|
273
|
+
raise ValueError(f"Unknown representation: {self.representation!r}")
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..state import PhotometryState
|
|
9
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
10
|
+
|
|
11
|
+
NormaliseMethod = Literal["baseline", "z_score", "null_z"]
|
|
12
|
+
BaselineMode = Literal["dff", "percent"]
|
|
13
|
+
NullZScale = Literal["rms", "mad"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _window_mask(
|
|
17
|
+
state: PhotometryState,
|
|
18
|
+
time_window: tuple[float, float] | None,
|
|
19
|
+
) -> np.ndarray:
|
|
20
|
+
if time_window is None:
|
|
21
|
+
return np.ones(state.n_samples, dtype=bool)
|
|
22
|
+
|
|
23
|
+
t0, t1 = time_window
|
|
24
|
+
if t1 < t0:
|
|
25
|
+
raise ValueError("time_window must satisfy t0 <= t1.")
|
|
26
|
+
|
|
27
|
+
mask = (state.time_seconds >= t0) & (state.time_seconds <= t1)
|
|
28
|
+
if not np.any(mask):
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"time_window={time_window} selects no samples; check your time range."
|
|
31
|
+
)
|
|
32
|
+
return mask
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True, slots=True)
|
|
36
|
+
class Normalise(UpdateStage):
|
|
37
|
+
"""
|
|
38
|
+
Normalise photometry signals using one of several common schemes.
|
|
39
|
+
|
|
40
|
+
Use the class constructors for clarity:
|
|
41
|
+
|
|
42
|
+
Normalise.baseline(...)
|
|
43
|
+
Normalise.z_score(...)
|
|
44
|
+
Normalise.null_z(...)
|
|
45
|
+
|
|
46
|
+
Notes
|
|
47
|
+
-----
|
|
48
|
+
This stage always operates on `state.signals` as they currently stand.
|
|
49
|
+
For baseline normalisation, this typically means you should run motion
|
|
50
|
+
correction first so your signals represent dF.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
name: str = field(default="normalise", init=False)
|
|
54
|
+
|
|
55
|
+
method: NormaliseMethod = "baseline"
|
|
56
|
+
channels: str | list[str] | None = None
|
|
57
|
+
|
|
58
|
+
# baseline normalisation
|
|
59
|
+
baseline_key: str | None = "double_exp_baseline"
|
|
60
|
+
baseline_mode: BaselineMode = "percent"
|
|
61
|
+
|
|
62
|
+
# z-score / null-z window
|
|
63
|
+
time_window: tuple[float, float] | None = None
|
|
64
|
+
ddof: int = 0
|
|
65
|
+
|
|
66
|
+
# null-z options
|
|
67
|
+
null_z_scale: NullZScale = "rms"
|
|
68
|
+
mad_scale: float = 1.4826
|
|
69
|
+
|
|
70
|
+
# numerical safety
|
|
71
|
+
eps: float = 1e-12
|
|
72
|
+
|
|
73
|
+
def __post_init__(self) -> None:
|
|
74
|
+
if self.method == "baseline":
|
|
75
|
+
if not self.baseline_key:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"baseline_key must be set when method='baseline'."
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
# baseline parameters should not be used for non-baseline methods
|
|
81
|
+
if self.baseline_key not in (None, "double_exp_baseline"):
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"baseline_key is only valid when method='baseline'. "
|
|
84
|
+
"Use Normalise.baseline(...)."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if self.ddof < 0:
|
|
88
|
+
raise ValueError("ddof must be >= 0.")
|
|
89
|
+
if self.eps <= 0:
|
|
90
|
+
raise ValueError("eps must be > 0.")
|
|
91
|
+
if self.mad_scale <= 0:
|
|
92
|
+
raise ValueError("mad_scale must be > 0.")
|
|
93
|
+
if self.time_window is not None:
|
|
94
|
+
t0, t1 = self.time_window
|
|
95
|
+
if t1 < t0:
|
|
96
|
+
raise ValueError("time_window must satisfy t0 <= t1.")
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def baseline(
|
|
100
|
+
cls,
|
|
101
|
+
*,
|
|
102
|
+
baseline_key: str = "double_exp_baseline",
|
|
103
|
+
mode: BaselineMode = "percent",
|
|
104
|
+
channels: str | list[str] | None = None,
|
|
105
|
+
eps: float = 1e-12,
|
|
106
|
+
) -> Normalise:
|
|
107
|
+
return cls(
|
|
108
|
+
method="baseline",
|
|
109
|
+
channels=channels,
|
|
110
|
+
baseline_key=baseline_key,
|
|
111
|
+
baseline_mode=mode,
|
|
112
|
+
eps=eps,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def z_score(
|
|
117
|
+
cls,
|
|
118
|
+
*,
|
|
119
|
+
channels: str | list[str] | None = None,
|
|
120
|
+
time_window: tuple[float, float] | None = None,
|
|
121
|
+
ddof: int = 0,
|
|
122
|
+
eps: float = 1e-12,
|
|
123
|
+
) -> Normalise:
|
|
124
|
+
return cls(
|
|
125
|
+
method="z_score",
|
|
126
|
+
channels=channels,
|
|
127
|
+
time_window=time_window,
|
|
128
|
+
ddof=ddof,
|
|
129
|
+
eps=eps,
|
|
130
|
+
baseline_key=None,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def null_z(
|
|
135
|
+
cls,
|
|
136
|
+
*,
|
|
137
|
+
channels: str | list[str] | None = None,
|
|
138
|
+
time_window: tuple[float, float] | None = None,
|
|
139
|
+
scale: NullZScale = "rms",
|
|
140
|
+
mad_scale: float = 1.4826,
|
|
141
|
+
eps: float = 1e-12,
|
|
142
|
+
) -> Normalise:
|
|
143
|
+
return cls(
|
|
144
|
+
method="null_z",
|
|
145
|
+
channels=channels,
|
|
146
|
+
time_window=time_window,
|
|
147
|
+
null_z_scale=scale,
|
|
148
|
+
mad_scale=mad_scale,
|
|
149
|
+
eps=eps,
|
|
150
|
+
baseline_key=None,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
154
|
+
return {
|
|
155
|
+
"method": self.method,
|
|
156
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
157
|
+
"baseline_key": self.baseline_key,
|
|
158
|
+
"baseline_mode": self.baseline_mode,
|
|
159
|
+
"time_window": self.time_window,
|
|
160
|
+
"ddof": self.ddof,
|
|
161
|
+
"null_z_scale": self.null_z_scale,
|
|
162
|
+
"mad_scale": self.mad_scale,
|
|
163
|
+
"eps": self.eps,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
167
|
+
idxs = _resolve_channels(state, self.channels)
|
|
168
|
+
new = state.signals.copy()
|
|
169
|
+
|
|
170
|
+
if self.method == "baseline":
|
|
171
|
+
assert self.baseline_key is not None # for type-checkers
|
|
172
|
+
|
|
173
|
+
if self.baseline_key not in state.derived:
|
|
174
|
+
raise KeyError(
|
|
175
|
+
f"Baseline '{self.baseline_key}' not found in state.derived. "
|
|
176
|
+
"Run the stage that produces this baseline first."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
baseline = np.asarray(state.derived[self.baseline_key], dtype=float)
|
|
180
|
+
if baseline.shape != state.signals.shape:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Baseline shape {baseline.shape} does not match signals shape "
|
|
183
|
+
f"{state.signals.shape}."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
scale = 100.0 if self.baseline_mode == "percent" else 1.0
|
|
187
|
+
for i in idxs:
|
|
188
|
+
denom = baseline[i]
|
|
189
|
+
denom = np.where(np.abs(denom) < self.eps, np.nan, denom)
|
|
190
|
+
new[i] = scale * (new[i] / denom)
|
|
191
|
+
|
|
192
|
+
return StageOutput(
|
|
193
|
+
signals=new,
|
|
194
|
+
results={
|
|
195
|
+
"method": "baseline",
|
|
196
|
+
"baseline_key": self.baseline_key,
|
|
197
|
+
"baseline_mode": self.baseline_mode,
|
|
198
|
+
"channels_normalised": idxs,
|
|
199
|
+
},
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
mask = _window_mask(state, self.time_window)
|
|
203
|
+
|
|
204
|
+
if self.method == "z_score":
|
|
205
|
+
means: dict[str, float] = {}
|
|
206
|
+
stds: dict[str, float] = {}
|
|
207
|
+
|
|
208
|
+
for i in idxs:
|
|
209
|
+
x = new[i]
|
|
210
|
+
mu = float(np.nanmean(x[mask]))
|
|
211
|
+
sd = float(np.nanstd(x[mask], ddof=self.ddof))
|
|
212
|
+
if not np.isfinite(sd) or sd < self.eps:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"Standard deviation too small/invalid for channel "
|
|
215
|
+
f"'{state.channel_names[i]}': {sd}."
|
|
216
|
+
)
|
|
217
|
+
new[i] = (x - mu) / sd
|
|
218
|
+
means[state.channel_names[i]] = mu
|
|
219
|
+
stds[state.channel_names[i]] = sd
|
|
220
|
+
|
|
221
|
+
return StageOutput(
|
|
222
|
+
signals=new,
|
|
223
|
+
results={
|
|
224
|
+
"method": "z_score",
|
|
225
|
+
"means": means,
|
|
226
|
+
"stds": stds,
|
|
227
|
+
"time_window": self.time_window,
|
|
228
|
+
"ddof": self.ddof,
|
|
229
|
+
},
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# null_z
|
|
233
|
+
scales: dict[str, float] = {}
|
|
234
|
+
for i in idxs:
|
|
235
|
+
x = new[i]
|
|
236
|
+
xm = x[mask]
|
|
237
|
+
|
|
238
|
+
if self.null_z_scale == "rms":
|
|
239
|
+
s0 = float(np.sqrt(np.nanmean(xm * xm)))
|
|
240
|
+
else:
|
|
241
|
+
s0 = float(self.mad_scale * np.nanmedian(np.abs(xm)))
|
|
242
|
+
|
|
243
|
+
if not np.isfinite(s0) or s0 < self.eps:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Null-Z scale too small/invalid for channel "
|
|
246
|
+
f"'{state.channel_names[i]}': {s0}."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
new[i] = x / s0
|
|
250
|
+
scales[state.channel_names[i]] = s0
|
|
251
|
+
|
|
252
|
+
return StageOutput(
|
|
253
|
+
signals=new,
|
|
254
|
+
results={
|
|
255
|
+
"method": "null_z",
|
|
256
|
+
"null_z_scale": self.null_z_scale,
|
|
257
|
+
"scales": scales,
|
|
258
|
+
"time_window": self.time_window,
|
|
259
|
+
},
|
|
260
|
+
)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..fit.regression import fit_irls, fit_ols
|
|
9
|
+
from ..state import PhotometryState
|
|
10
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
11
|
+
|
|
12
|
+
RegressionMethod = Literal["ols", "irls_tukey", "irls_huber"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True, slots=True)
|
|
16
|
+
class IsosbesticRegression(UpdateStage):
|
|
17
|
+
"""
|
|
18
|
+
Regress a control channel (typically isosbestic) onto one or more channels.
|
|
19
|
+
|
|
20
|
+
For each target channel y and control x, fit:
|
|
21
|
+
|
|
22
|
+
y ≈ intercept + slope * x
|
|
23
|
+
|
|
24
|
+
Output:
|
|
25
|
+
dF = y - y_hat
|
|
26
|
+
|
|
27
|
+
Also stores:
|
|
28
|
+
derived["motion_fit"] = y_hat (per channel; shape matches signals)
|
|
29
|
+
|
|
30
|
+
Notes
|
|
31
|
+
-----
|
|
32
|
+
`motion_fit` is a nuisance estimate used for subtraction/diagnostics. It is
|
|
33
|
+
not necessarily suitable as a denominator for dF/F, especially if signals
|
|
34
|
+
have been detrended (e.g. double exponential subtraction).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
name: str = field(default="isosbestic_regression", init=False)
|
|
38
|
+
|
|
39
|
+
control: str = "iso"
|
|
40
|
+
channels: str | list[str] | None = None
|
|
41
|
+
|
|
42
|
+
method: RegressionMethod = "irls_tukey"
|
|
43
|
+
include_intercept: bool = True
|
|
44
|
+
|
|
45
|
+
# IRLS settings
|
|
46
|
+
tuning_constant: float = 4.685
|
|
47
|
+
max_iter: int = 100
|
|
48
|
+
tol: float = 1e-10
|
|
49
|
+
store_weights: bool = False
|
|
50
|
+
|
|
51
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"control": self.control,
|
|
54
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
55
|
+
"method": self.method,
|
|
56
|
+
"include_intercept": self.include_intercept,
|
|
57
|
+
"tuning_constant": self.tuning_constant,
|
|
58
|
+
"max_iter": self.max_iter,
|
|
59
|
+
"tol": self.tol,
|
|
60
|
+
"store_weights": self.store_weights,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
64
|
+
control_idx = state.idx(self.control)
|
|
65
|
+
x = state.signals[control_idx]
|
|
66
|
+
|
|
67
|
+
idxs = _resolve_channels(state, self.channels)
|
|
68
|
+
idxs = [i for i in idxs if i != control_idx]
|
|
69
|
+
if not idxs:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"No target channels remain after excluding the control channel."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
new = state.signals.copy()
|
|
75
|
+
motion_fit = np.full_like(state.signals, np.nan, dtype=float)
|
|
76
|
+
|
|
77
|
+
per_channel: dict[str, dict[str, Any]] = {}
|
|
78
|
+
r2s: list[float] = []
|
|
79
|
+
|
|
80
|
+
for i in idxs:
|
|
81
|
+
name = state.channel_names[i]
|
|
82
|
+
y = state.signals[i]
|
|
83
|
+
|
|
84
|
+
if self.method == "ols":
|
|
85
|
+
fit = fit_ols(x, y, include_intercept=self.include_intercept)
|
|
86
|
+
max_iter: int | None = None
|
|
87
|
+
else:
|
|
88
|
+
loss = "tukey" if self.method == "irls_tukey" else "huber"
|
|
89
|
+
fit = fit_irls(
|
|
90
|
+
x,
|
|
91
|
+
y,
|
|
92
|
+
include_intercept=self.include_intercept,
|
|
93
|
+
loss=loss,
|
|
94
|
+
tuning_constant=self.tuning_constant,
|
|
95
|
+
max_iter=self.max_iter,
|
|
96
|
+
tol=self.tol,
|
|
97
|
+
store_weights=self.store_weights,
|
|
98
|
+
)
|
|
99
|
+
max_iter = self.max_iter
|
|
100
|
+
|
|
101
|
+
y_hat = fit.fitted
|
|
102
|
+
motion_fit[i] = y_hat
|
|
103
|
+
new[i] = y - y_hat
|
|
104
|
+
|
|
105
|
+
per_channel[name] = {
|
|
106
|
+
"control": self.control,
|
|
107
|
+
"intercept": fit.intercept,
|
|
108
|
+
"slope": fit.slope,
|
|
109
|
+
"r2": fit.r2,
|
|
110
|
+
"method": fit.method,
|
|
111
|
+
"n_iter": fit.n_iter,
|
|
112
|
+
"max_iter": max_iter,
|
|
113
|
+
"tuning_constant": fit.tuning_constant,
|
|
114
|
+
"scale": fit.scale,
|
|
115
|
+
"weights": fit.weights,
|
|
116
|
+
}
|
|
117
|
+
if np.isfinite(fit.r2):
|
|
118
|
+
r2s.append(float(fit.r2))
|
|
119
|
+
|
|
120
|
+
metrics: dict[str, float] = {}
|
|
121
|
+
if r2s:
|
|
122
|
+
metrics["mean_r2"] = float(np.mean(r2s))
|
|
123
|
+
metrics["median_r2"] = float(np.median(r2s))
|
|
124
|
+
|
|
125
|
+
return StageOutput(
|
|
126
|
+
signals=new,
|
|
127
|
+
derived={
|
|
128
|
+
"motion_fit": motion_fit,
|
|
129
|
+
},
|
|
130
|
+
results={
|
|
131
|
+
"control": self.control,
|
|
132
|
+
"control_idx": control_idx,
|
|
133
|
+
"channels_fitted": idxs,
|
|
134
|
+
"method": self.method,
|
|
135
|
+
"include_intercept": self.include_intercept,
|
|
136
|
+
"channels": per_channel,
|
|
137
|
+
},
|
|
138
|
+
metrics=metrics,
|
|
139
|
+
)
|