fibphot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fibphot/__init__.py +6 -0
- fibphot/analysis/__init__.py +0 -0
- fibphot/analysis/aggregate.py +257 -0
- fibphot/analysis/auc.py +354 -0
- fibphot/analysis/irls.py +350 -0
- fibphot/analysis/peaks.py +1163 -0
- fibphot/analysis/photobleaching.py +290 -0
- fibphot/analysis/plotting.py +105 -0
- fibphot/analysis/report.py +56 -0
- fibphot/collection.py +207 -0
- fibphot/fit/__init__.py +0 -0
- fibphot/fit/regression.py +269 -0
- fibphot/io/__init__.py +6 -0
- fibphot/io/doric.py +435 -0
- fibphot/io/excel.py +76 -0
- fibphot/io/h5.py +321 -0
- fibphot/misc.py +11 -0
- fibphot/peaks.py +628 -0
- fibphot/pipeline.py +14 -0
- fibphot/plotting.py +594 -0
- fibphot/stages/__init__.py +22 -0
- fibphot/stages/base.py +101 -0
- fibphot/stages/baseline.py +354 -0
- fibphot/stages/control_dff.py +214 -0
- fibphot/stages/filters.py +273 -0
- fibphot/stages/normalisation.py +260 -0
- fibphot/stages/regression.py +139 -0
- fibphot/stages/smooth.py +442 -0
- fibphot/stages/trim.py +141 -0
- fibphot/state.py +309 -0
- fibphot/tags.py +130 -0
- fibphot/types.py +6 -0
- fibphot-0.1.0.dist-info/METADATA +63 -0
- fibphot-0.1.0.dist-info/RECORD +37 -0
- fibphot-0.1.0.dist-info/WHEEL +5 -0
- fibphot-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- fibphot-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy.optimize
|
|
8
|
+
import scipy.signal
|
|
9
|
+
|
|
10
|
+
from ..state import PhotometryState
|
|
11
|
+
from ..types import FloatArray
|
|
12
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def double_exponential(
|
|
16
|
+
times: FloatArray,
|
|
17
|
+
const: float,
|
|
18
|
+
amp_fast: float,
|
|
19
|
+
amp_slow: float,
|
|
20
|
+
tau_fast: float,
|
|
21
|
+
tau_slow: float,
|
|
22
|
+
) -> FloatArray:
|
|
23
|
+
"""Double exponential function for baseline fitting."""
|
|
24
|
+
t = np.asarray(times, dtype=float)
|
|
25
|
+
return (
|
|
26
|
+
const
|
|
27
|
+
+ amp_slow * np.exp(-t / tau_slow)
|
|
28
|
+
+ amp_fast * np.exp(-t / tau_fast)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _r2(y: FloatArray, yhat: FloatArray) -> float:
|
|
33
|
+
"""Coefficient of determination R^2."""
|
|
34
|
+
y = np.asarray(y, dtype=float)
|
|
35
|
+
yhat = np.asarray(yhat, dtype=float)
|
|
36
|
+
ss_res = float(np.sum((y - yhat) ** 2))
|
|
37
|
+
ss_tot = float(np.sum((y - np.mean(y)) ** 2))
|
|
38
|
+
if ss_tot < 1e-20:
|
|
39
|
+
return float("nan")
|
|
40
|
+
return 1.0 - ss_res / ss_tot
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _rmse(y: FloatArray, yhat: FloatArray) -> float:
|
|
44
|
+
"""Root mean square error."""
|
|
45
|
+
y = np.asarray(y, dtype=float)
|
|
46
|
+
yhat = np.asarray(yhat, dtype=float)
|
|
47
|
+
return float(np.sqrt(np.mean((y - yhat) ** 2)))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _initial_guess(y: FloatArray) -> list[float]:
|
|
51
|
+
"""Initial guess for double exponential fitting."""
|
|
52
|
+
y = np.asarray(y, dtype=float)
|
|
53
|
+
n = y.shape[0]
|
|
54
|
+
tail = y[int(n * 0.9) :] if n >= 10 else y
|
|
55
|
+
const = float(np.median(tail))
|
|
56
|
+
amp = float(max(y[0] - const, 1e-6))
|
|
57
|
+
return [const, amp * 0.6, amp * 0.4, 300.0, 3000.0]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass(frozen=True, slots=True)
|
|
61
|
+
class DoubleExpBaseline(UpdateStage):
|
|
62
|
+
name: str = field(default="double_exp_baseline", init=False)
|
|
63
|
+
|
|
64
|
+
subtract: bool = False
|
|
65
|
+
channels: str | list[str] | None = None
|
|
66
|
+
decimate_to_hz: float | None = None
|
|
67
|
+
maxfev: int = 2000
|
|
68
|
+
|
|
69
|
+
tau_fast_bounds: tuple[float, float] = (60.0, 600.0)
|
|
70
|
+
tau_slow_bounds: tuple[float, float] = (600.0, 36000.0)
|
|
71
|
+
|
|
72
|
+
def _params_for_summary(self) -> dict[str, object]:
|
|
73
|
+
return {
|
|
74
|
+
"subtract": self.subtract,
|
|
75
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
76
|
+
"decimate_to_hz": self.decimate_to_hz,
|
|
77
|
+
"maxfev": self.maxfev,
|
|
78
|
+
"tau_fast_bounds": self.tau_fast_bounds,
|
|
79
|
+
"tau_slow_bounds": self.tau_slow_bounds,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
83
|
+
idxs = _resolve_channels(state, self.channels)
|
|
84
|
+
t_full = state.time_seconds
|
|
85
|
+
|
|
86
|
+
baseline = np.zeros_like(state.signals)
|
|
87
|
+
params_out = np.full((state.n_signals, 5), np.nan, dtype=float)
|
|
88
|
+
r2_out = np.full((state.n_signals,), np.nan, dtype=float)
|
|
89
|
+
rmse_out = np.full((state.n_signals,), np.nan, dtype=float)
|
|
90
|
+
|
|
91
|
+
fs = state.sampling_rate
|
|
92
|
+
if self.decimate_to_hz is None or self.decimate_to_hz <= 0:
|
|
93
|
+
decim = 1
|
|
94
|
+
else:
|
|
95
|
+
decim = max(1, int(round(fs / float(self.decimate_to_hz))))
|
|
96
|
+
|
|
97
|
+
for i in idxs:
|
|
98
|
+
y_full = state.signals[i]
|
|
99
|
+
|
|
100
|
+
if decim > 1:
|
|
101
|
+
# decimate both for fitting.
|
|
102
|
+
y_fit = scipy.signal.decimate(
|
|
103
|
+
y_full, decim, ftype="fir", zero_phase=True
|
|
104
|
+
)
|
|
105
|
+
t_fit = t_full[::decim]
|
|
106
|
+
|
|
107
|
+
# align lengths conservatively
|
|
108
|
+
m = min(t_fit.shape[0], y_fit.shape[0])
|
|
109
|
+
t_fit = t_fit[:m]
|
|
110
|
+
y_fit = y_fit[:m]
|
|
111
|
+
else:
|
|
112
|
+
t_fit = t_full
|
|
113
|
+
y_fit = y_full
|
|
114
|
+
|
|
115
|
+
guess = _initial_guess(y_fit)
|
|
116
|
+
|
|
117
|
+
y_max = float(np.max(y_fit))
|
|
118
|
+
lo = [
|
|
119
|
+
0.0,
|
|
120
|
+
0.0,
|
|
121
|
+
0.0,
|
|
122
|
+
self.tau_fast_bounds[0],
|
|
123
|
+
self.tau_slow_bounds[0],
|
|
124
|
+
]
|
|
125
|
+
hi = [
|
|
126
|
+
y_max,
|
|
127
|
+
y_max,
|
|
128
|
+
y_max,
|
|
129
|
+
self.tau_fast_bounds[1],
|
|
130
|
+
self.tau_slow_bounds[1],
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
popt, _ = scipy.optimize.curve_fit(
|
|
134
|
+
f=double_exponential,
|
|
135
|
+
xdata=t_fit,
|
|
136
|
+
ydata=y_fit,
|
|
137
|
+
p0=guess,
|
|
138
|
+
bounds=(lo, hi),
|
|
139
|
+
maxfev=self.maxfev,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
yhat_fit = double_exponential(t_fit, *popt)
|
|
143
|
+
r2_out[i] = _r2(y_fit, yhat_fit)
|
|
144
|
+
rmse_out[i] = _rmse(y_fit, yhat_fit)
|
|
145
|
+
|
|
146
|
+
params_out[i] = popt
|
|
147
|
+
baseline[i] = double_exponential(t_full, *popt)
|
|
148
|
+
|
|
149
|
+
new_signals = state.signals.copy()
|
|
150
|
+
if self.subtract:
|
|
151
|
+
for i in idxs:
|
|
152
|
+
new_signals[i] = new_signals[i] - baseline[i]
|
|
153
|
+
|
|
154
|
+
metrics = {
|
|
155
|
+
"mean_r2": float(np.nanmean(r2_out[idxs]))
|
|
156
|
+
if idxs
|
|
157
|
+
else float("nan"),
|
|
158
|
+
"mean_rmse": float(np.nanmean(rmse_out[idxs]))
|
|
159
|
+
if idxs
|
|
160
|
+
else float("nan"),
|
|
161
|
+
"decimate_factor": float(decim),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
results = {
|
|
165
|
+
"params": params_out,
|
|
166
|
+
"r2": r2_out,
|
|
167
|
+
"rmse": rmse_out,
|
|
168
|
+
"channels_fitted": idxs,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
notes = (
|
|
172
|
+
"Fitted double exponential baseline; parameters are stored per "
|
|
173
|
+
"channel. Baseline curves are available in "
|
|
174
|
+
"derived['double_exp_baseline']."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return StageOutput(
|
|
178
|
+
signals=new_signals,
|
|
179
|
+
derived={"double_exp_baseline": baseline},
|
|
180
|
+
results=results,
|
|
181
|
+
metrics=metrics,
|
|
182
|
+
notes=notes,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _summarise_pybaselines_params(params: dict[str, Any]) -> dict[str, Any]:
|
|
187
|
+
"""
|
|
188
|
+
Make pybaselines params safe to store in state.results:
|
|
189
|
+
- keep scalars
|
|
190
|
+
- keep very small arrays/lists
|
|
191
|
+
- otherwise store a short descriptor
|
|
192
|
+
"""
|
|
193
|
+
out: dict[str, Any] = {}
|
|
194
|
+
|
|
195
|
+
for k, v in params.items():
|
|
196
|
+
if isinstance(v, (str, int, float, bool)) or v is None:
|
|
197
|
+
out[k] = v
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
if isinstance(v, np.ndarray):
|
|
201
|
+
if v.size <= 16:
|
|
202
|
+
out[k] = np.asarray(v).tolist()
|
|
203
|
+
else:
|
|
204
|
+
out[k] = {
|
|
205
|
+
"type": "ndarray",
|
|
206
|
+
"shape": v.shape,
|
|
207
|
+
"dtype": str(v.dtype),
|
|
208
|
+
}
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
if isinstance(v, (list, tuple)):
|
|
212
|
+
if len(v) <= 16 and all(
|
|
213
|
+
isinstance(x, (int, float, str, bool)) for x in v
|
|
214
|
+
):
|
|
215
|
+
out[k] = list(v)
|
|
216
|
+
else:
|
|
217
|
+
out[k] = {"type": type(v).__name__, "len": len(v)}
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
out[k] = repr(v)[:200]
|
|
221
|
+
|
|
222
|
+
return out
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@dataclass(frozen=True, slots=True)
|
|
226
|
+
class PyBaselinesBaseline(UpdateStage):
|
|
227
|
+
"""
|
|
228
|
+
Generic baseline estimation using `pybaselines`.
|
|
229
|
+
|
|
230
|
+
pybaselines uses a Baseline(x_data=...) object; each algorithm is called as:
|
|
231
|
+
baseline, params = baseline_fitter.<method>(y, **kwargs)
|
|
232
|
+
|
|
233
|
+
This stage:
|
|
234
|
+
- computes a baseline per selected channel
|
|
235
|
+
- stores a full (n_signals, n_samples) baseline array in state.derived
|
|
236
|
+
- optionally subtracts baseline from the signal(s)
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
name: str = field(default="pybaselines_baseline", init=False)
|
|
240
|
+
|
|
241
|
+
method: str = "asls"
|
|
242
|
+
method_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
243
|
+
|
|
244
|
+
channels: str | list[str] | None = None
|
|
245
|
+
|
|
246
|
+
# x-axis passed to Baseline(x_data=...)
|
|
247
|
+
x_axis: str = "time" # "time" or "index"
|
|
248
|
+
|
|
249
|
+
# where to store the baseline in derived
|
|
250
|
+
baseline_key: str | None = None
|
|
251
|
+
|
|
252
|
+
# apply correction?
|
|
253
|
+
subtract: bool = False
|
|
254
|
+
|
|
255
|
+
# store full params (can be large); default stores a safe summary
|
|
256
|
+
store_full_params: bool = False
|
|
257
|
+
|
|
258
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
259
|
+
return {
|
|
260
|
+
"method": self.method,
|
|
261
|
+
"method_kwargs": self.method_kwargs,
|
|
262
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
263
|
+
"x_axis": self.x_axis,
|
|
264
|
+
"baseline_key": self.baseline_key,
|
|
265
|
+
"subtract": self.subtract,
|
|
266
|
+
"store_full_params": self.store_full_params,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
270
|
+
try:
|
|
271
|
+
from pybaselines import Baseline
|
|
272
|
+
except Exception as exc: # pragma: no cover
|
|
273
|
+
raise ImportError(
|
|
274
|
+
"pybaselines is required for PyBaselinesBaseline. "
|
|
275
|
+
"Install it with `pip install pybaselines`."
|
|
276
|
+
) from exc
|
|
277
|
+
|
|
278
|
+
idxs = _resolve_channels(state, self.channels)
|
|
279
|
+
|
|
280
|
+
if self.x_axis == "time":
|
|
281
|
+
x_data = np.asarray(state.time_seconds, dtype=float)
|
|
282
|
+
elif self.x_axis == "index":
|
|
283
|
+
x_data = np.arange(state.n_samples, dtype=float)
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError("x_axis must be 'time' or 'index'.")
|
|
286
|
+
|
|
287
|
+
fitter = Baseline(x_data=x_data)
|
|
288
|
+
|
|
289
|
+
if not hasattr(fitter, self.method):
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"Unknown pybaselines method '{self.method}'. "
|
|
292
|
+
"See pybaselines docs for available algorithms."
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
fn = getattr(fitter, self.method)
|
|
296
|
+
|
|
297
|
+
# full-shape baseline so downstream code can rely on shape == signals.shape
|
|
298
|
+
baseline = np.full_like(state.signals, np.nan, dtype=float)
|
|
299
|
+
new = state.signals.copy()
|
|
300
|
+
|
|
301
|
+
per_channel: dict[str, dict[str, Any]] = {}
|
|
302
|
+
|
|
303
|
+
for i in idxs:
|
|
304
|
+
y = np.asarray(state.signals[i], dtype=float)
|
|
305
|
+
|
|
306
|
+
b, params = fn(y, **self.method_kwargs)
|
|
307
|
+
|
|
308
|
+
# ensure dtype/shape sanity
|
|
309
|
+
b = np.asarray(b, dtype=float)
|
|
310
|
+
if b.shape != y.shape:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"pybaselines returned baseline shape {b.shape}, expected {y.shape}."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
baseline[i] = b
|
|
316
|
+
if self.subtract:
|
|
317
|
+
new[i] = y - b
|
|
318
|
+
|
|
319
|
+
per_channel[state.channel_names[i]] = {
|
|
320
|
+
"method": self.method,
|
|
321
|
+
"method_kwargs": dict(self.method_kwargs),
|
|
322
|
+
"params": (
|
|
323
|
+
params
|
|
324
|
+
if self.store_full_params
|
|
325
|
+
else _summarise_pybaselines_params(params)
|
|
326
|
+
),
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
key = (
|
|
330
|
+
self.baseline_key
|
|
331
|
+
if self.baseline_key is not None
|
|
332
|
+
else f"pybaselines_{self.method}"
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
notes = (
|
|
336
|
+
f"Computed baseline using pybaselines.{self.method}; "
|
|
337
|
+
f"stored in derived['{key}']."
|
|
338
|
+
+ (" Subtracted from signals." if self.subtract else "")
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return StageOutput(
|
|
342
|
+
signals=new if self.subtract else None,
|
|
343
|
+
derived={key: baseline},
|
|
344
|
+
results={
|
|
345
|
+
"baseline_key": key,
|
|
346
|
+
"method": self.method,
|
|
347
|
+
"method_kwargs": dict(self.method_kwargs),
|
|
348
|
+
"channels_fitted": [int(i) for i in idxs],
|
|
349
|
+
"channels": per_channel,
|
|
350
|
+
"x_axis": self.x_axis,
|
|
351
|
+
"subtract": self.subtract,
|
|
352
|
+
},
|
|
353
|
+
notes=notes,
|
|
354
|
+
)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..fit.regression import fit_irls, fit_ols
|
|
9
|
+
from ..state import PhotometryState
|
|
10
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
11
|
+
|
|
12
|
+
RegressionMethod = Literal["ols", "irls_tukey", "irls_huber"]
|
|
13
|
+
DffMode = Literal["dff", "percent", "df"]
|
|
14
|
+
DenomPolicy = Literal["nan", "clip", "raise"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True, slots=True)
|
|
18
|
+
class IsosbesticDff(UpdateStage):
|
|
19
|
+
"""
|
|
20
|
+
Combined control-fit normalisation workflow.
|
|
21
|
+
|
|
22
|
+
Fits a control channel x (typically isosbestic) to each target channel y:
|
|
23
|
+
|
|
24
|
+
y ≈ intercept + slope * x = y_hat
|
|
25
|
+
|
|
26
|
+
Then computes:
|
|
27
|
+
df: y - y_hat
|
|
28
|
+
dff: (y - y_hat) / y_hat
|
|
29
|
+
percent: 100 * (y - y_hat) / y_hat
|
|
30
|
+
|
|
31
|
+
This workflow is only numerically stable if y_hat does not spend meaningful
|
|
32
|
+
time near zero (e.g. when fitting raw-ish, positive signals). If signals have
|
|
33
|
+
already been detrended to be near zero (e.g. double-exp subtract=True),
|
|
34
|
+
prefer IsosbesticRegression (df) + Normalise.baseline(double_exp_baseline).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
name: str = field(default="isosbestic_dff", init=False)
|
|
38
|
+
|
|
39
|
+
control: str = "iso"
|
|
40
|
+
channels: str | list[str] | None = None
|
|
41
|
+
|
|
42
|
+
method: RegressionMethod = "irls_tukey"
|
|
43
|
+
include_intercept: bool = True
|
|
44
|
+
|
|
45
|
+
mode: DffMode = "percent"
|
|
46
|
+
|
|
47
|
+
# IRLS settings
|
|
48
|
+
tuning_constant: float = 4.685
|
|
49
|
+
max_iter: int = 100
|
|
50
|
+
tol: float = 1e-10
|
|
51
|
+
store_weights: bool = False
|
|
52
|
+
|
|
53
|
+
# denominator safety
|
|
54
|
+
min_abs_denom: float = 1e-6
|
|
55
|
+
denom_policy: DenomPolicy = "nan"
|
|
56
|
+
max_near_zero_frac: float = 0.01
|
|
57
|
+
|
|
58
|
+
# optional outputs
|
|
59
|
+
store_fit: bool = True
|
|
60
|
+
store_df: bool = False
|
|
61
|
+
|
|
62
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
63
|
+
return {
|
|
64
|
+
"control": self.control,
|
|
65
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
66
|
+
"method": self.method,
|
|
67
|
+
"include_intercept": self.include_intercept,
|
|
68
|
+
"mode": self.mode,
|
|
69
|
+
"tuning_constant": self.tuning_constant,
|
|
70
|
+
"max_iter": self.max_iter,
|
|
71
|
+
"tol": self.tol,
|
|
72
|
+
"store_weights": self.store_weights,
|
|
73
|
+
"min_abs_denom": self.min_abs_denom,
|
|
74
|
+
"denom_policy": self.denom_policy,
|
|
75
|
+
"max_near_zero_frac": self.max_near_zero_frac,
|
|
76
|
+
"store_fit": self.store_fit,
|
|
77
|
+
"store_df": self.store_df,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
81
|
+
control_idx = state.idx(self.control)
|
|
82
|
+
x = state.signals[control_idx]
|
|
83
|
+
|
|
84
|
+
idxs = _resolve_channels(state, self.channels)
|
|
85
|
+
idxs = [i for i in idxs if i != control_idx]
|
|
86
|
+
if not idxs:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"No target channels remain after excluding the control channel."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
new = state.signals.copy()
|
|
92
|
+
|
|
93
|
+
fit_mat = (
|
|
94
|
+
np.full_like(state.signals, np.nan, dtype=float)
|
|
95
|
+
if self.store_fit
|
|
96
|
+
else None
|
|
97
|
+
)
|
|
98
|
+
df_mat = (
|
|
99
|
+
np.full_like(state.signals, np.nan, dtype=float)
|
|
100
|
+
if self.store_df
|
|
101
|
+
else None
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
per_channel: dict[str, dict[str, Any]] = {}
|
|
105
|
+
r2s: list[float] = []
|
|
106
|
+
|
|
107
|
+
for i in idxs:
|
|
108
|
+
name = state.channel_names[i]
|
|
109
|
+
y = state.signals[i]
|
|
110
|
+
|
|
111
|
+
if self.method == "ols":
|
|
112
|
+
fit = fit_ols(x, y, include_intercept=self.include_intercept)
|
|
113
|
+
max_iter: int | None = None
|
|
114
|
+
else:
|
|
115
|
+
loss = "tukey" if self.method == "irls_tukey" else "huber"
|
|
116
|
+
fit = fit_irls(
|
|
117
|
+
x,
|
|
118
|
+
y,
|
|
119
|
+
include_intercept=self.include_intercept,
|
|
120
|
+
loss=loss,
|
|
121
|
+
tuning_constant=self.tuning_constant,
|
|
122
|
+
max_iter=self.max_iter,
|
|
123
|
+
tol=self.tol,
|
|
124
|
+
store_weights=self.store_weights,
|
|
125
|
+
)
|
|
126
|
+
max_iter = self.max_iter
|
|
127
|
+
|
|
128
|
+
y_hat = np.asarray(fit.fitted, dtype=float)
|
|
129
|
+
df = y - y_hat
|
|
130
|
+
|
|
131
|
+
if fit_mat is not None:
|
|
132
|
+
fit_mat[i] = y_hat
|
|
133
|
+
if df_mat is not None:
|
|
134
|
+
df_mat[i] = df
|
|
135
|
+
|
|
136
|
+
if self.mode == "df":
|
|
137
|
+
corrected = df
|
|
138
|
+
denom_used = None
|
|
139
|
+
near_zero_frac = 0.0
|
|
140
|
+
else:
|
|
141
|
+
abs_d = np.abs(y_hat)
|
|
142
|
+
near_zero = abs_d < self.min_abs_denom
|
|
143
|
+
near_zero_frac = float(np.mean(near_zero))
|
|
144
|
+
|
|
145
|
+
if near_zero_frac > self.max_near_zero_frac:
|
|
146
|
+
msg = (
|
|
147
|
+
f"Denominator y_hat is near zero for {near_zero_frac:.2%} "
|
|
148
|
+
f"of samples in channel '{name}'. This usually indicates "
|
|
149
|
+
"you ran this after detrending/subtraction (signals ~ 0). "
|
|
150
|
+
"Use IsosbesticRegression (df) + Normalise.baseline(...) "
|
|
151
|
+
"instead, or run IsosbesticDff earlier on raw-ish signals."
|
|
152
|
+
)
|
|
153
|
+
if self.denom_policy == "raise":
|
|
154
|
+
raise ValueError(msg)
|
|
155
|
+
|
|
156
|
+
if self.denom_policy == "nan":
|
|
157
|
+
denom = np.where(near_zero, np.nan, y_hat)
|
|
158
|
+
else:
|
|
159
|
+
denom = np.where(
|
|
160
|
+
near_zero,
|
|
161
|
+
np.sign(y_hat) * self.min_abs_denom,
|
|
162
|
+
y_hat,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
scale = 100.0 if self.mode == "percent" else 1.0
|
|
166
|
+
corrected = scale * (df / denom)
|
|
167
|
+
denom_used = "y_hat"
|
|
168
|
+
|
|
169
|
+
new[i] = corrected
|
|
170
|
+
|
|
171
|
+
per_channel[name] = {
|
|
172
|
+
"control": self.control,
|
|
173
|
+
"intercept": fit.intercept,
|
|
174
|
+
"slope": fit.slope,
|
|
175
|
+
"r2": fit.r2,
|
|
176
|
+
"method": fit.method,
|
|
177
|
+
"n_iter": fit.n_iter,
|
|
178
|
+
"max_iter": max_iter,
|
|
179
|
+
"tuning_constant": fit.tuning_constant,
|
|
180
|
+
"scale": fit.scale,
|
|
181
|
+
"weights": fit.weights,
|
|
182
|
+
"denom_used": denom_used,
|
|
183
|
+
"near_zero_frac": float(near_zero_frac),
|
|
184
|
+
"min_abs_denom": float(self.min_abs_denom),
|
|
185
|
+
"denom_policy": self.denom_policy,
|
|
186
|
+
}
|
|
187
|
+
if np.isfinite(fit.r2):
|
|
188
|
+
r2s.append(float(fit.r2))
|
|
189
|
+
|
|
190
|
+
metrics: dict[str, float] = {}
|
|
191
|
+
if r2s:
|
|
192
|
+
metrics["mean_r2"] = float(np.mean(r2s))
|
|
193
|
+
metrics["median_r2"] = float(np.median(r2s))
|
|
194
|
+
|
|
195
|
+
derived: dict[str, np.ndarray] = {}
|
|
196
|
+
if fit_mat is not None:
|
|
197
|
+
derived["control_fit"] = fit_mat
|
|
198
|
+
if df_mat is not None:
|
|
199
|
+
derived["control_df"] = df_mat
|
|
200
|
+
|
|
201
|
+
return StageOutput(
|
|
202
|
+
signals=new,
|
|
203
|
+
derived=derived or None,
|
|
204
|
+
results={
|
|
205
|
+
"control": self.control,
|
|
206
|
+
"control_idx": control_idx,
|
|
207
|
+
"channels_fitted": idxs,
|
|
208
|
+
"method": self.method,
|
|
209
|
+
"mode": self.mode,
|
|
210
|
+
"include_intercept": self.include_intercept,
|
|
211
|
+
"channels": per_channel,
|
|
212
|
+
},
|
|
213
|
+
metrics=metrics,
|
|
214
|
+
)
|