fibphot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fibphot/__init__.py +6 -0
- fibphot/analysis/__init__.py +0 -0
- fibphot/analysis/aggregate.py +257 -0
- fibphot/analysis/auc.py +354 -0
- fibphot/analysis/irls.py +350 -0
- fibphot/analysis/peaks.py +1163 -0
- fibphot/analysis/photobleaching.py +290 -0
- fibphot/analysis/plotting.py +105 -0
- fibphot/analysis/report.py +56 -0
- fibphot/collection.py +207 -0
- fibphot/fit/__init__.py +0 -0
- fibphot/fit/regression.py +269 -0
- fibphot/io/__init__.py +6 -0
- fibphot/io/doric.py +435 -0
- fibphot/io/excel.py +76 -0
- fibphot/io/h5.py +321 -0
- fibphot/misc.py +11 -0
- fibphot/peaks.py +628 -0
- fibphot/pipeline.py +14 -0
- fibphot/plotting.py +594 -0
- fibphot/stages/__init__.py +22 -0
- fibphot/stages/base.py +101 -0
- fibphot/stages/baseline.py +354 -0
- fibphot/stages/control_dff.py +214 -0
- fibphot/stages/filters.py +273 -0
- fibphot/stages/normalisation.py +260 -0
- fibphot/stages/regression.py +139 -0
- fibphot/stages/smooth.py +442 -0
- fibphot/stages/trim.py +141 -0
- fibphot/state.py +309 -0
- fibphot/tags.py +130 -0
- fibphot/types.py +6 -0
- fibphot-0.1.0.dist-info/METADATA +63 -0
- fibphot-0.1.0.dist-info/RECORD +37 -0
- fibphot-0.1.0.dist-info/WHEEL +5 -0
- fibphot-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- fibphot-0.1.0.dist-info/top_level.txt +1 -0
fibphot/stages/smooth.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..state import PhotometryState
|
|
9
|
+
from ..types import FloatArray
|
|
10
|
+
from .base import StageOutput, UpdateStage, _resolve_channels
|
|
11
|
+
|
|
12
|
+
WindowType = Literal["flat", "hanning", "hamming", "bartlett", "blackman"]
|
|
13
|
+
PadMode = Literal["reflect", "edge"]
|
|
14
|
+
KalmanModel = Literal["local_level", "local_linear_trend"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _window(window: WindowType, window_len: int) -> FloatArray:
|
|
18
|
+
if window == "flat":
|
|
19
|
+
return np.ones(window_len, dtype=float)
|
|
20
|
+
if window == "hanning":
|
|
21
|
+
return np.hanning(window_len).astype(float)
|
|
22
|
+
if window == "hamming":
|
|
23
|
+
return np.hamming(window_len).astype(float)
|
|
24
|
+
if window == "bartlett":
|
|
25
|
+
return np.bartlett(window_len).astype(float)
|
|
26
|
+
if window == "blackman":
|
|
27
|
+
return np.blackman(window_len).astype(float)
|
|
28
|
+
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"window must be one of: "
|
|
31
|
+
"'flat', 'hanning', 'hamming', 'bartlett', 'blackman'."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def smooth_1d(
|
|
36
|
+
x: FloatArray,
|
|
37
|
+
*,
|
|
38
|
+
window_len: int,
|
|
39
|
+
window: WindowType = "flat",
|
|
40
|
+
pad_mode: PadMode = "reflect",
|
|
41
|
+
match_edges: bool = True,
|
|
42
|
+
) -> FloatArray:
|
|
43
|
+
"""
|
|
44
|
+
Convolution smoothing with explicit edge handling.
|
|
45
|
+
|
|
46
|
+
Public helper so analysis code (e.g. peakfinding) can smooth *without*
|
|
47
|
+
creating a new state / touching state.history.
|
|
48
|
+
"""
|
|
49
|
+
x = np.asarray(x, dtype=float)
|
|
50
|
+
|
|
51
|
+
if x.ndim != 1:
|
|
52
|
+
raise ValueError("smooth_1d only accepts 1D arrays.")
|
|
53
|
+
if window_len < 3:
|
|
54
|
+
return x.copy()
|
|
55
|
+
if x.size < window_len:
|
|
56
|
+
raise ValueError("Input vector needs to be bigger than window size.")
|
|
57
|
+
if window_len % 2 == 0:
|
|
58
|
+
raise ValueError("window_len must be odd.")
|
|
59
|
+
|
|
60
|
+
w = _window(window, window_len)
|
|
61
|
+
w = w / float(np.sum(w))
|
|
62
|
+
|
|
63
|
+
half = window_len // 2
|
|
64
|
+
|
|
65
|
+
if pad_mode == "reflect":
|
|
66
|
+
xp = np.pad(x, pad_width=(half, half), mode="reflect")
|
|
67
|
+
elif pad_mode == "edge":
|
|
68
|
+
xp = np.pad(x, pad_width=(half, half), mode="edge")
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("pad_mode must be 'reflect' or 'edge'.")
|
|
71
|
+
|
|
72
|
+
y = np.convolve(xp, w, mode="valid")
|
|
73
|
+
|
|
74
|
+
if match_edges:
|
|
75
|
+
y = y.copy()
|
|
76
|
+
y[:half] = x[:half]
|
|
77
|
+
y[-half:] = x[-half:]
|
|
78
|
+
|
|
79
|
+
return y
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def savgol_smooth_1d(
|
|
83
|
+
x: FloatArray,
|
|
84
|
+
*,
|
|
85
|
+
window_len: int,
|
|
86
|
+
polyorder: int = 3,
|
|
87
|
+
mode: Literal["interp", "mirror", "nearest", "constant", "wrap"] = "interp",
|
|
88
|
+
) -> FloatArray:
|
|
89
|
+
"""
|
|
90
|
+
Savitzky–Golay smoothing (shape-preserving; good for peaks).
|
|
91
|
+
|
|
92
|
+
Requires scipy (already used elsewhere in fibphot).
|
|
93
|
+
"""
|
|
94
|
+
from scipy.signal import savgol_filter
|
|
95
|
+
|
|
96
|
+
x = np.asarray(x, dtype=float)
|
|
97
|
+
if x.ndim != 1:
|
|
98
|
+
raise ValueError("savgol_smooth_1d only accepts 1D arrays.")
|
|
99
|
+
if window_len < 3:
|
|
100
|
+
return x.copy()
|
|
101
|
+
if window_len % 2 == 0:
|
|
102
|
+
raise ValueError("window_len must be odd.")
|
|
103
|
+
if polyorder >= window_len:
|
|
104
|
+
raise ValueError("polyorder must be < window_len.")
|
|
105
|
+
if x.size < window_len:
|
|
106
|
+
raise ValueError("Input vector needs to be bigger than window size.")
|
|
107
|
+
|
|
108
|
+
# savgol_filter does not like NaNs; simple strategy:
|
|
109
|
+
# interpolate NaNs linearly before filtering, then restore NaNs.
|
|
110
|
+
m = np.isfinite(x)
|
|
111
|
+
if not np.all(m):
|
|
112
|
+
xi = np.arange(x.size, dtype=float)
|
|
113
|
+
xp = np.interp(xi, xi[m], x[m])
|
|
114
|
+
y = savgol_filter(
|
|
115
|
+
xp, window_length=window_len, polyorder=polyorder, mode=mode
|
|
116
|
+
)
|
|
117
|
+
y[~m] = np.nan # type: ignore
|
|
118
|
+
return np.asarray(y, dtype=float)
|
|
119
|
+
|
|
120
|
+
return np.asarray(
|
|
121
|
+
savgol_filter(
|
|
122
|
+
x, window_length=window_len, polyorder=polyorder, mode=mode
|
|
123
|
+
),
|
|
124
|
+
dtype=float,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _mad_sigma(x: FloatArray) -> float:
|
|
129
|
+
x = np.asarray(x, dtype=float)
|
|
130
|
+
x = x[np.isfinite(x)]
|
|
131
|
+
if x.size == 0:
|
|
132
|
+
return float("nan")
|
|
133
|
+
med = float(np.median(x))
|
|
134
|
+
mad = float(np.median(np.abs(x - med)))
|
|
135
|
+
return 1.4826 * mad
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _estimate_obs_var_from_diff(y: FloatArray) -> float:
|
|
139
|
+
"""
|
|
140
|
+
Robustly estimate observation variance R from first differences.
|
|
141
|
+
|
|
142
|
+
If y_t = x_t + v_t with white noise v_t ~ N(0, R),
|
|
143
|
+
then diff(y) has variance about 2R (ignoring process noise).
|
|
144
|
+
"""
|
|
145
|
+
y = np.asarray(y, dtype=float)
|
|
146
|
+
m = np.isfinite(y)
|
|
147
|
+
if np.sum(m) < 3:
|
|
148
|
+
return float("nan")
|
|
149
|
+
|
|
150
|
+
yy = y[m]
|
|
151
|
+
d = np.diff(yy)
|
|
152
|
+
s = _mad_sigma(d) # robust std of diffs
|
|
153
|
+
if not np.isfinite(s) or s <= 1e-20:
|
|
154
|
+
return float("nan")
|
|
155
|
+
|
|
156
|
+
# Var(diff) ~ 2R => R ~ (s^2)/2
|
|
157
|
+
return float((s * s) / 2.0)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def kalman_smooth_1d(
|
|
161
|
+
y: FloatArray,
|
|
162
|
+
*,
|
|
163
|
+
model: KalmanModel = "local_level",
|
|
164
|
+
r: float | Literal["auto"] = "auto",
|
|
165
|
+
q: float | None = None,
|
|
166
|
+
q_scale: float = 1e-3,
|
|
167
|
+
) -> FloatArray:
|
|
168
|
+
"""
|
|
169
|
+
Fast Kalman RTS smoother (NumPy-only).
|
|
170
|
+
|
|
171
|
+
model="local_level":
|
|
172
|
+
x_t = x_{t-1} + w_t
|
|
173
|
+
y_t = x_t + v_t
|
|
174
|
+
|
|
175
|
+
model="local_linear_trend":
|
|
176
|
+
[level, trend] evolves with small noise; smoother for slow drift.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
r:
|
|
181
|
+
Observation variance. "auto" estimates from first differences.
|
|
182
|
+
q:
|
|
183
|
+
Process variance (or process scale). If None, uses q = q_scale * r.
|
|
184
|
+
Smaller q => smoother.
|
|
185
|
+
q_scale:
|
|
186
|
+
Used only if q is None.
|
|
187
|
+
"""
|
|
188
|
+
y = np.asarray(y, dtype=float)
|
|
189
|
+
n = y.shape[0]
|
|
190
|
+
if y.ndim != 1:
|
|
191
|
+
raise ValueError("kalman_smooth_1d expects 1D input.")
|
|
192
|
+
|
|
193
|
+
# observation variance
|
|
194
|
+
if r == "auto":
|
|
195
|
+
r_var = _estimate_obs_var_from_diff(y)
|
|
196
|
+
if not np.isfinite(r_var) or r_var <= 1e-20:
|
|
197
|
+
# fallback: variance of residual around median
|
|
198
|
+
s = _mad_sigma(y)
|
|
199
|
+
r_var = float(s * s) if np.isfinite(s) else 1e-6
|
|
200
|
+
else:
|
|
201
|
+
r_var = float(r)
|
|
202
|
+
|
|
203
|
+
q_var = float(q_scale) * float(r_var) if q is None else float(q)
|
|
204
|
+
|
|
205
|
+
if model == "local_level":
|
|
206
|
+
# scalar filter
|
|
207
|
+
x_f = np.full(n, np.nan, dtype=float)
|
|
208
|
+
p_f = np.full(n, np.nan, dtype=float)
|
|
209
|
+
x_p = np.full(n, np.nan, dtype=float)
|
|
210
|
+
p_p = np.full(n, np.nan, dtype=float)
|
|
211
|
+
|
|
212
|
+
# init from first finite sample
|
|
213
|
+
m0 = np.isfinite(y)
|
|
214
|
+
if not np.any(m0):
|
|
215
|
+
return np.full_like(y, np.nan, dtype=float)
|
|
216
|
+
i0 = int(np.argmax(m0))
|
|
217
|
+
x = float(y[i0])
|
|
218
|
+
p = float(r_var) * 10.0
|
|
219
|
+
|
|
220
|
+
for t in range(n):
|
|
221
|
+
# predict
|
|
222
|
+
x_pred = x
|
|
223
|
+
p_pred = p + q_var
|
|
224
|
+
x_p[t] = x_pred
|
|
225
|
+
p_p[t] = p_pred
|
|
226
|
+
|
|
227
|
+
if np.isfinite(y[t]):
|
|
228
|
+
# update
|
|
229
|
+
k = p_pred / (p_pred + r_var)
|
|
230
|
+
x = x_pred + k * (float(y[t]) - x_pred)
|
|
231
|
+
p = (1.0 - k) * p_pred
|
|
232
|
+
else:
|
|
233
|
+
# missing obs: keep prediction
|
|
234
|
+
x = x_pred
|
|
235
|
+
p = p_pred
|
|
236
|
+
|
|
237
|
+
x_f[t] = x
|
|
238
|
+
p_f[t] = p
|
|
239
|
+
|
|
240
|
+
# RTS smoother
|
|
241
|
+
x_s = x_f.copy()
|
|
242
|
+
p_s = p_f.copy()
|
|
243
|
+
for t in range(n - 2, -1, -1):
|
|
244
|
+
denom = p_p[t + 1]
|
|
245
|
+
if not np.isfinite(denom) or denom <= 1e-30:
|
|
246
|
+
continue
|
|
247
|
+
c = p_f[t] / denom
|
|
248
|
+
x_s[t] = x_f[t] + c * (x_s[t + 1] - x_p[t + 1])
|
|
249
|
+
p_s[t] = p_f[t] + c * c * (p_s[t + 1] - p_p[t + 1])
|
|
250
|
+
|
|
251
|
+
return x_s
|
|
252
|
+
|
|
253
|
+
if model == "local_linear_trend":
|
|
254
|
+
# 2D state: [level, trend]
|
|
255
|
+
# level_t = level_{t-1} + trend_{t-1} + w1
|
|
256
|
+
# trend_t = trend_{t-1} + w2
|
|
257
|
+
# y_t = level_t + v
|
|
258
|
+
x_f = np.full((n, 2), np.nan, dtype=float)
|
|
259
|
+
p_f = np.full((n, 2, 2), np.nan, dtype=float)
|
|
260
|
+
x_p = np.full((n, 2), np.nan, dtype=float)
|
|
261
|
+
p_p = np.full((n, 2, 2), np.nan, dtype=float)
|
|
262
|
+
|
|
263
|
+
F = np.array([[1.0, 1.0], [0.0, 1.0]], dtype=float)
|
|
264
|
+
H = np.array([[1.0, 0.0]], dtype=float)
|
|
265
|
+
|
|
266
|
+
# split q_var between level/trend
|
|
267
|
+
Q = np.array([[q_var, 0.0], [0.0, q_var * 0.1]], dtype=float)
|
|
268
|
+
R = np.array([[r_var]], dtype=float)
|
|
269
|
+
|
|
270
|
+
m0 = np.isfinite(y)
|
|
271
|
+
if not np.any(m0):
|
|
272
|
+
return np.full_like(y, np.nan, dtype=float)
|
|
273
|
+
i0 = int(np.argmax(m0))
|
|
274
|
+
x = np.array([float(y[i0]), 0.0], dtype=float)
|
|
275
|
+
P = np.eye(2, dtype=float) * float(r_var) * 10.0
|
|
276
|
+
|
|
277
|
+
for t in range(n):
|
|
278
|
+
# predict
|
|
279
|
+
x_pred = F @ x
|
|
280
|
+
P_pred = F @ P @ F.T + Q
|
|
281
|
+
x_p[t] = x_pred
|
|
282
|
+
p_p[t] = P_pred
|
|
283
|
+
|
|
284
|
+
if np.isfinite(y[t]):
|
|
285
|
+
yt = np.array([[float(y[t])]], dtype=float)
|
|
286
|
+
S = H @ P_pred @ H.T + R
|
|
287
|
+
K = (P_pred @ H.T) @ np.linalg.inv(S)
|
|
288
|
+
x = x_pred + (K @ (yt - (H @ x_pred).reshape(1, 1))).ravel()
|
|
289
|
+
P = (np.eye(2) - K @ H) @ P_pred
|
|
290
|
+
else:
|
|
291
|
+
x = x_pred
|
|
292
|
+
P = P_pred
|
|
293
|
+
|
|
294
|
+
x_f[t] = x
|
|
295
|
+
p_f[t] = P
|
|
296
|
+
|
|
297
|
+
# RTS smoother
|
|
298
|
+
x_s = x_f.copy()
|
|
299
|
+
P_s = p_f.copy()
|
|
300
|
+
for t in range(n - 2, -1, -1):
|
|
301
|
+
P_pred_next = p_p[t + 1]
|
|
302
|
+
if not np.all(np.isfinite(P_pred_next)):
|
|
303
|
+
continue
|
|
304
|
+
C = p_f[t] @ F.T @ np.linalg.inv(P_pred_next)
|
|
305
|
+
x_s[t] = x_f[t] + C @ (x_s[t + 1] - x_p[t + 1])
|
|
306
|
+
P_s[t] = p_f[t] + C @ (P_s[t + 1] - P_pred_next) @ C.T
|
|
307
|
+
|
|
308
|
+
return x_s[:, 0]
|
|
309
|
+
|
|
310
|
+
raise ValueError("model must be 'local_level' or 'local_linear_trend'.")
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@dataclass(frozen=True, slots=True)
|
|
314
|
+
class Smooth(UpdateStage):
|
|
315
|
+
"""
|
|
316
|
+
Smooth signals by convolving with a window function.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
name: str = field(default="smooth", init=False)
|
|
320
|
+
|
|
321
|
+
window_len: int = 11
|
|
322
|
+
window: WindowType = "flat"
|
|
323
|
+
pad_mode: PadMode = "reflect"
|
|
324
|
+
match_edges: bool = True
|
|
325
|
+
channels: str | list[str] | None = None
|
|
326
|
+
|
|
327
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
328
|
+
return {
|
|
329
|
+
"window_len": self.window_len,
|
|
330
|
+
"window": self.window,
|
|
331
|
+
"pad_mode": self.pad_mode,
|
|
332
|
+
"match_edges": self.match_edges,
|
|
333
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
337
|
+
if self.window_len < 3:
|
|
338
|
+
return StageOutput(
|
|
339
|
+
signals=state.signals.copy(), notes="No-op: window_len < 3."
|
|
340
|
+
)
|
|
341
|
+
if self.window_len % 2 == 0:
|
|
342
|
+
raise ValueError("window_len must be odd and >= 3.")
|
|
343
|
+
|
|
344
|
+
idxs = _resolve_channels(state, self.channels)
|
|
345
|
+
new = state.signals.copy()
|
|
346
|
+
|
|
347
|
+
for i in idxs:
|
|
348
|
+
new[i] = smooth_1d(
|
|
349
|
+
new[i],
|
|
350
|
+
window_len=self.window_len,
|
|
351
|
+
window=self.window,
|
|
352
|
+
pad_mode=self.pad_mode,
|
|
353
|
+
match_edges=self.match_edges,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return StageOutput(signals=new)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@dataclass(frozen=True, slots=True)
|
|
360
|
+
class SavGolSmooth(UpdateStage):
|
|
361
|
+
"""
|
|
362
|
+
Savitzky–Golay smoothing stage (shape-preserving).
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
name: str = field(default="savgol_smooth", init=False)
|
|
366
|
+
|
|
367
|
+
window_len: int = 11
|
|
368
|
+
polyorder: int = 3
|
|
369
|
+
mode: Literal["interp", "mirror", "nearest", "constant", "wrap"] = "interp"
|
|
370
|
+
channels: str | list[str] | None = None
|
|
371
|
+
|
|
372
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
373
|
+
return {
|
|
374
|
+
"window_len": self.window_len,
|
|
375
|
+
"polyorder": self.polyorder,
|
|
376
|
+
"mode": self.mode,
|
|
377
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
381
|
+
if self.window_len < 3:
|
|
382
|
+
return StageOutput(
|
|
383
|
+
signals=state.signals.copy(), notes="No-op: window_len < 3."
|
|
384
|
+
)
|
|
385
|
+
if self.window_len % 2 == 0:
|
|
386
|
+
raise ValueError("window_len must be odd and >= 3.")
|
|
387
|
+
if self.polyorder >= self.window_len:
|
|
388
|
+
raise ValueError("polyorder must be < window_len.")
|
|
389
|
+
|
|
390
|
+
idxs = _resolve_channels(state, self.channels)
|
|
391
|
+
new = state.signals.copy()
|
|
392
|
+
|
|
393
|
+
for i in idxs:
|
|
394
|
+
new[i] = savgol_smooth_1d(
|
|
395
|
+
new[i],
|
|
396
|
+
window_len=self.window_len,
|
|
397
|
+
polyorder=self.polyorder,
|
|
398
|
+
mode=self.mode,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
return StageOutput(signals=new)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@dataclass(frozen=True, slots=True)
|
|
405
|
+
class KalmanSmooth(UpdateStage):
|
|
406
|
+
"""
|
|
407
|
+
Kalman RTS smoothing stage (fast, NumPy-only).
|
|
408
|
+
|
|
409
|
+
Use this mainly as a *detector helper* (e.g. peakfinding).
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
name: str = field(default="kalman_smooth", init=False)
|
|
413
|
+
|
|
414
|
+
model: KalmanModel = "local_level"
|
|
415
|
+
r: float | Literal["auto"] = "auto"
|
|
416
|
+
q: float | None = None
|
|
417
|
+
q_scale: float = 1e-3
|
|
418
|
+
channels: str | list[str] | None = None
|
|
419
|
+
|
|
420
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
421
|
+
return {
|
|
422
|
+
"model": self.model,
|
|
423
|
+
"r": self.r,
|
|
424
|
+
"q": self.q,
|
|
425
|
+
"q_scale": self.q_scale,
|
|
426
|
+
"channels": self.channels if self.channels is not None else "all",
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
430
|
+
idxs = _resolve_channels(state, self.channels)
|
|
431
|
+
new = state.signals.copy()
|
|
432
|
+
|
|
433
|
+
for i in idxs:
|
|
434
|
+
new[i] = kalman_smooth_1d(
|
|
435
|
+
new[i],
|
|
436
|
+
model=self.model,
|
|
437
|
+
r=self.r,
|
|
438
|
+
q=self.q,
|
|
439
|
+
q_scale=self.q_scale,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return StageOutput(signals=new)
|
fibphot/stages/trim.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..state import PhotometryState
|
|
9
|
+
from .base import StageOutput, UpdateStage
|
|
10
|
+
|
|
11
|
+
TrimUnit = Literal["samples", "seconds"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _as_n_samples(value: float, unit: TrimUnit, fs: float) -> int:
|
|
15
|
+
if unit == "samples":
|
|
16
|
+
return int(value)
|
|
17
|
+
return int(round(float(value) * fs))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _trim_derived_like_signals(
|
|
21
|
+
derived: dict[str, Any],
|
|
22
|
+
*,
|
|
23
|
+
n_signals: int,
|
|
24
|
+
n_samples: int,
|
|
25
|
+
sl: slice,
|
|
26
|
+
) -> dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Trim derived arrays that match the signal shape conventions:
|
|
29
|
+
- (n_signals, n_samples)
|
|
30
|
+
- (n_samples,)
|
|
31
|
+
- (h, n_signals, n_samples) [rare, but we handle it]
|
|
32
|
+
Leave everything else untouched.
|
|
33
|
+
"""
|
|
34
|
+
out: dict[str, Any] = dict(derived)
|
|
35
|
+
for k, v in derived.items():
|
|
36
|
+
arr = (
|
|
37
|
+
np.asarray(v) if isinstance(v, (list, tuple, np.ndarray)) else None
|
|
38
|
+
)
|
|
39
|
+
if arr is None:
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
if arr.shape == (n_signals, n_samples):
|
|
43
|
+
out[k] = arr[:, sl]
|
|
44
|
+
elif arr.shape == (n_samples,):
|
|
45
|
+
out[k] = arr[sl]
|
|
46
|
+
elif arr.ndim == 3 and arr.shape[1:] == (n_signals, n_samples):
|
|
47
|
+
out[k] = arr[:, :, sl]
|
|
48
|
+
|
|
49
|
+
return out
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True, slots=True)
|
|
53
|
+
class Trim(UpdateStage):
|
|
54
|
+
"""
|
|
55
|
+
Discard datapoints from the start and/or end of the recording.
|
|
56
|
+
|
|
57
|
+
You can specify trimming in either samples or seconds.
|
|
58
|
+
|
|
59
|
+
Examples
|
|
60
|
+
--------
|
|
61
|
+
Trim 10 seconds from start:
|
|
62
|
+
|
|
63
|
+
Trim(start=10, unit="seconds")
|
|
64
|
+
|
|
65
|
+
Trim first 1 second and last 2 seconds:
|
|
66
|
+
|
|
67
|
+
Trim(start=1, end=2, unit="seconds")
|
|
68
|
+
|
|
69
|
+
Trim 100 samples from end:
|
|
70
|
+
|
|
71
|
+
Trim(end=100, unit="samples")
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
name: str = field(default="trim", init=False)
|
|
75
|
+
|
|
76
|
+
start: float = 0.0
|
|
77
|
+
end: float = 0.0
|
|
78
|
+
unit: TrimUnit = "seconds"
|
|
79
|
+
|
|
80
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
81
|
+
return {"start": self.start, "end": self.end, "unit": self.unit}
|
|
82
|
+
|
|
83
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
84
|
+
if self.start < 0 or self.end < 0:
|
|
85
|
+
raise ValueError("start/end must be >= 0.")
|
|
86
|
+
|
|
87
|
+
fs = state.sampling_rate
|
|
88
|
+
n0 = state.n_samples
|
|
89
|
+
|
|
90
|
+
start_n = _as_n_samples(self.start, self.unit, fs)
|
|
91
|
+
end_n = _as_n_samples(self.end, self.unit, fs)
|
|
92
|
+
|
|
93
|
+
start_n = max(0, min(start_n, n0))
|
|
94
|
+
end_n = max(0, min(end_n, n0))
|
|
95
|
+
|
|
96
|
+
lo = start_n
|
|
97
|
+
hi = n0 - end_n
|
|
98
|
+
|
|
99
|
+
if hi <= lo:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Trim removes all samples: "
|
|
102
|
+
f"n={n0}, start={start_n}, end={end_n}."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
sl = slice(lo, hi)
|
|
106
|
+
|
|
107
|
+
new_time = state.time_seconds[sl]
|
|
108
|
+
new_signals = state.signals[:, sl]
|
|
109
|
+
new_history = (
|
|
110
|
+
state.history[:, :, sl] if state.history.size else state.history
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
new_derived = _trim_derived_like_signals(
|
|
114
|
+
state.derived,
|
|
115
|
+
n_signals=state.n_signals,
|
|
116
|
+
n_samples=state.n_samples,
|
|
117
|
+
sl=sl,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return StageOutput(
|
|
121
|
+
signals=new_signals,
|
|
122
|
+
derived=new_derived,
|
|
123
|
+
results={
|
|
124
|
+
"unit": self.unit,
|
|
125
|
+
"start": self.start,
|
|
126
|
+
"end": self.end,
|
|
127
|
+
"start_samples": int(start_n),
|
|
128
|
+
"end_samples": int(end_n),
|
|
129
|
+
"slice": (int(lo), int(hi)),
|
|
130
|
+
"old_n_samples": int(n0),
|
|
131
|
+
"new_n_samples": int(new_time.shape[0]),
|
|
132
|
+
},
|
|
133
|
+
notes=(
|
|
134
|
+
"Trimmed time/signals "
|
|
135
|
+
"(and any derived arrays matching signal shape)."
|
|
136
|
+
),
|
|
137
|
+
data={
|
|
138
|
+
"time_seconds": new_time,
|
|
139
|
+
"history": new_history,
|
|
140
|
+
},
|
|
141
|
+
)
|