lbm_suite2p_python 3.2.0__tar.gz → 3.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {lbm_suite2p_python-3.2.0/lbm_suite2p_python.egg-info → lbm_suite2p_python-3.2.2}/PKG-INFO +4 -2
  2. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/cli.py +143 -18
  3. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/run_lsp.py +115 -36
  4. lbm_suite2p_python-3.2.2/lbm_suite2p_python/utils.py +229 -0
  5. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2/lbm_suite2p_python.egg-info}/PKG-INFO +4 -2
  6. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python.egg-info/requires.txt +3 -1
  7. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/pyproject.toml +13 -2
  8. lbm_suite2p_python-3.2.0/lbm_suite2p_python/utils.py +0 -144
  9. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/LICENSE.md +0 -0
  10. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/MANIFEST.in +0 -0
  11. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/README.md +0 -0
  12. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/__init__.py +0 -0
  13. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/__main__.py +0 -0
  14. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/_benchmarking.py +0 -0
  15. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/cellpose.py +0 -0
  16. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/conversion.py +0 -0
  17. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/db_settings.py +0 -0
  18. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/default_ops.py +0 -0
  19. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/grid_search.py +0 -0
  20. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/gui.py +0 -0
  21. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/merging.py +0 -0
  22. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/postprocessing.py +0 -0
  23. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/volume.py +0 -0
  24. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python/zplane.py +0 -0
  25. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python.egg-info/SOURCES.txt +0 -0
  26. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python.egg-info/dependency_links.txt +0 -0
  27. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python.egg-info/entry_points.txt +0 -0
  28. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/lbm_suite2p_python.egg-info/top_level.txt +0 -0
  29. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/setup.cfg +0 -0
  30. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/tests/test_frame_count_aliases.py +0 -0
  31. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/tests/test_pipeline_parameters.py +0 -0
  32. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/tests/test_refactored_pipeline.py +0 -0
  33. {lbm_suite2p_python-3.2.0 → lbm_suite2p_python-3.2.2}/tests/test_run_volume.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lbm_suite2p_python
3
- Version: 3.2.0
3
+ Version: 3.2.2
4
4
  Summary: Calcium Imaging Pipeline built with Suite2p, Cellpose and Rastermap
5
5
  License-Expression: BSD-3-Clause
6
6
  Project-URL: homepage, https://github.com/MillerBrainObservatory/LBM-Suite2p-Python
@@ -11,9 +11,11 @@ Classifier: Programming Language :: Python :: 3 :: Only
11
11
  Requires-Python: <3.14,>=3.12.7
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE.md
14
- Requires-Dist: mbo_utilities>=3.2.1
14
+ Requires-Dist: mbo_utilities>=3.2.7
15
15
  Requires-Dist: suite2p>=1.0.0.1
16
16
  Requires-Dist: setuptools<81
17
+ Requires-Dist: torch
18
+ Requires-Dist: torchvision
17
19
  Provides-Extra: rastermap
18
20
  Requires-Dist: rastermap; extra == "rastermap"
19
21
  Provides-Extra: cellpose
@@ -23,6 +23,7 @@ Examples:
23
23
  """
24
24
 
25
25
  import argparse
26
+ import json
26
27
  import sys
27
28
  from pathlib import Path
28
29
  from typing import Any
@@ -67,6 +68,7 @@ def _get_ops_help() -> dict[str, str]:
67
68
  "cellprob_threshold": "cellpose cell probability threshold (lower = more cells)",
68
69
  "flow_threshold": "cellpose flow error threshold",
69
70
  "anatomical_only": "cellpose detection mode: 0=off, 1=max_proj, 2=mean, 3=enhanced, 4=max",
71
+ "algorithm": "detection algorithm: sparsery, sourcery, or cellpose",
70
72
  "pretrained_model": "cellpose model name (e.g., cpsam, cyto2, nuclei)",
71
73
  "do_registration": "whether to run motion correction",
72
74
  "nonrigid": "use nonrigid (piecewise) registration",
@@ -103,6 +105,10 @@ Examples:
103
105
  lsp data/ output/ --planes 1 2 3 # specific planes (1-indexed)
104
106
  lsp data/ output/ --num-timepoints 500 # quick test with 500 frames
105
107
  lsp data/ output/ --diameter 8 # custom cell diameter
108
+ lsp data/ output/ --tau 1.3 --algorithm cellpose # decay + detection algorithm
109
+ lsp data/ output/ --fix-phase --phasecorr-method median # reader phase correction
110
+ lsp data/ output/ --reader-kwargs '{"roi": 2}' # arbitrary imread kwargs
111
+ lsp data/ output/ --ops-file my_ops.json # load saved ops (CLI flags override)
106
112
  lsp --list-ops # show all suite2p parameters
107
113
  """,
108
114
  )
@@ -134,13 +140,21 @@ Examples:
134
140
  "--planes", nargs="*", type=int, dest="planes",
135
141
  help="z-planes to process (1-indexed, e.g., --planes 1 2 3)"
136
142
  )
143
+ pipeline.add_argument(
144
+ "--timepoints", nargs="*", type=int, dest="timepoints",
145
+ help="timepoints to process (1-indexed, e.g., --timepoints 1 50 100)"
146
+ )
137
147
  pipeline.add_argument(
138
148
  "--roi-mode", "--roi", type=int, dest="roi_mode",
139
149
  help="ROI mode: None=stitch, 0=split all, N=specific ROI"
140
150
  )
141
151
  pipeline.add_argument(
142
152
  "--num-timepoints", "--frames", type=int, dest="num_timepoints",
143
- help="number of frames/timepoints to process (for quick testing)"
153
+ help="number of timepoints to process (first N, for quick testing)"
154
+ )
155
+ pipeline.add_argument(
156
+ "--num-zplanes", type=int, dest="num_zplanes",
157
+ help="number of z-planes to process (first N)"
144
158
  )
145
159
  pipeline.add_argument(
146
160
  "--overwrite", action="store_true",
@@ -237,15 +251,57 @@ Examples:
237
251
  help="maximum cell diameter in pixels"
238
252
  )
239
253
 
