celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -1,544 +1,770 @@
1
- from PyQt5.QtWidgets import QMessageBox, QComboBox, QFrame, QCheckBox, QFileDialog, QGridLayout, QLineEdit, QVBoxLayout, QLabel, QHBoxLayout, QPushButton
2
- from PyQt5.QtCore import Qt, QSize
3
- from celldetective.gui.layouts import ChannelNormGenerator
1
+ from pathlib import Path
2
+
3
+ from PyQt5.QtWidgets import (
4
+ QMessageBox,
5
+ QComboBox,
6
+ QFrame,
7
+ QCheckBox,
8
+ QFileDialog,
9
+ QGridLayout,
10
+ QLineEdit,
11
+ QVBoxLayout,
12
+ QLabel,
13
+ QHBoxLayout,
14
+ QPushButton,
15
+ )
16
+ from PyQt5.QtCore import Qt, QSize, QThread
17
+ from celldetective.gui.base.channel_norm_generator import ChannelNormGenerator
4
18
  from superqt import QLabeledDoubleSlider, QLabeledSlider, QSearchableComboBox
5
19
  from superqt.fonticon import icon
6
20
  from fonticon_mdi6 import MDI6
7
- from celldetective.io import locate_signal_dataset, get_signal_datasets_list, load_experiment_tables
8
- from celldetective.signals import train_signal_model
9
21
  import numpy as np
10
22
  import json
11
23
  import os
12
24
  from glob import glob
13
25
  from datetime import datetime
14
26
  from pandas.api.types import is_numeric_dtype
15
- from celldetective.gui.processes.train_signal_model import TrainSignalModelProcess
16
- from celldetective.gui.workers import ProgressWindow
27
+ from celldetective.gui.workers import Runner
28
+ from celldetective.gui.dynamic_progress import DynamicProgressDialog
29
+ from PyQt5.QtCore import QThreadPool
17
30
  from celldetective.gui.settings._settings_base import CelldetectiveSettingsPanel
31
+ from celldetective.utils.data_loaders import load_experiment_tables
32
+ from celldetective.utils.model_getters import get_signal_datasets_list
33
+ from celldetective.utils.model_loaders import locate_signal_dataset
34
+ from celldetective import get_logger
35
+ import multiprocessing
36
+
37
+ logger = get_logger()
38
+
39
+
40
+ class BackgroundLoader(QThread):
41
+ def run(self):
42
+ logger.info("Loading libraries...")
43
+ try:
44
+ from celldetective.processes.train_signal_model import (
45
+ TrainSignalModelProcess,
46
+ )
47
+
48
+ self.TrainSignalModelProcess = TrainSignalModelProcess
49
+ except Exception:
50
+ logger.error("Librairies not loaded...")
51
+ logger.info("Librairies loaded...")
52
+
18
53
 
19
54
  class SettingsEventDetectionModelTraining(CelldetectiveSettingsPanel):
