fibphot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,350 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Literal
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ from matplotlib.axes import Axes
10
+ from matplotlib.figure import Figure
11
+
12
+ from ..fit.regression import fit_irls
13
+ from ..state import PhotometryState
14
+ from ..types import FloatArray
15
+
16
+ Loss = Literal["tukey", "huber"]
17
+ PlotView = Literal["current", "raw", "before_stage", "after_stage"]
18
+
19
+
20
+ @dataclass(frozen=True, slots=True)
21
+ class SweepSpec:
22
+ """How to select x/y from a state for the sweep."""
23
+
24
+ control: str = "iso"
25
+ channel: str = "gcamp"
26
+ view: PlotView = "current"
27
+ stage_name: str | None = None
28
+ stage_id: str | None = None
29
+ occurrence: int = -1
30
+
31
+
32
+ def _snapshot_at(state: PhotometryState, state_index: int) -> FloatArray:
33
+ """
34
+ Return the signals snapshot after `state_index` stages.
35
+
36
+ For k applied stages:
37
+ - snapshot 0: raw (after 0 stages)
38
+ - snapshot j: after j stages
39
+ - snapshot k: current (after k stages)
40
+ """
41
+ k = len(state.summary)
42
+ if state_index < 0 or state_index > k:
43
+ raise ValueError(f"state_index must be in [0, {k}], got {state_index}.")
44
+
45
+ if k == 0:
46
+ return state.signals
47
+
48
+ if state_index < k:
49
+ return state.history[state_index]
50
+
51
+ return state.signals
52
+
53
+
54
+ def _find_stage_index(
55
+ state: PhotometryState,
56
+ *,
57
+ stage_name: str | None,
58
+ stage_id: str | None,
59
+ occurrence: int,
60
+ ) -> int:
61
+ if stage_name is None and stage_id is None:
62
+ raise ValueError(
63
+ "Provide stage_name or stage_id for stage-based views."
64
+ )
65
+
66
+ matches: list[int] = []
67
+ for i, rec in enumerate(state.summary):
68
+ if (stage_id is not None and rec.stage_id == stage_id) or (
69
+ stage_name is not None and rec.name.lower() == stage_name.lower()
70
+ ):
71
+ matches.append(i)
72
+
73
+ if not matches:
74
+ key = stage_id if stage_id is not None else stage_name
75
+ raise KeyError(f"Stage not found in summary: {key!r}")
76
+
77
+ try:
78
+ return matches[occurrence]
79
+ except IndexError as exc:
80
+ raise IndexError(
81
+ f"Stage occurrence {occurrence} is out of range. "
82
+ f"Found {len(matches)} occurrence(s)."
83
+ ) from exc
84
+
85
+
86
+ def _signals_for_view(
87
+ state: PhotometryState,
88
+ *,
89
+ view: PlotView,
90
+ stage_name: str | None,
91
+ stage_id: str | None,
92
+ occurrence: int,
93
+ ) -> FloatArray:
94
+ if view == "current":
95
+ return state.signals
96
+
97
+ if view == "raw":
98
+ return _snapshot_at(state, 0)
99
+
100
+ si = _find_stage_index(
101
+ state,
102
+ stage_name=stage_name,
103
+ stage_id=stage_id,
104
+ occurrence=occurrence,
105
+ )
106
+ if view == "before_stage":
107
+ return _snapshot_at(state, si)
108
+ if view == "after_stage":
109
+ return _snapshot_at(state, si + 1)
110
+
111
+ raise ValueError(f"Unknown view: {view!r}")
112
+
113
+
114
+ def _weight_stats(w: FloatArray | None) -> dict[str, float]:
115
+ if w is None:
116
+ return {
117
+ "w_mean": float("nan"),
118
+ "w_median": float("nan"),
119
+ "w_min": float("nan"),
120
+ "w_p01": float("nan"),
121
+ "w_p05": float("nan"),
122
+ "w_zero_frac": float("nan"),
123
+ "w_lt_01_frac": float("nan"),
124
+ }
125
+
126
+ ww = np.asarray(w, dtype=float)
127
+ ww = ww[np.isfinite(ww)]
128
+ if ww.size == 0:
129
+ return {
130
+ "w_mean": float("nan"),
131
+ "w_median": float("nan"),
132
+ "w_min": float("nan"),
133
+ "w_p01": float("nan"),
134
+ "w_p05": float("nan"),
135
+ "w_zero_frac": float("nan"),
136
+ "w_lt_01_frac": float("nan"),
137
+ }
138
+
139
+ return {
140
+ "w_mean": float(np.mean(ww)),
141
+ "w_median": float(np.median(ww)),
142
+ "w_min": float(np.min(ww)),
143
+ "w_p01": float(np.quantile(ww, 0.01)),
144
+ "w_p05": float(np.quantile(ww, 0.05)),
145
+ "w_zero_frac": float(np.mean(ww <= 0.0)),
146
+ "w_lt_01_frac": float(np.mean(ww < 0.1)),
147
+ }
148
+
149
+
150
+ def irls_tuning_sweep_xy(
151
+ x: FloatArray,
152
+ y: FloatArray,
153
+ *,
154
+ tuning_constants: list[float] | FloatArray,
155
+ loss: Loss = "tukey",
156
+ include_intercept: bool = True,
157
+ max_iter: int = 100,
158
+ tol: float = 1e-10,
159
+ store_weights: bool = True,
160
+ ) -> pd.DataFrame:
161
+ """
162
+ Sweep IRLS tuning constants for y ~ a + b x.
163
+
164
+ Returns a DataFrame with fit parameters and robustness diagnostics.
165
+ """
166
+ tc = np.asarray(tuning_constants, dtype=float)
167
+ if tc.ndim != 1 or tc.size == 0:
168
+ raise ValueError("tuning_constants must be a non-empty 1D sequence.")
169
+ if np.any(~np.isfinite(tc)) or np.any(tc <= 0.0):
170
+ raise ValueError("All tuning_constants must be finite and > 0.")
171
+
172
+ rows: list[dict[str, float]] = []
173
+ for c in tc:
174
+ fit = fit_irls(
175
+ x,
176
+ y,
177
+ include_intercept=include_intercept,
178
+ loss=loss,
179
+ tuning_constant=float(c),
180
+ max_iter=max_iter,
181
+ tol=tol,
182
+ store_weights=store_weights,
183
+ )
184
+ row: dict[str, float] = {
185
+ "tuning_constant": float(c),
186
+ "slope": float(fit.slope),
187
+ "intercept": float(fit.intercept),
188
+ "r2": float(fit.r2),
189
+ "n_iter": float(fit.n_iter)
190
+ if fit.n_iter is not None
191
+ else float("nan"),
192
+ "scale": float(fit.scale)
193
+ if fit.scale is not None
194
+ else float("nan"),
195
+ }
196
+ row.update(_weight_stats(fit.weights))
197
+ rows.append(row)
198
+
199
+ df = (
200
+ pd.DataFrame(rows).sort_values("tuning_constant").reset_index(drop=True)
201
+ )
202
+ df["loss"] = loss
203
+ return df
204
+
205
+
206
+ def irls_tuning_sweep(
207
+ state: PhotometryState,
208
+ *,
209
+ tuning_constants: list[float] | FloatArray,
210
+ loss: Loss = "tukey",
211
+ include_intercept: bool = True,
212
+ max_iter: int = 100,
213
+ tol: float = 1e-10,
214
+ store_weights: bool = True,
215
+ spec: SweepSpec | None = None,
216
+ ) -> pd.DataFrame:
217
+ """
218
+ Sweep IRLS tuning constants using x/y taken from a PhotometryState.
219
+
220
+ Tip:
221
+ - If you want the *pre-regression* signals for sensitivity analysis, use:
222
+ spec=SweepSpec(view="before_stage", stage_name="isosbestic_regression")
223
+ on the *post* state that contains history.
224
+ - If you're sweeping before you've applied regression, just use view="current".
225
+ """
226
+ if spec is None:
227
+ spec = SweepSpec()
228
+
229
+ sig = _signals_for_view(
230
+ state,
231
+ view=spec.view,
232
+ stage_name=spec.stage_name,
233
+ stage_id=spec.stage_id,
234
+ occurrence=spec.occurrence,
235
+ )
236
+
237
+ x = sig[state.idx(spec.control)]
238
+ y = sig[state.idx(spec.channel)]
239
+
240
+ df = irls_tuning_sweep_xy(
241
+ x,
242
+ y,
243
+ tuning_constants=tuning_constants,
244
+ loss=loss,
245
+ include_intercept=include_intercept,
246
+ max_iter=max_iter,
247
+ tol=tol,
248
+ store_weights=store_weights,
249
+ )
250
+ df["control"] = spec.control.lower()
251
+ df["channel"] = spec.channel.lower()
252
+ df["view"] = spec.view
253
+ df["stage_name"] = spec.stage_name if spec.stage_name is not None else ""
254
+ return df
255
+
256
+
257
+ def plot_irls_tuning_sweep(
258
+ df: pd.DataFrame,
259
+ *,
260
+ figsize: tuple[float, float] = (6.0, 3.0),
261
+ dpi: int = 150,
262
+ font_size: int = 8,
263
+ show_weights: bool = True,
264
+ ) -> tuple[Figure, tuple[Any, Any, Any, Axes | None]]:
265
+ """
266
+ Plot slope/intercept sensitivity vs tuning_constant.
267
+
268
+ Returns (fig, (ax_params, ax_diag)).
269
+ """
270
+ required = {"tuning_constant", "slope", "intercept", "r2"}
271
+ missing = required - set(df.columns)
272
+ if missing:
273
+ raise ValueError(
274
+ f"DataFrame missing required columns: {sorted(missing)}"
275
+ )
276
+
277
+ x = np.asarray(df["tuning_constant"], dtype=float)
278
+
279
+ fig, (ax_top, ax_bot) = plt.subplots(
280
+ nrows=2,
281
+ sharex=True,
282
+ figsize=figsize,
283
+ dpi=dpi,
284
+ constrained_layout=True,
285
+ gridspec_kw={"height_ratios": [2, 1]},
286
+ )
287
+
288
+ def style(ax: Axes) -> None:
289
+ ax.tick_params(labelsize=font_size)
290
+ ax.xaxis.label.set_fontsize(font_size)
291
+ ax.yaxis.label.set_fontsize(font_size)
292
+ ax.spines["top"].set_visible(False)
293
+ ax.spines["right"].set_visible(False)
294
+
295
+ style(ax_top)
296
+ style(ax_bot)
297
+
298
+ # --- Top panel: slope + intercept on twinx ---
299
+ ax_top_r = ax_top.twinx()
300
+ ax_top_r.tick_params(labelsize=font_size)
301
+ ax_top_r.yaxis.label.set_fontsize(font_size)
302
+ ax_top_r.spines["top"].set_visible(False)
303
+
304
+ l1 = ax_top.plot(x, df["slope"], linewidth=1.2, label="slope")
305
+ l2 = ax_top_r.plot(
306
+ x,
307
+ df["intercept"],
308
+ linewidth=1.2,
309
+ linestyle="--",
310
+ label="intercept",
311
+ )
312
+
313
+ ax_top.set_ylabel("slope")
314
+ ax_top_r.set_ylabel("intercept")
315
+
316
+ lines = l1 + l2
317
+ labels = [ln.get_label() for ln in lines]
318
+ ax_top.legend(lines, labels, frameon=False, fontsize=font_size)
319
+
320
+ # --- Bottom panel: R² + optional weights on twinx ---
321
+ ax_bot_r: Axes | None = None
322
+ l3 = ax_bot.plot(x, df["r2"], linewidth=1.2, label="R²")
323
+ ax_bot.set_ylabel("R²")
324
+ ax_bot.set_xlabel("tuning constant")
325
+
326
+ lines2 = l3
327
+ labels2 = [ln.get_label() for ln in l3]
328
+
329
+ if show_weights and "w_zero_frac" in df.columns:
330
+ ax_bot_r = ax_bot.twinx()
331
+ assert ax_bot_r is not None
332
+ ax_bot_r.tick_params(labelsize=font_size)
333
+ ax_bot_r.yaxis.label.set_fontsize(font_size)
334
+ ax_bot_r.spines["top"].set_visible(False)
335
+
336
+ l4 = ax_bot_r.plot(
337
+ x,
338
+ df["w_zero_frac"],
339
+ linewidth=1.0,
340
+ linestyle=":",
341
+ label="zero-weight frac",
342
+ )
343
+ ax_bot_r.set_ylabel("zero-weight frac")
344
+
345
+ lines2 = l3 + l4
346
+ labels2 = [ln.get_label() for ln in lines2]
347
+
348
+ ax_bot.legend(lines2, labels2, frameon=False, fontsize=font_size)
349
+
350
+ return fig, (ax_top, ax_bot, ax_top_r, ax_bot_r)