celldetective 1.4.2__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +304 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/measure_cells.py +565 -0
  81. celldetective/processes/segment_cells.py +760 -0
  82. celldetective/processes/track_cells.py +435 -0
  83. celldetective/processes/train_segmentation_model.py +694 -0
  84. celldetective/processes/train_signal_model.py +265 -0
  85. celldetective/processes/unified_process.py +292 -0
  86. celldetective/regionprops/_regionprops.py +358 -317
  87. celldetective/relative_measurements.py +987 -710
  88. celldetective/scripts/measure_cells.py +313 -212
  89. celldetective/scripts/measure_relative.py +90 -46
  90. celldetective/scripts/segment_cells.py +165 -104
  91. celldetective/scripts/segment_cells_thresholds.py +96 -68
  92. celldetective/scripts/track_cells.py +198 -149
  93. celldetective/scripts/train_segmentation_model.py +324 -201
  94. celldetective/scripts/train_signal_model.py +87 -45
  95. celldetective/segmentation.py +844 -749
  96. celldetective/signals.py +3514 -2861
  97. celldetective/tracking.py +30 -15
  98. celldetective/utils/__init__.py +0 -0
  99. celldetective/utils/cellpose_utils/__init__.py +133 -0
  100. celldetective/utils/color_mappings.py +42 -0
  101. celldetective/utils/data_cleaning.py +630 -0
  102. celldetective/utils/data_loaders.py +450 -0
  103. celldetective/utils/dataset_helpers.py +207 -0
  104. celldetective/utils/downloaders.py +197 -0
  105. celldetective/utils/event_detection/__init__.py +8 -0
  106. celldetective/utils/experiment.py +1782 -0
  107. celldetective/utils/image_augmenters.py +308 -0
  108. celldetective/utils/image_cleaning.py +74 -0
  109. celldetective/utils/image_loaders.py +926 -0
  110. celldetective/utils/image_transforms.py +335 -0
  111. celldetective/utils/io.py +62 -0
  112. celldetective/utils/mask_cleaning.py +348 -0
  113. celldetective/utils/mask_transforms.py +5 -0
  114. celldetective/utils/masks.py +184 -0
  115. celldetective/utils/maths.py +351 -0
  116. celldetective/utils/model_getters.py +325 -0
  117. celldetective/utils/model_loaders.py +296 -0
  118. celldetective/utils/normalization.py +380 -0
  119. celldetective/utils/parsing.py +465 -0
  120. celldetective/utils/plots/__init__.py +0 -0
  121. celldetective/utils/plots/regression.py +53 -0
  122. celldetective/utils/resources.py +34 -0
  123. celldetective/utils/stardist_utils/__init__.py +104 -0
  124. celldetective/utils/stats.py +90 -0
  125. celldetective/utils/types.py +21 -0
  126. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
  127. celldetective-1.5.0b0.dist-info/RECORD +187 -0
  128. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
  129. tests/gui/test_new_project.py +129 -117
  130. tests/gui/test_project.py +127 -79
  131. tests/test_filters.py +39 -15
  132. tests/test_notebooks.py +8 -0
  133. tests/test_tracking.py +232 -13
  134. tests/test_utils.py +123 -77
  135. celldetective/gui/base_components.py +0 -23
  136. celldetective/gui/layouts.py +0 -1602
  137. celldetective/gui/processes/compute_neighborhood.py +0 -594
  138. celldetective/gui/processes/measure_cells.py +0 -360
  139. celldetective/gui/processes/segment_cells.py +0 -499
  140. celldetective/gui/processes/track_cells.py +0 -303
  141. celldetective/gui/processes/train_segmentation_model.py +0 -270
  142. celldetective/gui/processes/train_signal_model.py +0 -108
  143. celldetective/gui/table_ops/merge_groups.py +0 -118
  144. celldetective/gui/viewers.py +0 -1354
  145. celldetective/io.py +0 -3663
  146. celldetective/utils.py +0 -3108
  147. celldetective-1.4.2.dist-info/RECORD +0 -123
  148. /celldetective/{gui/processes → processes}/downloader.py +0 -0
  149. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
@@ -1,580 +1,795 @@
1
- from PyQt5.QtWidgets import QRadioButton, QComboBox, QFrame, QFileDialog, QGridLayout, QLineEdit, QVBoxLayout, QLabel, QHBoxLayout, QPushButton
2
- from PyQt5.QtCore import Qt, QSize
3
- from celldetective.gui.gui_utils import generic_message
4
- from celldetective.gui.layouts import ChannelNormGenerator
5
-
6
- from superqt import QLabeledDoubleSlider,QLabeledSlider
1
+ from PyQt5.QtWidgets import (
2
+ QRadioButton,
3
+ QComboBox,
4
+ QFrame,
5
+ QFileDialog,
6
+ QGridLayout,
7
+ QLineEdit,
8
+ QVBoxLayout,
9
+ QLabel,
10
+ QHBoxLayout,
11
+ QPushButton,
12
+ QMessageBox,
13
+ )
14
+ from PyQt5.QtCore import Qt, QSize, QThreadPool, QThread
15
+ from celldetective.gui.base.components import generic_message
16
+ from celldetective.gui.base.channel_norm_generator import ChannelNormGenerator
17
+ import multiprocessing
18
+ from celldetective.gui.workers import Runner
19
+ from celldetective.gui.dynamic_progress import DynamicProgressDialog
20
+ from superqt import QLabeledDoubleSlider, QLabeledSlider
7
21
  from superqt.fonticon import icon
8
22
  from fonticon_mdi6 import MDI6
9
- from celldetective.io import get_segmentation_datasets_list, locate_segmentation_dataset, get_segmentation_models_list
10
- from celldetective.segmentation import train_segmentation_model
11
- from celldetective.gui.layouts import CellposeParamsWidget
23
+ from celldetective.utils.model_getters import (
24
+ get_segmentation_models_list,
25
+ get_segmentation_datasets_list,
26
+ )
27
+ from celldetective.utils.model_loaders import locate_segmentation_dataset
28
+
12
29
  import numpy as np
13
30
  import json
14
31
  import os
15
32
  from glob import glob
16
33
  from datetime import datetime
17
34
  from celldetective.gui.settings._settings_base import CelldetectiveSettingsPanel
35
+ from celldetective import get_logger
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ class BackgroundLoader(QThread):
41
+ def run(self):
42
+ logger.info("Loading libraries...")
43
+ try:
44
+ from celldetective.processes.train_segmentation_model import (
45
+ TrainSegModelProcess,
46
+ )
47
+
48
+ self.TrainSegModelProcess = TrainSegModelProcess
49
+ except Exception:
50
+ logger.error("Librairies not loaded...")
51
+ logger.info("Librairies loaded...")
52
+
18
53
 
19
54
  class SettingsSegmentationModelTraining(CelldetectiveSettingsPanel):
