fibphot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,290 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from ..state import PhotometryState
10
+ from ..types import FloatArray
11
+
12
+ ParamOrder = Literal["const", "amp_fast", "amp_slow", "tau_fast", "tau_slow"]
13
+
14
+
15
+ @dataclass(frozen=True, slots=True)
16
+ class DoubleExpParams:
17
+ const: float
18
+ amp_fast: float
19
+ amp_slow: float
20
+ tau_fast: float
21
+ tau_slow: float
22
+
23
+ @classmethod
24
+ def from_row(cls, row: FloatArray) -> DoubleExpParams:
25
+ r = np.asarray(row, dtype=float).ravel()
26
+ if r.shape[0] != 5:
27
+ raise ValueError(f"Expected 5 params, got shape {r.shape}.")
28
+ return cls(
29
+ const=float(r[0]),
30
+ amp_fast=float(r[1]),
31
+ amp_slow=float(r[2]),
32
+ tau_fast=float(r[3]),
33
+ tau_slow=float(r[4]),
34
+ )
35
+
36
+ def f0(self, t: FloatArray) -> FloatArray:
37
+ t = np.asarray(t, dtype=float)
38
+ return (
39
+ self.const
40
+ + self.amp_fast * np.exp(-t / self.tau_fast)
41
+ + self.amp_slow * np.exp(-t / self.tau_slow)
42
+ )
43
+
44
+
45
+ def _last_stage_id(state: PhotometryState, stage_name: str) -> str:
46
+ target = stage_name.lower()
47
+ for rec in reversed(state.summary):
48
+ if rec.name.lower() == target:
49
+ return rec.stage_id
50
+ raise KeyError(f"Stage not found in summary: {stage_name!r}")
51
+
52
+
53
+ def _half_drop_time(t: FloatArray, y: FloatArray) -> float:
54
+ """
55
+ Time at which y has completed half of its total drop from y[0] to y[-1].
56
+
57
+ If there is no net drop (y[-1] >= y[0]) returns NaN.
58
+ Uses linear interpolation between samples.
59
+ """
60
+ t = np.asarray(t, dtype=float)
61
+ y = np.asarray(y, dtype=float)
62
+
63
+ if t.ndim != 1 or y.ndim != 1 or t.shape[0] != y.shape[0]:
64
+ raise ValueError("t and y must be 1D arrays of the same length.")
65
+
66
+ y0 = float(y[0])
67
+ y1 = float(y[-1])
68
+ drop = y0 - y1
69
+ if not np.isfinite(drop) or drop <= 0.0:
70
+ return float("nan")
71
+
72
+ target = y0 - 0.5 * drop
73
+
74
+ # Find the first index where we are at/below target (assuming decay).
75
+ idx = np.where(y <= target)[0]
76
+ if idx.size == 0:
77
+ return float("nan")
78
+
79
+ j = int(idx[0])
80
+ if j == 0:
81
+ return float(t[0])
82
+
83
+ # Linear interpolation between (j-1) and j
84
+ t_lo, t_hi = float(t[j - 1]), float(t[j])
85
+ y_lo, y_hi = float(y[j - 1]), float(y[j])
86
+
87
+ denom = y_hi - y_lo
88
+ if not np.isfinite(denom) or abs(denom) < 1e-20:
89
+ return float(t_hi)
90
+
91
+ frac = (target - y_lo) / denom
92
+ frac = float(np.clip(frac, 0.0, 1.0))
93
+ return t_lo + frac * (t_hi - t_lo)
94
+
95
+
96
+ def _norm_shape_rmse(a: FloatArray, b: FloatArray, eps: float = 1e-12) -> float:
97
+ """
98
+ RMSE between two traces after normalising each by its first sample.
99
+ """
100
+ a = np.asarray(a, dtype=float)
101
+ b = np.asarray(b, dtype=float)
102
+
103
+ if a.shape != b.shape:
104
+ raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}.")
105
+
106
+ a0 = float(a[0])
107
+ b0 = float(b[0])
108
+ if not (
109
+ np.isfinite(a0) or not np.isfinite(b0) or abs(a0) < eps or abs(b0) < eps
110
+ ):
111
+ return float("nan")
112
+
113
+ an = a / a0
114
+ bn = b / b0
115
+ d = an - bn
116
+ return float(np.sqrt(np.nanmean(d * d)))
117
+
118
+
119
+ def photobleaching_summary(
120
+ state: PhotometryState,
121
+ *,
122
+ stage_id: str | None = None,
123
+ stage_name: str = "double_exp_baseline",
124
+ fast_component_threshold: float = 0.05,
125
+ control: str | None = None,
126
+ ) -> pd.DataFrame:
127
+ """
128
+ Per-channel summary of a fitted double exponential baseline.
129
+
130
+ Adds:
131
+ - has_fast_component: whether amp_fast / (amp_fast + amp_slow) exceeds
132
+ `fast_component_threshold`
133
+ - half_drop_time_s: time to complete half of the session-long drop
134
+ - optional control comparisons if `control` is provided:
135
+ * control_percent_drop
136
+ * percent_drop_minus_control
137
+ * tau_slow_ratio_to_control
138
+ * norm_shape_rmse_to_control
139
+ """
140
+ sid = stage_id or _last_stage_id(state, stage_name)
141
+ res = state.results.get(sid)
142
+ if res is None:
143
+ raise KeyError(f"No results found for stage_id={sid!r}.")
144
+
145
+ params = np.asarray(res.get("params"), dtype=float)
146
+ r2 = np.asarray(res.get("r2"), dtype=float) if "r2" in res else None
147
+ rmse = np.asarray(res.get("rmse"), dtype=float) if "rmse" in res else None
148
+
149
+ if params.ndim != 2 or params.shape[1] != 5:
150
+ raise ValueError(f"Expected params shape (n, 5), got {params.shape}.")
151
+
152
+ t = np.asarray(state.time_seconds, dtype=float)
153
+
154
+ baseline_curves: dict[str, FloatArray] = {}
155
+ rows: list[dict[str, float | str | bool]] = []
156
+
157
+ for i, name in enumerate(state.channel_names):
158
+ p = DoubleExpParams.from_row(params[i])
159
+
160
+ f = p.f0(t)
161
+ baseline_curves[name] = f
162
+
163
+ f_start = float(f[0])
164
+ f_end = float(f[-1])
165
+
166
+ amp_total = p.amp_fast + p.amp_slow
167
+ if amp_total != 0.0:
168
+ fast_frac = float(p.amp_fast / amp_total)
169
+ slow_frac = float(p.amp_slow / amp_total)
170
+ else:
171
+ fast_frac = float("nan")
172
+ slow_frac = float("nan")
173
+
174
+ has_fast = bool(
175
+ np.isfinite(fast_frac) and fast_frac > fast_component_threshold
176
+ )
177
+
178
+ if f_start != 0.0 and np.isfinite(f_start) and np.isfinite(f_end):
179
+ drop_frac = float((f_start - f_end) / f_start)
180
+ percent_drop = 100.0 * drop_frac
181
+ else:
182
+ percent_drop = float("nan")
183
+
184
+ row: dict[str, float | str | bool] = {
185
+ "channel": name,
186
+ "tau_fast_s": p.tau_fast,
187
+ "tau_slow_s": p.tau_slow,
188
+ "fast_amp_frac": fast_frac,
189
+ "slow_amp_frac": slow_frac,
190
+ "has_fast_component": has_fast,
191
+ "half_drop_time_s": _half_drop_time(t, f),
192
+ "percent_drop": percent_drop,
193
+ "const": p.const,
194
+ "amp_fast": p.amp_fast,
195
+ "amp_slow": p.amp_slow,
196
+ "f0_start": f_start,
197
+ "f0_end": f_end,
198
+ }
199
+ if r2 is not None and i < r2.shape[0]:
200
+ row["r2"] = float(r2[i])
201
+ if rmse is not None and i < rmse.shape[0]:
202
+ row["rmse"] = float(rmse[i])
203
+
204
+ rows.append(row)
205
+
206
+ df = pd.DataFrame(rows)
207
+
208
+ if control is not None:
209
+ ctl = control.lower()
210
+ if ctl not in baseline_curves:
211
+ raise KeyError(
212
+ f"Control channel {control!r} not found. "
213
+ f"Available: {state.channel_names}"
214
+ )
215
+
216
+ ctl_row = df.loc[df["channel"] == ctl]
217
+ if ctl_row.empty:
218
+ raise KeyError(f"Control channel {control!r} not found in summary.")
219
+ ctl_percent_drop = float(ctl_row["percent_drop"].iloc[0])
220
+ ctl_tau_slow = float(ctl_row["tau_slow_s"].iloc[0])
221
+ ctl_curve = baseline_curves[ctl]
222
+
223
+ df["control_percent_drop"] = ctl_percent_drop
224
+ df["percent_drop_minus_control"] = df["percent_drop"] - ctl_percent_drop
225
+ df["tau_slow_ratio_to_control"] = df["tau_slow_s"] / ctl_tau_slow
226
+
227
+ rmses: list[float] = []
228
+ for ch in df["channel"].tolist():
229
+ rmses.append(_norm_shape_rmse(baseline_curves[ch], ctl_curve))
230
+ df["norm_shape_rmse_to_control"] = rmses
231
+
232
+ preferred = [
233
+ "channel",
234
+ "tau_fast_s",
235
+ "tau_slow_s",
236
+ "fast_amp_frac",
237
+ "slow_amp_frac",
238
+ "has_fast_component",
239
+ "half_drop_time_s",
240
+ "percent_drop",
241
+ "r2",
242
+ "rmse",
243
+ ]
244
+ if control is not None:
245
+ preferred += [
246
+ "control_percent_drop",
247
+ "percent_drop_minus_control",
248
+ "tau_slow_ratio_to_control",
249
+ "norm_shape_rmse_to_control",
250
+ ]
251
+ preferred += ["const", "amp_fast", "amp_slow", "f0_start", "f0_end"]
252
+
253
+ cols = [c for c in preferred if c in df.columns]
254
+ return df[cols].sort_values("channel")
255
+
256
+
257
+ def get_baseline_trace(
258
+ state: PhotometryState,
259
+ *,
260
+ baseline_key: str = "double_exp_baseline",
261
+ channel: str,
262
+ normalise_to_start: bool = False,
263
+ eps: float = 1e-12,
264
+ ) -> FloatArray:
265
+ """
266
+ Fetch a baseline trace from state.derived and optionally normalise to start.
267
+ """
268
+ if baseline_key not in state.derived:
269
+ raise KeyError(
270
+ f"derived['{baseline_key}'] not found. "
271
+ "Run the baseline stage first."
272
+ )
273
+ base = np.asarray(state.derived[baseline_key], dtype=float)
274
+ if base.shape != state.signals.shape:
275
+ raise ValueError(
276
+ f"derived['{baseline_key}'] has shape {base.shape}, "
277
+ f"expected {state.signals.shape}."
278
+ )
279
+
280
+ i = state.idx(channel)
281
+ y = base[i].copy()
282
+
283
+ if normalise_to_start:
284
+ d0 = float(y[0])
285
+ if not np.isfinite(d0) or abs(d0) < eps:
286
+ y[:] = np.nan
287
+ else:
288
+ y = y / d0
289
+
290
+ return y
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ from ..state import PhotometryState
7
+ from .report import AnalysisResult
8
+
9
+
10
+ def plot_auc_result(
11
+ state: PhotometryState,
12
+ result: AnalysisResult,
13
+ *,
14
+ ax=None,
15
+ label: str | None = None,
16
+ signal_alpha: float = 0.9,
17
+ window_alpha: float = 0.08,
18
+ fill_alpha: float = 0.30,
19
+ linewidth: float = 1.2,
20
+ fontsize: int = 8,
21
+ show_metrics: bool = True,
22
+ colour_signal: str = "#1f77b4", # vivid blue
23
+ colour_baseline: str = "#444444", # dark grey
24
+ colour_fill: str = "#ff7f0e", # vivid orange
25
+ ):
26
+ """
27
+ Visualise an AUC AnalysisResult:
28
+ - plots full trace
29
+ - shows baseline reference line
30
+ - shades the analysis window
31
+ - fills the integrated area (using stored window contrib)
32
+ """
33
+ if ax is None:
34
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
35
+ else:
36
+ fig = ax.figure
37
+
38
+ chan = result.channel
39
+ t = np.asarray(state.time_seconds, float)
40
+ y = np.asarray(state.channel(chan), float)
41
+
42
+ lab = label or chan
43
+ ax.plot(
44
+ t,
45
+ y,
46
+ linewidth=linewidth,
47
+ alpha=signal_alpha,
48
+ label=lab,
49
+ color=colour_signal,
50
+ )
51
+
52
+ b = float(result.metrics.get("baseline", np.nan))
53
+ if np.isfinite(b):
54
+ ax.axhline(
55
+ b, linewidth=1.0, alpha=0.85, linestyle="--", color=colour_baseline
56
+ )
57
+
58
+ # Window display
59
+ t0 = result.metrics.get("t0", None)
60
+ t1 = result.metrics.get("t1", None)
61
+ if (
62
+ t0 is not None
63
+ and t1 is not None
64
+ and np.isfinite(t0)
65
+ and np.isfinite(t1)
66
+ ):
67
+ ax.axvspan(float(t0), float(t1), alpha=window_alpha, color="grey")
68
+
69
+ # Fill based on stored arrays
70
+ tt = np.asarray(result.arrays.get("t_window", np.array([])), float)
71
+ contrib = np.asarray(result.arrays.get("contrib", np.array([])), float)
72
+
73
+ if tt.size >= 2 and contrib.size == tt.size and np.isfinite(b):
74
+ ax.fill_between(
75
+ tt,
76
+ b,
77
+ b + contrib,
78
+ alpha=fill_alpha,
79
+ color=colour_fill,
80
+ label="AUC area",
81
+ )
82
+
83
+ ax.set_xlabel("time (s)", fontsize=fontsize)
84
+ ax.set_ylabel(chan, fontsize=fontsize)
85
+
86
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
87
+
88
+ if show_metrics:
89
+ auc = result.metrics.get("auc", np.nan)
90
+ baseline = result.metrics.get("baseline", np.nan)
91
+ txt = f"AUC: {auc:.3g}\nBaseline: {baseline:.3g}"
92
+ ax.text(
93
+ 0.02,
94
+ 0.98,
95
+ txt,
96
+ transform=ax.transAxes,
97
+ va="top",
98
+ ha="left",
99
+ fontsize=fontsize,
100
+ color="#222222",
101
+ )
102
+
103
+ ax.legend(frameon=False, fontsize=fontsize)
104
+
105
+ return fig, ax
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Literal
6
+
7
+ import numpy.typing as npt
8
+
9
+ from ..state import PhotometryState
10
+
11
+ WindowRef = Literal["seconds", "samples"]
12
+
13
+
14
+ @dataclass(frozen=True, slots=True)
15
+ class AnalysisWindow:
16
+ """
17
+ A window over which an analysis is evaluated.
18
+
19
+ - ref="seconds": start/end are in seconds (state.time_seconds space)
20
+ - ref="samples": start/end are integer sample indices [start, end)
21
+ """
22
+
23
+ start: float | int
24
+ end: float | int
25
+ ref: WindowRef = "seconds"
26
+ label: str | None = None
27
+
28
+
29
+ @dataclass(frozen=True, slots=True)
30
+ class AnalysisResult:
31
+ name: str # e.g. "auc"
32
+ channel: str # e.g. "gcamp"
33
+ window: AnalysisWindow | None # None means “whole trace”
34
+ params: dict[str, Any] = field(default_factory=dict)
35
+
36
+ metrics: dict[str, float] = field(default_factory=dict)
37
+ arrays: dict[str, npt.NDArray[Any]] = field(default_factory=dict)
38
+
39
+ notes: str | None = None
40
+
41
+
42
+ @dataclass(frozen=True, slots=True)
43
+ class PhotometryReport:
44
+ state: PhotometryState
45
+ results: tuple[AnalysisResult, ...] = ()
46
+
47
+ def add(self, result: AnalysisResult) -> PhotometryReport:
48
+ return PhotometryReport(self.state, results=(*self.results, result))
49
+
50
+ def extend(self, results: Iterable[AnalysisResult]) -> PhotometryReport:
51
+ return PhotometryReport(
52
+ self.state, results=(*self.results, *tuple(results))
53
+ )
54
+
55
+ def find(self, name: str) -> tuple[AnalysisResult, ...]:
56
+ return tuple(r for r in self.results if r.name == name)
fibphot/collection.py ADDED
@@ -0,0 +1,207 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Iterable, Sequence
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Literal
7
+
8
+ from .analysis.aggregate import (
9
+ AlignedSignals,
10
+ align_collection_signals,
11
+ mean_state_from_aligned,
12
+ )
13
+ from .state import PhotometryState
14
+ from .tags import TagTable, apply_tags, default_subject_getter, read_tag_table
15
+
16
+ SubjectGetter = Callable[[PhotometryState], str | None]
17
+ AlignMode = Literal["intersection", "union"]
18
+ InterpKind = Literal["linear", "nearest"]
19
+ TimeRef = Literal["absolute", "start"]
20
+
21
+
22
+ @dataclass(frozen=True, slots=True)
23
+ class PhotometryCollection:
24
+ """
25
+ A thin immutable wrapper around multiple PhotometryState objects.
26
+
27
+ Provides tag-aware filtering, grouping, sorting, and bulk serialisation.
28
+ """
29
+
30
+ states: Sequence[PhotometryState]
31
+
32
+ def __post_init__(self) -> None:
33
+ object.__setattr__(self, "states", tuple(self.states))
34
+
35
+ @classmethod
36
+ def from_iterable(
37
+ cls, states: Iterable[PhotometryState]
38
+ ) -> PhotometryCollection:
39
+ return cls(states=tuple(states))
40
+
41
+ @classmethod
42
+ def from_glob(
43
+ cls,
44
+ base_path: Path | str,
45
+ pattern: str = "*.doric",
46
+ *,
47
+ reader: Callable[[Path], PhotometryState],
48
+ metadata_fn: Callable[[Path], dict[str, Any]] | None = None,
49
+ sort: bool = True,
50
+ ) -> PhotometryCollection:
51
+ base = Path(base_path)
52
+ paths = list(base.glob(pattern))
53
+ if sort:
54
+ paths.sort()
55
+
56
+ def _iter() -> Iterable[PhotometryState]:
57
+ for p in paths:
58
+ st = reader(p)
59
+ if metadata_fn is not None:
60
+ st = st.with_metadata(**metadata_fn(p))
61
+ yield st
62
+
63
+ return cls(states=tuple(_iter()))
64
+
65
+ def __len__(self) -> int:
66
+ return len(self.states)
67
+
68
+ def __iter__(self):
69
+ return iter(self.states)
70
+
71
+ @property
72
+ def subjects(self) -> tuple[str | None, ...]:
73
+ return tuple(s.subject for s in self.states)
74
+
75
+ def pipe(self, *stages: object) -> PhotometryCollection:
76
+ return PhotometryCollection(
77
+ states=tuple(s.pipe(*stages) for s in self.states)
78
+ )
79
+
80
+ def with_tags_from_table(
81
+ self,
82
+ path: Path | str,
83
+ *,
84
+ subject_getter: SubjectGetter = default_subject_getter,
85
+ overwrite: bool = False,
86
+ ) -> PhotometryCollection:
87
+ table = read_tag_table(path)
88
+ return self.with_tags(
89
+ table, subject_getter=subject_getter, overwrite=overwrite
90
+ )
91
+
92
+ def with_tags(
93
+ self,
94
+ table: TagTable,
95
+ *,
96
+ subject_getter: SubjectGetter = default_subject_getter,
97
+ overwrite: bool = False,
98
+ ) -> PhotometryCollection:
99
+ tagged = tuple(
100
+ apply_tags(
101
+ s, table, subject_getter=subject_getter, overwrite=overwrite
102
+ )
103
+ for s in self.states
104
+ )
105
+ return PhotometryCollection(states=tagged)
106
+
107
+ def filter(self, **criteria: str) -> PhotometryCollection:
108
+ """
109
+ Filter by tags: e.g. .filter(genotype="KO", context="A")
110
+ """
111
+
112
+ def ok(s: PhotometryState) -> bool:
113
+ tags = s.tags
114
+ return all(tags.get(k) == v for k, v in criteria.items())
115
+
116
+ return PhotometryCollection(
117
+ states=tuple(s for s in self.states if ok(s))
118
+ )
119
+
120
+ def groupby(self, key: str) -> dict[str, PhotometryCollection]:
121
+ """
122
+ Group by a tag key: returns {tag_value: PhotometryCollection}.
123
+ Missing values go under "".
124
+ """
125
+ groups: dict[str, list[PhotometryState]] = {}
126
+ for s in self.states:
127
+ v = s.tags.get(key, "")
128
+ groups.setdefault(v, []).append(s)
129
+
130
+ return {
131
+ k: PhotometryCollection(states=tuple(v)) for k, v in groups.items()
132
+ }
133
+
134
+ def sort_by(self, *keys: str) -> PhotometryCollection:
135
+ """
136
+ Sort by one or more tag keys (lexicographic).
137
+ Missing values sort as "".
138
+ """
139
+
140
+ def sort_key(s: PhotometryState) -> tuple[str, ...]:
141
+ tags = s.tags
142
+ return tuple(tags.get(k, "") for k in keys)
143
+
144
+ return PhotometryCollection(
145
+ states=tuple(sorted(self.states, key=sort_key))
146
+ )
147
+
148
+ def map(
149
+ self, fn: Callable[[PhotometryState], PhotometryState]
150
+ ) -> PhotometryCollection:
151
+ return PhotometryCollection(states=tuple(fn(s) for s in self.states))
152
+
153
+ def align(
154
+ self,
155
+ *,
156
+ channels: Sequence[str] | None = None,
157
+ align: AlignMode = "intersection",
158
+ time_ref: TimeRef = "start",
159
+ dt: float | None = None,
160
+ target_fs: float | None = None,
161
+ interpolation: InterpKind = "linear",
162
+ fill: float = float("nan"),
163
+ ) -> AlignedSignals:
164
+ return align_collection_signals(
165
+ self.states,
166
+ channels=channels,
167
+ align=align,
168
+ time_ref=time_ref,
169
+ dt=dt,
170
+ target_fs=target_fs,
171
+ interpolation=interpolation,
172
+ fill=fill,
173
+ )
174
+
175
+ def mean(
176
+ self,
177
+ *,
178
+ channels: Sequence[str] | None = None,
179
+ align: AlignMode = "intersection",
180
+ time_ref: TimeRef = "start",
181
+ dt: float | None = None,
182
+ target_fs: float | None = None,
183
+ interpolation: InterpKind = "linear",
184
+ fill: float = float("nan"),
185
+ name: str = "group_mean",
186
+ ) -> PhotometryState:
187
+ aligned = self.align(
188
+ channels=channels,
189
+ align=align,
190
+ time_ref=time_ref,
191
+ dt=dt,
192
+ target_fs=target_fs,
193
+ interpolation=interpolation,
194
+ fill=fill,
195
+ )
196
+ return mean_state_from_aligned(aligned, name=name)
197
+
198
+ def to_h5(self, path: Path | str) -> None:
199
+ from .io.h5 import save_collection_h5
200
+
201
+ save_collection_h5(self, path)
202
+
203
+ @classmethod
204
+ def from_h5(cls, path: Path | str) -> PhotometryCollection:
205
+ from .io.h5 import load_collection_h5
206
+
207
+ return load_collection_h5(path)
File without changes