240
- # reader options (for raw data)
254
+ # reader options (forwarded to mbo_utilities.imread)
241
255
  reader = parser.add_argument_group("reader options (raw scanimage data)")
242
256
  reader.add_argument(
243
- "--fix-phase", action="store_true",
244
- help="apply bidirectional phase correction"
257
+ "--fix-phase", action=argparse.BooleanOptionalAction, default=None,
258
+ help="bidirectional phase correction (reader default: on)"
259
+ )
260
+ reader.add_argument(
261
+ "--use-fft", action=argparse.BooleanOptionalAction, default=None,
262
+ help="FFT-based subpixel phase correction"
245
263
  )
246
264
  reader.add_argument(
247
- "--use-fft", action="store_true",
248
- help="use FFT for subpixel phase correction"
265
+ "--phasecorr-method", choices=["mean", "median", "max"], default=None,
266
+ help="phase-correction reduction method (default: mean)"
267
+ )
268
+ reader.add_argument(
269
+ "--channel", type=int, default=None,
270
+ help="zero-based color channel to read (multi-channel sources)"
271
+ )
272
+ reader.add_argument(
273
+ "--reader-kwargs", type=str, default=None, metavar="JSON",
274
+ help='extra imread kwargs as JSON, e.g. \'{"roi": 2}\''
275
+ )
276
+
277
+ # writer options (forwarded to the binary/zarr writer)
278
+ writer = parser.add_argument_group("writer options")
279
+ writer.add_argument(
280
+ "--writer-kwargs", type=str, default=None, metavar="JSON",
281
+ help='extra writer kwargs as JSON, e.g. \'{"target_chunk_mb": 200}\''
282
+ )
283
+
284
+ # rastermap options
285
+ rmap = parser.add_argument_group("rastermap options")
286
+ rmap.add_argument(
287
+ "--rastermap", action="store_true",
288
+ help="enable rastermap (planar + volumetric) with default settings"
289
+ )
290
+ rmap.add_argument(
291
+ "--rastermap-kwargs", type=str, default=None, metavar="JSON",
292
+ help='rastermap config JSON with "planar"/"volumetric" keys, '
293
+ 'e.g. \'{"planar": {"n_clusters": 50}}\''
294
+ )
295
+
296
+ # advanced
297
+ advanced = parser.add_argument_group("advanced")
298
+ advanced.add_argument(
299
+ "--ops-file", type=str, default=None,
300
+ help="base ops from a .npy/.json file or suite2p dir; CLI flags override"
301
+ )
302
+ advanced.add_argument(
303
+ "--replot", action=argparse.BooleanOptionalAction, default=True,
304
+ help="regenerate per-plane figures (default: on)"
249
305
  )
250
306
 
251
307
  # dynamically add all ops parameters
@@ -316,8 +372,9 @@ def list_ops():
316
372
  "Main Settings": ["nplanes", "nchannels", "fs", "tau", "frames_include"],
317
373
  "Registration": ["do_registration", "nonrigid", "batch_size", "maxregshift",
318
374
  "smooth_sigma", "nimg_init", "subpixel"],
319
- "Cell Detection": ["roidetect", "sparse_mode", "spatial_scale", "threshold_scaling",
320
- "max_overlap", "connected", "nbinned", "max_iterations"],
375
+ "Cell Detection": ["roidetect", "algorithm", "sparse_mode", "spatial_scale",
376
+ "threshold_scaling", "max_overlap", "connected", "nbinned",
377
+ "max_iterations"],
321
378
  "Cellpose": ["anatomical_only", "diameter", "cellprob_threshold", "flow_threshold",
322
379
  "pretrained_model", "spatial_hp_cp"],
