celldetective 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. celldetective/__init__.py +2 -0
  2. celldetective/__main__.py +432 -0
  3. celldetective/datasets/segmentation_annotations/blank +0 -0
  4. celldetective/datasets/signal_annotations/blank +0 -0
  5. celldetective/events.py +149 -0
  6. celldetective/extra_properties.py +100 -0
  7. celldetective/filters.py +89 -0
  8. celldetective/gui/__init__.py +20 -0
  9. celldetective/gui/about.py +44 -0
  10. celldetective/gui/analyze_block.py +563 -0
  11. celldetective/gui/btrack_options.py +898 -0
  12. celldetective/gui/classifier_widget.py +386 -0
  13. celldetective/gui/configure_new_exp.py +532 -0
  14. celldetective/gui/control_panel.py +438 -0
  15. celldetective/gui/gui_utils.py +495 -0
  16. celldetective/gui/json_readers.py +113 -0
  17. celldetective/gui/measurement_options.py +1425 -0
  18. celldetective/gui/neighborhood_options.py +452 -0
  19. celldetective/gui/plot_signals_ui.py +1042 -0
  20. celldetective/gui/process_block.py +1055 -0
  21. celldetective/gui/retrain_segmentation_model_options.py +706 -0
  22. celldetective/gui/retrain_signal_model_options.py +643 -0
  23. celldetective/gui/seg_model_loader.py +460 -0
  24. celldetective/gui/signal_annotator.py +2388 -0
  25. celldetective/gui/signal_annotator_options.py +340 -0
  26. celldetective/gui/styles.py +217 -0
  27. celldetective/gui/survival_ui.py +903 -0
  28. celldetective/gui/tableUI.py +608 -0
  29. celldetective/gui/thresholds_gui.py +1300 -0
  30. celldetective/icons/logo-large.png +0 -0
  31. celldetective/icons/logo.png +0 -0
  32. celldetective/icons/signals_icon.png +0 -0
  33. celldetective/icons/splash-test.png +0 -0
  34. celldetective/icons/splash.png +0 -0
  35. celldetective/icons/splash0.png +0 -0
  36. celldetective/icons/survival2.png +0 -0
  37. celldetective/icons/vignette_signals2.png +0 -0
  38. celldetective/icons/vignette_signals2.svg +114 -0
  39. celldetective/io.py +2050 -0
  40. celldetective/links/zenodo.json +561 -0
  41. celldetective/measure.py +1258 -0
  42. celldetective/models/segmentation_effectors/blank +0 -0
  43. celldetective/models/segmentation_generic/blank +0 -0
  44. celldetective/models/segmentation_targets/blank +0 -0
  45. celldetective/models/signal_detection/blank +0 -0
  46. celldetective/models/tracking_configs/mcf7.json +68 -0
  47. celldetective/models/tracking_configs/ricm.json +203 -0
  48. celldetective/models/tracking_configs/ricm2.json +203 -0
  49. celldetective/neighborhood.py +717 -0
  50. celldetective/scripts/analyze_signals.py +51 -0
  51. celldetective/scripts/measure_cells.py +275 -0
  52. celldetective/scripts/segment_cells.py +212 -0
  53. celldetective/scripts/segment_cells_thresholds.py +140 -0
  54. celldetective/scripts/track_cells.py +206 -0
  55. celldetective/scripts/train_segmentation_model.py +246 -0
  56. celldetective/scripts/train_signal_model.py +49 -0
  57. celldetective/segmentation.py +712 -0
  58. celldetective/signals.py +2826 -0
  59. celldetective/tracking.py +974 -0
  60. celldetective/utils.py +1681 -0
  61. celldetective-1.0.2.dist-info/LICENSE +674 -0
  62. celldetective-1.0.2.dist-info/METADATA +192 -0
  63. celldetective-1.0.2.dist-info/RECORD +66 -0
  64. celldetective-1.0.2.dist-info/WHEEL +5 -0
  65. celldetective-1.0.2.dist-info/entry_points.txt +2 -0
  66. celldetective-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,643 @@
