napari-spatial-correlation-plotter 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_spatial_correlation_plotter/__init__.py +6 -0
- napari_spatial_correlation_plotter/_nice_colormap.py +260 -0
- napari_spatial_correlation_plotter/_tests/__init__.py +0 -0
- napari_spatial_correlation_plotter/_widget.py +1059 -0
- napari_spatial_correlation_plotter/napari.yaml +14 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/LICENSE +22 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/METADATA +141 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/RECORD +11 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/WHEEL +5 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/entry_points.txt +2 -0
- napari_spatial_correlation_plotter-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1059 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path as PathL
|
|
3
|
+
from time import time
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import napari
|
|
7
|
+
import numpy as np
|
|
8
|
+
from magicgui.widgets import Container, EmptyWidget, create_widget
|
|
9
|
+
from matplotlib.backends.backend_qt5agg import (
|
|
10
|
+
FigureCanvasQTAgg as FigureCanvas,
|
|
11
|
+
)
|
|
12
|
+
from matplotlib.backends.backend_qt5agg import (
|
|
13
|
+
NavigationToolbar2QT as NavigationToolbar,
|
|
14
|
+
)
|
|
15
|
+
from matplotlib.figure import Figure
|
|
16
|
+
from matplotlib.path import Path
|
|
17
|
+
from matplotlib.widgets import LassoSelector, RectangleSelector
|
|
18
|
+
from napari.layers import Image, Labels, Layer
|
|
19
|
+
from napari.utils import DirectLabelColormap
|
|
20
|
+
from qtpy.QtCore import Qt
|
|
21
|
+
from qtpy.QtGui import QGuiApplication, QIcon
|
|
22
|
+
from skimage.measure import regionprops
|
|
23
|
+
from tapenade.analysis.spatial_correlation import SpatialCorrelationPlotter
|
|
24
|
+
from tapenade.preprocessing import masked_gaussian_smooth_dense_two_arrays_gpu
|
|
25
|
+
from vispy.color import Color
|
|
26
|
+
|
|
27
|
+
from napari_spatial_correlation_plotter._nice_colormap import get_nice_colormap
|
|
28
|
+
|
|
29
|
+
ICON_ROOT = PathL(__file__).parent / "icons"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# TODO:
|
|
33
|
+
# - add log scale to heatmap colors
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
colors = get_nice_colormap()
|
|
37
|
+
cmap = [Color(hex_name).RGBA.astype("float") / 255 for hex_name in colors]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def in_bbox(min_x, max_x, min_y, max_y, xys):
|
|
41
|
+
mins = np.array([min_x, min_y]).reshape(1, 2)
|
|
42
|
+
maxs = np.array([max_x, max_y]).reshape(1, 2)
|
|
43
|
+
|
|
44
|
+
foo = np.logical_and(xys >= mins, xys <= maxs)
|
|
45
|
+
|
|
46
|
+
return np.logical_and(foo[:, 0], foo[:, 1])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Class below was based upon matplotlib lasso selection example:
|
|
50
|
+
# https://matplotlib.org/stable/gallery/widgets/lasso_selector_demo_sgskip.html
|
|
51
|
+
class SelectFromCollection:
|
|
52
|
+
"""
|
|
53
|
+
Select indices from a matplotlib collection using `LassoSelector`.
|
|
54
|
+
Selected indices are saved in the `ind` attribute. This tool fades out the
|
|
55
|
+
points that are not part of the selection (i.e., reduces their alpha
|
|
56
|
+
values). If your collection has alpha < 1, this tool will permanently
|
|
57
|
+
alter the alpha values.
|
|
58
|
+
Note that this tool selects collection objects based on their *origins*
|
|
59
|
+
(i.e., `offsets`).
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
ax : `~matplotlib.axes.Axes`
|
|
63
|
+
Axes to interact with.
|
|
64
|
+
collection : `matplotlib.collections.Collection` subclass
|
|
65
|
+
Collection you want to select from.
|
|
66
|
+
alpha_other : 0 <= float <= 1
|
|
67
|
+
To highlight a selection, this tool sets all selected points to an
|
|
68
|
+
alpha value of 1 and non-selected points to *alpha_other*.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, parent, ax, xys, alpha_other=0.3):
|
|
72
|
+
self.canvas = ax.figure.canvas
|
|
73
|
+
self.parent = parent
|
|
74
|
+
# self.collection = collection
|
|
75
|
+
# self.alpha_other = alpha_other
|
|
76
|
+
|
|
77
|
+
# self.xys = collection.get_offsets()
|
|
78
|
+
self.xys = xys
|
|
79
|
+
# self.Npts = len(self.xys)
|
|
80
|
+
|
|
81
|
+
self.lasso = LassoSelector(ax, onselect=self.onselect, button=1)
|
|
82
|
+
self.ind = []
|
|
83
|
+
self.ind_mask = []
|
|
84
|
+
|
|
85
|
+
def onselect(self, verts):
|
|
86
|
+
verts = np.array(verts)
|
|
87
|
+
min_x, min_y = np.min(verts, axis=0)
|
|
88
|
+
max_x, max_y = np.max(verts, axis=0)
|
|
89
|
+
|
|
90
|
+
ind_mask = in_bbox(min_x, max_x, min_y, max_y, self.xys)
|
|
91
|
+
|
|
92
|
+
path = Path(verts)
|
|
93
|
+
# ind_mask = np.where(
|
|
94
|
+
# ind_mask, path.contains_points(self.xys[ind_mask]), False
|
|
95
|
+
# )
|
|
96
|
+
ind_mask[ind_mask] = path.contains_points(self.xys[ind_mask])
|
|
97
|
+
self.ind_mask = ind_mask
|
|
98
|
+
|
|
99
|
+
self.canvas.draw_idle()
|
|
100
|
+
# self.selected_coordinates = self.xys[self.ind].data
|
|
101
|
+
|
|
102
|
+
if self.parent.manual_clustering_method is not None:
|
|
103
|
+
self.parent.manual_clustering_method(self.ind_mask)
|
|
104
|
+
|
|
105
|
+
def disconnect(self):
|
|
106
|
+
self.lasso.disconnect_events()
|
|
107
|
+
self.canvas.draw_idle()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class MplCanvas(FigureCanvas):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
xys,
|
|
114
|
+
parent=None,
|
|
115
|
+
width=7,
|
|
116
|
+
height=4,
|
|
117
|
+
manual_clustering_method=None,
|
|
118
|
+
create_selectors=False,
|
|
119
|
+
):
|
|
120
|
+
|
|
121
|
+
self.xys = xys
|
|
122
|
+
|
|
123
|
+
if parent is None:
|
|
124
|
+
self.fig = Figure(figsize=(width, height))
|
|
125
|
+
self.axes = self.fig.add_subplot(111)
|
|
126
|
+
else:
|
|
127
|
+
self.fig = parent
|
|
128
|
+
if len(self.fig.axes) == 0:
|
|
129
|
+
self.fig.add_subplot(111)
|
|
130
|
+
self.axes = self.fig.axes[0]
|
|
131
|
+
# figure size
|
|
132
|
+
self.fig.set_size_inches(width, height)
|
|
133
|
+
self.manual_clustering_method = manual_clustering_method
|
|
134
|
+
self.fig.tight_layout()
|
|
135
|
+
|
|
136
|
+
super().__init__(self.fig)
|
|
137
|
+
|
|
138
|
+
self.reset_params(create_selectors=create_selectors, xys=xys)
|
|
139
|
+
|
|
140
|
+
def reset_params(self, create_selectors, xys):
|
|
141
|
+
self.axes = self.fig.axes[0]
|
|
142
|
+
|
|
143
|
+
if len(self.axes.collections) == 0:
|
|
144
|
+
self.pts = self.axes.scatter([], [])
|
|
145
|
+
|
|
146
|
+
self.pts = self.axes.collections[0]
|
|
147
|
+
|
|
148
|
+
self.fig.patch.set_facecolor("#262930")
|
|
149
|
+
|
|
150
|
+
# changing color of plot background to napari main window color
|
|
151
|
+
if create_selectors:
|
|
152
|
+
self.axes.set_facecolor("white")
|
|
153
|
+
else:
|
|
154
|
+
self.axes.set_facecolor("#262930")
|
|
155
|
+
|
|
156
|
+
# changing colors of all axes
|
|
157
|
+
self.axes.spines["bottom"].set_color("white")
|
|
158
|
+
self.axes.spines["top"].set_color("white")
|
|
159
|
+
self.axes.spines["right"].set_color("white")
|
|
160
|
+
self.axes.spines["left"].set_color("white")
|
|
161
|
+
self.axes.xaxis.label.set_color("white")
|
|
162
|
+
self.axes.yaxis.label.set_color("white")
|
|
163
|
+
|
|
164
|
+
# changing colors of axes labels
|
|
165
|
+
self.axes.tick_params(axis="x", colors="white")
|
|
166
|
+
self.axes.tick_params(axis="y", colors="white")
|
|
167
|
+
|
|
168
|
+
# COLORBAR
|
|
169
|
+
# extract already existing colobar from figure
|
|
170
|
+
if len(self.fig.axes) > 1:
|
|
171
|
+
cb = self.axes.images[0].colorbar
|
|
172
|
+
cb_label = cb.ax.get_ylabel()
|
|
173
|
+
# set colorbar label plus label color
|
|
174
|
+
cb.set_label(cb_label, color="white")
|
|
175
|
+
|
|
176
|
+
# set colorbar tick color
|
|
177
|
+
cb.ax.yaxis.set_tick_params(color="white")
|
|
178
|
+
|
|
179
|
+
# set colorbar edgecolor
|
|
180
|
+
cb.outline.set_edgecolor("white")
|
|
181
|
+
|
|
182
|
+
# set colorbar ticklabels
|
|
183
|
+
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
|
|
184
|
+
|
|
185
|
+
if create_selectors:
|
|
186
|
+
self.selector = SelectFromCollection(self, self.axes, xys)
|
|
187
|
+
# Rectangle
|
|
188
|
+
self.rectangle_selector = RectangleSelector(
|
|
189
|
+
self.axes,
|
|
190
|
+
self.draw_rectangle,
|
|
191
|
+
useblit=True,
|
|
192
|
+
props=dict(edgecolor="#1f77b4", fill=False),
|
|
193
|
+
button=3, # right button
|
|
194
|
+
minspanx=5,
|
|
195
|
+
minspany=5,
|
|
196
|
+
spancoords="pixels",
|
|
197
|
+
interactive=False,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def draw_rectangle(self, eclick, erelease):
|
|
201
|
+
"""eclick and erelease are the press and release events"""
|
|
202
|
+
x0, y0 = eclick.xdata, eclick.ydata
|
|
203
|
+
x1, y1 = erelease.xdata, erelease.ydata
|
|
204
|
+
min_x = min(x0, x1)
|
|
205
|
+
max_x = max(x0, x1)
|
|
206
|
+
min_y = min(y0, y1)
|
|
207
|
+
max_y = max(y0, y1)
|
|
208
|
+
|
|
209
|
+
self.rect_ind_mask = in_bbox(min_x, max_x, min_y, max_y, self.xys)
|
|
210
|
+
|
|
211
|
+
if self.manual_clustering_method is not None:
|
|
212
|
+
self.manual_clustering_method(self.rect_ind_mask)
|
|
213
|
+
|
|
214
|
+
def reset(self):
|
|
215
|
+
self.axes.clear()
|
|
216
|
+
self.is_pressed = None
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class FigureToolbar(NavigationToolbar):
|
|
220
|
+
def __init__(self, canvas):
|
|
221
|
+
super().__init__(canvas, None)
|
|
222
|
+
self.canvas = canvas
|
|
223
|
+
|
|
224
|
+
def _update_buttons_checked(self):
|
|
225
|
+
super()._update_buttons_checked()
|
|
226
|
+
# changes pan/zoom icons depending on state (checked or not)
|
|
227
|
+
if "pan" in self._actions:
|
|
228
|
+
if self._actions["pan"].isChecked():
|
|
229
|
+
self._actions["pan"].setIcon(
|
|
230
|
+
QIcon(os.path.join(ICON_ROOT, "Pan_checked.png"))
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
self._actions["pan"].setIcon(
|
|
234
|
+
QIcon(os.path.join(ICON_ROOT, "Pan.png"))
|
|
235
|
+
)
|
|
236
|
+
if "zoom" in self._actions:
|
|
237
|
+
if self._actions["zoom"].isChecked():
|
|
238
|
+
self._actions["zoom"].setIcon(
|
|
239
|
+
QIcon(os.path.join(ICON_ROOT, "Zoom_checked.png"))
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
self._actions["zoom"].setIcon(
|
|
243
|
+
QIcon(os.path.join(ICON_ROOT, "Zoom.png"))
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def save_figure(self):
|
|
247
|
+
self.canvas.fig.patch.set_facecolor("white")
|
|
248
|
+
|
|
249
|
+
# changing color of plot background to napari main window color
|
|
250
|
+
self.canvas.axes.set_facecolor("white")
|
|
251
|
+
|
|
252
|
+
# changing colors of all axes
|
|
253
|
+
self.canvas.axes.spines["bottom"].set_color("black")
|
|
254
|
+
self.canvas.axes.spines["top"].set_color("black")
|
|
255
|
+
self.canvas.axes.spines["right"].set_color("black")
|
|
256
|
+
self.canvas.axes.spines["left"].set_color("black")
|
|
257
|
+
self.canvas.axes.xaxis.label.set_color("black")
|
|
258
|
+
self.canvas.axes.yaxis.label.set_color("black")
|
|
259
|
+
|
|
260
|
+
# changing colors of axes labels
|
|
261
|
+
self.canvas.axes.tick_params(axis="x", colors="black")
|
|
262
|
+
self.canvas.axes.tick_params(axis="y", colors="black")
|
|
263
|
+
|
|
264
|
+
# COLORBAR
|
|
265
|
+
# extract already existing colobar from figure
|
|
266
|
+
if len(self.canvas.fig.axes) > 0:
|
|
267
|
+
cb = self.canvas.axes.images[0].colorbar
|
|
268
|
+
cb_label = cb.ax.get_ylabel()
|
|
269
|
+
# set colorbar label plus label color
|
|
270
|
+
cb.set_label(cb_label, color="black")
|
|
271
|
+
|
|
272
|
+
# set colorbar tick color
|
|
273
|
+
cb.ax.yaxis.set_tick_params(color="black")
|
|
274
|
+
|
|
275
|
+
# set colorbar edgecolor
|
|
276
|
+
cb.outline.set_edgecolor("black")
|
|
277
|
+
|
|
278
|
+
# set colorbar ticklabels
|
|
279
|
+
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="black")
|
|
280
|
+
|
|
281
|
+
super().save_figure()
|
|
282
|
+
|
|
283
|
+
self.canvas.fig.patch.set_facecolor("#262930")
|
|
284
|
+
|
|
285
|
+
# changing color of plot background to napari main window color
|
|
286
|
+
self.canvas.axes.set_facecolor("#262930")
|
|
287
|
+
|
|
288
|
+
# changing colors of all axes
|
|
289
|
+
self.canvas.axes.spines["bottom"].set_color("white")
|
|
290
|
+
self.canvas.axes.spines["top"].set_color("white")
|
|
291
|
+
self.canvas.axes.spines["right"].set_color("white")
|
|
292
|
+
self.canvas.axes.spines["left"].set_color("white")
|
|
293
|
+
self.canvas.axes.xaxis.label.set_color("white")
|
|
294
|
+
self.canvas.axes.yaxis.label.set_color("white")
|
|
295
|
+
|
|
296
|
+
# changing colors of axes labels
|
|
297
|
+
self.canvas.axes.tick_params(axis="x", colors="white")
|
|
298
|
+
self.canvas.axes.tick_params(axis="y", colors="white")
|
|
299
|
+
|
|
300
|
+
# COLORBAR
|
|
301
|
+
# extract already existing colobar from figure
|
|
302
|
+
if len(self.canvas.fig.axes) > 0:
|
|
303
|
+
cb = self.canvas.axes.images[0].colorbar
|
|
304
|
+
cb_label = cb.ax.get_ylabel()
|
|
305
|
+
# set colorbar label plus label color
|
|
306
|
+
cb.set_label(cb_label, color="white")
|
|
307
|
+
|
|
308
|
+
# set colorbar tick color
|
|
309
|
+
cb.ax.yaxis.set_tick_params(color="white")
|
|
310
|
+
|
|
311
|
+
# set colorbar edgecolor
|
|
312
|
+
cb.outline.set_edgecolor("white")
|
|
313
|
+
|
|
314
|
+
# set colorbar ticklabels
|
|
315
|
+
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
|
|
316
|
+
|
|
317
|
+
self.canvas.draw()
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class PlotterWidget(Container):
|
|
321
|
+
def __init__(self, napari_viewer):
|
|
322
|
+
super().__init__()
|
|
323
|
+
|
|
324
|
+
self.cluster_ids = None
|
|
325
|
+
self._viewer = napari_viewer
|
|
326
|
+
|
|
327
|
+
self.cluster_labels_layer = None
|
|
328
|
+
self.quantityX_smoothed_layer = None
|
|
329
|
+
self.quantityY_smoothed_layer = None
|
|
330
|
+
|
|
331
|
+
self.quantityX_labels_choices_displayed = False
|
|
332
|
+
self.quantityY_labels_choices_displayed = False
|
|
333
|
+
|
|
334
|
+
self.histogram_displayed = False
|
|
335
|
+
|
|
336
|
+
self.figure = None
|
|
337
|
+
|
|
338
|
+
self.labels_method_choices = ["cellular density", "volume fraction"]
|
|
339
|
+
self._hidden_features = {}
|
|
340
|
+
|
|
341
|
+
# Canvas Widget that displays the 'figure', it takes the 'figure' instance
|
|
342
|
+
if True:
|
|
343
|
+
self.graphics_widget = MplCanvas(
|
|
344
|
+
manual_clustering_method=self.manual_clustering_method,
|
|
345
|
+
xys=None,
|
|
346
|
+
)
|
|
347
|
+
self.toolbar = FigureToolbar(self.graphics_widget)
|
|
348
|
+
|
|
349
|
+
self.toolbar.native = self.toolbar
|
|
350
|
+
self.toolbar._explicitly_hidden = False
|
|
351
|
+
self.toolbar.name = ""
|
|
352
|
+
|
|
353
|
+
self.graphics_widget.native = self.graphics_widget
|
|
354
|
+
self.graphics_widget._explicitly_hidden = False
|
|
355
|
+
self.graphics_widget.name = ""
|
|
356
|
+
|
|
357
|
+
self.graph_container = Container(
|
|
358
|
+
widgets=[
|
|
359
|
+
self.toolbar,
|
|
360
|
+
self.graphics_widget,
|
|
361
|
+
],
|
|
362
|
+
labels=False,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
self.quantityX_layer_combo = create_widget(
|
|
366
|
+
annotation=Layer,
|
|
367
|
+
label="Quantity X",
|
|
368
|
+
options={"choices": self._image_labels_layers_filter},
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
self.quantityX_layer_combo.changed.connect(
|
|
372
|
+
self._update_quantities_labels_choices
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
self.quantityX_labels_choices_combo = create_widget(
|
|
376
|
+
widget_type="ComboBox",
|
|
377
|
+
options={"choices": self.labels_method_choices},
|
|
378
|
+
)
|
|
379
|
+
self.quantityX_labels_choices_container = Container(
|
|
380
|
+
widgets=[self.quantityX_labels_choices_combo],
|
|
381
|
+
labels=False,
|
|
382
|
+
layout="horizontal",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
self.quantityY_layer_combo = create_widget(
|
|
386
|
+
annotation=Layer,
|
|
387
|
+
label="Quantity Y",
|
|
388
|
+
options={"choices": self._image_labels_layers_filter},
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
self.quantityY_layer_combo.changed.connect(
|
|
392
|
+
self._update_quantities_labels_choices
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
self.quantityY_labels_choices_combo = create_widget(
|
|
396
|
+
widget_type="ComboBox",
|
|
397
|
+
options={"choices": self.labels_method_choices},
|
|
398
|
+
)
|
|
399
|
+
self.quantityY_labels_choices_container = Container(
|
|
400
|
+
widgets=[self.quantityY_labels_choices_combo],
|
|
401
|
+
labels=False,
|
|
402
|
+
layout="horizontal",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
self.mask_layer_combo = create_widget(
|
|
406
|
+
annotation=Image,
|
|
407
|
+
label="Mask layer",
|
|
408
|
+
options={
|
|
409
|
+
"nullable": True,
|
|
410
|
+
"choices": self._bool_layers_filter,
|
|
411
|
+
},
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
self.labels_layer_combo = create_widget(
|
|
415
|
+
annotation=Labels,
|
|
416
|
+
label="Labels layer",
|
|
417
|
+
options={"nullable": True},
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
self.blur_sigma_slider = create_widget(
|
|
421
|
+
widget_type="IntSlider",
|
|
422
|
+
label="Blur sigma",
|
|
423
|
+
options={"min": 0, "max": 50, "value": 1},
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# self.blur_sigma_slider.changed.connect(self.sigma_changed)
|
|
427
|
+
self.run_button = create_widget(
|
|
428
|
+
widget_type="PushButton",
|
|
429
|
+
label="Compute correlation heatmap",
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
self.run_button.clicked.connect(self.run)
|
|
433
|
+
|
|
434
|
+
self.show_individual_cells_checkbox = create_widget(
|
|
435
|
+
annotation=bool,
|
|
436
|
+
label="Show individual cells",
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
self.show_individual_cells_checkbox.changed.connect(
|
|
440
|
+
self.parameters_changed
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
self.show_linear_fit_checkbox = create_widget(
|
|
444
|
+
annotation=bool,
|
|
445
|
+
label="Show linear fit",
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
self.show_linear_fit_checkbox.changed.connect(
|
|
449
|
+
self.parameters_changed
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
#! normalize is currently broken with manual selection
|
|
453
|
+
# self.normalize_quantities_checkbox = create_widget(
|
|
454
|
+
# annotation=bool, label="Normalize quantities",
|
|
455
|
+
# )
|
|
456
|
+
|
|
457
|
+
# self.normalize_quantities_checkbox.changed.connect(self.parameters_changed)
|
|
458
|
+
|
|
459
|
+
self.display_quadrants = create_widget(
|
|
460
|
+
annotation=bool,
|
|
461
|
+
label="Display quadrants",
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
self.display_quadrants.changed.connect(self.parameters_changed)
|
|
465
|
+
|
|
466
|
+
self.options_container1 = Container(
|
|
467
|
+
widgets=[
|
|
468
|
+
self.show_individual_cells_checkbox,
|
|
469
|
+
self.show_linear_fit_checkbox,
|
|
470
|
+
],
|
|
471
|
+
labels=False,
|
|
472
|
+
layout="horizontal",
|
|
473
|
+
)
|
|
474
|
+
self.options_container2 = Container(
|
|
475
|
+
widgets=[
|
|
476
|
+
# self.normalize_quantities_checkbox, #! normalize is currently broken with manual selection
|
|
477
|
+
self.display_quadrants,
|
|
478
|
+
],
|
|
479
|
+
labels=False,
|
|
480
|
+
layout="horizontal",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
self.heatmap_binsX = create_widget(
|
|
484
|
+
widget_type="IntSlider",
|
|
485
|
+
label="X",
|
|
486
|
+
value=40,
|
|
487
|
+
options={"min": 2, "max": 100, "tracking": True},
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
self.heatmap_binsX.changed.connect(self.parameters_changed)
|
|
491
|
+
|
|
492
|
+
self.heatmap_binsY = create_widget(
|
|
493
|
+
widget_type="IntSlider",
|
|
494
|
+
label="Y",
|
|
495
|
+
value=40,
|
|
496
|
+
options={"min": 2, "max": 100, "tracking": True},
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
self.heatmap_binsY.changed.connect(self.parameters_changed)
|
|
500
|
+
|
|
501
|
+
self.heatmap_bins_container = Container(
|
|
502
|
+
widgets=[
|
|
503
|
+
self.heatmap_binsX,
|
|
504
|
+
self.heatmap_binsY,
|
|
505
|
+
],
|
|
506
|
+
labels=True,
|
|
507
|
+
label="Heatmap bins",
|
|
508
|
+
layout="horizontal",
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
self.percentilesX = create_widget(
|
|
512
|
+
widget_type="FloatRangeSlider",
|
|
513
|
+
label="X",
|
|
514
|
+
options={
|
|
515
|
+
"min": 0,
|
|
516
|
+
"max": 100,
|
|
517
|
+
"value": [0, 100],
|
|
518
|
+
"tracking": True,
|
|
519
|
+
},
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
self.percentilesX.changed.connect(self.parameters_changed)
|
|
523
|
+
|
|
524
|
+
self.percentilesY = create_widget(
|
|
525
|
+
widget_type="FloatRangeSlider",
|
|
526
|
+
label="Y",
|
|
527
|
+
options={
|
|
528
|
+
"min": 0,
|
|
529
|
+
"max": 100,
|
|
530
|
+
"value": [0, 100],
|
|
531
|
+
"tracking": True,
|
|
532
|
+
},
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
self.percentilesY.changed.connect(self.parameters_changed)
|
|
536
|
+
|
|
537
|
+
self.percentiles_container = Container(
|
|
538
|
+
widgets=[
|
|
539
|
+
self.percentilesX,
|
|
540
|
+
self.percentilesY,
|
|
541
|
+
],
|
|
542
|
+
labels=True,
|
|
543
|
+
label="Percentiles",
|
|
544
|
+
layout="horizontal",
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
parameters_text = EmptyWidget(label="<u>Parameters:</u>")
|
|
548
|
+
|
|
549
|
+
display_parameters_text = EmptyWidget(
|
|
550
|
+
label="<u>Display Parameters:</u>"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
self.extend(
|
|
554
|
+
[
|
|
555
|
+
parameters_text,
|
|
556
|
+
self.quantityX_layer_combo,
|
|
557
|
+
self.quantityY_layer_combo,
|
|
558
|
+
self.mask_layer_combo,
|
|
559
|
+
self.labels_layer_combo,
|
|
560
|
+
self.blur_sigma_slider,
|
|
561
|
+
self.run_button,
|
|
562
|
+
self.graph_container,
|
|
563
|
+
display_parameters_text,
|
|
564
|
+
self.options_container1,
|
|
565
|
+
self.options_container2,
|
|
566
|
+
self.heatmap_bins_container,
|
|
567
|
+
self.percentiles_container,
|
|
568
|
+
]
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# takes care of case where this isn't set yet directly after init
|
|
572
|
+
self.plot_cluster_name = None
|
|
573
|
+
|
|
574
|
+
self.id = 0
|
|
575
|
+
|
|
576
|
+
def manual_clustering_method(self, inside):
|
|
577
|
+
|
|
578
|
+
inside = np.array(inside) # leads to errors sometimes otherwise
|
|
579
|
+
if len(inside) == 0:
|
|
580
|
+
return # if nothing was plotted yet, leave
|
|
581
|
+
|
|
582
|
+
clustering_ID = "MANUAL_CLUSTER_ID"
|
|
583
|
+
|
|
584
|
+
modifiers = QGuiApplication.keyboardModifiers()
|
|
585
|
+
if (
|
|
586
|
+
modifiers == Qt.ShiftModifier
|
|
587
|
+
and clustering_ID in self._hidden_features.keys()
|
|
588
|
+
):
|
|
589
|
+
former_clusters = self._hidden_features[clustering_ID]
|
|
590
|
+
former_clusters[inside] = np.max(former_clusters) + 1
|
|
591
|
+
self._hidden_features.update({clustering_ID: former_clusters})
|
|
592
|
+
else:
|
|
593
|
+
self._hidden_features[clustering_ID] = inside.astype(int)
|
|
594
|
+
|
|
595
|
+
# redraw the whole plot
|
|
596
|
+
self.draw_cluster_labels(
|
|
597
|
+
self._hidden_features,
|
|
598
|
+
plot_cluster_name=clustering_ID,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def run(self):
|
|
602
|
+
# Check if all necessary layers are specified
|
|
603
|
+
if self.quantityX_layer_combo.value is None:
|
|
604
|
+
napari.utils.notifications.show_warning(
|
|
605
|
+
"Please specify quantityX_layer"
|
|
606
|
+
)
|
|
607
|
+
return
|
|
608
|
+
else:
|
|
609
|
+
self.quantityX = self.quantityX_layer_combo.value.data
|
|
610
|
+
self.quantityX_label = self.quantityX_layer_combo.value.name
|
|
611
|
+
if isinstance(self.quantityX_layer_combo.value, Labels):
|
|
612
|
+
self.quantityX_colormap = "inferno"
|
|
613
|
+
else:
|
|
614
|
+
self.quantityX_colormap = (
|
|
615
|
+
self.quantityX_layer_combo.value.colormap
|
|
616
|
+
)
|
|
617
|
+
self.quantityX_is_labels = isinstance(
|
|
618
|
+
self.quantityX_layer_combo.value, Labels
|
|
619
|
+
)
|
|
620
|
+
self.quantityX_labels_choice = (
|
|
621
|
+
self.quantityX_labels_choices_combo.value
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
if self.quantityY_layer_combo.value is None:
|
|
625
|
+
napari.utils.notifications.show_warning(
|
|
626
|
+
"Please specify quantityY_layer"
|
|
627
|
+
)
|
|
628
|
+
return
|
|
629
|
+
else:
|
|
630
|
+
self.quantityY = self.quantityY_layer_combo.value.data
|
|
631
|
+
self.quantityY_label = self.quantityY_layer_combo.value.name
|
|
632
|
+
if isinstance(self.quantityY_layer_combo.value, Labels):
|
|
633
|
+
self.quantityY_colormap = "inferno"
|
|
634
|
+
else:
|
|
635
|
+
self.quantityY_colormap = (
|
|
636
|
+
self.quantityY_layer_combo.value.colormap
|
|
637
|
+
)
|
|
638
|
+
self.quantityY_is_labels = isinstance(
|
|
639
|
+
self.quantityY_layer_combo.value, Labels
|
|
640
|
+
)
|
|
641
|
+
self.quantityY_labels_choice = (
|
|
642
|
+
self.quantityY_labels_choices_combo.value
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
if self.mask_layer_combo.value is not None:
|
|
646
|
+
self.mask = self.mask_layer_combo.value.data
|
|
647
|
+
else:
|
|
648
|
+
self.mask = None
|
|
649
|
+
|
|
650
|
+
if self.labels_layer_combo.value is not None:
|
|
651
|
+
self.labels_image = self.labels_layer_combo.value.data
|
|
652
|
+
else:
|
|
653
|
+
self.labels_image = None
|
|
654
|
+
if self.mask is not None:
|
|
655
|
+
self.argwheres = np.argwhere(self.mask)
|
|
656
|
+
else:
|
|
657
|
+
shape = self.quantityX.shape
|
|
658
|
+
# use np.mgrid
|
|
659
|
+
self.argwheres = (
|
|
660
|
+
np.mgrid[0 : shape[0], 0 : shape[1], 0 : shape[2]]
|
|
661
|
+
.reshape(3, -1)
|
|
662
|
+
.T
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Blur the layers
|
|
666
|
+
smoothedX, smoothedY = self._smooth_quantities(
|
|
667
|
+
self.quantityX,
|
|
668
|
+
self.quantityX_is_labels,
|
|
669
|
+
self.quantityX_labels_choice,
|
|
670
|
+
self.quantityY,
|
|
671
|
+
self.quantityY_is_labels,
|
|
672
|
+
self.quantityY_labels_choice,
|
|
673
|
+
self.mask,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
self._update_smoothed_layers(
|
|
677
|
+
smoothedX,
|
|
678
|
+
self.quantityX_colormap,
|
|
679
|
+
smoothedY,
|
|
680
|
+
self.quantityY_colormap,
|
|
681
|
+
)
|
|
682
|
+
self.plot_from_smoothed(
|
|
683
|
+
smoothedX,
|
|
684
|
+
self.quantityX_is_labels,
|
|
685
|
+
self.quantityX_label,
|
|
686
|
+
self.quantityX_labels_choice,
|
|
687
|
+
smoothedY,
|
|
688
|
+
self.quantityY_is_labels,
|
|
689
|
+
self.quantityY_label,
|
|
690
|
+
self.quantityY_labels_choice,
|
|
691
|
+
self.mask,
|
|
692
|
+
self.labels_image,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Set a parameter "self.histogram_displayed" to True
|
|
696
|
+
self.histogram_displayed = True
|
|
697
|
+
|
|
698
|
+
if self.cluster_labels_layer is not None:
|
|
699
|
+
self.cluster_labels_layer.data = np.zeros_like(
|
|
700
|
+
self.cluster_labels_layer.data
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
def _update_smoothed_layers(
|
|
704
|
+
self, blurredX, X_colormap, blurredY, Y_colormap
|
|
705
|
+
):
|
|
706
|
+
if (
|
|
707
|
+
self.quantityX_smoothed_layer is None
|
|
708
|
+
or self.quantityX_smoothed_layer not in self._viewer.layers
|
|
709
|
+
):
|
|
710
|
+
self.quantityX_smoothed_layer = self._viewer.add_image(
|
|
711
|
+
blurredX, colormap=X_colormap
|
|
712
|
+
)
|
|
713
|
+
else:
|
|
714
|
+
self.quantityX_smoothed_layer.data = blurredX
|
|
715
|
+
|
|
716
|
+
if (
|
|
717
|
+
self.quantityY_smoothed_layer is None
|
|
718
|
+
or self.quantityY_smoothed_layer not in self._viewer.layers
|
|
719
|
+
):
|
|
720
|
+
self.quantityY_smoothed_layer = self._viewer.add_image(
|
|
721
|
+
blurredY, colormap=Y_colormap
|
|
722
|
+
)
|
|
723
|
+
else:
|
|
724
|
+
self.quantityY_smoothed_layer.data = blurredY
|
|
725
|
+
|
|
726
|
+
def _update_quantities_labels_choices(self, event):
|
|
727
|
+
|
|
728
|
+
if isinstance(self.quantityX_layer_combo.value, Labels):
|
|
729
|
+
if not self.quantityX_labels_choices_displayed:
|
|
730
|
+
|
|
731
|
+
self.insert(
|
|
732
|
+
self.index(self.quantityX_layer_combo) + 1,
|
|
733
|
+
self.quantityX_labels_choices_container,
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
self.quantityX_labels_choices_displayed = True
|
|
737
|
+
|
|
738
|
+
else:
|
|
739
|
+
if self.quantityX_labels_choices_displayed:
|
|
740
|
+
|
|
741
|
+
self.remove(self.quantityX_labels_choices_container)
|
|
742
|
+
self.quantityX_labels_choices_displayed = False
|
|
743
|
+
|
|
744
|
+
if isinstance(self.quantityY_layer_combo.value, Labels):
|
|
745
|
+
if not self.quantityY_labels_choices_displayed:
|
|
746
|
+
|
|
747
|
+
self.insert(
|
|
748
|
+
self.index(self.quantityY_layer_combo) + 1,
|
|
749
|
+
self.quantityY_labels_choices_container,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
self.quantityY_labels_choices_displayed = True
|
|
753
|
+
|
|
754
|
+
else:
|
|
755
|
+
if self.quantityY_labels_choices_displayed:
|
|
756
|
+
|
|
757
|
+
self.remove(self.quantityY_labels_choices_container)
|
|
758
|
+
self.quantityY_labels_choices_displayed = False
|
|
759
|
+
|
|
760
|
+
def _transform_labels_to_density(self, labels, method):
|
|
761
|
+
self.test_value = True
|
|
762
|
+
if method == self.labels_method_choices[0]:
|
|
763
|
+
props = regionprops(labels)
|
|
764
|
+
centroids = np.array([prop.centroid for prop in props]).astype(int)
|
|
765
|
+
|
|
766
|
+
labels = np.zeros(labels.shape, dtype=bool)
|
|
767
|
+
|
|
768
|
+
labels[centroids[:, 0], centroids[:, 1], centroids[:, 2]] = True
|
|
769
|
+
|
|
770
|
+
return labels
|
|
771
|
+
|
|
772
|
+
elif method == self.labels_method_choices[1]:
|
|
773
|
+
|
|
774
|
+
return labels.astype(bool)
|
|
775
|
+
|
|
776
|
+
def _smooth_quantities(
|
|
777
|
+
self,
|
|
778
|
+
quantityX,
|
|
779
|
+
quantityX_is_labels,
|
|
780
|
+
quantityX_labels_choice,
|
|
781
|
+
quantityY,
|
|
782
|
+
quantityY_is_labels,
|
|
783
|
+
quantityY_labels_choice,
|
|
784
|
+
mask,
|
|
785
|
+
):
|
|
786
|
+
|
|
787
|
+
masks_volume = []
|
|
788
|
+
|
|
789
|
+
if quantityX_is_labels:
|
|
790
|
+
quantityX = self._transform_labels_to_density(
|
|
791
|
+
quantityX, quantityX_labels_choice
|
|
792
|
+
)
|
|
793
|
+
masks_volume.append(None)
|
|
794
|
+
if quantityY_is_labels:
|
|
795
|
+
quantityY = self._transform_labels_to_density(
|
|
796
|
+
quantityY, quantityY_labels_choice
|
|
797
|
+
)
|
|
798
|
+
masks_volume.append(None)
|
|
799
|
+
|
|
800
|
+
if mask is None or all([elem is None for elem in masks_volume]):
|
|
801
|
+
masks_volume = None
|
|
802
|
+
|
|
803
|
+
if self.blur_sigma_slider.value > 0:
|
|
804
|
+
smoothedX, smoothedY = masked_gaussian_smooth_dense_two_arrays_gpu(
|
|
805
|
+
datas=[quantityX, quantityY],
|
|
806
|
+
sigmas=self.blur_sigma_slider.value,
|
|
807
|
+
mask=mask,
|
|
808
|
+
masks_for_volume=masks_volume,
|
|
809
|
+
)
|
|
810
|
+
else:
|
|
811
|
+
smoothedX, smoothedY = quantityX, quantityY
|
|
812
|
+
|
|
813
|
+
return smoothedX, smoothedY
|
|
814
|
+
|
|
815
|
+
def plot_from_smoothed(
|
|
816
|
+
self,
|
|
817
|
+
smoothedX,
|
|
818
|
+
quantityX_is_labels,
|
|
819
|
+
quantityX_label,
|
|
820
|
+
quantityX_labels_choice,
|
|
821
|
+
smoothedY,
|
|
822
|
+
quantityY_is_labels,
|
|
823
|
+
quantityY_label,
|
|
824
|
+
quantityY_labels_choice,
|
|
825
|
+
mask,
|
|
826
|
+
labels,
|
|
827
|
+
):
|
|
828
|
+
# Construct HeatmapPlotter
|
|
829
|
+
self.heatmap_plotter = SpatialCorrelationPlotter(
|
|
830
|
+
quantity_X=smoothedX,
|
|
831
|
+
quantity_Y=smoothedY,
|
|
832
|
+
mask=mask,
|
|
833
|
+
labels=labels,
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
if quantityX_is_labels:
|
|
837
|
+
labelX = quantityX_labels_choice
|
|
838
|
+
else:
|
|
839
|
+
labelX = quantityX_label
|
|
840
|
+
|
|
841
|
+
if quantityY_is_labels:
|
|
842
|
+
labelY = quantityY_labels_choice
|
|
843
|
+
else:
|
|
844
|
+
labelY = quantityY_label
|
|
845
|
+
|
|
846
|
+
# Get figure from HeatmapPlotter
|
|
847
|
+
figure, _ = self.heatmap_plotter.get_heatmap_figure(
|
|
848
|
+
bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
|
|
849
|
+
show_individual_cells=self.show_individual_cells_checkbox.value,
|
|
850
|
+
show_linear_fit=self.show_linear_fit_checkbox.value,
|
|
851
|
+
# normalize_quantities=self.normalize_quantities_checkbox.value,
|
|
852
|
+
normalize_quantities=False, #! normalize is currently broken with manual selection
|
|
853
|
+
percentiles_X=self.percentilesX.value,
|
|
854
|
+
percentiles_Y=self.percentilesY.value,
|
|
855
|
+
figsize=self.graphics_widget.figure.get_size_inches(),
|
|
856
|
+
label_X=labelX,
|
|
857
|
+
label_Y=labelY,
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# Display figure in graphics_widget
|
|
861
|
+
self.plot_heatmap(figure)
|
|
862
|
+
|
|
863
|
+
def plot_heatmap(self, figure):
|
|
864
|
+
|
|
865
|
+
if self.figure is not None:
|
|
866
|
+
plt.close(self.figure)
|
|
867
|
+
self.figure = figure
|
|
868
|
+
|
|
869
|
+
# labels_layer_exists = self.labels_layer_combo.value is not None
|
|
870
|
+
|
|
871
|
+
xys = self.heatmap_plotter.xys
|
|
872
|
+
|
|
873
|
+
self.graphics_widget = MplCanvas(
|
|
874
|
+
parent=figure,
|
|
875
|
+
manual_clustering_method=self.manual_clustering_method,
|
|
876
|
+
create_selectors=True, # labels_layer_exists,
|
|
877
|
+
xys=xys,
|
|
878
|
+
)
|
|
879
|
+
self.toolbar = FigureToolbar(self.graphics_widget)
|
|
880
|
+
|
|
881
|
+
self.toolbar.native = self.toolbar
|
|
882
|
+
self.toolbar._explicitly_hidden = False
|
|
883
|
+
self.toolbar.name = ""
|
|
884
|
+
|
|
885
|
+
self.graphics_widget.native = self.graphics_widget
|
|
886
|
+
self.graphics_widget._explicitly_hidden = False
|
|
887
|
+
self.graphics_widget.name = ""
|
|
888
|
+
|
|
889
|
+
new_graph_container = Container(
|
|
890
|
+
widgets=[
|
|
891
|
+
self.toolbar,
|
|
892
|
+
self.graphics_widget,
|
|
893
|
+
],
|
|
894
|
+
labels=False,
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
widget_index = self.index(self.graph_container)
|
|
898
|
+
self.remove(self.graph_container)
|
|
899
|
+
self.insert(widget_index, new_graph_container)
|
|
900
|
+
self.graph_container = new_graph_container
|
|
901
|
+
self.graphics_widget.draw()
|
|
902
|
+
|
|
903
|
+
def parameters_changed(self):
|
|
904
|
+
if self.histogram_displayed:
|
|
905
|
+
|
|
906
|
+
labelX = (
|
|
907
|
+
self.quantityX_labels_choice
|
|
908
|
+
if self.quantityX_is_labels
|
|
909
|
+
else self.quantityX_label
|
|
910
|
+
)
|
|
911
|
+
labelY = (
|
|
912
|
+
self.quantityY_labels_choice
|
|
913
|
+
if self.quantityY_is_labels
|
|
914
|
+
else self.quantityY_label
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
t0 = time()
|
|
918
|
+
# Get figure from HeatmapPlotter
|
|
919
|
+
figure, _ = self.heatmap_plotter.get_heatmap_figure(
|
|
920
|
+
bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
|
|
921
|
+
show_individual_cells=self.show_individual_cells_checkbox.value,
|
|
922
|
+
show_linear_fit=self.show_linear_fit_checkbox.value,
|
|
923
|
+
# normalize_quantities=self.normalize_quantities_checkbox.value, #! normalize is currently broken with manual selection
|
|
924
|
+
normalize_quantities=False,
|
|
925
|
+
percentiles_X=self.percentilesX.value,
|
|
926
|
+
percentiles_Y=self.percentilesY.value,
|
|
927
|
+
figsize=self.graphics_widget.figure.get_size_inches(),
|
|
928
|
+
label_X=labelX,
|
|
929
|
+
label_Y=labelY,
|
|
930
|
+
display_quadrants=self.display_quadrants.value,
|
|
931
|
+
)
|
|
932
|
+
print("Time to get figure:", time() - t0)
|
|
933
|
+
|
|
934
|
+
# Display figure in graphics_widget -> Create a method "self.plot"
|
|
935
|
+
self.plot_heatmap(figure)
|
|
936
|
+
|
|
937
|
+
def _image_labels_layers_filter(self, wdg):
|
|
938
|
+
return [
|
|
939
|
+
layer
|
|
940
|
+
for layer in self._viewer.layers
|
|
941
|
+
if isinstance(layer, Image | Labels)
|
|
942
|
+
]
|
|
943
|
+
|
|
944
|
+
def _bool_layers_filter(self, wdg):
|
|
945
|
+
return [
|
|
946
|
+
layer
|
|
947
|
+
for layer in self._viewer.layers
|
|
948
|
+
if (isinstance(layer, Image) and layer.data.dtype == bool)
|
|
949
|
+
]
|
|
950
|
+
|
|
951
|
+
def draw_cluster_labels(
|
|
952
|
+
self,
|
|
953
|
+
features,
|
|
954
|
+
plot_cluster_name=None,
|
|
955
|
+
):
|
|
956
|
+
"""
|
|
957
|
+
Takes the manually selected points and plot the cluster on the labels image
|
|
958
|
+
"""
|
|
959
|
+
|
|
960
|
+
# self.analysed_layer = self.labels_select.value
|
|
961
|
+
# labels_layer = self.labels_layer_combo.value
|
|
962
|
+
# mask_layer = self.mask_layer_combo.value
|
|
963
|
+
# self.graphics_widget.reset()
|
|
964
|
+
|
|
965
|
+
# fill all prediction nan values with -1
|
|
966
|
+
self.cluster_ids = features[plot_cluster_name] # .fillna(-1)
|
|
967
|
+
|
|
968
|
+
self.graphics_widget.selector.disconnect()
|
|
969
|
+
self.graphics_widget.selector = SelectFromCollection(
|
|
970
|
+
self.graphics_widget,
|
|
971
|
+
self.graphics_widget.axes,
|
|
972
|
+
xys=self.graphics_widget.xys,
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# generate dictionary mapping each prediction to its respective color
|
|
976
|
+
# list cycling with % introduced for all labels except hdbscan noise points (id = -1)
|
|
977
|
+
cmap_dict = {
|
|
978
|
+
int(prediction + 1): (
|
|
979
|
+
cmap[int(prediction) % len(cmap)]
|
|
980
|
+
if prediction > 0
|
|
981
|
+
else [0, 0, 0, 0]
|
|
982
|
+
)
|
|
983
|
+
for prediction in range(np.max(self.cluster_ids) + 1)
|
|
984
|
+
}
|
|
985
|
+
# take care of background label
|
|
986
|
+
cmap_dict[None] = [0, 0, 0, 0]
|
|
987
|
+
|
|
988
|
+
napari_cmap = DirectLabelColormap(color_dict=cmap_dict)
|
|
989
|
+
|
|
990
|
+
keep_selection = list(self._viewer.layers.selection)
|
|
991
|
+
|
|
992
|
+
if self.labels_image is not None:
|
|
993
|
+
cluster_image = self.generate_cluster_image_from_labels(
|
|
994
|
+
self.labels_image, self.cluster_ids
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
elif self.mask is not None:
|
|
998
|
+
cluster_image = self.generate_cluster_image_from_points(
|
|
999
|
+
self.argwheres,
|
|
1000
|
+
self.cluster_ids,
|
|
1001
|
+
shape=self.quantityX.shape,
|
|
1002
|
+
)
|
|
1003
|
+
else:
|
|
1004
|
+
cluster_image = self.generate_cluster_image_from_points(
|
|
1005
|
+
self.argwheres,
|
|
1006
|
+
self.cluster_ids,
|
|
1007
|
+
shape=self.quantityX_layer_combo.value.data.shape,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
# if the cluster image layer doesn't yet exist make it
|
|
1011
|
+
# otherwise just update it
|
|
1012
|
+
if (
|
|
1013
|
+
self.cluster_labels_layer is None
|
|
1014
|
+
or self.cluster_labels_layer not in self._viewer.layers
|
|
1015
|
+
):
|
|
1016
|
+
# visualising cluster image
|
|
1017
|
+
self.cluster_labels_layer = self._viewer.add_labels(
|
|
1018
|
+
cluster_image, # self.analysed_layer.data
|
|
1019
|
+
colormap=napari_cmap, # cluster_id_dict
|
|
1020
|
+
name="clustered labels",
|
|
1021
|
+
opacity=1,
|
|
1022
|
+
)
|
|
1023
|
+
else:
|
|
1024
|
+
# updating data
|
|
1025
|
+
self.cluster_labels_layer.data = cluster_image
|
|
1026
|
+
self.cluster_labels_layer.colormap = napari_cmap
|
|
1027
|
+
|
|
1028
|
+
self._viewer.layers.selection.clear()
|
|
1029
|
+
for s in keep_selection:
|
|
1030
|
+
self._viewer.layers.selection.add(s)
|
|
1031
|
+
|
|
1032
|
+
def generate_cluster_image_from_labels(self, label_image, predictionlist):
|
|
1033
|
+
props = regionprops(label_image)
|
|
1034
|
+
|
|
1035
|
+
cluster_image = np.zeros(label_image.shape, dtype="uint8")
|
|
1036
|
+
|
|
1037
|
+
argwheres = np.argwhere(predictionlist > 0).flatten()
|
|
1038
|
+
|
|
1039
|
+
for index in argwheres:
|
|
1040
|
+
prop = props[index]
|
|
1041
|
+
roi_data = label_image[prop.slice]
|
|
1042
|
+
cluster_image[prop.slice][roi_data == prop.label] = (
|
|
1043
|
+
predictionlist[index] + 1
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
return cluster_image
|
|
1047
|
+
|
|
1048
|
+
def generate_cluster_image_from_points(
|
|
1049
|
+
self, argwheres, predictionlist, shape
|
|
1050
|
+
):
|
|
1051
|
+
|
|
1052
|
+
cluster_image = np.zeros(shape, dtype="uint8")
|
|
1053
|
+
points_to_display = argwheres[predictionlist > 0]
|
|
1054
|
+
|
|
1055
|
+
cluster_image[tuple(points_to_display.T)] = (
|
|
1056
|
+
predictionlist[predictionlist > 0] + 1
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
return cluster_image
|