fibphot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,273 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ import scipy.ndimage as ndi
8
+ import scipy.signal
9
+
10
+ from ..state import PhotometryState
11
+ from ..types import FloatArray
12
+ from .base import StageOutput, UpdateStage, _resolve_channels
13
+
14
+
15
+ def _hampel_1d(
16
+ x: FloatArray,
17
+ window_size: int,
18
+ n_sigmas: float,
19
+ *,
20
+ mad_scale: float = 1.4826,
21
+ mode: str = "reflect",
22
+ match_edges: bool = True,
23
+ ) -> FloatArray:
24
+ """
25
+ Fast Hampel filter using rolling medians.
26
+
27
+ Parameters
28
+ ----------
29
+ mad_scale:
30
+ Scale factor so MAD estimates standard deviation under Normal noise.
31
+ mode:
32
+ Padding strategy for the rolling median ('reflect', 'nearest', ...).
33
+ match_edges:
34
+ If True, applies "shrinking window" behaviour at first/last k samples.
35
+ """
36
+
37
+ if window_size < 3:
38
+ raise ValueError("window_size must be >= 3.")
39
+ if window_size % 2 == 0:
40
+ window_size += 1
41
+
42
+ x = np.asarray(x, dtype=float)
43
+ n = int(x.shape[0])
44
+ k = window_size // 2
45
+
46
+ # Rolling median
47
+ med = ndi.median_filter(x, size=window_size, mode=mode)
48
+
49
+ # Rolling MAD = median(|x - med|)
50
+ abs_dev = np.abs(x - med)
51
+ mad = mad_scale * ndi.median_filter(abs_dev, size=window_size, mode=mode)
52
+
53
+ out = x.copy()
54
+ mask = (mad > 1e-12) & (abs_dev > (n_sigmas * mad))
55
+ out[mask] = med[mask]
56
+
57
+ if match_edges and k > 0 and n > 0:
58
+ left = min(k, n)
59
+ for i in range(left):
60
+ lo = 0
61
+ hi = min(n, i + k + 1)
62
+ w = x[lo:hi]
63
+ m = float(np.median(w))
64
+ s = mad_scale * float(np.median(np.abs(w - m)))
65
+ if s > 1e-12 and abs(x[i] - m) > n_sigmas * s:
66
+ out[i] = m
67
+
68
+ right_start = max(0, n - k)
69
+ for i in range(right_start, n):
70
+ lo = max(0, i - k)
71
+ hi = n
72
+ w = x[lo:hi]
73
+ m = float(np.median(w))
74
+ s = mad_scale * float(np.median(np.abs(w - m)))
75
+ if s > 1e-12 and abs(x[i] - m) > n_sigmas * s:
76
+ out[i] = m
77
+
78
+ return out
79
+
80
+
81
+ @dataclass(frozen=True, slots=True)
82
+ class HampelFilter(UpdateStage):
83
+ """
84
+ Applies a Hampel filter to specified channels.
85
+
86
+ This implementation is fast: it uses rolling medians (SciPy) rather than
87
+ a Python loop over all samples.
88
+
89
+ Parameters
90
+ ----------
91
+ window_size:
92
+ Size of the moving window (forced odd and >= 3).
93
+ n_sigmas:
94
+ Threshold in units of (scaled) MAD.
95
+ channels:
96
+ "all", a channel name, a list of names, or None for all.
97
+ mad_scale:
98
+ Scale factor converting MAD to sigma under Normal noise.
99
+ mode:
100
+ Padding mode used by the rolling median.
101
+ match_edges:
102
+ If True, uses shrinking-window behaviour at the edges.
103
+
104
+ Context
105
+ -------
106
+ The Hampel filter is a robust method for outlier detection and correction
107
+ in time series data. It replaces outliers with the median of neighbouring
108
+ values within a specified window, making it effective for removing transient
109
+ spikes or noise without significantly distorting the underlying signal.
110
+
111
+ Compared to the `MedianFilter`, which replaces each point with the median of
112
+ its neighbours, the Hampel filter specifically targets outliers based on
113
+ their deviation from the local median. Hence, it does not alter the signal
114
+ unless an outlier is detected.
115
+ """
116
+
117
+ name: str = field(default="hampel_filter", init=False)
118
+
119
+ window_size: int = 11
120
+ n_sigmas: float = 3.0
121
+ channels: str | list[str] | None = None
122
+
123
+ mad_scale: float = 1.4826
124
+ mode: str = "reflect"
125
+ match_edges: bool = True
126
+
127
+ def _params_for_summary(self) -> dict[str, object]:
128
+ return {
129
+ "window_size": self.window_size,
130
+ "n_sigmas": self.n_sigmas,
131
+ "channels": self.channels if self.channels is not None else "all",
132
+ "mad_scale": self.mad_scale,
133
+ "mode": self.mode,
134
+ "match_edges": self.match_edges,
135
+ }
136
+
137
+ def apply(self, state: PhotometryState) -> StageOutput:
138
+ idxs = _resolve_channels(state, self.channels)
139
+ new = state.signals.copy()
140
+ for i in idxs:
141
+ new[i] = _hampel_1d(
142
+ new[i],
143
+ self.window_size,
144
+ self.n_sigmas,
145
+ mad_scale=self.mad_scale,
146
+ mode=self.mode,
147
+ match_edges=self.match_edges,
148
+ )
149
+ return StageOutput(signals=new)
150
+
151
+
152
+ @dataclass(frozen=True, slots=True)
153
+ class MedianFilter(UpdateStage):
154
+ name: str = field(default="median_filter", init=False)
155
+ kernel_size: int = 5
156
+ channels: str | list[str] | None = None
157
+
158
+ def _params_for_summary(self) -> dict[str, object]:
159
+ return {
160
+ "kernel_size": self.kernel_size,
161
+ "channels": self.channels if self.channels is not None else "all",
162
+ }
163
+
164
+ def apply(self, state: PhotometryState) -> StageOutput:
165
+ k = self.kernel_size + (self.kernel_size % 2 == 0)
166
+ idxs = _resolve_channels(state, self.channels)
167
+
168
+ new = state.signals.copy()
169
+ for i in idxs:
170
+ new[i] = scipy.signal.medfilt(new[i], kernel_size=k)
171
+
172
+ return StageOutput(signals=new)
173
+
174
+
175
+ @dataclass(frozen=True, slots=True)
176
+ class LowPassFilter(UpdateStage):
177
+ """
178
+ Applies a zero-phase low-pass Butterworth filter to specified channels.
179
+
180
+ Parameters
181
+ ----------
182
+ critical_frequency : float
183
+ The critical frequency (in Hz) for the low-pass filter. This is where
184
+ the filter begins to attenuate higher frequencies.
185
+ order : int
186
+ The order of the Butterworth filter. Higher order filters have a
187
+ steeper roll-off.
188
+ sampling_rate : float | None
189
+ The sampling rate (in Hz) of the input signals. If None, uses the
190
+ sampling rate from the PhotometryState.
191
+ channels : str | list[str] | None
192
+ The channels to which the filter should be applied. Can be "all", a
193
+ single channel name, or a list of channel names. If None, defaults to
194
+ "all".
195
+ representation : Literal["sos", "ba"]
196
+ The filter representation to use. "sos" for second-order sections
197
+ (numerically stable), or "ba" for (b, a) coefficients.
198
+
199
+ Context
200
+ -------
201
+ Biosensor kinetics typically operate on slower (e.g., sub-second) timescales
202
+ relative to higher-frequency electrical noise. A low-pass filter keeps low
203
+ frequencies and attenuates high frequencies.
204
+ """
205
+
206
+ name: str = field(default="low_pass_filter", init=False)
207
+ critical_frequency: float = 10.0
208
+ order: int = 2
209
+ sampling_rate: float | None = None
210
+ channels: str | list[str] | None = None
211
+ representation: Literal["sos", "ba"] = "sos"
212
+
213
+ def _params_for_summary(self) -> dict[str, object]:
214
+ return {
215
+ "critical_frequency": self.critical_frequency,
216
+ "order": self.order,
217
+ "sampling_rate": self.sampling_rate,
218
+ "channels": self.channels if self.channels is not None else "all",
219
+ "representation": self.representation,
220
+ }
221
+
222
+ def apply(self, state: PhotometryState) -> StageOutput:
223
+ fs = (
224
+ state.sampling_rate
225
+ if self.sampling_rate is None
226
+ else float(self.sampling_rate)
227
+ )
228
+ if not (0.0 < self.critical_frequency < 0.5 * fs):
229
+ raise ValueError(
230
+ "critical_frequency must be > 0 and < Nyquist (fs/2). "
231
+ f"Got critical_frequency={self.critical_frequency}, fs={fs}."
232
+ )
233
+
234
+ idxs = _resolve_channels(state, self.channels)
235
+ new = state.signals.copy()
236
+
237
+ if self.representation == "sos":
238
+ sos = scipy.signal.butter(
239
+ N=self.order,
240
+ Wn=self.critical_frequency,
241
+ btype="low",
242
+ fs=fs,
243
+ output="sos",
244
+ )
245
+ for i in idxs:
246
+ new[i] = scipy.signal.sosfiltfilt(sos, new[i])
247
+ return StageOutput(signals=new)
248
+
249
+ if self.representation == "ba":
250
+ res = scipy.signal.butter(
251
+ N=self.order,
252
+ Wn=self.critical_frequency,
253
+ btype="low",
254
+ fs=fs,
255
+ output="ba",
256
+ )
257
+
258
+ if res is None:
259
+ raise RuntimeError(
260
+ "scipy.signal.butter returned None; check filter params."
261
+ )
262
+
263
+ assert len(res) == 2, (
264
+ "Expected (b,a) tuple from scipy.signal.butter."
265
+ )
266
+ b, a = res
267
+
268
+ for i in idxs:
269
+ new[i] = scipy.signal.filtfilt(b, a, new[i])
270
+
271
+ return StageOutput(signals=new)
272
+
273
+ raise ValueError(f"Unknown representation: {self.representation!r}")
@@ -0,0 +1,260 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Literal
5
+
6
+ import numpy as np
7
+
8
+ from ..state import PhotometryState
9
+ from .base import StageOutput, UpdateStage, _resolve_channels
10
+
11
+ NormaliseMethod = Literal["baseline", "z_score", "null_z"]
12
+ BaselineMode = Literal["dff", "percent"]
13
+ NullZScale = Literal["rms", "mad"]
14
+
15
+
16
+ def _window_mask(
17
+ state: PhotometryState,
18
+ time_window: tuple[float, float] | None,
19
+ ) -> np.ndarray:
20
+ if time_window is None:
21
+ return np.ones(state.n_samples, dtype=bool)
22
+
23
+ t0, t1 = time_window
24
+ if t1 < t0:
25
+ raise ValueError("time_window must satisfy t0 <= t1.")
26
+
27
+ mask = (state.time_seconds >= t0) & (state.time_seconds <= t1)
28
+ if not np.any(mask):
29
+ raise ValueError(
30
+ f"time_window={time_window} selects no samples; check your time range."
31
+ )
32
+ return mask
33
+
34
+
35
+ @dataclass(frozen=True, slots=True)
36
+ class Normalise(UpdateStage):
37
+ """
38
+ Normalise photometry signals using one of several common schemes.
39
+
40
+ Use the class constructors for clarity:
41
+
42
+ Normalise.baseline(...)
43
+ Normalise.z_score(...)
44
+ Normalise.null_z(...)
45
+
46
+ Notes
47
+ -----
48
+ This stage always operates on `state.signals` as they currently stand.
49
+ For baseline normalisation, this typically means you should run motion
50
+ correction first so your signals represent dF.
51
+ """
52
+
53
+ name: str = field(default="normalise", init=False)
54
+
55
+ method: NormaliseMethod = "baseline"
56
+ channels: str | list[str] | None = None
57
+
58
+ # baseline normalisation
59
+ baseline_key: str | None = "double_exp_baseline"
60
+ baseline_mode: BaselineMode = "percent"
61
+
62
+ # z-score / null-z window
63
+ time_window: tuple[float, float] | None = None
64
+ ddof: int = 0
65
+
66
+ # null-z options
67
+ null_z_scale: NullZScale = "rms"
68
+ mad_scale: float = 1.4826
69
+
70
+ # numerical safety
71
+ eps: float = 1e-12
72
+
73
+ def __post_init__(self) -> None:
74
+ if self.method == "baseline":
75
+ if not self.baseline_key:
76
+ raise ValueError(
77
+ "baseline_key must be set when method='baseline'."
78
+ )
79
+ else:
80
+ # baseline parameters should not be used for non-baseline methods
81
+ if self.baseline_key not in (None, "double_exp_baseline"):
82
+ raise ValueError(
83
+ "baseline_key is only valid when method='baseline'. "
84
+ "Use Normalise.baseline(...)."
85
+ )
86
+
87
+ if self.ddof < 0:
88
+ raise ValueError("ddof must be >= 0.")
89
+ if self.eps <= 0:
90
+ raise ValueError("eps must be > 0.")
91
+ if self.mad_scale <= 0:
92
+ raise ValueError("mad_scale must be > 0.")
93
+ if self.time_window is not None:
94
+ t0, t1 = self.time_window
95
+ if t1 < t0:
96
+ raise ValueError("time_window must satisfy t0 <= t1.")
97
+
98
+ @classmethod
99
+ def baseline(
100
+ cls,
101
+ *,
102
+ baseline_key: str = "double_exp_baseline",
103
+ mode: BaselineMode = "percent",
104
+ channels: str | list[str] | None = None,
105
+ eps: float = 1e-12,
106
+ ) -> Normalise:
107
+ return cls(
108
+ method="baseline",
109
+ channels=channels,
110
+ baseline_key=baseline_key,
111
+ baseline_mode=mode,
112
+ eps=eps,
113
+ )
114
+
115
+ @classmethod
116
+ def z_score(
117
+ cls,
118
+ *,
119
+ channels: str | list[str] | None = None,
120
+ time_window: tuple[float, float] | None = None,
121
+ ddof: int = 0,
122
+ eps: float = 1e-12,
123
+ ) -> Normalise:
124
+ return cls(
125
+ method="z_score",
126
+ channels=channels,
127
+ time_window=time_window,
128
+ ddof=ddof,
129
+ eps=eps,
130
+ baseline_key=None,
131
+ )
132
+
133
+ @classmethod
134
+ def null_z(
135
+ cls,
136
+ *,
137
+ channels: str | list[str] | None = None,
138
+ time_window: tuple[float, float] | None = None,
139
+ scale: NullZScale = "rms",
140
+ mad_scale: float = 1.4826,
141
+ eps: float = 1e-12,
142
+ ) -> Normalise:
143
+ return cls(
144
+ method="null_z",
145
+ channels=channels,
146
+ time_window=time_window,
147
+ null_z_scale=scale,
148
+ mad_scale=mad_scale,
149
+ eps=eps,
150
+ baseline_key=None,
151
+ )
152
+
153
+ def _params_for_summary(self) -> dict[str, Any]:
154
+ return {
155
+ "method": self.method,
156
+ "channels": self.channels if self.channels is not None else "all",
157
+ "baseline_key": self.baseline_key,
158
+ "baseline_mode": self.baseline_mode,
159
+ "time_window": self.time_window,
160
+ "ddof": self.ddof,
161
+ "null_z_scale": self.null_z_scale,
162
+ "mad_scale": self.mad_scale,
163
+ "eps": self.eps,
164
+ }
165
+
166
+ def apply(self, state: PhotometryState) -> StageOutput:
167
+ idxs = _resolve_channels(state, self.channels)
168
+ new = state.signals.copy()
169
+
170
+ if self.method == "baseline":
171
+ assert self.baseline_key is not None # for type-checkers
172
+
173
+ if self.baseline_key not in state.derived:
174
+ raise KeyError(
175
+ f"Baseline '{self.baseline_key}' not found in state.derived. "
176
+ "Run the stage that produces this baseline first."
177
+ )
178
+
179
+ baseline = np.asarray(state.derived[self.baseline_key], dtype=float)
180
+ if baseline.shape != state.signals.shape:
181
+ raise ValueError(
182
+ f"Baseline shape {baseline.shape} does not match signals shape "
183
+ f"{state.signals.shape}."
184
+ )
185
+
186
+ scale = 100.0 if self.baseline_mode == "percent" else 1.0
187
+ for i in idxs:
188
+ denom = baseline[i]
189
+ denom = np.where(np.abs(denom) < self.eps, np.nan, denom)
190
+ new[i] = scale * (new[i] / denom)
191
+
192
+ return StageOutput(
193
+ signals=new,
194
+ results={
195
+ "method": "baseline",
196
+ "baseline_key": self.baseline_key,
197
+ "baseline_mode": self.baseline_mode,
198
+ "channels_normalised": idxs,
199
+ },
200
+ )
201
+
202
+ mask = _window_mask(state, self.time_window)
203
+
204
+ if self.method == "z_score":
205
+ means: dict[str, float] = {}
206
+ stds: dict[str, float] = {}
207
+
208
+ for i in idxs:
209
+ x = new[i]
210
+ mu = float(np.nanmean(x[mask]))
211
+ sd = float(np.nanstd(x[mask], ddof=self.ddof))
212
+ if not np.isfinite(sd) or sd < self.eps:
213
+ raise ValueError(
214
+ f"Standard deviation too small/invalid for channel "
215
+ f"'{state.channel_names[i]}': {sd}."
216
+ )
217
+ new[i] = (x - mu) / sd
218
+ means[state.channel_names[i]] = mu
219
+ stds[state.channel_names[i]] = sd
220
+
221
+ return StageOutput(
222
+ signals=new,
223
+ results={
224
+ "method": "z_score",
225
+ "means": means,
226
+ "stds": stds,
227
+ "time_window": self.time_window,
228
+ "ddof": self.ddof,
229
+ },
230
+ )
231
+
232
+ # null_z
233
+ scales: dict[str, float] = {}
234
+ for i in idxs:
235
+ x = new[i]
236
+ xm = x[mask]
237
+
238
+ if self.null_z_scale == "rms":
239
+ s0 = float(np.sqrt(np.nanmean(xm * xm)))
240
+ else:
241
+ s0 = float(self.mad_scale * np.nanmedian(np.abs(xm)))
242
+
243
+ if not np.isfinite(s0) or s0 < self.eps:
244
+ raise ValueError(
245
+ f"Null-Z scale too small/invalid for channel "
246
+ f"'{state.channel_names[i]}': {s0}."
247
+ )
248
+
249
+ new[i] = x / s0
250
+ scales[state.channel_names[i]] = s0
251
+
252
+ return StageOutput(
253
+ signals=new,
254
+ results={
255
+ "method": "null_z",
256
+ "null_z_scale": self.null_z_scale,
257
+ "scales": scales,
258
+ "time_window": self.time_window,
259
+ },
260
+ )
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Literal
5
+
6
+ import numpy as np
7
+
8
+ from ..fit.regression import fit_irls, fit_ols
9
+ from ..state import PhotometryState
10
+ from .base import StageOutput, UpdateStage, _resolve_channels
11
+
12
+ RegressionMethod = Literal["ols", "irls_tukey", "irls_huber"]
13
+
14
+
15
+ @dataclass(frozen=True, slots=True)
16
+ class IsosbesticRegression(UpdateStage):
17
+ """
18
+ Regress a control channel (typically isosbestic) onto one or more channels.
19
+
20
+ For each target channel y and control x, fit:
21
+
22
+ y ≈ intercept + slope * x
23
+
24
+ Output:
25
+ dF = y - y_hat
26
+
27
+ Also stores:
28
+ derived["motion_fit"] = y_hat (per channel; shape matches signals)
29
+
30
+ Notes
31
+ -----
32
+ `motion_fit` is a nuisance estimate used for subtraction/diagnostics. It is
33
+ not necessarily suitable as a denominator for dF/F, especially if signals
34
+ have been detrended (e.g. double exponential subtraction).
35
+ """
36
+
37
+ name: str = field(default="isosbestic_regression", init=False)
38
+
39
+ control: str = "iso"
40
+ channels: str | list[str] | None = None
41
+
42
+ method: RegressionMethod = "irls_tukey"
43
+ include_intercept: bool = True
44
+
45
+ # IRLS settings
46
+ tuning_constant: float = 4.685
47
+ max_iter: int = 100
48
+ tol: float = 1e-10
49
+ store_weights: bool = False
50
+
51
+ def _params_for_summary(self) -> dict[str, Any]:
52
+ return {
53
+ "control": self.control,
54
+ "channels": self.channels if self.channels is not None else "all",
55
+ "method": self.method,
56
+ "include_intercept": self.include_intercept,
57
+ "tuning_constant": self.tuning_constant,
58
+ "max_iter": self.max_iter,
59
+ "tol": self.tol,
60
+ "store_weights": self.store_weights,
61
+ }
62
+
63
+ def apply(self, state: PhotometryState) -> StageOutput:
64
+ control_idx = state.idx(self.control)
65
+ x = state.signals[control_idx]
66
+
67
+ idxs = _resolve_channels(state, self.channels)
68
+ idxs = [i for i in idxs if i != control_idx]
69
+ if not idxs:
70
+ raise ValueError(
71
+ "No target channels remain after excluding the control channel."
72
+ )
73
+
74
+ new = state.signals.copy()
75
+ motion_fit = np.full_like(state.signals, np.nan, dtype=float)
76
+
77
+ per_channel: dict[str, dict[str, Any]] = {}
78
+ r2s: list[float] = []
79
+
80
+ for i in idxs:
81
+ name = state.channel_names[i]
82
+ y = state.signals[i]
83
+
84
+ if self.method == "ols":
85
+ fit = fit_ols(x, y, include_intercept=self.include_intercept)
86
+ max_iter: int | None = None
87
+ else:
88
+ loss = "tukey" if self.method == "irls_tukey" else "huber"
89
+ fit = fit_irls(
90
+ x,
91
+ y,
92
+ include_intercept=self.include_intercept,
93
+ loss=loss,
94
+ tuning_constant=self.tuning_constant,
95
+ max_iter=self.max_iter,
96
+ tol=self.tol,
97
+ store_weights=self.store_weights,
98
+ )
99
+ max_iter = self.max_iter
100
+
101
+ y_hat = fit.fitted
102
+ motion_fit[i] = y_hat
103
+ new[i] = y - y_hat
104
+
105
+ per_channel[name] = {
106
+ "control": self.control,
107
+ "intercept": fit.intercept,
108
+ "slope": fit.slope,
109
+ "r2": fit.r2,
110
+ "method": fit.method,
111
+ "n_iter": fit.n_iter,
112
+ "max_iter": max_iter,
113
+ "tuning_constant": fit.tuning_constant,
114
+ "scale": fit.scale,
115
+ "weights": fit.weights,
116
+ }
117
+ if np.isfinite(fit.r2):
118
+ r2s.append(float(fit.r2))
119
+
120
+ metrics: dict[str, float] = {}
121
+ if r2s:
122
+ metrics["mean_r2"] = float(np.mean(r2s))
123
+ metrics["median_r2"] = float(np.median(r2s))
124
+
125
+ return StageOutput(
126
+ signals=new,
127
+ derived={
128
+ "motion_fit": motion_fit,
129
+ },
130
+ results={
131
+ "control": self.control,
132
+ "control_idx": control_idx,
133
+ "channels_fitted": idxs,
134
+ "method": self.method,
135
+ "include_intercept": self.include_intercept,
136
+ "channels": per_channel,
137
+ },
138
+ metrics=metrics,
139
+ )