lbm_caiman_python 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,92 @@
1
+ """
2
+ default caiman parameters for lbm data processing.
3
+ """
4
+
5
+
6
+ def default_ops() -> dict:
7
+ """
8
+ return default caiman parameters optimized for lbm microscopy data.
9
+
10
+ returns
11
+ -------
12
+ dict
13
+ dictionary of parameters for motion correction and cnmf.
14
+ """
15
+ return {
16
+ # motion correction parameters
17
+ "do_motion_correction": True,
18
+ "max_shifts": (6, 6),
19
+ "strides": (48, 48),
20
+ "overlaps": (24, 24),
21
+ "max_deviation_rigid": 3,
22
+ "pw_rigid": True,
23
+ "gSig_filt": (2, 2),
24
+ "border_nan": "copy",
25
+ "niter_rig": 1,
26
+ "splits_rig": 14,
27
+ "num_splits_to_process_rig": None,
28
+ "splits_els": 14,
29
+ "num_splits_to_process_els": None,
30
+ "upsample_factor_grid": 4,
31
+ "max_deviation_rigid": 3,
32
+ "use_cuda": False,
33
+
34
+ # cnmf parameters
35
+ "do_cnmf": True,
36
+ "K": 50,
37
+ "gSig": (4, 4),
38
+ "gSiz": None,
39
+ "p": 1,
40
+ "merge_thresh": 0.8,
41
+ "min_SNR": 2.5,
42
+ "rval_thr": 0.85,
43
+ "decay_time": 0.4,
44
+ "method_init": "greedy_roi",
45
+ "ssub": 1,
46
+ "tsub": 1,
47
+ "rf": None,
48
+ "stride": None,
49
+ "nb": 1,
50
+ "gnb": 1,
51
+ "low_rank_background": True,
52
+ "update_background_components": True,
53
+ "rolling_sum": True,
54
+ "only_init": False,
55
+ "normalize_init": True,
56
+ "ring_size_factor": 1.5,
57
+
58
+ # component evaluation
59
+ "min_cnn_thr": 0.9,
60
+ "cnn_lowest": 0.1,
61
+ "use_cnn": False,
62
+
63
+ # general parameters
64
+ "fr": 30.0,
65
+ "n_processes": None,
66
+ "dxy": (1.0, 1.0),
67
+ }
68
+
69
+
70
+ def mcorr_ops() -> dict:
71
+ """return only motion correction parameters."""
72
+ ops = default_ops()
73
+ return {k: v for k, v in ops.items() if k in (
74
+ "do_motion_correction", "max_shifts", "strides", "overlaps",
75
+ "max_deviation_rigid", "pw_rigid", "gSig_filt", "border_nan",
76
+ "niter_rig", "splits_rig", "num_splits_to_process_rig",
77
+ "splits_els", "num_splits_to_process_els", "upsample_factor_grid",
78
+ "use_cuda", "fr", "n_processes", "dxy",
79
+ )}
80
+
81
+
82
+ def cnmf_ops() -> dict:
83
+ """return only cnmf parameters."""
84
+ ops = default_ops()
85
+ return {k: v for k, v in ops.items() if k in (
86
+ "do_cnmf", "K", "gSig", "gSiz", "p", "merge_thresh", "min_SNR",
87
+ "rval_thr", "decay_time", "method_init", "ssub", "tsub", "rf",
88
+ "stride", "nb", "gnb", "low_rank_background",
89
+ "update_background_components", "rolling_sum", "only_init",
90
+ "normalize_init", "ring_size_factor", "min_cnn_thr", "cnn_lowest",
91
+ "use_cnn", "fr", "n_processes", "dxy",
92
+ )}
@@ -0,0 +1,3 @@
1
+ from .rungui import run_gui
2
+
3
+ __all__ = ['run_gui']
@@ -0,0 +1,170 @@
1
+ from typing import *
2
+ import numpy as np
3
+ from fastplotlib import ImageGraphic, LinearSelector, ScatterGraphic, ImageWidget
4
+ from ipywidgets import IntSlider, FloatSlider
5
+
6
+ from fastplotlib.graphics._features import FeatureEvent
7
+
8
+ MARGIN: float = 1
9
+
10
+
11
+ # TODO: need to make a method for automatic MARGIN setting based on the data
12
+
13
+
14
+ class TimeStoreComponent:
15
+ @property
16
+ def subscriber(self) -> ImageGraphic | IntSlider | FloatSlider | LinearSelector:
17
+ return self._subscriber
18
+
19
+ @property
20
+ def data(self) -> np.ndarray | None:
21
+ return self._data
22
+
23
+ @property
24
+ def multiplier(self) -> int | float | None:
25
+ return self._multiplier
26
+
27
+ @property
28
+ def data_filter(self) -> callable:
29
+ return self._data_filter
30
+
31
+ def __init__(self, subscriber, data=None, data_filter=None, multiplier=None):
32
+ """A TimeStore component of the time store."""
33
+ if multiplier is None:
34
+ multiplier = 1
35
+
36
+ self._multiplier = multiplier
37
+
38
+ self._subscriber = subscriber
39
+
40
+ # must have data if ImageGraphic
41
+ if isinstance(self.subscriber, (ImageGraphic, ScatterGraphic)):
42
+ # LazyArrayRCM has no `__array__`, using `shape` for now
43
+ if not hasattr(data, 'shape'):
44
+ raise ValueError("If passing in `ImageGraphic` must provide associated `ndarray` object to update "
45
+ "data with.")
46
+ self._data = data
47
+ self._data_filter = data_filter
48
+
49
+
50
+ class TimeStore:
51
+ @property
52
+ def time(self):
53
+ """Current t value that items in the store are set at."""
54
+ return self._time
55
+
56
+ @time.setter
57
+ def time(self, value: int | float):
58
+ """Set the current time."""
59
+ self._time = int(value)
60
+
61
+ @property
62
+ def store(self) -> List[TimeStoreComponent]:
63
+ """Returns the items in the store."""
64
+ return self._store
65
+
66
+ def __init__(self):
67
+ """
68
+ TimeStore for synchronizes and updating components of a plot (i.e. Ipywidgets.IntSlider,
69
+ fastplotlib.LinearSelector, or fastplotlob.ImageGraphic).
70
+
71
+ NOTE: If passing a `fastplotlib.ImageGraphic`, it is understood that there should be an associated
72
+ `ndarray` given.
73
+ """
74
+ # initialize store
75
+ self._store = list()
76
+ # by default, time is zero
77
+ self._time = 0
78
+
79
+ def subscribe(self,
80
+ subscriber: ImageWidget | ImageGraphic | LinearSelector | ScatterGraphic | IntSlider | FloatSlider,
81
+ data: np.ndarray = None,
82
+ data_filter: callable = None,
83
+ multiplier: int | float = None) -> None:
84
+ """
85
+ Method for adding a subscriber to the store to be synchronized.
86
+
87
+ Parameters
88
+ ----------
89
+ subscriber: fastplotlib.ImageGraphic, fastplotlib.LinearSelector, ipywidgets.IntSlider, or ipywidgets.FloatSlider
90
+ ipywidget or fastplotlib object to be synchronized
91
+ data: np.ndarray, optional
92
+ If subscriber is a fastplotlib.ImageGraphic, must have an associating numpy.ndarray to update data with.
93
+ data_filter: callable, optional
94
+ Function to apply to data before updating. Must return data in the same shape as input.
95
+ multiplier: int | float, optional
96
+ Scale the current time to reflect differing timescale.
97
+ """
98
+ # create a TimeStoreComponent
99
+ component = TimeStoreComponent(subscriber=subscriber,
100
+ data=data,
101
+ data_filter=data_filter,
102
+ multiplier=multiplier)
103
+
104
+ # add component to the store
105
+ self._store.append(component)
106
+
107
+ if isinstance(component.subscriber, ImageWidget):
108
+ component.subscriber.add_event_handler(self._update_store, "current_index")
109
+ if isinstance(component.subscriber, (IntSlider, FloatSlider)):
110
+ component.subscriber.observe(self._update_store, "value")
111
+ if isinstance(component.subscriber, LinearSelector):
112
+ component.subscriber.add_event_handler(self._update_store, "selection")
113
+
114
+ def unsubscribe(self, subscriber: ImageGraphic | LinearSelector | IntSlider | FloatSlider):
115
+ """Remove a subscriber from the store."""
116
+ for component in self.store:
117
+ if component.subscriber == subscriber:
118
+ # remove the component from the store
119
+ self.store.remove(component)
120
+ # remove event handler
121
+ if isinstance(component, (IntSlider, FloatSlider)):
122
+ component.unobserve(self._update_store)
123
+ if isinstance(component, LinearSelector):
124
+ component.subscriber.remove_event_handler(self._update_store, "selection")
125
+
126
+ def _update_store(self, ev):
127
+ """Called when event occurs and store needs to be updated."""
128
+ # parse event to see if it originated from ipywidget or selector
129
+ if isinstance(ev, FeatureEvent):
130
+ # check for multiplier to adjust time
131
+ for component in self.store:
132
+ if isinstance(component.subscriber, LinearSelector):
133
+ if ev.graphic == component.subscriber:
134
+ self.time = ev.info["value"] / component.multiplier
135
+ elif isinstance(ev, dict):
136
+ self.time = ev["t"]
137
+ else:
138
+ self.time = ev["new"]
139
+
140
+ print('Iterating components')
141
+ for component in self.store:
142
+ print('Component 1')
143
+ if isinstance(component.subscriber, ImageWidget):
144
+ # user moved qslider, don't update imagewidget
145
+ if isinstance(ev, dict) and 't' in ev:
146
+ pass
147
+ else:
148
+ component.subscriber.current_index = {"t": self.time}
149
+ elif isinstance(component.subscriber, ScatterGraphic):
150
+ component.subscriber.data = component.data[self.time]
151
+ # update ImageGraphic data no matter what
152
+ elif isinstance(component.subscriber, ImageGraphic):
153
+ if component.data_filter is None:
154
+ new_data = component.data[self.time]
155
+ else:
156
+ new_data = component.data_filter(component.data[self.time])
157
+ if new_data.shape != component.subscriber.data.value.shape:
158
+ raise ValueError(f"data filter function: {component.data_filter} must return data in the same shape"
159
+ f"as the current data")
160
+ component.subscriber.data = new_data
161
+ elif isinstance(component.subscriber, LinearSelector):
162
+ # only update if different
163
+ if abs(component.subscriber.selection - (self.time * component.multiplier)) > MARGIN:
164
+ print('Is LinearSelector and abs(component.subscriber.selection - (self.time * '
165
+ 'component.multiplier)) > MARGIN')
166
+ component.subscriber.selection = self.time * component.multiplier
167
+ else:
168
+ # only update if different
169
+ if abs(component.subscriber.value - self.time) > MARGIN:
170
+ component.subscriber.value = self.time
@@ -0,0 +1,13 @@
1
+ import sys
2
+
3
+ from qtpy.QtWidgets import QApplication
4
+ from lbm_caiman_python.gui.widgets import LBMMainWindow
5
+
6
+
7
+ def run_gui(path):
8
+ app = QApplication(sys.argv)
9
+ main_window = LBMMainWindow()
10
+ print('--')
11
+ main_window.show()
12
+ app.exec()
13
+ # fpl.loop.run()
@@ -0,0 +1,114 @@
1
+ import webbrowser
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from qtpy.QtWidgets import QMainWindow, QFileDialog
7
+ from qtpy import QtGui, QtCore
8
+ import fastplotlib as fpl
9
+ from fastplotlib.ui import EdgeWindow
10
+
11
+ from mbo_utilities import get_files, imread
12
+
13
+ try:
14
+ from imgui_bundle import imgui, icons_fontawesome_6 as fa
15
+ except ImportError:
16
+ raise ImportError("Please install imgui via `conda install -c conda-forge imgui-bundle`")
17
+
18
+
19
+ def get_base_iw():
20
+ """Temp until I figure out how to start with an empty canvas"""
21
+ rand = np.random.randn(100, 100, 100)
22
+ iw = fpl.ImageWidget(rand, histogram_widget=False)
23
+ return iw
24
+
25
+
26
+ def get_iw(path):
27
+ files = get_files(path, "plane", 1)
28
+ arr = imread(files)
29
+ iw = fpl.ImageWidget(arr, histogram_widget=False)
30
+ return iw
31
+
32
+
33
+ class LBMMainWindow(QMainWindow):
34
+
35
+ @property
36
+ def image_widget(self):
37
+ return self._image_widget
38
+
39
+ def __init__(self):
40
+ super(LBMMainWindow, self).__init__()
41
+
42
+ print('Setting up main window')
43
+ self.setGeometry(50, 50, 1500, 800)
44
+ self.setWindowTitle("LBM-CaImAn-Python Pipeline")
45
+
46
+ app_icon = QtGui.QIcon()
47
+ icon_path = str(Path().home() / ".lbm" / "icons" / "icon_caiman_python.svg")
48
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
49
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
50
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
51
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
52
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
53
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
54
+ self.setWindowIcon(app_icon)
55
+ self.setStyleSheet("QMainWindow {background: 'black';}")
56
+ self.stylePressed = ("QPushButton {Text-align: left; "
57
+ "background-color: rgb(100,50,100); "
58
+ "color:white;}")
59
+ self.styleUnpressed = ("QPushButton {Text-align: left; "
60
+ "background-color: rgb(50,50,50); "
61
+ "color:white;}")
62
+ self.styleInactive = ("QPushButton {Text-align: left; "
63
+ "background-color: rgb(50,50,50); "
64
+ "color:gray;}")
65
+
66
+ print('Setting up image widget')
67
+ self._image_widget = get_base_iw()
68
+ gui = PreviewTracesWidget(size=50)
69
+ self._image_widget.figure.add_gui(gui)
70
+ qwidget = self._image_widget.show()
71
+ self.setCentralWidget(qwidget)
72
+ self.resize(1200, 800)
73
+
74
+
75
+ class PreviewTracesWidget(EdgeWindow):
76
+ def __init__(self, figure, size, location, title, image_widget):
77
+ super().__init__(figure=figure, size=size, location=location, title=title)
78
+ self._image_widget = image_widget
79
+
80
+ # whether or not a dimension is in play mode
81
+ self._playing: dict[str, bool] = {"t": False, "z": False}
82
+
83
+ self.tfig = fpl.Figure()
84
+
85
+ self.raw_trace = self.tfig[0, 0].add_line(np.zeros(self._image_widget.data[0].shape[0]))
86
+ self._image_widget.managed_graphics[0].add_event_handler("click")
87
+ self.tfig.show()
88
+
89
+ def pixel_clicked(self, ev):
90
+ col, row = ev.pick_info["index"]
91
+ if self._image_widget.ndim == 4:
92
+ self.raw_trace.data[:, 1] = self._image_widget.data[0][:, self._image_widget.current_index["z"], row, col]
93
+ elif self._image_widget.ndim == 3:
94
+ self.raw_trace.data[:, 1] = self._image_widget.data[0][:, row, col]
95
+ else:
96
+ raise ValueError("ImageWidget has an unexpected number of dimensions. Expected 3 or 4.")
97
+ self.tfig[0, 0].auto_scale(maintain_aspect=False)
98
+
99
+ def update(self):
100
+
101
+ imgui.push_font(self._fa_icons)
102
+ if imgui.button(label=fa.ICON_FA_FOLDER_OPEN):
103
+ print("Opening file dialog")
104
+ dlg_kwargs = {
105
+ "parent": self.parent,
106
+ "caption": "Open folder with z-planes",
107
+ }
108
+ name = QFileDialog.getExistingDirectory(**dlg_kwargs)
109
+ print(name)
110
+ self.parent.update_widget(name)
111
+
112
+ imgui.pop_font()
113
+ if imgui.is_item_hovered(0):
114
+ imgui.set_tooltip("Open a file dialog to load data")
@@ -0,0 +1,262 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from typing import Any as ArrayLike
4
+
5
+
6
+ def _get_30p_order():
7
+ return (np.array([
8
+ 1, 5, 6, 7, 8, 9, 2, 10, 11, 12, 13, 14, 15, 16, 17, 3, 18, 19, 20, 21, 22, 23, 4, 24, 25, 26, 27, 28, 29, 30
9
+ ]) - 1)
10
+
11
+
12
+ def extract_center_square(images, size):
13
+ """
14
+ Extract a square crop from the center of the input images.
15
+
16
+ Parameters
17
+ ----------
18
+ images : numpy.ndarray
19
+ Input array. Can be 2D (H x W) or 3D (T x H x W), where:
20
+ - H is the height of the image(s).
21
+ - W is the width of the image(s).
22
+ - T is the number of frames (if 3D).
23
+ size : int
24
+ The size of the square crop. The output will have dimensions
25
+ (size x size) for 2D inputs or (T x size x size) for 3D inputs.
26
+
27
+ Returns
28
+ -------
29
+ numpy.ndarray
30
+ A square crop from the center of the input images. The returned array
31
+ will have dimensions:
32
+ - (size x size) if the input is 2D.
33
+ - (T x size x size) if the input is 3D.
34
+
35
+ Raises
36
+ ------
37
+ ValueError
38
+ If `images` is not a NumPy array.
39
+ If `images` is not 2D or 3D.
40
+ If the specified `size` is larger than the height or width of the input images.
41
+
42
+ Notes
43
+ -----
44
+ - For 2D arrays, the function extracts a square crop directly from the center.
45
+ - For 3D arrays, the crop is applied uniformly across all frames (T).
46
+ - If the input dimensions are smaller than the requested `size`, an error will be raised.
47
+
48
+ Examples
49
+ --------
50
+ Extract a center square from a 2D image:
51
+
52
+ >>> import numpy as np
53
+ >>> image = np.random.rand(600, 576)
54
+ >>> cropped = extract_center_square(image, size=200)
55
+ >>> cropped.shape
56
+ (200, 200)
57
+
58
+ Extract a center square from a 3D stack of images:
59
+
60
+ >>> stack = np.random.rand(100, 600, 576)
61
+ >>> cropped_stack = extract_center_square(stack, size=200)
62
+ >>> cropped_stack.shape
63
+ (100, 200, 200)
64
+ """
65
+ if not isinstance(images, np.ndarray):
66
+ raise ValueError("Input must be a numpy array.")
67
+
68
+ if images.ndim == 2: # 2D array (H x W)
69
+ height, width = images.shape
70
+ center_h, center_w = height // 2, width // 2
71
+ half_size = size // 2
72
+ return images[center_h - half_size:center_h + half_size,
73
+ center_w - half_size:center_w + half_size]
74
+
75
+ elif images.ndim == 3: # 3D array (T x H x W)
76
+ T, height, width = images.shape
77
+ center_h, center_w = height // 2, width // 2
78
+ half_size = size // 2
79
+ return images[:,
80
+ center_h - half_size:center_h + half_size,
81
+ center_w - half_size:center_w + half_size]
82
+ else:
83
+ raise ValueError("Input array must be 2D or 3D.")
84
+
85
+ def _get_30p_order():
86
+ return (np.array([
87
+ 1, 5, 6, 7, 8, 9, 2, 10, 11, 12, 13, 14, 15, 16, 17, 3, 18, 19, 20, 21, 22, 23, 4, 24, 25, 26, 27, 28, 29, 30
88
+ ]) - 1)
89
+
90
+
91
+ def extract_center_square(images, size):
92
+ """
93
+ Extract a square crop from the center of the input images.
94
+
95
+ Parameters
96
+ ----------
97
+ images : numpy.ndarray
98
+ Input array. Can be 2D (H x W) or 3D (T x H x W), where:
99
+ - H is the height of the image(s).
100
+ - W is the width of the image(s).
101
+ - T is the number of frames (if 3D).
102
+ size : int
103
+ The size of the square crop. The output will have dimensions
104
+ (size x size) for 2D inputs or (T x size x size) for 3D inputs.
105
+
106
+ Returns
107
+ -------
108
+ numpy.ndarray
109
+ A square crop from the center of the input images. The returned array
110
+ will have dimensions:
111
+ - (size x size) if the input is 2D.
112
+ - (T x size x size) if the input is 3D.
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ If `images` is not a NumPy array.
118
+ If `images` is not 2D or 3D.
119
+ If the specified `size` is larger than the height or width of the input images.
120
+
121
+ Notes
122
+ -----
123
+ - For 2D arrays, the function extracts a square crop directly from the center.
124
+ - For 3D arrays, the crop is applied uniformly across all frames (T).
125
+ - If the input dimensions are smaller than the requested `size`, an error will be raised.
126
+
127
+ Examples
128
+ --------
129
+ Extract a center square from a 2D image:
130
+
131
+ >>> import numpy as np
132
+ >>> image = np.random.rand(600, 576)
133
+ >>> cropped = extract_center_square(image, size=200)
134
+ >>> cropped.shape
135
+ (200, 200)
136
+
137
+ Extract a center square from a 3D stack of images:
138
+
139
+ >>> stack = np.random.rand(100, 600, 576)
140
+ >>> cropped_stack = extract_center_square(stack, size=200)
141
+ >>> cropped_stack.shape
142
+ (100, 200, 200)
143
+ """
144
+ if not isinstance(images, np.ndarray):
145
+ raise ValueError("Input must be a numpy array.")
146
+
147
+ if images.ndim == 2: # 2D array (H x W)
148
+ height, width = images.shape
149
+ center_h, center_w = height // 2, width // 2
150
+ half_size = size // 2
151
+ return images[center_h - half_size:center_h + half_size,
152
+ center_w - half_size:center_w + half_size]
153
+
154
+ elif images.ndim == 3: # 3D array (T x H x W)
155
+ T, height, width = images.shape
156
+ center_h, center_w = height // 2, width // 2
157
+ half_size = size // 2
158
+ return images[:,
159
+ center_h - half_size:center_h + half_size,
160
+ center_w - half_size:center_w + half_size]
161
+ else:
162
+ raise ValueError("Input array must be 2D or 3D.")
163
+
164
+
165
+ def get_single_patch_coords(dims, stride, overlap, patch_index):
166
+ """
167
+ Get coordinates of a single patch based on stride, overlap parameters of motion-correction.
168
+
169
+ Parameters
170
+ ----------
171
+ dims : tuple
172
+ Dimensions of the image as (rows, cols).
173
+ stride : int
174
+ Number of pixels to include in each patch.
175
+ overlap : int
176
+ Number of pixels to overlap between patches.
177
+ patch_index : tuple
178
+ Index of the patch to return.
179
+ """
180
+ patch_height = stride + overlap
181
+ patch_width = stride + overlap
182
+ rows = np.arange(0, dims[0] - patch_height + 1, stride)
183
+ cols = np.arange(0, dims[1] - patch_width + 1, stride)
184
+
185
+ row_idx, col_idx = patch_index
186
+ y_start = rows[row_idx]
187
+ x_start = cols[col_idx]
188
+
189
+ return y_start, y_start + patch_height, x_start, x_start + patch_width
190
+
191
+
192
+ def _pad_image_for_even_patches(image, stride, overlap):
193
+ patch_width = stride + overlap
194
+ padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0]
195
+ padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1]
196
+ return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y
197
+
198
+
199
+ def generate_patch_view(image: ArrayLike, pixel_resolution: float, target_patch_size: int = 40,
200
+ overlap_fraction: float = 0.5):
201
+ """
202
+ Generate a patch visualization for a 2D image with approximately square patches of a specified size in microns.
203
+ Patches are evenly distributed across the image, using calculated strides and overlaps.
204
+
205
+ Parameters
206
+ ----------
207
+ image : ndarray
208
+ A 2D NumPy array representing the input image to be divided into patches.
209
+ pixel_resolution : float
210
+ The pixel resolution of the image in microns per pixel.
211
+ target_patch_size : float, optional
212
+ The desired size of the patches in microns. Default is 40 microns.
213
+ overlap_fraction : float, optional
214
+ The fraction of the patch size to use as overlap between patches. Default is 0.5 (50%).
215
+
216
+ Returns
217
+ -------
218
+ fig : matplotlib.figure.Figure
219
+ A matplotlib figure containing the patch visualization.
220
+ ax : matplotlib.axes.Axes
221
+ A matplotlib axes object showing the patch layout on the image.
222
+
223
+ Examples
224
+ --------
225
+ >>> import numpy as np
226
+ >>> from matplotlib import pyplot as plt
227
+ >>> data = np.random.random((144, 600)) # Example 2D image
228
+ >>> pixel_resolution = 0.5 # Microns per pixel
229
+ >>> fig, ax = generate_patch_view(data, pixel_resolution)
230
+ >>> plt.show()
231
+ """
232
+
233
+ from caiman.utils.visualization import get_rectangle_coords, rect_draw
234
+
235
+ # Calculate stride and overlap in pixels
236
+ stride = int(target_patch_size / pixel_resolution)
237
+ overlap = int(overlap_fraction * stride)
238
+
239
+ # pad the image like caiman does
240
+ def pad_image_for_even_patches(image, stride, overlap):
241
+ patch_width = stride + overlap
242
+ padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0]
243
+ padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1]
244
+ return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y
245
+
246
+ padded_image, pad_x, pad_y = pad_image_for_even_patches(image, stride, overlap)
247
+
248
+ # Get patch coordinates
249
+ patch_rows, patch_cols = get_rectangle_coords(padded_image.shape, stride, overlap)
250
+
251
+ fig, ax = plt.subplots(figsize=(8, 8))
252
+ ax.imshow(padded_image, cmap='gray')
253
+
254
+ # Draw patches using rect_draw
255
+ for patch_row in patch_rows:
256
+ for patch_col in patch_cols:
257
+ rect_draw(patch_row, patch_col, color='white', alpha=0.2, ax=ax)
258
+
259
+ ax.set_title(f"Stride: {stride} pixels (~{stride * pixel_resolution:.1f} μm)\n"
260
+ f"Overlap: {overlap} pixels (~{overlap * pixel_resolution:.1f} μm)\n")
261
+ plt.tight_layout()
262
+ return fig, ax, stride, overlap