fibphot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fibphot/__init__.py +6 -0
- fibphot/analysis/__init__.py +0 -0
- fibphot/analysis/aggregate.py +257 -0
- fibphot/analysis/auc.py +354 -0
- fibphot/analysis/irls.py +350 -0
- fibphot/analysis/peaks.py +1163 -0
- fibphot/analysis/photobleaching.py +290 -0
- fibphot/analysis/plotting.py +105 -0
- fibphot/analysis/report.py +56 -0
- fibphot/collection.py +207 -0
- fibphot/fit/__init__.py +0 -0
- fibphot/fit/regression.py +269 -0
- fibphot/io/__init__.py +6 -0
- fibphot/io/doric.py +435 -0
- fibphot/io/excel.py +76 -0
- fibphot/io/h5.py +321 -0
- fibphot/misc.py +11 -0
- fibphot/peaks.py +628 -0
- fibphot/pipeline.py +14 -0
- fibphot/plotting.py +594 -0
- fibphot/stages/__init__.py +22 -0
- fibphot/stages/base.py +101 -0
- fibphot/stages/baseline.py +354 -0
- fibphot/stages/control_dff.py +214 -0
- fibphot/stages/filters.py +273 -0
- fibphot/stages/normalisation.py +260 -0
- fibphot/stages/regression.py +139 -0
- fibphot/stages/smooth.py +442 -0
- fibphot/stages/trim.py +141 -0
- fibphot/state.py +309 -0
- fibphot/tags.py +130 -0
- fibphot/types.py +6 -0
- fibphot-0.1.0.dist-info/METADATA +63 -0
- fibphot-0.1.0.dist-info/RECORD +37 -0
- fibphot-0.1.0.dist-info/WHEEL +5 -0
- fibphot-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- fibphot-0.1.0.dist-info/top_level.txt +1 -0
fibphot/plotting.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .state import PhotometryState
|
|
11
|
+
from .types import FloatArray
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from matplotlib.axes import Axes
|
|
15
|
+
from matplotlib.figure import Figure, SubFigure
|
|
16
|
+
|
|
17
|
+
PlotView = Literal["current", "raw", "before_stage", "after_stage"]
|
|
18
|
+
PlotMode = Literal["overlay", "stacked"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, slots=True)
|
|
22
|
+
class PlotTheme:
|
|
23
|
+
figsize: tuple[float, float] = (6.0, 3.0)
|
|
24
|
+
dpi: int = 150
|
|
25
|
+
|
|
26
|
+
label_size: int = 8
|
|
27
|
+
tick_size: int = 8
|
|
28
|
+
legend_size: int = 8
|
|
29
|
+
title_size: int = 8
|
|
30
|
+
|
|
31
|
+
linewidth: float = 1.1
|
|
32
|
+
alpha: float = 1.0
|
|
33
|
+
|
|
34
|
+
# Vibrant, colour-blind friendly
|
|
35
|
+
signal_colour: str = "#0072B2" # blue
|
|
36
|
+
control_colour: str = "#000000" # black
|
|
37
|
+
baseline_colour: str = "#E69F00" # orange (baseline fit)
|
|
38
|
+
fit_colour: str = "#009E73" # green (motion/control fit)
|
|
39
|
+
difference_colour: str = "#D55E00" # vermillion
|
|
40
|
+
accent_purple: str = "#CC79A7" # purple
|
|
41
|
+
|
|
42
|
+
cycle: tuple[str, ...] = (
|
|
43
|
+
"#0072B2", # blue
|
|
44
|
+
"#009E73", # green
|
|
45
|
+
"#D55E00", # vermillion
|
|
46
|
+
"#E69F00", # orange
|
|
47
|
+
"#CC79A7", # purple
|
|
48
|
+
"#56B4E9", # light blue
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def set_plot_defaults() -> None:
|
|
53
|
+
"""Set matplotlib rcParams for consistent plotting style."""
|
|
54
|
+
|
|
55
|
+
theme = PlotTheme()
|
|
56
|
+
|
|
57
|
+
plt.rcParams.update(
|
|
58
|
+
{
|
|
59
|
+
"figure.dpi": theme.dpi,
|
|
60
|
+
"savefig.dpi": theme.dpi,
|
|
61
|
+
"font.size": theme.label_size,
|
|
62
|
+
"axes.titlesize": theme.title_size,
|
|
63
|
+
"axes.labelsize": theme.label_size,
|
|
64
|
+
"xtick.labelsize": theme.tick_size,
|
|
65
|
+
"ytick.labelsize": theme.tick_size,
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _apply_theme(ax: Axes, theme: PlotTheme) -> None:
|
|
71
|
+
ax.tick_params(labelsize=theme.tick_size)
|
|
72
|
+
ax.xaxis.label.set_fontsize(theme.label_size)
|
|
73
|
+
ax.yaxis.label.set_fontsize(theme.label_size)
|
|
74
|
+
ax.title.set_fontsize(theme.title_size)
|
|
75
|
+
|
|
76
|
+
ax.spines["top"].set_visible(False)
|
|
77
|
+
ax.spines["right"].set_visible(False)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _new_axes(
|
|
81
|
+
ax: Axes | None, theme: PlotTheme
|
|
82
|
+
) -> tuple[Figure | SubFigure, Axes]:
|
|
83
|
+
if ax is not None:
|
|
84
|
+
fig = ax.figure
|
|
85
|
+
_apply_theme(ax, theme)
|
|
86
|
+
return fig, ax
|
|
87
|
+
|
|
88
|
+
fig, ax = plt.subplots(
|
|
89
|
+
figsize=theme.figsize,
|
|
90
|
+
dpi=theme.dpi,
|
|
91
|
+
constrained_layout=True,
|
|
92
|
+
)
|
|
93
|
+
_apply_theme(ax, theme)
|
|
94
|
+
return fig, ax
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _snapshot_at(state: PhotometryState, state_index: int) -> FloatArray:
|
|
98
|
+
"""
|
|
99
|
+
Return the signals snapshot after `state_index` stages.
|
|
100
|
+
|
|
101
|
+
For k applied stages:
|
|
102
|
+
- snapshot 0: raw (after 0 stages)
|
|
103
|
+
- snapshot j: after j stages
|
|
104
|
+
- snapshot k: current (after k stages)
|
|
105
|
+
|
|
106
|
+
With your state design, history has length k and stores snapshots for
|
|
107
|
+
state_index in [0, k-1]; current is state.signals.
|
|
108
|
+
"""
|
|
109
|
+
k = len(state.summary)
|
|
110
|
+
if state_index < 0 or state_index > k:
|
|
111
|
+
raise ValueError(f"state_index must be in [0, {k}], got {state_index}.")
|
|
112
|
+
|
|
113
|
+
if k == 0:
|
|
114
|
+
return state.signals
|
|
115
|
+
|
|
116
|
+
if state_index < k:
|
|
117
|
+
return state.history[state_index]
|
|
118
|
+
|
|
119
|
+
return state.signals
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _find_stage_index(
|
|
123
|
+
state: PhotometryState,
|
|
124
|
+
*,
|
|
125
|
+
stage_name: str | None,
|
|
126
|
+
stage_id: str | None,
|
|
127
|
+
occurrence: int,
|
|
128
|
+
) -> int:
|
|
129
|
+
"""Return 0-based index into state.summary for the requested stage."""
|
|
130
|
+
if stage_name is None and stage_id is None:
|
|
131
|
+
raise ValueError("Provide stage_name or stage_id.")
|
|
132
|
+
|
|
133
|
+
matches: list[int] = []
|
|
134
|
+
for i, rec in enumerate(state.summary):
|
|
135
|
+
if (stage_id is not None and rec.stage_id == stage_id) or (
|
|
136
|
+
stage_name is not None and rec.name.lower() == stage_name.lower()
|
|
137
|
+
):
|
|
138
|
+
matches.append(i)
|
|
139
|
+
|
|
140
|
+
if not matches:
|
|
141
|
+
key = stage_id if stage_id is not None else stage_name
|
|
142
|
+
raise KeyError(f"Stage not found in summary: {key!r}")
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
return matches[occurrence]
|
|
146
|
+
except IndexError as exc:
|
|
147
|
+
raise IndexError(
|
|
148
|
+
f"Stage occurrence {occurrence} is out of range. "
|
|
149
|
+
f"Found {len(matches)} occurrence(s)."
|
|
150
|
+
) from exc
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _signals_for_view(
|
|
154
|
+
state: PhotometryState,
|
|
155
|
+
*,
|
|
156
|
+
view: PlotView,
|
|
157
|
+
stage_name: str | None,
|
|
158
|
+
stage_id: str | None,
|
|
159
|
+
occurrence: int,
|
|
160
|
+
) -> FloatArray:
|
|
161
|
+
"""Select which signals snapshot to plot."""
|
|
162
|
+
if view == "current":
|
|
163
|
+
return state.signals
|
|
164
|
+
|
|
165
|
+
if view == "raw":
|
|
166
|
+
return _snapshot_at(state, 0)
|
|
167
|
+
|
|
168
|
+
si = _find_stage_index(
|
|
169
|
+
state,
|
|
170
|
+
stage_name=stage_name,
|
|
171
|
+
stage_id=stage_id,
|
|
172
|
+
occurrence=occurrence,
|
|
173
|
+
)
|
|
174
|
+
if view == "before_stage":
|
|
175
|
+
return _snapshot_at(state, si)
|
|
176
|
+
if view == "after_stage":
|
|
177
|
+
return _snapshot_at(state, si + 1)
|
|
178
|
+
|
|
179
|
+
raise ValueError(f"Unknown view: {view!r}")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def plot_current(
|
|
183
|
+
state: PhotometryState,
|
|
184
|
+
*,
|
|
185
|
+
signal: str,
|
|
186
|
+
control: str | None = None,
|
|
187
|
+
baseline_key: str | None = None,
|
|
188
|
+
motion_fit_key: str | None = None,
|
|
189
|
+
overlay_offset: dict[str, float] | None = None,
|
|
190
|
+
view: PlotView = "current",
|
|
191
|
+
stage_name: str | None = None,
|
|
192
|
+
stage_id: str | None = None,
|
|
193
|
+
occurrence: int = -1,
|
|
194
|
+
title: str | None = None,
|
|
195
|
+
xlim: tuple[float, float] | None = None,
|
|
196
|
+
ylim: tuple[float, float] | None = None,
|
|
197
|
+
show_legend: bool = True,
|
|
198
|
+
theme: PlotTheme | None = None,
|
|
199
|
+
ax: Axes | None = None,
|
|
200
|
+
label: str | None = None,
|
|
201
|
+
colour: str | None = None,
|
|
202
|
+
linestyle: str | None = None,
|
|
203
|
+
alpha: float | None = None,
|
|
204
|
+
linewidth: float | None = None,
|
|
205
|
+
) -> tuple[Figure | SubFigure, Axes]:
|
|
206
|
+
"""
|
|
207
|
+
Plot a single signal from a chosen snapshot, optionally with control and
|
|
208
|
+
derived overlays.
|
|
209
|
+
|
|
210
|
+
Typical use after `DoubleExpBaseline(subtract=True)`:
|
|
211
|
+
|
|
212
|
+
state.plot(
|
|
213
|
+
signal="gcamp",
|
|
214
|
+
baseline_key="double_exp_baseline",
|
|
215
|
+
view="before_stage",
|
|
216
|
+
stage_name="double_exp_baseline",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
baseline_key:
|
|
220
|
+
A derived array with shape == state.signals.shape, e.g.
|
|
221
|
+
"double_exp_baseline". Only the `signal` channel is drawn.
|
|
222
|
+
|
|
223
|
+
motion_fit_key:
|
|
224
|
+
A derived array with shape == state.signals.shape, e.g. "motion_fit".
|
|
225
|
+
|
|
226
|
+
overlay_offset:
|
|
227
|
+
Optional per-overlay constant offsets applied when plotting overlays,
|
|
228
|
+
keyed by the overlay name. For example:
|
|
229
|
+
|
|
230
|
+
overlay_offset={"motion_fit": -0.01}
|
|
231
|
+
|
|
232
|
+
Offsets are applied only to overlays (not the main signal/control).
|
|
233
|
+
"""
|
|
234
|
+
theme = theme or PlotTheme()
|
|
235
|
+
fig, ax = _new_axes(ax, theme)
|
|
236
|
+
|
|
237
|
+
offsets = overlay_offset or {}
|
|
238
|
+
|
|
239
|
+
def _offset_for(key: str) -> float:
|
|
240
|
+
v = offsets.get(key, 0.0)
|
|
241
|
+
return float(v)
|
|
242
|
+
|
|
243
|
+
t = state.time_seconds
|
|
244
|
+
sig = _signals_for_view(
|
|
245
|
+
state,
|
|
246
|
+
view=view,
|
|
247
|
+
stage_name=stage_name,
|
|
248
|
+
stage_id=stage_id,
|
|
249
|
+
occurrence=occurrence,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
i_sig = state.idx(signal)
|
|
253
|
+
y = sig[i_sig]
|
|
254
|
+
|
|
255
|
+
ax.plot(
|
|
256
|
+
t,
|
|
257
|
+
y,
|
|
258
|
+
label=label or f"{signal.lower()} ({view})",
|
|
259
|
+
color=colour or theme.signal_colour,
|
|
260
|
+
linewidth=theme.linewidth if linewidth is None else float(linewidth),
|
|
261
|
+
alpha=theme.alpha if alpha is None else float(alpha),
|
|
262
|
+
linestyle=linestyle or "-",
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if control is not None:
|
|
266
|
+
i_ctl = state.idx(control)
|
|
267
|
+
c = sig[i_ctl]
|
|
268
|
+
ax.plot(
|
|
269
|
+
t,
|
|
270
|
+
c,
|
|
271
|
+
label=f"{control.lower()} ({view})",
|
|
272
|
+
color=theme.control_colour,
|
|
273
|
+
linewidth=max(0.9, theme.linewidth * 0.9),
|
|
274
|
+
alpha=0.75,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
if baseline_key is not None and baseline_key in state.derived:
|
|
278
|
+
base = np.asarray(state.derived[baseline_key], dtype=float)
|
|
279
|
+
if base.shape != state.signals.shape:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"derived['{baseline_key}'] has shape {base.shape}, "
|
|
282
|
+
f"expected {state.signals.shape}."
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
off = _offset_for(baseline_key)
|
|
286
|
+
ax.plot(
|
|
287
|
+
t,
|
|
288
|
+
base[i_sig] + off,
|
|
289
|
+
label=baseline_key if off == 0.0 else f"{baseline_key} (offset)",
|
|
290
|
+
color=theme.baseline_colour,
|
|
291
|
+
linewidth=max(1.6, theme.linewidth * 1.6),
|
|
292
|
+
alpha=1.0,
|
|
293
|
+
linestyle="--",
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if motion_fit_key is not None and motion_fit_key in state.derived:
|
|
297
|
+
mf = np.asarray(state.derived[motion_fit_key], dtype=float)
|
|
298
|
+
if mf.shape != state.signals.shape:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"derived['{motion_fit_key}'] has shape {mf.shape}, "
|
|
301
|
+
f"expected {state.signals.shape}."
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
off = _offset_for(motion_fit_key)
|
|
305
|
+
ax.plot(
|
|
306
|
+
t,
|
|
307
|
+
mf[i_sig] + off,
|
|
308
|
+
label=motion_fit_key
|
|
309
|
+
if off == 0.0
|
|
310
|
+
else f"{motion_fit_key} (offset)",
|
|
311
|
+
color=theme.fit_colour,
|
|
312
|
+
linewidth=max(1.4, theme.linewidth * 1.4),
|
|
313
|
+
alpha=0.95,
|
|
314
|
+
linestyle=":",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
ax.set_xlabel("time (s)")
|
|
318
|
+
ax.set_ylabel(signal.lower())
|
|
319
|
+
|
|
320
|
+
if title is not None:
|
|
321
|
+
ax.set_title(title)
|
|
322
|
+
|
|
323
|
+
if xlim is not None:
|
|
324
|
+
ax.set_xlim(*xlim)
|
|
325
|
+
if ylim is not None:
|
|
326
|
+
ax.set_ylim(*ylim)
|
|
327
|
+
|
|
328
|
+
if show_legend:
|
|
329
|
+
ax.legend(frameon=False, fontsize=theme.legend_size)
|
|
330
|
+
|
|
331
|
+
return fig, ax
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _label_for_state_index(
|
|
335
|
+
state: PhotometryState,
|
|
336
|
+
state_index: int,
|
|
337
|
+
*,
|
|
338
|
+
include_index: bool,
|
|
339
|
+
include_current_tag: bool,
|
|
340
|
+
) -> str:
|
|
341
|
+
k = len(state.summary)
|
|
342
|
+
base = "raw" if state_index <= 0 else state.summary[state_index - 1].name
|
|
343
|
+
|
|
344
|
+
parts: list[str] = []
|
|
345
|
+
if include_index:
|
|
346
|
+
parts.append(f"{state_index:02d}")
|
|
347
|
+
parts.append(base)
|
|
348
|
+
|
|
349
|
+
label = " - ".join(parts)
|
|
350
|
+
if include_current_tag and state_index == k and k > 0:
|
|
351
|
+
label = f"{label} (current)"
|
|
352
|
+
return label
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def plot_history(
|
|
356
|
+
state: PhotometryState,
|
|
357
|
+
channel: str,
|
|
358
|
+
*,
|
|
359
|
+
stage_name: str | None = None,
|
|
360
|
+
stage_id: str | None = None,
|
|
361
|
+
occurrence: int = -1,
|
|
362
|
+
around: bool = False,
|
|
363
|
+
plot_difference: bool = False,
|
|
364
|
+
difference_label: str | None = None,
|
|
365
|
+
n_recent: int | None = None,
|
|
366
|
+
include_raw: bool = True,
|
|
367
|
+
include_current: bool = True,
|
|
368
|
+
include_index: bool = True,
|
|
369
|
+
include_current_tag: bool = True,
|
|
370
|
+
mode: PlotMode = "overlay",
|
|
371
|
+
theme: PlotTheme | None = None,
|
|
372
|
+
ax: Axes | None = None,
|
|
373
|
+
title: str | None = None,
|
|
374
|
+
) -> tuple[Figure | SubFigure, Axes | tuple[Axes, Axes]]:
|
|
375
|
+
"""
|
|
376
|
+
Plot one channel across the saved history (and optionally current).
|
|
377
|
+
|
|
378
|
+
If around=True, plot only the snapshot immediately before and immediately
|
|
379
|
+
after a chosen stage.
|
|
380
|
+
"""
|
|
381
|
+
theme = theme or PlotTheme()
|
|
382
|
+
ch_i = state.idx(channel)
|
|
383
|
+
t = state.time_seconds
|
|
384
|
+
k = len(state.summary)
|
|
385
|
+
|
|
386
|
+
if around:
|
|
387
|
+
if k == 0:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
"No stages have been applied; cannot plot around a stage."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
si = _find_stage_index(
|
|
393
|
+
state,
|
|
394
|
+
stage_name=stage_name,
|
|
395
|
+
stage_id=stage_id,
|
|
396
|
+
occurrence=occurrence,
|
|
397
|
+
)
|
|
398
|
+
stage = state.summary[si]
|
|
399
|
+
before_sig = _snapshot_at(state, si)
|
|
400
|
+
after_sig = _snapshot_at(state, si + 1)
|
|
401
|
+
|
|
402
|
+
y_before = before_sig[ch_i]
|
|
403
|
+
y_after = after_sig[ch_i]
|
|
404
|
+
y_diff = y_after - y_before
|
|
405
|
+
|
|
406
|
+
if plot_difference:
|
|
407
|
+
if ax is not None:
|
|
408
|
+
raise ValueError(
|
|
409
|
+
"plot_difference=True requires ax=None so the function can "
|
|
410
|
+
"create a two-row figure."
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
fig, (ax0, ax1) = plt.subplots(
|
|
414
|
+
nrows=2,
|
|
415
|
+
sharex=True,
|
|
416
|
+
figsize=theme.figsize,
|
|
417
|
+
dpi=theme.dpi,
|
|
418
|
+
constrained_layout=True,
|
|
419
|
+
gridspec_kw={"height_ratios": [2, 1]},
|
|
420
|
+
)
|
|
421
|
+
_apply_theme(ax0, theme)
|
|
422
|
+
_apply_theme(ax1, theme)
|
|
423
|
+
|
|
424
|
+
ax0.plot(
|
|
425
|
+
t,
|
|
426
|
+
y_before,
|
|
427
|
+
label=f"before {stage.name}",
|
|
428
|
+
linewidth=theme.linewidth,
|
|
429
|
+
alpha=theme.alpha,
|
|
430
|
+
color=theme.control_colour,
|
|
431
|
+
)
|
|
432
|
+
ax0.plot(
|
|
433
|
+
t,
|
|
434
|
+
y_after,
|
|
435
|
+
label=f"after {stage.name}",
|
|
436
|
+
linewidth=theme.linewidth,
|
|
437
|
+
alpha=theme.alpha,
|
|
438
|
+
color=theme.signal_colour,
|
|
439
|
+
)
|
|
440
|
+
ax0.legend(frameon=False, fontsize=theme.legend_size)
|
|
441
|
+
|
|
442
|
+
ax1.plot(
|
|
443
|
+
t,
|
|
444
|
+
y_diff,
|
|
445
|
+
linewidth=theme.linewidth,
|
|
446
|
+
alpha=theme.alpha,
|
|
447
|
+
color=theme.difference_colour,
|
|
448
|
+
)
|
|
449
|
+
ax1.axhline(0.0, linewidth=1.0, color="k")
|
|
450
|
+
|
|
451
|
+
ax0.set_ylabel(channel.lower())
|
|
452
|
+
ax1.set_ylabel("difference")
|
|
453
|
+
ax1.set_xlabel("time (s)")
|
|
454
|
+
ax0.set_title(title or f"{channel.lower()} — around {stage.name}")
|
|
455
|
+
|
|
456
|
+
return fig, (ax0, ax1)
|
|
457
|
+
|
|
458
|
+
fig, ax0 = _new_axes(ax, theme)
|
|
459
|
+
|
|
460
|
+
if mode == "overlay":
|
|
461
|
+
ax0.plot(
|
|
462
|
+
t,
|
|
463
|
+
y_before,
|
|
464
|
+
label=f"before {stage.name}",
|
|
465
|
+
linewidth=theme.linewidth,
|
|
466
|
+
alpha=theme.alpha,
|
|
467
|
+
color=theme.control_colour,
|
|
468
|
+
)
|
|
469
|
+
ax0.plot(
|
|
470
|
+
t,
|
|
471
|
+
y_after,
|
|
472
|
+
label=f"after {stage.name}",
|
|
473
|
+
linewidth=theme.linewidth,
|
|
474
|
+
alpha=theme.alpha,
|
|
475
|
+
color=theme.signal_colour,
|
|
476
|
+
)
|
|
477
|
+
ax0.legend(frameon=False, fontsize=theme.legend_size)
|
|
478
|
+
|
|
479
|
+
elif mode == "stacked":
|
|
480
|
+
base_scale = float(np.nanpercentile(np.abs(y_after), 95))
|
|
481
|
+
step = base_scale if base_scale > 0 else 1.0
|
|
482
|
+
|
|
483
|
+
ax0.plot(
|
|
484
|
+
t,
|
|
485
|
+
y_before,
|
|
486
|
+
label=f"before {stage.name}",
|
|
487
|
+
linewidth=theme.linewidth,
|
|
488
|
+
alpha=theme.alpha,
|
|
489
|
+
color=theme.control_colour,
|
|
490
|
+
)
|
|
491
|
+
ax0.plot(
|
|
492
|
+
t,
|
|
493
|
+
y_after + step,
|
|
494
|
+
label=f"after {stage.name}",
|
|
495
|
+
linewidth=theme.linewidth,
|
|
496
|
+
alpha=theme.alpha,
|
|
497
|
+
color=theme.signal_colour,
|
|
498
|
+
)
|
|
499
|
+
ax0.legend(frameon=False, fontsize=theme.legend_size)
|
|
500
|
+
|
|
501
|
+
else:
|
|
502
|
+
raise ValueError(f"Unknown mode: {mode!r}")
|
|
503
|
+
|
|
504
|
+
ax0.set_xlabel("time (s)")
|
|
505
|
+
ax0.set_ylabel(channel.lower())
|
|
506
|
+
ax0.set_title(title or f"{channel.lower()} — around {stage.name}")
|
|
507
|
+
|
|
508
|
+
return fig, ax0
|
|
509
|
+
|
|
510
|
+
# ---- default behaviour: plot multiple snapshots ----
|
|
511
|
+
|
|
512
|
+
snapshots: list[tuple[int, FloatArray]] = []
|
|
513
|
+
h = int(state.history.shape[0])
|
|
514
|
+
|
|
515
|
+
if h == 0:
|
|
516
|
+
snapshots.append((0, state.signals))
|
|
517
|
+
else:
|
|
518
|
+
for j in range(h):
|
|
519
|
+
snapshots.append((j, state.history[j]))
|
|
520
|
+
snapshots.append((h, state.signals))
|
|
521
|
+
|
|
522
|
+
filtered: list[tuple[int, FloatArray]] = []
|
|
523
|
+
for state_index, sig in snapshots:
|
|
524
|
+
if state_index == 0 and not include_raw and h > 0:
|
|
525
|
+
continue
|
|
526
|
+
if state_index == k and not include_current:
|
|
527
|
+
continue
|
|
528
|
+
filtered.append((state_index, sig))
|
|
529
|
+
|
|
530
|
+
if not filtered:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"No snapshots selected to plot (check include_* flags)."
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
if n_recent is not None:
|
|
536
|
+
if n_recent < 1:
|
|
537
|
+
raise ValueError("n_recent must be >= 1.")
|
|
538
|
+
filtered = filtered[-n_recent:]
|
|
539
|
+
|
|
540
|
+
fig, ax0 = _new_axes(ax, theme)
|
|
541
|
+
with contextlib.suppress(Exception):
|
|
542
|
+
ax0.set_prop_cycle(color=list(theme.cycle))
|
|
543
|
+
|
|
544
|
+
if mode == "overlay":
|
|
545
|
+
for state_index, sig in filtered:
|
|
546
|
+
y = sig[ch_i]
|
|
547
|
+
label = _label_for_state_index(
|
|
548
|
+
state,
|
|
549
|
+
state_index,
|
|
550
|
+
include_index=include_index,
|
|
551
|
+
include_current_tag=include_current_tag,
|
|
552
|
+
)
|
|
553
|
+
ax0.plot(
|
|
554
|
+
t,
|
|
555
|
+
y,
|
|
556
|
+
label=label,
|
|
557
|
+
linewidth=theme.linewidth,
|
|
558
|
+
alpha=theme.alpha,
|
|
559
|
+
)
|
|
560
|
+
ax0.legend(frameon=False, fontsize=theme.legend_size)
|
|
561
|
+
|
|
562
|
+
elif mode == "stacked":
|
|
563
|
+
offset = 0.0
|
|
564
|
+
base_scale = float(np.nanpercentile(np.abs(state.signals[ch_i]), 95))
|
|
565
|
+
step = base_scale if base_scale > 0 else 1.0
|
|
566
|
+
|
|
567
|
+
for state_index, sig in filtered:
|
|
568
|
+
y = sig[ch_i]
|
|
569
|
+
label = _label_for_state_index(
|
|
570
|
+
state,
|
|
571
|
+
state_index,
|
|
572
|
+
include_index=include_index,
|
|
573
|
+
include_current_tag=include_current_tag,
|
|
574
|
+
)
|
|
575
|
+
ax0.plot(
|
|
576
|
+
t,
|
|
577
|
+
y + offset,
|
|
578
|
+
label=label,
|
|
579
|
+
linewidth=theme.linewidth,
|
|
580
|
+
alpha=theme.alpha,
|
|
581
|
+
)
|
|
582
|
+
offset += step
|
|
583
|
+
|
|
584
|
+
ax0.legend(frameon=False, fontsize=theme.legend_size)
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
raise ValueError(f"Unknown mode: {mode!r}")
|
|
588
|
+
|
|
589
|
+
ax0.set_xlabel("time (s)")
|
|
590
|
+
ax0.set_ylabel(channel.lower())
|
|
591
|
+
if title is not None:
|
|
592
|
+
ax0.set_title(title)
|
|
593
|
+
|
|
594
|
+
return fig, ax0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .baseline import DoubleExpBaseline, PyBaselinesBaseline
|
|
4
|
+
from .control_dff import IsosbesticDff
|
|
5
|
+
from .filters import HampelFilter, LowPassFilter, MedianFilter
|
|
6
|
+
from .normalisation import Normalise
|
|
7
|
+
from .regression import IsosbesticRegression
|
|
8
|
+
from .smooth import Smooth
|
|
9
|
+
from .trim import Trim
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"DoubleExpBaseline",
|
|
13
|
+
"PyBaselinesBaseline",
|
|
14
|
+
"HampelFilter",
|
|
15
|
+
"IsosbesticDff",
|
|
16
|
+
"IsosbesticRegression",
|
|
17
|
+
"LowPassFilter",
|
|
18
|
+
"MedianFilter",
|
|
19
|
+
"Normalise",
|
|
20
|
+
"Smooth",
|
|
21
|
+
"Trim",
|
|
22
|
+
]
|
fibphot/stages/base.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ..state import PhotometryState, StageRecord
|
|
10
|
+
from ..types import FloatArray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True, slots=True)
|
|
14
|
+
class StageOutput:
|
|
15
|
+
"""Output of a processing stage."""
|
|
16
|
+
|
|
17
|
+
signals: FloatArray | None = None
|
|
18
|
+
derived: dict[str, FloatArray] | None = None
|
|
19
|
+
results: dict[str, Any] | None = None
|
|
20
|
+
metrics: dict[str, float] | None = None
|
|
21
|
+
notes: str | None = None
|
|
22
|
+
data: dict[str, object] = field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_channels(
|
|
26
|
+
state: PhotometryState, channels: str | list[str] | None
|
|
27
|
+
) -> list[int]:
|
|
28
|
+
"""Resolve channel names to indices."""
|
|
29
|
+
|
|
30
|
+
if channels is None or (
|
|
31
|
+
isinstance(channels, str) and channels.lower() == "all"
|
|
32
|
+
):
|
|
33
|
+
return list(range(state.n_signals))
|
|
34
|
+
|
|
35
|
+
if isinstance(channels, str):
|
|
36
|
+
return [state.idx(channels)]
|
|
37
|
+
|
|
38
|
+
return [state.idx(c) for c in channels]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True, slots=True)
|
|
42
|
+
class UpdateStage(ABC):
|
|
43
|
+
name: str
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def apply(self, state: PhotometryState) -> StageOutput:
|
|
47
|
+
"""Apply the stage to the given PhotometryState."""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
def __call__(self, state: PhotometryState) -> PhotometryState:
|
|
51
|
+
"""Apply the stage and return an updated PhotometryState."""
|
|
52
|
+
|
|
53
|
+
state0 = state.push_history()
|
|
54
|
+
|
|
55
|
+
out = self.apply(state0)
|
|
56
|
+
|
|
57
|
+
time_seconds = np.asarray(
|
|
58
|
+
out.data.get("time_seconds", state0.time_seconds), dtype=float
|
|
59
|
+
)
|
|
60
|
+
history = np.asarray(
|
|
61
|
+
out.data.get("history", state0.history), dtype=float
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
new_signals = (
|
|
65
|
+
state0.signals
|
|
66
|
+
if out.signals is None
|
|
67
|
+
else np.asarray(out.signals, dtype=float)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
stage_id = f"{len(state0.summary) + 1:04d}_{self.name.lower()}"
|
|
71
|
+
record = StageRecord(
|
|
72
|
+
stage_id=stage_id,
|
|
73
|
+
name=self.name,
|
|
74
|
+
params=self._params_for_summary(),
|
|
75
|
+
metrics=out.metrics or {},
|
|
76
|
+
notes=out.notes,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
new_summary = (*state0.summary, record)
|
|
80
|
+
|
|
81
|
+
new_derived = dict(state0.derived)
|
|
82
|
+
if out.derived:
|
|
83
|
+
new_derived.update(out.derived)
|
|
84
|
+
|
|
85
|
+
new_results = dict(state0.results)
|
|
86
|
+
new_results[stage_id] = out.results or {}
|
|
87
|
+
|
|
88
|
+
return PhotometryState(
|
|
89
|
+
time_seconds=time_seconds,
|
|
90
|
+
signals=new_signals,
|
|
91
|
+
channel_names=state0.channel_names,
|
|
92
|
+
history=history,
|
|
93
|
+
summary=new_summary,
|
|
94
|
+
derived=new_derived,
|
|
95
|
+
results=new_results,
|
|
96
|
+
metadata=state0.metadata,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _params_for_summary(self) -> dict[str, Any]:
|
|
100
|
+
"""Override to store concise parameters for reproducibility."""
|
|
101
|
+
return {}
|