batch2p 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,342 @@
1
+ """Suite2P source extraction algorithm."""
2
+ import json
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+ from .base import SourceExtractor
10
+
11
+
12
+ def _update_two_level_dict(d1: dict, d2: dict) -> dict:
13
+ """Update d1 with values from d2, respecting two-level dict structure.
14
+
15
+ Non-dict fields in d1 are overwritten by the corresponding field in d2.
16
+ Dict fields in d1 are updated at the second level with the corresponding
17
+ second-level dict in d2 (if present).
18
+ """
19
+ for key, val in d2.items():
20
+ if key in d1 and isinstance(d1[key], dict):
21
+ if isinstance(val, dict):
22
+ d1[key].update(val)
23
+ else:
24
+ d1[key] = val
25
+ return d1
26
+
27
+
28
+ def _detect_torch_device() -> str:
29
+ try:
30
+ import torch
31
+ if torch.cuda.is_available():
32
+ return "cuda"
33
+ except ImportError:
34
+ pass
35
+ try:
36
+ import cupy
37
+ if cupy.cuda.runtime.getDeviceCount() > 0:
38
+ return "cuda"
39
+ except Exception:
40
+ pass
41
+ return "cpu"
42
+
43
+
44
+ def _load_settings(params_path: Path) -> tuple[dict, dict]:
45
+ """Load suite2p settings, returning (user_params, merged_settings).
46
+
47
+ user_params: the params as read from the JSON file (with torch_device added
48
+ if not present), before merging with suite2p defaults.
49
+ merged_settings: suite2p default_settings() updated with user_params.
50
+ """
51
+ from suite2p.parameters import default_settings
52
+
53
+ with open(params_path) as f:
54
+ user_params = json.load(f)
55
+
56
+ user_params.pop("comments", None)
57
+
58
+ if "torch_device" not in user_params:
59
+ user_params["torch_device"] = _detect_torch_device()
60
+
61
+ merged = _update_two_level_dict(default_settings(), user_params)
62
+ return user_params, merged
63
+
64
+
65
+ def _params_to_json_serializable(params: dict) -> dict:
66
+ result = {}
67
+ for k, v in params.items():
68
+ if isinstance(v, np.ndarray):
69
+ result[k] = v.tolist()
70
+ elif isinstance(v, dict):
71
+ result[k] = _params_to_json_serializable(v)
72
+ else:
73
+ result[k] = v
74
+ return result
75
+
76
+
77
+ def _get_tif_n_frames(tif_path: Path) -> int:
78
+ import tifffile
79
+ with tifffile.TiffFile(tif_path) as tif:
80
+ return len(tif.pages)
81
+
82
+
83
+ def _build_sync_indices(
84
+ tif_files: list[Path],
85
+ sync_results: list[dict],
86
+ block_size: int,
87
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
88
+ """Return a list of (global_vol_indices, t_frames) per tif file.
89
+
90
+ global_vol_indices: volume-level indices into the suite2p combined array for
91
+ this tif's contribution (i.e. local volume indices + cumulative offset).
92
+ t_frames: timestamps corresponding to those indices, with cumulative time offset.
93
+ """
94
+ # Raw page counts are used for offsets; block_size converts pages → volumes.
95
+ frame_offsets = [0]
96
+ for tif in tif_files[:-1]:
97
+ frame_offsets.append(frame_offsets[-1] + _get_tif_n_frames(tif))
98
+
99
+ time_offsets = [0.0]
100
+ for stats in sync_results:
101
+ time_offsets.append(time_offsets[-1] + stats['frames_time_idx'].t[-1])
102
+
103
+ result = []
104
+ for tif, stats, frame_offset, time_offset in zip(tif_files, sync_results, frame_offsets, time_offsets):
105
+ frames_time_idx = stats['frames_time_idx']
106
+ local_indices = frames_time_idx.d.astype(int)
107
+
108
+ if block_size > 1:
109
+ # Assign each synced frame to its volume (integer division).
110
+ # Only keep volumes where all block_size frames are present.
111
+ frame_to_volume = local_indices // block_size
112
+ volumes, counts = np.unique(frame_to_volume, return_counts=True)
113
+ complete_volumes = volumes[counts == block_size]
114
+ if len(complete_volumes) < len(volumes):
115
+ n_dropped = len(volumes) - len(complete_volumes)
116
+ print(f" Warning: dropping {n_dropped} incomplete volume(s) "
117
+ f"(fewer than {block_size} synced frames) for {tif.name}")
118
+ local_indices = complete_volumes # now volume indices
119
+ # Find the position in the Tsd of each volume's first frame number.
120
+ # frames_time_idx.d contains frame numbers (with possible gaps), so
121
+ # we cannot use frame number directly as a positional index into .t.
122
+ first_frame_numbers = local_indices * block_size
123
+ pos = np.searchsorted(frames_time_idx.d.astype(int), first_frame_numbers)
124
+ t_frames = frames_time_idx.t[pos]
125
+ vol_offset = frame_offset // block_size
126
+ else:
127
+ t_frames = frames_time_idx.t
128
+ vol_offset = frame_offset
129
+
130
+ result.append((local_indices + vol_offset, t_frames + time_offset))
131
+ return result
132
+
133
+
134
+ def _compute_F_sub(source_dir: Path, settings: dict) -> bool:
135
+ """Compute baseline-subtracted neuropil-corrected fluorescence and save as F_sub.npy.
136
+
137
+ F_sub = dcnv.preprocess(F - neucoeff * Fneu) using params from settings.
138
+ settings is a two-level dict: the 'extraction:' section (note colon) holds
139
+ neuropil_coefficient; the 'dcnv_preprocess' section holds baseline params.
140
+ Suite2p flat-dict keys are used as fallback when sections are absent.
141
+ Returns True if F_sub.npy was created.
142
+ """
143
+ import torch
144
+ from suite2p.extraction import dcnv
145
+
146
+ f_path = source_dir / "F.npy"
147
+ fneu_path = source_dir / "Fneu.npy"
148
+ if not f_path.exists() or not fneu_path.exists():
149
+ return False
150
+
151
+ F = np.load(f_path)
152
+ Fneu = np.load(fneu_path)
153
+
154
+ # neuropil_coefficient is in the 'extraction:' section (note colon in key name);
155
+ # fall back to suite2p's flat 'neucoeff' key.
156
+ extraction = settings.get('extraction:', {})
157
+ neucoeff = float(extraction.get('neuropil_coefficient', settings.get('neucoeff', 0.7)))
158
+ Fc = F - neucoeff * Fneu
159
+
160
+ # dcnv params are in the 'dcnv_preprocess' section; fall back to suite2p flat keys.
161
+ dcnv_section = settings.get('dcnv_preprocess', {})
162
+ def _p(key, default):
163
+ return dcnv_section.get(key, settings.get(key, default))
164
+
165
+ device = torch.device(settings.get('torch_device', 'cpu'))
166
+ F_sub = dcnv.preprocess(
167
+ F=Fc,
168
+ baseline=_p('baseline', 'maximin'),
169
+ win_baseline=float(_p('win_baseline', 60.0)),
170
+ sig_baseline=float(_p('sig_baseline', 10.0)),
171
+ fs=float(settings.get('fs', 10.0)),
172
+ prctile_baseline=float(_p('prctile_baseline', 8.0)),
173
+ batch_size=int(settings.get('batch_size', 200)),
174
+ device=device,
175
+ )
176
+ np.save(source_dir / "F_sub.npy", F_sub)
177
+ return True
178
+
179
+
180
+ def _sync_source_dir(
181
+ source_dir: Path,
182
+ out_dir: Path,
183
+ sync_indices: list[tuple[np.ndarray, np.ndarray]],
184
+ ) -> bool:
185
+ """Load F/Fneu/spks/F_sub from source_dir and save synced TsdFrames to out_dir.
186
+
187
+ Returns True if any arrays were found and processed.
188
+ """
189
+ import pynapple as nap
190
+
191
+ arrays_to_sync = {}
192
+ for name in ('F', 'Fneu', 'spks', 'F_sub'):
193
+ npy_path = source_dir / f"{name}.npy"
194
+ if npy_path.exists():
195
+ arrays_to_sync[name] = np.load(npy_path)
196
+
197
+ if not arrays_to_sync:
198
+ return False
199
+
200
+ out_dir.mkdir(parents=True, exist_ok=True)
201
+
202
+ # Accumulate selected columns across all tif files; timestamps are already
203
+ # globally offset so we can concatenate directly.
204
+ all_selected = {name: [] for name in arrays_to_sync}
205
+ all_t = []
206
+ for global_indices, t_frames in sync_indices:
207
+ for name, arr in arrays_to_sync.items():
208
+ valid_mask = global_indices < arr.shape[1]
209
+ gi = global_indices[valid_mask]
210
+ all_selected[name].append(arr[:, gi])
211
+ # Use the shortest valid time vector across arrays for this segment
212
+ n_valid = min(
213
+ (global_indices < arr.shape[1]).sum() for arr in arrays_to_sync.values()
214
+ )
215
+ all_t.append(t_frames[:n_valid])
216
+
217
+ t = np.concatenate(all_t)
218
+ for name, chunks in all_selected.items():
219
+ selected = np.concatenate(chunks, axis=1) # (n_cells, total_frames)
220
+ n = min(len(t), selected.shape[1])
221
+ tsd_frame = nap.TsdFrame(t=t[:n], d=selected[:, :n].T, time_units='s')
222
+ out_path = out_dir / f"{name}_sync.npz"
223
+ tsd_frame.save(out_path)
224
+ print(f" Saved {out_path.relative_to(out_dir.parent)} ({tsd_frame.shape})")
225
+
226
+ return True
227
+
228
+
229
+ def create_synced_outputs(
230
+ tif_files: list[Path],
231
+ sync_results: list[dict],
232
+ suite2p_output_dir: Path,
233
+ behavior_sync_dir: Path,
234
+ block_size: int = 3,
235
+ ) -> None:
236
+ """Select suite2p traces by synced frame indices and save as pynapple TsdFrames.
237
+
238
+ Processes the combined output directory and each per-plane directory found
239
+ under suite2p_output_dir. Synced files are saved as:
240
+ behavior_sync_dir/F_sync.npz (combined)
241
+ behavior_sync_dir/plane0/F_sync.npz (per plane)
242
+ ...
243
+ """
244
+ sync_indices = _build_sync_indices(tif_files, sync_results, block_size)
245
+
246
+ # Combined output
247
+ combined_dir = suite2p_output_dir / "combined"
248
+ if not _sync_source_dir(combined_dir, behavior_sync_dir, sync_indices):
249
+ print(" No F/Fneu/spks arrays found in suite2p combined output, skipping.")
250
+
251
+ # Per-plane outputs (plane0, plane1, ...)
252
+ plane_dirs = sorted(suite2p_output_dir.glob("plane[0-9]*"))
253
+ for plane_dir in plane_dirs:
254
+ out_dir = behavior_sync_dir / plane_dir.name
255
+ if not _sync_source_dir(plane_dir, out_dir, sync_indices):
256
+ print(f" No arrays found in {plane_dir.name}, skipping.")
257
+
258
+
259
+ class Suite2PExtractor(SourceExtractor):
260
+ def __init__(self, data: dict):
261
+ super().__init__(data)
262
+ params_file = Path(data["params_file"])
263
+ if not params_file.is_absolute():
264
+ params_file = Path(data.get("root_path", ".")) / params_file
265
+ self.user_params, self.settings = _load_settings(params_file)
266
+
267
+ def get_job_subdir(self, job_id: str) -> str:
268
+ # Suite2P writes all outputs directly into results_path, so there is no
269
+ # separate job directory created under job_root_dir. Returning a name
270
+ # that will not be created keeps the CLI copy-back step a no-op.
271
+ return f"s2p-{job_id}"
272
+
273
+ def save_reproducibility_info(self, results_path: Path) -> None:
274
+ with open(results_path / "params_supplied.json", "w") as f:
275
+ json.dump(_params_to_json_serializable(self.user_params), f, indent=2)
276
+ with open(results_path / "params_used.json", "w") as f:
277
+ json.dump(_params_to_json_serializable(self.settings), f, indent=2)
278
+
279
+ def run(self, tifs: list[Path], job_root_dir: Path, job_id: str, results_path: Path) -> None:
280
+ import suite2p
281
+
282
+ # Collect all tif files into a single flat folder as required by suite2p.
283
+ # This folder is cleaned up after the run.
284
+ collected_folder = results_path / "collected_input"
285
+ collected_folder.mkdir(parents=True, exist_ok=True)
286
+ for tif in tifs:
287
+ dst = collected_folder / tif.name
288
+ shutil.copy2(tif, dst)
289
+ print(f" Collected {tif.name} -> {dst}")
290
+
291
+ # Create a unique scratch directory for suite2p's binary files.
292
+ # Parent is taken from data["temp_dir"] (which the CLI may override with
293
+ # --working-dir before calling run()).
294
+ temp_dir_parent = self.data.get("temp_dir")
295
+ if temp_dir_parent is not None:
296
+ fast_disk_parent = Path(temp_dir_parent)
297
+ fast_disk_parent.mkdir(parents=True, exist_ok=True)
298
+ else:
299
+ fast_disk_parent = None
300
+ fast_disk = Path(tempfile.mkdtemp(prefix=f"s2p_{job_id}_", dir=fast_disk_parent))
301
+
302
+ db = {
303
+ "data_path": [str(collected_folder)],
304
+ "fast_disk": str(fast_disk),
305
+ "delete_bin": False,
306
+ "move_bin": True,
307
+ "save_folder": str(results_path),
308
+ }
309
+ # Mirror acquisition parameters from settings into db as the notebook does.
310
+ for key in ('fs', 'tau', 'nplanes', 'nchannels', 'functional_chan',
311
+ 'force_sktiff', 'ignore_flyback', 'keep_movie_raw'):
312
+ if key in self.settings:
313
+ db[key] = self.settings[key]
314
+
315
+ try:
316
+ suite2p.run_s2p(db, self.settings)
317
+ finally:
318
+ if collected_folder.exists():
319
+ shutil.rmtree(collected_folder)
320
+ print(f" Cleaned up collected input folder: {collected_folder}")
321
+ if fast_disk.exists():
322
+ shutil.rmtree(fast_disk)
323
+ print(f" Cleaned up fast_disk scratch directory: {fast_disk}")
324
+
325
+ if self.data.get('do_F_sub', False):
326
+ print("\nComputing F_sub (baseline-subtracted neuropil-corrected fluorescence)...")
327
+ for sub_dir in [results_path / "combined"] + sorted(results_path.glob("plane[0-9]*")):
328
+ if _compute_F_sub(sub_dir, self.settings):
329
+ print(f" Saved F_sub.npy in {sub_dir.name}/")
330
+
331
+ def create_synced_outputs(
332
+ self,
333
+ tif_files: list[Path],
334
+ sync_results: list[dict],
335
+ results_path: Path,
336
+ behavior_sync_dir: Path,
337
+ block_size: int,
338
+ ) -> None:
339
+ # Suite2P saves combined outputs directly under results_path/combined/
340
+ create_synced_outputs(
341
+ tif_files, sync_results, results_path, behavior_sync_dir, block_size
342
+ )
@@ -0,0 +1,143 @@
1
+ """Suite3D source extraction algorithm."""
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+
7
+ from .base import SourceExtractor
8
+
9
+
10
+ def _load_params(params_path: Path) -> dict:
11
+ with open(params_path) as f:
12
+ params = json.load(f)
13
+ params.pop("comments", None)
14
+ if "planes" in params:
15
+ params["planes"] = np.array(params["planes"])
16
+ if "pc_size" in params:
17
+ params["pc_size"] = np.array(params["pc_size"])
18
+ return params
19
+
20
+
21
+ def _params_to_json_serializable(params: dict) -> dict:
22
+ return {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in params.items()}
23
+
24
+
25
+ def _get_tif_n_frames(tif_path: Path) -> int:
26
+ import tifffile
27
+ with tifffile.TiffFile(tif_path) as tif:
28
+ return len(tif.pages)
29
+
30
+
31
+ def create_synced_outputs(
32
+ tif_files: list[Path],
33
+ sync_results: list[dict],
34
+ rois_dir: Path,
35
+ behavior_sync_dir: Path,
36
+ block_size: int = 3,
37
+ ) -> None:
38
+ """Select suite3d traces by synced frame indices and save as pynapple TsdFrames.
39
+
40
+ For each session, loads F, Fneu, and spks from rois_dir, selects the columns
41
+ corresponding to synchronized frames (using frames_time_idx.d plus the cumulative
42
+ frame offset for that TIF), and saves a pynapple TsdFrame per array as
43
+ F_sync.npz, Fneu_sync.npz, spks_sync.npz in behavior_sync_dir.
44
+ """
45
+ import pynapple as nap
46
+
47
+ arrays_to_sync = {}
48
+ for name in ('F', 'Fneu', 'spks'):
49
+ npy_path = rois_dir / f"{name}.npy"
50
+ if npy_path.exists():
51
+ arrays_to_sync[name] = np.load(npy_path)
52
+
53
+ if not arrays_to_sync:
54
+ print(" No F/Fneu/spks arrays found in rois directory, skipping synced outputs.")
55
+ return
56
+
57
+ frame_offsets = [0]
58
+ for tif in tif_files[:-1]:
59
+ frame_offsets.append(frame_offsets[-1] + _get_tif_n_frames(tif))
60
+
61
+ time_offsets = [0]
62
+ for stats in sync_results:
63
+ time_offsets.append(time_offsets[-1] + stats['frames_time_idx'].t[-1])
64
+
65
+ for i, stats in enumerate(sync_results):
66
+ stats['time_offsets'] = time_offsets[i]
67
+
68
+ for tif, stats, offset, time_offset in zip(tif_files, sync_results, frame_offsets, time_offsets):
69
+ frames_time_idx = stats['frames_time_idx']
70
+ local_indices = frames_time_idx.d.astype(int)
71
+
72
+ if block_size > 1:
73
+ # Assign each synced frame to its volume (integer division).
74
+ # Only keep volumes where all block_size frames are present.
75
+ frame_to_volume = local_indices // block_size
76
+ volumes, counts = np.unique(frame_to_volume, return_counts=True)
77
+ complete_volumes = volumes[counts == block_size]
78
+ if len(complete_volumes) < len(volumes):
79
+ n_dropped = len(volumes) - len(complete_volumes)
80
+ print(f" Warning: dropping {n_dropped} incomplete volume(s) "
81
+ f"(fewer than {block_size} synced frames) for {tif.name}")
82
+ local_indices = complete_volumes # now volume indices
83
+ # Find the position in the Tsd of each volume's first frame number.
84
+ # frames_time_idx.d contains frame numbers (with possible gaps), so
85
+ # we cannot use frame number directly as a positional index into .t.
86
+ first_frame_numbers = local_indices * block_size
87
+ pos = np.searchsorted(frames_time_idx.d.astype(int), first_frame_numbers)
88
+ t_frames = frames_time_idx.t[pos]
89
+ else:
90
+ t_frames = frames_time_idx.t
91
+ global_indices = local_indices
92
+ t_frames = t_frames + time_offset
93
+
94
+ for name, arr in arrays_to_sync.items():
95
+ global_indices = global_indices[np.where(global_indices < arr.shape[1])]
96
+ selected = arr[:, global_indices] # (n_cells, n_selected_frames)
97
+ t = t_frames[:len(global_indices)]
98
+ tsd_frame = nap.TsdFrame(t=t, d=selected.T, time_units='s')
99
+ out_path = behavior_sync_dir / f"{name}_sync.npz"
100
+ tsd_frame.save(out_path)
101
+ print(f" Saved {out_path.name} ({tsd_frame.shape})")
102
+
103
+
104
+ class Suite3DExtractor(SourceExtractor):
105
+ def __init__(self, data: dict):
106
+ super().__init__(data)
107
+ params_file = Path(data["params_file"])
108
+ if not params_file.is_absolute():
109
+ params_file = Path(data.get("root_path", ".")) / params_file
110
+ self.params = _load_params(params_file)
111
+
112
+ def get_job_subdir(self, job_id: str) -> str:
113
+ return f"s3d-{job_id}"
114
+
115
+ def save_reproducibility_info(self, results_path: Path) -> None:
116
+ saved_params = _params_to_json_serializable(self.params)
117
+ with open(results_path / "params_used.json", "w") as f:
118
+ json.dump(saved_params, f, indent=2)
119
+
120
+ def run(self, tifs: list[Path], job_root_dir: Path, job_id: str, results_path: Path) -> None:
121
+ from suite3d.job import Job
122
+
123
+ job = Job(job_root_dir, job_id, tifs=tifs,
124
+ params=self.params, create=True, overwrite=True, verbosity=3)
125
+ job.params.update(self.params)
126
+ job.run_init_pass()
127
+ job.register()
128
+ job.calculate_corr_map()
129
+ job.segment_rois()
130
+ job.compute_npil_masks()
131
+ job.extract_and_deconvolve()
132
+ job.export_results(results_path, result_dir_name="rois")
133
+
134
+ def create_synced_outputs(
135
+ self,
136
+ tif_files: list[Path],
137
+ sync_results: list[dict],
138
+ results_path: Path,
139
+ behavior_sync_dir: Path,
140
+ block_size: int,
141
+ ) -> None:
142
+ rois_dir = results_path / f"s3d-results-{results_path.name}"
143
+ create_synced_outputs(tif_files, sync_results, rois_dir, behavior_sync_dir, block_size)