20
-
21
- """
22
- UI to set measurement instructions.
23
-
24
- """
25
-
26
- def __init__(self, parent_window=None, signal_mode='single-cells'):
27
-
28
- self.parent_window = parent_window
29
- self.mode = self.parent_window.mode
30
- self.exp_dir = self.parent_window.exp_dir
31
- self.pretrained_model = None
32
- self.dataset_folder = None
33
- self.current_neighborhood = None
34
- self.reference_population = None
35
- self.neighbor_population = None
36
- self.signal_mode = signal_mode
37
-
38
- super().__init__(title="Train event detection model")
39
-
40
- if self.signal_mode=='single-cells':
41
- self.signal_models_dir = self._software_path+os.sep+os.sep.join(['celldetective','models','signal_detection'])
42
- elif self.signal_mode=='pairs':
43
- self.signal_models_dir = self._software_path+os.sep+os.sep.join(['celldetective','models','pair_signal_detection'])
44
- self.mode = 'pairs'
45
-
46
- self._add_to_layout()
47
- self._load_previous_instructions()
48
-
49
- self._adjustSize()
50
- new_width = int(self.width()*1.2)
51
- self.resize(new_width, int(self._screen_height * 0.8))
52
- self.setMinimumWidth(new_width)
53
-
54
- def _add_to_layout(self):
55
- self._layout.addWidget(self.model_frame)
56
- self._layout.addWidget(self.data_frame)
57
- self._layout.addWidget(self.hyper_frame)
58
- self._layout.addWidget(self.submit_btn)
59
-
60
- def _create_widgets(self):
61
-
62
- """
63
- Create the multibox design.
64
-
65
- """
66
- super()._create_widgets()
67
-
68
- # first frame for FEATURES
69
- self.model_frame = QFrame()
70
- self.model_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
71
- self.populate_model_frame()
72
-
73
- self.data_frame = QFrame()
74
- self.data_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
75
- self.populate_data_frame()
76
-
77
- self.hyper_frame = QFrame()
78
- self.hyper_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
79
- self.populate_hyper_frame()
80
-
81
- self.submit_btn.setEnabled(False)
82
- self.submit_btn.setText("Train")
83
-
84
-
85
- def populate_hyper_frame(self):
86
-
87
- """
88
- Add widgets and layout in the POST-PROCESSING frame.
89
- """
90
-
91
- grid = QGridLayout(self.hyper_frame)
92
- grid.setContentsMargins(30,30,30,30)
93
- grid.setSpacing(30)
94
-
95
- self.hyper_lbl = QLabel("HYPERPARAMETERS")
96
- self.hyper_lbl.setStyleSheet("""
55
+ """
56
+ UI to set measurement instructions.
57
+
58
+ """
59
+
60
+ def __init__(self, parent_window=None, signal_mode="single-cells"):
61
+
62
+ self.parent_window = parent_window
63
+ self.mode = self.parent_window.mode
64
+ self.exp_dir = self.parent_window.exp_dir
65
+ self.pretrained_model = None
66
+ self.dataset_folder = None
67
+ self.current_neighborhood = None
68
+ self.reference_population = None
69
+ self.neighbor_population = None
70
+ self.signal_mode = signal_mode
71
+
72
+ super().__init__(title="Train event detection model")
73
+
74
+ if self.signal_mode == "single-cells":
75
+ self.signal_models_dir = (
76
+ self._software_path
77
+ + os.sep
78
+ + os.sep.join(["celldetective", "models", "signal_detection"])
79
+ )
80
+ elif self.signal_mode == "pairs":
81
+ self.signal_models_dir = (
82
+ self._software_path
83
+ + os.sep
84
+ + os.sep.join(["celldetective", "models", "pair_signal_detection"])
85
+ )
86
+ self.mode = "pairs"
87
+
88
+ self._add_to_layout()
89
+ self._load_previous_instructions()
90
+
91
+ self._adjust_size()
92
+ new_width = int(self.width() * 1.01)
93
+ self.resize(new_width, int(self._screen_height * 0.8))
94
+ self.setMinimumWidth(new_width)
95
+
96
+ self.bg_loader = BackgroundLoader()
97
+ self.bg_loader.start()
98
+
99
+ def _add_to_layout(self):
100
+ self._layout.addWidget(self.model_frame)
101
+ self._layout.addWidget(self.data_frame)
102
+ self._layout.addWidget(self.hyper_frame)
103
+ self._layout.addWidget(self.submit_btn)
104
+ self._layout.addWidget(self.warning_label)
105
+
106
+ def _create_widgets(self):
107
+ """
108
+ Create the multibox design.
109
+
110
+ """
111
+ super()._create_widgets()
112
+
113
+ # first frame for FEATURES
114
+ self.model_frame = QFrame()
115
+ self.model_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
116
+ self.populate_model_frame()
117
+
118
+ self.data_frame = QFrame()
119
+ self.data_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
120
+ self.populate_data_frame()
121
+
122
+ self.hyper_frame = QFrame()
123
+ self.hyper_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
124
+ self.populate_hyper_frame()
125
+
126
+ self.submit_btn.setEnabled(False)
127
+ self.submit_btn.setText("Train")
128
+
129
+ self.warning_label = QLabel("")
130
+ self.warning_label.setStyleSheet("color: red; font-weight: bold;")
131
+ self.warning_label.setAlignment(Qt.AlignCenter)
132
+ self.check_readiness()
133
+
134
+ def populate_hyper_frame(self):
135
+ """
136
+ Add widgets and layout in the POST-PROCESSING frame.
137
+ """
138
+
139
+ grid = QGridLayout(self.hyper_frame)
140
+ grid.setContentsMargins(30, 30, 30, 30)
141
+ grid.setSpacing(30)
142
+
143
+ self.hyper_lbl = QLabel("HYPERPARAMETERS")
144
+ self.hyper_lbl.setStyleSheet(
145
+ """
97
146
  font-weight: bold;
98
147
  padding: 0px;
99
- """)
100
- grid.addWidget(self.hyper_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
101
- self.generate_hyper_contents()
102
- grid.addWidget(self.ContentsHyper, 1, 0, 1, 4, alignment=Qt.AlignTop)
103
-
104
- def generate_hyper_contents(self):
105
-
106
- self.ContentsHyper = QFrame()
107
- layout = QVBoxLayout(self.ContentsHyper)
108
- layout.setContentsMargins(0,0,0,0)
109
-
110
- lr_layout = QHBoxLayout()
111
- lr_layout.addWidget(QLabel('learning rate: '),30)
112
- self.lr_le = QLineEdit('0,01')
113
- self.lr_le.setValidator(self._floatValidator)
114
- lr_layout.addWidget(self.lr_le, 70)
115
- layout.addLayout(lr_layout)
116
-
117
- bs_layout = QHBoxLayout()
118
- bs_layout.addWidget(QLabel('batch size: '),30)
119
- self.bs_le = QLineEdit('64')
120
- self.bs_le.setValidator(self._intValidator)
121
- bs_layout.addWidget(self.bs_le, 70)
122
- layout.addLayout(bs_layout)
123
-
124
- epochs_layout = QHBoxLayout()
125
- epochs_layout.addWidget(QLabel('# epochs: '), 30)
126
- self.epochs_slider = QLabeledSlider()
127
- self.epochs_slider.setRange(1,3000)
128
- self.epochs_slider.setSingleStep(1)
129
- self.epochs_slider.setTickInterval(1)
130
- self.epochs_slider.setOrientation(Qt.Horizontal)
131
- self.epochs_slider.setValue(300)
132
- epochs_layout.addWidget(self.epochs_slider, 70)
133
- layout.addLayout(epochs_layout)
134
-
135
-
136
- def populate_data_frame(self):
137
-
138
- """
139
- Add widgets and layout in the POST-PROCESSING frame.
140
- """
141
-
142
- grid = QGridLayout(self.data_frame)
143
- grid.setContentsMargins(30,30,30,30)
144
- grid.setSpacing(30)
145
-
146
- self.data_lbl = QLabel("DATA")
147
- self.data_lbl.setStyleSheet("""
148
+ """
149
+ )
150
+ grid.addWidget(self.hyper_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
151
+ self.generate_hyper_contents()
152
+ grid.addWidget(self.ContentsHyper, 1, 0, 1, 4, alignment=Qt.AlignTop)
153
+
154
+ def generate_hyper_contents(self):
155
+
156
+ self.ContentsHyper = QFrame()
157
+ layout = QVBoxLayout(self.ContentsHyper)
158
+ layout.setContentsMargins(0, 0, 0, 0)
159
+
160
+ lr_layout = QHBoxLayout()
161
+ lr_layout.addWidget(QLabel("learning rate: "), 30)
162
+ self.lr_le = QLineEdit("0,01")
163
+ self.lr_le.setValidator(self._floatValidator)
164
+ lr_layout.addWidget(self.lr_le, 70)
165
+ layout.addLayout(lr_layout)
166
+
167
+ bs_layout = QHBoxLayout()
168
+ bs_layout.addWidget(QLabel("batch size: "), 30)
169
+ self.bs_le = QLineEdit("64")
170
+ self.bs_le.setValidator(self._intValidator)
171
+ bs_layout.addWidget(self.bs_le, 70)
172
+ layout.addLayout(bs_layout)
173
+
174
+ epochs_layout = QHBoxLayout()
175
+ epochs_layout.addWidget(QLabel("# epochs: "), 30)
176
+ self.epochs_slider = QLabeledSlider()
177
+ self.epochs_slider.setRange(1, 3000)
178
+ self.epochs_slider.setSingleStep(1)
179
+ self.epochs_slider.setTickInterval(1)
180
+ self.epochs_slider.setOrientation(Qt.Horizontal)
181
+ self.epochs_slider.setValue(300)
182
+ epochs_layout.addWidget(self.epochs_slider, 70)
183
+ layout.addLayout(epochs_layout)
184
+
185
+ def populate_data_frame(self):
186
+ """
187
+ Add widgets and layout in the POST-PROCESSING frame.
188
+ """
189
+
190
+ grid = QGridLayout(self.data_frame)
191
+ grid.setContentsMargins(30, 30, 30, 30)
192
+ grid.setSpacing(30)
193
+
194
+ self.data_lbl = QLabel("DATA")
195
+ self.data_lbl.setStyleSheet(
196
+ """
148
197
  font-weight: bold;
149
198
  padding: 0px;
150
- """)
151
- grid.addWidget(self.data_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
152
- self.generate_data_contents()
153
- grid.addWidget(self.ContentsData, 1, 0, 1, 4, alignment=Qt.AlignTop)
154
-
155
- def populate_model_frame(self):
156
-
157
- """
158
- Add widgets and layout in the FEATURES frame.
159
- """
160
-
161
- grid = QGridLayout(self.model_frame)
162
- grid.setContentsMargins(30,30,30,30)
163
- grid.setSpacing(30)
164
-
165
- self.model_lbl = QLabel("MODEL")
166
- self.model_lbl.setStyleSheet("""
199
+ """
200
+ )
201
+ grid.addWidget(self.data_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
202
+ self.generate_data_contents()
203
+ grid.addWidget(self.ContentsData, 1, 0, 1, 4, alignment=Qt.AlignTop)
204
+
205
+ def populate_model_frame(self):
206
+ """
207
+ Add widgets and layout in the FEATURES frame.
208
+ """
209
+
210
+ grid = QGridLayout(self.model_frame)
211
+ grid.setContentsMargins(30, 30, 30, 30)
212
+ grid.setSpacing(30)
213
+
214
+ self.model_lbl = QLabel("MODEL")
215
+ self.model_lbl.setStyleSheet(
216
+ """
167
217
  font-weight: bold;
168
218
  padding: 0px;
169
- """)
170
- grid.addWidget(self.model_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
171
-
172
- self.generate_model_panel_contents()
173
- grid.addWidget(self.ContentsModel, 1, 0, 1, 4, alignment=Qt.AlignTop)
174
-
175
-
176
- def generate_data_contents(self):
177
-
178
- self.ContentsData = QFrame()
179
- layout = QVBoxLayout(self.ContentsData)
180
- layout.setContentsMargins(0,0,0,0)
181
-
182
- train_data_layout = QHBoxLayout()
183
- train_data_layout.addWidget(QLabel('Training data: '), 30)
184
- self.select_data_folder_btn = QPushButton('Choose folder')
185
- self.select_data_folder_btn.clicked.connect(self.showDialog_dataset)
186
- self.data_folder_label = QLabel('No folder chosen')
187
- train_data_layout.addWidget(self.select_data_folder_btn, 35)
188
- train_data_layout.addWidget(self.data_folder_label, 30)
189
-
190
- self.cancel_dataset = QPushButton()
191
- self.cancel_dataset.setIcon(icon(MDI6.close,color="black"))
192
- self.cancel_dataset.clicked.connect(self.clear_dataset)
193
- self.cancel_dataset.setStyleSheet(self.button_select_all)
194
- self.cancel_dataset.setIconSize(QSize(20, 20))
195
- self.cancel_dataset.setVisible(False)
196
- train_data_layout.addWidget(self.cancel_dataset, 5)
197
-
198
-
199
- layout.addLayout(train_data_layout)
200
-
201
- include_dataset_layout = QHBoxLayout()
202
- include_dataset_layout.addWidget(QLabel('include dataset: '),30)
203
- self.dataset_cb = QComboBox()
204
-
205
- available_datasets, self.datasets_path = get_signal_datasets_list(return_path=True)
206
- signal_datasets = ['--'] + available_datasets
207
-
208
- self.dataset_cb.addItems(signal_datasets)
209
- include_dataset_layout.addWidget(self.dataset_cb, 70)
210
- layout.addLayout(include_dataset_layout)
211
-
212
- augmentation_hbox = QHBoxLayout()
213
- augmentation_hbox.addWidget(QLabel('augmentation\nfactor: '), 30)
214
- self.augmentation_slider = QLabeledDoubleSlider()
215
- self.augmentation_slider.setSingleStep(0.01)
216
- self.augmentation_slider.setTickInterval(0.01)
217
- self.augmentation_slider.setOrientation(Qt.Horizontal)
218
- self.augmentation_slider.setRange(1, 5)
219
- self.augmentation_slider.setValue(2)
220
-
221
- augmentation_hbox.addWidget(self.augmentation_slider, 70)
222
- layout.addLayout(augmentation_hbox)
223
-
224
- validation_split_layout = QHBoxLayout()
225
- validation_split_layout.addWidget(QLabel('validation split: '),30)
226
- self.validation_slider = QLabeledDoubleSlider()
227
- self.validation_slider.setSingleStep(0.01)
228
- self.validation_slider.setTickInterval(0.01)
229
- self.validation_slider.setOrientation(Qt.Horizontal)
230
- self.validation_slider.setRange(0,0.9)
231
- self.validation_slider.setValue(0.25)
232
- validation_split_layout.addWidget(self.validation_slider, 70)
233
- layout.addLayout(validation_split_layout)
234
-
235
-
236
- def generate_model_panel_contents(self):
237
-
238
- self.ContentsModel = QFrame()
239
- layout = QVBoxLayout(self.ContentsModel)
240
- layout.setContentsMargins(0,0,0,0)
241
-
242
- modelname_layout = QHBoxLayout()
243
- modelname_layout.addWidget(QLabel('Model name: '), 30)
244
- self.modelname_le = QLineEdit()
245
- self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
246
- modelname_layout.addWidget(self.modelname_le, 70)
247
- layout.addLayout(modelname_layout)
248
-
249
- if self.signal_mode=='pairs':
250
- neighborhood_layout = QHBoxLayout()
251
- neighborhood_layout.addWidget(QLabel('neighborhood of interest: '), 30)
252
- self.neighborhood_choice_cb = QSearchableComboBox()
253
- self.fill_available_neighborhoods()
254
- neighborhood_layout.addWidget(self.neighborhood_choice_cb, 70)
255
- layout.addLayout(neighborhood_layout)
256
-
257
- classname_layout = QHBoxLayout()
258
- classname_layout.addWidget(QLabel('event name: '), 30)
259
- self.class_name_le = QLineEdit()
260
- self.class_name_le.setText("")
261
- classname_layout.addWidget(self.class_name_le, 70)
262
- layout.addLayout(classname_layout)
263
-
264
- pretrained_layout = QHBoxLayout()
265
- pretrained_layout.setContentsMargins(0,0,0,0)
266
- pretrained_layout.addWidget(QLabel('Pretrained model: '), 30)
267
-
268
- self.browse_pretrained_btn = QPushButton('Choose folder')
269
- self.browse_pretrained_btn.clicked.connect(self.showDialog_pretrained)
270
- pretrained_layout.addWidget(self.browse_pretrained_btn, 35)
271
-
272
- self.pretrained_lbl = QLabel('No folder chosen')
273
- pretrained_layout.addWidget(self.pretrained_lbl, 30)
274
-
275
- self.cancel_pretrained = QPushButton()
276
- self.cancel_pretrained.setIcon(icon(MDI6.close,color="black"))
277
- self.cancel_pretrained.clicked.connect(self.clear_pretrained)
278
- self.cancel_pretrained.setStyleSheet(self.button_select_all)
279
- self.cancel_pretrained.setIconSize(QSize(20, 20))
280
- self.cancel_pretrained.setVisible(False)
281
- pretrained_layout.addWidget(self.cancel_pretrained, 5)
282
-
283
- layout.addLayout(pretrained_layout)
284
-
285
- recompile_layout = QHBoxLayout()
286
- recompile_layout.addWidget(QLabel('Recompile: '), 30)
287
- self.recompile_option = QCheckBox()
288
- self.recompile_option.setEnabled(False)
289
- recompile_layout.addWidget(self.recompile_option, 70)
290
- layout.addLayout(recompile_layout)
291
-
292
- self.max_nbr_channels = 5
293
- self.ch_norm = ChannelNormGenerator(self, mode='signals')
294
- layout.addLayout(self.ch_norm)
295
-
296
- if self.signal_mode=='pairs':
297
- self.neighborhood_choice_cb.currentIndexChanged.connect(self.neighborhood_changed)
298
- self.neighborhood_changed()
299
-
300
- model_length_layout = QHBoxLayout()
301
- model_length_layout.addWidget(QLabel('Max signal length: '), 30)
302
- self.model_length_slider = QLabeledSlider()
303
- self.model_length_slider.setSingleStep(1)
304
- self.model_length_slider.setTickInterval(1)
305
- self.model_length_slider.setSingleStep(1)
306
- self.model_length_slider.setOrientation(Qt.Horizontal)
307
- self.model_length_slider.setRange(0,1024)
308
- self.model_length_slider.setValue(128)
309
- model_length_layout.addWidget(self.model_length_slider, 70)
310
- layout.addLayout(model_length_layout)
311
-
312
- def neighborhood_changed(self):
313
-
314
- neigh = self.neighborhood_choice_cb.currentText()
315
- self.current_neighborhood = neigh
316
- for pop in self.dataframes.keys():
317
- self.current_neighborhood = self.current_neighborhood.replace(f'{pop}_ref_', '')
318
-
319
- self.reference_population = self.neighborhood_choice_cb.currentText().split('_')[0]
320
- if '_(' in self.current_neighborhood and ')_' in self.current_neighborhood:
321
- self.neighbor_population = self.current_neighborhood.split('_(')[-1].split(')_')[0].split('-')[-1]
322
- self.reference_population = self.current_neighborhood.split('_(')[-1].split(')_')[0].split('-')[0]
323
- else:
324
- if 'self' in self.current_neighborhood:
325
- self.neighbor_population = self.reference_population
326
-
327
- print(f'Current neighborhood: {self.current_neighborhood}')
328
- print(f'New reference population: {self.reference_population}')
329
- print(f'New neighbor population: {self.neighbor_population}')
330
-
331
- self.df_reference = self.dataframes[self.reference_population]
332
- self.df_neighbor = self.dataframes[self.neighbor_population]
333
- self.df_pairs = load_experiment_tables(self.parent_window.exp_dir, population='pairs', load_pickle=False)
334
-
335
- self.df_reference = self.df_reference.rename(columns=lambda x: 'reference_' + x)
336
- num_cols_reference = [c for c in list(self.df_reference.columns) if is_numeric_dtype(self.df_reference[c])]
337
- self.df_neighbor = self.df_neighbor.rename(columns=lambda x: 'neighbor_' + x)
338
- num_cols_neighbor = [c for c in list(self.df_neighbor.columns) if is_numeric_dtype(self.df_neighbor[c])]
339
- self.df_pairs = self.df_pairs.rename(columns=lambda x: 'pair_' + x)
340
- num_cols_pairs = [c for c in list(self.df_pairs.columns) if is_numeric_dtype(self.df_pairs[c])]
341
-
342
- self.signals = ['--'] + num_cols_pairs + num_cols_reference + num_cols_neighbor
343
-
344
- for cb in self.ch_norm.channel_cbs:
345
- cb.clear()
346
- cb.addItems(self.signals)
347
-
348
- def fill_available_neighborhoods(self):
349
-
350
- self.dataframes = {}
351
- self.neighborhood_cols = []
352
- for population in self.parent_window.parent_window.populations:
353
- df_pop = load_experiment_tables(self.parent_window.exp_dir, population=population, load_pickle=True)
354
- self.dataframes.update({population: df_pop})
355
- if df_pop is not None:
356
- self.neighborhood_cols.extend(
357
- [f'{population}_ref_' + c for c in list(df_pop.columns) if c.startswith('neighborhood')])
358
-
359
- self.neighborhood_choice_cb.addItems(self.neighborhood_cols)
360
-
361
- def showDialog_pretrained(self):
362
-
363
- self.pretrained_model = QFileDialog.getExistingDirectory(
364
- self, "Open Directory",
365
- os.sep.join([self.soft_path,'celldetective','models','signal_detection','']),
366
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
367
- )
368
-
369
- if self.pretrained_model is not None:
370
- # self.foldername = self.file_dialog_pretrained.selectedFiles()[0]
371
- subfiles = glob(os.sep.join([self.pretrained_model,"*"]))
372
- if os.sep.join([self.pretrained_model,"config_input.json"]) in subfiles:
373
- self.load_pretrained_config()
374
- self.pretrained_lbl.setText(self.pretrained_model.split(os.sep)[-1])
375
- self.cancel_pretrained.setVisible(True)
376
- self.recompile_option.setEnabled(True)
377
- self.modelname_le.setText(f"{self.pretrained_model.split(os.sep)[-1]}_{datetime.today().strftime('%Y-%m-%d')}")
378
- else:
379
- self.pretrained_model = None
380
- self.pretrained_lbl.setText('No folder chosen')
381
- self.recompile_option.setEnabled(False)
382
- self.cancel_pretrained.setVisible(False)
383
- print(self.pretrained_model)
384
-
385
- def showDialog_dataset(self):
386
-
387
- self.dataset_folder = QFileDialog.getExistingDirectory(
388
- self, "Open Directory",
389
- self.exp_dir,
390
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
391
- )
392
- if self.dataset_folder is not None:
393
-
394
- subfiles = glob(os.sep.join([self.dataset_folder,"*.npy"]))
395
- if len(subfiles)>0:
396
- print(f'found {len(subfiles)} files in folder')
397
- self.data_folder_label.setText(self.dataset_folder[:16]+'...')
398
- self.data_folder_label.setToolTip(self.dataset_folder)
399
- self.cancel_dataset.setVisible(True)
400
- else:
401
- self.data_folder_label.setText('No folder chosen')
402
- self.data_folder_label.setToolTip('')
403
- self.dataset_folder = None
404
- self.cancel_dataset.setVisible(False)
405
-
406
- def clear_pretrained(self):
407
-
408
- self.pretrained_model = None
409
- self.pretrained_lbl.setText('No folder chosen')
410
- for cb in self.ch_norm.channel_cbs:
411
- cb.setEnabled(True)
412
- self.ch_norm.add_col_btn.setEnabled(True)
413
- self.recompile_option.setEnabled(False)
414
- self.cancel_pretrained.setVisible(False)
415
- self.model_length_slider.setEnabled(True)
416
- self.class_name_le.setText('')
417
- self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
418
-
419
- def clear_dataset(self):
420
-
421
- self.dataset_folder = None
422
- self.data_folder_label.setText('No folder chosen')
423
- self.data_folder_label.setToolTip('')
424
- self.cancel_dataset.setVisible(False)
425
-
426
-
427
- def load_pretrained_config(self):
428
-
429
- f = open(os.sep.join([self.pretrained_model,"config_input.json"]))
430
- data = json.load(f)
431
- channels = data["channels"]
432
- signal_length = data["model_signal_length"]
433
- try:
434
- label = data['label']
435
- self.class_name_le.setText(label)
436
- except:
437
- pass
438
- self.model_length_slider.setValue(int(signal_length))
439
- self.model_length_slider.setEnabled(False)
440
-
441
- for c,cb in zip(channels, self.ch_norm.channel_cbs):
442
- index = cb.findText(c)
443
- cb.setCurrentIndex(index)
444
-
445
- if len(channels)<len(self.ch_norm.channel_cbs):
446
- for k in range(len(self.ch_norm.channel_cbs)-len(channels)):
447
- self.ch_norm.channel_cbs[len(channels)+k].setCurrentIndex(0)
448
- self.ch_norm.channel_cbs[len(channels)+k].setEnabled(False)
449
- self.ch_norm.add_col_btn.setEnabled(False)
450
-
451
-
452
- def adjustScrollArea(self):
453
-
454
- """
455
- Auto-adjust scroll area to fill space
456
- (from https://stackoverflow.com/questions/66417576/make-qscrollarea-use-all-available-space-of-qmainwindow-height-axis)
457
- """
458
-
459
- step = 5
460
- while self.scroll_area.verticalScrollBar().isVisible() and self.height() < self.maximumHeight():
461
- self.resize(self.width(), self.height() + step)
462
-
463
- def _write_instructions(self):
464
-
465
- model_name = self.modelname_le.text()
466
- pretrained_model = self.pretrained_model
467
- signal_length = self.model_length_slider.value()
468
- recompile_op = self.recompile_option.isChecked()
469
-
470
- channels = []
471
- for i in range(len(self.ch_norm.channel_cbs)):
472
- channels.append(self.ch_norm.channel_cbs[i].currentText())
473
-
474
- slots_to_keep = np.where(np.array(channels)!='--')[0]
475
- while '--' in channels:
476
- channels.remove('--')
477
-
478
- norm_values = np.array([[float(a.replace(',','.')),float(b.replace(',','.'))] for a,b in zip([l.text() for l in self.ch_norm.normalization_min_value_le],
479
- [l.text() for l in self.ch_norm.normalization_max_value_le])])
480
- norm_values = norm_values[slots_to_keep]
481
- norm_values = [list(v) for v in norm_values]
482
-
483
- clip_values = np.array(self.ch_norm.clip_option)
484
- clip_values = list(clip_values[slots_to_keep])
485
- clip_values = [bool(c) for c in clip_values]
486
-
487
- normalization_mode = np.array(self.ch_norm.normalization_mode)
488
- normalization_mode = list(normalization_mode[slots_to_keep])
489
- normalization_mode = [bool(m) for m in normalization_mode]
490
-
491
- data_folders = []
492
- if self.dataset_folder is not None:
493
- data_folders.append(self.dataset_folder)
494
- if self.dataset_cb.currentText()!='--':
495
- dataset = locate_signal_dataset(self.dataset_cb.currentText())
496
- data_folders.append(dataset)
497
-
498
- aug_factor = self.augmentation_slider.value()
499
- val_split = self.validation_slider.value()
500
-
501
- try:
502
- lr = float(self.lr_le.text().replace(',','.'))
503
- except:
504
- msgBox = QMessageBox()
505
- msgBox.setIcon(QMessageBox.Warning)
506
- msgBox.setText("Invalid value encountered for the learning rate.")
507
- msgBox.setWindowTitle("Warning")
508
- msgBox.setStandardButtons(QMessageBox.Ok)
509
- returnValue = msgBox.exec()
510
- if returnValue == QMessageBox.Ok:
511
- return None
512
-
513
- bs = int(self.bs_le.text())
514
- epochs = self.epochs_slider.value()
515
-
516
- training_instructions = {'model_name': model_name,'pretrained': pretrained_model, 'channel_option': channels, 'normalization_percentile': normalization_mode,
517
- 'normalization_clip': clip_values,'normalization_values': norm_values, 'model_signal_length': signal_length,
518
- 'recompile_pretrained': recompile_op, 'ds': data_folders, 'augmentation_factor': aug_factor, 'validation_split': val_split,
519
- 'learning_rate': lr, 'batch_size': bs, 'epochs': epochs, 'label': self.class_name_le.text(), 'neighborhood_of_interest': self.current_neighborhood, 'reference_population': self.reference_population, 'neighbor_population': self.neighbor_population}
520
-
521
- model_folder = self.signal_models_dir +os.sep+ model_name + os.sep
522
- if not os.path.exists(model_folder):
523
- os.mkdir(model_folder)
524
-
525
- training_instructions.update({'target_directory': self.signal_models_dir})
526
-
527
- print(f"Set of instructions: {training_instructions}")
528
- with open(model_folder+"training_instructions.json", 'w') as f:
529
- json.dump(training_instructions, f, indent=4)
530
-
531
- # self.instructions = model_folder+"training_instructions.json"
532
- # process_args = {"instructions": self.instructions} # "use_gpu": self.use_gpu
533
- # self.job = ProgressWindow(TrainSignalModelProcess, parent_window=self, title="Training", position_info=False, process_args=process_args)
534
- # result = self.job.exec_()
535
- # if result == QDialog.Accepted:
536
- # pass
537
- # elif result == QDialog.Rejected:
538
- # return None
539
-
540
- train_signal_model(model_folder+"training_instructions.json")
541
- self.parent_window.refresh_signal_models()
542
-
543
- def _load_previous_instructions(self):
544
- pass
219
+ """
220
+ )
221
+ grid.addWidget(self.model_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
222
+
223
+ self.generate_model_panel_contents()
224
+ grid.addWidget(self.ContentsModel, 1, 0, 1, 4, alignment=Qt.AlignTop)
225
+
226
+ def generate_data_contents(self):
227
+
228
+ self.ContentsData = QFrame()
229
+ layout = QVBoxLayout(self.ContentsData)
230
+ layout.setContentsMargins(0, 0, 0, 0)
231
+
232
+ train_data_layout = QHBoxLayout()
233
+ train_data_layout.addWidget(QLabel("Training data: "), 30)
234
+ self.select_data_folder_btn = QPushButton("Choose folder")
235
+ self.select_data_folder_btn.clicked.connect(self.show_dialog_dataset)
236
+ self.data_folder_label = QLabel("No folder chosen")
237
+ train_data_layout.addWidget(self.select_data_folder_btn, 35)
238
+ train_data_layout.addWidget(self.data_folder_label, 30)
239
+
240
+ self.cancel_dataset = QPushButton()
241
+ self.cancel_dataset.setIcon(icon(MDI6.close, color="black"))
242
+ self.cancel_dataset.clicked.connect(self.clear_dataset)
243
+ self.cancel_dataset.setStyleSheet(self.button_select_all)
244
+ self.cancel_dataset.setIconSize(QSize(20, 20))
245
+ self.cancel_dataset.setVisible(False)
246
+ train_data_layout.addWidget(self.cancel_dataset, 5)
247
+
248
+ layout.addLayout(train_data_layout)
249
+
250
+ include_dataset_layout = QHBoxLayout()
251
+ include_dataset_layout.addWidget(QLabel("include dataset: "), 30)
252
+ self.dataset_cb = QComboBox()
253
+
254
+ available_datasets, self.datasets_path = get_signal_datasets_list(
255
+ return_path=True
256
+ )
257
+ signal_datasets = ["--"] + available_datasets
258
+
259
+ self.dataset_cb.addItems(signal_datasets)
260
+ self.dataset_cb.currentTextChanged.connect(self.check_readiness)
261
+ include_dataset_layout.addWidget(self.dataset_cb, 70)
262
+ layout.addLayout(include_dataset_layout)
263
+
264
+ augmentation_hbox = QHBoxLayout()
265
+ augmentation_hbox.addWidget(QLabel("augmentation\nfactor: "), 30)
266
+ self.augmentation_slider = QLabeledDoubleSlider()
267
+ self.augmentation_slider.setSingleStep(0.01)
268
+ self.augmentation_slider.setTickInterval(0.01)
269
+ self.augmentation_slider.setOrientation(Qt.Horizontal)
270
+ self.augmentation_slider.setRange(1, 5)
271
+ self.augmentation_slider.setValue(2)
272
+
273
+ augmentation_hbox.addWidget(self.augmentation_slider, 70)
274
+ layout.addLayout(augmentation_hbox)
275
+
276
+ validation_split_layout = QHBoxLayout()
277
+ validation_split_layout.addWidget(QLabel("validation split: "), 30)
278
+ self.validation_slider = QLabeledDoubleSlider()
279
+ self.validation_slider.setSingleStep(0.01)
280
+ self.validation_slider.setTickInterval(0.01)
281
+ self.validation_slider.setOrientation(Qt.Horizontal)
282
+ self.validation_slider.setRange(0, 0.9)
283
+ self.validation_slider.setValue(0.25)
284
+ validation_split_layout.addWidget(self.validation_slider, 70)
285
+ layout.addLayout(validation_split_layout)
286
+
287
+ def generate_model_panel_contents(self):
288
+
289
+ self.ContentsModel = QFrame()
290
+ layout = QVBoxLayout(self.ContentsModel)
291
+ layout.setContentsMargins(0, 0, 0, 0)
292
+
293
+ modelname_layout = QHBoxLayout()
294
+ modelname_layout.addWidget(QLabel("Model name: "), 30)
295
+ self.modelname_le = QLineEdit()
296
+ self.modelname_le.setText(
297
+ f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}"
298
+ )
299
+ modelname_layout.addWidget(self.modelname_le, 70)
300
+ layout.addLayout(modelname_layout)
301
+
302
+ if self.signal_mode == "pairs":
303
+ neighborhood_layout = QHBoxLayout()
304
+ neighborhood_layout.addWidget(QLabel("neighborhood of interest: "), 30)
305
+ self.neighborhood_choice_cb = QSearchableComboBox()
306
+ self.fill_available_neighborhoods()
307
+ neighborhood_layout.addWidget(self.neighborhood_choice_cb, 70)
308
+ layout.addLayout(neighborhood_layout)
309
+
310
+ classname_layout = QHBoxLayout()
311
+ classname_layout.addWidget(QLabel("event name: "), 30)
312
+ self.class_name_le = QLineEdit()
313
+ self.class_name_le.setText("")
314
+ classname_layout.addWidget(self.class_name_le, 70)
315
+ layout.addLayout(classname_layout)
316
+
317
+ pretrained_layout = QHBoxLayout()
318
+ pretrained_layout.setContentsMargins(0, 0, 0, 0)
319
+ pretrained_layout.addWidget(QLabel("Pretrained model: "), 30)
320
+
321
+ self.browse_pretrained_btn = QPushButton("Choose folder")
322
+ self.browse_pretrained_btn.clicked.connect(self.show_dialog_pretrained)
323
+ pretrained_layout.addWidget(self.browse_pretrained_btn, 35)
324
+
325
+ self.pretrained_lbl = QLabel("No folder chosen")
326
+ pretrained_layout.addWidget(self.pretrained_lbl, 30)
327
+
328
+ self.cancel_pretrained = QPushButton()
329
+ self.cancel_pretrained.setIcon(icon(MDI6.close, color="black"))
330
+ self.cancel_pretrained.clicked.connect(self.clear_pretrained)
331
+ self.cancel_pretrained.setStyleSheet(self.button_select_all)
332
+ self.cancel_pretrained.setIconSize(QSize(20, 20))
333
+ self.cancel_pretrained.setVisible(False)
334
+ pretrained_layout.addWidget(self.cancel_pretrained, 5)
335
+
336
+ layout.addLayout(pretrained_layout)
337
+
338
+ recompile_layout = QHBoxLayout()
339
+ recompile_layout.addWidget(QLabel("Recompile: "), 30)
340
+ self.recompile_option = QCheckBox()
341
+ self.recompile_option.setEnabled(False)
342
+ recompile_layout.addWidget(self.recompile_option, 70)
343
+ layout.addLayout(recompile_layout)
344
+
345
+ self.max_nbr_channels = 5
346
+ self.ch_norm = ChannelNormGenerator(self, mode="signals")
347
+ layout.addLayout(self.ch_norm)
348
+
349
+ if self.signal_mode == "pairs":
350
+ self.neighborhood_choice_cb.currentIndexChanged.connect(
351
+ self.neighborhood_changed
352
+ )
353
+ self.neighborhood_changed()
354
+
355
+ model_length_layout = QHBoxLayout()
356
+ model_length_layout.addWidget(QLabel("Max signal length: "), 30)
357
+ self.model_length_slider = QLabeledSlider()
358
+ self.model_length_slider.setSingleStep(1)
359
+ self.model_length_slider.setTickInterval(1)
360
+ self.model_length_slider.setSingleStep(1)
361
+ self.model_length_slider.setOrientation(Qt.Horizontal)
362
+ self.model_length_slider.setRange(0, 1024)
363
+ self.model_length_slider.setValue(128)
364
+ model_length_layout.addWidget(self.model_length_slider, 70)
365
+ layout.addLayout(model_length_layout)
366
+
367
+ def neighborhood_changed(self):
368
+
369
+ neigh = self.neighborhood_choice_cb.currentText()
370
+ self.current_neighborhood = neigh
371
+ for pop in self.dataframes.keys():
372
+ self.current_neighborhood = self.current_neighborhood.replace(
373
+ f"{pop}_ref_", ""
374
+ )
375
+
376
+ self.reference_population = self.neighborhood_choice_cb.currentText().split(
377
+ "_"
378
+ )[0]
379
+ if "_(" in self.current_neighborhood and ")_" in self.current_neighborhood:
380
+ self.neighbor_population = (
381
+ self.current_neighborhood.split("_(")[-1].split(")_")[0].split("-")[-1]
382
+ )
383
+ self.reference_population = (
384
+ self.current_neighborhood.split("_(")[-1].split(")_")[0].split("-")[0]
385
+ )
386
+ else:
387
+ if "self" in self.current_neighborhood:
388
+ self.neighbor_population = self.reference_population
389
+
390
+ logger.info(f"Current neighborhood: {self.current_neighborhood}")
391
+ logger.info(f"New reference population: {self.reference_population}")
392
+ logger.info(f"New neighbor population: {self.neighbor_population}")
393
+
394
+ self.df_reference = self.dataframes[self.reference_population]
395
+ self.df_neighbor = self.dataframes[self.neighbor_population]
396
+ self.df_pairs = load_experiment_tables(
397
+ self.parent_window.exp_dir, population="pairs", load_pickle=False
398
+ )
399
+
400
+ self.df_reference = self.df_reference.rename(columns=lambda x: "reference_" + x)
401
+ num_cols_reference = [
402
+ c
403
+ for c in list(self.df_reference.columns)
404
+ if is_numeric_dtype(self.df_reference[c])
405
+ ]
406
+ self.df_neighbor = self.df_neighbor.rename(columns=lambda x: "neighbor_" + x)
407
+ num_cols_neighbor = [
408
+ c
409
+ for c in list(self.df_neighbor.columns)
410
+ if is_numeric_dtype(self.df_neighbor[c])
411
+ ]
412
+ self.df_pairs = self.df_pairs.rename(columns=lambda x: "pair_" + x)
413
+ num_cols_pairs = [
414
+ c for c in list(self.df_pairs.columns) if is_numeric_dtype(self.df_pairs[c])
415
+ ]
416
+
417
+ self.signals = ["--"] + num_cols_pairs + num_cols_reference + num_cols_neighbor
418
+
419
+ for cb in self.ch_norm.channel_cbs:
420
+ cb.clear()
421
+ cb.addItems(self.signals)
422
+
423
+ def fill_available_neighborhoods(self):
424
+
425
+ self.dataframes = {}
426
+ self.neighborhood_cols = []
427
+ for population in self.parent_window.parent_window.populations:
428
+ df_pop = load_experiment_tables(
429
+ self.parent_window.exp_dir, population=population, load_pickle=True
430
+ )
431
+ self.dataframes.update({population: df_pop})
432
+ if df_pop is not None:
433
+ self.neighborhood_cols.extend(
434
+ [
435
+ f"{population}_ref_" + c
436
+ for c in list(df_pop.columns)
437
+ if c.startswith("neighborhood")
438
+ ]
439
+ )
440
+
441
+ self.neighborhood_choice_cb.addItems(self.neighborhood_cols)
442
+
443
+ def show_dialog_pretrained(self):
444
+
445
+ self.pretrained_model = QFileDialog.getExistingDirectory(
446
+ self,
447
+ "Open Directory",
448
+ os.sep.join(
449
+ [self._software_path, "celldetective", "models", "signal_detection", ""]
450
+ ),
451
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
452
+ )
453
+
454
+ if self.pretrained_model is not None:
455
+ # self.foldername = self.file_dialog_pretrained.selectedFiles()[0]
456
+ subfiles = glob(os.sep.join([self.pretrained_model, "*"]))
457
+ if os.sep.join([self.pretrained_model, "config_input.json"]) in subfiles:
458
+ self.load_pretrained_config()
459
+ self.pretrained_lbl.setText(Path(self.pretrained_model).name)
460
+ self.cancel_pretrained.setVisible(True)
461
+ self.recompile_option.setEnabled(True)
462
+ self.modelname_le.setText(
463
+ f"{Path(self.pretrained_model).name}_{datetime.today().strftime('%Y-%m-%d')}"
464
+ )
465
+ else:
466
+ self.pretrained_model = None
467
+ self.pretrained_lbl.setText("No folder chosen")
468
+ self.recompile_option.setEnabled(False)
469
+ self.cancel_pretrained.setVisible(False)
470
+ logger.info(self.pretrained_model)
471
+
472
+ def show_dialog_dataset(self):
473
+
474
+ self.dataset_folder = QFileDialog.getExistingDirectory(
475
+ self,
476
+ "Open Directory",
477
+ self.exp_dir,
478
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
479
+ )
480
+ if self.dataset_folder is not None:
481
+
482
+ subfiles = glob(os.sep.join([self.dataset_folder, "*.npy"]))
483
+ if len(subfiles) > 0:
484
+ logger.info(f"found {len(subfiles)} files in folder")
485
+ self.data_folder_label.setText(self.dataset_folder[:16] + "...")
486
+ self.data_folder_label.setToolTip(self.dataset_folder)
487
+ self.data_folder_label.setToolTip(self.dataset_folder)
488
+ self.cancel_dataset.setVisible(True)
489
+ self.check_readiness()
490
+ else:
491
+ self.data_folder_label.setText("No folder chosen")
492
+ self.data_folder_label.setToolTip("")
493
+ self.dataset_folder = None
494
+ self.dataset_folder = None
495
+ self.cancel_dataset.setVisible(False)
496
+ self.check_readiness()
497
+
498
+ def clear_pretrained(self):
499
+
500
+ self.pretrained_model = None
501
+ self.pretrained_lbl.setText("No folder chosen")
502
+ for cb in self.ch_norm.channel_cbs:
503
+ cb.setEnabled(True)
504
+ self.ch_norm.add_col_btn.setEnabled(True)
505
+ self.recompile_option.setEnabled(False)
506
+ self.cancel_pretrained.setVisible(False)
507
+ self.model_length_slider.setEnabled(True)
508
+ self.class_name_le.setText("")
509
+ self.modelname_le.setText(
510
+ f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}"
511
+ )
512
+
513
+ def check_readiness(self):
514
+ if self.dataset_folder is None and self.dataset_cb.currentText() == "--":
515
+ self.submit_btn.setEnabled(False)
516
+ self.warning_label.setText("Please provide a dataset to train the model.")
517
+ else:
518
+ self.submit_btn.setEnabled(True)
519
+ self.warning_label.setText("")
520
+
521
+ def clear_dataset(self):
522
+
523
+ self.dataset_folder = None
524
+ self.data_folder_label.setText("No folder chosen")
525
+ self.data_folder_label.setToolTip("")
526
+ self.data_folder_label.setToolTip("")
527
+ self.cancel_dataset.setVisible(False)
528
+ self.check_readiness()
529
+
530
+ def load_pretrained_config(self):
531
+
532
+ f = open(os.sep.join([self.pretrained_model, "config_input.json"]))
533
+ data = json.load(f)
534
+ channels = data["channels"]
535
+ signal_length = data["model_signal_length"]
536
+ try:
537
+ label = data["label"]
538
+ self.class_name_le.setText(label)
539
+ except:
540
+ pass
541
+ self.model_length_slider.setValue(int(signal_length))
542
+ self.model_length_slider.setEnabled(False)
543
+
544
+ for c, cb in zip(channels, self.ch_norm.channel_cbs):
545
+ index = cb.findText(c)
546
+ cb.setCurrentIndex(index)
547
+
548
+ if len(channels) < len(self.ch_norm.channel_cbs):
549
+ for k in range(len(self.ch_norm.channel_cbs) - len(channels)):
550
+ self.ch_norm.channel_cbs[len(channels) + k].setCurrentIndex(0)
551
+ self.ch_norm.channel_cbs[len(channels) + k].setEnabled(False)
552
+ self.ch_norm.add_col_btn.setEnabled(False)
553
+
554
+ def adjust_scroll_area(self):
555
+ """
556
+ Auto-adjust scroll area to fill space
557
+ (from https://stackoverflow.com/questions/66417576/make-qscrollarea-use-all-available-space-of-qmainwindow-height-axis)
558
+ """
559
+
560
+ step = 5
561
+ while (
562
+ self.scroll_area.verticalScrollBar().isVisible()
563
+ and self.height() < self.maximumHeight()
564
+ ):
565
+ self.resize(self.width(), self.height() + step)
566
+
567
+ def _write_instructions(self):
568
+ if self.bg_loader.isFinished() and hasattr(
569
+ self.bg_loader, "TrainSignalModelProcess"
570
+ ):
571
+ TrainSignalModelProcess = self.bg_loader.TrainSignalModelProcess
572
+ else:
573
+ from celldetective.processes.train_signal_model import (
574
+ TrainSignalModelProcess,
575
+ )
576
+
577
+ model_name = self.modelname_le.text()
578
+ pretrained_model = self.pretrained_model
579
+ signal_length = self.model_length_slider.value()
580
+ recompile_op = self.recompile_option.isChecked()
581
+
582
+ channels = []
583
+ for i in range(len(self.ch_norm.channel_cbs)):
584
+ channels.append(self.ch_norm.channel_cbs[i].currentText())
585
+
586
+ slots_to_keep = np.where(np.array(channels) != "--")[0]
587
+ while "--" in channels:
588
+ channels.remove("--")
589
+
590
+ norm_values = np.array(
591
+ [
592
+ [float(a.replace(",", ".")), float(b.replace(",", "."))]
593
+ for a, b in zip(
594
+ [l.text() for l in self.ch_norm.normalization_min_value_le],
595
+ [l.text() for l in self.ch_norm.normalization_max_value_le],
596
+ )
597
+ ]
598
+ )
599
+ norm_values = norm_values[slots_to_keep]
600
+ norm_values = [list(v) for v in norm_values]
601
+
602
+ clip_values = np.array(self.ch_norm.clip_option)
603
+ clip_values = list(clip_values[slots_to_keep])
604
+ clip_values = [bool(c) for c in clip_values]
605
+
606
+ normalization_mode = np.array(self.ch_norm.normalization_mode)
607
+ normalization_mode = list(normalization_mode[slots_to_keep])
608
+ normalization_mode = [bool(m) for m in normalization_mode]
609
+
610
+ data_folders = []
611
+ if self.dataset_folder is not None:
612
+ data_folders.append(self.dataset_folder)
613
+ if self.dataset_cb.currentText() != "--":
614
+ dataset = locate_signal_dataset(self.dataset_cb.currentText())
615
+ data_folders.append(dataset)
616
+
617
+ aug_factor = self.augmentation_slider.value()
618
+ val_split = self.validation_slider.value()
619
+
620
+ try:
621
+ lr = float(self.lr_le.text().replace(",", "."))
622
+ except:
623
+ msg_box = QMessageBox()
624
+ msg_box.setIcon(QMessageBox.Warning)
625
+ msg_box.setText("Invalid value encountered for the learning rate.")
626
+ msg_box.setWindowTitle("Warning")
627
+ msg_box.setStandardButtons(QMessageBox.Ok)
628
+ return_value = msg_box.exec()
629
+ if return_value == QMessageBox.Ok:
630
+ return None
631
+
632
+ bs = int(self.bs_le.text())
633
+ epochs = self.epochs_slider.value()
634
+
635
+ training_instructions = {
636
+ "model_name": model_name,
637
+ "pretrained": pretrained_model,
638
+ "channel_option": channels,
639
+ "normalization_percentile": normalization_mode,
640
+ "normalization_clip": clip_values,
641
+ "normalization_values": norm_values,
642
+ "model_signal_length": signal_length,
643
+ "recompile_pretrained": recompile_op,
644
+ "ds": data_folders,
645
+ "augmentation_factor": aug_factor,
646
+ "validation_split": val_split,
647
+ "learning_rate": lr,
648
+ "batch_size": bs,
649
+ "epochs": epochs,
650
+ "label": self.class_name_le.text(),
651
+ "neighborhood_of_interest": self.current_neighborhood,
652
+ "reference_population": self.reference_population,
653
+ "neighbor_population": self.neighbor_population,
654
+ }
655
+
656
+ model_folder = self.signal_models_dir + os.sep + model_name + os.sep
657
+ logger.info(f"{self.signal_models_dir=} {model_name=}")
658
+ if not os.path.exists(model_folder):
659
+ os.mkdir(model_folder)
660
+
661
+ training_instructions.update({"target_directory": self.signal_models_dir})
662
+
663
+ logger.info(f"Set of instructions: {training_instructions}")
664
+ with open(model_folder + "training_instructions.json", "w") as f:
665
+ json.dump(training_instructions, f, indent=4)
666
+
667
+ self.instructions = model_folder + "training_instructions.json"
668
+
669
+ # Simple Progress Window implementation
670
+ self.stop_event = multiprocessing.Event()
671
+ process_args = {
672
+ "instructions": self.instructions,
673
+ "stop_event": self.stop_event,
674
+ }
675
+ self.training_was_cancelled = False
676
+ self.is_finished = False
677
+
678
+ self.progress_dialog = DynamicProgressDialog(
679
+ label_text="Preparing model training...",
680
+ max_epochs=epochs,
681
+ parent=self,
682
+ title="Training Event Model",
683
+ )
684
+ # self.progress_dialog.setMinimumDuration(0) # Standard QProgressDialog method, might not be in Dynamic
685
+
686
+ # Create Runner (Thread Logic)
687
+ self.runner = Runner(process=TrainSignalModelProcess, process_args=process_args)
688
+
689
+ # Connect Signals
690
+ self.runner.signals.update_pos.connect(self.progress_dialog.update_progress)
691
+ self.runner.signals.update_pos_time.connect(
692
+ lambda t: self.progress_dialog.status_label.setText(
693
+ f"Training model... {t}"
694
+ )
695
+ )
696
+ self.runner.signals.update_plot.connect(self.progress_dialog.update_plot)
697
+ self.runner.signals.training_result.connect(self.progress_dialog.show_result)
698
+ self.runner.signals.update_status.connect(self.progress_dialog.update_status)
699
+
700
+ self.runner.signals.finished.connect(self.on_training_finished)
701
+ self.runner.signals.error.connect(self.on_training_error)
702
+
703
+ # Handle Cancel & Interrupt
704
+ self.progress_dialog.canceled.connect(self.on_training_cancel)
705
+ self.progress_dialog.interrupted.connect(self.on_training_interrupt)
706
+
707
+ # Start
708
+ self.pool = QThreadPool.globalInstance()
709
+ self.pool.start(self.runner)
710
+ self.progress_dialog.exec_()
711
+
712
+ def on_training_finished(self):
713
+ if self.training_was_cancelled:
714
+ return
715
+
716
+ self.is_finished = True # Mark as complete
717
+
718
+ # Keep dialog open for result viewing
719
+ self.progress_dialog.status_label.setText(
720
+ "Training Finished. Result displayed."
721
+ )
722
+ self.progress_dialog.cancel_btn.setText("Close")
723
+ self.progress_dialog.progress_bar.setValue(
724
+ self.progress_dialog.progress_bar.maximum()
725
+ )
726
+
727
+ self.runner.close()
728
+ self.parent_window.refresh_signal_models()
729
+ # MessageBox removed to allow viewing results in popup
730
+
731
+ def on_training_error(self, message):
732
+ if self.training_was_cancelled:
733
+ return
734
+ self.progress_dialog.close()
735
+ QMessageBox.critical(self, "Error", f"Training failed: {message}")
736
+
737
+ def on_training_cancel(self):
738
+ if self.is_finished:
739
+ logger.info("Training complete, dialog closed.")
740
+ self.runner.close()
741
+ self.progress_dialog.close()
742
+ return
743
+
744
+ self.training_was_cancelled = True
745
+ self.runner.close()
746
+
747
+ # Deep clean: Delete the model folder if cancelled
748
+ try:
749
+ import shutil
750
+
751
+ # Assuming signal/event models directory. Verifying path would be better but this fits pattern.
752
+ # If exact attribute unknown, use os.path.dirname logic or config methods.
753
+ # Safe bet: self.parent_window.signal_models_dir based on refresh call.
754
+ model_path = os.path.join(
755
+ self.parent_window.signal_models_dir, self.modelname_le.text()
756
+ )
757
+ if os.path.exists(model_path):
758
+ time.sleep(0.5)
759
+ shutil.rmtree(model_path)
760
+ logger.info(f"Cancelled training. Deleted model folder: {model_path}")
761
+ except Exception as e:
762
+ logger.error(f"Could not delete model folder after cancel: {e}")
763
+ logger.info("Training cancelled.")
764
+
765
+ def on_training_interrupt(self):
766
+ logger.info("Training interrupted by user (Skip Model).")
767
+ self.stop_event.set()
768
+
769
+ def _load_previous_instructions(self):
770
+ pass