fibphot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fibphot/__init__.py +6 -0
- fibphot/analysis/__init__.py +0 -0
- fibphot/analysis/aggregate.py +257 -0
- fibphot/analysis/auc.py +354 -0
- fibphot/analysis/irls.py +350 -0
- fibphot/analysis/peaks.py +1163 -0
- fibphot/analysis/photobleaching.py +290 -0
- fibphot/analysis/plotting.py +105 -0
- fibphot/analysis/report.py +56 -0
- fibphot/collection.py +207 -0
- fibphot/fit/__init__.py +0 -0
- fibphot/fit/regression.py +269 -0
- fibphot/io/__init__.py +6 -0
- fibphot/io/doric.py +435 -0
- fibphot/io/excel.py +76 -0
- fibphot/io/h5.py +321 -0
- fibphot/misc.py +11 -0
- fibphot/peaks.py +628 -0
- fibphot/pipeline.py +14 -0
- fibphot/plotting.py +594 -0
- fibphot/stages/__init__.py +22 -0
- fibphot/stages/base.py +101 -0
- fibphot/stages/baseline.py +354 -0
- fibphot/stages/control_dff.py +214 -0
- fibphot/stages/filters.py +273 -0
- fibphot/stages/normalisation.py +260 -0
- fibphot/stages/regression.py +139 -0
- fibphot/stages/smooth.py +442 -0
- fibphot/stages/trim.py +141 -0
- fibphot/state.py +309 -0
- fibphot/tags.py +130 -0
- fibphot/types.py +6 -0
- fibphot-0.1.0.dist-info/METADATA +63 -0
- fibphot-0.1.0.dist-info/RECORD +37 -0
- fibphot-0.1.0.dist-info/WHEEL +5 -0
- fibphot-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- fibphot-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from ..state import PhotometryState
|
|
10
|
+
from ..types import FloatArray
|
|
11
|
+
|
|
12
|
+
ParamOrder = Literal["const", "amp_fast", "amp_slow", "tau_fast", "tau_slow"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True, slots=True)
|
|
16
|
+
class DoubleExpParams:
|
|
17
|
+
const: float
|
|
18
|
+
amp_fast: float
|
|
19
|
+
amp_slow: float
|
|
20
|
+
tau_fast: float
|
|
21
|
+
tau_slow: float
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_row(cls, row: FloatArray) -> DoubleExpParams:
|
|
25
|
+
r = np.asarray(row, dtype=float).ravel()
|
|
26
|
+
if r.shape[0] != 5:
|
|
27
|
+
raise ValueError(f"Expected 5 params, got shape {r.shape}.")
|
|
28
|
+
return cls(
|
|
29
|
+
const=float(r[0]),
|
|
30
|
+
amp_fast=float(r[1]),
|
|
31
|
+
amp_slow=float(r[2]),
|
|
32
|
+
tau_fast=float(r[3]),
|
|
33
|
+
tau_slow=float(r[4]),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def f0(self, t: FloatArray) -> FloatArray:
|
|
37
|
+
t = np.asarray(t, dtype=float)
|
|
38
|
+
return (
|
|
39
|
+
self.const
|
|
40
|
+
+ self.amp_fast * np.exp(-t / self.tau_fast)
|
|
41
|
+
+ self.amp_slow * np.exp(-t / self.tau_slow)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _last_stage_id(state: PhotometryState, stage_name: str) -> str:
|
|
46
|
+
target = stage_name.lower()
|
|
47
|
+
for rec in reversed(state.summary):
|
|
48
|
+
if rec.name.lower() == target:
|
|
49
|
+
return rec.stage_id
|
|
50
|
+
raise KeyError(f"Stage not found in summary: {stage_name!r}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _half_drop_time(t: FloatArray, y: FloatArray) -> float:
|
|
54
|
+
"""
|
|
55
|
+
Time at which y has completed half of its total drop from y[0] to y[-1].
|
|
56
|
+
|
|
57
|
+
If there is no net drop (y[-1] >= y[0]) returns NaN.
|
|
58
|
+
Uses linear interpolation between samples.
|
|
59
|
+
"""
|
|
60
|
+
t = np.asarray(t, dtype=float)
|
|
61
|
+
y = np.asarray(y, dtype=float)
|
|
62
|
+
|
|
63
|
+
if t.ndim != 1 or y.ndim != 1 or t.shape[0] != y.shape[0]:
|
|
64
|
+
raise ValueError("t and y must be 1D arrays of the same length.")
|
|
65
|
+
|
|
66
|
+
y0 = float(y[0])
|
|
67
|
+
y1 = float(y[-1])
|
|
68
|
+
drop = y0 - y1
|
|
69
|
+
if not np.isfinite(drop) or drop <= 0.0:
|
|
70
|
+
return float("nan")
|
|
71
|
+
|
|
72
|
+
target = y0 - 0.5 * drop
|
|
73
|
+
|
|
74
|
+
# Find the first index where we are at/below target (assuming decay).
|
|
75
|
+
idx = np.where(y <= target)[0]
|
|
76
|
+
if idx.size == 0:
|
|
77
|
+
return float("nan")
|
|
78
|
+
|
|
79
|
+
j = int(idx[0])
|
|
80
|
+
if j == 0:
|
|
81
|
+
return float(t[0])
|
|
82
|
+
|
|
83
|
+
# Linear interpolation between (j-1) and j
|
|
84
|
+
t_lo, t_hi = float(t[j - 1]), float(t[j])
|
|
85
|
+
y_lo, y_hi = float(y[j - 1]), float(y[j])
|
|
86
|
+
|
|
87
|
+
denom = y_hi - y_lo
|
|
88
|
+
if not np.isfinite(denom) or abs(denom) < 1e-20:
|
|
89
|
+
return float(t_hi)
|
|
90
|
+
|
|
91
|
+
frac = (target - y_lo) / denom
|
|
92
|
+
frac = float(np.clip(frac, 0.0, 1.0))
|
|
93
|
+
return t_lo + frac * (t_hi - t_lo)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _norm_shape_rmse(a: FloatArray, b: FloatArray, eps: float = 1e-12) -> float:
|
|
97
|
+
"""
|
|
98
|
+
RMSE between two traces after normalising each by its first sample.
|
|
99
|
+
"""
|
|
100
|
+
a = np.asarray(a, dtype=float)
|
|
101
|
+
b = np.asarray(b, dtype=float)
|
|
102
|
+
|
|
103
|
+
if a.shape != b.shape:
|
|
104
|
+
raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}.")
|
|
105
|
+
|
|
106
|
+
a0 = float(a[0])
|
|
107
|
+
b0 = float(b[0])
|
|
108
|
+
if not (
|
|
109
|
+
np.isfinite(a0) or not np.isfinite(b0) or abs(a0) < eps or abs(b0) < eps
|
|
110
|
+
):
|
|
111
|
+
return float("nan")
|
|
112
|
+
|
|
113
|
+
an = a / a0
|
|
114
|
+
bn = b / b0
|
|
115
|
+
d = an - bn
|
|
116
|
+
return float(np.sqrt(np.nanmean(d * d)))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def photobleaching_summary(
|
|
120
|
+
state: PhotometryState,
|
|
121
|
+
*,
|
|
122
|
+
stage_id: str | None = None,
|
|
123
|
+
stage_name: str = "double_exp_baseline",
|
|
124
|
+
fast_component_threshold: float = 0.05,
|
|
125
|
+
control: str | None = None,
|
|
126
|
+
) -> pd.DataFrame:
|
|
127
|
+
"""
|
|
128
|
+
Per-channel summary of a fitted double exponential baseline.
|
|
129
|
+
|
|
130
|
+
Adds:
|
|
131
|
+
- has_fast_component: whether amp_fast / (amp_fast + amp_slow) exceeds
|
|
132
|
+
`fast_component_threshold`
|
|
133
|
+
- half_drop_time_s: time to complete half of the session-long drop
|
|
134
|
+
- optional control comparisons if `control` is provided:
|
|
135
|
+
* control_percent_drop
|
|
136
|
+
* percent_drop_minus_control
|
|
137
|
+
* tau_slow_ratio_to_control
|
|
138
|
+
* norm_shape_rmse_to_control
|
|
139
|
+
"""
|
|
140
|
+
sid = stage_id or _last_stage_id(state, stage_name)
|
|
141
|
+
res = state.results.get(sid)
|
|
142
|
+
if res is None:
|
|
143
|
+
raise KeyError(f"No results found for stage_id={sid!r}.")
|
|
144
|
+
|
|
145
|
+
params = np.asarray(res.get("params"), dtype=float)
|
|
146
|
+
r2 = np.asarray(res.get("r2"), dtype=float) if "r2" in res else None
|
|
147
|
+
rmse = np.asarray(res.get("rmse"), dtype=float) if "rmse" in res else None
|
|
148
|
+
|
|
149
|
+
if params.ndim != 2 or params.shape[1] != 5:
|
|
150
|
+
raise ValueError(f"Expected params shape (n, 5), got {params.shape}.")
|
|
151
|
+
|
|
152
|
+
t = np.asarray(state.time_seconds, dtype=float)
|
|
153
|
+
|
|
154
|
+
baseline_curves: dict[str, FloatArray] = {}
|
|
155
|
+
rows: list[dict[str, float | str | bool]] = []
|
|
156
|
+
|
|
157
|
+
for i, name in enumerate(state.channel_names):
|
|
158
|
+
p = DoubleExpParams.from_row(params[i])
|
|
159
|
+
|
|
160
|
+
f = p.f0(t)
|
|
161
|
+
baseline_curves[name] = f
|
|
162
|
+
|
|
163
|
+
f_start = float(f[0])
|
|
164
|
+
f_end = float(f[-1])
|
|
165
|
+
|
|
166
|
+
amp_total = p.amp_fast + p.amp_slow
|
|
167
|
+
if amp_total != 0.0:
|
|
168
|
+
fast_frac = float(p.amp_fast / amp_total)
|
|
169
|
+
slow_frac = float(p.amp_slow / amp_total)
|
|
170
|
+
else:
|
|
171
|
+
fast_frac = float("nan")
|
|
172
|
+
slow_frac = float("nan")
|
|
173
|
+
|
|
174
|
+
has_fast = bool(
|
|
175
|
+
np.isfinite(fast_frac) and fast_frac > fast_component_threshold
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if f_start != 0.0 and np.isfinite(f_start) and np.isfinite(f_end):
|
|
179
|
+
drop_frac = float((f_start - f_end) / f_start)
|
|
180
|
+
percent_drop = 100.0 * drop_frac
|
|
181
|
+
else:
|
|
182
|
+
percent_drop = float("nan")
|
|
183
|
+
|
|
184
|
+
row: dict[str, float | str | bool] = {
|
|
185
|
+
"channel": name,
|
|
186
|
+
"tau_fast_s": p.tau_fast,
|
|
187
|
+
"tau_slow_s": p.tau_slow,
|
|
188
|
+
"fast_amp_frac": fast_frac,
|
|
189
|
+
"slow_amp_frac": slow_frac,
|
|
190
|
+
"has_fast_component": has_fast,
|
|
191
|
+
"half_drop_time_s": _half_drop_time(t, f),
|
|
192
|
+
"percent_drop": percent_drop,
|
|
193
|
+
"const": p.const,
|
|
194
|
+
"amp_fast": p.amp_fast,
|
|
195
|
+
"amp_slow": p.amp_slow,
|
|
196
|
+
"f0_start": f_start,
|
|
197
|
+
"f0_end": f_end,
|
|
198
|
+
}
|
|
199
|
+
if r2 is not None and i < r2.shape[0]:
|
|
200
|
+
row["r2"] = float(r2[i])
|
|
201
|
+
if rmse is not None and i < rmse.shape[0]:
|
|
202
|
+
row["rmse"] = float(rmse[i])
|
|
203
|
+
|
|
204
|
+
rows.append(row)
|
|
205
|
+
|
|
206
|
+
df = pd.DataFrame(rows)
|
|
207
|
+
|
|
208
|
+
if control is not None:
|
|
209
|
+
ctl = control.lower()
|
|
210
|
+
if ctl not in baseline_curves:
|
|
211
|
+
raise KeyError(
|
|
212
|
+
f"Control channel {control!r} not found. "
|
|
213
|
+
f"Available: {state.channel_names}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
ctl_row = df.loc[df["channel"] == ctl]
|
|
217
|
+
if ctl_row.empty:
|
|
218
|
+
raise KeyError(f"Control channel {control!r} not found in summary.")
|
|
219
|
+
ctl_percent_drop = float(ctl_row["percent_drop"].iloc[0])
|
|
220
|
+
ctl_tau_slow = float(ctl_row["tau_slow_s"].iloc[0])
|
|
221
|
+
ctl_curve = baseline_curves[ctl]
|
|
222
|
+
|
|
223
|
+
df["control_percent_drop"] = ctl_percent_drop
|
|
224
|
+
df["percent_drop_minus_control"] = df["percent_drop"] - ctl_percent_drop
|
|
225
|
+
df["tau_slow_ratio_to_control"] = df["tau_slow_s"] / ctl_tau_slow
|
|
226
|
+
|
|
227
|
+
rmses: list[float] = []
|
|
228
|
+
for ch in df["channel"].tolist():
|
|
229
|
+
rmses.append(_norm_shape_rmse(baseline_curves[ch], ctl_curve))
|
|
230
|
+
df["norm_shape_rmse_to_control"] = rmses
|
|
231
|
+
|
|
232
|
+
preferred = [
|
|
233
|
+
"channel",
|
|
234
|
+
"tau_fast_s",
|
|
235
|
+
"tau_slow_s",
|
|
236
|
+
"fast_amp_frac",
|
|
237
|
+
"slow_amp_frac",
|
|
238
|
+
"has_fast_component",
|
|
239
|
+
"half_drop_time_s",
|
|
240
|
+
"percent_drop",
|
|
241
|
+
"r2",
|
|
242
|
+
"rmse",
|
|
243
|
+
]
|
|
244
|
+
if control is not None:
|
|
245
|
+
preferred += [
|
|
246
|
+
"control_percent_drop",
|
|
247
|
+
"percent_drop_minus_control",
|
|
248
|
+
"tau_slow_ratio_to_control",
|
|
249
|
+
"norm_shape_rmse_to_control",
|
|
250
|
+
]
|
|
251
|
+
preferred += ["const", "amp_fast", "amp_slow", "f0_start", "f0_end"]
|
|
252
|
+
|
|
253
|
+
cols = [c for c in preferred if c in df.columns]
|
|
254
|
+
return df[cols].sort_values("channel")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def get_baseline_trace(
|
|
258
|
+
state: PhotometryState,
|
|
259
|
+
*,
|
|
260
|
+
baseline_key: str = "double_exp_baseline",
|
|
261
|
+
channel: str,
|
|
262
|
+
normalise_to_start: bool = False,
|
|
263
|
+
eps: float = 1e-12,
|
|
264
|
+
) -> FloatArray:
|
|
265
|
+
"""
|
|
266
|
+
Fetch a baseline trace from state.derived and optionally normalise to start.
|
|
267
|
+
"""
|
|
268
|
+
if baseline_key not in state.derived:
|
|
269
|
+
raise KeyError(
|
|
270
|
+
f"derived['{baseline_key}'] not found. "
|
|
271
|
+
"Run the baseline stage first."
|
|
272
|
+
)
|
|
273
|
+
base = np.asarray(state.derived[baseline_key], dtype=float)
|
|
274
|
+
if base.shape != state.signals.shape:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"derived['{baseline_key}'] has shape {base.shape}, "
|
|
277
|
+
f"expected {state.signals.shape}."
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
i = state.idx(channel)
|
|
281
|
+
y = base[i].copy()
|
|
282
|
+
|
|
283
|
+
if normalise_to_start:
|
|
284
|
+
d0 = float(y[0])
|
|
285
|
+
if not np.isfinite(d0) or abs(d0) < eps:
|
|
286
|
+
y[:] = np.nan
|
|
287
|
+
else:
|
|
288
|
+
y = y / d0
|
|
289
|
+
|
|
290
|
+
return y
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ..state import PhotometryState
|
|
7
|
+
from .report import AnalysisResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def plot_auc_result(
|
|
11
|
+
state: PhotometryState,
|
|
12
|
+
result: AnalysisResult,
|
|
13
|
+
*,
|
|
14
|
+
ax=None,
|
|
15
|
+
label: str | None = None,
|
|
16
|
+
signal_alpha: float = 0.9,
|
|
17
|
+
window_alpha: float = 0.08,
|
|
18
|
+
fill_alpha: float = 0.30,
|
|
19
|
+
linewidth: float = 1.2,
|
|
20
|
+
fontsize: int = 8,
|
|
21
|
+
show_metrics: bool = True,
|
|
22
|
+
colour_signal: str = "#1f77b4", # vivid blue
|
|
23
|
+
colour_baseline: str = "#444444", # dark grey
|
|
24
|
+
colour_fill: str = "#ff7f0e", # vivid orange
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Visualise an AUC AnalysisResult:
|
|
28
|
+
- plots full trace
|
|
29
|
+
- shows baseline reference line
|
|
30
|
+
- shades the analysis window
|
|
31
|
+
- fills the integrated area (using stored window contrib)
|
|
32
|
+
"""
|
|
33
|
+
if ax is None:
|
|
34
|
+
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
|
|
35
|
+
else:
|
|
36
|
+
fig = ax.figure
|
|
37
|
+
|
|
38
|
+
chan = result.channel
|
|
39
|
+
t = np.asarray(state.time_seconds, float)
|
|
40
|
+
y = np.asarray(state.channel(chan), float)
|
|
41
|
+
|
|
42
|
+
lab = label or chan
|
|
43
|
+
ax.plot(
|
|
44
|
+
t,
|
|
45
|
+
y,
|
|
46
|
+
linewidth=linewidth,
|
|
47
|
+
alpha=signal_alpha,
|
|
48
|
+
label=lab,
|
|
49
|
+
color=colour_signal,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
b = float(result.metrics.get("baseline", np.nan))
|
|
53
|
+
if np.isfinite(b):
|
|
54
|
+
ax.axhline(
|
|
55
|
+
b, linewidth=1.0, alpha=0.85, linestyle="--", color=colour_baseline
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Window display
|
|
59
|
+
t0 = result.metrics.get("t0", None)
|
|
60
|
+
t1 = result.metrics.get("t1", None)
|
|
61
|
+
if (
|
|
62
|
+
t0 is not None
|
|
63
|
+
and t1 is not None
|
|
64
|
+
and np.isfinite(t0)
|
|
65
|
+
and np.isfinite(t1)
|
|
66
|
+
):
|
|
67
|
+
ax.axvspan(float(t0), float(t1), alpha=window_alpha, color="grey")
|
|
68
|
+
|
|
69
|
+
# Fill based on stored arrays
|
|
70
|
+
tt = np.asarray(result.arrays.get("t_window", np.array([])), float)
|
|
71
|
+
contrib = np.asarray(result.arrays.get("contrib", np.array([])), float)
|
|
72
|
+
|
|
73
|
+
if tt.size >= 2 and contrib.size == tt.size and np.isfinite(b):
|
|
74
|
+
ax.fill_between(
|
|
75
|
+
tt,
|
|
76
|
+
b,
|
|
77
|
+
b + contrib,
|
|
78
|
+
alpha=fill_alpha,
|
|
79
|
+
color=colour_fill,
|
|
80
|
+
label="AUC area",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
ax.set_xlabel("time (s)", fontsize=fontsize)
|
|
84
|
+
ax.set_ylabel(chan, fontsize=fontsize)
|
|
85
|
+
|
|
86
|
+
ax.tick_params(axis="both", which="major", labelsize=fontsize)
|
|
87
|
+
|
|
88
|
+
if show_metrics:
|
|
89
|
+
auc = result.metrics.get("auc", np.nan)
|
|
90
|
+
baseline = result.metrics.get("baseline", np.nan)
|
|
91
|
+
txt = f"AUC: {auc:.3g}\nBaseline: {baseline:.3g}"
|
|
92
|
+
ax.text(
|
|
93
|
+
0.02,
|
|
94
|
+
0.98,
|
|
95
|
+
txt,
|
|
96
|
+
transform=ax.transAxes,
|
|
97
|
+
va="top",
|
|
98
|
+
ha="left",
|
|
99
|
+
fontsize=fontsize,
|
|
100
|
+
color="#222222",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
ax.legend(frameon=False, fontsize=fontsize)
|
|
104
|
+
|
|
105
|
+
return fig, ax
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
|
|
9
|
+
from ..state import PhotometryState
|
|
10
|
+
|
|
11
|
+
WindowRef = Literal["seconds", "samples"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True, slots=True)
|
|
15
|
+
class AnalysisWindow:
|
|
16
|
+
"""
|
|
17
|
+
A window over which an analysis is evaluated.
|
|
18
|
+
|
|
19
|
+
- ref="seconds": start/end are in seconds (state.time_seconds space)
|
|
20
|
+
- ref="samples": start/end are integer sample indices [start, end)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
start: float | int
|
|
24
|
+
end: float | int
|
|
25
|
+
ref: WindowRef = "seconds"
|
|
26
|
+
label: str | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, slots=True)
|
|
30
|
+
class AnalysisResult:
|
|
31
|
+
name: str # e.g. "auc"
|
|
32
|
+
channel: str # e.g. "gcamp"
|
|
33
|
+
window: AnalysisWindow | None # None means “whole trace”
|
|
34
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
37
|
+
arrays: dict[str, npt.NDArray[Any]] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
notes: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True, slots=True)
|
|
43
|
+
class PhotometryReport:
|
|
44
|
+
state: PhotometryState
|
|
45
|
+
results: tuple[AnalysisResult, ...] = ()
|
|
46
|
+
|
|
47
|
+
def add(self, result: AnalysisResult) -> PhotometryReport:
|
|
48
|
+
return PhotometryReport(self.state, results=(*self.results, result))
|
|
49
|
+
|
|
50
|
+
def extend(self, results: Iterable[AnalysisResult]) -> PhotometryReport:
|
|
51
|
+
return PhotometryReport(
|
|
52
|
+
self.state, results=(*self.results, *tuple(results))
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def find(self, name: str) -> tuple[AnalysisResult, ...]:
|
|
56
|
+
return tuple(r for r in self.results if r.name == name)
|
fibphot/collection.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from .analysis.aggregate import (
|
|
9
|
+
AlignedSignals,
|
|
10
|
+
align_collection_signals,
|
|
11
|
+
mean_state_from_aligned,
|
|
12
|
+
)
|
|
13
|
+
from .state import PhotometryState
|
|
14
|
+
from .tags import TagTable, apply_tags, default_subject_getter, read_tag_table
|
|
15
|
+
|
|
16
|
+
SubjectGetter = Callable[[PhotometryState], str | None]
|
|
17
|
+
AlignMode = Literal["intersection", "union"]
|
|
18
|
+
InterpKind = Literal["linear", "nearest"]
|
|
19
|
+
TimeRef = Literal["absolute", "start"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True, slots=True)
|
|
23
|
+
class PhotometryCollection:
|
|
24
|
+
"""
|
|
25
|
+
A thin immutable wrapper around multiple PhotometryState objects.
|
|
26
|
+
|
|
27
|
+
Provides tag-aware filtering, grouping, sorting, and bulk serialisation.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
states: Sequence[PhotometryState]
|
|
31
|
+
|
|
32
|
+
def __post_init__(self) -> None:
|
|
33
|
+
object.__setattr__(self, "states", tuple(self.states))
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def from_iterable(
|
|
37
|
+
cls, states: Iterable[PhotometryState]
|
|
38
|
+
) -> PhotometryCollection:
|
|
39
|
+
return cls(states=tuple(states))
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_glob(
|
|
43
|
+
cls,
|
|
44
|
+
base_path: Path | str,
|
|
45
|
+
pattern: str = "*.doric",
|
|
46
|
+
*,
|
|
47
|
+
reader: Callable[[Path], PhotometryState],
|
|
48
|
+
metadata_fn: Callable[[Path], dict[str, Any]] | None = None,
|
|
49
|
+
sort: bool = True,
|
|
50
|
+
) -> PhotometryCollection:
|
|
51
|
+
base = Path(base_path)
|
|
52
|
+
paths = list(base.glob(pattern))
|
|
53
|
+
if sort:
|
|
54
|
+
paths.sort()
|
|
55
|
+
|
|
56
|
+
def _iter() -> Iterable[PhotometryState]:
|
|
57
|
+
for p in paths:
|
|
58
|
+
st = reader(p)
|
|
59
|
+
if metadata_fn is not None:
|
|
60
|
+
st = st.with_metadata(**metadata_fn(p))
|
|
61
|
+
yield st
|
|
62
|
+
|
|
63
|
+
return cls(states=tuple(_iter()))
|
|
64
|
+
|
|
65
|
+
def __len__(self) -> int:
|
|
66
|
+
return len(self.states)
|
|
67
|
+
|
|
68
|
+
def __iter__(self):
|
|
69
|
+
return iter(self.states)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def subjects(self) -> tuple[str | None, ...]:
|
|
73
|
+
return tuple(s.subject for s in self.states)
|
|
74
|
+
|
|
75
|
+
def pipe(self, *stages: object) -> PhotometryCollection:
|
|
76
|
+
return PhotometryCollection(
|
|
77
|
+
states=tuple(s.pipe(*stages) for s in self.states)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def with_tags_from_table(
|
|
81
|
+
self,
|
|
82
|
+
path: Path | str,
|
|
83
|
+
*,
|
|
84
|
+
subject_getter: SubjectGetter = default_subject_getter,
|
|
85
|
+
overwrite: bool = False,
|
|
86
|
+
) -> PhotometryCollection:
|
|
87
|
+
table = read_tag_table(path)
|
|
88
|
+
return self.with_tags(
|
|
89
|
+
table, subject_getter=subject_getter, overwrite=overwrite
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def with_tags(
|
|
93
|
+
self,
|
|
94
|
+
table: TagTable,
|
|
95
|
+
*,
|
|
96
|
+
subject_getter: SubjectGetter = default_subject_getter,
|
|
97
|
+
overwrite: bool = False,
|
|
98
|
+
) -> PhotometryCollection:
|
|
99
|
+
tagged = tuple(
|
|
100
|
+
apply_tags(
|
|
101
|
+
s, table, subject_getter=subject_getter, overwrite=overwrite
|
|
102
|
+
)
|
|
103
|
+
for s in self.states
|
|
104
|
+
)
|
|
105
|
+
return PhotometryCollection(states=tagged)
|
|
106
|
+
|
|
107
|
+
def filter(self, **criteria: str) -> PhotometryCollection:
|
|
108
|
+
"""
|
|
109
|
+
Filter by tags: e.g. .filter(genotype="KO", context="A")
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def ok(s: PhotometryState) -> bool:
|
|
113
|
+
tags = s.tags
|
|
114
|
+
return all(tags.get(k) == v for k, v in criteria.items())
|
|
115
|
+
|
|
116
|
+
return PhotometryCollection(
|
|
117
|
+
states=tuple(s for s in self.states if ok(s))
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def groupby(self, key: str) -> dict[str, PhotometryCollection]:
|
|
121
|
+
"""
|
|
122
|
+
Group by a tag key: returns {tag_value: PhotometryCollection}.
|
|
123
|
+
Missing values go under "".
|
|
124
|
+
"""
|
|
125
|
+
groups: dict[str, list[PhotometryState]] = {}
|
|
126
|
+
for s in self.states:
|
|
127
|
+
v = s.tags.get(key, "")
|
|
128
|
+
groups.setdefault(v, []).append(s)
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
k: PhotometryCollection(states=tuple(v)) for k, v in groups.items()
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
def sort_by(self, *keys: str) -> PhotometryCollection:
|
|
135
|
+
"""
|
|
136
|
+
Sort by one or more tag keys (lexicographic).
|
|
137
|
+
Missing values sort as "".
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def sort_key(s: PhotometryState) -> tuple[str, ...]:
|
|
141
|
+
tags = s.tags
|
|
142
|
+
return tuple(tags.get(k, "") for k in keys)
|
|
143
|
+
|
|
144
|
+
return PhotometryCollection(
|
|
145
|
+
states=tuple(sorted(self.states, key=sort_key))
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def map(
|
|
149
|
+
self, fn: Callable[[PhotometryState], PhotometryState]
|
|
150
|
+
) -> PhotometryCollection:
|
|
151
|
+
return PhotometryCollection(states=tuple(fn(s) for s in self.states))
|
|
152
|
+
|
|
153
|
+
def align(
|
|
154
|
+
self,
|
|
155
|
+
*,
|
|
156
|
+
channels: Sequence[str] | None = None,
|
|
157
|
+
align: AlignMode = "intersection",
|
|
158
|
+
time_ref: TimeRef = "start",
|
|
159
|
+
dt: float | None = None,
|
|
160
|
+
target_fs: float | None = None,
|
|
161
|
+
interpolation: InterpKind = "linear",
|
|
162
|
+
fill: float = float("nan"),
|
|
163
|
+
) -> AlignedSignals:
|
|
164
|
+
return align_collection_signals(
|
|
165
|
+
self.states,
|
|
166
|
+
channels=channels,
|
|
167
|
+
align=align,
|
|
168
|
+
time_ref=time_ref,
|
|
169
|
+
dt=dt,
|
|
170
|
+
target_fs=target_fs,
|
|
171
|
+
interpolation=interpolation,
|
|
172
|
+
fill=fill,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def mean(
|
|
176
|
+
self,
|
|
177
|
+
*,
|
|
178
|
+
channels: Sequence[str] | None = None,
|
|
179
|
+
align: AlignMode = "intersection",
|
|
180
|
+
time_ref: TimeRef = "start",
|
|
181
|
+
dt: float | None = None,
|
|
182
|
+
target_fs: float | None = None,
|
|
183
|
+
interpolation: InterpKind = "linear",
|
|
184
|
+
fill: float = float("nan"),
|
|
185
|
+
name: str = "group_mean",
|
|
186
|
+
) -> PhotometryState:
|
|
187
|
+
aligned = self.align(
|
|
188
|
+
channels=channels,
|
|
189
|
+
align=align,
|
|
190
|
+
time_ref=time_ref,
|
|
191
|
+
dt=dt,
|
|
192
|
+
target_fs=target_fs,
|
|
193
|
+
interpolation=interpolation,
|
|
194
|
+
fill=fill,
|
|
195
|
+
)
|
|
196
|
+
return mean_state_from_aligned(aligned, name=name)
|
|
197
|
+
|
|
198
|
+
def to_h5(self, path: Path | str) -> None:
|
|
199
|
+
from .io.h5 import save_collection_h5
|
|
200
|
+
|
|
201
|
+
save_collection_h5(self, path)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def from_h5(cls, path: Path | str) -> PhotometryCollection:
|
|
205
|
+
from .io.h5 import load_collection_h5
|
|
206
|
+
|
|
207
|
+
return load_collection_h5(path)
|
fibphot/fit/__init__.py
ADDED
|
File without changes
|