napari-musa 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,581 @@
1
+ """ """
2
+
3
+ from contextlib import suppress
4
+
5
+ import napari
6
+ import numpy as np
7
+ import pyqtgraph as pg
8
+ import pyqtgraph.exporters
9
+ import qtawesome as qta # Icons
10
+ from magicgui.widgets import PushButton
11
+ from matplotlib.path import Path
12
+ from napari.utils.notifications import show_info, show_warning
13
+ from qtpy import QtCore
14
+ from qtpy.QtWidgets import QFileDialog, QWidget
15
+
16
+
17
+ class Plot(QWidget):
18
+ """Class for the plots"""
19
+
20
+ def __init__(self, viewer: napari.Viewer, data):
21
+ """ """
22
+ super().__init__()
23
+ self.viewer = viewer
24
+ self.data = data
25
+ self.poly_roi = None
26
+ self.drawing = False
27
+ self.vertical_line = None
28
+ self.mouse_connected = False
29
+
30
+ def setup_plot(self, plot, fused=False):
31
+ self.ax = self.ax2 = self.ax3 = None # Reset of the axis
32
+ plot.figure.patch.set_facecolor("#262930")
33
+ #
34
+ if fused and len(self.data.fusion_modes) > 2:
35
+ n_axes = 3
36
+ elif fused:
37
+ n_axes = 2
38
+ else:
39
+ n_axes = 1
40
+ #
41
+ axes = plot.figure.subplots(1, n_axes)
42
+ if n_axes == 1:
43
+ axes = [
44
+ axes
45
+ ] # if n_axes = 1, subplots return a single axis, not a list of axes. we'll make it a list anyway
46
+ self.ax, *rest = (
47
+ axes # unpacking (self.ax = ax1, the other(s) to rest)
48
+ )
49
+ self.ax2 = rest[0] if len(rest) > 0 else None
50
+ self.ax3 = rest[1] if len(rest) > 1 else None
51
+
52
+ for ax in axes:
53
+ if ax is not None:
54
+ ax.set_facecolor("#262930")
55
+ ax.tick_params(axis="x", colors="#D3D4D5", labelsize=14)
56
+ ax.tick_params(axis="y", colors="#D3D4D5", labelsize=14)
57
+ ax.grid(
58
+ True,
59
+ linestyle="--",
60
+ linewidth=0.5,
61
+ color="#D3D4D5",
62
+ alpha=0.5,
63
+ )
64
+ for position, spine in ax.spines.items():
65
+ if position in ["left", "bottom"]:
66
+ spine.set_color("#D3D4D5")
67
+ spine.set_linewidth(1)
68
+ spine.set_visible(True)
69
+ else:
70
+ spine.set_visible(False)
71
+
72
+ def show_plot(
73
+ self,
74
+ plot,
75
+ mode,
76
+ std_flag=False,
77
+ norm_flag=False,
78
+ reduced_dataset_flag=False,
79
+ export_txt_flag=False,
80
+ derivative_flag=False,
81
+ ):
82
+ """ """
83
+ selected_layer = self.viewer.layers.selection.active
84
+ # Check if the selected layer is a label layer
85
+ if not isinstance(selected_layer, napari.layers.Labels):
86
+ show_warning(
87
+ "⚠️ The selected layer is not a label layer. Please, select a label layer."
88
+ )
89
+ return
90
+ labels_data = selected_layer.data
91
+ if labels_data is None or np.all(
92
+ labels_data == 0
93
+ ): # check if all elements are 0
94
+ show_warning("⚠️ The selected label layer is empty")
95
+ return
96
+ # Clean and reset the plot
97
+ fig = plot.figure
98
+ fig.clf()
99
+ self.setup_plot(plot, fused=(mode == "Fused"))
100
+ fused_flag = mode == "Fused"
101
+ # Secondary axis for derivative
102
+ ax_der = None
103
+ if derivative_flag and not fused_flag:
104
+ ax_der = self.ax.twinx()
105
+ ax_der.tick_params(axis="y", colors="#FFA500", labelsize=14)
106
+ ax_der.spines["right"].set_color("#FFA500")
107
+ ax_der.set_ylabel("Derivative", color="#FFA500")
108
+ #
109
+ num_classes = int(labels_data.max())
110
+ colormap = np.array(selected_layer.colormap.colors)
111
+ print("Shape of mask; ", labels_data.shape)
112
+ wavelengths = self.data.wls[mode]
113
+ # Compute spectra and derivatives
114
+ spectra, stds, spectra_der, stds_der = (
115
+ self.compute_spectra( # funxtion for compute spectra
116
+ wavelengths,
117
+ mask=labels_data,
118
+ mode=mode,
119
+ reduced_flag=reduced_dataset_flag,
120
+ num_classes=num_classes,
121
+ normalize_flag=norm_flag,
122
+ derivative_flag=derivative_flag,
123
+ )
124
+ )
125
+
126
+ # PLOT SPECTRA
127
+ for index in range(num_classes):
128
+ color = colormap[index + 1, :3]
129
+ # Fused mode
130
+ if fused_flag:
131
+ self.plot_fused(
132
+ index, spectra, stds, color, std_flag
133
+ ) # function for fused
134
+ continue
135
+ # principal spectrum
136
+ self.plot_spectrum(
137
+ self.ax,
138
+ wavelengths,
139
+ spectra[index],
140
+ stds[index] if std_flag else None,
141
+ color,
142
+ )
143
+ if derivative_flag and ax_der is not None:
144
+ self.plot_spectrum(
145
+ ax_der,
146
+ wavelengths,
147
+ spectra_der[index],
148
+ stds_der[index] if std_flag else None,
149
+ color,
150
+ linestyle="--",
151
+ )
152
+ #
153
+ if export_txt_flag:
154
+ filename, selected_filter = QFileDialog.getSaveFileName(
155
+ self, "Save spectra", "", "CSV file (*.csv);;Text file (*.txt)"
156
+ )
157
+ if filename:
158
+
159
+ if not (
160
+ filename.lower().endswith(".txt")
161
+ or filename.lower().endswith(".csv")
162
+ ):
163
+ if selected_filter.startswith("Text"):
164
+ filename += ".txt"
165
+ else:
166
+ filename += ".csv"
167
+ if not std_flag:
168
+ stds = np.zeros_like(spectra)
169
+ print(spectra.shape, stds.shape)
170
+
171
+ if filename.endswith(".txt"):
172
+ self.export_spectra_txt(
173
+ filename, wavelengths, spectra.T, stds.T, mode="txt"
174
+ )
175
+
176
+ elif filename.endswith(".csv"):
177
+ self.export_spectra_txt(
178
+ filename, wavelengths, spectra.T, stds.T, mode="csv"
179
+ )
180
+ plot.draw()
181
+
182
+ def compute_spectra(
183
+ self,
184
+ wavelengths,
185
+ mask,
186
+ mode,
187
+ reduced_flag,
188
+ num_classes,
189
+ normalize_flag,
190
+ derivative_flag,
191
+ ):
192
+ """Compute mean and std of spectra (and derivative if requested)."""
193
+ wl_len = len(wavelengths)
194
+ spectra = np.zeros((num_classes, wl_len))
195
+ stds = np.zeros_like(spectra)
196
+ spectra_der = np.zeros_like(spectra)
197
+ stds_der = np.zeros_like(stds)
198
+
199
+ #
200
+ # Select the right data
201
+ def select_cube(mode):
202
+ """"""
203
+ if reduced_flag:
204
+ cube = self.data.hypercubes_spatial_red.get(mode)
205
+ if cube is not None:
206
+ return self.data.hypercubes_spatial_red
207
+ return self.data.hypercubes_red
208
+ return self.data.hypercubes
209
+
210
+ #
211
+ for idx in range(num_classes):
212
+ points = np.array(np.where(mask == idx + 1))
213
+ if points.size == 0:
214
+ continue # Jump to the next cycle
215
+ # Handle Fused
216
+ if mode == "Fused":
217
+ cube = select_cube("Fused")
218
+ data_selected = np.concatenate(
219
+ [
220
+ cube[m][points[0], points[1], :]
221
+ for m in self.data.fusion_modes
222
+ ],
223
+ axis=1,
224
+ )
225
+ else:
226
+ cube = select_cube(mode)
227
+ data_selected = cube[mode][points[0], points[1], :]
228
+ #
229
+ mean_spec = np.mean(data_selected, axis=0)
230
+ std_spec = np.std(data_selected, axis=0)
231
+ #
232
+ if normalize_flag:
233
+ min_val, max_val = np.min(mean_spec), np.max(mean_spec)
234
+ if max_val > min_val: # evita divisioni per zero
235
+ mean_spec = (mean_spec - min_val) / (max_val - min_val)
236
+ std_spec /= max_val - min_val
237
+ spectra[idx] = mean_spec
238
+ stds[idx] = std_spec
239
+ #
240
+ if derivative_flag:
241
+ cube_der = select_cube(mode + " - derivative")
242
+ if cube_der.get(mode + " - derivative") is not None:
243
+ data_der = cube_der[mode + " - derivative"][
244
+ points[0], points[1], :
245
+ ]
246
+ spectra_der[idx] = np.mean(data_der, axis=0)
247
+ stds_der[idx] = np.std(data_der, axis=0)
248
+ return spectra, stds, spectra_der, stds_der
249
+
250
+ def plot_spectrum(self, ax, x, y, std=None, color="blue", linestyle="-"):
251
+ """Plot with optional standard deviation shading."""
252
+ ax.plot(x, y, color=color, linewidth=2, linestyle=linestyle)
253
+ if std is not None:
254
+ ax.fill_between(x, y - std, y + std, color=color, alpha=0.3)
255
+
256
+ def plot_fused(self, index, spectra, stds, color, std_flag):
257
+ """Handle the plotting of fused datasets."""
258
+ fusion_modes = self.data.fusion_modes
259
+ wls = [self.data.wls[m] for m in fusion_modes]
260
+ wl_points = np.cumsum(
261
+ [w.shape[0] for w in wls]
262
+ ) # list with n of elements in wls + cumulative sum
263
+ #
264
+ start = 0
265
+ axes = [self.ax, self.ax2, getattr(self, "ax3", None)]
266
+ for i, (ax, wl) in enumerate(zip(axes, wls, strict=False)):
267
+ end = wl_points[i]
268
+ if ax is None:
269
+ break
270
+ y = spectra[index, start:end]
271
+ s = stds[index, start:end]
272
+ ax.plot(wl, y, color=color, linewidth=2)
273
+ if std_flag:
274
+ ax.fill_between(wl, y - s, y + s, color=color, alpha=0.3)
275
+ start = end
276
+
277
+ # %% Export spectra
278
+ def export_spectra_txt(self, filename, wavelengths, spectra, stds, mode):
279
+ """Export spectra and standard deviation to TXT."""
280
+ print("Spectra and std shape: ", spectra.shape, stds.shape)
281
+ print("Wavelengths shape: ", wavelengths.shape)
282
+ M = spectra.shape[1]
283
+ cols = [wavelengths]
284
+ for j in range(M):
285
+ cols.append(spectra[:, j])
286
+ cols.append(stds[:, j])
287
+ data_to_save = np.column_stack(cols)
288
+
289
+ if mode == "txt":
290
+ header_parts = ["Wavelength"]
291
+ header_parts += [f"Spectrum{j+1}\tStd{j+1}" for j in range(M)]
292
+ header_txt = "\t".join(header_parts)
293
+ np.savetxt(
294
+ filename,
295
+ data_to_save,
296
+ fmt="%.6f",
297
+ delimiter="\t",
298
+ header=header_txt,
299
+ comments="",
300
+ )
301
+ show_info("The .txt has been saved")
302
+ if mode == "csv":
303
+ header_csv = "Wavelength," + ",".join(
304
+ [f"Spectrum{j+1},Std{j+1}" for j in range(M)]
305
+ )
306
+ np.savetxt(
307
+ filename,
308
+ data_to_save,
309
+ fmt="%.6f",
310
+ delimiter=",",
311
+ header=header_csv,
312
+ comments="",
313
+ )
314
+ show_info("The .csv has been saved")
315
+
316
+ # -----------------------------------------------------------------------------------------------
317
+ # %% SCATTERPLOT
318
+ # -----------------------------------------------------------------------------------------------
319
+ def setup_scatterplot(self, plot):
320
+ """Setup basic scatterplot appearance"""
321
+ plot.setBackground("w")
322
+ for axis in ("left", "bottom"):
323
+ plot.getAxis(axis).setTicks([])
324
+ plot.getAxis(axis).setStyle(tickLength=0)
325
+ plot.getAxis(axis).setPen(None)
326
+ plot.setMinimumSize(400, 400)
327
+
328
+ ## ICONS
329
+ def polygon_selection(self, plot):
330
+ """Polygonal selection on scatterplot"""
331
+ self.plot = plot
332
+ self.temp_points = []
333
+ # Remove old ROIs
334
+ if self.poly_roi:
335
+ # Disconnect if necessary
336
+ if self.mouse_connected:
337
+ with suppress(TypeError):
338
+ self.plot.scene().sigMouseClicked.disconnect(
339
+ self.add_point_to_polygon
340
+ )
341
+ # try:
342
+ # self.plot.scene().sigMouseClicked.disconnect(
343
+ # self.add_point_to_polygon
344
+ # )
345
+ # except TypeError:
346
+ # pass
347
+ self.plot.removeItem(self.poly_roi)
348
+
349
+ self.poly_roi = pg.PolyLineROI(
350
+ [],
351
+ closed=False,
352
+ pen="r",
353
+ handlePen=pg.mkPen("red"),
354
+ )
355
+ self.plot.addItem(self.poly_roi)
356
+ self.drawing = True
357
+ # connect
358
+ if not self.mouse_connected:
359
+ self.plot.scene().sigMouseClicked.connect(
360
+ self.add_point_to_polygon
361
+ )
362
+ self.mouse_connected = True
363
+
364
+ def add_point_to_polygon(self, event):
365
+ """Add points to the ROI"""
366
+ if not self.drawing:
367
+ return
368
+ if event.button() == QtCore.Qt.LeftButton:
369
+ pos = self.plot.plotItem.vb.mapSceneToView(event.scenePos())
370
+ point = (pos.x(), pos.y())
371
+ self.temp_points.append(point)
372
+ self.poly_roi.setPoints(self.temp_points)
373
+ if event.double():
374
+ self.drawing = False
375
+ self.poly_roi.closed = True
376
+ self.poly_roi.setPoints(
377
+ self.temp_points
378
+ ) # richiude visivamente
379
+ if self.mouse_connected:
380
+ with suppress(TypeError):
381
+ self.plot.scene().sigMouseClicked.disconnect(
382
+ self.add_point_to_polygon
383
+ )
384
+ # try:
385
+ # self.plot.scene().sigMouseClicked.disconnect(
386
+ # self.add_point_to_polygon
387
+ # )
388
+ # except TypeError:
389
+ # pass
390
+ # self.mouse_connected = False
391
+
392
+ def show_selected_points(self, scatterdata, hsi_image, mode, points):
393
+ """ """
394
+ if not self.poly_roi:
395
+ print("No active selection!")
396
+ return
397
+ polygon = self.poly_roi.getState()["points"]
398
+ polygon = np.array(polygon)
399
+ path = Path(polygon)
400
+ points_mask = path.contains_points(scatterdata)
401
+ selected_indices = [
402
+ index for index, value in enumerate(points_mask) if value
403
+ ]
404
+ if len(points) > 0:
405
+ selected_indices = points[selected_indices]
406
+ # print("Punti selezionati:", selected_points)
407
+ # print("Indici selezionati:", selected_indices)
408
+ # CREATION OF LAYER LABELS
409
+ labels = np.zeros(
410
+ (hsi_image.shape[0], hsi_image.shape[1]), dtype=np.int32
411
+ )
412
+ existing_layers = [
413
+ layer
414
+ for layer in self.viewer.layers
415
+ if layer.name == f"{mode} SCATTERPLOT LABELS"
416
+ ]
417
+ if existing_layers:
418
+ labels_layer = existing_layers[0]
419
+ labels = labels_layer.data
420
+ new_label_value = labels.max() + 1
421
+ else:
422
+ labels_layer = None
423
+ labels = np.zeros(
424
+ (hsi_image.shape[0], hsi_image.shape[1]), dtype=np.int32
425
+ )
426
+ new_label_value = 1
427
+
428
+ labels.flat[np.asarray(selected_indices, dtype=np.intp)] = (
429
+ new_label_value
430
+ )
431
+ # LABELS IN THE SELCTED POINTS
432
+ # for idx in selected_indices:
433
+ # row, col = divmod(idx, hsi_image.shape[1]) # Converted in 2D
434
+ # labels[row, col] = new_label_value
435
+ if labels_layer:
436
+ # labels_layer.data = labels
437
+ labels_layer.refresh()
438
+ else:
439
+ labels_layer = self.viewer.add_labels(
440
+ labels, name=f"{mode} SCATTERPLOT LABELS"
441
+ )
442
+ self.temp_points = []
443
+
444
+ def save_image_button(self, plot):
445
+ """ """
446
+ filename, _ = QFileDialog.getSaveFileName(
447
+ self, "Save UMAP image", "", "png (*.png)"
448
+ )
449
+ if filename:
450
+ exporter = pg.exporters.ImageExporter(plot.getPlotItem())
451
+ exporter.parameters()["width"] = 2000
452
+ exporter.parameters()["height"] = 2000
453
+ exporter.export(filename)
454
+ print("Image saved!")
455
+
456
+ def show_scatterplot(self, plot, data, hex_colors, points, size):
457
+ """Display sctterplot"""
458
+ if hasattr(self, "scatter") and self.scatter:
459
+ plot.removeItem(self.scatter)
460
+ self.scatter = None
461
+ if len(points) > 0:
462
+ self.scatter = pg.ScatterPlotItem(
463
+ pos=data,
464
+ pen=None,
465
+ symbol="o",
466
+ size=size,
467
+ brush=hex_colors[points],
468
+ )
469
+ else:
470
+ self.scatter = pg.ScatterPlotItem(
471
+ pos=data, pen=None, symbol="o", size=size, brush=hex_colors
472
+ )
473
+ plot.addItem(self.scatter)
474
+ plot.getViewBox().autoRange()
475
+ plot.update()
476
+
477
+ # %% Customization
478
+ def customize_toolbar(self, toolbar):
479
+ """Customaize the toolbar of the plot"""
480
+ # Cambia sfondo della toolbar
481
+ toolbar.setStyleSheet("background-color: #262930; border: none;")
482
+
483
+ # Mappa nome azione → nome file icona
484
+ icon_map = {
485
+ "Home": "fa5s.home",
486
+ "Back": "fa5s.arrow-left",
487
+ "Forward": "fa5s.arrow-right",
488
+ "Pan": "fa5s.expand-arrows-alt",
489
+ "Zoom": "ei.zoom-in",
490
+ "Subplots": "msc.settings",
491
+ "Customize": "mdi.chart-scatter-plot",
492
+ "Save": "fa5.save",
493
+ }
494
+
495
+ for action in toolbar.actions():
496
+ text = action.text()
497
+ if text in icon_map:
498
+ action.setIcon(qta.icon(f"{icon_map[text]}", color="#D3D4D5"))
499
+
500
+ def create_button(self, icon_name):
501
+ """Create styled icon button"""
502
+ btn = PushButton(text="").native
503
+ btn.setIcon(qta.icon(f"{icon_name}", color="#D3D4D5")) # Icon
504
+ btn.setStyleSheet(
505
+ """
506
+ QPushButton {
507
+ background-color: #262930;
508
+ border-radius: 5px;
509
+ padding: 5px;
510
+ }
511
+ QPushButton:hover {
512
+ background-color: #3E3F40;
513
+ }"""
514
+ )
515
+ btn.setFixedSize(30, 30) # Dimensione fissa
516
+ return btn
517
+
518
+ # %% Show spectra for endmembers extraction
519
+ def show_spectra(
520
+ self,
521
+ plot,
522
+ spectra,
523
+ mode,
524
+ basis_numbers,
525
+ export_txt_flag=False,
526
+ ):
527
+ # Clean and reset the plot
528
+ fig = plot.figure
529
+ fig.clf()
530
+ self.setup_plot(plot, fused=(mode == "Fused"))
531
+
532
+ wavelengths = self.data.wls[mode]
533
+
534
+ colormap = np.array(napari.utils.colormaps.label_colormap().colors)
535
+ for index, element in enumerate(basis_numbers):
536
+ if mode == "Fused":
537
+ self.plot_fused(
538
+ index,
539
+ spectra.transpose(),
540
+ np.zeros_like(spectra.transpose()), # Std deviation
541
+ colormap[element + 3, :3],
542
+ False,
543
+ )
544
+
545
+ else:
546
+ self.ax.plot(
547
+ wavelengths,
548
+ spectra[:, index],
549
+ color=colormap[element + 3, :3],
550
+ linewidth=2,
551
+ )
552
+
553
+ if export_txt_flag:
554
+ filename, selected_filter = QFileDialog.getSaveFileName(
555
+ self, "Save spectra", "", "CSV file (*.csv);;Text file (*.txt)"
556
+ )
557
+ if filename:
558
+
559
+ if not (
560
+ filename.lower().endswith(".txt")
561
+ or filename.lower().endswith(".csv")
562
+ ):
563
+ if selected_filter.startswith("Text"):
564
+ filename += ".txt"
565
+ else:
566
+ filename += ".csv"
567
+
568
+ std = np.zeros_like(spectra)
569
+ print(spectra.shape, std.shape)
570
+
571
+ if filename.endswith(".txt"):
572
+ self.export_spectra_txt(
573
+ filename, wavelengths, spectra, std, mode="txt"
574
+ )
575
+
576
+ elif filename.endswith(".csv"):
577
+ self.export_spectra_txt(
578
+ filename, wavelengths, spectra, std, mode="csv"
579
+ )
580
+
581
+ plot.draw()
@@ -0,0 +1,15 @@
1
+ name: napari-musa
2
+ display_name: MUSA Plugin
3
+ # use 'hidden' to remove plugin from napari hub search results
4
+ visibility: public
5
+ # see https://napari.org/stable/plugins/technical_references/manifest.html#fields for valid categories
6
+ categories: ["Annotation", "Segmentation", "Acquisition"]
7
+ contributions:
8
+ commands:
9
+ - id: napari-musa.run_app
10
+ python_name: napari_musa.main:run_napari_app
11
+ title: RUN
12
+
13
+ widgets:
14
+ - command: napari-musa.run_app
15
+ display_name: MUSA - HSI Analysis