celldetective 1.4.2__py3-none-any.whl → 1.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +25 -0
- celldetective/__main__.py +62 -43
- celldetective/_version.py +1 -1
- celldetective/extra_properties.py +477 -399
- celldetective/filters.py +192 -97
- celldetective/gui/InitWindow.py +541 -411
- celldetective/gui/__init__.py +0 -15
- celldetective/gui/about.py +44 -39
- celldetective/gui/analyze_block.py +120 -84
- celldetective/gui/base/__init__.py +0 -0
- celldetective/gui/base/channel_norm_generator.py +335 -0
- celldetective/gui/base/components.py +249 -0
- celldetective/gui/base/feature_choice.py +92 -0
- celldetective/gui/base/figure_canvas.py +52 -0
- celldetective/gui/base/list_widget.py +133 -0
- celldetective/gui/{styles.py → base/styles.py} +92 -36
- celldetective/gui/base/utils.py +33 -0
- celldetective/gui/base_annotator.py +900 -767
- celldetective/gui/classifier_widget.py +6 -22
- celldetective/gui/configure_new_exp.py +777 -671
- celldetective/gui/control_panel.py +635 -524
- celldetective/gui/dynamic_progress.py +449 -0
- celldetective/gui/event_annotator.py +2023 -1662
- celldetective/gui/generic_signal_plot.py +1292 -944
- celldetective/gui/gui_utils.py +899 -1289
- celldetective/gui/interactions_block.py +658 -0
- celldetective/gui/interactive_timeseries_viewer.py +447 -0
- celldetective/gui/json_readers.py +48 -15
- celldetective/gui/layouts/__init__.py +5 -0
- celldetective/gui/layouts/background_model_free_layout.py +537 -0
- celldetective/gui/layouts/channel_offset_layout.py +134 -0
- celldetective/gui/layouts/local_correction_layout.py +91 -0
- celldetective/gui/layouts/model_fit_layout.py +372 -0
- celldetective/gui/layouts/operation_layout.py +68 -0
- celldetective/gui/layouts/protocol_designer_layout.py +96 -0
- celldetective/gui/pair_event_annotator.py +3130 -2435
- celldetective/gui/plot_measurements.py +586 -267
- celldetective/gui/plot_signals_ui.py +724 -506
- celldetective/gui/preprocessing_block.py +395 -0
- celldetective/gui/process_block.py +1678 -1831
- celldetective/gui/seg_model_loader.py +580 -473
- celldetective/gui/settings/__init__.py +0 -7
- celldetective/gui/settings/_cellpose_model_params.py +181 -0
- celldetective/gui/settings/_event_detection_model_params.py +95 -0
- celldetective/gui/settings/_segmentation_model_params.py +159 -0
- celldetective/gui/settings/_settings_base.py +77 -65
- celldetective/gui/settings/_settings_event_model_training.py +752 -526
- celldetective/gui/settings/_settings_measurements.py +1133 -964
- celldetective/gui/settings/_settings_neighborhood.py +574 -488
- celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
- celldetective/gui/settings/_settings_signal_annotator.py +329 -305
- celldetective/gui/settings/_settings_tracking.py +1304 -1094
- celldetective/gui/settings/_stardist_model_params.py +98 -0
- celldetective/gui/survival_ui.py +422 -312
- celldetective/gui/tableUI.py +1665 -1701
- celldetective/gui/table_ops/_maths.py +295 -0
- celldetective/gui/table_ops/_merge_groups.py +140 -0
- celldetective/gui/table_ops/_merge_one_hot.py +95 -0
- celldetective/gui/table_ops/_query_table.py +43 -0
- celldetective/gui/table_ops/_rename_col.py +44 -0
- celldetective/gui/thresholds_gui.py +382 -179
- celldetective/gui/viewers/__init__.py +0 -0
- celldetective/gui/viewers/base_viewer.py +700 -0
- celldetective/gui/viewers/channel_offset_viewer.py +331 -0
- celldetective/gui/viewers/contour_viewer.py +394 -0
- celldetective/gui/viewers/size_viewer.py +153 -0
- celldetective/gui/viewers/spot_detection_viewer.py +341 -0
- celldetective/gui/viewers/threshold_viewer.py +309 -0
- celldetective/gui/workers.py +304 -126
- celldetective/log_manager.py +92 -0
- celldetective/measure.py +1895 -1478
- celldetective/napari/__init__.py +0 -0
- celldetective/napari/utils.py +1025 -0
- celldetective/neighborhood.py +1914 -1448
- celldetective/preprocessing.py +1620 -1220
- celldetective/processes/__init__.py +0 -0
- celldetective/processes/background_correction.py +271 -0
- celldetective/processes/compute_neighborhood.py +894 -0
- celldetective/processes/detect_events.py +246 -0
- celldetective/processes/measure_cells.py +565 -0
- celldetective/processes/segment_cells.py +760 -0
- celldetective/processes/track_cells.py +435 -0
- celldetective/processes/train_segmentation_model.py +694 -0
- celldetective/processes/train_signal_model.py +265 -0
- celldetective/processes/unified_process.py +292 -0
- celldetective/regionprops/_regionprops.py +358 -317
- celldetective/relative_measurements.py +987 -710
- celldetective/scripts/measure_cells.py +313 -212
- celldetective/scripts/measure_relative.py +90 -46
- celldetective/scripts/segment_cells.py +165 -104
- celldetective/scripts/segment_cells_thresholds.py +96 -68
- celldetective/scripts/track_cells.py +198 -149
- celldetective/scripts/train_segmentation_model.py +324 -201
- celldetective/scripts/train_signal_model.py +87 -45
- celldetective/segmentation.py +844 -749
- celldetective/signals.py +3514 -2861
- celldetective/tracking.py +30 -15
- celldetective/utils/__init__.py +0 -0
- celldetective/utils/cellpose_utils/__init__.py +133 -0
- celldetective/utils/color_mappings.py +42 -0
- celldetective/utils/data_cleaning.py +630 -0
- celldetective/utils/data_loaders.py +450 -0
- celldetective/utils/dataset_helpers.py +207 -0
- celldetective/utils/downloaders.py +197 -0
- celldetective/utils/event_detection/__init__.py +8 -0
- celldetective/utils/experiment.py +1782 -0
- celldetective/utils/image_augmenters.py +308 -0
- celldetective/utils/image_cleaning.py +74 -0
- celldetective/utils/image_loaders.py +926 -0
- celldetective/utils/image_transforms.py +335 -0
- celldetective/utils/io.py +62 -0
- celldetective/utils/mask_cleaning.py +348 -0
- celldetective/utils/mask_transforms.py +5 -0
- celldetective/utils/masks.py +184 -0
- celldetective/utils/maths.py +351 -0
- celldetective/utils/model_getters.py +325 -0
- celldetective/utils/model_loaders.py +296 -0
- celldetective/utils/normalization.py +380 -0
- celldetective/utils/parsing.py +465 -0
- celldetective/utils/plots/__init__.py +0 -0
- celldetective/utils/plots/regression.py +53 -0
- celldetective/utils/resources.py +34 -0
- celldetective/utils/stardist_utils/__init__.py +104 -0
- celldetective/utils/stats.py +90 -0
- celldetective/utils/types.py +21 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
- celldetective-1.5.0b0.dist-info/RECORD +187 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
- tests/gui/test_new_project.py +129 -117
- tests/gui/test_project.py +127 -79
- tests/test_filters.py +39 -15
- tests/test_notebooks.py +8 -0
- tests/test_tracking.py +232 -13
- tests/test_utils.py +123 -77
- celldetective/gui/base_components.py +0 -23
- celldetective/gui/layouts.py +0 -1602
- celldetective/gui/processes/compute_neighborhood.py +0 -594
- celldetective/gui/processes/measure_cells.py +0 -360
- celldetective/gui/processes/segment_cells.py +0 -499
- celldetective/gui/processes/track_cells.py +0 -303
- celldetective/gui/processes/train_segmentation_model.py +0 -270
- celldetective/gui/processes/train_signal_model.py +0 -108
- celldetective/gui/table_ops/merge_groups.py +0 -118
- celldetective/gui/viewers.py +0 -1354
- celldetective/io.py +0 -3663
- celldetective/utils.py +0 -3108
- celldetective-1.4.2.dist-info/RECORD +0 -123
- /celldetective/{gui/processes → processes}/downloader.py +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from PyQt5.QtCore import QTimer, Qt, pyqtSignal
|
|
3
|
+
from PyQt5.QtWidgets import QApplication, QComboBox, QDialog, QHBoxLayout, QLabel, QProgressBar, QPushButton, \
|
|
4
|
+
QSizePolicy, QVBoxLayout
|
|
5
|
+
from fonticon_mdi6 import MDI6
|
|
6
|
+
from matplotlib import pyplot as plt
|
|
7
|
+
from matplotlib.figure import Figure
|
|
8
|
+
from superqt.fonticon import icon
|
|
9
|
+
|
|
10
|
+
from celldetective.gui.base.styles import Styles
|
|
11
|
+
from celldetective.gui.base.figure_canvas import FigureCanvas
|
|
12
|
+
from celldetective import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
class DynamicProgressDialog(QDialog, Styles):
|
|
17
|
+
canceled = pyqtSignal()
|
|
18
|
+
interrupted = pyqtSignal()
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
title="Training Progress",
|
|
23
|
+
label_text="Launching the training script...",
|
|
24
|
+
minimum=0,
|
|
25
|
+
maximum=100,
|
|
26
|
+
max_epochs=100,
|
|
27
|
+
parent=None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(parent)
|
|
30
|
+
Styles.__init__(self)
|
|
31
|
+
self.setWindowTitle(title)
|
|
32
|
+
self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint)
|
|
33
|
+
self.setWindowModality(Qt.ApplicationModal)
|
|
34
|
+
|
|
35
|
+
self.resize(600, 500) # Standard size
|
|
36
|
+
|
|
37
|
+
self.max_epochs = max_epochs # Keep this from original __init__
|
|
38
|
+
self.current_epoch = 0 # Keep this from original __init__
|
|
39
|
+
self.metrics_history = ( # Keep this from original __init__
|
|
40
|
+
{}
|
|
41
|
+
) # Struct: {metric_name: {train: [], val: [], epochs: []}}
|
|
42
|
+
self.current_model_name = None # Keep this from original __init__
|
|
43
|
+
self.last_update_time = 0 # Keep this from original __init__
|
|
44
|
+
self.log_scale = False # Keep this from original __init__
|
|
45
|
+
self.user_interrupted = False
|
|
46
|
+
self.is_percentile_scaled = False
|
|
47
|
+
|
|
48
|
+
# Layouts
|
|
49
|
+
layout = QVBoxLayout(self)
|
|
50
|
+
layout.setContentsMargins(30, 30, 30, 30)
|
|
51
|
+
|
|
52
|
+
# Labels
|
|
53
|
+
self.status_label = QLabel(label_text)
|
|
54
|
+
# self.status_label.setStyleSheet("color: #333; font-size: 14px;")
|
|
55
|
+
layout.addWidget(self.status_label)
|
|
56
|
+
|
|
57
|
+
# Progress Bar
|
|
58
|
+
self.progress_bar = QProgressBar()
|
|
59
|
+
self.progress_bar.setRange(minimum, maximum)
|
|
60
|
+
self.progress_bar.setStyleSheet(self.progress_bar_style)
|
|
61
|
+
layout.addWidget(self.progress_bar)
|
|
62
|
+
|
|
63
|
+
# Plot Canvas
|
|
64
|
+
self.figure = Figure(figsize=(5, 4), dpi=100)
|
|
65
|
+
self.figure.patch.set_alpha(0.0) # Transparent figure
|
|
66
|
+
self.canvas = FigureCanvas(self.figure)
|
|
67
|
+
self.ax = self.figure.add_subplot(111)
|
|
68
|
+
self.apply_plot_style()
|
|
69
|
+
|
|
70
|
+
# Toolbar / Controls
|
|
71
|
+
controls_layout = QHBoxLayout()
|
|
72
|
+
|
|
73
|
+
# Log Scale Button
|
|
74
|
+
self.btn_log = QPushButton("")
|
|
75
|
+
self.btn_log.setCheckable(True)
|
|
76
|
+
self.btn_log.setIcon(icon(MDI6.math_log, color="black"))
|
|
77
|
+
self.btn_log.clicked.connect(self.toggle_log_scale)
|
|
78
|
+
self.btn_log.setStyleSheet(self.button_select_all)
|
|
79
|
+
self.btn_log.setEnabled(False)
|
|
80
|
+
|
|
81
|
+
# Auto Scale Button
|
|
82
|
+
# self.btn_auto_scale = QPushButton("Auto Contrast")
|
|
83
|
+
# self.btn_auto_scale.clicked.connect(self.auto_scale)
|
|
84
|
+
# self.btn_auto_scale.setStyleSheet(self.button_style_sheet)
|
|
85
|
+
# self.btn_auto_scale.setEnabled(False)
|
|
86
|
+
# controls_layout.addWidget(self.btn_auto_scale)
|
|
87
|
+
|
|
88
|
+
# Metric Selector
|
|
89
|
+
self.metric_label = QLabel("Metric: ")
|
|
90
|
+
self.metric_combo = QComboBox()
|
|
91
|
+
# self.metric_combo.setStyleSheet(self.combo_style)
|
|
92
|
+
self.metric_combo.currentIndexChanged.connect(self.force_update_plot)
|
|
93
|
+
|
|
94
|
+
controls_layout.addWidget(self.metric_label, 10)
|
|
95
|
+
controls_layout.addWidget(self.metric_combo, 85)
|
|
96
|
+
controls_layout.addWidget(self.btn_log, 5, alignment=Qt.AlignRight)
|
|
97
|
+
layout.addLayout(controls_layout)
|
|
98
|
+
|
|
99
|
+
# Add Canvas
|
|
100
|
+
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
|
101
|
+
self.canvas.setStyleSheet("background-color: transparent;")
|
|
102
|
+
layout.addWidget(self.canvas)
|
|
103
|
+
|
|
104
|
+
# Buttons Layout
|
|
105
|
+
btn_layout = QHBoxLayout()
|
|
106
|
+
|
|
107
|
+
# Skip Button
|
|
108
|
+
self.skip_btn = QPushButton("Interrupt && Skip")
|
|
109
|
+
self.skip_btn.setStyleSheet(self.button_style_sheet_2)
|
|
110
|
+
self.skip_btn.setIcon(icon(MDI6.skip_next, color=self.celldetective_blue))
|
|
111
|
+
self.skip_btn.clicked.connect(self.on_skip)
|
|
112
|
+
self.skip_btn.setEnabled(False)
|
|
113
|
+
btn_layout.addWidget(self.skip_btn, 50)
|
|
114
|
+
|
|
115
|
+
# Cancel Button
|
|
116
|
+
self.cancel_btn = QPushButton("Cancel")
|
|
117
|
+
self.cancel_btn.setStyleSheet(self.button_style_sheet)
|
|
118
|
+
self.cancel_btn.clicked.connect(self.on_cancel)
|
|
119
|
+
btn_layout.addWidget(self.cancel_btn, 50)
|
|
120
|
+
|
|
121
|
+
layout.addLayout(btn_layout)
|
|
122
|
+
self._get_screen_height()
|
|
123
|
+
self.adjustSize()
|
|
124
|
+
new_width = int(self.width() * 1.01)
|
|
125
|
+
self.resize(new_width, int(self._screen_height * 0.7))
|
|
126
|
+
self.setMinimumWidth(new_width)
|
|
127
|
+
|
|
128
|
+
def _get_screen_height(self):
|
|
129
|
+
app = QApplication.instance()
|
|
130
|
+
screen = app.primaryScreen()
|
|
131
|
+
geometry = screen.availableGeometry()
|
|
132
|
+
self._screen_width, self._screen_height = geometry.getRect()[-2:]
|
|
133
|
+
|
|
134
|
+
def on_skip(self):
|
|
135
|
+
self.interrupted.emit()
|
|
136
|
+
self.skip_btn.setDisabled(True)
|
|
137
|
+
self.user_interrupted = True
|
|
138
|
+
self.status_label.setText(
|
|
139
|
+
"Interrupting current model training [effective at the end of the current epoch]..."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def apply_plot_style(self):
|
|
143
|
+
self.ax.spines["top"].set_visible(False)
|
|
144
|
+
self.ax.spines["right"].set_visible(False)
|
|
145
|
+
self.ax.patch.set_alpha(0.0)
|
|
146
|
+
self.ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
|
|
147
|
+
self.ax.minorticks_on()
|
|
148
|
+
if getattr(self, "log_scale", False):
|
|
149
|
+
self.ax.set_yscale("log")
|
|
150
|
+
else:
|
|
151
|
+
self.ax.set_yscale("linear")
|
|
152
|
+
|
|
153
|
+
def show_result(self, results):
|
|
154
|
+
"""Display final results (Confusion Matrix or Regression Plot)"""
|
|
155
|
+
self.ax.clear()
|
|
156
|
+
self.apply_plot_style()
|
|
157
|
+
self.ax.set_yscale("linear")
|
|
158
|
+
self.ax.set_xscale("linear")
|
|
159
|
+
self.metric_combo.hide()
|
|
160
|
+
self.metric_label.hide()
|
|
161
|
+
self.btn_log.hide()
|
|
162
|
+
# self.btn_auto_scale.hide()
|
|
163
|
+
|
|
164
|
+
# Regression
|
|
165
|
+
if "val_predictions" in results and "val_ground_truth" in results:
|
|
166
|
+
preds = results["val_predictions"]
|
|
167
|
+
gt = results["val_ground_truth"]
|
|
168
|
+
|
|
169
|
+
self.ax.scatter(gt, preds, alpha=0.5, c="white", edgecolors="C0")
|
|
170
|
+
|
|
171
|
+
min_val = min(gt.min(), preds.min())
|
|
172
|
+
max_val = max(gt.max(), preds.max())
|
|
173
|
+
self.ax.plot([min_val, max_val], [min_val, max_val], "r--")
|
|
174
|
+
|
|
175
|
+
self.ax.set_xlabel("Ground Truth")
|
|
176
|
+
self.ax.set_ylabel("Predictions")
|
|
177
|
+
val_mse = results.get("val_mse", "N/A")
|
|
178
|
+
if isinstance(val_mse, (int, float)):
|
|
179
|
+
title_str = f"Regression Result (MSE: {val_mse:.4f})"
|
|
180
|
+
else:
|
|
181
|
+
title_str = f"Regression Result (MSE: {val_mse})"
|
|
182
|
+
self.ax.set_title(title_str)
|
|
183
|
+
self.ax.set_aspect("equal", adjustable="box")
|
|
184
|
+
|
|
185
|
+
# Classification (Confusion Matrix)
|
|
186
|
+
elif "val_confusion" in results or "test_confusion" in results:
|
|
187
|
+
cm = results.get("val_confusion", results.get("test_confusion"))
|
|
188
|
+
norm_cm = cm / cm.sum(axis=1)[:, np.newaxis]
|
|
189
|
+
|
|
190
|
+
im = self.ax.imshow(
|
|
191
|
+
norm_cm, interpolation="nearest", cmap=plt.cm.Blues, aspect="equal"
|
|
192
|
+
)
|
|
193
|
+
self.ax.set_title("Confusion Matrix (Normalized)")
|
|
194
|
+
self.ax.set_ylabel("True label")
|
|
195
|
+
self.ax.set_xlabel("Predicted label")
|
|
196
|
+
|
|
197
|
+
# Custom ticks
|
|
198
|
+
tick_marks = np.arange(len(norm_cm))
|
|
199
|
+
self.ax.set_xticks(tick_marks)
|
|
200
|
+
self.ax.set_yticks(tick_marks)
|
|
201
|
+
|
|
202
|
+
if len(norm_cm) == 3:
|
|
203
|
+
labels = ["event", "no event", "else"]
|
|
204
|
+
self.ax.set_xticklabels(labels)
|
|
205
|
+
self.ax.set_yticklabels(labels)
|
|
206
|
+
|
|
207
|
+
self.ax.grid(False)
|
|
208
|
+
|
|
209
|
+
fmt = ".2f"
|
|
210
|
+
thresh = norm_cm.max() / 2.0
|
|
211
|
+
for i in range(norm_cm.shape[0]):
|
|
212
|
+
for j in range(norm_cm.shape[1]):
|
|
213
|
+
self.ax.text(
|
|
214
|
+
j,
|
|
215
|
+
i,
|
|
216
|
+
format(norm_cm[i, j], fmt),
|
|
217
|
+
ha="center",
|
|
218
|
+
va="center",
|
|
219
|
+
color="white" if norm_cm[i, j] > thresh else "black",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
self.ax.text(
|
|
224
|
+
0.5,
|
|
225
|
+
0.5,
|
|
226
|
+
"No visualization data found.",
|
|
227
|
+
ha="center",
|
|
228
|
+
va="center",
|
|
229
|
+
transform=self.ax.transAxes,
|
|
230
|
+
)
|
|
231
|
+
self.canvas.draw()
|
|
232
|
+
|
|
233
|
+
def toggle_log_scale(self):
|
|
234
|
+
self.log_scale = self.btn_log.isChecked()
|
|
235
|
+
self.update_plot_display()
|
|
236
|
+
self.figure.tight_layout()
|
|
237
|
+
if self.ax.get_yscale() == "linear":
|
|
238
|
+
self.btn_log.setIcon(icon(MDI6.math_log, color="black"))
|
|
239
|
+
try:
|
|
240
|
+
QTimer.singleShot(
|
|
241
|
+
100, lambda: self.resize(self.width() - 1, self.height() - 1)
|
|
242
|
+
)
|
|
243
|
+
except:
|
|
244
|
+
pass
|
|
245
|
+
else:
|
|
246
|
+
self.btn_log.setIcon(icon(MDI6.math_log, color="white"))
|
|
247
|
+
try:
|
|
248
|
+
QTimer.singleShot(
|
|
249
|
+
100, lambda: self.resize(self.width() + 1, self.height() + 1)
|
|
250
|
+
)
|
|
251
|
+
except:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
def auto_scale(self):
|
|
255
|
+
target_metric = self.metric_combo.currentText()
|
|
256
|
+
if not target_metric or target_metric not in self.metrics_history:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
# Get data once
|
|
260
|
+
data = self.metrics_history[target_metric]
|
|
261
|
+
y_values = []
|
|
262
|
+
if "train" in data:
|
|
263
|
+
y_values.extend([v for v in data["train"] if v is not None])
|
|
264
|
+
if "val" in data:
|
|
265
|
+
y_values.extend([v for v in data["val"] if v is not None])
|
|
266
|
+
|
|
267
|
+
y_values = np.array(y_values)
|
|
268
|
+
if len(y_values) == 0:
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
if not getattr(self, "is_percentile_scaled", False):
|
|
272
|
+
# Mode: Percentile 1-99
|
|
273
|
+
try:
|
|
274
|
+
p1, p99 = np.nanpercentile(y_values, [1, 99])
|
|
275
|
+
if p1 != p99:
|
|
276
|
+
self.ax.set_ylim(p1, p99)
|
|
277
|
+
self.is_percentile_scaled = True
|
|
278
|
+
except Exception as e:
|
|
279
|
+
logger.warning(f"Could not compute percentiles: {e}")
|
|
280
|
+
else:
|
|
281
|
+
# Mode: Min/Max (Standard Autoscale)
|
|
282
|
+
try:
|
|
283
|
+
min_val, max_val = np.nanmin(y_values), np.nanmax(y_values)
|
|
284
|
+
# Add a small padding (5%)
|
|
285
|
+
margin = (max_val - min_val) * 0.05
|
|
286
|
+
if margin == 0:
|
|
287
|
+
margin = 0.1 # default padding if constant
|
|
288
|
+
self.ax.set_ylim(min_val - margin, max_val + margin)
|
|
289
|
+
self.is_percentile_scaled = False
|
|
290
|
+
except Exception as e:
|
|
291
|
+
logger.warning(f"Could not compute min/max: {e}")
|
|
292
|
+
self.ax.relim()
|
|
293
|
+
self.ax.autoscale_view()
|
|
294
|
+
|
|
295
|
+
self.canvas.draw()
|
|
296
|
+
|
|
297
|
+
def force_update_plot(self):
|
|
298
|
+
self.update_plot_display()
|
|
299
|
+
|
|
300
|
+
def on_cancel(self):
|
|
301
|
+
self.canceled.emit()
|
|
302
|
+
self.reject()
|
|
303
|
+
|
|
304
|
+
def update_progress(self, value, text=None):
|
|
305
|
+
self.progress_bar.setValue(value)
|
|
306
|
+
if text:
|
|
307
|
+
self.status_label.setText(text)
|
|
308
|
+
|
|
309
|
+
def update_plot(self, epoch_data):
|
|
310
|
+
import time
|
|
311
|
+
|
|
312
|
+
"""
|
|
313
|
+
epoch_data: dict with keys 'epoch', 'metrics' (dict), 'val_metrics' (dict), 'model_name', 'total_epochs'
|
|
314
|
+
"""
|
|
315
|
+
model_name = epoch_data.get("model_name", "Unknown")
|
|
316
|
+
total_epochs = epoch_data.get("total_epochs", 100)
|
|
317
|
+
epoch = epoch_data.get("epoch", 0)
|
|
318
|
+
metrics = epoch_data.get("metrics", {})
|
|
319
|
+
val_metrics = epoch_data.get("val_metrics", {})
|
|
320
|
+
|
|
321
|
+
# Handle Model Switch
|
|
322
|
+
if model_name != self.current_model_name:
|
|
323
|
+
self.metrics_history = {} # Clear history
|
|
324
|
+
self.current_model_name = model_name
|
|
325
|
+
self.user_interrupted = False
|
|
326
|
+
self.metric_combo.blockSignals(True)
|
|
327
|
+
self.metric_combo.clear()
|
|
328
|
+
# Populate combos with keys present in metrics (assuming val_metrics shares keys usually)
|
|
329
|
+
# Find common keys or just use metrics keys for simplicity
|
|
330
|
+
potential_metrics = list(metrics.keys())
|
|
331
|
+
# Prioritize 'iou' or 'loss' if present
|
|
332
|
+
potential_metrics.sort(
|
|
333
|
+
key=lambda x: 0 if x in ["iou", "loss", "mse"] else 1
|
|
334
|
+
)
|
|
335
|
+
self.metric_combo.addItems(potential_metrics)
|
|
336
|
+
self.metric_combo.blockSignals(False)
|
|
337
|
+
|
|
338
|
+
self.status_label.setText(f"Training {model_name}...")
|
|
339
|
+
self.ax.clear()
|
|
340
|
+
self.apply_plot_style()
|
|
341
|
+
self.metric_combo.show()
|
|
342
|
+
self.metric_label.show()
|
|
343
|
+
self.btn_log.show()
|
|
344
|
+
# self.btn_auto_scale.show()
|
|
345
|
+
self.btn_log.setEnabled(True)
|
|
346
|
+
# self.btn_auto_scale.setEnabled(True)
|
|
347
|
+
self.ax.set_aspect("auto")
|
|
348
|
+
self.current_plot_metric = None
|
|
349
|
+
self.update_plot_display()
|
|
350
|
+
|
|
351
|
+
# Update History
|
|
352
|
+
# Initialize keys if new
|
|
353
|
+
for k, v in metrics.items():
|
|
354
|
+
if k not in self.metrics_history:
|
|
355
|
+
self.metrics_history[k] = {"train": [], "val": [], "epochs": []}
|
|
356
|
+
|
|
357
|
+
self.metrics_history[k]["epochs"].append(epoch)
|
|
358
|
+
self.metrics_history[k]["train"].append(v)
|
|
359
|
+
|
|
360
|
+
# Find corresponding val metric
|
|
361
|
+
val_key = f"val_{k}"
|
|
362
|
+
if val_key in val_metrics:
|
|
363
|
+
self.metrics_history[k]["val"].append(val_metrics[val_key])
|
|
364
|
+
else:
|
|
365
|
+
self.metrics_history[k]["val"].append(None)
|
|
366
|
+
|
|
367
|
+
# Store total epochs for limits
|
|
368
|
+
self.current_total_epochs = total_epochs
|
|
369
|
+
|
|
370
|
+
# Throttle Update (3 seconds) OR if explicit end
|
|
371
|
+
current_time = time.time()
|
|
372
|
+
|
|
373
|
+
if epoch > -1 and not self.user_interrupted:
|
|
374
|
+
self.skip_btn.setEnabled(True)
|
|
375
|
+
|
|
376
|
+
if (current_time - self.last_update_time > 3.0) or (epoch >= total_epochs):
|
|
377
|
+
self.update_plot_display()
|
|
378
|
+
self.last_update_time = current_time
|
|
379
|
+
|
|
380
|
+
def update_plot_display(self):
|
|
381
|
+
target_metric = self.metric_combo.currentText()
|
|
382
|
+
if not target_metric or target_metric not in self.metrics_history:
|
|
383
|
+
return
|
|
384
|
+
|
|
385
|
+
data = self.metrics_history[target_metric]
|
|
386
|
+
|
|
387
|
+
# Check if we need to initialize the plot (new metric or first time)
|
|
388
|
+
if getattr(self, "current_plot_metric", None) != target_metric:
|
|
389
|
+
self.ax.clear()
|
|
390
|
+
self.apply_plot_style()
|
|
391
|
+
# self.ax.set_title(f"Training {self.current_model_name} - {target_metric}")
|
|
392
|
+
self.ax.set_xlabel("Epoch")
|
|
393
|
+
self.ax.set_ylabel(target_metric)
|
|
394
|
+
|
|
395
|
+
# Initial X limits
|
|
396
|
+
if hasattr(self, "current_total_epochs"):
|
|
397
|
+
self.ax.set_xlim(0, self.current_total_epochs)
|
|
398
|
+
|
|
399
|
+
# Initialize lines
|
|
400
|
+
(self.train_line,) = self.ax.plot(
|
|
401
|
+
[], [], label="Train", marker=".", color="tab:blue"
|
|
402
|
+
)
|
|
403
|
+
(self.val_line,) = self.ax.plot(
|
|
404
|
+
[], [], label="Validation", marker=".", color="tab:orange"
|
|
405
|
+
)
|
|
406
|
+
self.ax.legend()
|
|
407
|
+
self.current_plot_metric = target_metric
|
|
408
|
+
|
|
409
|
+
# Update data
|
|
410
|
+
if any(v is not None for v in data["train"]):
|
|
411
|
+
self.train_line.set_data(data["epochs"], data["train"])
|
|
412
|
+
|
|
413
|
+
if any(v is not None for v in data["val"]):
|
|
414
|
+
self.val_line.set_data(data["epochs"], data["val"])
|
|
415
|
+
|
|
416
|
+
# Update limits without resetting zoom if user zoomed
|
|
417
|
+
if getattr(self, "log_scale", False):
|
|
418
|
+
self.ax.set_yscale("log")
|
|
419
|
+
else:
|
|
420
|
+
self.ax.set_yscale("linear")
|
|
421
|
+
|
|
422
|
+
self.ax.relim()
|
|
423
|
+
self.ax.autoscale_view()
|
|
424
|
+
self.canvas.draw()
|
|
425
|
+
|
|
426
|
+
if max(data["epochs"]) % 2:
|
|
427
|
+
try:
|
|
428
|
+
QTimer.singleShot(
|
|
429
|
+
100, lambda: self.resize(self.width() + 1, self.height() + 1)
|
|
430
|
+
)
|
|
431
|
+
except:
|
|
432
|
+
pass
|
|
433
|
+
else:
|
|
434
|
+
try:
|
|
435
|
+
QTimer.singleShot(
|
|
436
|
+
100, lambda: self.resize(self.width() - 1, self.height() - 1)
|
|
437
|
+
)
|
|
438
|
+
except:
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
def update_status(self, text):
|
|
442
|
+
self.status_label.setText(text)
|
|
443
|
+
if "Loading" in text and "librar" in text.lower():
|
|
444
|
+
try:
|
|
445
|
+
QTimer.singleShot(
|
|
446
|
+
100, lambda: self.status_label.setText("Training model...")
|
|
447
|
+
)
|
|
448
|
+
except:
|
|
449
|
+
pass
|