lbm_caiman_python 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lbm_caiman_python/__init__.py +63 -0
- lbm_caiman_python/__main__.py +302 -0
- lbm_caiman_python/_version.py +8 -0
- lbm_caiman_python/batch.py +188 -0
- lbm_caiman_python/collation.py +125 -0
- lbm_caiman_python/default_ops.py +92 -0
- lbm_caiman_python/gui/__init__.py +3 -0
- lbm_caiman_python/gui/_store_model.py +170 -0
- lbm_caiman_python/gui/rungui.py +13 -0
- lbm_caiman_python/gui/widgets.py +114 -0
- lbm_caiman_python/helpers.py +262 -0
- lbm_caiman_python/postprocessing.py +319 -0
- lbm_caiman_python/run_lcp.py +1059 -0
- lbm_caiman_python/stdout.py +3 -0
- lbm_caiman_python/summary.py +569 -0
- lbm_caiman_python/util/__init__.py +87 -0
- lbm_caiman_python/util/exceptions.py +6 -0
- lbm_caiman_python/util/quality.py +366 -0
- lbm_caiman_python/util/signal.py +17 -0
- lbm_caiman_python/util/transform.py +208 -0
- lbm_caiman_python/visualize.py +522 -0
- lbm_caiman_python-0.2.0.dist-info/METADATA +161 -0
- lbm_caiman_python-0.2.0.dist-info/RECORD +27 -0
- lbm_caiman_python-0.2.0.dist-info/WHEEL +5 -0
- lbm_caiman_python-0.2.0.dist-info/entry_points.txt +2 -0
- lbm_caiman_python-0.2.0.dist-info/licenses/LICENSE.md +38 -0
- lbm_caiman_python-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""
|
|
2
|
+
default caiman parameters for lbm data processing.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def default_ops() -> dict:
|
|
7
|
+
"""
|
|
8
|
+
return default caiman parameters optimized for lbm microscopy data.
|
|
9
|
+
|
|
10
|
+
returns
|
|
11
|
+
-------
|
|
12
|
+
dict
|
|
13
|
+
dictionary of parameters for motion correction and cnmf.
|
|
14
|
+
"""
|
|
15
|
+
return {
|
|
16
|
+
# motion correction parameters
|
|
17
|
+
"do_motion_correction": True,
|
|
18
|
+
"max_shifts": (6, 6),
|
|
19
|
+
"strides": (48, 48),
|
|
20
|
+
"overlaps": (24, 24),
|
|
21
|
+
"max_deviation_rigid": 3,
|
|
22
|
+
"pw_rigid": True,
|
|
23
|
+
"gSig_filt": (2, 2),
|
|
24
|
+
"border_nan": "copy",
|
|
25
|
+
"niter_rig": 1,
|
|
26
|
+
"splits_rig": 14,
|
|
27
|
+
"num_splits_to_process_rig": None,
|
|
28
|
+
"splits_els": 14,
|
|
29
|
+
"num_splits_to_process_els": None,
|
|
30
|
+
"upsample_factor_grid": 4,
|
|
31
|
+
"max_deviation_rigid": 3,
|
|
32
|
+
"use_cuda": False,
|
|
33
|
+
|
|
34
|
+
# cnmf parameters
|
|
35
|
+
"do_cnmf": True,
|
|
36
|
+
"K": 50,
|
|
37
|
+
"gSig": (4, 4),
|
|
38
|
+
"gSiz": None,
|
|
39
|
+
"p": 1,
|
|
40
|
+
"merge_thresh": 0.8,
|
|
41
|
+
"min_SNR": 2.5,
|
|
42
|
+
"rval_thr": 0.85,
|
|
43
|
+
"decay_time": 0.4,
|
|
44
|
+
"method_init": "greedy_roi",
|
|
45
|
+
"ssub": 1,
|
|
46
|
+
"tsub": 1,
|
|
47
|
+
"rf": None,
|
|
48
|
+
"stride": None,
|
|
49
|
+
"nb": 1,
|
|
50
|
+
"gnb": 1,
|
|
51
|
+
"low_rank_background": True,
|
|
52
|
+
"update_background_components": True,
|
|
53
|
+
"rolling_sum": True,
|
|
54
|
+
"only_init": False,
|
|
55
|
+
"normalize_init": True,
|
|
56
|
+
"ring_size_factor": 1.5,
|
|
57
|
+
|
|
58
|
+
# component evaluation
|
|
59
|
+
"min_cnn_thr": 0.9,
|
|
60
|
+
"cnn_lowest": 0.1,
|
|
61
|
+
"use_cnn": False,
|
|
62
|
+
|
|
63
|
+
# general parameters
|
|
64
|
+
"fr": 30.0,
|
|
65
|
+
"n_processes": None,
|
|
66
|
+
"dxy": (1.0, 1.0),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def mcorr_ops() -> dict:
|
|
71
|
+
"""return only motion correction parameters."""
|
|
72
|
+
ops = default_ops()
|
|
73
|
+
return {k: v for k, v in ops.items() if k in (
|
|
74
|
+
"do_motion_correction", "max_shifts", "strides", "overlaps",
|
|
75
|
+
"max_deviation_rigid", "pw_rigid", "gSig_filt", "border_nan",
|
|
76
|
+
"niter_rig", "splits_rig", "num_splits_to_process_rig",
|
|
77
|
+
"splits_els", "num_splits_to_process_els", "upsample_factor_grid",
|
|
78
|
+
"use_cuda", "fr", "n_processes", "dxy",
|
|
79
|
+
)}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def cnmf_ops() -> dict:
|
|
83
|
+
"""return only cnmf parameters."""
|
|
84
|
+
ops = default_ops()
|
|
85
|
+
return {k: v for k, v in ops.items() if k in (
|
|
86
|
+
"do_cnmf", "K", "gSig", "gSiz", "p", "merge_thresh", "min_SNR",
|
|
87
|
+
"rval_thr", "decay_time", "method_init", "ssub", "tsub", "rf",
|
|
88
|
+
"stride", "nb", "gnb", "low_rank_background",
|
|
89
|
+
"update_background_components", "rolling_sum", "only_init",
|
|
90
|
+
"normalize_init", "ring_size_factor", "min_cnn_thr", "cnn_lowest",
|
|
91
|
+
"use_cnn", "fr", "n_processes", "dxy",
|
|
92
|
+
)}
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import numpy as np
|
|
3
|
+
from fastplotlib import ImageGraphic, LinearSelector, ScatterGraphic, ImageWidget
|
|
4
|
+
from ipywidgets import IntSlider, FloatSlider
|
|
5
|
+
|
|
6
|
+
from fastplotlib.graphics._features import FeatureEvent
|
|
7
|
+
|
|
8
|
+
MARGIN: float = 1
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO: need to make a method for automatic MARGIN setting based on the data
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TimeStoreComponent:
|
|
15
|
+
@property
|
|
16
|
+
def subscriber(self) -> ImageGraphic | IntSlider | FloatSlider | LinearSelector:
|
|
17
|
+
return self._subscriber
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def data(self) -> np.ndarray | None:
|
|
21
|
+
return self._data
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def multiplier(self) -> int | float | None:
|
|
25
|
+
return self._multiplier
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def data_filter(self) -> callable:
|
|
29
|
+
return self._data_filter
|
|
30
|
+
|
|
31
|
+
def __init__(self, subscriber, data=None, data_filter=None, multiplier=None):
|
|
32
|
+
"""A TimeStore component of the time store."""
|
|
33
|
+
if multiplier is None:
|
|
34
|
+
multiplier = 1
|
|
35
|
+
|
|
36
|
+
self._multiplier = multiplier
|
|
37
|
+
|
|
38
|
+
self._subscriber = subscriber
|
|
39
|
+
|
|
40
|
+
# must have data if ImageGraphic
|
|
41
|
+
if isinstance(self.subscriber, (ImageGraphic, ScatterGraphic)):
|
|
42
|
+
# LazyArrayRCM has no `__array__`, using `shape` for now
|
|
43
|
+
if not hasattr(data, 'shape'):
|
|
44
|
+
raise ValueError("If passing in `ImageGraphic` must provide associated `ndarray` object to update "
|
|
45
|
+
"data with.")
|
|
46
|
+
self._data = data
|
|
47
|
+
self._data_filter = data_filter
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TimeStore:
|
|
51
|
+
@property
|
|
52
|
+
def time(self):
|
|
53
|
+
"""Current t value that items in the store are set at."""
|
|
54
|
+
return self._time
|
|
55
|
+
|
|
56
|
+
@time.setter
|
|
57
|
+
def time(self, value: int | float):
|
|
58
|
+
"""Set the current time."""
|
|
59
|
+
self._time = int(value)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def store(self) -> List[TimeStoreComponent]:
|
|
63
|
+
"""Returns the items in the store."""
|
|
64
|
+
return self._store
|
|
65
|
+
|
|
66
|
+
def __init__(self):
|
|
67
|
+
"""
|
|
68
|
+
TimeStore for synchronizes and updating components of a plot (i.e. Ipywidgets.IntSlider,
|
|
69
|
+
fastplotlib.LinearSelector, or fastplotlob.ImageGraphic).
|
|
70
|
+
|
|
71
|
+
NOTE: If passing a `fastplotlib.ImageGraphic`, it is understood that there should be an associated
|
|
72
|
+
`ndarray` given.
|
|
73
|
+
"""
|
|
74
|
+
# initialize store
|
|
75
|
+
self._store = list()
|
|
76
|
+
# by default, time is zero
|
|
77
|
+
self._time = 0
|
|
78
|
+
|
|
79
|
+
def subscribe(self,
|
|
80
|
+
subscriber: ImageWidget | ImageGraphic | LinearSelector | ScatterGraphic | IntSlider | FloatSlider,
|
|
81
|
+
data: np.ndarray = None,
|
|
82
|
+
data_filter: callable = None,
|
|
83
|
+
multiplier: int | float = None) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Method for adding a subscriber to the store to be synchronized.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
subscriber: fastplotlib.ImageGraphic, fastplotlib.LinearSelector, ipywidgets.IntSlider, or ipywidgets.FloatSlider
|
|
90
|
+
ipywidget or fastplotlib object to be synchronized
|
|
91
|
+
data: np.ndarray, optional
|
|
92
|
+
If subscriber is a fastplotlib.ImageGraphic, must have an associating numpy.ndarray to update data with.
|
|
93
|
+
data_filter: callable, optional
|
|
94
|
+
Function to apply to data before updating. Must return data in the same shape as input.
|
|
95
|
+
multiplier: int | float, optional
|
|
96
|
+
Scale the current time to reflect differing timescale.
|
|
97
|
+
"""
|
|
98
|
+
# create a TimeStoreComponent
|
|
99
|
+
component = TimeStoreComponent(subscriber=subscriber,
|
|
100
|
+
data=data,
|
|
101
|
+
data_filter=data_filter,
|
|
102
|
+
multiplier=multiplier)
|
|
103
|
+
|
|
104
|
+
# add component to the store
|
|
105
|
+
self._store.append(component)
|
|
106
|
+
|
|
107
|
+
if isinstance(component.subscriber, ImageWidget):
|
|
108
|
+
component.subscriber.add_event_handler(self._update_store, "current_index")
|
|
109
|
+
if isinstance(component.subscriber, (IntSlider, FloatSlider)):
|
|
110
|
+
component.subscriber.observe(self._update_store, "value")
|
|
111
|
+
if isinstance(component.subscriber, LinearSelector):
|
|
112
|
+
component.subscriber.add_event_handler(self._update_store, "selection")
|
|
113
|
+
|
|
114
|
+
def unsubscribe(self, subscriber: ImageGraphic | LinearSelector | IntSlider | FloatSlider):
|
|
115
|
+
"""Remove a subscriber from the store."""
|
|
116
|
+
for component in self.store:
|
|
117
|
+
if component.subscriber == subscriber:
|
|
118
|
+
# remove the component from the store
|
|
119
|
+
self.store.remove(component)
|
|
120
|
+
# remove event handler
|
|
121
|
+
if isinstance(component, (IntSlider, FloatSlider)):
|
|
122
|
+
component.unobserve(self._update_store)
|
|
123
|
+
if isinstance(component, LinearSelector):
|
|
124
|
+
component.subscriber.remove_event_handler(self._update_store, "selection")
|
|
125
|
+
|
|
126
|
+
def _update_store(self, ev):
|
|
127
|
+
"""Called when event occurs and store needs to be updated."""
|
|
128
|
+
# parse event to see if it originated from ipywidget or selector
|
|
129
|
+
if isinstance(ev, FeatureEvent):
|
|
130
|
+
# check for multiplier to adjust time
|
|
131
|
+
for component in self.store:
|
|
132
|
+
if isinstance(component.subscriber, LinearSelector):
|
|
133
|
+
if ev.graphic == component.subscriber:
|
|
134
|
+
self.time = ev.info["value"] / component.multiplier
|
|
135
|
+
elif isinstance(ev, dict):
|
|
136
|
+
self.time = ev["t"]
|
|
137
|
+
else:
|
|
138
|
+
self.time = ev["new"]
|
|
139
|
+
|
|
140
|
+
print('Iterating components')
|
|
141
|
+
for component in self.store:
|
|
142
|
+
print('Component 1')
|
|
143
|
+
if isinstance(component.subscriber, ImageWidget):
|
|
144
|
+
# user moved qslider, don't update imagewidget
|
|
145
|
+
if isinstance(ev, dict) and 't' in ev:
|
|
146
|
+
pass
|
|
147
|
+
else:
|
|
148
|
+
component.subscriber.current_index = {"t": self.time}
|
|
149
|
+
elif isinstance(component.subscriber, ScatterGraphic):
|
|
150
|
+
component.subscriber.data = component.data[self.time]
|
|
151
|
+
# update ImageGraphic data no matter what
|
|
152
|
+
elif isinstance(component.subscriber, ImageGraphic):
|
|
153
|
+
if component.data_filter is None:
|
|
154
|
+
new_data = component.data[self.time]
|
|
155
|
+
else:
|
|
156
|
+
new_data = component.data_filter(component.data[self.time])
|
|
157
|
+
if new_data.shape != component.subscriber.data.value.shape:
|
|
158
|
+
raise ValueError(f"data filter function: {component.data_filter} must return data in the same shape"
|
|
159
|
+
f"as the current data")
|
|
160
|
+
component.subscriber.data = new_data
|
|
161
|
+
elif isinstance(component.subscriber, LinearSelector):
|
|
162
|
+
# only update if different
|
|
163
|
+
if abs(component.subscriber.selection - (self.time * component.multiplier)) > MARGIN:
|
|
164
|
+
print('Is LinearSelector and abs(component.subscriber.selection - (self.time * '
|
|
165
|
+
'component.multiplier)) > MARGIN')
|
|
166
|
+
component.subscriber.selection = self.time * component.multiplier
|
|
167
|
+
else:
|
|
168
|
+
# only update if different
|
|
169
|
+
if abs(component.subscriber.value - self.time) > MARGIN:
|
|
170
|
+
component.subscriber.value = self.time
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from qtpy.QtWidgets import QApplication
|
|
4
|
+
from lbm_caiman_python.gui.widgets import LBMMainWindow
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def run_gui(path):
|
|
8
|
+
app = QApplication(sys.argv)
|
|
9
|
+
main_window = LBMMainWindow()
|
|
10
|
+
print('--')
|
|
11
|
+
main_window.show()
|
|
12
|
+
app.exec()
|
|
13
|
+
# fpl.loop.run()
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import webbrowser
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from qtpy.QtWidgets import QMainWindow, QFileDialog
|
|
7
|
+
from qtpy import QtGui, QtCore
|
|
8
|
+
import fastplotlib as fpl
|
|
9
|
+
from fastplotlib.ui import EdgeWindow
|
|
10
|
+
|
|
11
|
+
from mbo_utilities import get_files, imread
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from imgui_bundle import imgui, icons_fontawesome_6 as fa
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise ImportError("Please install imgui via `conda install -c conda-forge imgui-bundle`")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_base_iw():
|
|
20
|
+
"""Temp until I figure out how to start with an empty canvas"""
|
|
21
|
+
rand = np.random.randn(100, 100, 100)
|
|
22
|
+
iw = fpl.ImageWidget(rand, histogram_widget=False)
|
|
23
|
+
return iw
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_iw(path):
|
|
27
|
+
files = get_files(path, "plane", 1)
|
|
28
|
+
arr = imread(files)
|
|
29
|
+
iw = fpl.ImageWidget(arr, histogram_widget=False)
|
|
30
|
+
return iw
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LBMMainWindow(QMainWindow):
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def image_widget(self):
|
|
37
|
+
return self._image_widget
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
super(LBMMainWindow, self).__init__()
|
|
41
|
+
|
|
42
|
+
print('Setting up main window')
|
|
43
|
+
self.setGeometry(50, 50, 1500, 800)
|
|
44
|
+
self.setWindowTitle("LBM-CaImAn-Python Pipeline")
|
|
45
|
+
|
|
46
|
+
app_icon = QtGui.QIcon()
|
|
47
|
+
icon_path = str(Path().home() / ".lbm" / "icons" / "icon_caiman_python.svg")
|
|
48
|
+
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
|
49
|
+
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
|
50
|
+
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
|
51
|
+
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
|
52
|
+
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
|
53
|
+
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
|
54
|
+
self.setWindowIcon(app_icon)
|
|
55
|
+
self.setStyleSheet("QMainWindow {background: 'black';}")
|
|
56
|
+
self.stylePressed = ("QPushButton {Text-align: left; "
|
|
57
|
+
"background-color: rgb(100,50,100); "
|
|
58
|
+
"color:white;}")
|
|
59
|
+
self.styleUnpressed = ("QPushButton {Text-align: left; "
|
|
60
|
+
"background-color: rgb(50,50,50); "
|
|
61
|
+
"color:white;}")
|
|
62
|
+
self.styleInactive = ("QPushButton {Text-align: left; "
|
|
63
|
+
"background-color: rgb(50,50,50); "
|
|
64
|
+
"color:gray;}")
|
|
65
|
+
|
|
66
|
+
print('Setting up image widget')
|
|
67
|
+
self._image_widget = get_base_iw()
|
|
68
|
+
gui = PreviewTracesWidget(size=50)
|
|
69
|
+
self._image_widget.figure.add_gui(gui)
|
|
70
|
+
qwidget = self._image_widget.show()
|
|
71
|
+
self.setCentralWidget(qwidget)
|
|
72
|
+
self.resize(1200, 800)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class PreviewTracesWidget(EdgeWindow):
|
|
76
|
+
def __init__(self, figure, size, location, title, image_widget):
|
|
77
|
+
super().__init__(figure=figure, size=size, location=location, title=title)
|
|
78
|
+
self._image_widget = image_widget
|
|
79
|
+
|
|
80
|
+
# whether or not a dimension is in play mode
|
|
81
|
+
self._playing: dict[str, bool] = {"t": False, "z": False}
|
|
82
|
+
|
|
83
|
+
self.tfig = fpl.Figure()
|
|
84
|
+
|
|
85
|
+
self.raw_trace = self.tfig[0, 0].add_line(np.zeros(self._image_widget.data[0].shape[0]))
|
|
86
|
+
self._image_widget.managed_graphics[0].add_event_handler("click")
|
|
87
|
+
self.tfig.show()
|
|
88
|
+
|
|
89
|
+
def pixel_clicked(self, ev):
|
|
90
|
+
col, row = ev.pick_info["index"]
|
|
91
|
+
if self._image_widget.ndim == 4:
|
|
92
|
+
self.raw_trace.data[:, 1] = self._image_widget.data[0][:, self._image_widget.current_index["z"], row, col]
|
|
93
|
+
elif self._image_widget.ndim == 3:
|
|
94
|
+
self.raw_trace.data[:, 1] = self._image_widget.data[0][:, row, col]
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError("ImageWidget has an unexpected number of dimensions. Expected 3 or 4.")
|
|
97
|
+
self.tfig[0, 0].auto_scale(maintain_aspect=False)
|
|
98
|
+
|
|
99
|
+
def update(self):
|
|
100
|
+
|
|
101
|
+
imgui.push_font(self._fa_icons)
|
|
102
|
+
if imgui.button(label=fa.ICON_FA_FOLDER_OPEN):
|
|
103
|
+
print("Opening file dialog")
|
|
104
|
+
dlg_kwargs = {
|
|
105
|
+
"parent": self.parent,
|
|
106
|
+
"caption": "Open folder with z-planes",
|
|
107
|
+
}
|
|
108
|
+
name = QFileDialog.getExistingDirectory(**dlg_kwargs)
|
|
109
|
+
print(name)
|
|
110
|
+
self.parent.update_widget(name)
|
|
111
|
+
|
|
112
|
+
imgui.pop_font()
|
|
113
|
+
if imgui.is_item_hovered(0):
|
|
114
|
+
imgui.set_tooltip("Open a file dialog to load data")
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Any as ArrayLike
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_30p_order():
|
|
7
|
+
return (np.array([
|
|
8
|
+
1, 5, 6, 7, 8, 9, 2, 10, 11, 12, 13, 14, 15, 16, 17, 3, 18, 19, 20, 21, 22, 23, 4, 24, 25, 26, 27, 28, 29, 30
|
|
9
|
+
]) - 1)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extract_center_square(images, size):
|
|
13
|
+
"""
|
|
14
|
+
Extract a square crop from the center of the input images.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
images : numpy.ndarray
|
|
19
|
+
Input array. Can be 2D (H x W) or 3D (T x H x W), where:
|
|
20
|
+
- H is the height of the image(s).
|
|
21
|
+
- W is the width of the image(s).
|
|
22
|
+
- T is the number of frames (if 3D).
|
|
23
|
+
size : int
|
|
24
|
+
The size of the square crop. The output will have dimensions
|
|
25
|
+
(size x size) for 2D inputs or (T x size x size) for 3D inputs.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
numpy.ndarray
|
|
30
|
+
A square crop from the center of the input images. The returned array
|
|
31
|
+
will have dimensions:
|
|
32
|
+
- (size x size) if the input is 2D.
|
|
33
|
+
- (T x size x size) if the input is 3D.
|
|
34
|
+
|
|
35
|
+
Raises
|
|
36
|
+
------
|
|
37
|
+
ValueError
|
|
38
|
+
If `images` is not a NumPy array.
|
|
39
|
+
If `images` is not 2D or 3D.
|
|
40
|
+
If the specified `size` is larger than the height or width of the input images.
|
|
41
|
+
|
|
42
|
+
Notes
|
|
43
|
+
-----
|
|
44
|
+
- For 2D arrays, the function extracts a square crop directly from the center.
|
|
45
|
+
- For 3D arrays, the crop is applied uniformly across all frames (T).
|
|
46
|
+
- If the input dimensions are smaller than the requested `size`, an error will be raised.
|
|
47
|
+
|
|
48
|
+
Examples
|
|
49
|
+
--------
|
|
50
|
+
Extract a center square from a 2D image:
|
|
51
|
+
|
|
52
|
+
>>> import numpy as np
|
|
53
|
+
>>> image = np.random.rand(600, 576)
|
|
54
|
+
>>> cropped = extract_center_square(image, size=200)
|
|
55
|
+
>>> cropped.shape
|
|
56
|
+
(200, 200)
|
|
57
|
+
|
|
58
|
+
Extract a center square from a 3D stack of images:
|
|
59
|
+
|
|
60
|
+
>>> stack = np.random.rand(100, 600, 576)
|
|
61
|
+
>>> cropped_stack = extract_center_square(stack, size=200)
|
|
62
|
+
>>> cropped_stack.shape
|
|
63
|
+
(100, 200, 200)
|
|
64
|
+
"""
|
|
65
|
+
if not isinstance(images, np.ndarray):
|
|
66
|
+
raise ValueError("Input must be a numpy array.")
|
|
67
|
+
|
|
68
|
+
if images.ndim == 2: # 2D array (H x W)
|
|
69
|
+
height, width = images.shape
|
|
70
|
+
center_h, center_w = height // 2, width // 2
|
|
71
|
+
half_size = size // 2
|
|
72
|
+
return images[center_h - half_size:center_h + half_size,
|
|
73
|
+
center_w - half_size:center_w + half_size]
|
|
74
|
+
|
|
75
|
+
elif images.ndim == 3: # 3D array (T x H x W)
|
|
76
|
+
T, height, width = images.shape
|
|
77
|
+
center_h, center_w = height // 2, width // 2
|
|
78
|
+
half_size = size // 2
|
|
79
|
+
return images[:,
|
|
80
|
+
center_h - half_size:center_h + half_size,
|
|
81
|
+
center_w - half_size:center_w + half_size]
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError("Input array must be 2D or 3D.")
|
|
84
|
+
|
|
85
|
+
def _get_30p_order():
|
|
86
|
+
return (np.array([
|
|
87
|
+
1, 5, 6, 7, 8, 9, 2, 10, 11, 12, 13, 14, 15, 16, 17, 3, 18, 19, 20, 21, 22, 23, 4, 24, 25, 26, 27, 28, 29, 30
|
|
88
|
+
]) - 1)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def extract_center_square(images, size):
|
|
92
|
+
"""
|
|
93
|
+
Extract a square crop from the center of the input images.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
images : numpy.ndarray
|
|
98
|
+
Input array. Can be 2D (H x W) or 3D (T x H x W), where:
|
|
99
|
+
- H is the height of the image(s).
|
|
100
|
+
- W is the width of the image(s).
|
|
101
|
+
- T is the number of frames (if 3D).
|
|
102
|
+
size : int
|
|
103
|
+
The size of the square crop. The output will have dimensions
|
|
104
|
+
(size x size) for 2D inputs or (T x size x size) for 3D inputs.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
numpy.ndarray
|
|
109
|
+
A square crop from the center of the input images. The returned array
|
|
110
|
+
will have dimensions:
|
|
111
|
+
- (size x size) if the input is 2D.
|
|
112
|
+
- (T x size x size) if the input is 3D.
|
|
113
|
+
|
|
114
|
+
Raises
|
|
115
|
+
------
|
|
116
|
+
ValueError
|
|
117
|
+
If `images` is not a NumPy array.
|
|
118
|
+
If `images` is not 2D or 3D.
|
|
119
|
+
If the specified `size` is larger than the height or width of the input images.
|
|
120
|
+
|
|
121
|
+
Notes
|
|
122
|
+
-----
|
|
123
|
+
- For 2D arrays, the function extracts a square crop directly from the center.
|
|
124
|
+
- For 3D arrays, the crop is applied uniformly across all frames (T).
|
|
125
|
+
- If the input dimensions are smaller than the requested `size`, an error will be raised.
|
|
126
|
+
|
|
127
|
+
Examples
|
|
128
|
+
--------
|
|
129
|
+
Extract a center square from a 2D image:
|
|
130
|
+
|
|
131
|
+
>>> import numpy as np
|
|
132
|
+
>>> image = np.random.rand(600, 576)
|
|
133
|
+
>>> cropped = extract_center_square(image, size=200)
|
|
134
|
+
>>> cropped.shape
|
|
135
|
+
(200, 200)
|
|
136
|
+
|
|
137
|
+
Extract a center square from a 3D stack of images:
|
|
138
|
+
|
|
139
|
+
>>> stack = np.random.rand(100, 600, 576)
|
|
140
|
+
>>> cropped_stack = extract_center_square(stack, size=200)
|
|
141
|
+
>>> cropped_stack.shape
|
|
142
|
+
(100, 200, 200)
|
|
143
|
+
"""
|
|
144
|
+
if not isinstance(images, np.ndarray):
|
|
145
|
+
raise ValueError("Input must be a numpy array.")
|
|
146
|
+
|
|
147
|
+
if images.ndim == 2: # 2D array (H x W)
|
|
148
|
+
height, width = images.shape
|
|
149
|
+
center_h, center_w = height // 2, width // 2
|
|
150
|
+
half_size = size // 2
|
|
151
|
+
return images[center_h - half_size:center_h + half_size,
|
|
152
|
+
center_w - half_size:center_w + half_size]
|
|
153
|
+
|
|
154
|
+
elif images.ndim == 3: # 3D array (T x H x W)
|
|
155
|
+
T, height, width = images.shape
|
|
156
|
+
center_h, center_w = height // 2, width // 2
|
|
157
|
+
half_size = size // 2
|
|
158
|
+
return images[:,
|
|
159
|
+
center_h - half_size:center_h + half_size,
|
|
160
|
+
center_w - half_size:center_w + half_size]
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError("Input array must be 2D or 3D.")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def get_single_patch_coords(dims, stride, overlap, patch_index):
|
|
166
|
+
"""
|
|
167
|
+
Get coordinates of a single patch based on stride, overlap parameters of motion-correction.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
dims : tuple
|
|
172
|
+
Dimensions of the image as (rows, cols).
|
|
173
|
+
stride : int
|
|
174
|
+
Number of pixels to include in each patch.
|
|
175
|
+
overlap : int
|
|
176
|
+
Number of pixels to overlap between patches.
|
|
177
|
+
patch_index : tuple
|
|
178
|
+
Index of the patch to return.
|
|
179
|
+
"""
|
|
180
|
+
patch_height = stride + overlap
|
|
181
|
+
patch_width = stride + overlap
|
|
182
|
+
rows = np.arange(0, dims[0] - patch_height + 1, stride)
|
|
183
|
+
cols = np.arange(0, dims[1] - patch_width + 1, stride)
|
|
184
|
+
|
|
185
|
+
row_idx, col_idx = patch_index
|
|
186
|
+
y_start = rows[row_idx]
|
|
187
|
+
x_start = cols[col_idx]
|
|
188
|
+
|
|
189
|
+
return y_start, y_start + patch_height, x_start, x_start + patch_width
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _pad_image_for_even_patches(image, stride, overlap):
|
|
193
|
+
patch_width = stride + overlap
|
|
194
|
+
padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0]
|
|
195
|
+
padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1]
|
|
196
|
+
return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def generate_patch_view(image: ArrayLike, pixel_resolution: float, target_patch_size: int = 40,
|
|
200
|
+
overlap_fraction: float = 0.5):
|
|
201
|
+
"""
|
|
202
|
+
Generate a patch visualization for a 2D image with approximately square patches of a specified size in microns.
|
|
203
|
+
Patches are evenly distributed across the image, using calculated strides and overlaps.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
image : ndarray
|
|
208
|
+
A 2D NumPy array representing the input image to be divided into patches.
|
|
209
|
+
pixel_resolution : float
|
|
210
|
+
The pixel resolution of the image in microns per pixel.
|
|
211
|
+
target_patch_size : float, optional
|
|
212
|
+
The desired size of the patches in microns. Default is 40 microns.
|
|
213
|
+
overlap_fraction : float, optional
|
|
214
|
+
The fraction of the patch size to use as overlap between patches. Default is 0.5 (50%).
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
fig : matplotlib.figure.Figure
|
|
219
|
+
A matplotlib figure containing the patch visualization.
|
|
220
|
+
ax : matplotlib.axes.Axes
|
|
221
|
+
A matplotlib axes object showing the patch layout on the image.
|
|
222
|
+
|
|
223
|
+
Examples
|
|
224
|
+
--------
|
|
225
|
+
>>> import numpy as np
|
|
226
|
+
>>> from matplotlib import pyplot as plt
|
|
227
|
+
>>> data = np.random.random((144, 600)) # Example 2D image
|
|
228
|
+
>>> pixel_resolution = 0.5 # Microns per pixel
|
|
229
|
+
>>> fig, ax = generate_patch_view(data, pixel_resolution)
|
|
230
|
+
>>> plt.show()
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
from caiman.utils.visualization import get_rectangle_coords, rect_draw
|
|
234
|
+
|
|
235
|
+
# Calculate stride and overlap in pixels
|
|
236
|
+
stride = int(target_patch_size / pixel_resolution)
|
|
237
|
+
overlap = int(overlap_fraction * stride)
|
|
238
|
+
|
|
239
|
+
# pad the image like caiman does
|
|
240
|
+
def pad_image_for_even_patches(image, stride, overlap):
|
|
241
|
+
patch_width = stride + overlap
|
|
242
|
+
padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0]
|
|
243
|
+
padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1]
|
|
244
|
+
return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y
|
|
245
|
+
|
|
246
|
+
padded_image, pad_x, pad_y = pad_image_for_even_patches(image, stride, overlap)
|
|
247
|
+
|
|
248
|
+
# Get patch coordinates
|
|
249
|
+
patch_rows, patch_cols = get_rectangle_coords(padded_image.shape, stride, overlap)
|
|
250
|
+
|
|
251
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
|
252
|
+
ax.imshow(padded_image, cmap='gray')
|
|
253
|
+
|
|
254
|
+
# Draw patches using rect_draw
|
|
255
|
+
for patch_row in patch_rows:
|
|
256
|
+
for patch_col in patch_cols:
|
|
257
|
+
rect_draw(patch_row, patch_col, color='white', alpha=0.2, ax=ax)
|
|
258
|
+
|
|
259
|
+
ax.set_title(f"Stride: {stride} pixels (~{stride * pixel_resolution:.1f} μm)\n"
|
|
260
|
+
f"Overlap: {overlap} pixels (~{overlap * pixel_resolution:.1f} μm)\n")
|
|
261
|
+
plt.tight_layout()
|
|
262
|
+
return fig, ax, stride, overlap
|