20
-
21
- """
22
- UI to set segmentation model training instructions.
23
-
24
- """
25
-
26
- def __init__(self, parent_window=None):
27
-
28
- self.parent_window = parent_window
29
- self.use_gpu = self.parent_window.use_gpu
30
- self.mode = self.parent_window.mode
31
- self.exp_dir = self.parent_window.exp_dir
32
- self.pretrained_model = None
33
- self.dataset_folder = None
34
- super().__init__(title="Train segmentation model")
35
-
36
- self.software_models_dir = os.sep.join([self._software_path, 'celldetective', 'models', f'segmentation_{self.mode}'])
37
- self._add_to_layout()
38
- self._load_previous_instructions()
39
-
40
- self._adjustSize()
41
- self.resize(int(self.width()), int(self._screen_height * 0.8))
42
-
43
- def _add_to_layout(self):
44
-
45
- self._layout.addWidget(self.model_frame)
46
- self._layout.addWidget(self.data_frame)
47
- self._layout.addWidget(self.hyper_frame)
48
- self._layout.addWidget(self.submit_warning)
49
- self._layout.addWidget(self.submit_btn)
50
-
51
- def _create_widgets(self):
52
-
53
- """
54
- Create the multibox design.
55
-
56
- """
57
-
58
- super()._create_widgets()
59
- self.model_frame = QFrame()
60
- self.model_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
61
- self.populate_model_frame()
62
-
63
- self.data_frame = QFrame()
64
- self.data_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
65
- self.populate_data_frame()
66
-
67
- self.hyper_frame = QFrame()
68
- self.hyper_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
69
- self.populate_hyper_frame()
70
-
71
- self.submit_btn.setEnabled(False)
72
- self.submit_warning = QLabel('')
73
- self.submit_btn.setText("Train")
74
-
75
- self.spatial_calib_le.textChanged.connect(self.activate_train_btn)
76
- self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
77
-
78
- def populate_hyper_frame(self):
79
-
80
- """
81
- Add widgets and layout in the POST-PROCESSING frame.
82
- """
83
-
84
- grid = QGridLayout(self.hyper_frame)
85
- grid.setContentsMargins(30,30,30,30)
86
- grid.setSpacing(30)
87
-
88
- self.hyper_lbl = QLabel("HYPERPARAMETERS")
89
- self.hyper_lbl.setStyleSheet("""
55
+ """
56
+ UI to set segmentation model training instructions.
57
+
58
+ """
59
+
60
+ def __init__(self, parent_window=None):
61
+
62
+ self.parent_window = parent_window
63
+ self.use_gpu = self.parent_window.use_gpu
64
+ self.mode = self.parent_window.mode
65
+ self.exp_dir = self.parent_window.exp_dir
66
+ self.pretrained_model = None
67
+ self.dataset_folder = None
68
+ super().__init__(title="Train segmentation model")
69
+
70
+ self.software_models_dir = os.sep.join(
71
+ [
72
+ self._software_path,
73
+ "celldetective",
74
+ "models",
75
+ f"segmentation_{self.mode}",
76
+ ]
77
+ )
78
+ self._add_to_layout()
79
+ self._load_previous_instructions()
80
+
81
+ self._adjust_size()
82
+ self.resize(int(self.width()), int(self._screen_height * 0.8))
83
+
84
+ self.bg_loader = BackgroundLoader()
85
+ self.bg_loader.start()
86
+
87
+ def _add_to_layout(self):
88
+
89
+ self._layout.addWidget(self.model_frame)
90
+ self._layout.addWidget(self.data_frame)
91
+ self._layout.addWidget(self.hyper_frame)
92
+ self._layout.addWidget(self.submit_warning)
93
+ self._layout.addWidget(self.submit_btn)
94
+
95
+ def _create_widgets(self):
96
+ """
97
+ Create the multibox design.
98
+
99
+ """
100
+
101
+ super()._create_widgets()
102
+ self.model_frame = QFrame()
103
+ self.model_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
104
+ self.populate_model_frame()
105
+
106
+ self.data_frame = QFrame()
107
+ self.data_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
108
+ self.populate_data_frame()
109
+
110
+ self.hyper_frame = QFrame()
111
+ self.hyper_frame.setFrameStyle(QFrame.StyledPanel | QFrame.Raised)
112
+ self.populate_hyper_frame()
113
+
114
+ self.submit_btn.setEnabled(False)
115
+ self.submit_warning = QLabel("")
116
+ self.submit_btn.setText("Train")
117
+
118
+ self.spatial_calib_le.textChanged.connect(self.activate_train_btn)
119
+ self.modelname_le.setText(
120
+ f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}"
121
+ )
122
+
123
+ def populate_hyper_frame(self):
124
+ """
125
+ Add widgets and layout in the POST-PROCESSING frame.
126
+ """
127
+
128
+ grid = QGridLayout(self.hyper_frame)
129
+ grid.setContentsMargins(30, 30, 30, 30)
130
+ grid.setSpacing(30)
131
+
132
+ self.hyper_lbl = QLabel("HYPERPARAMETERS")
133
+ self.hyper_lbl.setStyleSheet(
134
+ """
90
135
  font-weight: bold;
91
136
  padding: 0px;
92
- """)
93
- grid.addWidget(self.hyper_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
94
- self.generate_hyper_contents()
95
- grid.addWidget(self.ContentsHyper, 1, 0, 1, 4, alignment=Qt.AlignTop)
96
-
97
- def generate_hyper_contents(self):
98
-
99
- self.ContentsHyper = QFrame()
100
- layout = QVBoxLayout(self.ContentsHyper)
101
- layout.setContentsMargins(0,0,0,0)
102
-
103
- lr_layout = QHBoxLayout()
104
- lr_layout.addWidget(QLabel('learning rate: '),30)
105
- self.lr_le = QLineEdit('0,0003')
106
- self.lr_le.setValidator(self._floatValidator)
107
- lr_layout.addWidget(self.lr_le, 70)
108
- layout.addLayout(lr_layout)
109
-
110
- bs_layout = QHBoxLayout()
111
- bs_layout.addWidget(QLabel('batch size: '),30)
112
- self.bs_le = QLineEdit('8')
113
- self.bs_le.setValidator(self._intValidator)
114
- bs_layout.addWidget(self.bs_le, 70)
115
- layout.addLayout(bs_layout)
116
-
117
- epochs_layout = QHBoxLayout()
118
- epochs_layout.addWidget(QLabel('# epochs: '), 30)
119
- self.epochs_slider = QLabeledSlider()
120
- self.epochs_slider.setRange(1,300)
121
- self.epochs_slider.setSingleStep(1)
122
- self.epochs_slider.setTickInterval(1)
123
- self.epochs_slider.setOrientation(Qt.Horizontal)
124
- self.epochs_slider.setValue(100)
125
- epochs_layout.addWidget(self.epochs_slider, 70)
126
- layout.addLayout(epochs_layout)
127
-
128
- self.stardist_model.clicked.connect(self.rescale_slider)
129
- self.cellpose_model.clicked.connect(self.rescale_slider)
130
-
131
- def populate_data_frame(self):
132
-
133
- """
134
- Add widgets and layout in the POST-PROCESSING frame.
135
- """
136
-
137
- grid = QGridLayout(self.data_frame)
138
- grid.setContentsMargins(30,30,30,30)
139
- grid.setSpacing(30)
140
-
141
- self.data_lbl = QLabel("DATA")
142
- self.data_lbl.setStyleSheet("""
137
+ """
138
+ )
139
+ grid.addWidget(self.hyper_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
140
+ self.generate_hyper_contents()
141
+ grid.addWidget(self.ContentsHyper, 1, 0, 1, 4, alignment=Qt.AlignTop)
142
+
143
+ def generate_hyper_contents(self):
144
+
145
+ self.ContentsHyper = QFrame()
146
+ layout = QVBoxLayout(self.ContentsHyper)
147
+ layout.setContentsMargins(0, 0, 0, 0)
148
+
149
+ lr_layout = QHBoxLayout()
150
+ lr_layout.addWidget(QLabel("learning rate: "), 30)
151
+ self.lr_le = QLineEdit("0,0003")
152
+ self.lr_le.setValidator(self._floatValidator)
153
+ lr_layout.addWidget(self.lr_le, 70)
154
+ layout.addLayout(lr_layout)
155
+
156
+ bs_layout = QHBoxLayout()
157
+ bs_layout.addWidget(QLabel("batch size: "), 30)
158
+ self.bs_le = QLineEdit("8")
159
+ self.bs_le.setValidator(self._intValidator)
160
+ bs_layout.addWidget(self.bs_le, 70)
161
+ layout.addLayout(bs_layout)
162
+
163
+ epochs_layout = QHBoxLayout()
164
+ epochs_layout.addWidget(QLabel("# epochs: "), 30)
165
+ self.epochs_slider = QLabeledSlider()
166
+ self.epochs_slider.setRange(1, 300)
167
+ self.epochs_slider.setSingleStep(1)
168
+ self.epochs_slider.setTickInterval(1)
169
+ self.epochs_slider.setOrientation(Qt.Horizontal)
170
+ self.epochs_slider.setValue(100)
171
+ epochs_layout.addWidget(self.epochs_slider, 70)
172
+ layout.addLayout(epochs_layout)
173
+
174
+ self.stardist_model.clicked.connect(self.rescale_slider)
175
+ self.cellpose_model.clicked.connect(self.rescale_slider)
176
+
177
+ def populate_data_frame(self):
178
+ """
179
+ Add widgets and layout in the POST-PROCESSING frame.
180
+ """
181
+
182
+ grid = QGridLayout(self.data_frame)
183
+ grid.setContentsMargins(30, 30, 30, 30)
184
+ grid.setSpacing(30)
185
+
186
+ self.data_lbl = QLabel("DATA")
187
+ self.data_lbl.setStyleSheet(
188
+ """
143
189
  font-weight: bold;
144
190
  padding: 0px;
145
- """)
146
- grid.addWidget(self.data_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
147
- self.generate_data_contents()
148
- grid.addWidget(self.ContentsData, 1, 0, 1, 4, alignment=Qt.AlignTop)
149
-
150
- def populate_model_frame(self):
151
-
152
- """
153
- Add widgets and layout in the FEATURES frame.
154
- """
155
-
156
- grid = QGridLayout(self.model_frame)
157
- grid.setContentsMargins(30,30,30,30)
158
- grid.setSpacing(30)
159
-
160
- self.model_lbl = QLabel("MODEL")
161
- self.model_lbl.setStyleSheet("""
191
+ """
192
+ )
193
+ grid.addWidget(self.data_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
194
+ self.generate_data_contents()
195
+ grid.addWidget(self.ContentsData, 1, 0, 1, 4, alignment=Qt.AlignTop)
196
+
197
+ def populate_model_frame(self):
198
+ """
199
+ Add widgets and layout in the FEATURES frame.
200
+ """
201
+
202
+ grid = QGridLayout(self.model_frame)
203
+ grid.setContentsMargins(30, 30, 30, 30)
204
+ grid.setSpacing(30)
205
+
206
+ self.model_lbl = QLabel("MODEL")
207
+ self.model_lbl.setStyleSheet(
208
+ """
162
209
  font-weight: bold;
163
210
  padding: 0px;
164
- """)
165
- grid.addWidget(self.model_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
166
-
167
- self.generate_model_panel_contents()
168
- grid.addWidget(self.ContentsModel, 1, 0, 1, 4, alignment=Qt.AlignTop)
169
-
170
-
171
- def generate_data_contents(self):
172
-
173
- self.ContentsData = QFrame()
174
- layout = QVBoxLayout(self.ContentsData)
175
- layout.setContentsMargins(0,0,0,0)
176
-
177
- train_data_layout = QHBoxLayout()
178
- train_data_layout.addWidget(QLabel('Training data: '), 30)
179
- self.select_data_folder_btn = QPushButton('Choose folder')
180
- self.select_data_folder_btn.clicked.connect(self.showDialog_dataset)
181
- self.data_folder_label = QLabel('No folder chosen')
182
- train_data_layout.addWidget(self.select_data_folder_btn, 35)
183
- train_data_layout.addWidget(self.data_folder_label, 30)
184
-
185
- self.cancel_dataset = QPushButton()
186
- self.cancel_dataset.setIcon(icon(MDI6.close,color="black"))
187
- self.cancel_dataset.clicked.connect(self.clear_dataset)
188
- self.cancel_dataset.setStyleSheet(self.button_select_all)
189
- self.cancel_dataset.setIconSize(QSize(20, 20))
190
- self.cancel_dataset.setVisible(False)
191
- train_data_layout.addWidget(self.cancel_dataset, 5)
192
-
193
-
194
- layout.addLayout(train_data_layout)
195
-
196
- include_dataset_layout = QHBoxLayout()
197
- include_dataset_layout.addWidget(QLabel('include dataset: '),30)
198
- self.dataset_cb = QComboBox()
199
- available_datasets, self.datasets_path = get_segmentation_datasets_list(return_path=True)
200
- signal_datasets = ['--'] + available_datasets #[d.split('/')[-2] for d in available_datasets]
201
- self.dataset_cb.addItems(signal_datasets)
202
- include_dataset_layout.addWidget(self.dataset_cb, 70)
203
- layout.addLayout(include_dataset_layout)
204
-
205
- augmentation_hbox = QHBoxLayout()
206
- augmentation_hbox.addWidget(QLabel('augmentation\nfactor: '), 30)
207
- self.augmentation_slider = QLabeledDoubleSlider()
208
- self.augmentation_slider.setSingleStep(0.01)
209
- self.augmentation_slider.setTickInterval(0.01)
210
- self.augmentation_slider.setOrientation(Qt.Horizontal)
211
- self.augmentation_slider.setRange(0.01, 3)
212
- self.augmentation_slider.setValue(2.0)
213
-
214
- augmentation_hbox.addWidget(self.augmentation_slider, 70)
215
- layout.addLayout(augmentation_hbox)
216
-
217
- validation_split_layout = QHBoxLayout()
218
- validation_split_layout.addWidget(QLabel('validation split: '),30)
219
- self.validation_slider = QLabeledDoubleSlider()
220
- self.validation_slider.setSingleStep(0.01)
221
- self.validation_slider.setTickInterval(0.01)
222
- self.validation_slider.setOrientation(Qt.Horizontal)
223
- self.validation_slider.setRange(0,0.9)
224
- self.validation_slider.setValue(0.2)
225
- validation_split_layout.addWidget(self.validation_slider, 70)
226
- layout.addLayout(validation_split_layout)
227
-
228
-
229
- def generate_model_panel_contents(self):
230
-
231
- self.ContentsModel = QFrame()
232
- layout = QVBoxLayout(self.ContentsModel)
233
- layout.setContentsMargins(0,0,0,0)
234
-
235
- model_type_layout = QHBoxLayout()
236
- model_type_layout.setContentsMargins(30,5,30,15)
237
- self.cellpose_model = QRadioButton('Cellpose')
238
- self.stardist_model = QRadioButton('StarDist')
239
- self.stardist_model.setChecked(True)
240
- model_type_layout.addWidget(self.stardist_model,50, alignment=Qt.AlignCenter)
241
- model_type_layout.addWidget(self.cellpose_model,50, alignment=Qt.AlignCenter)
242
- layout.addLayout(model_type_layout)
243
-
244
- modelname_layout = QHBoxLayout()
245
- modelname_layout.addWidget(QLabel('Model name: '), 30)
246
- self.modelname_le = QLineEdit()
247
- self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
248
- self.modelname_le.textChanged.connect(self.activate_train_btn)
249
- modelname_layout.addWidget(self.modelname_le, 70)
250
- layout.addLayout(modelname_layout)
251
-
252
- pretrained_layout = QHBoxLayout()
253
- pretrained_layout.setContentsMargins(0,0,0,0)
254
- pretrained_layout.addWidget(QLabel('Pretrained model: '), 30)
255
-
256
- self.browse_pretrained_btn = QPushButton('Choose folder')
257
- self.browse_pretrained_btn.clicked.connect(self.showDialog_pretrained)
258
- pretrained_layout.addWidget(self.browse_pretrained_btn, 35)
259
-
260
- self.pretrained_lbl = QLabel('No folder chosen')
261
- pretrained_layout.addWidget(self.pretrained_lbl, 30)
262
-
263
- self.cancel_pretrained = QPushButton()
264
- self.cancel_pretrained.setIcon(icon(MDI6.close,color="black"))
265
- self.cancel_pretrained.clicked.connect(self.clear_pretrained)
266
- self.cancel_pretrained.setStyleSheet(self.button_select_all)
267
- self.cancel_pretrained.setIconSize(QSize(20, 20))
268
- self.cancel_pretrained.setVisible(False)
269
- pretrained_layout.addWidget(self.cancel_pretrained, 5)
270
-
271
- layout.addLayout(pretrained_layout)
272
-
273
- # recompile_layout = QHBoxLayout()
274
- # recompile_layout.addWidget(QLabel('Recompile: '), 30)
275
- # self.recompile_option = QCheckBox()
276
- # self.recompile_option.setEnabled(False)
277
- # recompile_layout.addWidget(self.recompile_option, 70)
278
- # layout.addLayout(recompile_layout)
279
-
280
- self.max_nbr_channels = 5
281
- self.ch_norm = ChannelNormGenerator(self, mode='channels')
282
- layout.addLayout(self.ch_norm)
283
-
284
- spatial_calib_layout = QHBoxLayout()
285
- spatial_calib_layout.addWidget(QLabel('input spatial\ncalibration'), 30)
286
- parent_pxtoum = f"{self.parent_window.parent_window.PxToUm}"
287
- self.spatial_calib_le = QLineEdit(parent_pxtoum.replace('.',','))
288
- self.spatial_calib_le.setPlaceholderText('e.g. 0.1 µm per pixel')
289
- self.spatial_calib_le.setValidator(self._floatValidator)
290
- spatial_calib_layout.addWidget(self.spatial_calib_le, 70)
291
- layout.addLayout(spatial_calib_layout)
292
-
293
- def activate_train_btn(self):
294
-
295
- current_name = self.modelname_le.text()
296
- models = get_segmentation_models_list(mode=self.mode, return_path=False)
297
- if not current_name in models and not self.spatial_calib_le.text()=='' and not np.all([cb.currentText()=='--' for cb in self.ch_norm.channel_cbs]):
298
- self.submit_btn.setEnabled(True)
299
- self.submit_warning.setText('')
300
- else:
301
- self.submit_btn.setEnabled(False)
302
- if current_name in models:
303
- self.submit_warning.setText('A model with this name already exists... Please pick another.')
304
- elif self.spatial_calib_le.text()=='':
305
- self.submit_warning.setText('Please provide a valid spatial calibration...')
306
- elif np.all([cb.currentText()=='--' for cb in self.ch_norm.channel_cbs]):
307
- self.submit_warning.setText('Please provide valid channels...')
308
-
309
- def rescale_slider(self):
310
- if self.stardist_model.isChecked():
311
- self.epochs_slider.setRange(1,500)
312
- self.lr_le.setText('0,0003')
313
- else:
314
- self.epochs_slider.setRange(1,10000)
315
- self.lr_le.setText('0,01')
316
-
317
-
318
- def showDialog_pretrained(self):
319
-
320
- self.clear_pretrained()
321
- self.pretrained_model = None
322
- self.pretrained_model = QFileDialog.getExistingDirectory(
323
- self, "Open Directory",
324
- os.sep.join([self._software_path, 'celldetective', 'models', f'segmentation_generic','']),
325
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
326
- )
327
-
328
- if self.pretrained_model=='':
329
- return None
330
-
331
- if self.pretrained_model is None:
332
- return None
333
-
334
- else:
335
- self.pretrained_model = self.pretrained_model.replace('\\','/')
336
- self.pretrained_model = rf"{self.pretrained_model}"
337
-
338
- subfiles = glob(os.sep.join([self.pretrained_model,"*"]))
339
- subfiles = [s.replace('\\','/') for s in subfiles]
340
- subfiles = [rf"{s}" for s in subfiles]
341
-
342
- if "/".join([self.pretrained_model,"config_input.json"]) in subfiles:
343
- self.load_pretrained_config()
344
- self.pretrained_lbl.setText(self.pretrained_model.split("/")[-1])
345
- self.cancel_pretrained.setVisible(True)
346
- #self.recompile_option.setEnabled(True)
347
- self.modelname_le.setText(f"{self.pretrained_model.split('/')[-1]}_{datetime.today().strftime('%Y-%m-%d')}")
348
- else:
349
- self.pretrained_model = None
350
- self.pretrained_lbl.setText('No folder chosen')
351
- #self.recompile_option.setEnabled(False)
352
- self.cancel_pretrained.setVisible(False)
353
- return None
354
-
355
- self.seg_folder = self.pretrained_model.split('/')[-2]
356
- self.model_name = self.pretrained_model.split('/')[-1]
357
- if self.model_name.startswith('CP') and self.seg_folder=='segmentation_generic':
358
-
359
- self.diamWidget = CellposeParamsWidget(self, model_name=self.model_name)
360
- self.diamWidget.show()
361
-
362
- def set_cellpose_scale(self):
363
-
364
- scale = self.parent_window.parent_window.PxToUm * float(self.diamWidget.diameter_le.text().replace(',','.')) / 30.0
365
- if self.model_name=="CP_nuclei":
366
- scale = self.parent_window.parent_window.PxToUm * float(self.diamWidget.diameter_le.text().replace(',','.')) / 17.0
367
- self.spatial_calib_le.setText(str(scale).replace('.',','))
368
-
369
- for k in range(len(self.diamWidget.cellpose_channel_cb)):
370
- ch = self.diamWidget.cellpose_channel_cb[k].currentText()
371
- idx = self.ch_norm.channel_cbs[k].findText(ch)
372
- self.ch_norm.channel_cbs[k].setCurrentIndex(idx)
373
-
374
- self.diamWidget.close()
375
-
376
-
377
- def showDialog_dataset(self):
378
-
379
- self.dataset_folder = QFileDialog.getExistingDirectory(
380
- self, "Open Directory",
381
- self.exp_dir,
382
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
383
- )
384
- if self.dataset_folder is not None:
385
-
386
- subfiles = glob(self.dataset_folder+os.sep+"*.tif")
387
- if len(subfiles)>0:
388
- print(f'found {len(subfiles)} files in folder')
389
- self.data_folder_label.setText(self.dataset_folder[:16]+'...')
390
- self.data_folder_label.setToolTip(self.dataset_folder)
391
- self.cancel_dataset.setVisible(True)
392
- else:
393
- self.data_folder_label.setText('No folder chosen')
394
- self.data_folder_label.setToolTip('')
395
- self.dataset_folder = None
396
- self.cancel_dataset.setVisible(False)
397
-
398
- def clear_pretrained(self):
399
-
400
- self.pretrained_model = None
401
- self.pretrained_lbl.setText('No folder chosen')
402
- for i in range(len(self.ch_norm.channel_cbs)):
403
- self.ch_norm.channel_cbs[i].setEnabled(True)
404
- self.ch_norm.normalization_mode_btns[i].setEnabled(True)
405
- self.ch_norm.normalization_max_value_le[i].setEnabled(True)
406
- self.ch_norm.normalization_min_value_le[i].setEnabled(True)
407
- self.ch_norm.normalization_clip_btns[i].setEnabled(True)
408
- self.ch_norm.normalization_min_value_lbl[i].setEnabled(True)
409
- self.ch_norm.normalization_max_value_lbl[i].setEnabled(True)
410
- self.ch_norm.add_col_btn.setEnabled(True)
411
-
412
- self.cancel_pretrained.setVisible(False)
413
- self.modelname_le.setText(f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}")
414
-
415
- def clear_dataset(self):
416
-
417
- self.dataset_folder = None
418
- self.data_folder_label.setText('No folder chosen')
419
- self.data_folder_label.setToolTip('')
420
- self.cancel_dataset.setVisible(False)
421
-
422
- def load_stardist_train_config(self):
423
-
424
- config = os.sep.join([self.pretrained_model,"config.json"])
425
- if os.path.exists(config):
426
- with open(config, 'r') as f:
427
- config = json.load(f)
428
- if 'train_batch_size' in config:
429
- bs = config['train_batch_size']
430
- self.bs_le.setText(str(bs).replace('.',','))
431
- if 'train_learning_rate' in config:
432
- lr = config['train_learning_rate']
433
- self.lr_le.setText(str(lr).replace('.',','))
434
-
435
- def load_pretrained_config(self):
436
-
437
- f = open(os.sep.join([self.pretrained_model,"config_input.json"]))
438
- data = json.load(f)
439
- channels = data["channels"]
440
- self.seg_folder = self.pretrained_model.split('/')[-2]
441
- self.model_name = self.pretrained_model.split('/')[-1]
442
- if self.model_name.startswith('CP') and self.seg_folder=='segmentation_generic':
443
- channels = ['brightfield_channel', 'live_nuclei_channel']
444
- if self.model_name=="CP_nuclei":
445
- channels = ['live_nuclei_channel', 'None']
446
- if self.model_name.startswith('SD') and self.seg_folder=='segmentation_generic':
447
- channels = ['live_nuclei_channel']
448
- if self.model_name=="SD_versatile_he":
449
- channels = ["H&E_1","H&E_2","H&E_3"]
450
-
451
- normalization_percentile = data['normalization_percentile']
452
- normalization_clip = data['normalization_clip']
453
- normalization_values = data['normalization_values']
454
- spatial_calib = data['spatial_calibration']
455
- model_type = data['model_type']
456
- if model_type=='stardist':
457
- self.stardist_model.setChecked(True)
458
- self.cellpose_model.setChecked(False)
459
- self.load_stardist_train_config()
460
- else:
461
- self.stardist_model.setChecked(False)
462
- self.cellpose_model.setChecked(True)
463
-
464
- for c,cb in zip(channels, self.ch_norm.channel_cbs):
465
- index = cb.findText(c)
466
- cb.setCurrentIndex(index)
467
-
468
- for i in range(len(channels)):
469
-
470
- to_clip = normalization_clip[i]
471
- if self.ch_norm.clip_option[i] != to_clip:
472
- self.ch_norm.normalization_clip_btns[i].click()
473
-
474
- use_percentile = normalization_percentile[i]
475
- if self.ch_norm.normalization_mode[i] != use_percentile:
476
- self.ch_norm.normalization_mode_btns[i].click()
477
-
478
- self.ch_norm.normalization_min_value_le[i].setText(str(normalization_values[i][0]))
479
- self.ch_norm.normalization_max_value_le[i].setText(str(normalization_values[i][1]))
480
-
481
-
482
- if len(channels)<len(self.ch_norm.channel_cbs):
483
- for k in range(len(self.ch_norm.channel_cbs)-len(channels)):
484
- self.ch_norm.channel_cbs[len(channels)+k].setCurrentIndex(0)
485
- self.ch_norm.channel_cbs[len(channels)+k].setEnabled(False)
486
- self.ch_norm.normalization_mode_btns[len(channels)+k].setEnabled(False)
487
- self.ch_norm.normalization_max_value_le[len(channels)+k].setEnabled(False)
488
- self.ch_norm.normalization_min_value_le[len(channels)+k].setEnabled(False)
489
- self.ch_norm.normalization_min_value_lbl[len(channels)+k].setEnabled(False)
490
- self.ch_norm.normalization_max_value_lbl[len(channels)+k].setEnabled(False)
491
- self.ch_norm.normalization_clip_btns[len(channels)+k].setEnabled(False)
492
- self.ch_norm.add_col_btn.setEnabled(False)
493
-
494
- self.spatial_calib_le.setText(str(spatial_calib).replace('.',','))
495
-
496
- def _write_instructions(self):
497
-
498
- model_name = self.modelname_le.text()
499
- pretrained_model = self.pretrained_model
500
-
501
- channels = []
502
- for i in range(len(self.ch_norm.channel_cbs)):
503
- channels.append(self.ch_norm.channel_cbs[i].currentText())
504
-
505
- slots_to_keep = np.where(np.array(channels)!='--')[0]
506
- while '--' in channels:
507
- channels.remove('--')
508
-
509
- norm_values = np.array([[float(a.replace(',','.')),float(b.replace(',','.'))] for a,b in zip([l.text() for l in self.ch_norm.normalization_min_value_le],
510
- [l.text() for l in self.ch_norm.normalization_max_value_le])])
511
- norm_values = norm_values[slots_to_keep]
512
- norm_values = [list(v) for v in norm_values]
513
-
514
- clip_values = np.array(self.ch_norm.clip_option)
515
- clip_values = list(clip_values[slots_to_keep])
516
- clip_values = [bool(c) for c in clip_values]
517
-
518
- normalization_mode = np.array(self.ch_norm.normalization_mode)
519
- normalization_mode = list(normalization_mode[slots_to_keep])
520
- normalization_mode = [bool(m) for m in normalization_mode]
521
-
522
- data_folders = []
523
- if self.dataset_folder is not None:
524
- data_folders.append(self.dataset_folder)
525
- if self.dataset_cb.currentText()!='--':
526
- dataset = locate_segmentation_dataset(self.dataset_cb.currentText()) #glob(self.soft_path+'/celldetective/datasets/signals/*/')[self.dataset_cb.currentIndex()-1]
527
- data_folders.append(dataset)
528
-
529
- aug_factor = round(self.augmentation_slider.value(),2)
530
- val_split = round(self.validation_slider.value(),2)
531
- if self.stardist_model.isChecked():
532
- model_type = 'stardist'
533
- else:
534
- model_type = 'cellpose'
535
-
536
- try:
537
- lr = float(self.lr_le.text().replace(',','.'))
538
- except:
539
- generic_message('Invalid value encountered for the learning rate.')
540
- return None
541
-
542
- bs = int(self.bs_le.text())
543
- epochs = self.epochs_slider.value()
544
- spatial_calib = float(self.spatial_calib_le.text().replace(',','.'))
545
-
546
- self.training_instructions = {'model_name': model_name,'model_type': model_type, 'pretrained': pretrained_model, 'spatial_calibration': spatial_calib, 'channel_option': channels, 'normalization_percentile': normalization_mode,
547
- 'normalization_clip': clip_values,'normalization_values': norm_values, 'ds': data_folders, 'augmentation_factor': aug_factor, 'validation_split': val_split,
548
- 'learning_rate': lr, 'batch_size': bs, 'epochs': epochs}
549
-
550
-
551
- model_folder = os.sep.join([self.software_models_dir,model_name, ''])
552
- print(model_folder)
553
- if not os.path.exists(model_folder):
554
- os.mkdir(model_folder)
555
-
556
- self.training_instructions.update({'target_directory': self.software_models_dir})
557
-
558
- print(f"Set of instructions: {self.training_instructions}")
559
-
560
- self.instructions = model_folder+"training_instructions.json"
561
-
562
- with open(model_folder+"training_instructions.json", 'w') as f:
563
- json.dump(self.training_instructions, f, indent=4)
564
-
565
- # process_args = {"instructions": self.instructions, "use_gpu": self.use_gpu}
566
- # self.job = ProgressWindow(TrainSegModelProcess, parent_window=self, title="Training", position_info=False, process_args=process_args)
567
- # result = self.job.exec_()
568
- # if result == QDialog.Accepted:
569
- # pass
570
- # elif result == QDialog.Rejected:
571
- # return None
572
-
573
- train_segmentation_model(self.instructions, use_gpu=self.parent_window.parent_window.parent_window.use_gpu)
574
-
575
- self.parent_window.init_seg_model_list()
576
- idx = self.parent_window.seg_model_list.findText(model_name)
577
- self.parent_window.seg_model_list.setCurrentIndex(idx)
578
-
579
- def _load_previous_instructions(self):
580
- pass
211
+ """
212
+ )
213
+ grid.addWidget(self.model_lbl, 0, 0, 1, 4, alignment=Qt.AlignCenter)
214
+
215
+ self.generate_model_panel_contents()
216
+ grid.addWidget(self.ContentsModel, 1, 0, 1, 4, alignment=Qt.AlignTop)
217
+
218
+ def generate_data_contents(self):
219
+
220
+ self.ContentsData = QFrame()
221
+ layout = QVBoxLayout(self.ContentsData)
222
+ layout.setContentsMargins(0, 0, 0, 0)
223
+
224
+ train_data_layout = QHBoxLayout()
225
+ train_data_layout.addWidget(QLabel("Training data: "), 30)
226
+ self.select_data_folder_btn = QPushButton("Choose folder")
227
+ self.select_data_folder_btn.clicked.connect(self.showDialog_dataset)
228
+ self.data_folder_label = QLabel("No folder chosen")
229
+ train_data_layout.addWidget(self.select_data_folder_btn, 35)
230
+ train_data_layout.addWidget(self.data_folder_label, 30)
231
+
232
+ self.cancel_dataset = QPushButton()
233
+ self.cancel_dataset.setIcon(icon(MDI6.close, color="black"))
234
+ self.cancel_dataset.clicked.connect(self.clear_dataset)
235
+ self.cancel_dataset.setStyleSheet(self.button_select_all)
236
+ self.cancel_dataset.setIconSize(QSize(20, 20))
237
+ self.cancel_dataset.setVisible(False)
238
+ train_data_layout.addWidget(self.cancel_dataset, 5)
239
+
240
+ layout.addLayout(train_data_layout)
241
+
242
+ include_dataset_layout = QHBoxLayout()
243
+ include_dataset_layout.addWidget(QLabel("include dataset: "), 30)
244
+ self.dataset_cb = QComboBox()
245
+ available_datasets, self.datasets_path = get_segmentation_datasets_list(
246
+ return_path=True
247
+ )
248
+ signal_datasets = [
249
+ "--"
250
+ ] + available_datasets # [d.split('/')[-2] for d in available_datasets]
251
+ self.dataset_cb.addItems(signal_datasets)
252
+ include_dataset_layout.addWidget(self.dataset_cb, 70)
253
+ layout.addLayout(include_dataset_layout)
254
+
255
+ augmentation_hbox = QHBoxLayout()
256
+ augmentation_hbox.addWidget(QLabel("augmentation\nfactor: "), 30)
257
+ self.augmentation_slider = QLabeledDoubleSlider()
258
+ self.augmentation_slider.setSingleStep(0.01)
259
+ self.augmentation_slider.setTickInterval(0.01)
260
+ self.augmentation_slider.setOrientation(Qt.Horizontal)
261
+ self.augmentation_slider.setRange(0.01, 3)
262
+ self.augmentation_slider.setValue(2.0)
263
+
264
+ augmentation_hbox.addWidget(self.augmentation_slider, 70)
265
+ layout.addLayout(augmentation_hbox)
266
+
267
+ validation_split_layout = QHBoxLayout()
268
+ validation_split_layout.addWidget(QLabel("validation split: "), 30)
269
+ self.validation_slider = QLabeledDoubleSlider()
270
+ self.validation_slider.setSingleStep(0.01)
271
+ self.validation_slider.setTickInterval(0.01)
272
+ self.validation_slider.setOrientation(Qt.Horizontal)
273
+ self.validation_slider.setRange(0, 0.9)
274
+ self.validation_slider.setValue(0.2)
275
+ validation_split_layout.addWidget(self.validation_slider, 70)
276
+ layout.addLayout(validation_split_layout)
277
+
278
+ def generate_model_panel_contents(self):
279
+
280
+ self.ContentsModel = QFrame()
281
+ layout = QVBoxLayout(self.ContentsModel)
282
+ layout.setContentsMargins(0, 0, 0, 0)
283
+
284
+ model_type_layout = QHBoxLayout()
285
+ model_type_layout.setContentsMargins(30, 5, 30, 15)
286
+ self.cellpose_model = QRadioButton("Cellpose")
287
+ self.stardist_model = QRadioButton("StarDist")
288
+ self.stardist_model.setChecked(True)
289
+ model_type_layout.addWidget(self.stardist_model, 50, alignment=Qt.AlignCenter)
290
+ model_type_layout.addWidget(self.cellpose_model, 50, alignment=Qt.AlignCenter)
291
+ layout.addLayout(model_type_layout)
292
+
293
+ modelname_layout = QHBoxLayout()
294
+ modelname_layout.addWidget(QLabel("Model name: "), 30)
295
+ self.modelname_le = QLineEdit()
296
+ self.modelname_le.setText(
297
+ f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}"
298
+ )
299
+ self.modelname_le.textChanged.connect(self.activate_train_btn)
300
+ modelname_layout.addWidget(self.modelname_le, 70)
301
+ layout.addLayout(modelname_layout)
302
+
303
+ pretrained_layout = QHBoxLayout()
304
+ pretrained_layout.setContentsMargins(0, 0, 0, 0)
305
+ pretrained_layout.addWidget(QLabel("Pretrained model: "), 30)
306
+
307
+ self.browse_pretrained_btn = QPushButton("Choose folder")
308
+ self.browse_pretrained_btn.clicked.connect(self.showDialog_pretrained)
309
+ pretrained_layout.addWidget(self.browse_pretrained_btn, 35)
310
+
311
+ self.pretrained_lbl = QLabel("No folder chosen")
312
+ pretrained_layout.addWidget(self.pretrained_lbl, 30)
313
+
314
+ self.cancel_pretrained = QPushButton()
315
+ self.cancel_pretrained.setIcon(icon(MDI6.close, color="black"))
316
+ self.cancel_pretrained.clicked.connect(self.clear_pretrained)
317
+ self.cancel_pretrained.setStyleSheet(self.button_select_all)
318
+ self.cancel_pretrained.setIconSize(QSize(20, 20))
319
+ self.cancel_pretrained.setVisible(False)
320
+ pretrained_layout.addWidget(self.cancel_pretrained, 5)
321
+
322
+ layout.addLayout(pretrained_layout)
323
+
324
+ # recompile_layout = QHBoxLayout()
325
+ # recompile_layout.addWidget(QLabel('Recompile: '), 30)
326
+ # self.recompile_option = QCheckBox()
327
+ # self.recompile_option.setEnabled(False)
328
+ # recompile_layout.addWidget(self.recompile_option, 70)
329
+ # layout.addLayout(recompile_layout)
330
+
331
+ self.max_nbr_channels = 5
332
+ self.ch_norm = ChannelNormGenerator(self, mode="channels")
333
+ layout.addLayout(self.ch_norm)
334
+
335
+ spatial_calib_layout = QHBoxLayout()
336
+ spatial_calib_layout.addWidget(QLabel("input spatial\ncalibration"), 30)
337
+ parent_pxtoum = f"{self.parent_window.parent_window.PxToUm}"
338
+ self.spatial_calib_le = QLineEdit(parent_pxtoum.replace(".", ","))
339
+ self.spatial_calib_le.setPlaceholderText("e.g. 0.1 µm per pixel")
340
+ self.spatial_calib_le.setValidator(self._floatValidator)
341
+ spatial_calib_layout.addWidget(self.spatial_calib_le, 70)
342
+ layout.addLayout(spatial_calib_layout)
343
+
344
+ def activate_train_btn(self):
345
+
346
+ current_name = self.modelname_le.text()
347
+ models = get_segmentation_models_list(mode=self.mode, return_path=False)
348
+ if (
349
+ not current_name in models
350
+ and not self.spatial_calib_le.text() == ""
351
+ and not np.all(
352
+ [cb.currentText() == "--" for cb in self.ch_norm.channel_cbs]
353
+ )
354
+ ):
355
+ self.submit_btn.setEnabled(True)
356
+ self.submit_warning.setText("")
357
+ else:
358
+ self.submit_btn.setEnabled(False)
359
+ if current_name in models:
360
+ self.submit_warning.setText(
361
+ "A model with this name already exists... Please pick another."
362
+ )
363
+ elif self.spatial_calib_le.text() == "":
364
+ self.submit_warning.setText(
365
+ "Please provide a valid spatial calibration..."
366
+ )
367
+ elif np.all([cb.currentText() == "--" for cb in self.ch_norm.channel_cbs]):
368
+ self.submit_warning.setText("Please provide valid channels...")
369
+
370
+ def rescale_slider(self):
371
+ if self.stardist_model.isChecked():
372
+ self.epochs_slider.setRange(1, 500)
373
+ self.lr_le.setText("0,0003")
374
+ else:
375
+ self.epochs_slider.setRange(1, 10000)
376
+ self.lr_le.setText("0,01")
377
+
378
+ def showDialog_pretrained(self):
379
+
380
+ self.clear_pretrained()
381
+ self.pretrained_model = None
382
+ self.pretrained_model = QFileDialog.getExistingDirectory(
383
+ self,
384
+ "Open Directory",
385
+ os.sep.join(
386
+ [
387
+ self._software_path,
388
+ "celldetective",
389
+ "models",
390
+ f"segmentation_generic",
391
+ "",
392
+ ]
393
+ ),
394
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
395
+ )
396
+
397
+ if self.pretrained_model == "":
398
+ return None
399
+
400
+ if self.pretrained_model is None:
401
+ return None
402
+
403
+ else:
404
+ self.pretrained_model = self.pretrained_model.replace("\\", "/")
405
+ self.pretrained_model = rf"{self.pretrained_model}"
406
+
407
+ subfiles = glob(os.sep.join([self.pretrained_model, "*"]))
408
+ subfiles = [s.replace("\\", "/") for s in subfiles]
409
+ subfiles = [rf"{s}" for s in subfiles]
410
+
411
+ if "/".join([self.pretrained_model, "config_input.json"]) in subfiles:
412
+ self.load_pretrained_config()
413
+ self.pretrained_lbl.setText(self.pretrained_model.split("/")[-1])
414
+ self.cancel_pretrained.setVisible(True)
415
+ # self.recompile_option.setEnabled(True)
416
+ self.modelname_le.setText(
417
+ f"{self.pretrained_model.split('/')[-1]}_{datetime.today().strftime('%Y-%m-%d')}"
418
+ )
419
+ else:
420
+ self.pretrained_model = None
421
+ self.pretrained_lbl.setText("No folder chosen")
422
+ # self.recompile_option.setEnabled(False)
423
+ self.cancel_pretrained.setVisible(False)
424
+ return None
425
+
426
+ self.seg_folder = self.pretrained_model.split("/")[-2]
427
+ self.model_name = self.pretrained_model.split("/")[-1]
428
+ if (
429
+ self.model_name.startswith("CP")
430
+ and self.seg_folder == "segmentation_generic"
431
+ ):
432
+ from celldetective.gui.settings._cellpose_model_params import (
433
+ CellposeParamsWidget,
434
+ )
435
+
436
+ self.diamWidget = CellposeParamsWidget(self, model_name=self.model_name)
437
+ self.diamWidget.show()
438
+
439
+ def set_cellpose_scale(self):
440
+
441
+ scale = (
442
+ self.parent_window.parent_window.PxToUm
443
+ * float(self.diamWidget.diameter_le.text().replace(",", "."))
444
+ / 30.0
445
+ )
446
+ if self.model_name == "CP_nuclei":
447
+ scale = (
448
+ self.parent_window.parent_window.PxToUm
449
+ * float(self.diamWidget.diameter_le.text().replace(",", "."))
450
+ / 17.0
451
+ )
452
+ self.spatial_calib_le.setText(str(scale).replace(".", ","))
453
+
454
+ for k in range(len(self.diamWidget.cellpose_channel_cb)):
455
+ ch = self.diamWidget.cellpose_channel_cb[k].currentText()
456
+ idx = self.ch_norm.channel_cbs[k].findText(ch)
457
+ self.ch_norm.channel_cbs[k].setCurrentIndex(idx)
458
+
459
+ self.diamWidget.close()
460
+
461
+ def showDialog_dataset(self):
462
+
463
+ self.dataset_folder = QFileDialog.getExistingDirectory(
464
+ self,
465
+ "Open Directory",
466
+ self.exp_dir,
467
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
468
+ )
469
+ if self.dataset_folder is not None:
470
+
471
+ subfiles = glob(self.dataset_folder + os.sep + "*.tif")
472
+ if len(subfiles) > 0:
473
+ print(f"found {len(subfiles)} files in folder")
474
+ self.data_folder_label.setText(self.dataset_folder[:16] + "...")
475
+ self.data_folder_label.setToolTip(self.dataset_folder)
476
+ self.cancel_dataset.setVisible(True)
477
+ else:
478
+ self.data_folder_label.setText("No folder chosen")
479
+ self.data_folder_label.setToolTip("")
480
+ self.dataset_folder = None
481
+ self.cancel_dataset.setVisible(False)
482
+
483
+ def clear_pretrained(self):
484
+
485
+ self.pretrained_model = None
486
+ self.pretrained_lbl.setText("No folder chosen")
487
+ for i in range(len(self.ch_norm.channel_cbs)):
488
+ self.ch_norm.channel_cbs[i].setEnabled(True)
489
+ self.ch_norm.normalization_mode_btns[i].setEnabled(True)
490
+ self.ch_norm.normalization_max_value_le[i].setEnabled(True)
491
+ self.ch_norm.normalization_min_value_le[i].setEnabled(True)
492
+ self.ch_norm.normalization_clip_btns[i].setEnabled(True)
493
+ self.ch_norm.normalization_min_value_lbl[i].setEnabled(True)
494
+ self.ch_norm.normalization_max_value_lbl[i].setEnabled(True)
495
+ self.ch_norm.add_col_btn.setEnabled(True)
496
+
497
+ self.cancel_pretrained.setVisible(False)
498
+ self.modelname_le.setText(
499
+ f"Untitled_model_{datetime.today().strftime('%Y-%m-%d')}"
500
+ )
501
+
502
+ def clear_dataset(self):
503
+
504
+ self.dataset_folder = None
505
+ self.data_folder_label.setText("No folder chosen")
506
+ self.data_folder_label.setToolTip("")
507
+ self.cancel_dataset.setVisible(False)
508
+
509
+ def load_stardist_train_config(self):
510
+
511
+ config = os.sep.join([self.pretrained_model, "config.json"])
512
+ if os.path.exists(config):
513
+ with open(config, "r") as f:
514
+ config = json.load(f)
515
+ if "train_batch_size" in config:
516
+ bs = config["train_batch_size"]
517
+ self.bs_le.setText(str(bs).replace(".", ","))
518
+ if "train_learning_rate" in config:
519
+ lr = config["train_learning_rate"]
520
+ self.lr_le.setText(str(lr).replace(".", ","))
521
+
522
+ def load_pretrained_config(self):
523
+
524
+ f = open(os.sep.join([self.pretrained_model, "config_input.json"]))
525
+ data = json.load(f)
526
+ channels = data["channels"]
527
+ self.seg_folder = self.pretrained_model.split("/")[-2]
528
+ self.model_name = self.pretrained_model.split("/")[-1]
529
+ if (
530
+ self.model_name.startswith("CP")
531
+ and self.seg_folder == "segmentation_generic"
532
+ ):
533
+ channels = ["brightfield_channel", "live_nuclei_channel"]
534
+ if self.model_name == "CP_nuclei":
535
+ channels = ["live_nuclei_channel", "None"]
536
+ if (
537
+ self.model_name.startswith("SD")
538
+ and self.seg_folder == "segmentation_generic"
539
+ ):
540
+ channels = ["live_nuclei_channel"]
541
+ if self.model_name == "SD_versatile_he":
542
+ channels = ["H&E_1", "H&E_2", "H&E_3"]
543
+
544
+ normalization_percentile = data["normalization_percentile"]
545
+ normalization_clip = data["normalization_clip"]
546
+ normalization_values = data["normalization_values"]
547
+ spatial_calib = data["spatial_calibration"]
548
+ model_type = data["model_type"]
549
+ if model_type == "stardist":
550
+ self.stardist_model.setChecked(True)
551
+ self.cellpose_model.setChecked(False)
552
+ self.load_stardist_train_config()
553
+ else:
554
+ self.stardist_model.setChecked(False)
555
+ self.cellpose_model.setChecked(True)
556
+
557
+ for c, cb in zip(channels, self.ch_norm.channel_cbs):
558
+ index = cb.findText(c)
559
+ cb.setCurrentIndex(index)
560
+
561
+ for i in range(len(channels)):
562
+
563
+ to_clip = normalization_clip[i]
564
+ if self.ch_norm.clip_option[i] != to_clip:
565
+ self.ch_norm.normalization_clip_btns[i].click()
566
+
567
+ use_percentile = normalization_percentile[i]
568
+ if self.ch_norm.normalization_mode[i] != use_percentile:
569
+ self.ch_norm.normalization_mode_btns[i].click()
570
+
571
+ self.ch_norm.normalization_min_value_le[i].setText(
572
+ str(normalization_values[i][0])
573
+ )
574
+ self.ch_norm.normalization_max_value_le[i].setText(
575
+ str(normalization_values[i][1])
576
+ )
577
+
578
+ if len(channels) < len(self.ch_norm.channel_cbs):
579
+ for k in range(len(self.ch_norm.channel_cbs) - len(channels)):
580
+ self.ch_norm.channel_cbs[len(channels) + k].setCurrentIndex(0)
581
+ self.ch_norm.channel_cbs[len(channels) + k].setEnabled(False)
582
+ self.ch_norm.normalization_mode_btns[len(channels) + k].setEnabled(
583
+ False
584
+ )
585
+ self.ch_norm.normalization_max_value_le[len(channels) + k].setEnabled(
586
+ False
587
+ )
588
+ self.ch_norm.normalization_min_value_le[len(channels) + k].setEnabled(
589
+ False
590
+ )
591
+ self.ch_norm.normalization_min_value_lbl[len(channels) + k].setEnabled(
592
+ False
593
+ )
594
+ self.ch_norm.normalization_max_value_lbl[len(channels) + k].setEnabled(
595
+ False
596
+ )
597
+ self.ch_norm.normalization_clip_btns[len(channels) + k].setEnabled(
598
+ False
599
+ )
600
+ self.ch_norm.add_col_btn.setEnabled(False)
601
+
602
+ self.spatial_calib_le.setText(str(spatial_calib).replace(".", ","))
603
+
604
+ def _write_instructions(self):
605
+ if self.bg_loader.isFinished() and hasattr(
606
+ self.bg_loader, "TrainSegModelProcess"
607
+ ):
608
+ TrainSegModelProcess = self.bg_loader.TrainSegModelProcess
609
+ else:
610
+ from celldetective.processes.train_signal_model import (
611
+ TrainSegModelProcess,
612
+ )
613
+
614
+ model_name = self.modelname_le.text()
615
+ pretrained_model = self.pretrained_model
616
+
617
+ channels = []
618
+ for i in range(len(self.ch_norm.channel_cbs)):
619
+ channels.append(self.ch_norm.channel_cbs[i].currentText())
620
+
621
+ slots_to_keep = np.where(np.array(channels) != "--")[0]
622
+ while "--" in channels:
623
+ channels.remove("--")
624
+
625
+ norm_values = np.array(
626
+ [
627
+ [float(a.replace(",", ".")), float(b.replace(",", "."))]
628
+ for a, b in zip(
629
+ [l.text() for l in self.ch_norm.normalization_min_value_le],
630
+ [l.text() for l in self.ch_norm.normalization_max_value_le],
631
+ )
632
+ ]
633
+ )
634
+ norm_values = norm_values[slots_to_keep]
635
+ norm_values = [list(v) for v in norm_values]
636
+
637
+ clip_values = np.array(self.ch_norm.clip_option)
638
+ clip_values = list(clip_values[slots_to_keep])
639
+ clip_values = [bool(c) for c in clip_values]
640
+
641
+ normalization_mode = np.array(self.ch_norm.normalization_mode)
642
+ normalization_mode = list(normalization_mode[slots_to_keep])
643
+ normalization_mode = [bool(m) for m in normalization_mode]
644
+
645
+ data_folders = []
646
+ if self.dataset_folder is not None:
647
+ data_folders.append(self.dataset_folder)
648
+ if self.dataset_cb.currentText() != "--":
649
+ dataset = locate_segmentation_dataset(
650
+ self.dataset_cb.currentText()
651
+ ) # glob(self.soft_path+'/celldetective/datasets/signals/*/')[self.dataset_cb.currentIndex()-1]
652
+ data_folders.append(dataset)
653
+
654
+ aug_factor = round(self.augmentation_slider.value(), 2)
655
+ val_split = round(self.validation_slider.value(), 2)
656
+ if self.stardist_model.isChecked():
657
+ model_type = "stardist"
658
+ else:
659
+ model_type = "cellpose"
660
+
661
+ try:
662
+ lr = float(self.lr_le.text().replace(",", "."))
663
+ except:
664
+ generic_message("Invalid value encountered for the learning rate.")
665
+ return None
666
+
667
+ bs = int(self.bs_le.text())
668
+ epochs = self.epochs_slider.value()
669
+ spatial_calib = float(self.spatial_calib_le.text().replace(",", "."))
670
+
671
+ self.training_instructions = {
672
+ "model_name": model_name,
673
+ "model_type": model_type,
674
+ "pretrained": pretrained_model,
675
+ "spatial_calibration": spatial_calib,
676
+ "channel_option": channels,
677
+ "normalization_percentile": normalization_mode,
678
+ "normalization_clip": clip_values,
679
+ "normalization_values": norm_values,
680
+ "ds": data_folders,
681
+ "augmentation_factor": aug_factor,
682
+ "validation_split": val_split,
683
+ "learning_rate": lr,
684
+ "batch_size": bs,
685
+ "epochs": epochs,
686
+ }
687
+
688
+ model_folder = os.sep.join([self.software_models_dir, model_name, ""])
689
+ print(model_folder)
690
+ if not os.path.exists(model_folder):
691
+ os.mkdir(model_folder)
692
+
693
+ self.training_instructions.update(
694
+ {"target_directory": self.software_models_dir}
695
+ )
696
+
697
+ print(f"Set of instructions: {self.training_instructions}")
698
+
699
+ self.instructions = model_folder + "training_instructions.json"
700
+
701
+ with open(model_folder + "training_instructions.json", "w") as f:
702
+ json.dump(self.training_instructions, f, indent=4)
703
+
704
+ # Progress Window
705
+ self.stop_event = multiprocessing.Event()
706
+ process_args = {
707
+ "instructions": self.instructions,
708
+ "stop_event": self.stop_event,
709
+ "use_gpu": self.use_gpu,
710
+ }
711
+ self.training_was_cancelled = False
712
+ self.is_finished = False
713
+
714
+ self.progress_dialog = DynamicProgressDialog(
715
+ label_text="Preparing model training...",
716
+ max_epochs=epochs,
717
+ parent=self,
718
+ title="Training Segmentation Model",
719
+ )
720
+
721
+ # Create Runner (Thread Logic)
722
+ self.runner = Runner(process=TrainSegModelProcess, process_args=process_args)
723
+
724
+ # Connect Signals
725
+ self.runner.signals.update_pos.connect(self.progress_dialog.update_progress)
726
+ self.runner.signals.update_plot.connect(self.progress_dialog.update_plot)
727
+ self.runner.signals.training_result.connect(self.progress_dialog.show_result)
728
+ self.runner.signals.update_status.connect(self.progress_dialog.update_status)
729
+
730
+ self.runner.signals.finished.connect(self.on_training_finished)
731
+ self.runner.signals.error.connect(self.on_training_error)
732
+
733
+ # Handle Cancel & Interrupt
734
+ # self.progress_dialog.canceled.connect(self.on_training_cancel)
735
+ self.progress_dialog.canceled.connect(self.on_training_cancel)
736
+ self.progress_dialog.interrupted.connect(self.on_training_interrupt)
737
+
738
+ # Start
739
+ self.pool = QThreadPool.globalInstance()
740
+ self.pool.start(self.runner)
741
+ self.progress_dialog.exec_()
742
+
743
+ def on_training_finished(self):
744
+ if self.training_was_cancelled:
745
+ return
746
+
747
+ self.is_finished = True # Mark as complete
748
+
749
+ # Keep dialog open for result viewing
750
+ self.progress_dialog.status_label.setText("Training Finished.")
751
+ self.progress_dialog.cancel_btn.setText("Close")
752
+ self.progress_dialog.progress_bar.setValue(
753
+ self.progress_dialog.progress_bar.maximum()
754
+ )
755
+
756
+ self.runner.close()
757
+ self.parent_window.init_seg_model_list()
758
+ idx = self.parent_window.seg_model_list.findText(self.modelname_le.text())
759
+ self.parent_window.seg_model_list.setCurrentIndex(idx)
760
+
761
+ def on_training_error(self, message):
762
+ if self.training_was_cancelled:
763
+ return
764
+ self.progress_dialog.close()
765
+ QMessageBox.critical(self, "Error", f"Training failed: {message}")
766
+
767
+ def on_training_cancel(self):
768
+ if self.is_finished:
769
+ self.runner.close()
770
+ self.progress_dialog.close()
771
+ return
772
+
773
+ self.training_was_cancelled = True
774
+ self.runner.close()
775
+
776
+ # Deep clean: Delete the model folder if cancelled
777
+ try:
778
+ import shutil
779
+
780
+ model_path = os.path.join(
781
+ self.parent_window.seg_models_dir, self.modelname_le.text()
782
+ )
783
+ if os.path.exists(model_path):
784
+ # Wait briefly for process to release file locks
785
+ time.sleep(0.5)
786
+ shutil.rmtree(model_path)
787
+ logger.info(f"Cancelled training. Deleted model folder: {model_path}")
788
+ except Exception as e:
789
+ logger.error(f"Could not delete model folder after cancel: {e}")
790
+
791
+ def on_training_interrupt(self):
792
+ self.stop_event.set()
793
+
794
+ def _load_previous_instructions(self):
795
+ pass