stiminterp 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stiminterp/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ try:
4
+ __version__ = version("stiminterp")
5
+ except PackageNotFoundError:
6
+ # package is not installed
7
+ pass
8
+
9
+ from .stim_interpolate import remove_photostim_artefacts, StimInterpConfig
10
+ from .load_data.scanimage_metadata import ScanImageMetadata
11
+ from .pipeline import run_stiminterp
@@ -0,0 +1,124 @@
1
+ """
2
+ These functions are used to load data from a custom ScanImage data format
3
+ used in our experiments. The data are generated and saved by MATLAB scripts
4
+ that are not included in this repository.
5
+
6
+ ScanImage's DataRecorder should be used to record:
7
+ - FrameOnset
8
+ - StimTTL
9
+
10
+ Please also see:
11
+ BaselLaserMouse/ScanImageTools - Rob Campbell, Sumiya Kuroda
12
+ https://github.com/BaselLaserMouse/ScanImageTools/tree/2p-313
13
+ """
14
+
15
+ from pathlib import Path
16
+
17
+ import h5py
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+
22
+ def read_h5_array(h5_path: Path, dataset_path: str) -> np.ndarray:
23
+ """
24
+ Read a dataset from an HDF5 file and return it as a NumPy array.
25
+
26
+ Parameters
27
+ ----------
28
+ h5_path : str
29
+ Path to the .h5 / .hdf5 file
30
+ dataset_path : str
31
+ Path inside the HDF5 file (e.g. 'SyncTTL' or '/group/dataset')
32
+
33
+ Returns
34
+ -------
35
+ np.ndarray
36
+ Dataset as a NumPy array
37
+ """
38
+ with h5py.File(h5_path, "r") as f:
39
+ data = f[dataset_path][:]
40
+ return data
41
+
42
+
43
+ def digitize_ai_signal(ai_signal, digitizeThr=4):
44
+ ai_signal_digitized = np.array(ai_signal).copy()
45
+
46
+ ai_signal_digitized[ai_signal < digitizeThr] = 0.0
47
+ ai_signal_digitized[ai_signal >= digitizeThr] = 1.0
48
+
49
+ return ai_signal_digitized
50
+
51
+
52
+ def find_edges(signal: np.ndarray, return_edges: str = "both"):
53
+ signal = np.asarray(signal).astype(bool)
54
+
55
+ diff = np.diff(signal.astype(np.int8))
56
+
57
+ rising = np.where(diff == 1)[0] + 1
58
+ falling = np.where(diff == -1)[0] + 1
59
+
60
+ # If signal starts HIGH, first rising edge is at index 0
61
+ if signal[0]:
62
+ rising = np.insert(rising, 0, 0)
63
+
64
+ # If signal ends HIGH, final falling edge is at last index
65
+ if signal[-1]:
66
+ falling = np.append(falling, len(signal) - 1)
67
+
68
+ if return_edges == "rising":
69
+ return rising
70
+ elif return_edges == "falling":
71
+ return falling
72
+ elif return_edges == "both":
73
+ return rising, falling
74
+
75
+
76
+ def get_artefact_dfs(
77
+ path_to_h5: Path,
78
+ channel_name_framettl: str,
79
+ channel_name_stimttl: str,
80
+ digitize_threshold: float = 4,
81
+ ) -> tuple:
82
+ """Read the HDF5 file generated by ScanImage's Data Recorder
83
+
84
+ Parameters
85
+ ----------
86
+ path_to_h5 : Path
87
+ Path to the HDF5 file.
88
+ channel_name_framettl : str
89
+ Name of the channel for frame clock.
90
+ channel_name_stimttl : str
91
+ Name of the channel for photostim TTL.
92
+ digitize_threshold: float
93
+ Threshold used to digitize analog signals
94
+
95
+ Returns
96
+ -------
97
+ tuple
98
+ tuple of pd.DataFrame
99
+ """
100
+
101
+ framettl = read_h5_array(path_to_h5, channel_name_framettl)
102
+ stimttl = read_h5_array(path_to_h5, channel_name_stimttl)
103
+
104
+ framettl_d = digitize_ai_signal(framettl, digitizeThr=digitize_threshold)
105
+ stimttl_d = digitize_ai_signal(stimttl, digitizeThr=digitize_threshold)
106
+
107
+ framettl_rising, framettl_falling = find_edges(framettl_d)
108
+ stimttl_rising, stimttl_falling = find_edges(stimttl_d)
109
+
110
+ df_frames = pd.DataFrame(
111
+ {
112
+ "start": framettl_rising,
113
+ "stop": framettl_falling,
114
+ }
115
+ )
116
+
117
+ df_stims = pd.DataFrame(
118
+ {
119
+ "start": stimttl_rising,
120
+ "stop": stimttl_falling,
121
+ }
122
+ )
123
+
124
+ return df_frames, df_stims
@@ -0,0 +1,159 @@
1
+ # https://github.com/AllenNeuralDynamics/aind-ophys-mesoscope-image-splitter/blob/c034d3d893dc7365498b61e5353337c1a4e45fb5/code/tiff_metadata.py#L19
2
+
3
+ import copy
4
+ import pathlib
5
+ from typing import List, Union
6
+
7
+ import tifffile
8
+
9
+
10
+ def _read_metadata(tiff_path: pathlib.Path):
11
+ """
12
+ Calls tifffile.read_scanimage_metadata on the specified
13
+ path and returns the result. This method was factored
14
+ out so that it could be easily mocked in unit tests.
15
+ """
16
+ return tifffile.read_scanimage_metadata(open(tiff_path, "rb"))
17
+
18
+
19
+ class ScanImageMetadata(object):
20
+ """
21
+ A class to handle reading and parsing the metadata that
22
+ comes with the TIFF files produced by ScanImage
23
+
24
+ Parameters
25
+ ----------
26
+ tiff_path: pathlib.Path
27
+ Path to the TIFF file whose metadata we are parsing
28
+ """
29
+
30
+ def __init__(self, tiff_path: pathlib.Path):
31
+ self._file_path = tiff_path
32
+ if not tiff_path.is_file():
33
+ raise ValueError(f"{tiff_path.resolve().absolute()} is not a file")
34
+ self._metadata = _read_metadata(tiff_path)
35
+
36
+ @property
37
+ def file_path(self) -> pathlib.Path:
38
+ return self._file_path
39
+
40
+ @property
41
+ def raw_metadata(self) -> tuple:
42
+ """
43
+ Return a copy of the raw metadata as read by
44
+ tifffile.read_scanimage_metadata.
45
+ """
46
+ return copy.deepcopy(self._metadata)
47
+
48
+ @property
49
+ def numVolumes(self) -> int:
50
+ """
51
+ The metadata field representing the number of volumes
52
+ recorded by the rig
53
+ """
54
+ if not hasattr(self, "_numVolumes"):
55
+ value = self._metadata[0]["SI.hStackManager.actualNumVolumes"]
56
+ if not isinstance(value, int):
57
+ raise ValueError(
58
+ f"in {self._file_path}\n"
59
+ "SI.hStackManager.actualNumVolumes is a "
60
+ f"{type(value)}; expected int"
61
+ )
62
+
63
+ self._numVolumes = value
64
+
65
+ return self._numVolumes
66
+
67
+ @property
68
+ def numSlices(self) -> int:
69
+ """
70
+ The metadata field representing the number of slices
71
+ recorded by the rig
72
+ """
73
+ if not hasattr(self, "_numSlices"):
74
+ value = self._metadata[0]["SI.hStackManager.actualNumSlices"]
75
+ if not isinstance(value, int):
76
+ raise ValueError(
77
+ f"in {self._file_path}\n"
78
+ "SI.hStackManager.actualNumSlices is a "
79
+ f"{type(value)}; expected int"
80
+ )
81
+ self._numSlices = value
82
+
83
+ return self._numSlices
84
+
85
+ @property
86
+ def channelSave(self) -> Union[int, List[int]]:
87
+ """
88
+ The metadata field representing which channels were saved
89
+ in this TIFF. Either 1 or [1, 2]
90
+ """
91
+ if not hasattr(self, "_channelSave"):
92
+ self._channelSave = self._metadata[0]["SI.hChannels.channelSave"]
93
+ return self._channelSave
94
+
95
+ @property
96
+ def defined_rois(self) -> List[dict]:
97
+ """
98
+ Get the ROIs defined in this TIFF file
99
+
100
+ This is list of dicts, each dict containing the ScanImage
101
+ metadata for a given ROI
102
+
103
+ In this context, an ROI is a 3-dimensional volume of the brain
104
+ that was scanned by the microscope.
105
+ """
106
+ if not hasattr(self, "_defined_rois"):
107
+ roi_parent = self._metadata[1]["RoiGroups"]
108
+ roi_group = roi_parent["imagingRoiGroup"]["rois"]
109
+ if isinstance(roi_group, dict):
110
+ self._defined_rois = [
111
+ roi_group,
112
+ ]
113
+ elif isinstance(roi_group, list):
114
+ self._defined_rois = roi_group
115
+ else:
116
+ msg = "unable to parse "
117
+ msg += "self._metadata[1]['RoiGroups']"
118
+ msg += "['imagingROIGroup']['rois'] "
119
+ msg += f"of type {type(roi_group)}"
120
+ raise RuntimeError(msg)
121
+
122
+ # use copy to make absolutely sure self._defined_rois
123
+ # is not accidentally changed downstream
124
+ return copy.deepcopy(self._defined_rois)
125
+
126
+ @property
127
+ def n_rois(self) -> int:
128
+ """
129
+ Number of ROIs defined in the metadata for this TIFF file.
130
+ """
131
+ if not hasattr(self, "_n_rois"):
132
+ self._n_rois = len(self.defined_rois)
133
+ return self._n_rois
134
+
135
+ @property
136
+ def n_chans(self) -> int:
137
+ """
138
+ Number of channels saved in this TIFF.
139
+
140
+ ScanImage's SI.hChannels.channelSave is typically either:
141
+ - 1 (meaning channel 1 only), or
142
+ - [1, 2] (meaning channels 1 and 2)
143
+ """
144
+ ch = self.channelSave
145
+ if isinstance(ch, int):
146
+ return 1
147
+ return len(ch)
148
+
149
+ def zs_for_roi(self, i_roi: int) -> List[int]:
150
+ """
151
+ Return a list of the z-values at which the specified
152
+ ROI was scanned
153
+ """
154
+ if i_roi >= self.n_rois:
155
+ msg = f"You asked for ROI {i_roi}; "
156
+ msg += f"there are only {self.n_rois} "
157
+ msg += "specified in this TIFF file"
158
+ raise ValueError(msg)
159
+ return self.defined_rois[i_roi]["zs"]
stiminterp/pipeline.py ADDED
@@ -0,0 +1,51 @@
1
+ from pathlib import Path
2
+
3
+ from ScanImageTiffReader import ScanImageTiffReader
4
+ from tifffile import imwrite
5
+
6
+ from stiminterp import remove_photostim_artefacts
7
+ from stiminterp.load_data.custom_data_loader import get_artefact_dfs
8
+ from stiminterp.load_data.scanimage_metadata import ScanImageMetadata
9
+
10
+
11
+ def run_stiminterp(
12
+ input_tif: str,
13
+ input_h5: str | None = None,
14
+ output_tif: str | None = None,
15
+ ):
16
+ tif_path = Path(input_tif)
17
+ sim = ScanImageMetadata(tif_path)
18
+
19
+ # infer h5 if not provided
20
+ if input_h5 is None:
21
+ h5_path = tif_path.with_suffix(".h5")
22
+ else:
23
+ h5_path = Path(input_h5)
24
+
25
+ # determine output path
26
+ if output_tif is None:
27
+ out_path = tif_path.with_name(f"{tif_path.stem}_corrected.tif")
28
+ else:
29
+ out_tmp = Path(output_tif)
30
+ if out_tmp.is_dir():
31
+ out_path = out_tmp / f"{tif_path.stem}_corrected.tif"
32
+ else:
33
+ out_path = out_tmp
34
+
35
+ vol = ScanImageTiffReader(input_tif).data()
36
+
37
+ df_frames, df_stims = get_artefact_dfs(
38
+ h5_path,
39
+ "FrameTTL",
40
+ "SatsumaGateTTL",
41
+ )
42
+
43
+ corrected, bad_mask, df_split = remove_photostim_artefacts(
44
+ vol,
45
+ df_frames,
46
+ df_stims,
47
+ frame_gap=sim.n_rois - 1,
48
+ num_channel=sim.n_chans,
49
+ )
50
+
51
+ imwrite(str(out_path), corrected)
@@ -0,0 +1,81 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def create_sanitycheck_axes(nchannel: int):
6
+ """
7
+ Create axes for stim visualisation.
8
+
9
+ Layout:
10
+ rows = nchannel
11
+ cols = 3 (Uncorrected | Bad mask | Corrected)
12
+
13
+ Returns
14
+ -------
15
+ fig : matplotlib.figure.Figure
16
+ axes : ndarray of shape (nchannel, 3)
17
+ """
18
+
19
+ ncols = 3
20
+
21
+ fig, axes = plt.subplots(
22
+ nchannel,
23
+ ncols,
24
+ figsize=(5 * ncols, 4 * nchannel),
25
+ sharex=True,
26
+ sharey=True,
27
+ constrained_layout=True,
28
+ )
29
+
30
+ # Ensure 2D array even if nchannel == 1
31
+ if nchannel == 1:
32
+ axes = axes[None, :]
33
+
34
+ # Column titles (only top row)
35
+ axes[0, 0].set_title("Uncorrected")
36
+ axes[0, 1].set_title("NaN mask")
37
+ axes[0, 2].set_title("Corrected")
38
+
39
+ return fig, axes
40
+
41
+
42
+ def plot_removal(
43
+ axes_row,
44
+ frame: int,
45
+ y_frac_start: float,
46
+ y_frac_stop: float,
47
+ uncorrected: np.ndarray,
48
+ corrected: np.ndarray,
49
+ bad_mask: np.ndarray,
50
+ channel: int, # <-- NEW
51
+ cmap: str = "gray",
52
+ ):
53
+ raw = uncorrected[frame]
54
+ corr = corrected[frame]
55
+ mask = bad_mask[frame]
56
+
57
+ Y = uncorrected.shape[1]
58
+ y_min = int(y_frac_start * Y)
59
+ y_max = int(y_frac_stop * Y)
60
+
61
+ vmin = np.percentile(raw, 1)
62
+ vmax = np.percentile(raw, 99.5)
63
+
64
+ # --- Uncorrected ---
65
+ axes_row[0].set_ylabel(f"Ch {channel}\nFrame {frame}")
66
+ axes_row[0].imshow(raw, cmap=cmap, vmin=vmin, vmax=vmax)
67
+ axes_row[0].axhline(y_min, c="r", lw=2)
68
+ axes_row[0].axhline(y_max, c="r", lw=2)
69
+ axes_row[0].axis("off")
70
+
71
+ # --- Bad mask ---
72
+ axes_row[1].imshow(mask, cmap="Reds")
73
+ axes_row[1].axhline(y_min, c="k", lw=2)
74
+ axes_row[1].axhline(y_max, c="k", lw=2)
75
+ axes_row[1].axis("off")
76
+
77
+ # --- Corrected ---
78
+ axes_row[2].imshow(corr, cmap=cmap, vmin=vmin, vmax=vmax)
79
+ # axes_row[2].axhline(y_min, c="r", lw=2)
80
+ # axes_row[2].axhline(y_max, c="r", lw=2)
81
+ axes_row[2].axis("off")
@@ -0,0 +1,372 @@
1
+ """
2
+ stim_interpolate.py
3
+
4
+ Photostimulation artefact removal via 1D interpolation.
5
+
6
+ Pipeline
7
+ --------
8
+ 1) Use frame timing + stim timing to determine artefact regions
9
+ (frame index + fraction within frame).
10
+ 2) Convert fractions -> scanline rows (Y).
11
+ 3) Set contaminated pixels to NaN.
12
+ 4) Fill NaNs per pixel using nearest-neighbor in *frame_index* space.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Configuration
25
+ # -----------------------------------------------------------------------------
26
+
27
+
28
+ @dataclass
29
+ class StimInterpConfig:
30
+ """Configuration for stim artefact removal."""
31
+
32
+ pad_rows: int = 5
33
+ require_n_good: int = 1 # nearest fill needs only 1 neighbor
34
+
35
+
36
+ # -----------------------------------------------------------------------------
37
+ # Public API
38
+ # -----------------------------------------------------------------------------
39
+
40
+
41
+ def remove_photostim_artefacts(
42
+ movie: np.ndarray,
43
+ df_frames: pd.DataFrame,
44
+ df_stims: pd.DataFrame,
45
+ frame_gap: Optional[int] = None,
46
+ num_channel: Optional[int] = None,
47
+ cfg: Optional[StimInterpConfig] = None,
48
+ ) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
49
+ """
50
+ Remove photostimulation artefacts by nearest-neighbor temporal filling.
51
+
52
+ Parameters
53
+ ----------
54
+ movie : np.ndarray
55
+ Imaging movie, shape (T, Y, X).
56
+ If multiple channels are saved interleaved per TTL frame, then:
57
+ T = len(df_frames) * num_channel
58
+ df_frames : pd.DataFrame
59
+ Columns ["start", "stop"] per TTL frame.
60
+ NOTE: TTL frames should already include plane interleaving if present.
61
+ df_stims : pd.DataFrame
62
+ Columns ["start", "stop"] per stim interval.
63
+ frame_gap : int, optional
64
+ Plane interleave indicator:
65
+ - 1 => 2 planes interleaved (plane_id = ttl_frame % 2)
66
+ - 2 => 3 planes interleaved (plane_id = ttl_frame % 3)
67
+ If None, treat as single plane for interpolation grouping i.e, 0.
68
+ num_channel : int, optional
69
+ Number of interleaved channels in the TIFF (channel-fast ordering).
70
+ If None, inferred as movie.shape[0] // len(df_frames).
71
+ cfg : StimInterpConfig, optional
72
+
73
+ Returns
74
+ -------
75
+ corrected : np.ndarray
76
+ Corrected movie (float32).
77
+ bad_mask : np.ndarray
78
+ Boolean mask (T, Y, X) indicating pixels replaced.
79
+ df_split : pd.DataFrame
80
+ Per-TTL-frame stim segments: ["frame", "frac_start", "frac_stop"].
81
+ """
82
+ if cfg is None:
83
+ cfg = StimInterpConfig()
84
+
85
+ movie = np.asarray(movie)
86
+ if movie.ndim != 3:
87
+ raise ValueError("movie must have shape (T, Y, X)")
88
+ T, Y, X = movie.shape
89
+
90
+ n_ttl = len(df_frames)
91
+ if n_ttl <= 0:
92
+ raise ValueError("df_frames is empty")
93
+
94
+ if T % n_ttl != 0:
95
+ raise ValueError(
96
+ f"movie.shape[0]={T} must be a multiple of len(df_frames)={n_ttl}."
97
+ )
98
+
99
+ inferred_num_channel = T // n_ttl
100
+ nchan = (
101
+ int(num_channel)
102
+ if num_channel is not None
103
+ else int(inferred_num_channel)
104
+ )
105
+
106
+ if nchan <= 0 or (n_ttl * nchan) != T:
107
+ raise ValueError(
108
+ f"movie.shape[0]={T} must be a multiple of len(df_frames)={n_ttl}."
109
+ )
110
+
111
+ # --- stim regions in TTL frame space ---
112
+ df_split = _artefact_regions(df_frames, df_stims)
113
+ if df_split.empty:
114
+ return (
115
+ movie.astype(np.float32, copy=True),
116
+ np.zeros_like(movie, dtype=bool),
117
+ df_split,
118
+ )
119
+
120
+ # --- bad scanlines in TTL frame space: (n_ttl, Y) ---
121
+ bad_lines_ttl = _build_bad_line_mask(
122
+ df_split=df_split,
123
+ T=n_ttl,
124
+ Y=Y,
125
+ pad_rows=cfg.pad_rows,
126
+ )
127
+
128
+ # --- expand across channels to movie time axis: (T, Y) ---
129
+ # TTL frame t corresponds to movie indices t*nchan + c for each channel c
130
+ bad_lines = np.repeat(bad_lines_ttl, repeats=nchan, axis=0)
131
+ bad_mask = np.repeat(bad_lines[:, :, None], X, axis=2)
132
+
133
+ corrected = movie.astype(np.float32, copy=True)
134
+ corrected[bad_mask] = np.nan
135
+
136
+ frame_index = np.arange(n_ttl, dtype=np.int32)
137
+ donor_mask = np.ones(n_ttl, dtype=bool)
138
+
139
+ corrected = interpolate_nan(
140
+ corrected,
141
+ frame_index=frame_index,
142
+ donor_mask=donor_mask,
143
+ require_n_good=cfg.require_n_good,
144
+ num_channel=nchan,
145
+ frame_gap=frame_gap,
146
+ )
147
+
148
+ return corrected, bad_mask, df_split
149
+
150
+
151
+ # -----------------------------------------------------------------------------
152
+ # Region detection (timing -> per-frame fractions)
153
+ # -----------------------------------------------------------------------------
154
+
155
+
156
+ def _artefact_regions(
157
+ df_frames: pd.DataFrame, df_stims: pd.DataFrame
158
+ ) -> pd.DataFrame:
159
+ df_frames = df_frames.sort_values("start").reset_index(drop=True)
160
+ df_stims = df_stims.sort_values("start")
161
+
162
+ if df_stims.empty:
163
+ return pd.DataFrame(columns=["frame", "frac_start", "frac_stop"])
164
+
165
+ # Remove stims outside acquisition span
166
+ t0 = df_frames["start"].iloc[0]
167
+ t1 = df_frames["stop"].iloc[-1]
168
+ df_stims = df_stims[(df_stims["stop"] > t0) & (df_stims["start"] < t1)]
169
+ if df_stims.empty:
170
+ return pd.DataFrame(columns=["frame", "frac_start", "frac_stop"])
171
+
172
+ all_bounds = np.empty(2 * len(df_frames), dtype=float)
173
+ all_bounds[0::2] = df_frames["start"].to_numpy()
174
+ all_bounds[1::2] = df_frames["stop"].to_numpy()
175
+
176
+ frame_start, frac_start = _map_times_to_frame_frac(
177
+ times=df_stims["start"].to_numpy(),
178
+ frame_boundaries=df_frames["stop"].to_numpy(),
179
+ all_boundaries=all_bounds,
180
+ fill=0.0,
181
+ offset=1,
182
+ )
183
+ frame_stop, frac_stop = _map_times_to_frame_frac(
184
+ times=df_stims["stop"].to_numpy(),
185
+ frame_boundaries=df_frames["start"].to_numpy(),
186
+ all_boundaries=all_bounds,
187
+ fill=1.0,
188
+ offset=0,
189
+ )
190
+
191
+ df = pd.DataFrame(
192
+ {
193
+ "frame_start": frame_start,
194
+ "frac_start": frac_start,
195
+ "frame_stop": frame_stop,
196
+ "frac_stop": frac_stop,
197
+ },
198
+ index=df_stims.index,
199
+ )
200
+ return _split_multi_frame_stims(df)
201
+
202
+
203
+ def _map_times_to_frame_frac(
204
+ times: np.ndarray,
205
+ frame_boundaries: np.ndarray,
206
+ all_boundaries: np.ndarray,
207
+ fill: float,
208
+ offset: int = 0,
209
+ ) -> Tuple[np.ndarray, np.ndarray]:
210
+ frame = (
211
+ np.interp(
212
+ times,
213
+ frame_boundaries,
214
+ np.arange(len(frame_boundaries)),
215
+ left=-offset,
216
+ )
217
+ + offset
218
+ )
219
+ frame = frame.astype(int)
220
+
221
+ all_idx = np.interp(times, all_boundaries, np.arange(len(all_boundaries)))
222
+ out_of_frame = (all_idx.astype(int) % 2) == 1
223
+
224
+ frac_template = np.tile([0.0, 1.0], len(frame_boundaries))
225
+ frac = np.interp(times, all_boundaries, frac_template)
226
+ frac[out_of_frame] = float(fill)
227
+
228
+ return frame, np.clip(frac, 0.0, 1.0)
229
+
230
+
231
+ def _split_multi_frame_stims(df: pd.DataFrame) -> pd.DataFrame:
232
+ out = []
233
+ for r in df.itertuples():
234
+ if r.frame_start == r.frame_stop:
235
+ out.append(
236
+ (int(r.frame_start), float(r.frac_start), float(r.frac_stop))
237
+ )
238
+ continue
239
+
240
+ out.append((int(r.frame_start), float(r.frac_start), 1.0))
241
+ for f in range(int(r.frame_start) + 1, int(r.frame_stop)):
242
+ out.append((f, 0.0, 1.0))
243
+ out.append((int(r.frame_stop), 0.0, float(r.frac_stop)))
244
+
245
+ return pd.DataFrame(out, columns=["frame", "frac_start", "frac_stop"])
246
+
247
+
248
+ # -----------------------------------------------------------------------------
249
+ # Mask building (fractions -> scanlines)
250
+ # -----------------------------------------------------------------------------
251
+
252
+
253
+ def _build_bad_line_mask(
254
+ df_split: pd.DataFrame,
255
+ T: int,
256
+ Y: int,
257
+ pad_rows: int = 0,
258
+ ) -> np.ndarray:
259
+ bad = np.zeros((T, Y), dtype=bool)
260
+
261
+ for r in df_split.itertuples(index=False):
262
+ t = int(r.frame)
263
+ if t < 0 or t >= T:
264
+ continue
265
+
266
+ y0 = int(np.floor(float(r.frac_start) * Y))
267
+ y1 = int(np.ceil(float(r.frac_stop) * Y))
268
+
269
+ y0 = max(0, y0 - pad_rows)
270
+ y1 = min(Y, y1 + pad_rows)
271
+
272
+ if y1 > y0:
273
+ bad[t, y0:y1] = True
274
+
275
+ return bad
276
+
277
+
278
+ # -----------------------------------------------------------------------------
279
+ # Nearest-neighbor 1D interpolation
280
+ # -----------------------------------------------------------------------------
281
+
282
+
283
+ def interpolate_nan(
284
+ movie_float: np.ndarray,
285
+ frame_index: np.ndarray,
286
+ donor_mask: np.ndarray,
287
+ require_n_good: int = 2,
288
+ num_channel: int = 1,
289
+ frame_gap: Optional[int] = None,
290
+ ) -> np.ndarray:
291
+ """
292
+ Fill NaNs using np.interp
293
+ """
294
+
295
+ T, Y, X = movie_float.shape
296
+ n_ttl = len(frame_index)
297
+
298
+ if T != n_ttl * num_channel:
299
+ raise ValueError(
300
+ f"T ({T}) must equal len(frame_index)*num_channel "
301
+ f"({n_ttl * num_channel})"
302
+ )
303
+
304
+ x = frame_index.astype(np.float32)
305
+ flat = movie_float.reshape(T, -1)
306
+
307
+ # Plane grouping
308
+ if frame_gap is None:
309
+ num_planes = 1
310
+ plane_ids = np.zeros(n_ttl, dtype=np.int32)
311
+ else:
312
+ num_planes = int(frame_gap) + 1
313
+ plane_ids = np.arange(n_ttl) % num_planes
314
+
315
+ for c in range(num_channel):
316
+ ys = flat[c::num_channel, :] # (n_ttl, n_pixels)
317
+
318
+ if num_planes == 1:
319
+ _interp_block_numpy(ys, x, donor_mask, require_n_good)
320
+ continue
321
+
322
+ for p in range(num_planes):
323
+ idx = np.where(plane_ids == p)[0]
324
+ if len(idx) == 0:
325
+ continue
326
+
327
+ block = ys[idx, :]
328
+ _interp_block_numpy(
329
+ block,
330
+ x[idx],
331
+ donor_mask[idx],
332
+ require_n_good,
333
+ )
334
+ ys[idx, :] = block
335
+
336
+ flat[c::num_channel, :] = ys
337
+
338
+ return flat.reshape(T, Y, X)
339
+
340
+
341
+ def _interp_block_numpy(
342
+ block: np.ndarray,
343
+ x: np.ndarray,
344
+ donor_mask: np.ndarray,
345
+ require_n_good: int,
346
+ ) -> None:
347
+ """
348
+ In-place linear interpolation using np.interp.
349
+ """
350
+
351
+ N, P = block.shape
352
+
353
+ for j in range(P):
354
+ y = block[:, j]
355
+ nans = np.isnan(y)
356
+ if not nans.any():
357
+ continue
358
+
359
+ good = donor_mask & ~nans
360
+ if good.sum() < require_n_good:
361
+ continue
362
+
363
+ x_good = x[good]
364
+ y_good = y[good]
365
+
366
+ # NumPy requires sorted x
367
+ # frame_index already sorted, but keep safe
368
+ order = np.argsort(x_good)
369
+ x_good = x_good[order]
370
+ y_good = y_good[order]
371
+
372
+ y[nans] = np.interp(x[nans], x_good, y_good)
@@ -0,0 +1,114 @@
1
+ Metadata-Version: 2.4
2
+ Name: stiminterp
3
+ Version: 0.1
4
+ Summary: Photostimulation artifact removal via interpolation
5
+ Author-email: Sumiya Kuroda <s.kuroda@ucl.ac.uk>
6
+ License: BSD-3-Clause
7
+ Project-URL: Homepage, https://github.com/SainsburyWellcomeCentre/stiminterp
8
+ Project-URL: Bug Tracker, https://github.com/SainsburyWellcomeCentre/stiminterp/issues
9
+ Project-URL: Source Code, https://github.com/SainsburyWellcomeCentre/stiminterp
10
+ Project-URL: User Support, https://github.com/SainsburyWellcomeCentre/stiminterp/issues
11
+ Classifier: Development Status :: 2 - Pre-Alpha
12
+ Classifier: Programming Language :: Python
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Operating System :: OS Independent
18
+ Classifier: License :: OSI Approved :: BSD License
19
+ Requires-Python: >=3.10.0
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: numpy
23
+ Requires-Dist: tifffile
24
+ Requires-Dist: matplotlib
25
+ Requires-Dist: scipy
26
+ Requires-Dist: PyYAML
27
+ Requires-Dist: fancylog
28
+ Requires-Dist: matplotlib
29
+ Requires-Dist: pandas
30
+ Requires-Dist: tqdm
31
+ Requires-Dist: scikit-learn
32
+ Requires-Dist: scikit-image
33
+ Requires-Dist: scanimage-tiff-reader
34
+ Requires-Dist: h5py
35
+ Provides-Extra: dev
36
+ Requires-Dist: pytest; extra == "dev"
37
+ Requires-Dist: pytest-cov; extra == "dev"
38
+ Requires-Dist: coverage; extra == "dev"
39
+ Requires-Dist: tox; extra == "dev"
40
+ Requires-Dist: black; extra == "dev"
41
+ Requires-Dist: mypy; extra == "dev"
42
+ Requires-Dist: pre-commit; extra == "dev"
43
+ Requires-Dist: ruff; extra == "dev"
44
+ Requires-Dist: setuptools_scm; extra == "dev"
45
+ Dynamic: license-file
46
+
47
+ [![Python
48
+ Version](https://img.shields.io/pypi/pyversions/stiminterp.svg)](https://pypi.org/project/stiminterp)
49
+ [![PyPI
50
+ Version](https://img.shields.io/pypi/v/stiminterp.svg)](https://pypi.org/project/stiminterp)
51
+ [![License](https://img.shields.io/badge/License-BSD_3--Clause-orange.svg)](https://opensource.org/licenses/BSD-3-Clause)
52
+
53
+ # stiminterp
54
+
55
+ **stiminterp** provides an 1D-interpolation-based solution for removing
56
+ photostimulation artefacts from multiphoton calcium imaging data.
57
+
58
+ The holographic stimulation saturates the PMTs and causes data loss. By identifying lines with the stimulation artefacts, this pipeline can replace the pixel rows containing the stimulation artefacts with the average values from corresponding rows in the preceding and following frames.
59
+
60
+ ------------------------------------------------------------------------
61
+
62
+ ## Installation
63
+
64
+ Create a fresh environment and install via pip:
65
+
66
+ conda create -n stiminterp-env python=3.12
67
+ conda activate stiminterp-env
68
+ pip install stiminterp
69
+
70
+ ------------------------------------------------------------------------
71
+
72
+ ## Overview
73
+
74
+ Understanding the causal role of brain dynamics is one of the fundamental questions in systemns neuroscience. Multiphoton holographic optogenetics, combined with multiphoton calcium imaging, enables causal testing of circuit models at single-cell resolution. However, photostimulation can saturate PMTs, producing line artefacts in the imaging data.
75
+
76
+ With `stiminterp` you can:
77
+
78
+ - Detect artefact-contaminated lines from HDF5 generated by ScanImage
79
+ - Perform spatiotemporal 1D-interpolation using `scipy.interpolate`
80
+ - Recover calcium imaging movies that can be fed into standard analysis pipelines such as `suite2p`
81
+ ------------------------------------------------------------------------
82
+
83
+ ## Data Source & Funding
84
+
85
+ Sample data used for examples will be publicly available in the near future.
86
+
87
+ All microscopy data has been acquired using a custom two-photon microscope by [Sumiya Kuroda](https://github.com/sumiya-kuroda) in the [Mrsic-Flogel Lab](https://www.sainsburywellcome.org/web/groups/mrsic-flogel-lab) and Dale Elgar from [COSYS Ltd.](https://www.cosys.org.uk/).
88
+
89
+ This work represents a joint collaboration between Stanford University and the Sainsbury Wellcome Centre for Neural Circuits and Behaviour, University College London, supported by the Gatsby Charitable Foundation.
90
+
91
+ ------------------------------------------------------------------------
92
+
93
+ ## References
94
+
95
+ Previous work on artefact removal of all-optical imaging movies:
96
+ - [Drinnenberg et al, 2025, bioRxiv](https://www.biorxiv.org/content/10.1101/2025.10.21.683734v1)
97
+ - [Attinger et al, 2025, bioRxiv](https://www.biorxiv.org/content/10.1101/2025.10.21.683723v1)
98
+
99
+ This package was inspired by [previous calcium imaging analysis pipeline at Deisseroth lab](https://github.com/deisseroth-lab/two-photon/tree/main).
100
+
101
+ This repo was made using [neuroinformatics-unit/python-cookiecutter](https://github.com/neuroinformatics-unit/python-cookiecutter). See [here](https://python-cookiecutter.neuroinformatics.dev/) for more info.
102
+
103
+ ------------------------------------------------------------------------
104
+
105
+ ## Contributing
106
+
107
+ Contributions are welcome. Please open an issue or submit a pull request
108
+ on GitHub.
109
+
110
+ ------------------------------------------------------------------------
111
+
112
+ ## License
113
+
114
+ BSD-3-Clause
@@ -0,0 +1,11 @@
1
+ stiminterp/__init__.py,sha256=JoNB-7HQLr1HndEMytwpXp_DwhsyEb9pWaV-OQGES3g,349
2
+ stiminterp/pipeline.py,sha256=tfVVt6V7M7iypFJlR-ELU9TC5-W2b0t7KOph8oUdZQ8,1328
3
+ stiminterp/stim_interpolate.py,sha256=SLPKi8zXpbd7TZSTtVl0SbAl8A6M9FEKWSQWRSPi6xc,10692
4
+ stiminterp/load_data/custom_data_loader.py,sha256=NIVMg0aKypvzrHuVtfR1Tjg9J8EstT_4s2NmsNgEJcA,3236
5
+ stiminterp/load_data/scanimage_metadata.py,sha256=Dv22tm1-qKonLxWotalJx5H5BwSqxxuqYY6yErsIq64,5137
6
+ stiminterp/plotting_hooks/sanity_check.py,sha256=iQxstWOihSfpVuKdOpR2IUyEVat9kes-JdmFOaNVk24,1959
7
+ stiminterp-0.1.dist-info/licenses/LICENSE,sha256=5m5N81VhYYpXw_lIVFMjIf_IIx-BJxn7RX4C5RuoYjQ,1482
8
+ stiminterp-0.1.dist-info/METADATA,sha256=STz7HVzyfMKgUK38M7IDDkAKPpYh02UWDTKwxl8kvRc,5099
9
+ stiminterp-0.1.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
10
+ stiminterp-0.1.dist-info/top_level.txt,sha256=eBlYdz1-bU8B8sVmunLVsv4ixNKh4fGjC6BxknFft54,11
11
+ stiminterp-0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,28 @@
1
+
2
+ Copyright (c) 2026, Sumiya Kuroda
3
+ All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ * Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ * Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ * Neither the name of stiminterp nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1 @@
1
+ stiminterp