323
380
  "Signal Extraction": ["neuropil_extract", "neucoeff", "spikedetect",
@@ -397,6 +454,65 @@ def build_cell_filters(args) -> list | None:
397
454
  return filters if filters else None
398
455
 
399
456
 
457
+ def _parse_json_arg(flag: str, value: str) -> dict:
458
+ """parse a JSON object from a CLI flag value; exit cleanly on error."""
459
+ try:
460
+ parsed = json.loads(value)
461
+ except json.JSONDecodeError as e:
462
+ raise SystemExit(f"Error: {flag} is not valid JSON: {e}")
463
+ if not isinstance(parsed, dict):
464
+ raise SystemExit(f"Error: {flag} must be a JSON object")
465
+ return parsed
466
+
467
+
468
+ def _load_ops_file(path: str) -> dict:
469
+ """load a base ops dict from a .json/.npy file or a suite2p directory."""
470
+ from lbm_suite2p_python import load_ops
471
+
472
+ p = Path(path).expanduser()
473
+ if p.suffix.lower() == ".json":
474
+ if not p.exists():
475
+ raise SystemExit(f"Error: --ops-file not found: {p}")
476
+ with open(p, "r", encoding="utf-8") as fh:
477
+ data = json.load(fh)
478
+ if not isinstance(data, dict):
479
+ raise SystemExit("Error: --ops-file JSON must be an object")
480
+ return data
481
+ return load_ops(p)
482
+
483
+
484
+ def build_reader_kwargs(args) -> dict | None:
485
+ """build imread kwargs from CLI reader args (None if empty)."""
486
+ kw = {}
487
+ if args.fix_phase is not None:
488
+ kw["fix_phase"] = args.fix_phase
489
+ if args.use_fft is not None:
490
+ kw["use_fft"] = args.use_fft
491
+ if args.phasecorr_method is not None:
492
+ kw["phasecorr_method"] = args.phasecorr_method
493
+ if args.channel is not None:
494
+ kw["channel"] = args.channel
495
+ if args.reader_kwargs:
496
+ kw.update(_parse_json_arg("--reader-kwargs", args.reader_kwargs))
497
+ return kw or None
498
+
499
+
500
+ def build_writer_kwargs(args) -> dict | None:
501
+ """build writer kwargs from CLI writer args (None if empty)."""
502
+ if args.writer_kwargs:
503
+ return _parse_json_arg("--writer-kwargs", args.writer_kwargs)
504
+ return None
505
+
506
+
507
+ def build_rastermap_kwargs(args) -> dict | None:
508
+ """build rastermap_kwargs from CLI args (None if disabled)."""
509
+ if args.rastermap_kwargs:
510
+ return _parse_json_arg("--rastermap-kwargs", args.rastermap_kwargs)
511
+ if args.rastermap:
512
+ return {"planar": {}, "volumetric": {}}
513
+ return None
514
+
515
+
400
516
  def build_ops(args, base_ops: dict) -> dict:
401
517
  """build ops dict from CLI args, overriding base_ops."""
402
518
  from lbm_suite2p_python.default_ops import s2p_ops
@@ -504,6 +620,13 @@ def main():
504
620
 
505
621
  output_path.mkdir(parents=True, exist_ok=True)
506
622
 
623
+ # parse config/advanced args up front so invalid JSON or a missing ops
624
+ # file fails before any logging or processing work begins
625
+ base_extra_ops = _load_ops_file(args.ops_file) if args.ops_file else None
626
+ reader_kwargs = build_reader_kwargs(args)
627
+ writer_kwargs = build_writer_kwargs(args)
628
+ rastermap_kwargs = build_rastermap_kwargs(args)
629
+
507
630
  log_path = output_path / "log.txt"
508
631
  log_file = open(log_path, "w", encoding="utf-8", buffering=1)
509
632
  _orig_stdout, _orig_stderr = sys.stdout, sys.stderr
@@ -530,17 +653,10 @@ def main():
530
653
  print(f"Input: {input_path}")
531
654
  print(f"Output: {output_path}")
532
655
 
533
- # build ops
534
- base_ops = lsp.default_ops()
656
+ # build ops (optionally starting from a user-supplied ops file)
657
+ base_ops = lsp.default_ops(ops=base_extra_ops) if base_extra_ops else lsp.default_ops()
535
658
  ops = build_ops(args, base_ops)
536
659
 
537
- # build reader kwargs
538
- reader_kwargs = {}
539
- if args.fix_phase:
540
- reader_kwargs["fix_phase"] = True
541
- if args.use_fft:
542
- reader_kwargs["use_fft"] = True
543
-
544
660
  # build cell filters
545
661
  cell_filters = build_cell_filters(args)
546
662
 
@@ -554,6 +670,10 @@ def main():
554
670
  print(f" Cellpose model: {ops.get('pretrained_model', 'cpsam')}")
555
671
  if cell_filters:
556
672
  print(f" Cell filters: {cell_filters}")
673
+ if reader_kwargs:
674
+ print(f" Reader: {reader_kwargs}")
675
+ if rastermap_kwargs:
676
+ print(f" Rastermap: {sorted(rastermap_kwargs)}")
557
677
 
558
678
  print(f"\n{'='*60}\n")
559
679
 
@@ -566,8 +686,10 @@ def main():
566
686
  save_path=output_path,
567
687
  ops=ops,
568
688
  planes=args.planes,
689
+ timepoints=args.timepoints,
569
690
  roi_mode=args.roi_mode,
570
691
  num_timepoints=args.num_timepoints,
692
+ num_zplanes=args.num_zplanes,
571
693
  keep_reg=args.keep_reg,
572
694
  keep_raw=args.keep_raw,
573
695
  force_reg=args.force_reg or args.overwrite,
@@ -580,7 +702,10 @@ def main():
580
702
  cell_filters=cell_filters,
581
703
  accept_all_cells=args.accept_all_cells,
582
704
  save_json=args.save_json,
583
- reader_kwargs=reader_kwargs if reader_kwargs else None,
705
+ reader_kwargs=reader_kwargs,
706
+ writer_kwargs=writer_kwargs,
707
+ rastermap_kwargs=rastermap_kwargs,
708
+ replot=args.replot,
584
709
  workers=args.workers,
585
710
  skip_volumetric=args.skip_volumetric,
586
711
  threads_per_worker=args.threads_per_worker,
@@ -446,14 +446,15 @@ def _is_valid_torch_checkpoint(path) -> bool:
446
446
  def _prewarm_cellpose_model(ops) -> None:
447
447
  """Download the cellpose model once, in the parent, before workers fan out.
448
448
 
449
- cellpose's cache_CPSAM_model_path downloads to a temp file then renames
450
- with no cross-process lock. Multiple workers hitting an empty cache at once
449
+ cellpose's model cache downloads to a temp file then renames with no
450
+ cross-process lock. Multiple workers hitting an empty cache at once
451
451
  race: one wins the rename, the rest fail (Windows WinError 32/183) or read a
452
452
  half-written file (PytorchStreamReader miniz error). Warming here serializes
453
453
  the download so workers only ever read a complete file. A corrupt leftover
454
454
  from a prior failed run is removed and re-downloaded.
455
455
  """
456
- if not (ops.get("roidetect", True) and ops.get("anatomical_only", 0) > 0):
456
+ if not (ops.get("roidetect", True)
457
+ and (ops.get("anatomical_only", 0) > 0 or ops.get("algorithm") == "cellpose")):
457
458
  return
458
459
  try:
459
460
  from cellpose import models as cp_models
@@ -467,7 +468,11 @@ def _prewarm_cellpose_model(ops) -> None:
467
468
  except OSError:
468
469
  pass
469
470
  try:
470
- cp_models.cache_CPSAM_model_path()
471
+ # cellpose 4.x renamed cache_CPSAM_model_path() -> cache_model_path(backbone)
472
+ if hasattr(cp_models, "cache_model_path"):
473
+ cp_models.cache_model_path("cpsam")
474
+ else:
475
+ cp_models.cache_CPSAM_model_path()
471
476
  except Exception as exc:
472
477
  print(
473
478
  f"Warning: could not pre-download cellpose model ({exc}); "
@@ -780,6 +785,38 @@ def _prepare_plane_ops(*, base_ops, plane_idx, num_planes, input_arr,
780
785
  return current_ops
781
786
 
782
787
 
788
+ def _resolve_timepoints(timepoints=None, frames=None, frame_indices=None):
789
+ """Resolve the canonical 1-based ``timepoints`` selection.
790
+
791
+ ``frames`` (1-based) and ``frame_indices`` (0-based) are deprecated
792
+ aliases and emit a DeprecationWarning. Returns a 1-based list, or
793
+ None for all timepoints.
794
+ """
795
+ import warnings
796
+
797
+ if frames is not None:
798
+ warnings.warn(
799
+ "'frames' is deprecated, use 'timepoints' (1-based)",
800
+ DeprecationWarning,
801
+ stacklevel=3,
802
+ )
803
+ if timepoints is None:
804
+ timepoints = frames
805
+ if frame_indices is not None:
806
+ warnings.warn(
807
+ "'frame_indices' is deprecated, use 'timepoints' (1-based)",
808
+ DeprecationWarning,
809
+ stacklevel=3,
810
+ )
811
+ if timepoints is None:
812
+ timepoints = [int(i) + 1 for i in frame_indices]
813
+ if timepoints is None:
814
+ return None
815
+ if isinstance(timepoints, (int, np.integer)):
816
+ return [int(timepoints)]
817
+ return [int(t) for t in timepoints]
818
+
819
+
783
820
  def pipeline(
784
821
  input_data,
785
822
  save_path: str | Path = None,
@@ -793,7 +830,9 @@ def pipeline(
793
830
  force_reg: bool = False,
794
831
  force_detect: bool = False,
795
832
  replot: bool = True,
833
+ timepoints: list | int | None = None,
796
834
  num_timepoints: int = None,
835
+ num_zplanes: int = None,
797
836
  frame_indices: list | None = None,
798
837
  dff_window_size: int = None,
799
838
  dff_percentile: int = 20,
@@ -859,16 +898,21 @@ def pipeline(
859
898
  Regenerate per-plane figures. Set False to skip per-plane figure
860
899
  regeneration (e.g. the volumetric aggregate over already-plotted
861
900
  planes); suite2p and the volumetric plots are unaffected.
862
- num_timepoints : int, optional
863
- Limit processing to first N frames (truncation only). For an
864
- explicit set of frames or a strided selection, use
865
- ``frame_indices`` instead.
866
- frame_indices : list[int], optional
867
- Explicit 0-based frame indices to process. Supports stride
868
- (e.g. ``list(range(0, 1574, 2))`` for every other frame).
901
+ timepoints : list[int] or int, optional
902
+ Explicit 1-based timepoints to process. Supports stride
903
+ (e.g. ``list(range(1, 1575, 2))`` for every other timepoint).
869
904
  When provided, the implicit stride is used by `OutputMetadata`
870
905
  to reactively scale `fs` (e.g. stride of 2 → `fs / 2` in the
871
906
  output ops.npy). Takes precedence over ``num_timepoints``.
907
+ num_timepoints : int, optional
908
+ Limit processing to first N timepoints (truncation only). For an
909
+ explicit set or a strided selection, use ``timepoints`` instead.
910
+ num_zplanes : int, optional
911
+ Limit processing to the first N z-planes. Shortcut for
912
+ ``planes=[1..N]``; ignored when ``planes`` is given.
913
+ frame_indices : list[int], optional
914
+ Deprecated alias for ``timepoints`` (0-based). Emits a
915
+ DeprecationWarning.
872
916
  dff_window_size : int, optional
873
917
  Window size for rolling percentile dF/F baseline (frames).
874
918
  If None, auto-calculated as ~10 * tau * fs.
@@ -981,7 +1025,15 @@ def pipeline(
981
1025
  DeprecationWarning,
982
1026
  stacklevel=2,
983
1027
  )
984
- num_timepoints = num_frames
1028
+ if num_timepoints is None:
1029
+ num_timepoints = num_frames
1030
+
1031
+ # canonical 1-based timepoint selection (frames/frame_indices deprecated).
1032
+ timepoints = _resolve_timepoints(timepoints, kwargs.pop("frames", None), frame_indices)
1033
+ frame_indices = None
1034
+ # num_zplanes is a count shortcut for planes=[1..N].
1035
+ if num_zplanes is not None and planes is None:
1036
+ planes = list(range(1, int(num_zplanes) + 1))
985
1037
 
986
1038
  # flatten (db, settings) into ops so downstream run_volume / run_plane
987
1039
  # don't each need to forward the pair. explicit ops keys still win.
@@ -991,14 +1043,10 @@ def pipeline(
991
1043
 
992
1044
  reader_kwargs = reader_kwargs or {}
993
1045
  writer_kwargs = writer_kwargs or {}
1046
+ # num_timepoints truncation reaches the writer; an explicit `timepoints`
1047
+ # selection is forwarded as a param and rebuilt by run_plane.
994
1048
  if num_timepoints is not None:
995
- writer_kwargs["num_frames"] = num_timepoints
996
-
997
- # 1-based frame numbers.
998
- if frame_indices is not None:
999
- writer_kwargs["frames"] = [int(i) + 1 for i in frame_indices]
1000
- # don't double-pass num_frames; len(frame_indices) is implicit
1001
- writer_kwargs.pop("num_frames", None)
1049
+ writer_kwargs["num_timepoints"] = num_timepoints
1002
1050
 
1003
1051
  # Always load array to check dimensions and ensure downstream functions have the array shape
1004
1052
  # If input is already array, this is fast. If path or list of paths, it loads lazy array.
@@ -1034,7 +1082,7 @@ def pipeline(
1034
1082
  force_reg=force_reg,
1035
1083
  force_detect=force_detect,
1036
1084
  replot=replot,
1037
- frame_indices=frame_indices,
1085
+ timepoints=timepoints,
1038
1086
  dff_window_size=dff_window_size,
1039
1087
  dff_percentile=dff_percentile,
1040
1088
  dff_smooth_window=dff_smooth_window,
@@ -1074,7 +1122,7 @@ def pipeline(
1074
1122
  force_reg=force_reg,
1075
1123
  force_detect=force_detect,
1076
1124
  replot=replot,
1077
- frame_indices=frame_indices,
1125
+ timepoints=timepoints,
1078
1126
  dff_window_size=dff_window_size,
1079
1127
  dff_percentile=dff_percentile,
1080
1128
  dff_smooth_window=dff_smooth_window,
@@ -1246,6 +1294,9 @@ def run_volume(
1246
1294
  force_reg: bool = False,
1247
1295
  force_detect: bool = False,
1248
1296
  replot: bool = True,
1297
+ timepoints: list | int | None = None,
1298
+ num_timepoints: int = None,
1299
+ num_zplanes: int = None,
1249
1300
  frame_indices: list | None = None,
1250
1301
  dff_window_size: int = None,
1251
1302
  dff_percentile: int = 20,
@@ -1337,6 +1388,17 @@ def run_volume(
1337
1388
  _resolve_gpu_env()
1338
1389
  _apply_thread_limits(threads_per_worker)
1339
1390
 
1391
+ # canonical 1-based timepoints (frames/frame_indices deprecated); keep a
1392
+ # 0-based frame_indices for this function's reactive-metadata plumbing,
1393
+ # and forward `timepoints` to run_plane.
1394
+ timepoints = _resolve_timepoints(timepoints, kwargs.pop("frames", None), frame_indices)
1395
+ frame_indices = [int(t) - 1 for t in timepoints] if timepoints is not None else None
1396
+ if num_zplanes is not None and planes is None:
1397
+ planes = list(range(1, int(num_zplanes) + 1))
1398
+ writer_kwargs = dict(writer_kwargs or {})
1399
+ if num_timepoints is not None:
1400
+ writer_kwargs.setdefault("num_timepoints", num_timepoints)
1401
+
1340
1402
  # Handle input data
1341
1403
  input_arr = None
1342
1404
  input_paths = []
@@ -1430,7 +1492,7 @@ def run_volume(
1430
1492
  force_reg=force_reg,
1431
1493
  force_detect=force_detect,
1432
1494
  replot=replot,
1433
- frame_indices=frame_indices,
1495
+ timepoints=timepoints,
1434
1496
  dff_window_size=dff_window_size,
1435
1497
  dff_percentile=dff_percentile,
1436
1498
  dff_smooth_window=dff_smooth_window,
@@ -2283,6 +2345,8 @@ def run_plane(
2283
2345
  force_reg: bool = False,
2284
2346
  force_detect: bool = False,
2285
2347
  replot: bool = True,
2348
+ timepoints: list | int | None = None,
2349
+ num_timepoints: int = None,
2286
2350
  frame_indices: list | None = None,
2287
2351
  dff_window_size: int = None,
2288
2352
  dff_percentile: int = 20,
@@ -2353,13 +2417,18 @@ def run_plane(
2353
2417
  Example: {"n_clusters": 50, "n_PCs": 64}.
2354
2418
  save_json : bool, default False
2355
2419
  Save ops as JSON.
2356
- frame_indices : list[int], optional
2357
- Explicit 0-based timepoint indices. Supports stride
2358
- (e.g. ``list(range(0, 1574, 2))`` for every other frame).
2420
+ timepoints : list[int] or int, optional
2421
+ Explicit 1-based timepoints. Supports stride
2422
+ (e.g. ``list(range(1, 1575, 2))`` for every other timepoint).
2359
2423
  When provided, the binary on disk contains exactly these
2360
- frames, and `OutputMetadata` reactively scales `fs` in ops.npy
2361
- based on the implicit stride. Takes precedence over any
2362
- ``num_frames`` in ``writer_kwargs``.
2424
+ timepoints, and `OutputMetadata` reactively scales `fs` in ops.npy
2425
+ based on the implicit stride. Takes precedence over
2426
+ ``num_timepoints``.
2427
+ num_timepoints : int, optional
2428
+ Limit processing to first N timepoints (truncation only).
2429
+ frame_indices : list[int], optional
2430
+ Deprecated alias for ``timepoints`` (0-based). Emits a
2431
+ DeprecationWarning.
2363
2432
  plane_name : str, optional
2364
2433
  Custom name for the plane subdirectory.
2365
2434
  reader_kwargs : dict, optional
@@ -2379,6 +2448,14 @@ def run_plane(
2379
2448
 
2380
2449
  _resolve_gpu_env()
2381
2450
 
2451
+ # canonical 1-based timepoints (frames/frame_indices deprecated); convert to
2452
+ # the 0-based frame_indices this function consumes internally.
2453
+ timepoints = _resolve_timepoints(timepoints, kwargs.pop("frames", None), frame_indices)
2454
+ frame_indices = [int(t) - 1 for t in timepoints] if timepoints is not None else None
2455
+ writer_kwargs = dict(writer_kwargs or {})
2456
+ if num_timepoints is not None:
2457
+ writer_kwargs.setdefault("num_timepoints", num_timepoints)
2458
+
2382
2459
  progress_callback = kwargs.pop("progress_callback", None)
2383
2460
 
2384
2461
  if "debug" in kwargs:
@@ -2637,7 +2714,8 @@ def run_plane(
2637
2714
  else:
2638
2715
  # prefer the user-specified frame limit over raw array shape
2639
2716
  nframes_hint = (
2640
- writer_kwargs.get("num_frames")
2717
+ writer_kwargs.get("num_timepoints")
2718
+ or writer_kwargs.get("num_frames")
2641
2719
  or ops.get("nframes")
2642
2720
  )
2643
2721
  if not nframes_hint and input_arr is not None and hasattr(input_arr, "shape"):
@@ -2666,8 +2744,8 @@ def run_plane(
2666
2744
  ops_file = plane_dir / "ops.npy"
2667
2745
 
2668
2746
  # extract expected dims from input for cache validation
2669
- # honors writer_kwargs["num_frames"] limit if user requested fewer frames
2670
- exp_nframes = writer_kwargs.get("num_frames")
2747
+ # honors num_timepoints truncation if user requested fewer frames
2748
+ exp_nframes = writer_kwargs.get("num_timepoints") or writer_kwargs.get("num_frames")
2671
2749
  exp_ly = exp_lx = None
2672
2750
  if input_arr is not None and hasattr(input_arr, "shape"):
2673
2751
  if exp_nframes is None:
@@ -2796,12 +2874,13 @@ def run_plane(
2796
2874
  write_planes = [plane] if _get_num_planes(file) > 1 else None
2797
2875
 
2798
2876
  write_kw = dict(writer_kwargs)
2799
- # If the caller gave us explicit frame indices, pass them as
2800
- # `frames=` (1-based) to imwrite. This wins over any stale
2801
- # `num_frames` truncation in writer_kwargs — strided semantics
2802
- # require an explicit index list, not a count.
2877
+ # If the caller gave us explicit timepoints, pass them as
2878
+ # `timepoints=` (1-based) to imwrite. This wins over any stale
2879
+ # truncation count in writer_kwargs — strided semantics require an
2880
+ # explicit index list, not a count.
2803
2881
  if frame_indices is not None:
2804
- write_kw["frames"] = [int(i) + 1 for i in frame_indices]
2882
+ write_kw["timepoints"] = [int(i) + 1 for i in frame_indices]
2883
+ write_kw.pop("num_timepoints", None)
2805
2884
  write_kw.pop("num_frames", None)
2806
2885
 
2807
2886
  imwrite(
@@ -0,0 +1,229 @@
1
+ import os
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+
6
+ # mbo_utilities >= 4.0 exposes a single LazyArray base; isinstance covers
7
+ # every built-in and any third-party plugin. The class-name tuple is the
8
+ # fallback for mbo_utilities < 4.0, which has no shared base.
9
+ try:
10
+ from mbo_utilities import LazyArray as _LazyArray
11
+ except ImportError: # mbo_utilities < 4.0
12
+ _LazyArray = None
13
+
14
+ _LAZY_ARRAY_TYPES = (
15
+ "ScanImageArray",
16
+ "LBMArray",
17
+ "PiezoArray",
18
+ "SinglePlaneArray",
19
+ "Suite2pArray",
20
+ "MBOTiffArray",
21
+ "MboRawArray",
22
+ "TiffArray",
23
+ "ZarrArray",
24
+ "H5Array",
25
+ "NumpyArray",
26
+ "BinArray",
27
+ )
28
+
29
+
30
+ def _is_lazy_array(obj):
31
+ """Check if obj is an mbo_utilities lazy array type."""
32
+ if _LazyArray is not None and isinstance(obj, _LazyArray):
33
+ return True
34
+ return type(obj).__name__ in _LAZY_ARRAY_TYPES
35
+
36
+
37
+ def _get_num_planes(arr):
38
+ """
39
+ Get number of z-planes from a lazy array.
40
+
41
+ mbo_utilities arrays are always 5D TCZYX, so Z is at shape[2].
42
+ Falls back to legacy 4D TZYX (Z at shape[1]) and other heuristics
43
+ for non-mbo arrays.
44
+
45
+ Parameters
46
+ ----------
47
+ arr : array-like
48
+ Input array, typically from mbo_utilities.
49
+
50
+ Returns
51
+ -------
52
+ int
53
+ Number of z-planes (1 if no Z dimension).
54
+ """
55
+ # mbo_utilities Shape5DMixin
56
+ if hasattr(arr, "nz"):
57
+ return arr.nz
58
+ if hasattr(arr, "num_planes") and arr.num_planes is not None:
59
+ return arr.num_planes
60
+ shape = arr.shape
61
+ if len(shape) == 5:
62
+ return shape[2] # 5D TCZYX
63
+ if len(shape) == 4:
64
+ return shape[1] # legacy 4D TZYX
65
+ return 1
66
+
67
+
68
+ def _resize_masks_fit_crop(mask, target_shape):
69
+ """Centers a mask within the target shape, cropping if too large or padding if too small."""
70
+ sy, sx = mask.shape
71
+ ty, tx = target_shape
72
+
73
+ # If mask is larger, crop it
74
+ if sy > ty or sx > tx:
75
+ start_y = (sy - ty) // 2
76
+ start_x = (sx - tx) // 2
77
+ return mask[start_y : start_y + ty, start_x : start_x + tx]
78
+
79
+ # If mask is smaller, pad it
80
+ resized_mask = np.zeros(target_shape, dtype=mask.dtype)
81
+ start_y = (ty - sy) // 2
82
+ start_x = (tx - sx) // 2
83
+ resized_mask[start_y : start_y + sy, start_x : start_x + sx] = mask
84
+ return resized_mask
85
+
86
+
87
+ def get_common_path(ops_files: list | tuple):
88
+ """
89
+ Find the common parent path of all files.
90
+
91
+ Parameters
92
+ ----------
93
+ ops_files : list or tuple
94
+ List of file paths.
95
+
96
+ Returns
97
+ -------
98
+ Path
99
+ Common parent directory of all files.
100
+ """
101
+ if not isinstance(ops_files, (list, tuple)):
102
+ ops_files = [ops_files]
103
+ if len(ops_files) == 1:
104
+ path = Path(ops_files[0]).parent
105
+ while (
106
+ path.exists() and len(list(path.iterdir())) <= 1
107
+ ): # Traverse up if only one item exists
108
+ path = path.parent
109
+ return path
110
+ else:
111
+ return Path(os.path.commonpath(ops_files))
112
+
113
+
114
+ def estimate_peak_memory(ops, Ly, Lx, n_frames, device="cuda", workers=1):
115
+ """
116
+ Estimate peak memory for one Suite2p plane from its parameters.
117
+
118
+ Registration and detection run sequentially in suite2p's pipeline, so
119
+ the peak for a plane is ``max(registration, detection)``, not the sum.
120
+ The two stages load different pools:
121
+
122
+ - Registration compute runs on ``device``. When ``device`` is cuda the
123
+ per-batch float32/FFT buffers live in VRAM, scaled by
124
+ ``batch_size * Ly * Lx``. The reference-image correlation
125
+ (``pick_initial_reference``) and the binary read stay on host.
126
+ - Detection's binned movie is a host numpy array, and the default
127
+ ``sparsery`` / ``sourcery`` detectors run on CPU. ``device`` is only
128
+ used by the cellpose path, which sees the 2D meanImg / max_proj, not
129
+ the movie. So the binned movie never enters VRAM.
130
+
131
+ Host RAM therefore peaks during detection (binned movie plus the
132
+ high-pass / sparsery copies, ~2.5x); VRAM peaks during registration
133
+ (or cellpose inference, when enabled). Neither VRAM term scales with
134
+ ``n_frames``; host detection plateaus once ``n_frames // bin_size``
135
+ exceeds ``nbins``.
136
+
137
+ Parameters
138
+ ----------
139
+ ops : dict
140
+ Flat Suite2p ops. Reads ``nimg_init``, ``batch_size`` (registration
141
+ batch), ``nbins``, ``bin_size``, ``tau``, ``fs``, ``nchannels``, and
142
+ ``anatomical_only``. Missing keys fall back to suite2p defaults.
143
+ Ly, Lx : int
144
+ Frame height and width in pixels. The detection crop (yrange/xrange)
145
+ is unknown before registration, so full ``Ly``/``Lx`` are used as an
146
+ upper bound.
147
+ n_frames : int
148
+ Number of frames in the plane.
149
+ device : str, optional (default "cuda")
150
+ Torch device. VRAM terms are reported only when this starts with
151
+ "cuda".
152
+ workers : int, optional (default 1)
153
+ Concurrent plane workers. Per-plane peaks are multiplied by this for
154
+ the ``*_total`` fields.
155
+
156
+ Returns
157
+ -------
158
+ dict
159
+ Bytes for ``host_per_plane``, ``vram_per_plane``, ``host_total``,
160
+ ``vram_total``. VRAM fields are 0 when ``device`` is not cuda.
161
+
162
+ Notes
163
+ -----
164
+ The 2.5x detection multiplier and the cpsam VRAM constant are rough; the
165
+ real values depend on data and hardware. Calibrate with
166
+ ``torch.cuda.max_memory_allocated()`` and an RSS sample around the
167
+ registration / detection calls for tight bounds.
168
+ """
169
+ cuda = str(device).startswith("cuda")
170
+
171
+ nimg_init = min(int(ops.get("nimg_init", 400)), n_frames)
172
+ reg_host = nimg_init * Ly * Lx * 10 # int16 frames + float64 ref corr (CPU)
173
+
174
+ nbins = int(ops.get("nbins", 5000))
175
+ bin_size = ops.get("bin_size") or max(
176
+ 1, n_frames // nbins, round(ops.get("tau", 1.0) * ops.get("fs", 10.0))
177
+ )
178
+ nbinned = min(nbins, n_frames // bin_size)
179
+ detect_host = int(2.5 * nbinned * Ly * Lx * 4) # binned movie + hp/sparsery copies (CPU)
180
+
181
+ host_peak = max(reg_host, detect_host)
182
+
183
+ vram_peak = 0
184
+ if cuda:
185
+ reg_batch = int(ops.get("batch_size", 100))
186
+ nchan = int(ops.get("nchannels", 1))
187
+ reg_vram = nchan * 8 * reg_batch * Ly * Lx * 4 # ~8 float32/FFT buffers per batch
188
+ cpsam_vram = 1_500_000_000 if ops.get("anatomical_only", 0) else 0 # cpsam weights + activations
189
+ vram_peak = max(reg_vram, cpsam_vram)
190
+
191
+ return {
192
+ "host_per_plane": host_peak,
193
+ "vram_per_plane": vram_peak,
194
+ "host_total": host_peak * max(1, workers),
195
+ "vram_total": vram_peak * max(1, workers),
196
+ }
197
+
198
+
199
+ def bin1d(X, bin_size, axis=0):
200
+ """
201
+ Mean bin over `axis` of `X` with bin `bin_size`.
202
+
203
+ Parameters
204
+ ----------
205
+ X : np.ndarray
206
+ Input array to be binned.
207
+ bin_size : int
208
+ Size of the bin. If <=0, no binning is performed.
209
+ axis : int, optional
210
+ Axis along which to bin. Default is 0.
211
+
212
+ Returns
213
+ -------
214
+ np.ndarray
215
+ Binned array with reduced size along the specified axis.
216
+ """
217
+ if bin_size > 0:
218
+ size = list(X.shape)
219
+ Xb = X.swapaxes(0, axis)
220
+ size_new = Xb.shape
221
+ Xb = (
222
+ Xb[: size[axis] // bin_size * bin_size]
223
+ .reshape((size[axis] // bin_size, bin_size, *size_new[1:]))
224
+ .mean(axis=1)
225
+ )
226
+ Xb = Xb.swapaxes(axis, 0)
227
+ return Xb
228
+ else:
229
+ return X
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lbm_suite2p_python
3
- Version: 3.2.0
3
+ Version: 3.2.2
4
4
  Summary: Calcium Imaging Pipeline built with Suite2p, Cellpose and Rastermap
5
5
  License-Expression: BSD-3-Clause
6
6
  Project-URL: homepage, https://github.com/MillerBrainObservatory/LBM-Suite2p-Python
@@ -11,9 +11,11 @@ Classifier: Programming Language :: Python :: 3 :: Only
11
11
  Requires-Python: <3.14,>=3.12.7
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE.md
14
- Requires-Dist: mbo_utilities>=3.2.1
14
+ Requires-Dist: mbo_utilities>=3.2.7
15
15
  Requires-Dist: suite2p>=1.0.0.1
16
16
  Requires-Dist: setuptools<81
17
+ Requires-Dist: torch
18
+ Requires-Dist: torchvision
17
19
  Provides-Extra: rastermap
18
20
  Requires-Dist: rastermap; extra == "rastermap"
19
21
  Provides-Extra: cellpose
@@ -1,6 +1,8 @@
1
- mbo_utilities>=3.2.1
1
+ mbo_utilities>=3.2.7
2
2
  suite2p>=1.0.0.1
3
3
  setuptools<81
4
+ torch
5
+ torchvision
4
6
 
5
7
  [all]
6
8
  lbm_suite2p_python[cellpose,rastermap]
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "lbm_suite2p_python"
7
- version = "3.2.0"
7
+ version = "3.2.2"
8
8
  description = "Calcium Imaging Pipeline built with Suite2p, Cellpose and Rastermap"
9
9
  readme = "README.md"
10
10
  license = "BSD-3-Clause"
@@ -18,9 +18,11 @@ classifiers=[
18
18
  ]
19
19
 
20
20
  dependencies = [
21
- "mbo_utilities>=3.2.1",
21
+ "mbo_utilities>=3.2.7",
22
22
  "suite2p>=1.0.0.1",
23
23
  "setuptools<81",
24
+ "torch",
25
+ "torchvision",
24
26
  ]
25
27
 
26
28
  [project.scripts]
@@ -67,6 +69,15 @@ docs = [
67
69
  "suite2p",
68
70
  ]
69
71
 
72
+ [tool.uv.sources]
73
+ torch = [{ index = "pytorch-cu126" }]
74
+ torchvision = [{ index = "pytorch-cu126" }]
75
+
76
+ [[tool.uv.index]]
77
+ name = "pytorch-cu126"
78
+ url = "https://download.pytorch.org/whl/cu126"
79
+ explicit = true
80
+
70
81
  # https://github.com/charliermarsh/ruff
71
82
  [tool.ruff]
72
83
  line-length = 88
@@ -1,144 +0,0 @@
1
- import os
2
- import numpy as np
3
- from pathlib import Path
4
-
5
-
6
- # mbo_utilities >= 4.0 exposes a single LazyArray base; isinstance covers
7
- # every built-in and any third-party plugin. The class-name tuple is the
8
- # fallback for mbo_utilities < 4.0, which has no shared base.
9
- try:
10
- from mbo_utilities import LazyArray as _LazyArray
11
- except ImportError: # mbo_utilities < 4.0
12
- _LazyArray = None
13
-
14
- _LAZY_ARRAY_TYPES = (
15
- "ScanImageArray",
16
- "LBMArray",
17
- "PiezoArray",
18
- "SinglePlaneArray",
19
- "Suite2pArray",
20
- "MBOTiffArray",
21
- "MboRawArray",
22
- "TiffArray",
23
- "ZarrArray",
24
- "H5Array",
25
- "NumpyArray",
26
- "BinArray",
27
- )
28
-
29
-
30
- def _is_lazy_array(obj):
31
- """Check if obj is an mbo_utilities lazy array type."""
32
- if _LazyArray is not None and isinstance(obj, _LazyArray):
33
- return True
34
- return type(obj).__name__ in _LAZY_ARRAY_TYPES
35
-
36
-
37
- def _get_num_planes(arr):
38
- """
39
- Get number of z-planes from a lazy array.
40
-
41
- mbo_utilities arrays are always 5D TCZYX, so Z is at shape[2].
42
- Falls back to legacy 4D TZYX (Z at shape[1]) and other heuristics
43
- for non-mbo arrays.
44
-
45
- Parameters
46
- ----------
47
- arr : array-like
48
- Input array, typically from mbo_utilities.
49
-
50
- Returns
51
- -------
52
- int
53
- Number of z-planes (1 if no Z dimension).
54
- """
55
- # mbo_utilities Shape5DMixin
56
- if hasattr(arr, "nz"):
57
- return arr.nz
58
- if hasattr(arr, "num_planes") and arr.num_planes is not None:
59
- return arr.num_planes
60
- shape = arr.shape
61
- if len(shape) == 5:
62
- return shape[2] # 5D TCZYX
63
- if len(shape) == 4:
64
- return shape[1] # legacy 4D TZYX
65
- return 1
66
-
67
-
68
- def _resize_masks_fit_crop(mask, target_shape):
69
- """Centers a mask within the target shape, cropping if too large or padding if too small."""
70
- sy, sx = mask.shape
71
- ty, tx = target_shape
72
-
73
- # If mask is larger, crop it
74
- if sy > ty or sx > tx:
75
- start_y = (sy - ty) // 2
76
- start_x = (sx - tx) // 2
77
- return mask[start_y : start_y + ty, start_x : start_x + tx]
78
-
79
- # If mask is smaller, pad it
80
- resized_mask = np.zeros(target_shape, dtype=mask.dtype)
81
- start_y = (ty - sy) // 2
82
- start_x = (tx - sx) // 2
83
- resized_mask[start_y : start_y + sy, start_x : start_x + sx] = mask
84
- return resized_mask
85
-
86
-
87
- def get_common_path(ops_files: list | tuple):
88
- """
89
- Find the common parent path of all files.
90
-
91
- Parameters
92
- ----------
93
- ops_files : list or tuple
94
- List of file paths.
95
-
96
- Returns
97
- -------
98
- Path
99
- Common parent directory of all files.
100
- """
101
- if not isinstance(ops_files, (list, tuple)):
102
- ops_files = [ops_files]
103
- if len(ops_files) == 1:
104
- path = Path(ops_files[0]).parent
105
- while (
106
- path.exists() and len(list(path.iterdir())) <= 1
107
- ): # Traverse up if only one item exists
108
- path = path.parent
109
- return path
110
- else:
111
- return Path(os.path.commonpath(ops_files))
112
-
113
-
114
- def bin1d(X, bin_size, axis=0):
115
- """
116
- Mean bin over `axis` of `X` with bin `bin_size`.
117
-
118
- Parameters
119
- ----------
120
- X : np.ndarray
121
- Input array to be binned.
122
- bin_size : int
123
- Size of the bin. If <=0, no binning is performed.
124
- axis : int, optional
125
- Axis along which to bin. Default is 0.
126
-
127
- Returns
128
- -------
129
- np.ndarray
130
- Binned array with reduced size along the specified axis.
131
- """
132
- if bin_size > 0:
133
- size = list(X.shape)
134
- Xb = X.swapaxes(0, axis)
135
- size_new = Xb.shape
136
- Xb = (
137
- Xb[: size[axis] // bin_size * bin_size]
138
- .reshape((size[axis] // bin_size, bin_size, *size_new[1:]))
139
- .mean(axis=1)
140
- )
141
- Xb = Xb.swapaxes(axis, 0)
142
- return Xb
143
- else:
144
- return X