celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ import os
2
+ from glob import glob
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PyQt5.QtCore import QSize
7
+ from PyQt5.QtGui import QDoubleValidator
8
+ from PyQt5.QtWidgets import QMessageBox, QHBoxLayout, QLabel, QComboBox, QLineEdit, QPushButton
9
+ from fonticon_mdi6 import MDI6
10
+ from natsort import natsorted
11
+ from superqt.fonticon import icon
12
+
13
+ from celldetective.gui.gui_utils import PreprocessingLayout2
14
+ from celldetective.gui.viewers.base_viewer import StackVisualizer
15
+ from celldetective.utils.image_loaders import load_frames
16
+ from celldetective import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ class SpotDetectionVisualizer(StackVisualizer):
21
+
22
+ def __init__(
23
+ self,
24
+ parent_channel_cb=None,
25
+ parent_diameter_le=None,
26
+ parent_threshold_le=None,
27
+ parent_preprocessing_list=None,
28
+ cell_type="targets",
29
+ labels=None,
30
+ *args,
31
+ **kwargs,
32
+ ):
33
+
34
+ super().__init__(*args, **kwargs)
35
+
36
+ self.cell_type = cell_type
37
+ self.labels = labels
38
+ self.detection_channel = self.target_channel
39
+ self.switch_from_channel = False
40
+
41
+ self.parent_channel_cb = parent_channel_cb
42
+ self.parent_diameter_le = parent_diameter_le
43
+ self.parent_threshold_le = parent_threshold_le
44
+ self.parent_preprocessing_list = parent_preprocessing_list
45
+
46
+ self.spot_sizes = []
47
+ self.floatValidator = QDoubleValidator()
48
+ self.init_scatter()
49
+
50
+ self.generate_detection_channel()
51
+ self.detection_channel = self.detection_channel_cb.currentIndex()
52
+
53
+ self.generate_spot_detection_params()
54
+ self.generate_add_measurement_btn()
55
+ self.load_labels()
56
+ self.change_frame(self.mid_time)
57
+
58
+ self.ax.callbacks.connect("xlim_changed", self.update_marker_sizes)
59
+ self.ax.callbacks.connect("ylim_changed", self.update_marker_sizes)
60
+
61
+ self.apply_diam_btn.clicked.connect(self.detect_and_display_spots)
62
+ self.apply_thresh_btn.clicked.connect(self.detect_and_display_spots)
63
+
64
+ self.channel_cb.setCurrentIndex(self.target_channel)
65
+ self.detection_channel_cb.setCurrentIndex(self.target_channel)
66
+
67
+ def update_marker_sizes(self, event=None):
68
+
69
+ # Get axis bounds
70
+ xlim = self.ax.get_xlim()
71
+ ylim = self.ax.get_ylim()
72
+
73
+ # Data-to-pixel scale
74
+ ax_width_in_pixels = self.ax.bbox.width
75
+ ax_height_in_pixels = self.ax.bbox.height
76
+
77
+ x_scale = (float(xlim[1]) - float(xlim[0])) / ax_width_in_pixels
78
+ y_scale = (float(ylim[1]) - float(ylim[0])) / ax_height_in_pixels
79
+
80
+ # Choose the smaller scale for square pixels
81
+ scale = min(x_scale, y_scale)
82
+
83
+ # Convert radius_px to data units
84
+ if len(self.spot_sizes) > 0:
85
+
86
+ radius_data_units = self.spot_sizes / float(scale)
87
+
88
+ # Convert to scatter `s` size (points squared)
89
+ radius_pts = radius_data_units * (72.0 / self.fig.dpi)
90
+ size = np.pi * (radius_pts**2)
91
+
92
+ # Update scatter sizes
93
+ self.spot_scat.set_sizes(size)
94
+ self.fig.canvas.draw_idle()
95
+
96
+ def init_scatter(self):
97
+ self.spot_scat = self.ax.scatter(
98
+ [], [], s=50, facecolors="none", edgecolors="tab:red", zorder=100
99
+ )
100
+ self.canvas.canvas.draw()
101
+
102
+ def change_frame(self, value):
103
+
104
+ super().change_frame(value)
105
+ if not self.switch_from_channel:
106
+ self.reset_detection()
107
+
108
+ if self.mode == "virtual":
109
+ from tifffile import imread
110
+
111
+ self.init_label = imread(self.mask_paths[value])
112
+ self.target_img = load_frames(
113
+ self.img_num_per_channel[self.detection_channel, value],
114
+ self.stack_path,
115
+ normalize_input=False,
116
+ )[:, :, 0]
117
+ elif self.mode == "direct":
118
+ self.init_label = self.labels[value, :, :]
119
+ self.target_img = self.stack[value, :, :, self.detection_channel].copy()
120
+
121
+ def detect_and_display_spots(self):
122
+
123
+ self.reset_detection()
124
+ self.control_valid_parameters() # set current diam and threshold
125
+ # self.change_frame(self.frame_slider.value())
126
+ # self.set_detection_channel_index(self.detection_channel_cb.currentIndex())
127
+
128
+ image_preprocessing = self.preprocessing.list.items
129
+ if image_preprocessing == []:
130
+ image_preprocessing = None
131
+
132
+ from celldetective.measure import extract_blobs_in_image
133
+
134
+ blobs_filtered = extract_blobs_in_image(
135
+ self.target_img,
136
+ self.init_label,
137
+ threshold=self.thresh,
138
+ diameter=self.diameter,
139
+ image_preprocessing=image_preprocessing,
140
+ )
141
+ if blobs_filtered is not None:
142
+ self.spot_positions = np.array([[x, y] for y, x, _ in blobs_filtered])
143
+ if len(self.spot_positions) > 0:
144
+ self.spot_sizes = np.sqrt(2) * np.array(
145
+ [sig for _, _, sig in blobs_filtered]
146
+ )
147
+ # radius_pts = self.spot_sizes * (self.fig.dpi / 72.0)
148
+ # sizes = np.pi*(radius_pts**2)
149
+ if len(self.spot_positions) > 0:
150
+ self.spot_scat.set_offsets(self.spot_positions)
151
+ else:
152
+ empty_offset = np.ma.masked_array([0, 0], mask=True)
153
+ self.spot_scat.set_offsets(empty_offset)
154
+ # self.spot_scat.set_sizes(sizes)
155
+ if len(self.spot_positions) > 0:
156
+ self.update_marker_sizes()
157
+ self.canvas.canvas.draw()
158
+
159
+ def reset_detection(self):
160
+
161
+ self.ax.scatter([], []).get_offsets()
162
+ empty_offset = np.ma.masked_array([0, 0], mask=True)
163
+ self.spot_scat.set_offsets(empty_offset)
164
+ self.canvas.canvas.draw()
165
+
166
+ def load_labels(self):
167
+
168
+ # Load the cell labels
169
+ if self.labels is not None:
170
+
171
+ if isinstance(self.labels, list):
172
+ self.labels = np.array(self.labels)
173
+
174
+ assert (
175
+ self.labels.ndim == 3
176
+ ), "Wrong dimensions for the provided labels, expect TXY"
177
+ assert len(self.labels) == self.stack_length
178
+
179
+ self.mode = "direct"
180
+ self.init_label = self.labels[self.mid_time, :, :]
181
+ else:
182
+ self.mode = "virtual"
183
+ assert isinstance(self.stack_path, str)
184
+ assert self.stack_path.endswith(".tif")
185
+ self.locate_labels_virtual()
186
+
187
+ def locate_labels_virtual(self):
188
+ # Locate virtual labels
189
+
190
+ labels_path = (
191
+ str(Path(self.stack_path).parent.parent)
192
+ + os.sep
193
+ + f"labels_{self.cell_type}"
194
+ + os.sep
195
+ )
196
+ self.mask_paths = natsorted(glob(labels_path + "*.tif"))
197
+
198
+ if len(self.mask_paths) == 0:
199
+
200
+ msgBox = QMessageBox()
201
+ msgBox.setIcon(QMessageBox.Critical)
202
+ msgBox.setText("No labels were found for the selected cells. Abort.")
203
+ msgBox.setWindowTitle("Critical")
204
+ msgBox.setStandardButtons(QMessageBox.Ok)
205
+ returnValue = msgBox.exec()
206
+ self.close()
207
+
208
+ from tifffile import imread
209
+
210
+ self.init_label = imread(self.mask_paths[self.frame_slider.value()])
211
+
212
+ def generate_detection_channel(self):
213
+
214
+ assert self.channel_names is not None
215
+ assert len(self.channel_names) == self.n_channels
216
+
217
+ channel_layout = QHBoxLayout()
218
+ channel_layout.setContentsMargins(15, 0, 15, 0)
219
+ channel_layout.addWidget(QLabel("Detection\nchannel: "), 25)
220
+
221
+ self.detection_channel_cb = QComboBox()
222
+ self.detection_channel_cb.addItems(self.channel_names)
223
+ self.detection_channel_cb.currentIndexChanged.connect(
224
+ self.set_detection_channel_index
225
+ )
226
+ channel_layout.addWidget(self.detection_channel_cb, 75)
227
+
228
+ # self.invert_check = QCheckBox('invert')
229
+ # if self.invert:
230
+ # self.invert_check.setChecked(True)
231
+ # self.invert_check.toggled.connect(self.set_invert)
232
+ # channel_layout.addWidget(self.invert_check, 10)
233
+
234
+ self.canvas.layout.addLayout(channel_layout)
235
+
236
+ self.preprocessing = PreprocessingLayout2(fraction=25, parent_window=self)
237
+ self.preprocessing.setContentsMargins(15, 0, 15, 0)
238
+ self.canvas.layout.addLayout(self.preprocessing)
239
+
240
+ # def set_invert(self):
241
+ # if self.invert_check.isChecked():
242
+ # self.invert = True
243
+ # else:
244
+ # self.invert = False
245
+
246
+ def set_detection_channel_index(self, value):
247
+
248
+ self.detection_channel = value
249
+ if self.mode == "direct":
250
+ self.target_img = self.stack[-1, :, :, self.detection_channel]
251
+ elif self.mode == "virtual":
252
+ self.target_img = load_frames(
253
+ self.img_num_per_channel[
254
+ self.detection_channel, self.frame_slider.value()
255
+ ],
256
+ self.stack_path,
257
+ normalize_input=False,
258
+ ).astype(float)[:, :, 0]
259
+
260
+ def generate_spot_detection_params(self):
261
+
262
+ self.spot_diam_le = QLineEdit("1")
263
+ self.spot_diam_le.setValidator(self.floatValidator)
264
+ self.apply_diam_btn = QPushButton("Set")
265
+ self.apply_diam_btn.setStyleSheet(self.button_style_sheet_2)
266
+
267
+ self.spot_thresh_le = QLineEdit("0")
268
+ self.spot_thresh_le.setValidator(self.floatValidator)
269
+ self.apply_thresh_btn = QPushButton("Set")
270
+ self.apply_thresh_btn.setStyleSheet(self.button_style_sheet_2)
271
+
272
+ self.spot_diam_le.textChanged.connect(self.control_valid_parameters)
273
+ self.spot_thresh_le.textChanged.connect(self.control_valid_parameters)
274
+
275
+ spot_diam_layout = QHBoxLayout()
276
+ spot_diam_layout.setContentsMargins(15, 0, 15, 0)
277
+ spot_diam_layout.addWidget(QLabel("Spot diameter: "), 25)
278
+ spot_diam_layout.addWidget(self.spot_diam_le, 65)
279
+ spot_diam_layout.addWidget(self.apply_diam_btn, 10)
280
+ self.canvas.layout.addLayout(spot_diam_layout)
281
+
282
+ spot_thresh_layout = QHBoxLayout()
283
+ spot_thresh_layout.setContentsMargins(15, 0, 15, 0)
284
+ spot_thresh_layout.addWidget(QLabel("Detection\nthreshold: "), 25)
285
+ spot_thresh_layout.addWidget(self.spot_thresh_le, 65)
286
+ spot_thresh_layout.addWidget(self.apply_thresh_btn, 10)
287
+ self.canvas.layout.addLayout(spot_thresh_layout)
288
+
289
+ def generate_add_measurement_btn(self):
290
+
291
+ add_hbox = QHBoxLayout()
292
+ self.add_measurement_btn = QPushButton("Add measurement")
293
+ self.add_measurement_btn.clicked.connect(self.set_measurement_in_parent_list)
294
+ self.add_measurement_btn.setIcon(icon(MDI6.plus, color="white"))
295
+ self.add_measurement_btn.setIconSize(QSize(20, 20))
296
+ self.add_measurement_btn.setStyleSheet(self.button_style_sheet)
297
+ add_hbox.addWidget(QLabel(""), 33)
298
+ add_hbox.addWidget(self.add_measurement_btn, 33)
299
+ add_hbox.addWidget(QLabel(""), 33)
300
+ self.canvas.layout.addLayout(add_hbox)
301
+
302
+ def control_valid_parameters(self):
303
+
304
+ valid_diam = False
305
+ try:
306
+ self.diameter = float(self.spot_diam_le.text().replace(",", "."))
307
+ valid_diam = True
308
+ except:
309
+ valid_diam = False
310
+
311
+ valid_thresh = False
312
+ try:
313
+ self.thresh = float(self.spot_thresh_le.text().replace(",", "."))
314
+ valid_thresh = True
315
+ except:
316
+ valid_thresh = False
317
+
318
+ if valid_diam and valid_thresh:
319
+ self.apply_diam_btn.setEnabled(True)
320
+ self.apply_thresh_btn.setEnabled(True)
321
+ self.add_measurement_btn.setEnabled(True)
322
+ else:
323
+ self.apply_diam_btn.setEnabled(False)
324
+ self.apply_thresh_btn.setEnabled(False)
325
+ self.add_measurement_btn.setEnabled(False)
326
+
327
+ def set_measurement_in_parent_list(self):
328
+
329
+ if self.parent_channel_cb is not None:
330
+ self.parent_channel_cb.setCurrentIndex(self.detection_channel)
331
+ if self.parent_diameter_le is not None:
332
+ self.parent_diameter_le.setText(self.spot_diam_le.text())
333
+ if self.parent_threshold_le is not None:
334
+ self.parent_threshold_le.setText(self.spot_thresh_le.text())
335
+ if self.parent_preprocessing_list is not None:
336
+ self.parent_preprocessing_list.clear()
337
+ items = self.preprocessing.list.getItems()
338
+ for item in items:
339
+ self.parent_preprocessing_list.addItemToList(item)
340
+ self.parent_preprocessing_list.items = self.preprocessing.list.items
341
+ self.close()
@@ -0,0 +1,309 @@
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ from PyQt5.QtWidgets import QLineEdit, QHBoxLayout, QPushButton, QLabel
5
+ from superqt import QLabeledDoubleSlider
6
+
7
+ from celldetective.gui.gui_utils import QuickSliderLayout
8
+ from celldetective.gui.viewers.base_viewer import StackVisualizer
9
+ from celldetective import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class ThresholdedStackVisualizer(StackVisualizer):
15
+ """
16
+ A widget for visualizing thresholded image stacks with interactive sliders and channel selection.
17
+
18
+ Parameters:
19
+ - preprocessing (list or None): A list of preprocessing filters to apply to the image before thresholding.
20
+ - parent_le: The parent QLineEdit instance to set the threshold value.
21
+ - initial_threshold (float): Initial threshold value.
22
+ - initial_mask_alpha (float): Initial mask opacity value.
23
+ - args, kwargs: Additional arguments to pass to the parent class constructor.
24
+
25
+ Methods:
26
+ - generate_apply_btn(): Generate the apply button to set the threshold in the parent QLineEdit.
27
+ - set_threshold_in_parent_le(): Set the threshold value in the parent QLineEdit.
28
+ - generate_mask_imshow(): Generate the mask imshow.
29
+ - generate_threshold_slider(): Generate the threshold slider.
30
+ - generate_opacity_slider(): Generate the opacity slider for the mask.
31
+ - change_mask_opacity(value): Change the opacity of the mask.
32
+ - change_threshold(value): Change the threshold value.
33
+ - change_frame(value): Change the displayed frame and update the threshold.
34
+ - compute_mask(threshold_value): Compute the mask based on the threshold value.
35
+ - preprocess_image(): Preprocess the image before thresholding.
36
+
37
+ Notes:
38
+ - This class extends the functionality of StackVisualizer to visualize thresholded image stacks
39
+ with interactive sliders for threshold and mask opacity adjustment.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ preprocessing=None,
45
+ parent_le=None,
46
+ initial_threshold=5,
47
+ initial_mask_alpha=0.5,
48
+ show_opacity_slider=True,
49
+ show_threshold_slider=True,
50
+ fill_holes=True,
51
+ *args,
52
+ **kwargs,
53
+ ):
54
+ # Initialize the widget and its attributes
55
+ super().__init__(*args, **kwargs)
56
+ self.preprocessing = preprocessing
57
+ self.thresh = initial_threshold
58
+ self.mask_alpha = initial_mask_alpha
59
+ self.fill_holes = fill_holes
60
+ self.parent_le = parent_le
61
+ self.show_opacity_slider = show_opacity_slider
62
+ self.show_threshold_slider = show_threshold_slider
63
+ self.thresholded = False
64
+ self.mask = np.zeros_like(self.init_frame)
65
+ self.thresh_min = 0.0
66
+ self.thresh_max = 30.0
67
+
68
+ self.thresh_max = 30.0
69
+
70
+ # Cache for processed images
71
+ self.processed_cache = OrderedDict()
72
+ self.processed_image = None
73
+ self.max_processed_cache_size = 128
74
+
75
+ self.generate_threshold_slider()
76
+
77
+ if self.thresh is not None:
78
+ self.compute_mask(self.thresh)
79
+
80
+ self.generate_mask_imshow()
81
+ self.generate_scatter()
82
+ self.generate_opacity_slider()
83
+ if isinstance(self.parent_le, QLineEdit):
84
+ self.generate_apply_btn()
85
+
86
+ def generate_apply_btn(self):
87
+ # Generate the apply button to set the threshold in the parent QLineEdit
88
+ apply_hbox = QHBoxLayout()
89
+ self.apply_threshold_btn = QPushButton("Apply")
90
+ self.apply_threshold_btn.clicked.connect(self.set_threshold_in_parent_le)
91
+ self.apply_threshold_btn.setStyleSheet(self.button_style_sheet)
92
+ apply_hbox.addWidget(QLabel(""), 33)
93
+ apply_hbox.addWidget(self.apply_threshold_btn, 33)
94
+ apply_hbox.addWidget(QLabel(""), 33)
95
+ self.canvas.layout.addLayout(apply_hbox)
96
+
97
+ def closeEvent(self, event):
98
+ if hasattr(self, "processed_cache") and isinstance(
99
+ self.processed_cache, OrderedDict
100
+ ):
101
+ self.processed_cache.clear()
102
+ super().closeEvent(event)
103
+
104
+ def set_threshold_in_parent_le(self):
105
+ # Set the threshold value in the parent QLineEdit
106
+ self.parent_le.set_threshold(self.threshold_slider.value())
107
+ self.close()
108
+
109
+ def generate_mask_imshow(self):
110
+ # Generate the mask imshow
111
+
112
+ self.im_mask = self.ax.imshow(
113
+ np.ma.masked_where(self.mask == 0, self.mask),
114
+ alpha=self.mask_alpha,
115
+ interpolation="none",
116
+ vmin=0,
117
+ vmax=1,
118
+ cmap="Purples",
119
+ )
120
+ self.canvas.canvas.draw()
121
+
122
+ def generate_scatter(self):
123
+ self.scat_markers = self.ax.scatter([], [], color="tab:red")
124
+
125
+ def generate_threshold_slider(self):
126
+ # Generate the threshold slider
127
+ self.threshold_slider = QLabeledDoubleSlider()
128
+ if self.thresh is None:
129
+ init_value = 1.0e5
130
+ elif isinstance(self.thresh, (list, tuple, np.ndarray)):
131
+ init_value = self.thresh[0]
132
+ else:
133
+ init_value = self.thresh
134
+ thresh_layout = QuickSliderLayout(
135
+ label="Threshold: ",
136
+ slider=self.threshold_slider,
137
+ slider_initial_value=init_value,
138
+ slider_range=(self.thresh_min, np.amax([self.thresh_max, init_value])),
139
+ decimal_option=True,
140
+ precision=4,
141
+ )
142
+ thresh_layout.setContentsMargins(15, 0, 15, 0)
143
+ self.threshold_slider.valueChanged.connect(self.change_threshold)
144
+ if self.show_threshold_slider:
145
+ self.canvas.layout.addLayout(thresh_layout)
146
+
147
+ def generate_opacity_slider(self):
148
+ # Generate the opacity slider for the mask
149
+ self.opacity_slider = QLabeledDoubleSlider()
150
+ opacity_layout = QuickSliderLayout(
151
+ label="Opacity: ",
152
+ slider=self.opacity_slider,
153
+ slider_initial_value=0.5,
154
+ slider_range=(0, 1),
155
+ decimal_option=True,
156
+ precision=3,
157
+ )
158
+ opacity_layout.setContentsMargins(15, 0, 15, 0)
159
+ self.opacity_slider.valueChanged.connect(self.change_mask_opacity)
160
+ if self.show_opacity_slider:
161
+ self.canvas.layout.addLayout(opacity_layout)
162
+
163
+ def change_mask_opacity(self, value):
164
+ # Change the opacity of the mask
165
+ self.mask_alpha = value
166
+ self.im_mask.set_alpha(self.mask_alpha)
167
+ self.canvas.canvas.draw_idle()
168
+
169
+ def change_threshold(self, value):
170
+ # Change the threshold value
171
+ self.thresh = value
172
+
173
+ # Sync slider if value came from external source (like Wizard)
174
+ # to prevent slider from being "stale" and overwriting with old value later
175
+ if hasattr(self, "threshold_slider"):
176
+ display_val = value
177
+ if isinstance(value, (list, tuple, np.ndarray)):
178
+ display_val = value[0]
179
+
180
+ try:
181
+ current_val = self.threshold_slider.value()
182
+ # Update slider if significant difference
183
+ if abs(current_val - float(display_val)) > 1e-5:
184
+ self.threshold_slider.blockSignals(True)
185
+ self.threshold_slider.setValue(float(display_val))
186
+ self.threshold_slider.blockSignals(False)
187
+ except Exception:
188
+ pass
189
+
190
+ if self.thresh is not None:
191
+ self.compute_mask(self.thresh)
192
+ mask = np.ma.masked_where(self.mask == 0, self.mask)
193
+ self.im_mask.set_data(mask)
194
+ self.canvas.canvas.draw_idle()
195
+
196
+ def change_frame(self, value):
197
+ # Change the displayed frame and update the threshold
198
+ if self.thresholded:
199
+ self.init_contrast = True
200
+ super().change_frame(value)
201
+ self.processed_image = None
202
+
203
+ if self.thresh is not None:
204
+ self.change_threshold(self.thresh)
205
+ else:
206
+ self.change_threshold(self.threshold_slider.value())
207
+
208
+ if self.thresholded:
209
+ self.thresholded = False
210
+ self.init_contrast = False
211
+
212
+ def compute_mask(self, threshold_value):
213
+ # Compute the mask based on the threshold value
214
+ if self.processed_image is None:
215
+ self.preprocess_image()
216
+
217
+ from celldetective.utils.image_transforms import (
218
+ estimate_unreliable_edge,
219
+ threshold_image,
220
+ )
221
+
222
+ edge = estimate_unreliable_edge(self.preprocessing)
223
+
224
+ if isinstance(threshold_value, (list, np.ndarray, tuple)):
225
+ self.mask = threshold_image(
226
+ self.processed_image,
227
+ threshold_value[0],
228
+ threshold_value[1],
229
+ foreground_value=1,
230
+ fill_holes=self.fill_holes,
231
+ edge_exclusion=edge,
232
+ ).astype(int)
233
+ else:
234
+ self.mask = threshold_image(
235
+ self.processed_image,
236
+ threshold_value,
237
+ np.inf,
238
+ foreground_value=1,
239
+ fill_holes=self.fill_holes,
240
+ edge_exclusion=edge,
241
+ ).astype(int)
242
+
243
+ def preprocess_image(self):
244
+ # Preprocess the image before thresholding
245
+
246
+ # Determine cache key
247
+ target = self.target_channel
248
+ time_idx = getattr(self, "current_time_index", 0)
249
+ cache_key = (target, time_idx, str(self.preprocessing))
250
+
251
+ # Check cache
252
+ if self.preprocessing is not None:
253
+ if cache_key in self.processed_cache:
254
+ self.processed_image = self.processed_cache[cache_key]
255
+ self.processed_cache.move_to_end(cache_key)
256
+ # Ensure slider range is updated even on cache hit?
257
+ # Probably redundant if image matches, but safe to skip or do lightweight check.
258
+ return
259
+
260
+ # Compute
261
+ if self.preprocessing is not None:
262
+ assert isinstance(self.preprocessing, list)
263
+ from celldetective.filters import filter_image
264
+
265
+ self.processed_image = filter_image(
266
+ self.init_frame.copy().astype(float), filters=self.preprocessing
267
+ )
268
+
269
+ # Subsampled min/max for slider range
270
+ if self.processed_image.size > 1000000:
271
+ view = self.processed_image[::30, ::30]
272
+ else:
273
+ view = self.processed_image
274
+
275
+ min_ = np.nanmin(view)
276
+ max_ = np.nanmax(view)
277
+
278
+ if min_ < self.thresh_min:
279
+ self.thresh_min = min_
280
+ if max_ > self.thresh_max:
281
+ self.thresh_max = max_
282
+
283
+ self.threshold_slider.setRange(self.thresh_min, self.thresh_max)
284
+
285
+ # Store in cache
286
+ self.processed_cache[cache_key] = self.processed_image
287
+ if len(self.processed_cache) > self.max_processed_cache_size:
288
+ self.processed_cache.popitem(last=False)
289
+
290
+ else:
291
+ # If no preprocessing, just use init_frame (casted)
292
+ # We don't cache this as it's just a reference or light copy of init_frame
293
+ self.processed_image = self.init_frame.astype(float)
294
+
295
+ def set_preprocessing(self, activation_protocol):
296
+
297
+ self.preprocessing = activation_protocol
298
+ self.preprocess_image()
299
+
300
+ self.im.set_data(self.processed_image)
301
+ vmin = np.nanpercentile(self.processed_image, 1.0)
302
+ vmax = np.nanpercentile(self.processed_image, 99.99)
303
+ self.contrast_slider.setRange(
304
+ np.nanmin(self.processed_image), np.nanmax(self.processed_image)
305
+ )
306
+ self.contrast_slider.setValue((vmin, vmax))
307
+ self.im.set_clim(vmin, vmax)
308
+ self.canvas.canvas.draw()
309
+ self.thresholded = True