fibphot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import scipy.optimize
8
+ import scipy.signal
9
+
10
+ from ..state import PhotometryState
11
+ from ..types import FloatArray
12
+ from .base import StageOutput, UpdateStage, _resolve_channels
13
+
14
+
15
+ def double_exponential(
16
+ times: FloatArray,
17
+ const: float,
18
+ amp_fast: float,
19
+ amp_slow: float,
20
+ tau_fast: float,
21
+ tau_slow: float,
22
+ ) -> FloatArray:
23
+ """Double exponential function for baseline fitting."""
24
+ t = np.asarray(times, dtype=float)
25
+ return (
26
+ const
27
+ + amp_slow * np.exp(-t / tau_slow)
28
+ + amp_fast * np.exp(-t / tau_fast)
29
+ )
30
+
31
+
32
+ def _r2(y: FloatArray, yhat: FloatArray) -> float:
33
+ """Coefficient of determination R^2."""
34
+ y = np.asarray(y, dtype=float)
35
+ yhat = np.asarray(yhat, dtype=float)
36
+ ss_res = float(np.sum((y - yhat) ** 2))
37
+ ss_tot = float(np.sum((y - np.mean(y)) ** 2))
38
+ if ss_tot < 1e-20:
39
+ return float("nan")
40
+ return 1.0 - ss_res / ss_tot
41
+
42
+
43
+ def _rmse(y: FloatArray, yhat: FloatArray) -> float:
44
+ """Root mean square error."""
45
+ y = np.asarray(y, dtype=float)
46
+ yhat = np.asarray(yhat, dtype=float)
47
+ return float(np.sqrt(np.mean((y - yhat) ** 2)))
48
+
49
+
50
+ def _initial_guess(y: FloatArray) -> list[float]:
51
+ """Initial guess for double exponential fitting."""
52
+ y = np.asarray(y, dtype=float)
53
+ n = y.shape[0]
54
+ tail = y[int(n * 0.9) :] if n >= 10 else y
55
+ const = float(np.median(tail))
56
+ amp = float(max(y[0] - const, 1e-6))
57
+ return [const, amp * 0.6, amp * 0.4, 300.0, 3000.0]
58
+
59
+
60
+ @dataclass(frozen=True, slots=True)
61
+ class DoubleExpBaseline(UpdateStage):
62
+ name: str = field(default="double_exp_baseline", init=False)
63
+
64
+ subtract: bool = False
65
+ channels: str | list[str] | None = None
66
+ decimate_to_hz: float | None = None
67
+ maxfev: int = 2000
68
+
69
+ tau_fast_bounds: tuple[float, float] = (60.0, 600.0)
70
+ tau_slow_bounds: tuple[float, float] = (600.0, 36000.0)
71
+
72
+ def _params_for_summary(self) -> dict[str, object]:
73
+ return {
74
+ "subtract": self.subtract,
75
+ "channels": self.channels if self.channels is not None else "all",
76
+ "decimate_to_hz": self.decimate_to_hz,
77
+ "maxfev": self.maxfev,
78
+ "tau_fast_bounds": self.tau_fast_bounds,
79
+ "tau_slow_bounds": self.tau_slow_bounds,
80
+ }
81
+
82
+ def apply(self, state: PhotometryState) -> StageOutput:
83
+ idxs = _resolve_channels(state, self.channels)
84
+ t_full = state.time_seconds
85
+
86
+ baseline = np.zeros_like(state.signals)
87
+ params_out = np.full((state.n_signals, 5), np.nan, dtype=float)
88
+ r2_out = np.full((state.n_signals,), np.nan, dtype=float)
89
+ rmse_out = np.full((state.n_signals,), np.nan, dtype=float)
90
+
91
+ fs = state.sampling_rate
92
+ if self.decimate_to_hz is None or self.decimate_to_hz <= 0:
93
+ decim = 1
94
+ else:
95
+ decim = max(1, int(round(fs / float(self.decimate_to_hz))))
96
+
97
+ for i in idxs:
98
+ y_full = state.signals[i]
99
+
100
+ if decim > 1:
101
+ # decimate both for fitting.
102
+ y_fit = scipy.signal.decimate(
103
+ y_full, decim, ftype="fir", zero_phase=True
104
+ )
105
+ t_fit = t_full[::decim]
106
+
107
+ # align lengths conservatively
108
+ m = min(t_fit.shape[0], y_fit.shape[0])
109
+ t_fit = t_fit[:m]
110
+ y_fit = y_fit[:m]
111
+ else:
112
+ t_fit = t_full
113
+ y_fit = y_full
114
+
115
+ guess = _initial_guess(y_fit)
116
+
117
+ y_max = float(np.max(y_fit))
118
+ lo = [
119
+ 0.0,
120
+ 0.0,
121
+ 0.0,
122
+ self.tau_fast_bounds[0],
123
+ self.tau_slow_bounds[0],
124
+ ]
125
+ hi = [
126
+ y_max,
127
+ y_max,
128
+ y_max,
129
+ self.tau_fast_bounds[1],
130
+ self.tau_slow_bounds[1],
131
+ ]
132
+
133
+ popt, _ = scipy.optimize.curve_fit(
134
+ f=double_exponential,
135
+ xdata=t_fit,
136
+ ydata=y_fit,
137
+ p0=guess,
138
+ bounds=(lo, hi),
139
+ maxfev=self.maxfev,
140
+ )
141
+
142
+ yhat_fit = double_exponential(t_fit, *popt)
143
+ r2_out[i] = _r2(y_fit, yhat_fit)
144
+ rmse_out[i] = _rmse(y_fit, yhat_fit)
145
+
146
+ params_out[i] = popt
147
+ baseline[i] = double_exponential(t_full, *popt)
148
+
149
+ new_signals = state.signals.copy()
150
+ if self.subtract:
151
+ for i in idxs:
152
+ new_signals[i] = new_signals[i] - baseline[i]
153
+
154
+ metrics = {
155
+ "mean_r2": float(np.nanmean(r2_out[idxs]))
156
+ if idxs
157
+ else float("nan"),
158
+ "mean_rmse": float(np.nanmean(rmse_out[idxs]))
159
+ if idxs
160
+ else float("nan"),
161
+ "decimate_factor": float(decim),
162
+ }
163
+
164
+ results = {
165
+ "params": params_out,
166
+ "r2": r2_out,
167
+ "rmse": rmse_out,
168
+ "channels_fitted": idxs,
169
+ }
170
+
171
+ notes = (
172
+ "Fitted double exponential baseline; parameters are stored per "
173
+ "channel. Baseline curves are available in "
174
+ "derived['double_exp_baseline']."
175
+ )
176
+
177
+ return StageOutput(
178
+ signals=new_signals,
179
+ derived={"double_exp_baseline": baseline},
180
+ results=results,
181
+ metrics=metrics,
182
+ notes=notes,
183
+ )
184
+
185
+
186
+ def _summarise_pybaselines_params(params: dict[str, Any]) -> dict[str, Any]:
187
+ """
188
+ Make pybaselines params safe to store in state.results:
189
+ - keep scalars
190
+ - keep very small arrays/lists
191
+ - otherwise store a short descriptor
192
+ """
193
+ out: dict[str, Any] = {}
194
+
195
+ for k, v in params.items():
196
+ if isinstance(v, (str, int, float, bool)) or v is None:
197
+ out[k] = v
198
+ continue
199
+
200
+ if isinstance(v, np.ndarray):
201
+ if v.size <= 16:
202
+ out[k] = np.asarray(v).tolist()
203
+ else:
204
+ out[k] = {
205
+ "type": "ndarray",
206
+ "shape": v.shape,
207
+ "dtype": str(v.dtype),
208
+ }
209
+ continue
210
+
211
+ if isinstance(v, (list, tuple)):
212
+ if len(v) <= 16 and all(
213
+ isinstance(x, (int, float, str, bool)) for x in v
214
+ ):
215
+ out[k] = list(v)
216
+ else:
217
+ out[k] = {"type": type(v).__name__, "len": len(v)}
218
+ continue
219
+
220
+ out[k] = repr(v)[:200]
221
+
222
+ return out
223
+
224
+
225
+ @dataclass(frozen=True, slots=True)
226
+ class PyBaselinesBaseline(UpdateStage):
227
+ """
228
+ Generic baseline estimation using `pybaselines`.
229
+
230
+ pybaselines uses a Baseline(x_data=...) object; each algorithm is called as:
231
+ baseline, params = baseline_fitter.<method>(y, **kwargs)
232
+
233
+ This stage:
234
+ - computes a baseline per selected channel
235
+ - stores a full (n_signals, n_samples) baseline array in state.derived
236
+ - optionally subtracts baseline from the signal(s)
237
+ """
238
+
239
+ name: str = field(default="pybaselines_baseline", init=False)
240
+
241
+ method: str = "asls"
242
+ method_kwargs: dict[str, Any] = field(default_factory=dict)
243
+
244
+ channels: str | list[str] | None = None
245
+
246
+ # x-axis passed to Baseline(x_data=...)
247
+ x_axis: str = "time" # "time" or "index"
248
+
249
+ # where to store the baseline in derived
250
+ baseline_key: str | None = None
251
+
252
+ # apply correction?
253
+ subtract: bool = False
254
+
255
+ # store full params (can be large); default stores a safe summary
256
+ store_full_params: bool = False
257
+
258
+ def _params_for_summary(self) -> dict[str, Any]:
259
+ return {
260
+ "method": self.method,
261
+ "method_kwargs": self.method_kwargs,
262
+ "channels": self.channels if self.channels is not None else "all",
263
+ "x_axis": self.x_axis,
264
+ "baseline_key": self.baseline_key,
265
+ "subtract": self.subtract,
266
+ "store_full_params": self.store_full_params,
267
+ }
268
+
269
+ def apply(self, state: PhotometryState) -> StageOutput:
270
+ try:
271
+ from pybaselines import Baseline
272
+ except Exception as exc: # pragma: no cover
273
+ raise ImportError(
274
+ "pybaselines is required for PyBaselinesBaseline. "
275
+ "Install it with `pip install pybaselines`."
276
+ ) from exc
277
+
278
+ idxs = _resolve_channels(state, self.channels)
279
+
280
+ if self.x_axis == "time":
281
+ x_data = np.asarray(state.time_seconds, dtype=float)
282
+ elif self.x_axis == "index":
283
+ x_data = np.arange(state.n_samples, dtype=float)
284
+ else:
285
+ raise ValueError("x_axis must be 'time' or 'index'.")
286
+
287
+ fitter = Baseline(x_data=x_data)
288
+
289
+ if not hasattr(fitter, self.method):
290
+ raise ValueError(
291
+ f"Unknown pybaselines method '{self.method}'. "
292
+ "See pybaselines docs for available algorithms."
293
+ )
294
+
295
+ fn = getattr(fitter, self.method)
296
+
297
+ # full-shape baseline so downstream code can rely on shape == signals.shape
298
+ baseline = np.full_like(state.signals, np.nan, dtype=float)
299
+ new = state.signals.copy()
300
+
301
+ per_channel: dict[str, dict[str, Any]] = {}
302
+
303
+ for i in idxs:
304
+ y = np.asarray(state.signals[i], dtype=float)
305
+
306
+ b, params = fn(y, **self.method_kwargs)
307
+
308
+ # ensure dtype/shape sanity
309
+ b = np.asarray(b, dtype=float)
310
+ if b.shape != y.shape:
311
+ raise ValueError(
312
+ f"pybaselines returned baseline shape {b.shape}, expected {y.shape}."
313
+ )
314
+
315
+ baseline[i] = b
316
+ if self.subtract:
317
+ new[i] = y - b
318
+
319
+ per_channel[state.channel_names[i]] = {
320
+ "method": self.method,
321
+ "method_kwargs": dict(self.method_kwargs),
322
+ "params": (
323
+ params
324
+ if self.store_full_params
325
+ else _summarise_pybaselines_params(params)
326
+ ),
327
+ }
328
+
329
+ key = (
330
+ self.baseline_key
331
+ if self.baseline_key is not None
332
+ else f"pybaselines_{self.method}"
333
+ )
334
+
335
+ notes = (
336
+ f"Computed baseline using pybaselines.{self.method}; "
337
+ f"stored in derived['{key}']."
338
+ + (" Subtracted from signals." if self.subtract else "")
339
+ )
340
+
341
+ return StageOutput(
342
+ signals=new if self.subtract else None,
343
+ derived={key: baseline},
344
+ results={
345
+ "baseline_key": key,
346
+ "method": self.method,
347
+ "method_kwargs": dict(self.method_kwargs),
348
+ "channels_fitted": [int(i) for i in idxs],
349
+ "channels": per_channel,
350
+ "x_axis": self.x_axis,
351
+ "subtract": self.subtract,
352
+ },
353
+ notes=notes,
354
+ )
@@ -0,0 +1,214 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Literal
5
+
6
+ import numpy as np
7
+
8
+ from ..fit.regression import fit_irls, fit_ols
9
+ from ..state import PhotometryState
10
+ from .base import StageOutput, UpdateStage, _resolve_channels
11
+
12
+ RegressionMethod = Literal["ols", "irls_tukey", "irls_huber"]
13
+ DffMode = Literal["dff", "percent", "df"]
14
+ DenomPolicy = Literal["nan", "clip", "raise"]
15
+
16
+
17
+ @dataclass(frozen=True, slots=True)
18
+ class IsosbesticDff(UpdateStage):
19
+ """
20
+ Combined control-fit normalisation workflow.
21
+
22
+ Fits a control channel x (typically isosbestic) to each target channel y:
23
+
24
+ y ≈ intercept + slope * x = y_hat
25
+
26
+ Then computes:
27
+ df: y - y_hat
28
+ dff: (y - y_hat) / y_hat
29
+ percent: 100 * (y - y_hat) / y_hat
30
+
31
+ This workflow is only numerically stable if y_hat does not spend meaningful
32
+ time near zero (e.g. when fitting raw-ish, positive signals). If signals have
33
+ already been detrended to be near zero (e.g. double-exp subtract=True),
34
+ prefer IsosbesticRegression (df) + Normalise.baseline(double_exp_baseline).
35
+ """
36
+
37
+ name: str = field(default="isosbestic_dff", init=False)
38
+
39
+ control: str = "iso"
40
+ channels: str | list[str] | None = None
41
+
42
+ method: RegressionMethod = "irls_tukey"
43
+ include_intercept: bool = True
44
+
45
+ mode: DffMode = "percent"
46
+
47
+ # IRLS settings
48
+ tuning_constant: float = 4.685
49
+ max_iter: int = 100
50
+ tol: float = 1e-10
51
+ store_weights: bool = False
52
+
53
+ # denominator safety
54
+ min_abs_denom: float = 1e-6
55
+ denom_policy: DenomPolicy = "nan"
56
+ max_near_zero_frac: float = 0.01
57
+
58
+ # optional outputs
59
+ store_fit: bool = True
60
+ store_df: bool = False
61
+
62
+ def _params_for_summary(self) -> dict[str, Any]:
63
+ return {
64
+ "control": self.control,
65
+ "channels": self.channels if self.channels is not None else "all",
66
+ "method": self.method,
67
+ "include_intercept": self.include_intercept,
68
+ "mode": self.mode,
69
+ "tuning_constant": self.tuning_constant,
70
+ "max_iter": self.max_iter,
71
+ "tol": self.tol,
72
+ "store_weights": self.store_weights,
73
+ "min_abs_denom": self.min_abs_denom,
74
+ "denom_policy": self.denom_policy,
75
+ "max_near_zero_frac": self.max_near_zero_frac,
76
+ "store_fit": self.store_fit,
77
+ "store_df": self.store_df,
78
+ }
79
+
80
+ def apply(self, state: PhotometryState) -> StageOutput:
81
+ control_idx = state.idx(self.control)
82
+ x = state.signals[control_idx]
83
+
84
+ idxs = _resolve_channels(state, self.channels)
85
+ idxs = [i for i in idxs if i != control_idx]
86
+ if not idxs:
87
+ raise ValueError(
88
+ "No target channels remain after excluding the control channel."
89
+ )
90
+
91
+ new = state.signals.copy()
92
+
93
+ fit_mat = (
94
+ np.full_like(state.signals, np.nan, dtype=float)
95
+ if self.store_fit
96
+ else None
97
+ )
98
+ df_mat = (
99
+ np.full_like(state.signals, np.nan, dtype=float)
100
+ if self.store_df
101
+ else None
102
+ )
103
+
104
+ per_channel: dict[str, dict[str, Any]] = {}
105
+ r2s: list[float] = []
106
+
107
+ for i in idxs:
108
+ name = state.channel_names[i]
109
+ y = state.signals[i]
110
+
111
+ if self.method == "ols":
112
+ fit = fit_ols(x, y, include_intercept=self.include_intercept)
113
+ max_iter: int | None = None
114
+ else:
115
+ loss = "tukey" if self.method == "irls_tukey" else "huber"
116
+ fit = fit_irls(
117
+ x,
118
+ y,
119
+ include_intercept=self.include_intercept,
120
+ loss=loss,
121
+ tuning_constant=self.tuning_constant,
122
+ max_iter=self.max_iter,
123
+ tol=self.tol,
124
+ store_weights=self.store_weights,
125
+ )
126
+ max_iter = self.max_iter
127
+
128
+ y_hat = np.asarray(fit.fitted, dtype=float)
129
+ df = y - y_hat
130
+
131
+ if fit_mat is not None:
132
+ fit_mat[i] = y_hat
133
+ if df_mat is not None:
134
+ df_mat[i] = df
135
+
136
+ if self.mode == "df":
137
+ corrected = df
138
+ denom_used = None
139
+ near_zero_frac = 0.0
140
+ else:
141
+ abs_d = np.abs(y_hat)
142
+ near_zero = abs_d < self.min_abs_denom
143
+ near_zero_frac = float(np.mean(near_zero))
144
+
145
+ if near_zero_frac > self.max_near_zero_frac:
146
+ msg = (
147
+ f"Denominator y_hat is near zero for {near_zero_frac:.2%} "
148
+ f"of samples in channel '{name}'. This usually indicates "
149
+ "you ran this after detrending/subtraction (signals ~ 0). "
150
+ "Use IsosbesticRegression (df) + Normalise.baseline(...) "
151
+ "instead, or run IsosbesticDff earlier on raw-ish signals."
152
+ )
153
+ if self.denom_policy == "raise":
154
+ raise ValueError(msg)
155
+
156
+ if self.denom_policy == "nan":
157
+ denom = np.where(near_zero, np.nan, y_hat)
158
+ else:
159
+ denom = np.where(
160
+ near_zero,
161
+ np.sign(y_hat) * self.min_abs_denom,
162
+ y_hat,
163
+ )
164
+
165
+ scale = 100.0 if self.mode == "percent" else 1.0
166
+ corrected = scale * (df / denom)
167
+ denom_used = "y_hat"
168
+
169
+ new[i] = corrected
170
+
171
+ per_channel[name] = {
172
+ "control": self.control,
173
+ "intercept": fit.intercept,
174
+ "slope": fit.slope,
175
+ "r2": fit.r2,
176
+ "method": fit.method,
177
+ "n_iter": fit.n_iter,
178
+ "max_iter": max_iter,
179
+ "tuning_constant": fit.tuning_constant,
180
+ "scale": fit.scale,
181
+ "weights": fit.weights,
182
+ "denom_used": denom_used,
183
+ "near_zero_frac": float(near_zero_frac),
184
+ "min_abs_denom": float(self.min_abs_denom),
185
+ "denom_policy": self.denom_policy,
186
+ }
187
+ if np.isfinite(fit.r2):
188
+ r2s.append(float(fit.r2))
189
+
190
+ metrics: dict[str, float] = {}
191
+ if r2s:
192
+ metrics["mean_r2"] = float(np.mean(r2s))
193
+ metrics["median_r2"] = float(np.median(r2s))
194
+
195
+ derived: dict[str, np.ndarray] = {}
196
+ if fit_mat is not None:
197
+ derived["control_fit"] = fit_mat
198
+ if df_mat is not None:
199
+ derived["control_df"] = df_mat
200
+
201
+ return StageOutput(
202
+ signals=new,
203
+ derived=derived or None,
204
+ results={
205
+ "control": self.control,
206
+ "control_idx": control_idx,
207
+ "channels_fitted": idxs,
208
+ "method": self.method,
209
+ "mode": self.mode,
210
+ "include_intercept": self.include_intercept,
211
+ "channels": per_channel,
212
+ },
213
+ metrics=metrics,
214
+ )