1
+ from PyQt5.QtWidgets import QMainWindow, QApplication, QMessageBox, QScrollArea, QComboBox, QFrame, QCheckBox, QFileDialog, QGridLayout, QTextEdit, QLineEdit, QVBoxLayout, QWidget, QLabel, QHBoxLayout, QPushButton
2
+ from PyQt5.QtCore import Qt, QSize
3
+ from PyQt5.QtGui import QDoubleValidator, QIntValidator, QIcon
4
+ from celldetective.gui.gui_utils import center_window, FeatureChoice, ListWidget, QHSeperationLine, FigureCanvas, GeometryChoice, OperationChoice
5
+ from superqt import QLabeledDoubleRangeSlider, QLabeledDoubleSlider,QLabeledSlider
6
+ from superqt.fonticon import icon
7
+ from fonticon_mdi6 import MDI6
8
+ from celldetective.utils import extract_experiment_channels, get_software_location
9
+ from celldetective.io import interpret_tracking_configuration, load_frames, locate_signal_dataset, get_signal_datasets_list
10
+ from celldetective.measure import compute_haralick_features, contour_of_instance_segmentation
11
+ from celldetective.signals import train_signal_model
12
+ import numpy as np
13
+ import json
14
+ from shutil import copyfile
15
+ import os
16
+ import matplotlib.pyplot as plt
17
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
18
+ from glob import glob
19
+ from natsort import natsorted
20
+ from tifffile import imread
21
+ from pathlib import Path, PurePath
22
+ from datetime import datetime
23
+ import pandas as pd
24
+ from functools import partial
25
+
26
+ class ConfigSignalModelTraining(QMainWindow):
27
+
28
+ """
29
+ UI to set measurement instructions.
30
+
31
+ """
32
+
33
+ def __init__(self, parent=None):
34
+
35
+ super().__init__()
36
+ self.parent = parent
37
+ self.setWindowTitle("Train signal model")
38
+ self.setWindowIcon(QIcon(os.sep.join(['celldetective','icons','mexican-hat.png'])))
39
+ self.mode = self.parent.mode
40
+ self.exp_dir = self.parent.exp_dir
41
+ self.soft_path = get_software_location()
42
+ self.pretrained_model = None
43
+ self.dataset_folder = None
44
+ self.signal_models_dir = self.soft_path+'/celldetective/models/signal_detection/'
45
+
46
+ self.onlyFloat = QDoubleValidator()
47
+ self.onlyInt = QIntValidator()
48
+
49
+ self.screen_height = self.parent.parent.parent.screen_height
50
+ center_window(self)
51
+
52
+ self.setMinimumWidth(500)
53
+ self.setMinimumHeight(int(0.3*self.screen_height))
54
+ self.setMaximumHeight(int(0.8*self.screen_height))
55
+ self.populate_widget()
56
+ #self.load_previous_measurement_instructions()
57
+
58
+ def populate_widget(self):
59
+
60
+ """
61
+ Create the multibox design.
62
+
63
+ """
64
+
65
+ # Create button widget and layout
66
+ self.scroll_area = QScrollArea(self)
67
+ self.button_widget = QWidget()
68
+ main_layout = QVBoxLayout()
69
+ self.button_widget.setLayout(main_layout)
70
+ main_layout.setContentsMargins(30,30,30,30)
71
+
72
+ # first frame for FEATURES
73
+ self.model_frame = QFrame()
74
+ self.model_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
75
+ self.populate_model_frame()
76
+ main_layout.addWidget(self.model_frame)
77
+
78
+ self.data_frame = QFrame()
79
+ self.data_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
80
+ self.populate_data_frame()
81
+ main_layout.addWidget(self.data_frame)
82
+
83
+ self.hyper_frame = QFrame()
84
+ self.hyper_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
85
+ self.populate_hyper_frame()
86
+ main_layout.addWidget(self.hyper_frame)
87
+
88
+ self.submit_btn = QPushButton('Train')
89
+ self.submit_btn.setStyleSheet(self.parent.parent.parent.button_style_sheet)
90
+ self.submit_btn.clicked.connect(self.prep_model)
91
+ main_layout.addWidget(self.submit_btn)
92
+ self.submit_btn.setEnabled(False)
93
+
94
+ #self.populate_left_panel()
95
+ #grid.addLayout(self.left_side, 0, 0, 1, 1)
96
+ self.button_widget.adjustSize()
97
+
98
+ self.scroll_area.setAlignment(Qt.AlignCenter)
99
+ self.scroll_area.setWidget(self.button_widget)
100
+ self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
101
+ self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
102
+ self.scroll_area.setWidgetResizable(True)
103
+ self.setCentralWidget(self.scroll_area)
104
+ self.show()
105
+
106
+ QApplication.processEvents()
107
+ self.adjustScrollArea()
108
+
109
+ def populate_hyper_frame(self):
110
+
111
+ """
112
+ Add widgets and layout in the POST-PROCESSING frame.
113
+ """
114
+
115
+ grid = QGridLayout(self.hyper_frame)
116
+ grid.setContentsMargins(30,30,30,30)
117
+ grid.setSpacing(30)
118
+
119
+ self.hyper_lbl = QLabel("HYPERPARAMETERS")
120
+ self.hyper_lbl.setStyleSheet("""
121
+ font-weight: bold;
122
+ padding: 0px;
123
+ """)
124
+ grid.addWidget(self.hyper_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
125
+ self.generate_hyper_contents()
126
+ grid.addWidget(self.ContentsHyper, 1, 0, 1, 4, alignment=Qt.AlignTop)
127
+
128
+ def generate_hyper_contents(self):
129
+
130
+ self.ContentsHyper = QFrame()
131
+ layout = QVBoxLayout(self.ContentsHyper)
132
+ layout.setContentsMargins(0,0,0,0)
133
+
134
+ lr_layout = QHBoxLayout()
135
+ lr_layout.addWidget(QLabel('learning rate: '),30)
136
+ self.lr_le = QLineEdit('0,01')
137
+ self.lr_le.setValidator(self.onlyFloat)
138
+ lr_layout.addWidget(self.lr_le, 70)
139
+ layout.addLayout(lr_layout)
140
+
141
+ bs_layout = QHBoxLayout()
142
+ bs_layout.addWidget(QLabel('batch size: '),30)
143
+ self.bs_le = QLineEdit('64')
144
+ self.bs_le.setValidator(self.onlyInt)
145
+ bs_layout.addWidget(self.bs_le, 70)
146
+ layout.addLayout(bs_layout)
147
+
148
+ epochs_layout = QHBoxLayout()
149
+ epochs_layout.addWidget(QLabel('# epochs: '), 30)
150
+ self.epochs_slider = QLabeledSlider()
151
+ self.epochs_slider.setRange(1,3000)
152
+ self.epochs_slider.setSingleStep(1)
153
+ self.epochs_slider.setTickInterval(1)
154
+ self.epochs_slider.setOrientation(1)
155
+ self.epochs_slider.setValue(300)
156
+ epochs_layout.addWidget(self.epochs_slider, 70)
157
+ layout.addLayout(epochs_layout)
158
+
159
+
160
+ def populate_data_frame(self):
161
+
162
+ """
163
+ Add widgets and layout in the POST-PROCESSING frame.
164
+ """
165
+
166
+ grid = QGridLayout(self.data_frame)
167
+ grid.setContentsMargins(30,30,30,30)
168
+ grid.setSpacing(30)
169
+
170
+ self.data_lbl = QLabel("DATA")
171
+ self.data_lbl.setStyleSheet("""
172
+ font-weight: bold;
173
+ padding: 0px;
174
+ """)
175
+ grid.addWidget(self.data_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
176
+ self.generate_data_contents()
177
+ grid.addWidget(self.ContentsData, 1, 0, 1, 4, alignment=Qt.AlignTop)
178
+
179
+ def populate_model_frame(self):
180
+
181
+ """
182
+ Add widgets and layout in the FEATURES frame.
183
+ """
184
+
185
+ grid = QGridLayout(self.model_frame)
186
+ grid.setContentsMargins(30,30,30,30)
187
+ grid.setSpacing(30)
188
+
189
+ self.model_lbl = QLabel("MODEL")
190
+ self.model_lbl.setStyleSheet("""
191
+ font-weight: bold;
192
+ padding: 0px;
193
+ """)
194
+ grid.addWidget(self.model_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
195
+
196
+ self.generate_model_panel_contents()
197
+ grid.addWidget(self.ContentsModel, 1, 0, 1, 4, alignment=Qt.AlignTop)
198
+
199
+
200
+ def generate_data_contents(self):
201
+
202
+ self.ContentsData = QFrame()
203
+ layout = QVBoxLayout(self.ContentsData)
204
+ layout.setContentsMargins(0,0,0,0)
205
+
206
+ train_data_layout = QHBoxLayout()
207
+ train_data_layout.addWidget(QLabel('Training data: '), 30)
208
+ self.select_data_folder_btn = QPushButton('Choose folder')
209
+ self.select_data_folder_btn.clicked.connect(self.showDialog_dataset)
210
+ self.data_folder_label = QLabel('No folder chosen')
211
+ train_data_layout.addWidget(self.select_data_folder_btn, 35)
212
+ train_data_layout.addWidget(self.data_folder_label, 30)
213
+
214
+ self.cancel_dataset = QPushButton()
215
+ self.cancel_dataset.setIcon(icon(MDI6.close,color="black"))
216
+ self.cancel_dataset.clicked.connect(self.clear_dataset)
217
+ self.cancel_dataset.setStyleSheet(self.parent.parent.parent.button_select_all)
218
+ self.cancel_dataset.setIconSize(QSize(20, 20))
219
+ self.cancel_dataset.setVisible(False)
220
+ train_data_layout.addWidget(self.cancel_dataset, 5)
221
+
222
+
223
+ layout.addLayout(train_data_layout)
224
+
225
+ include_dataset_layout = QHBoxLayout()
226
+ include_dataset_layout.addWidget(QLabel('include dataset: '),30)
227
+ self.dataset_cb = QComboBox()
228
+
229
+ available_datasets, self.datasets_path = get_signal_datasets_list(return_path=True)
230
+ signal_datasets = ['--'] + available_datasets
231
+
232
+ self.dataset_cb.addItems(signal_datasets)
233
+ include_dataset_layout.addWidget(self.dataset_cb, 70)
234
+ layout.addLayout(include_dataset_layout)
235
+
236
+ augmentation_hbox = QHBoxLayout()
237
+ augmentation_hbox.addWidget(QLabel('augmentation\nfactor: '), 30)
238
+ self.augmentation_slider = QLabeledDoubleSlider()
239
+ self.augmentation_slider.setSingleStep(0.01)
240
+ self.augmentation_slider.setTickInterval(0.01)
241
+ self.augmentation_slider.setOrientation(1)
242
+ self.augmentation_slider.setRange(1, 5)
243
+ self.augmentation_slider.setValue(2)
244
+
245
+ augmentation_hbox.addWidget(self.augmentation_slider, 70)
246
+ layout.addLayout(augmentation_hbox)
247
+
248
+ validation_split_layout = QHBoxLayout()
249
+ validation_split_layout.addWidget(QLabel('validation split: '),30)
250
+ self.validation_slider = QLabeledDoubleSlider()
251
+ self.validation_slider.setSingleStep(0.01)
252
+ self.validation_slider.setTickInterval(0.01)
253
+ self.validation_slider.setOrientation(1)
254
+ self.validation_slider.setRange(0,0.9)
255
+ self.validation_slider.setValue(0.25)
256
+ validation_split_layout.addWidget(self.validation_slider, 70)
257
+ layout.addLayout(validation_split_layout)
258
+
259
+
260
+ def generate_model_panel_contents(self):
261
+
262
+ self.ContentsModel = QFrame()
263
+ layout = QVBoxLayout(self.ContentsModel)
264
+ layout.setContentsMargins(0,0,0,0)
265
+
266
+ modelname_layout = QHBoxLayout()
267
+ modelname_layout.addWidget(QLabel('Model name: '), 30)
268
+ self.modelname_le = QLineEdit()
269
+ self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
270
+ modelname_layout.addWidget(self.modelname_le, 70)
271
+ layout.addLayout(modelname_layout)
272
+
273
+ classname_layout = QHBoxLayout()
274
+ classname_layout.addWidget(QLabel('event name: '), 30)
275
+ self.class_name_le = QLineEdit()
276
+ self.class_name_le.setText("")
277
+ classname_layout.addWidget(self.class_name_le, 70)
278
+ layout.addLayout(classname_layout)
279
+
280
+ pretrained_layout = QHBoxLayout()
281
+ pretrained_layout.setContentsMargins(0,0,0,0)
282
+ pretrained_layout.addWidget(QLabel('Pretrained model: '), 30)
283
+
284
+ self.browse_pretrained_btn = QPushButton('Choose folder')
285
+ self.browse_pretrained_btn.clicked.connect(self.showDialog_pretrained)
286
+ pretrained_layout.addWidget(self.browse_pretrained_btn, 35)
287
+
288
+ self.pretrained_lbl = QLabel('No folder chosen')
289
+ pretrained_layout.addWidget(self.pretrained_lbl, 30)
290
+
291
+ self.cancel_pretrained = QPushButton()
292
+ self.cancel_pretrained.setIcon(icon(MDI6.close,color="black"))
293
+ self.cancel_pretrained.clicked.connect(self.clear_pretrained)
294
+ self.cancel_pretrained.setStyleSheet(self.parent.parent.parent.button_select_all)
295
+ self.cancel_pretrained.setIconSize(QSize(20, 20))
296
+ self.cancel_pretrained.setVisible(False)
297
+ pretrained_layout.addWidget(self.cancel_pretrained, 5)
298
+
299
+ layout.addLayout(pretrained_layout)
300
+
301
+ recompile_layout = QHBoxLayout()
302
+ recompile_layout.addWidget(QLabel('Recompile: '), 30)
303
+ self.recompile_option = QCheckBox()
304
+ self.recompile_option.setEnabled(False)
305
+ recompile_layout.addWidget(self.recompile_option, 70)
306
+ layout.addLayout(recompile_layout)
307
+
308
+ #self.channel_cbs = [QComboBox() for i in range(4)]
309
+
310
+ self.max_nbr_channels = 5
311
+ self.channel_cbs = [QComboBox() for i in range(self.max_nbr_channels)]
312
+ self.normalization_mode_btns = [QPushButton('') for i in range(self.max_nbr_channels)]
313
+ self.normalization_mode = [True for i in range(self.max_nbr_channels)]
314
+
315
+ self.normalization_clip_btns = [QPushButton('') for i in range(self.max_nbr_channels)]
316
+ self.clip_option = [False for i in range(self.max_nbr_channels)]
317
+
318
+ for i in range(self.max_nbr_channels):
319
+
320
+ self.normalization_mode_btns[i].setIcon(icon(MDI6.percent_circle,color="#1565c0"))
321
+ self.normalization_mode_btns[i].setIconSize(QSize(20, 20))
322
+ self.normalization_mode_btns[i].setStyleSheet(self.parent.parent.parent.button_select_all)
323
+ self.normalization_mode_btns[i].setToolTip("Switch to absolute normalization values.")
324
+ self.normalization_mode_btns[i].clicked.connect(partial(self.switch_normalization_mode, i))
325
+
326
+ self.normalization_clip_btns[i].setIcon(icon(MDI6.content_cut,color="black"))
327
+ self.normalization_clip_btns[i].setIconSize(QSize(20, 20))
328
+ self.normalization_clip_btns[i].setStyleSheet(self.parent.parent.parent.button_select_all)
329
+ self.normalization_clip_btns[i].clicked.connect(partial(self.switch_clipping_mode, i))
330
+ self.normalization_clip_btns[i].setToolTip('clip')
331
+
332
+ self.normalization_min_value_lbl = [QLabel('Min %: ') for i in range(self.max_nbr_channels)]
333
+ self.normalization_min_value_le = [QLineEdit('0.1') for i in range(self.max_nbr_channels)]
334
+
335
+ self.normalization_max_value_lbl = [QLabel('Max %: ') for i in range(self.max_nbr_channels)]
336
+ self.normalization_max_value_le = [QLineEdit('99.99') for i in range(self.max_nbr_channels)]
337
+
338
+ tables = glob(self.exp_dir+os.sep.join(['W*','*','output','tables',f'trajectories_{self.mode}.csv']))
339
+ print(tables)
340
+ all_measurements = []
341
+ for tab in tables:
342
+ cols = pd.read_csv(tab, nrows=1).columns.tolist()
343
+ all_measurements.extend(cols)
344
+ all_measurements = np.unique(all_measurements)
345
+ generic_measurements = ['brightfield_channel', 'live_nuclei_channel', 'dead_nuclei_channel',
346
+ 'effector_fluo_channel', 'adhesion_channel', 'fluo_channel_1', 'fluo_channel_2',
347
+ "area", "area_bbox","area_convex","area_filled","major_axis_length",
348
+ "minor_axis_length",
349
+ "eccentricity",
350
+ "equivalent_diameter_area",
351
+ "euler_number",
352
+ "extent",
353
+ "feret_diameter_max",
354
+ "orientation",
355
+ "perimeter",
356
+ "perimeter_crofton",
357
+ "solidity",
358
+ "angular_second_moment",
359
+ "contrast",
360
+ "correlation",
361
+ "sum_of_square_variance",
362
+ "inverse_difference_moment",
363
+ "sum_average",
364
+ "sum_variance",
365
+ "sum_entropy",
366
+ "entropy",
367
+ "difference_variance",
368
+ "difference_entropy",
369
+ "information_measure_of_correlation_1",
370
+ "information_measure_of_correlation_2",
371
+ "maximal_correlation_coefficient",
372
+ "POSITION_X",
373
+ "POSITION_Y",
374
+ ]
375
+
376
+ self.channel_items = np.unique(generic_measurements + list(all_measurements))
377
+ self.channel_items = np.insert(self.channel_items, 0, '--')
378
+
379
+ self.channel_option_layouts = []
380
+ for i in range(len(self.channel_cbs)):
381
+ ch_layout = QHBoxLayout()
382
+ ch_layout.addWidget(QLabel(f'channel {i}: '), 30)
383
+ self.channel_cbs[i].addItems(self.channel_items)
384
+ self.channel_cbs[i].currentIndexChanged.connect(self.check_valid_channels)
385
+ ch_layout.addWidget(self.channel_cbs[i], 70)
386
+ layout.addLayout(ch_layout)
387
+
388
+ channel_norm_options_layout = QHBoxLayout()
389
+ channel_norm_options_layout.setContentsMargins(130,0,0,0)
390
+ channel_norm_options_layout.addWidget(self.normalization_min_value_lbl[i])
391
+ channel_norm_options_layout.addWidget(self.normalization_min_value_le[i])
392
+ channel_norm_options_layout.addWidget(self.normalization_max_value_lbl[i])
393
+ channel_norm_options_layout.addWidget(self.normalization_max_value_le[i])
394
+ channel_norm_options_layout.addWidget(self.normalization_clip_btns[i])
395
+ channel_norm_options_layout.addWidget(self.normalization_mode_btns[i])
396
+ layout.addLayout(channel_norm_options_layout)
397
+
398
+ # for i in range(len(self.channel_cbs)):
399
+ # ch_layout = QHBoxLayout()
400
+ # ch_layout.addWidget(QLabel(f'channel {i}: '), 30)
401
+ # self.channel_cbs[i].addItems(self.channel_items)
402
+ # self.channel_cbs[i].currentIndexChanged.connect(self.check_valid_channels)
403
+ # ch_layout.addWidget(self.channel_cbs[i], 70)
404
+ # layout.addLayout(ch_layout)
405
+
406
+ model_length_layout = QHBoxLayout()
407
+ model_length_layout.addWidget(QLabel('Max signal length: '), 30)
408
+ self.model_length_slider = QLabeledSlider()
409
+ self.model_length_slider.setSingleStep(1)
410
+ self.model_length_slider.setTickInterval(1)
411
+ self.model_length_slider.setSingleStep(1)
412
+ self.model_length_slider.setOrientation(1)
413
+ self.model_length_slider.setRange(0,1024)
414
+ self.model_length_slider.setValue(128)
415
+ model_length_layout.addWidget(self.model_length_slider, 70)
416
+ layout.addLayout(model_length_layout)
417
+
418
+ def showDialog_pretrained(self):
419
+
420
+ self.pretrained_model = QFileDialog.getExistingDirectory(
421
+ self, "Open Directory",
422
+ os.sep.join([self.soft_path,'celldetective','models','signal_detection','']),
423
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
424
+ )
425
+
426
+ if self.pretrained_model is not None:
427
+ # self.foldername = self.file_dialog_pretrained.selectedFiles()[0]
428
+ subfiles = glob(os.sep.join([self.pretrained_model,"*"]))
429
+ if os.sep.join([self.pretrained_model,"config_input.json"]) in subfiles:
430
+ self.load_pretrained_config()
431
+ self.pretrained_lbl.setText(self.pretrained_model.split(os.sep)[-1])
432
+ self.cancel_pretrained.setVisible(True)
433
+ self.recompile_option.setEnabled(True)
434
+ self.modelname_le.setText(f"{self.pretrained_model.split(os.sep)[-1]}_{datetime.today().strftime('%Y-%m-%d')}")
435
+ else:
436
+ self.pretrained_model = None
437
+ self.pretrained_lbl.setText('No folder chosen')
438
+ self.recompile_option.setEnabled(False)
439
+ self.cancel_pretrained.setVisible(False)
440
+ print(self.pretrained_model)
441
+
442
+ def showDialog_dataset(self):
443
+
444
+ self.dataset_folder = QFileDialog.getExistingDirectory(
445
+ self, "Open Directory",
446
+ self.exp_dir,
447
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
448
+ )
449
+ if self.dataset_folder is not None:
450
+
451
+ subfiles = glob(os.sep.join([self.dataset_folder,"*.npy"]))
452
+ if len(subfiles)>0:
453
+ print(f'found {len(subfiles)} files in folder')
454
+ self.data_folder_label.setText(self.dataset_folder[:16]+'...')
455
+ self.data_folder_label.setToolTip(self.dataset_folder)
456
+ self.cancel_dataset.setVisible(True)
457
+ else:
458
+ self.data_folder_label.setText('No folder chosen')
459
+ self.data_folder_label.setToolTip('')
460
+ self.dataset_folder = None
461
+ self.cancel_dataset.setVisible(False)
462
+
463
+ def clear_pretrained(self):
464
+
465
+ self.pretrained_model = None
466
+ self.pretrained_lbl.setText('No folder chosen')
467
+ for cb in self.channel_cbs:
468
+ cb.setEnabled(True)
469
+ self.recompile_option.setEnabled(False)
470
+ self.cancel_pretrained.setVisible(False)
471
+ self.model_length_slider.setEnabled(True)
472
+ self.class_name_le.setText('')
473
+ self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
474
+
475
+ def clear_dataset(self):
476
+
477
+ self.dataset_folder = None
478
+ self.data_folder_label.setText('No folder chosen')
479
+ self.data_folder_label.setToolTip('')
480
+ self.cancel_dataset.setVisible(False)
481
+
482
+
483
+ def load_pretrained_config(self):
484
+
485
+ f = open(os.sep.join([self.pretrained_model,"config_input.json"]))
486
+ data = json.load(f)
487
+ channels = data["channels"]
488
+ signal_length = data["model_signal_length"]
489
+ try:
490
+ label = data['label']
491
+ self.class_name_le.setText(label)
492
+ except:
493
+ pass
494
+ self.model_length_slider.setValue(int(signal_length))
495
+ self.model_length_slider.setEnabled(False)
496
+
497
+ for c,cb in zip(channels, self.channel_cbs):
498
+ index = cb.findText(c)
499
+ cb.setCurrentIndex(index)
500
+
501
+ if len(channels)<len(self.channel_cbs):
502
+ for k in range(len(self.channel_cbs)-len(channels)):
503
+ self.channel_cbs[len(channels)+k].setCurrentIndex(0)
504
+ self.channel_cbs[len(channels)+k].setEnabled(False)
505
+
506
+
507
+ def adjustScrollArea(self):
508
+
509
+ """
510
+ Auto-adjust scroll area to fill space
511
+ (from https://stackoverflow.com/questions/66417576/make-qscrollarea-use-all-available-space-of-qmainwindow-height-axis)
512
+ """
513
+
514
+ step = 5
515
+ while self.scroll_area.verticalScrollBar().isVisible() and self.height() < self.maximumHeight():
516
+ self.resize(self.width(), self.height() + step)
517
+
518
+ def prep_model(self):
519
+
520
+ model_name = self.modelname_le.text()
521
+ pretrained_model = self.pretrained_model
522
+ signal_length = self.model_length_slider.value()
523
+ recompile_op = self.recompile_option.isChecked()
524
+
525
+ channels = []
526
+ for i in range(len(self.channel_cbs)):
527
+ channels.append(self.channel_cbs[i].currentText())
528
+
529
+ slots_to_keep = np.where(np.array(channels)!='--')[0]
530
+ while '--' in channels:
531
+ channels.remove('--')
532
+
533
+ norm_values = np.array([[float(a.replace(',','.')),float(b.replace(',','.'))] for a,b in zip([l.text() for l in self.normalization_min_value_le],
534
+ [l.text() for l in self.normalization_max_value_le])])
535
+ norm_values = norm_values[slots_to_keep]
536
+ norm_values = [list(v) for v in norm_values]
537
+
538
+ clip_values = np.array(self.clip_option)
539
+ clip_values = list(clip_values[slots_to_keep])
540
+ clip_values = [bool(c) for c in clip_values]
541
+
542
+ normalization_mode = np.array(self.normalization_mode)
543
+ normalization_mode = list(normalization_mode[slots_to_keep])
544
+ normalization_mode = [bool(m) for m in normalization_mode]
545
+
546
+ data_folders = []
547
+ if self.dataset_folder is not None:
548
+ data_folders.append(self.dataset_folder)
549
+ if self.dataset_cb.currentText()!='--':
550
+ dataset = locate_signal_dataset(self.dataset_cb.currentText())
551
+ data_folders.append(dataset)
552
+
553
+ aug_factor = self.augmentation_slider.value()
554
+ val_split = self.validation_slider.value()
555
+
556
+ try:
557
+ lr = float(self.lr_le.text().replace(',','.'))
558
+ except:
559
+ msgBox = QMessageBox()
560
+ msgBox.setIcon(QMessageBox.Warning)
561
+ msgBox.setText("Invalid value encountered for the learning rate.")
562
+ msgBox.setWindowTitle("Warning")
563
+ msgBox.setStandardButtons(QMessageBox.Ok)
564
+ returnValue = msgBox.exec()
565
+ if returnValue == QMessageBox.Ok:
566
+ return None
567
+
568
+ bs = int(self.bs_le.text())
569
+ epochs = self.epochs_slider.value()
570
+
571
+ training_instructions = {'model_name': model_name,'pretrained': pretrained_model, 'channel_option': channels, 'normalization_percentile': normalization_mode,
572
+ 'normalization_clip': clip_values,'normalization_values': norm_values, 'model_signal_length': signal_length,
573
+ 'recompile_pretrained': recompile_op, 'ds': data_folders, 'augmentation_factor': aug_factor, 'validation_split': val_split,
574
+ 'learning_rate': lr, 'batch_size': bs, 'epochs': epochs, 'label': self.class_name_le.text()}
575
+
576
+ model_folder = self.signal_models_dir + model_name + os.sep
577
+ if not os.path.exists(model_folder):
578
+ os.mkdir(model_folder)
579
+
580
+ training_instructions.update({'target_directory': self.signal_models_dir})
581
+
582
+ print(f"Set of instructions: {training_instructions}")
583
+ with open(model_folder+"training_instructions.json", 'w') as f:
584
+ json.dump(training_instructions, f, indent=4)
585
+
586
+ train_signal_model(model_folder+"training_instructions.json")
587
+
588
+ self.parent.refresh_signal_models()
589
+
590
+
591
+ def check_valid_channels(self):
592
+
593
+ if np.all([cb.currentText()=='--' for cb in self.channel_cbs]):
594
+ self.submit_btn.setEnabled(False)
595
+ else:
596
+ self.submit_btn.setEnabled(True)
597
+
598
+
599
+ def switch_normalization_mode(self, index):
600
+
601
+ """
602
+ Use absolute or percentile values for the normalization of each individual channel.
603
+
604
+ """
605
+
606
+ currentNormMode = self.normalization_mode[index]
607
+ self.normalization_mode[index] = not currentNormMode
608
+
609
+ if self.normalization_mode[index]:
610
+ self.normalization_mode_btns[index].setIcon(icon(MDI6.percent_circle,color="#1565c0"))
611
+ self.normalization_mode_btns[index].setIconSize(QSize(20, 20))
612
+ self.normalization_mode_btns[index].setStyleSheet(self.parent.parent.parent.button_select_all)
613
+ self.normalization_mode_btns[index].setToolTip("Switch to absolute normalization values.")
614
+ self.normalization_min_value_lbl[index].setText('Min %: ')
615
+ self.normalization_max_value_lbl[index].setText('Max %: ')
616
+ self.normalization_min_value_le[index].setText('0.1')
617
+ self.normalization_max_value_le[index].setText('99.99')
618
+
619
+ else:
620
+ self.normalization_mode_btns[index].setIcon(icon(MDI6.percent_circle_outline,color="black"))
621
+ self.normalization_mode_btns[index].setIconSize(QSize(20, 20))
622
+ self.normalization_mode_btns[index].setStyleSheet(self.parent.parent.parent.button_select_all)
623
+ self.normalization_mode_btns[index].setToolTip("Switch to percentile normalization values.")
624
+ self.normalization_min_value_lbl[index].setText('Min: ')
625
+ self.normalization_min_value_le[index].setText('0')
626
+ self.normalization_max_value_lbl[index].setText('Max: ')
627
+ self.normalization_max_value_le[index].setText('1000')
628
+
629
+ def switch_clipping_mode(self, index):
630
+
631
+ currentClipMode = self.clip_option[index]
632
+ self.clip_option[index] = not currentClipMode
633
+
634
+ if self.clip_option[index]:
635
+ self.normalization_clip_btns[index].setIcon(icon(MDI6.content_cut,color="#1565c0"))
636
+ self.normalization_clip_btns[index].setIconSize(QSize(20, 20))
637
+ self.normalization_clip_btns[index].setStyleSheet(self.parent.parent.parent.button_select_all)
638
+
639
+ else:
640
+ self.normalization_clip_btns[index].setIcon(icon(MDI6.content_cut,color="black"))
641
+ self.normalization_clip_btns[index].setIconSize(QSize(20, 20))
642
+ self.normalization_clip_btns[index].setStyleSheet(self.parent.parent.parent.button_select_all)
643
+