napari-spatial-correlation-plotter 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1059 @@
1
+ import os
2
+ from pathlib import Path as PathL
3
+ from time import time
4
+
5
+ import matplotlib.pyplot as plt
6
+ import napari
7
+ import numpy as np
8
+ from magicgui.widgets import Container, EmptyWidget, create_widget
9
+ from matplotlib.backends.backend_qt5agg import (
10
+ FigureCanvasQTAgg as FigureCanvas,
11
+ )
12
+ from matplotlib.backends.backend_qt5agg import (
13
+ NavigationToolbar2QT as NavigationToolbar,
14
+ )
15
+ from matplotlib.figure import Figure
16
+ from matplotlib.path import Path
17
+ from matplotlib.widgets import LassoSelector, RectangleSelector
18
+ from napari.layers import Image, Labels, Layer
19
+ from napari.utils import DirectLabelColormap
20
+ from qtpy.QtCore import Qt
21
+ from qtpy.QtGui import QGuiApplication, QIcon
22
+ from skimage.measure import regionprops
23
+ from tapenade.analysis.spatial_correlation import SpatialCorrelationPlotter
24
+ from tapenade.preprocessing import masked_gaussian_smooth_dense_two_arrays_gpu
25
+ from vispy.color import Color
26
+
27
+ from napari_spatial_correlation_plotter._nice_colormap import get_nice_colormap
28
+
29
+ ICON_ROOT = PathL(__file__).parent / "icons"
30
+
31
+
32
+ # TODO:
33
+ # - add log scale to heatmap colors
34
+
35
+
36
+ colors = get_nice_colormap()
37
+ cmap = [Color(hex_name).RGBA.astype("float") / 255 for hex_name in colors]
38
+
39
+
40
+ def in_bbox(min_x, max_x, min_y, max_y, xys):
41
+ mins = np.array([min_x, min_y]).reshape(1, 2)
42
+ maxs = np.array([max_x, max_y]).reshape(1, 2)
43
+
44
+ foo = np.logical_and(xys >= mins, xys <= maxs)
45
+
46
+ return np.logical_and(foo[:, 0], foo[:, 1])
47
+
48
+
49
+ # Class below was based upon matplotlib lasso selection example:
50
+ # https://matplotlib.org/stable/gallery/widgets/lasso_selector_demo_sgskip.html
51
+ class SelectFromCollection:
52
+ """
53
+ Select indices from a matplotlib collection using `LassoSelector`.
54
+ Selected indices are saved in the `ind` attribute. This tool fades out the
55
+ points that are not part of the selection (i.e., reduces their alpha
56
+ values). If your collection has alpha < 1, this tool will permanently
57
+ alter the alpha values.
58
+ Note that this tool selects collection objects based on their *origins*
59
+ (i.e., `offsets`).
60
+ Parameters
61
+ ----------
62
+ ax : `~matplotlib.axes.Axes`
63
+ Axes to interact with.
64
+ collection : `matplotlib.collections.Collection` subclass
65
+ Collection you want to select from.
66
+ alpha_other : 0 <= float <= 1
67
+ To highlight a selection, this tool sets all selected points to an
68
+ alpha value of 1 and non-selected points to *alpha_other*.
69
+ """
70
+
71
+ def __init__(self, parent, ax, xys, alpha_other=0.3):
72
+ self.canvas = ax.figure.canvas
73
+ self.parent = parent
74
+ # self.collection = collection
75
+ # self.alpha_other = alpha_other
76
+
77
+ # self.xys = collection.get_offsets()
78
+ self.xys = xys
79
+ # self.Npts = len(self.xys)
80
+
81
+ self.lasso = LassoSelector(ax, onselect=self.onselect, button=1)
82
+ self.ind = []
83
+ self.ind_mask = []
84
+
85
+ def onselect(self, verts):
86
+ verts = np.array(verts)
87
+ min_x, min_y = np.min(verts, axis=0)
88
+ max_x, max_y = np.max(verts, axis=0)
89
+
90
+ ind_mask = in_bbox(min_x, max_x, min_y, max_y, self.xys)
91
+
92
+ path = Path(verts)
93
+ # ind_mask = np.where(
94
+ # ind_mask, path.contains_points(self.xys[ind_mask]), False
95
+ # )
96
+ ind_mask[ind_mask] = path.contains_points(self.xys[ind_mask])
97
+ self.ind_mask = ind_mask
98
+
99
+ self.canvas.draw_idle()
100
+ # self.selected_coordinates = self.xys[self.ind].data
101
+
102
+ if self.parent.manual_clustering_method is not None:
103
+ self.parent.manual_clustering_method(self.ind_mask)
104
+
105
+ def disconnect(self):
106
+ self.lasso.disconnect_events()
107
+ self.canvas.draw_idle()
108
+
109
+
110
+ class MplCanvas(FigureCanvas):
111
+ def __init__(
112
+ self,
113
+ xys,
114
+ parent=None,
115
+ width=7,
116
+ height=4,
117
+ manual_clustering_method=None,
118
+ create_selectors=False,
119
+ ):
120
+
121
+ self.xys = xys
122
+
123
+ if parent is None:
124
+ self.fig = Figure(figsize=(width, height))
125
+ self.axes = self.fig.add_subplot(111)
126
+ else:
127
+ self.fig = parent
128
+ if len(self.fig.axes) == 0:
129
+ self.fig.add_subplot(111)
130
+ self.axes = self.fig.axes[0]
131
+ # figure size
132
+ self.fig.set_size_inches(width, height)
133
+ self.manual_clustering_method = manual_clustering_method
134
+ self.fig.tight_layout()
135
+
136
+ super().__init__(self.fig)
137
+
138
+ self.reset_params(create_selectors=create_selectors, xys=xys)
139
+
140
+ def reset_params(self, create_selectors, xys):
141
+ self.axes = self.fig.axes[0]
142
+
143
+ if len(self.axes.collections) == 0:
144
+ self.pts = self.axes.scatter([], [])
145
+
146
+ self.pts = self.axes.collections[0]
147
+
148
+ self.fig.patch.set_facecolor("#262930")
149
+
150
+ # changing color of plot background to napari main window color
151
+ if create_selectors:
152
+ self.axes.set_facecolor("white")
153
+ else:
154
+ self.axes.set_facecolor("#262930")
155
+
156
+ # changing colors of all axes
157
+ self.axes.spines["bottom"].set_color("white")
158
+ self.axes.spines["top"].set_color("white")
159
+ self.axes.spines["right"].set_color("white")
160
+ self.axes.spines["left"].set_color("white")
161
+ self.axes.xaxis.label.set_color("white")
162
+ self.axes.yaxis.label.set_color("white")
163
+
164
+ # changing colors of axes labels
165
+ self.axes.tick_params(axis="x", colors="white")
166
+ self.axes.tick_params(axis="y", colors="white")
167
+
168
+ # COLORBAR
169
+ # extract already existing colobar from figure
170
+ if len(self.fig.axes) > 1:
171
+ cb = self.axes.images[0].colorbar
172
+ cb_label = cb.ax.get_ylabel()
173
+ # set colorbar label plus label color
174
+ cb.set_label(cb_label, color="white")
175
+
176
+ # set colorbar tick color
177
+ cb.ax.yaxis.set_tick_params(color="white")
178
+
179
+ # set colorbar edgecolor
180
+ cb.outline.set_edgecolor("white")
181
+
182
+ # set colorbar ticklabels
183
+ plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
184
+
185
+ if create_selectors:
186
+ self.selector = SelectFromCollection(self, self.axes, xys)
187
+ # Rectangle
188
+ self.rectangle_selector = RectangleSelector(
189
+ self.axes,
190
+ self.draw_rectangle,
191
+ useblit=True,
192
+ props=dict(edgecolor="#1f77b4", fill=False),
193
+ button=3, # right button
194
+ minspanx=5,
195
+ minspany=5,
196
+ spancoords="pixels",
197
+ interactive=False,
198
+ )
199
+
200
+ def draw_rectangle(self, eclick, erelease):
201
+ """eclick and erelease are the press and release events"""
202
+ x0, y0 = eclick.xdata, eclick.ydata
203
+ x1, y1 = erelease.xdata, erelease.ydata
204
+ min_x = min(x0, x1)
205
+ max_x = max(x0, x1)
206
+ min_y = min(y0, y1)
207
+ max_y = max(y0, y1)
208
+
209
+ self.rect_ind_mask = in_bbox(min_x, max_x, min_y, max_y, self.xys)
210
+
211
+ if self.manual_clustering_method is not None:
212
+ self.manual_clustering_method(self.rect_ind_mask)
213
+
214
+ def reset(self):
215
+ self.axes.clear()
216
+ self.is_pressed = None
217
+
218
+
219
+ class FigureToolbar(NavigationToolbar):
220
+ def __init__(self, canvas):
221
+ super().__init__(canvas, None)
222
+ self.canvas = canvas
223
+
224
+ def _update_buttons_checked(self):
225
+ super()._update_buttons_checked()
226
+ # changes pan/zoom icons depending on state (checked or not)
227
+ if "pan" in self._actions:
228
+ if self._actions["pan"].isChecked():
229
+ self._actions["pan"].setIcon(
230
+ QIcon(os.path.join(ICON_ROOT, "Pan_checked.png"))
231
+ )
232
+ else:
233
+ self._actions["pan"].setIcon(
234
+ QIcon(os.path.join(ICON_ROOT, "Pan.png"))
235
+ )
236
+ if "zoom" in self._actions:
237
+ if self._actions["zoom"].isChecked():
238
+ self._actions["zoom"].setIcon(
239
+ QIcon(os.path.join(ICON_ROOT, "Zoom_checked.png"))
240
+ )
241
+ else:
242
+ self._actions["zoom"].setIcon(
243
+ QIcon(os.path.join(ICON_ROOT, "Zoom.png"))
244
+ )
245
+
246
+ def save_figure(self):
247
+ self.canvas.fig.patch.set_facecolor("white")
248
+
249
+ # changing color of plot background to napari main window color
250
+ self.canvas.axes.set_facecolor("white")
251
+
252
+ # changing colors of all axes
253
+ self.canvas.axes.spines["bottom"].set_color("black")
254
+ self.canvas.axes.spines["top"].set_color("black")
255
+ self.canvas.axes.spines["right"].set_color("black")
256
+ self.canvas.axes.spines["left"].set_color("black")
257
+ self.canvas.axes.xaxis.label.set_color("black")
258
+ self.canvas.axes.yaxis.label.set_color("black")
259
+
260
+ # changing colors of axes labels
261
+ self.canvas.axes.tick_params(axis="x", colors="black")
262
+ self.canvas.axes.tick_params(axis="y", colors="black")
263
+
264
+ # COLORBAR
265
+ # extract already existing colobar from figure
266
+ if len(self.canvas.fig.axes) > 0:
267
+ cb = self.canvas.axes.images[0].colorbar
268
+ cb_label = cb.ax.get_ylabel()
269
+ # set colorbar label plus label color
270
+ cb.set_label(cb_label, color="black")
271
+
272
+ # set colorbar tick color
273
+ cb.ax.yaxis.set_tick_params(color="black")
274
+
275
+ # set colorbar edgecolor
276
+ cb.outline.set_edgecolor("black")
277
+
278
+ # set colorbar ticklabels
279
+ plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="black")
280
+
281
+ super().save_figure()
282
+
283
+ self.canvas.fig.patch.set_facecolor("#262930")
284
+
285
+ # changing color of plot background to napari main window color
286
+ self.canvas.axes.set_facecolor("#262930")
287
+
288
+ # changing colors of all axes
289
+ self.canvas.axes.spines["bottom"].set_color("white")
290
+ self.canvas.axes.spines["top"].set_color("white")
291
+ self.canvas.axes.spines["right"].set_color("white")
292
+ self.canvas.axes.spines["left"].set_color("white")
293
+ self.canvas.axes.xaxis.label.set_color("white")
294
+ self.canvas.axes.yaxis.label.set_color("white")
295
+
296
+ # changing colors of axes labels
297
+ self.canvas.axes.tick_params(axis="x", colors="white")
298
+ self.canvas.axes.tick_params(axis="y", colors="white")
299
+
300
+ # COLORBAR
301
+ # extract already existing colobar from figure
302
+ if len(self.canvas.fig.axes) > 0:
303
+ cb = self.canvas.axes.images[0].colorbar
304
+ cb_label = cb.ax.get_ylabel()
305
+ # set colorbar label plus label color
306
+ cb.set_label(cb_label, color="white")
307
+
308
+ # set colorbar tick color
309
+ cb.ax.yaxis.set_tick_params(color="white")
310
+
311
+ # set colorbar edgecolor
312
+ cb.outline.set_edgecolor("white")
313
+
314
+ # set colorbar ticklabels
315
+ plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
316
+
317
+ self.canvas.draw()
318
+
319
+
320
+ class PlotterWidget(Container):
321
+ def __init__(self, napari_viewer):
322
+ super().__init__()
323
+
324
+ self.cluster_ids = None
325
+ self._viewer = napari_viewer
326
+
327
+ self.cluster_labels_layer = None
328
+ self.quantityX_smoothed_layer = None
329
+ self.quantityY_smoothed_layer = None
330
+
331
+ self.quantityX_labels_choices_displayed = False
332
+ self.quantityY_labels_choices_displayed = False
333
+
334
+ self.histogram_displayed = False
335
+
336
+ self.figure = None
337
+
338
+ self.labels_method_choices = ["cellular density", "volume fraction"]
339
+ self._hidden_features = {}
340
+
341
+ # Canvas Widget that displays the 'figure', it takes the 'figure' instance
342
+ if True:
343
+ self.graphics_widget = MplCanvas(
344
+ manual_clustering_method=self.manual_clustering_method,
345
+ xys=None,
346
+ )
347
+ self.toolbar = FigureToolbar(self.graphics_widget)
348
+
349
+ self.toolbar.native = self.toolbar
350
+ self.toolbar._explicitly_hidden = False
351
+ self.toolbar.name = ""
352
+
353
+ self.graphics_widget.native = self.graphics_widget
354
+ self.graphics_widget._explicitly_hidden = False
355
+ self.graphics_widget.name = ""
356
+
357
+ self.graph_container = Container(
358
+ widgets=[
359
+ self.toolbar,
360
+ self.graphics_widget,
361
+ ],
362
+ labels=False,
363
+ )
364
+
365
+ self.quantityX_layer_combo = create_widget(
366
+ annotation=Layer,
367
+ label="Quantity X",
368
+ options={"choices": self._image_labels_layers_filter},
369
+ )
370
+
371
+ self.quantityX_layer_combo.changed.connect(
372
+ self._update_quantities_labels_choices
373
+ )
374
+
375
+ self.quantityX_labels_choices_combo = create_widget(
376
+ widget_type="ComboBox",
377
+ options={"choices": self.labels_method_choices},
378
+ )
379
+ self.quantityX_labels_choices_container = Container(
380
+ widgets=[self.quantityX_labels_choices_combo],
381
+ labels=False,
382
+ layout="horizontal",
383
+ )
384
+
385
+ self.quantityY_layer_combo = create_widget(
386
+ annotation=Layer,
387
+ label="Quantity Y",
388
+ options={"choices": self._image_labels_layers_filter},
389
+ )
390
+
391
+ self.quantityY_layer_combo.changed.connect(
392
+ self._update_quantities_labels_choices
393
+ )
394
+
395
+ self.quantityY_labels_choices_combo = create_widget(
396
+ widget_type="ComboBox",
397
+ options={"choices": self.labels_method_choices},
398
+ )
399
+ self.quantityY_labels_choices_container = Container(
400
+ widgets=[self.quantityY_labels_choices_combo],
401
+ labels=False,
402
+ layout="horizontal",
403
+ )
404
+
405
+ self.mask_layer_combo = create_widget(
406
+ annotation=Image,
407
+ label="Mask layer",
408
+ options={
409
+ "nullable": True,
410
+ "choices": self._bool_layers_filter,
411
+ },
412
+ )
413
+
414
+ self.labels_layer_combo = create_widget(
415
+ annotation=Labels,
416
+ label="Labels layer",
417
+ options={"nullable": True},
418
+ )
419
+
420
+ self.blur_sigma_slider = create_widget(
421
+ widget_type="IntSlider",
422
+ label="Blur sigma",
423
+ options={"min": 0, "max": 50, "value": 1},
424
+ )
425
+
426
+ # self.blur_sigma_slider.changed.connect(self.sigma_changed)
427
+ self.run_button = create_widget(
428
+ widget_type="PushButton",
429
+ label="Compute correlation heatmap",
430
+ )
431
+
432
+ self.run_button.clicked.connect(self.run)
433
+
434
+ self.show_individual_cells_checkbox = create_widget(
435
+ annotation=bool,
436
+ label="Show individual cells",
437
+ )
438
+
439
+ self.show_individual_cells_checkbox.changed.connect(
440
+ self.parameters_changed
441
+ )
442
+
443
+ self.show_linear_fit_checkbox = create_widget(
444
+ annotation=bool,
445
+ label="Show linear fit",
446
+ )
447
+
448
+ self.show_linear_fit_checkbox.changed.connect(
449
+ self.parameters_changed
450
+ )
451
+
452
+ #! normalize is currently broken with manual selection
453
+ # self.normalize_quantities_checkbox = create_widget(
454
+ # annotation=bool, label="Normalize quantities",
455
+ # )
456
+
457
+ # self.normalize_quantities_checkbox.changed.connect(self.parameters_changed)
458
+
459
+ self.display_quadrants = create_widget(
460
+ annotation=bool,
461
+ label="Display quadrants",
462
+ )
463
+
464
+ self.display_quadrants.changed.connect(self.parameters_changed)
465
+
466
+ self.options_container1 = Container(
467
+ widgets=[
468
+ self.show_individual_cells_checkbox,
469
+ self.show_linear_fit_checkbox,
470
+ ],
471
+ labels=False,
472
+ layout="horizontal",
473
+ )
474
+ self.options_container2 = Container(
475
+ widgets=[
476
+ # self.normalize_quantities_checkbox, #! normalize is currently broken with manual selection
477
+ self.display_quadrants,
478
+ ],
479
+ labels=False,
480
+ layout="horizontal",
481
+ )
482
+
483
+ self.heatmap_binsX = create_widget(
484
+ widget_type="IntSlider",
485
+ label="X",
486
+ value=40,
487
+ options={"min": 2, "max": 100, "tracking": True},
488
+ )
489
+
490
+ self.heatmap_binsX.changed.connect(self.parameters_changed)
491
+
492
+ self.heatmap_binsY = create_widget(
493
+ widget_type="IntSlider",
494
+ label="Y",
495
+ value=40,
496
+ options={"min": 2, "max": 100, "tracking": True},
497
+ )
498
+
499
+ self.heatmap_binsY.changed.connect(self.parameters_changed)
500
+
501
+ self.heatmap_bins_container = Container(
502
+ widgets=[
503
+ self.heatmap_binsX,
504
+ self.heatmap_binsY,
505
+ ],
506
+ labels=True,
507
+ label="Heatmap bins",
508
+ layout="horizontal",
509
+ )
510
+
511
+ self.percentilesX = create_widget(
512
+ widget_type="FloatRangeSlider",
513
+ label="X",
514
+ options={
515
+ "min": 0,
516
+ "max": 100,
517
+ "value": [0, 100],
518
+ "tracking": True,
519
+ },
520
+ )
521
+
522
+ self.percentilesX.changed.connect(self.parameters_changed)
523
+
524
+ self.percentilesY = create_widget(
525
+ widget_type="FloatRangeSlider",
526
+ label="Y",
527
+ options={
528
+ "min": 0,
529
+ "max": 100,
530
+ "value": [0, 100],
531
+ "tracking": True,
532
+ },
533
+ )
534
+
535
+ self.percentilesY.changed.connect(self.parameters_changed)
536
+
537
+ self.percentiles_container = Container(
538
+ widgets=[
539
+ self.percentilesX,
540
+ self.percentilesY,
541
+ ],
542
+ labels=True,
543
+ label="Percentiles",
544
+ layout="horizontal",
545
+ )
546
+
547
+ parameters_text = EmptyWidget(label="<u>Parameters:</u>")
548
+
549
+ display_parameters_text = EmptyWidget(
550
+ label="<u>Display Parameters:</u>"
551
+ )
552
+
553
+ self.extend(
554
+ [
555
+ parameters_text,
556
+ self.quantityX_layer_combo,
557
+ self.quantityY_layer_combo,
558
+ self.mask_layer_combo,
559
+ self.labels_layer_combo,
560
+ self.blur_sigma_slider,
561
+ self.run_button,
562
+ self.graph_container,
563
+ display_parameters_text,
564
+ self.options_container1,
565
+ self.options_container2,
566
+ self.heatmap_bins_container,
567
+ self.percentiles_container,
568
+ ]
569
+ )
570
+
571
+ # takes care of case where this isn't set yet directly after init
572
+ self.plot_cluster_name = None
573
+
574
+ self.id = 0
575
+
576
+ def manual_clustering_method(self, inside):
577
+
578
+ inside = np.array(inside) # leads to errors sometimes otherwise
579
+ if len(inside) == 0:
580
+ return # if nothing was plotted yet, leave
581
+
582
+ clustering_ID = "MANUAL_CLUSTER_ID"
583
+
584
+ modifiers = QGuiApplication.keyboardModifiers()
585
+ if (
586
+ modifiers == Qt.ShiftModifier
587
+ and clustering_ID in self._hidden_features.keys()
588
+ ):
589
+ former_clusters = self._hidden_features[clustering_ID]
590
+ former_clusters[inside] = np.max(former_clusters) + 1
591
+ self._hidden_features.update({clustering_ID: former_clusters})
592
+ else:
593
+ self._hidden_features[clustering_ID] = inside.astype(int)
594
+
595
+ # redraw the whole plot
596
+ self.draw_cluster_labels(
597
+ self._hidden_features,
598
+ plot_cluster_name=clustering_ID,
599
+ )
600
+
601
+ def run(self):
602
+ # Check if all necessary layers are specified
603
+ if self.quantityX_layer_combo.value is None:
604
+ napari.utils.notifications.show_warning(
605
+ "Please specify quantityX_layer"
606
+ )
607
+ return
608
+ else:
609
+ self.quantityX = self.quantityX_layer_combo.value.data
610
+ self.quantityX_label = self.quantityX_layer_combo.value.name
611
+ if isinstance(self.quantityX_layer_combo.value, Labels):
612
+ self.quantityX_colormap = "inferno"
613
+ else:
614
+ self.quantityX_colormap = (
615
+ self.quantityX_layer_combo.value.colormap
616
+ )
617
+ self.quantityX_is_labels = isinstance(
618
+ self.quantityX_layer_combo.value, Labels
619
+ )
620
+ self.quantityX_labels_choice = (
621
+ self.quantityX_labels_choices_combo.value
622
+ )
623
+
624
+ if self.quantityY_layer_combo.value is None:
625
+ napari.utils.notifications.show_warning(
626
+ "Please specify quantityY_layer"
627
+ )
628
+ return
629
+ else:
630
+ self.quantityY = self.quantityY_layer_combo.value.data
631
+ self.quantityY_label = self.quantityY_layer_combo.value.name
632
+ if isinstance(self.quantityY_layer_combo.value, Labels):
633
+ self.quantityY_colormap = "inferno"
634
+ else:
635
+ self.quantityY_colormap = (
636
+ self.quantityY_layer_combo.value.colormap
637
+ )
638
+ self.quantityY_is_labels = isinstance(
639
+ self.quantityY_layer_combo.value, Labels
640
+ )
641
+ self.quantityY_labels_choice = (
642
+ self.quantityY_labels_choices_combo.value
643
+ )
644
+
645
+ if self.mask_layer_combo.value is not None:
646
+ self.mask = self.mask_layer_combo.value.data
647
+ else:
648
+ self.mask = None
649
+
650
+ if self.labels_layer_combo.value is not None:
651
+ self.labels_image = self.labels_layer_combo.value.data
652
+ else:
653
+ self.labels_image = None
654
+ if self.mask is not None:
655
+ self.argwheres = np.argwhere(self.mask)
656
+ else:
657
+ shape = self.quantityX.shape
658
+ # use np.mgrid
659
+ self.argwheres = (
660
+ np.mgrid[0 : shape[0], 0 : shape[1], 0 : shape[2]]
661
+ .reshape(3, -1)
662
+ .T
663
+ )
664
+
665
+ # Blur the layers
666
+ smoothedX, smoothedY = self._smooth_quantities(
667
+ self.quantityX,
668
+ self.quantityX_is_labels,
669
+ self.quantityX_labels_choice,
670
+ self.quantityY,
671
+ self.quantityY_is_labels,
672
+ self.quantityY_labels_choice,
673
+ self.mask,
674
+ )
675
+
676
+ self._update_smoothed_layers(
677
+ smoothedX,
678
+ self.quantityX_colormap,
679
+ smoothedY,
680
+ self.quantityY_colormap,
681
+ )
682
+ self.plot_from_smoothed(
683
+ smoothedX,
684
+ self.quantityX_is_labels,
685
+ self.quantityX_label,
686
+ self.quantityX_labels_choice,
687
+ smoothedY,
688
+ self.quantityY_is_labels,
689
+ self.quantityY_label,
690
+ self.quantityY_labels_choice,
691
+ self.mask,
692
+ self.labels_image,
693
+ )
694
+
695
+ # Set a parameter "self.histogram_displayed" to True
696
+ self.histogram_displayed = True
697
+
698
+ if self.cluster_labels_layer is not None:
699
+ self.cluster_labels_layer.data = np.zeros_like(
700
+ self.cluster_labels_layer.data
701
+ )
702
+
703
+ def _update_smoothed_layers(
704
+ self, blurredX, X_colormap, blurredY, Y_colormap
705
+ ):
706
+ if (
707
+ self.quantityX_smoothed_layer is None
708
+ or self.quantityX_smoothed_layer not in self._viewer.layers
709
+ ):
710
+ self.quantityX_smoothed_layer = self._viewer.add_image(
711
+ blurredX, colormap=X_colormap
712
+ )
713
+ else:
714
+ self.quantityX_smoothed_layer.data = blurredX
715
+
716
+ if (
717
+ self.quantityY_smoothed_layer is None
718
+ or self.quantityY_smoothed_layer not in self._viewer.layers
719
+ ):
720
+ self.quantityY_smoothed_layer = self._viewer.add_image(
721
+ blurredY, colormap=Y_colormap
722
+ )
723
+ else:
724
+ self.quantityY_smoothed_layer.data = blurredY
725
+
726
+ def _update_quantities_labels_choices(self, event):
727
+
728
+ if isinstance(self.quantityX_layer_combo.value, Labels):
729
+ if not self.quantityX_labels_choices_displayed:
730
+
731
+ self.insert(
732
+ self.index(self.quantityX_layer_combo) + 1,
733
+ self.quantityX_labels_choices_container,
734
+ )
735
+
736
+ self.quantityX_labels_choices_displayed = True
737
+
738
+ else:
739
+ if self.quantityX_labels_choices_displayed:
740
+
741
+ self.remove(self.quantityX_labels_choices_container)
742
+ self.quantityX_labels_choices_displayed = False
743
+
744
+ if isinstance(self.quantityY_layer_combo.value, Labels):
745
+ if not self.quantityY_labels_choices_displayed:
746
+
747
+ self.insert(
748
+ self.index(self.quantityY_layer_combo) + 1,
749
+ self.quantityY_labels_choices_container,
750
+ )
751
+
752
+ self.quantityY_labels_choices_displayed = True
753
+
754
+ else:
755
+ if self.quantityY_labels_choices_displayed:
756
+
757
+ self.remove(self.quantityY_labels_choices_container)
758
+ self.quantityY_labels_choices_displayed = False
759
+
760
+ def _transform_labels_to_density(self, labels, method):
761
+ self.test_value = True
762
+ if method == self.labels_method_choices[0]:
763
+ props = regionprops(labels)
764
+ centroids = np.array([prop.centroid for prop in props]).astype(int)
765
+
766
+ labels = np.zeros(labels.shape, dtype=bool)
767
+
768
+ labels[centroids[:, 0], centroids[:, 1], centroids[:, 2]] = True
769
+
770
+ return labels
771
+
772
+ elif method == self.labels_method_choices[1]:
773
+
774
+ return labels.astype(bool)
775
+
776
+ def _smooth_quantities(
777
+ self,
778
+ quantityX,
779
+ quantityX_is_labels,
780
+ quantityX_labels_choice,
781
+ quantityY,
782
+ quantityY_is_labels,
783
+ quantityY_labels_choice,
784
+ mask,
785
+ ):
786
+
787
+ masks_volume = []
788
+
789
+ if quantityX_is_labels:
790
+ quantityX = self._transform_labels_to_density(
791
+ quantityX, quantityX_labels_choice
792
+ )
793
+ masks_volume.append(None)
794
+ if quantityY_is_labels:
795
+ quantityY = self._transform_labels_to_density(
796
+ quantityY, quantityY_labels_choice
797
+ )
798
+ masks_volume.append(None)
799
+
800
+ if mask is None or all([elem is None for elem in masks_volume]):
801
+ masks_volume = None
802
+
803
+ if self.blur_sigma_slider.value > 0:
804
+ smoothedX, smoothedY = masked_gaussian_smooth_dense_two_arrays_gpu(
805
+ datas=[quantityX, quantityY],
806
+ sigmas=self.blur_sigma_slider.value,
807
+ mask=mask,
808
+ masks_for_volume=masks_volume,
809
+ )
810
+ else:
811
+ smoothedX, smoothedY = quantityX, quantityY
812
+
813
+ return smoothedX, smoothedY
814
+
815
+ def plot_from_smoothed(
816
+ self,
817
+ smoothedX,
818
+ quantityX_is_labels,
819
+ quantityX_label,
820
+ quantityX_labels_choice,
821
+ smoothedY,
822
+ quantityY_is_labels,
823
+ quantityY_label,
824
+ quantityY_labels_choice,
825
+ mask,
826
+ labels,
827
+ ):
828
+ # Construct HeatmapPlotter
829
+ self.heatmap_plotter = SpatialCorrelationPlotter(
830
+ quantity_X=smoothedX,
831
+ quantity_Y=smoothedY,
832
+ mask=mask,
833
+ labels=labels,
834
+ )
835
+
836
+ if quantityX_is_labels:
837
+ labelX = quantityX_labels_choice
838
+ else:
839
+ labelX = quantityX_label
840
+
841
+ if quantityY_is_labels:
842
+ labelY = quantityY_labels_choice
843
+ else:
844
+ labelY = quantityY_label
845
+
846
+ # Get figure from HeatmapPlotter
847
+ figure, _ = self.heatmap_plotter.get_heatmap_figure(
848
+ bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
849
+ show_individual_cells=self.show_individual_cells_checkbox.value,
850
+ show_linear_fit=self.show_linear_fit_checkbox.value,
851
+ # normalize_quantities=self.normalize_quantities_checkbox.value,
852
+ normalize_quantities=False, #! normalize is currently broken with manual selection
853
+ percentiles_X=self.percentilesX.value,
854
+ percentiles_Y=self.percentilesY.value,
855
+ figsize=self.graphics_widget.figure.get_size_inches(),
856
+ label_X=labelX,
857
+ label_Y=labelY,
858
+ )
859
+
860
+ # Display figure in graphics_widget
861
+ self.plot_heatmap(figure)
862
+
863
+ def plot_heatmap(self, figure):
864
+
865
+ if self.figure is not None:
866
+ plt.close(self.figure)
867
+ self.figure = figure
868
+
869
+ # labels_layer_exists = self.labels_layer_combo.value is not None
870
+
871
+ xys = self.heatmap_plotter.xys
872
+
873
+ self.graphics_widget = MplCanvas(
874
+ parent=figure,
875
+ manual_clustering_method=self.manual_clustering_method,
876
+ create_selectors=True, # labels_layer_exists,
877
+ xys=xys,
878
+ )
879
+ self.toolbar = FigureToolbar(self.graphics_widget)
880
+
881
+ self.toolbar.native = self.toolbar
882
+ self.toolbar._explicitly_hidden = False
883
+ self.toolbar.name = ""
884
+
885
+ self.graphics_widget.native = self.graphics_widget
886
+ self.graphics_widget._explicitly_hidden = False
887
+ self.graphics_widget.name = ""
888
+
889
+ new_graph_container = Container(
890
+ widgets=[
891
+ self.toolbar,
892
+ self.graphics_widget,
893
+ ],
894
+ labels=False,
895
+ )
896
+
897
+ widget_index = self.index(self.graph_container)
898
+ self.remove(self.graph_container)
899
+ self.insert(widget_index, new_graph_container)
900
+ self.graph_container = new_graph_container
901
+ self.graphics_widget.draw()
902
+
903
+ def parameters_changed(self):
904
+ if self.histogram_displayed:
905
+
906
+ labelX = (
907
+ self.quantityX_labels_choice
908
+ if self.quantityX_is_labels
909
+ else self.quantityX_label
910
+ )
911
+ labelY = (
912
+ self.quantityY_labels_choice
913
+ if self.quantityY_is_labels
914
+ else self.quantityY_label
915
+ )
916
+
917
+ t0 = time()
918
+ # Get figure from HeatmapPlotter
919
+ figure, _ = self.heatmap_plotter.get_heatmap_figure(
920
+ bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
921
+ show_individual_cells=self.show_individual_cells_checkbox.value,
922
+ show_linear_fit=self.show_linear_fit_checkbox.value,
923
+ # normalize_quantities=self.normalize_quantities_checkbox.value, #! normalize is currently broken with manual selection
924
+ normalize_quantities=False,
925
+ percentiles_X=self.percentilesX.value,
926
+ percentiles_Y=self.percentilesY.value,
927
+ figsize=self.graphics_widget.figure.get_size_inches(),
928
+ label_X=labelX,
929
+ label_Y=labelY,
930
+ display_quadrants=self.display_quadrants.value,
931
+ )
932
+ print("Time to get figure:", time() - t0)
933
+
934
+ # Display figure in graphics_widget -> Create a method "self.plot"
935
+ self.plot_heatmap(figure)
936
+
937
+ def _image_labels_layers_filter(self, wdg):
938
+ return [
939
+ layer
940
+ for layer in self._viewer.layers
941
+ if isinstance(layer, Image | Labels)
942
+ ]
943
+
944
+ def _bool_layers_filter(self, wdg):
945
+ return [
946
+ layer
947
+ for layer in self._viewer.layers
948
+ if (isinstance(layer, Image) and layer.data.dtype == bool)
949
+ ]
950
+
951
+ def draw_cluster_labels(
952
+ self,
953
+ features,
954
+ plot_cluster_name=None,
955
+ ):
956
+ """
957
+ Takes the manually selected points and plot the cluster on the labels image
958
+ """
959
+
960
+ # self.analysed_layer = self.labels_select.value
961
+ # labels_layer = self.labels_layer_combo.value
962
+ # mask_layer = self.mask_layer_combo.value
963
+ # self.graphics_widget.reset()
964
+
965
+ # fill all prediction nan values with -1
966
+ self.cluster_ids = features[plot_cluster_name] # .fillna(-1)
967
+
968
+ self.graphics_widget.selector.disconnect()
969
+ self.graphics_widget.selector = SelectFromCollection(
970
+ self.graphics_widget,
971
+ self.graphics_widget.axes,
972
+ xys=self.graphics_widget.xys,
973
+ )
974
+
975
+ # generate dictionary mapping each prediction to its respective color
976
+ # list cycling with % introduced for all labels except hdbscan noise points (id = -1)
977
+ cmap_dict = {
978
+ int(prediction + 1): (
979
+ cmap[int(prediction) % len(cmap)]
980
+ if prediction > 0
981
+ else [0, 0, 0, 0]
982
+ )
983
+ for prediction in range(np.max(self.cluster_ids) + 1)
984
+ }
985
+ # take care of background label
986
+ cmap_dict[None] = [0, 0, 0, 0]
987
+
988
+ napari_cmap = DirectLabelColormap(color_dict=cmap_dict)
989
+
990
+ keep_selection = list(self._viewer.layers.selection)
991
+
992
+ if self.labels_image is not None:
993
+ cluster_image = self.generate_cluster_image_from_labels(
994
+ self.labels_image, self.cluster_ids
995
+ )
996
+
997
+ elif self.mask is not None:
998
+ cluster_image = self.generate_cluster_image_from_points(
999
+ self.argwheres,
1000
+ self.cluster_ids,
1001
+ shape=self.quantityX.shape,
1002
+ )
1003
+ else:
1004
+ cluster_image = self.generate_cluster_image_from_points(
1005
+ self.argwheres,
1006
+ self.cluster_ids,
1007
+ shape=self.quantityX_layer_combo.value.data.shape,
1008
+ )
1009
+
1010
+ # if the cluster image layer doesn't yet exist make it
1011
+ # otherwise just update it
1012
+ if (
1013
+ self.cluster_labels_layer is None
1014
+ or self.cluster_labels_layer not in self._viewer.layers
1015
+ ):
1016
+ # visualising cluster image
1017
+ self.cluster_labels_layer = self._viewer.add_labels(
1018
+ cluster_image, # self.analysed_layer.data
1019
+ colormap=napari_cmap, # cluster_id_dict
1020
+ name="clustered labels",
1021
+ opacity=1,
1022
+ )
1023
+ else:
1024
+ # updating data
1025
+ self.cluster_labels_layer.data = cluster_image
1026
+ self.cluster_labels_layer.colormap = napari_cmap
1027
+
1028
+ self._viewer.layers.selection.clear()
1029
+ for s in keep_selection:
1030
+ self._viewer.layers.selection.add(s)
1031
+
1032
+ def generate_cluster_image_from_labels(self, label_image, predictionlist):
1033
+ props = regionprops(label_image)
1034
+
1035
+ cluster_image = np.zeros(label_image.shape, dtype="uint8")
1036
+
1037
+ argwheres = np.argwhere(predictionlist > 0).flatten()
1038
+
1039
+ for index in argwheres:
1040
+ prop = props[index]
1041
+ roi_data = label_image[prop.slice]
1042
+ cluster_image[prop.slice][roi_data == prop.label] = (
1043
+ predictionlist[index] + 1
1044
+ )
1045
+
1046
+ return cluster_image
1047
+
1048
+ def generate_cluster_image_from_points(
1049
+ self, argwheres, predictionlist, shape
1050
+ ):
1051
+
1052
+ cluster_image = np.zeros(shape, dtype="uint8")
1053
+ points_to_display = argwheres[predictionlist > 0]
1054
+
1055
+ cluster_image[tuple(points_to_display.T)] = (
1056
+ predictionlist[predictionlist > 0] + 1
1057
+ )
1058
+
1059
+ return cluster_image