napari-musa 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ """ """
2
+
3
+ import sys
4
+ from os.path import dirname
5
+
6
+ sys.path.append(dirname(dirname(__file__)))
7
+ import napari
8
+ import numpy as np
9
+ from magicgui.widgets import (
10
+ CheckBox,
11
+ ComboBox,
12
+ Container,
13
+ PushButton,
14
+ Select,
15
+ SpinBox,
16
+ )
17
+ from matplotlib.backends.backend_qt5agg import (
18
+ FigureCanvasQTAgg as FigureCanvas,
19
+ )
20
+ from matplotlib.backends.backend_qt5agg import (
21
+ NavigationToolbar2QT as NavigationToolbar,
22
+ )
23
+ from matplotlib.figure import Figure
24
+ from napari.utils.notifications import show_info
25
+ from qtpy.QtWidgets import (
26
+ QFileDialog,
27
+ QGroupBox,
28
+ QHBoxLayout,
29
+ QScrollArea,
30
+ QVBoxLayout,
31
+ QWidget,
32
+ )
33
+ from scipy.io import savemat
34
+
35
+ from napari_musa.modules.functions import (
36
+ NMF_analysis,
37
+ inverse_metrics,
38
+ )
39
+
40
+
41
+ class NMF(QWidget):
42
+ """ """
43
+
44
+ def __init__(self, viewer: napari.Viewer, data, plot):
45
+ """ """
46
+ super().__init__()
47
+ self.viewer = viewer
48
+ self.data = data
49
+ self.plot = plot
50
+ self.init_ui()
51
+
52
+ def init_ui(self):
53
+ """ """
54
+ scroll = QScrollArea()
55
+ scroll.setWidgetResizable(True)
56
+ content_widget = QWidget()
57
+ content_layout = QVBoxLayout(content_widget)
58
+
59
+ self.build_nmf_group(content_layout)
60
+ content_layout.addStretch()
61
+
62
+ scroll.setWidget(content_widget)
63
+ main_layout = QVBoxLayout(self)
64
+ main_layout.addWidget(scroll)
65
+ self.setLayout(main_layout)
66
+
67
+ def build_nmf_group(self, layout):
68
+ """ """
69
+
70
+ layout.addWidget(self.create_nmf_controls())
71
+ layout.addWidget(self.create_mean_spectrum_area())
72
+ layout.addStretch()
73
+
74
+ def create_nmf_controls(self):
75
+ """ """
76
+ nmf_box = QGroupBox("NMF parameters")
77
+ layout = QVBoxLayout()
78
+ layout.addSpacing(10)
79
+
80
+ row1 = QHBoxLayout()
81
+ self.reduced_dataset = CheckBox(text="Apply to reduced dataset")
82
+ self.masked_dataset = CheckBox(text="Apply to masked dataset")
83
+ self.modes_combobox = ComboBox(
84
+ choices=self.data.modes, label="Select the imaging mode"
85
+ )
86
+ row1.addWidget(self.reduced_dataset.native)
87
+ row1.addWidget(self.modes_combobox.native)
88
+ layout.addLayout(row1)
89
+
90
+ row2 = QHBoxLayout()
91
+ row2.addWidget(self.masked_dataset.native)
92
+ layout.addLayout(row2)
93
+
94
+ self.init_dropdown = ComboBox(
95
+ choices=[
96
+ "random",
97
+ "nndsvd",
98
+ "nndsvda",
99
+ "nndsvdar",
100
+ ],
101
+ label="Select the initialization",
102
+ )
103
+ self.n_components = SpinBox(
104
+ min=1, max=500, value=5, step=1, name="N Components"
105
+ )
106
+
107
+ layout.addWidget(
108
+ Container(
109
+ widgets=[
110
+ self.init_dropdown,
111
+ self.n_components,
112
+ ]
113
+ ).native
114
+ )
115
+
116
+ run_btn = PushButton(text="Run NMF")
117
+ run_btn.clicked.connect(self.run_nmf)
118
+ layout.addWidget(run_btn.native)
119
+
120
+ row3 = QHBoxLayout()
121
+ self.nmf_basis_multiselecton = Select(label="Select Bases", choices=[])
122
+ self.nmf_basis_multiselecton.changed.connect(
123
+ self.on_basis_selection_changed
124
+ )
125
+ row3.addWidget(self.nmf_basis_multiselecton.native)
126
+ layout.addLayout(row3)
127
+ nmf_box.setLayout(layout)
128
+ return nmf_box
129
+
130
+ def create_mean_spectrum_area(self):
131
+ """ """
132
+ nmf_box = QGroupBox("NMF spectra")
133
+ layout = QVBoxLayout()
134
+ layout.addSpacing(10)
135
+ self.mean_plot = FigureCanvas(Figure(figsize=(5, 3)))
136
+ self.mean_plot.setMinimumSize(300, 450)
137
+ self.mean_plot_toolbar = NavigationToolbar(self.mean_plot, self)
138
+ self.plot.customize_toolbar(self.mean_plot_toolbar)
139
+ self.plot.setup_plot(self.mean_plot)
140
+
141
+ layout.addWidget(self.mean_plot)
142
+ layout.addWidget(self.mean_plot_toolbar)
143
+
144
+ # Export button
145
+ export_btn = PushButton(text="Export spectra")
146
+ export_btn.clicked.connect(self.export_spectrum)
147
+ export_nmf_btn = PushButton(text="Export NMF matrices as .mat")
148
+ export_nmf_btn.clicked.connect(self.export_nmf)
149
+
150
+ layout.addWidget(Container(widgets=[export_btn]).native)
151
+ layout.addWidget(Container(widgets=[export_nmf_btn]).native)
152
+ nmf_box.setLayout(layout)
153
+ return nmf_box
154
+
155
+ def run_nmf(self):
156
+ """Perform NMF"""
157
+ mode = self.modes_combobox.value
158
+
159
+ n_basis = self.n_components.value
160
+ options = [f"Basis {i}" for i in range(n_basis)]
161
+
162
+ if self.masked_dataset.value:
163
+ dataset = self.data.hypercubes_masked[mode]
164
+ data_reshaped = dataset.reshape(
165
+ dataset.shape[0] * dataset.shape[1], -1
166
+ )
167
+ self.points = np.array(
168
+ np.where(~np.isnan(np.mean(data_reshaped, axis=1)))
169
+ ).flatten()
170
+ print(self.points)
171
+
172
+ elif self.reduced_dataset.value:
173
+ dataset = self.data.hypercubes_red[mode]
174
+ self.points = []
175
+ else:
176
+ dataset = self.data.hypercubes[mode]
177
+ self.points = []
178
+
179
+ self.data.nmf_maps[mode], self.data.nmf_basis[mode] = NMF_analysis(
180
+ dataset,
181
+ points=self.points,
182
+ n_components=n_basis,
183
+ init=self.init_dropdown.value,
184
+ )
185
+
186
+ self.nmf_basis_multiselecton.choices = options
187
+
188
+ self.viewer.add_image(
189
+ self.data.nmf_maps[mode].transpose(2, 0, 1),
190
+ name=str(mode) + " - NMF",
191
+ # ={"type": "hyperspectral_cube"},
192
+ )
193
+
194
+ show_info("NMF analysis completed!")
195
+
196
+ def on_basis_selection_changed(self, value):
197
+ mode = self.modes_combobox.value
198
+
199
+ print("Selected bases:", value)
200
+ print("Shape of the array:", self.data.nmf_basis[mode].shape)
201
+
202
+ self.basis_numbers = sorted([int(s.split()[1]) for s in value])
203
+ self.selected_basis = self.data.nmf_basis[mode][:, self.basis_numbers]
204
+ print(self.selected_basis.shape)
205
+ self.selected_basis_to_show = self.selected_basis
206
+
207
+ if mode == "Fused":
208
+ fusion_point = self.data.wls[self.data.fusion_modes[0]].shape[0]
209
+ print(self.data.fusion_norm)
210
+ # xxx aggiustare corrections
211
+ self.selected_basis_to_show[:fusion_point, :] = inverse_metrics(
212
+ self.selected_basis_to_show[:fusion_point, :],
213
+ self.data.fusion_norm,
214
+ self.data.fusion_params[0],
215
+ )
216
+ self.selected_basis_to_show[fusion_point:, :] = inverse_metrics(
217
+ self.selected_basis_to_show[fusion_point:, :],
218
+ self.data.fusion_norm,
219
+ self.data.fusion_params[1],
220
+ )
221
+
222
+ self.plot.show_spectra(
223
+ self.mean_plot,
224
+ self.selected_basis_to_show,
225
+ mode,
226
+ basis_numbers=self.basis_numbers,
227
+ export_txt_flag=False,
228
+ )
229
+
230
+ # show_info(f"NMF bases selected: {self.basis_numbers}")
231
+
232
+ def export_spectrum(self):
233
+ """Export the mean spectrum"""
234
+ mode = self.modes_combobox.value
235
+
236
+ self.plot.show_spectra(
237
+ self.mean_plot,
238
+ self.selected_basis_to_show,
239
+ mode,
240
+ basis_numbers=self.basis_numbers,
241
+ export_txt_flag=True,
242
+ )
243
+
244
+ def export_nmf(self):
245
+ """Export nmf"""
246
+ mode = self.modes_combobox.value
247
+ H = self.data.nmf_maps[mode]
248
+ W = self.data.nmf_basis[mode]
249
+
250
+ save_dict = {
251
+ "H": H,
252
+ "W": W,
253
+ }
254
+
255
+ filename, _ = QFileDialog.getSaveFileName(
256
+ self, "Save nmf .mat", "", "mat (*.mat)"
257
+ )
258
+ if filename:
259
+ savemat(filename, save_dict)
260
+
261
+ def update_number_H(self):
262
+ """ """
263
+ index = self.viewer.dims.current_step[0]
264
+ index = min(index, self.n_components.value - 1)
265
+ self.viewer.text_overlay.text = f"Component number: {index}"
@@ -0,0 +1,212 @@
1
+ """ """
2
+
3
+ import sys
4
+ from os.path import dirname
5
+
6
+ sys.path.append(dirname(dirname(__file__)))
7
+ import napari
8
+ import numpy as np
9
+ import pyqtgraph as pg
10
+ from magicgui.widgets import (
11
+ CheckBox,
12
+ ComboBox,
13
+ Container,
14
+ PushButton,
15
+ SpinBox,
16
+ )
17
+ from qtpy.QtWidgets import (
18
+ QGroupBox,
19
+ QHBoxLayout,
20
+ QScrollArea,
21
+ QVBoxLayout,
22
+ QWidget,
23
+ )
24
+
25
+ from napari_musa.modules.functions import PCA_analysis, RGB_to_hex
26
+
27
+
28
+ class PCA(QWidget):
29
+ def __init__(self, viewer: napari.Viewer, data, plot):
30
+ super().__init__()
31
+ self.viewer = viewer
32
+ self.data = data
33
+ self.plot = plot
34
+ self.hex_reshaped = np.zeros(1)
35
+ self.init_ui()
36
+
37
+ def init_ui(self):
38
+ """ """
39
+ scroll = QScrollArea()
40
+ scroll.setWidgetResizable(True)
41
+ content_widget = QWidget()
42
+ content_layout = QVBoxLayout(content_widget)
43
+
44
+ self.build_sivm_group(content_layout)
45
+ content_layout.addStretch()
46
+
47
+ scroll.setWidget(content_widget)
48
+ main_layout = QVBoxLayout(self)
49
+ main_layout.addWidget(scroll)
50
+ self.setLayout(main_layout)
51
+
52
+ def build_sivm_group(self, layout):
53
+ """ """
54
+ layout.addWidget(self.create_pca_controls())
55
+ layout.addWidget(self.create_pca_scatterplot())
56
+ layout.addStretch()
57
+
58
+ def create_pca_controls(self):
59
+ PCA_box = QGroupBox("PCA parameters")
60
+ PCA_main_layout = QVBoxLayout()
61
+ PCA_main_layout.addSpacing(10)
62
+ # - - - pca data - - -
63
+ self.reduced_dataset = CheckBox(text="Apply to reduced dataset")
64
+ self.modes_combobox = ComboBox(
65
+ choices=self.data.modes, label="Select the imaging mode"
66
+ ) # DROPDOWN FOR CALIBRATION
67
+ self.n_components = SpinBox(
68
+ min=1, max=100, value=10, step=1, name="Number of components"
69
+ )
70
+
71
+ PCA_perform_btn = PushButton(text="Perform PCA")
72
+ PCA_perform_btn.clicked.connect(self.PCA_perform_btn_f)
73
+ PCA_main_layout.addWidget(
74
+ Container(
75
+ widgets=[
76
+ self.reduced_dataset,
77
+ self.modes_combobox,
78
+ self.n_components,
79
+ PCA_perform_btn,
80
+ ]
81
+ ).native
82
+ )
83
+ PCA_box.setLayout(PCA_main_layout)
84
+ return PCA_box
85
+
86
+ def create_pca_scatterplot(self):
87
+ PCA_box = QGroupBox("PCA scatterplot")
88
+ PCA_layout_plot_var = QVBoxLayout()
89
+ PCA_layout_plot_var.addSpacing(10)
90
+ # - - - pca variables - - -
91
+ self.x_axis = SpinBox(min=1, max=100, value=1, step=1, name="X axis")
92
+ self.y_axis = SpinBox(min=1, max=100, value=2, step=1, name="Y axis")
93
+ PCA_layout_plot_var.addWidget(
94
+ Container(widgets=[self.x_axis, self.y_axis]).native
95
+ )
96
+ PCA_layout_perform = QHBoxLayout()
97
+ self.PCA_colorRGB = CheckBox(text="Scatterplot with True RGB")
98
+ PCA_show_plot_btn = PushButton(text="Show PCA scatterplot")
99
+ PCA_show_plot_btn.clicked.connect(self.PCA_show_plot_btn_f)
100
+ PCA_layout_perform.addWidget(
101
+ Container(
102
+ widgets=[self.PCA_colorRGB, PCA_show_plot_btn],
103
+ layout="horizontal",
104
+ ).native
105
+ )
106
+ PCA_layout_plot_var.addLayout(PCA_layout_perform)
107
+
108
+ self.pca_plot = pg.PlotWidget()
109
+ self.plot.setup_scatterplot(self.pca_plot)
110
+
111
+ # Add control buttons for scatter plot interaction
112
+ btn_layout = QHBoxLayout()
113
+
114
+ for icon, func in [
115
+ ("fa5s.home", lambda: self.pca_plot.getViewBox().autoRange()),
116
+ (
117
+ "fa5s.draw-polygon",
118
+ lambda: self.plot.polygon_selection(self.pca_plot),
119
+ ),
120
+ ("ri.add-box-fill", self.handle_selection),
121
+ (
122
+ "mdi6.image-edit",
123
+ lambda: self.plot.save_image_button(self.pca_plot),
124
+ ),
125
+ ]:
126
+ btn = self.plot.create_button(icon)
127
+ btn.clicked.connect(func)
128
+ btn_layout.addWidget(btn)
129
+
130
+ self.point_size = SpinBox(
131
+ min=1, max=100, value=1, step=1, name="Point size"
132
+ )
133
+ btn_layout.addSpacing(30)
134
+ btn_layout.addWidget(Container(widgets=[self.point_size]).native)
135
+ PCA_layout_plot_var.addLayout(btn_layout)
136
+ PCA_layout_plot_var.addWidget(self.pca_plot)
137
+ PCA_box.setLayout(PCA_layout_plot_var)
138
+ return PCA_box
139
+
140
+ def PCA_perform_btn_f(self):
141
+ mode = self.modes_combobox.value
142
+ if self.reduced_dataset.value:
143
+ self.PCA_dataset = self.data.hypercubes_red[mode]
144
+ else:
145
+ self.PCA_dataset = self.data.hypercubes[mode]
146
+
147
+ self.data.pca_maps[mode], W = PCA_analysis(
148
+ self.PCA_dataset, self.n_components.value
149
+ )
150
+
151
+ print("PCA dataset shape: ", self.data.pca_maps[mode].shape)
152
+ # Add the PCA maps to the viewer
153
+ self.viewer.add_image(
154
+ self.data.pca_maps[mode].transpose(2, 0, 1),
155
+ name=str(mode) + " - PCA ",
156
+ colormap="gray_r",
157
+ # ={"type": "hyperspectral_cube"},
158
+ )
159
+
160
+ def PCA_show_plot_btn_f(self):
161
+ """Plot UMAP scatter plot"""
162
+ mode = self.modes_combobox.value
163
+ pca_xaxis = self.x_axis.value - 1
164
+ pca_yaxis = self.y_axis.value - 1
165
+ H_PCA_reshaped = self.data.pca_maps[self.modes_combobox.value].reshape(
166
+ -1, self.n_components.value
167
+ )
168
+ self.H_PCA_reshaped_selected = np.stack(
169
+ (H_PCA_reshaped[:, pca_xaxis], H_PCA_reshaped[:, pca_yaxis])
170
+ ).T
171
+
172
+ if self.PCA_colorRGB.value:
173
+ if self.reduced_dataset.value:
174
+ colors = np.array(RGB_to_hex(self.data.rgb_red[mode])).reshape(
175
+ -1
176
+ )
177
+ else:
178
+ colors = np.array(RGB_to_hex(self.data.rgb[mode])).reshape(-1)
179
+ else:
180
+ colors = pg.mkBrush("#262930")
181
+
182
+ # print("Colors: \n", colors.reshape(-1))
183
+ self.points = []
184
+ self.plot.show_scatterplot(
185
+ self.pca_plot,
186
+ self.H_PCA_reshaped_selected,
187
+ colors,
188
+ self.points,
189
+ self.point_size.value,
190
+ )
191
+
192
+ def handle_selection(self):
193
+ """Handle polygon selection and create label layer"""
194
+ mode = self.modes_combobox.value
195
+ if self.reduced_dataset.value:
196
+ dataset = self.data.hypercubes_red[mode]
197
+ self.points = []
198
+
199
+ else:
200
+ dataset = self.data.hypercubes[mode]
201
+ self.plot.show_selected_points(
202
+ self.H_PCA_reshaped_selected,
203
+ dataset,
204
+ mode,
205
+ self.points,
206
+ )
207
+
208
+ def update_number_H(self):
209
+ """ """
210
+ index = self.viewer.dims.current_step[0]
211
+ index = min(index, self.n_components.value - 1)
212
+ self.viewer.text_overlay.text = f"PCA number: {index}"