celldetective 1.4.2__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +304 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/measure_cells.py +565 -0
  81. celldetective/processes/segment_cells.py +760 -0
  82. celldetective/processes/track_cells.py +435 -0
  83. celldetective/processes/train_segmentation_model.py +694 -0
  84. celldetective/processes/train_signal_model.py +265 -0
  85. celldetective/processes/unified_process.py +292 -0
  86. celldetective/regionprops/_regionprops.py +358 -317
  87. celldetective/relative_measurements.py +987 -710
  88. celldetective/scripts/measure_cells.py +313 -212
  89. celldetective/scripts/measure_relative.py +90 -46
  90. celldetective/scripts/segment_cells.py +165 -104
  91. celldetective/scripts/segment_cells_thresholds.py +96 -68
  92. celldetective/scripts/track_cells.py +198 -149
  93. celldetective/scripts/train_segmentation_model.py +324 -201
  94. celldetective/scripts/train_signal_model.py +87 -45
  95. celldetective/segmentation.py +844 -749
  96. celldetective/signals.py +3514 -2861
  97. celldetective/tracking.py +30 -15
  98. celldetective/utils/__init__.py +0 -0
  99. celldetective/utils/cellpose_utils/__init__.py +133 -0
  100. celldetective/utils/color_mappings.py +42 -0
  101. celldetective/utils/data_cleaning.py +630 -0
  102. celldetective/utils/data_loaders.py +450 -0
  103. celldetective/utils/dataset_helpers.py +207 -0
  104. celldetective/utils/downloaders.py +197 -0
  105. celldetective/utils/event_detection/__init__.py +8 -0
  106. celldetective/utils/experiment.py +1782 -0
  107. celldetective/utils/image_augmenters.py +308 -0
  108. celldetective/utils/image_cleaning.py +74 -0
  109. celldetective/utils/image_loaders.py +926 -0
  110. celldetective/utils/image_transforms.py +335 -0
  111. celldetective/utils/io.py +62 -0
  112. celldetective/utils/mask_cleaning.py +348 -0
  113. celldetective/utils/mask_transforms.py +5 -0
  114. celldetective/utils/masks.py +184 -0
  115. celldetective/utils/maths.py +351 -0
  116. celldetective/utils/model_getters.py +325 -0
  117. celldetective/utils/model_loaders.py +296 -0
  118. celldetective/utils/normalization.py +380 -0
  119. celldetective/utils/parsing.py +465 -0
  120. celldetective/utils/plots/__init__.py +0 -0
  121. celldetective/utils/plots/regression.py +53 -0
  122. celldetective/utils/resources.py +34 -0
  123. celldetective/utils/stardist_utils/__init__.py +104 -0
  124. celldetective/utils/stats.py +90 -0
  125. celldetective/utils/types.py +21 -0
  126. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
  127. celldetective-1.5.0b0.dist-info/RECORD +187 -0
  128. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
  129. tests/gui/test_new_project.py +129 -117
  130. tests/gui/test_project.py +127 -79
  131. tests/test_filters.py +39 -15
  132. tests/test_notebooks.py +8 -0
  133. tests/test_tracking.py +232 -13
  134. tests/test_utils.py +123 -77
  135. celldetective/gui/base_components.py +0 -23
  136. celldetective/gui/layouts.py +0 -1602
  137. celldetective/gui/processes/compute_neighborhood.py +0 -594
  138. celldetective/gui/processes/measure_cells.py +0 -360
  139. celldetective/gui/processes/segment_cells.py +0 -499
  140. celldetective/gui/processes/track_cells.py +0 -303
  141. celldetective/gui/processes/train_segmentation_model.py +0 -270
  142. celldetective/gui/processes/train_signal_model.py +0 -108
  143. celldetective/gui/table_ops/merge_groups.py +0 -118
  144. celldetective/gui/viewers.py +0 -1354
  145. celldetective/io.py +0 -3663
  146. celldetective/utils.py +0 -3108
  147. celldetective-1.4.2.dist-info/RECORD +0 -123
  148. /celldetective/{gui/processes → processes}/downloader.py +0 -0
  149. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,449 @@
1
+ import numpy as np
2
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal
3
+ from PyQt5.QtWidgets import QApplication, QComboBox, QDialog, QHBoxLayout, QLabel, QProgressBar, QPushButton, \
4
+ QSizePolicy, QVBoxLayout
5
+ from fonticon_mdi6 import MDI6
6
+ from matplotlib import pyplot as plt
7
+ from matplotlib.figure import Figure
8
+ from superqt.fonticon import icon
9
+
10
+ from celldetective.gui.base.styles import Styles
11
+ from celldetective.gui.base.figure_canvas import FigureCanvas
12
+ from celldetective import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ class DynamicProgressDialog(QDialog, Styles):
17
+ canceled = pyqtSignal()
18
+ interrupted = pyqtSignal()
19
+
20
+ def __init__(
21
+ self,
22
+ title="Training Progress",
23
+ label_text="Launching the training script...",
24
+ minimum=0,
25
+ maximum=100,
26
+ max_epochs=100,
27
+ parent=None,
28
+ ):
29
+ super().__init__(parent)
30
+ Styles.__init__(self)
31
+ self.setWindowTitle(title)
32
+ self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint)
33
+ self.setWindowModality(Qt.ApplicationModal)
34
+
35
+ self.resize(600, 500) # Standard size
36
+
37
+ self.max_epochs = max_epochs # Keep this from original __init__
38
+ self.current_epoch = 0 # Keep this from original __init__
39
+ self.metrics_history = ( # Keep this from original __init__
40
+ {}
41
+ ) # Struct: {metric_name: {train: [], val: [], epochs: []}}
42
+ self.current_model_name = None # Keep this from original __init__
43
+ self.last_update_time = 0 # Keep this from original __init__
44
+ self.log_scale = False # Keep this from original __init__
45
+ self.user_interrupted = False
46
+ self.is_percentile_scaled = False
47
+
48
+ # Layouts
49
+ layout = QVBoxLayout(self)
50
+ layout.setContentsMargins(30, 30, 30, 30)
51
+
52
+ # Labels
53
+ self.status_label = QLabel(label_text)
54
+ # self.status_label.setStyleSheet("color: #333; font-size: 14px;")
55
+ layout.addWidget(self.status_label)
56
+
57
+ # Progress Bar
58
+ self.progress_bar = QProgressBar()
59
+ self.progress_bar.setRange(minimum, maximum)
60
+ self.progress_bar.setStyleSheet(self.progress_bar_style)
61
+ layout.addWidget(self.progress_bar)
62
+
63
+ # Plot Canvas
64
+ self.figure = Figure(figsize=(5, 4), dpi=100)
65
+ self.figure.patch.set_alpha(0.0) # Transparent figure
66
+ self.canvas = FigureCanvas(self.figure)
67
+ self.ax = self.figure.add_subplot(111)
68
+ self.apply_plot_style()
69
+
70
+ # Toolbar / Controls
71
+ controls_layout = QHBoxLayout()
72
+
73
+ # Log Scale Button
74
+ self.btn_log = QPushButton("")
75
+ self.btn_log.setCheckable(True)
76
+ self.btn_log.setIcon(icon(MDI6.math_log, color="black"))
77
+ self.btn_log.clicked.connect(self.toggle_log_scale)
78
+ self.btn_log.setStyleSheet(self.button_select_all)
79
+ self.btn_log.setEnabled(False)
80
+
81
+ # Auto Scale Button
82
+ # self.btn_auto_scale = QPushButton("Auto Contrast")
83
+ # self.btn_auto_scale.clicked.connect(self.auto_scale)
84
+ # self.btn_auto_scale.setStyleSheet(self.button_style_sheet)
85
+ # self.btn_auto_scale.setEnabled(False)
86
+ # controls_layout.addWidget(self.btn_auto_scale)
87
+
88
+ # Metric Selector
89
+ self.metric_label = QLabel("Metric: ")
90
+ self.metric_combo = QComboBox()
91
+ # self.metric_combo.setStyleSheet(self.combo_style)
92
+ self.metric_combo.currentIndexChanged.connect(self.force_update_plot)
93
+
94
+ controls_layout.addWidget(self.metric_label, 10)
95
+ controls_layout.addWidget(self.metric_combo, 85)
96
+ controls_layout.addWidget(self.btn_log, 5, alignment=Qt.AlignRight)
97
+ layout.addLayout(controls_layout)
98
+
99
+ # Add Canvas
100
+ self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
101
+ self.canvas.setStyleSheet("background-color: transparent;")
102
+ layout.addWidget(self.canvas)
103
+
104
+ # Buttons Layout
105
+ btn_layout = QHBoxLayout()
106
+
107
+ # Skip Button
108
+ self.skip_btn = QPushButton("Interrupt && Skip")
109
+ self.skip_btn.setStyleSheet(self.button_style_sheet_2)
110
+ self.skip_btn.setIcon(icon(MDI6.skip_next, color=self.celldetective_blue))
111
+ self.skip_btn.clicked.connect(self.on_skip)
112
+ self.skip_btn.setEnabled(False)
113
+ btn_layout.addWidget(self.skip_btn, 50)
114
+
115
+ # Cancel Button
116
+ self.cancel_btn = QPushButton("Cancel")
117
+ self.cancel_btn.setStyleSheet(self.button_style_sheet)
118
+ self.cancel_btn.clicked.connect(self.on_cancel)
119
+ btn_layout.addWidget(self.cancel_btn, 50)
120
+
121
+ layout.addLayout(btn_layout)
122
+ self._get_screen_height()
123
+ self.adjustSize()
124
+ new_width = int(self.width() * 1.01)
125
+ self.resize(new_width, int(self._screen_height * 0.7))
126
+ self.setMinimumWidth(new_width)
127
+
128
+ def _get_screen_height(self):
129
+ app = QApplication.instance()
130
+ screen = app.primaryScreen()
131
+ geometry = screen.availableGeometry()
132
+ self._screen_width, self._screen_height = geometry.getRect()[-2:]
133
+
134
+ def on_skip(self):
135
+ self.interrupted.emit()
136
+ self.skip_btn.setDisabled(True)
137
+ self.user_interrupted = True
138
+ self.status_label.setText(
139
+ "Interrupting current model training [effective at the end of the current epoch]..."
140
+ )
141
+
142
+ def apply_plot_style(self):
143
+ self.ax.spines["top"].set_visible(False)
144
+ self.ax.spines["right"].set_visible(False)
145
+ self.ax.patch.set_alpha(0.0)
146
+ self.ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
147
+ self.ax.minorticks_on()
148
+ if getattr(self, "log_scale", False):
149
+ self.ax.set_yscale("log")
150
+ else:
151
+ self.ax.set_yscale("linear")
152
+
153
+ def show_result(self, results):
154
+ """Display final results (Confusion Matrix or Regression Plot)"""
155
+ self.ax.clear()
156
+ self.apply_plot_style()
157
+ self.ax.set_yscale("linear")
158
+ self.ax.set_xscale("linear")
159
+ self.metric_combo.hide()
160
+ self.metric_label.hide()
161
+ self.btn_log.hide()
162
+ # self.btn_auto_scale.hide()
163
+
164
+ # Regression
165
+ if "val_predictions" in results and "val_ground_truth" in results:
166
+ preds = results["val_predictions"]
167
+ gt = results["val_ground_truth"]
168
+
169
+ self.ax.scatter(gt, preds, alpha=0.5, c="white", edgecolors="C0")
170
+
171
+ min_val = min(gt.min(), preds.min())
172
+ max_val = max(gt.max(), preds.max())
173
+ self.ax.plot([min_val, max_val], [min_val, max_val], "r--")
174
+
175
+ self.ax.set_xlabel("Ground Truth")
176
+ self.ax.set_ylabel("Predictions")
177
+ val_mse = results.get("val_mse", "N/A")
178
+ if isinstance(val_mse, (int, float)):
179
+ title_str = f"Regression Result (MSE: {val_mse:.4f})"
180
+ else:
181
+ title_str = f"Regression Result (MSE: {val_mse})"
182
+ self.ax.set_title(title_str)
183
+ self.ax.set_aspect("equal", adjustable="box")
184
+
185
+ # Classification (Confusion Matrix)
186
+ elif "val_confusion" in results or "test_confusion" in results:
187
+ cm = results.get("val_confusion", results.get("test_confusion"))
188
+ norm_cm = cm / cm.sum(axis=1)[:, np.newaxis]
189
+
190
+ im = self.ax.imshow(
191
+ norm_cm, interpolation="nearest", cmap=plt.cm.Blues, aspect="equal"
192
+ )
193
+ self.ax.set_title("Confusion Matrix (Normalized)")
194
+ self.ax.set_ylabel("True label")
195
+ self.ax.set_xlabel("Predicted label")
196
+
197
+ # Custom ticks
198
+ tick_marks = np.arange(len(norm_cm))
199
+ self.ax.set_xticks(tick_marks)
200
+ self.ax.set_yticks(tick_marks)
201
+
202
+ if len(norm_cm) == 3:
203
+ labels = ["event", "no event", "else"]
204
+ self.ax.set_xticklabels(labels)
205
+ self.ax.set_yticklabels(labels)
206
+
207
+ self.ax.grid(False)
208
+
209
+ fmt = ".2f"
210
+ thresh = norm_cm.max() / 2.0
211
+ for i in range(norm_cm.shape[0]):
212
+ for j in range(norm_cm.shape[1]):
213
+ self.ax.text(
214
+ j,
215
+ i,
216
+ format(norm_cm[i, j], fmt),
217
+ ha="center",
218
+ va="center",
219
+ color="white" if norm_cm[i, j] > thresh else "black",
220
+ )
221
+
222
+ else:
223
+ self.ax.text(
224
+ 0.5,
225
+ 0.5,
226
+ "No visualization data found.",
227
+ ha="center",
228
+ va="center",
229
+ transform=self.ax.transAxes,
230
+ )
231
+ self.canvas.draw()
232
+
233
+ def toggle_log_scale(self):
234
+ self.log_scale = self.btn_log.isChecked()
235
+ self.update_plot_display()
236
+ self.figure.tight_layout()
237
+ if self.ax.get_yscale() == "linear":
238
+ self.btn_log.setIcon(icon(MDI6.math_log, color="black"))
239
+ try:
240
+ QTimer.singleShot(
241
+ 100, lambda: self.resize(self.width() - 1, self.height() - 1)
242
+ )
243
+ except:
244
+ pass
245
+ else:
246
+ self.btn_log.setIcon(icon(MDI6.math_log, color="white"))
247
+ try:
248
+ QTimer.singleShot(
249
+ 100, lambda: self.resize(self.width() + 1, self.height() + 1)
250
+ )
251
+ except:
252
+ pass
253
+
254
+ def auto_scale(self):
255
+ target_metric = self.metric_combo.currentText()
256
+ if not target_metric or target_metric not in self.metrics_history:
257
+ return
258
+
259
+ # Get data once
260
+ data = self.metrics_history[target_metric]
261
+ y_values = []
262
+ if "train" in data:
263
+ y_values.extend([v for v in data["train"] if v is not None])
264
+ if "val" in data:
265
+ y_values.extend([v for v in data["val"] if v is not None])
266
+
267
+ y_values = np.array(y_values)
268
+ if len(y_values) == 0:
269
+ return
270
+
271
+ if not getattr(self, "is_percentile_scaled", False):
272
+ # Mode: Percentile 1-99
273
+ try:
274
+ p1, p99 = np.nanpercentile(y_values, [1, 99])
275
+ if p1 != p99:
276
+ self.ax.set_ylim(p1, p99)
277
+ self.is_percentile_scaled = True
278
+ except Exception as e:
279
+ logger.warning(f"Could not compute percentiles: {e}")
280
+ else:
281
+ # Mode: Min/Max (Standard Autoscale)
282
+ try:
283
+ min_val, max_val = np.nanmin(y_values), np.nanmax(y_values)
284
+ # Add a small padding (5%)
285
+ margin = (max_val - min_val) * 0.05
286
+ if margin == 0:
287
+ margin = 0.1 # default padding if constant
288
+ self.ax.set_ylim(min_val - margin, max_val + margin)
289
+ self.is_percentile_scaled = False
290
+ except Exception as e:
291
+ logger.warning(f"Could not compute min/max: {e}")
292
+ self.ax.relim()
293
+ self.ax.autoscale_view()
294
+
295
+ self.canvas.draw()
296
+
297
+ def force_update_plot(self):
298
+ self.update_plot_display()
299
+
300
+ def on_cancel(self):
301
+ self.canceled.emit()
302
+ self.reject()
303
+
304
+ def update_progress(self, value, text=None):
305
+ self.progress_bar.setValue(value)
306
+ if text:
307
+ self.status_label.setText(text)
308
+
309
+ def update_plot(self, epoch_data):
310
+ import time
311
+
312
+ """
313
+ epoch_data: dict with keys 'epoch', 'metrics' (dict), 'val_metrics' (dict), 'model_name', 'total_epochs'
314
+ """
315
+ model_name = epoch_data.get("model_name", "Unknown")
316
+ total_epochs = epoch_data.get("total_epochs", 100)
317
+ epoch = epoch_data.get("epoch", 0)
318
+ metrics = epoch_data.get("metrics", {})
319
+ val_metrics = epoch_data.get("val_metrics", {})
320
+
321
+ # Handle Model Switch
322
+ if model_name != self.current_model_name:
323
+ self.metrics_history = {} # Clear history
324
+ self.current_model_name = model_name
325
+ self.user_interrupted = False
326
+ self.metric_combo.blockSignals(True)
327
+ self.metric_combo.clear()
328
+ # Populate combos with keys present in metrics (assuming val_metrics shares keys usually)
329
+ # Find common keys or just use metrics keys for simplicity
330
+ potential_metrics = list(metrics.keys())
331
+ # Prioritize 'iou' or 'loss' if present
332
+ potential_metrics.sort(
333
+ key=lambda x: 0 if x in ["iou", "loss", "mse"] else 1
334
+ )
335
+ self.metric_combo.addItems(potential_metrics)
336
+ self.metric_combo.blockSignals(False)
337
+
338
+ self.status_label.setText(f"Training {model_name}...")
339
+ self.ax.clear()
340
+ self.apply_plot_style()
341
+ self.metric_combo.show()
342
+ self.metric_label.show()
343
+ self.btn_log.show()
344
+ # self.btn_auto_scale.show()
345
+ self.btn_log.setEnabled(True)
346
+ # self.btn_auto_scale.setEnabled(True)
347
+ self.ax.set_aspect("auto")
348
+ self.current_plot_metric = None
349
+ self.update_plot_display()
350
+
351
+ # Update History
352
+ # Initialize keys if new
353
+ for k, v in metrics.items():
354
+ if k not in self.metrics_history:
355
+ self.metrics_history[k] = {"train": [], "val": [], "epochs": []}
356
+
357
+ self.metrics_history[k]["epochs"].append(epoch)
358
+ self.metrics_history[k]["train"].append(v)
359
+
360
+ # Find corresponding val metric
361
+ val_key = f"val_{k}"
362
+ if val_key in val_metrics:
363
+ self.metrics_history[k]["val"].append(val_metrics[val_key])
364
+ else:
365
+ self.metrics_history[k]["val"].append(None)
366
+
367
+ # Store total epochs for limits
368
+ self.current_total_epochs = total_epochs
369
+
370
+ # Throttle Update (3 seconds) OR if explicit end
371
+ current_time = time.time()
372
+
373
+ if epoch > -1 and not self.user_interrupted:
374
+ self.skip_btn.setEnabled(True)
375
+
376
+ if (current_time - self.last_update_time > 3.0) or (epoch >= total_epochs):
377
+ self.update_plot_display()
378
+ self.last_update_time = current_time
379
+
380
+ def update_plot_display(self):
381
+ target_metric = self.metric_combo.currentText()
382
+ if not target_metric or target_metric not in self.metrics_history:
383
+ return
384
+
385
+ data = self.metrics_history[target_metric]
386
+
387
+ # Check if we need to initialize the plot (new metric or first time)
388
+ if getattr(self, "current_plot_metric", None) != target_metric:
389
+ self.ax.clear()
390
+ self.apply_plot_style()
391
+ # self.ax.set_title(f"Training {self.current_model_name} - {target_metric}")
392
+ self.ax.set_xlabel("Epoch")
393
+ self.ax.set_ylabel(target_metric)
394
+
395
+ # Initial X limits
396
+ if hasattr(self, "current_total_epochs"):
397
+ self.ax.set_xlim(0, self.current_total_epochs)
398
+
399
+ # Initialize lines
400
+ (self.train_line,) = self.ax.plot(
401
+ [], [], label="Train", marker=".", color="tab:blue"
402
+ )
403
+ (self.val_line,) = self.ax.plot(
404
+ [], [], label="Validation", marker=".", color="tab:orange"
405
+ )
406
+ self.ax.legend()
407
+ self.current_plot_metric = target_metric
408
+
409
+ # Update data
410
+ if any(v is not None for v in data["train"]):
411
+ self.train_line.set_data(data["epochs"], data["train"])
412
+
413
+ if any(v is not None for v in data["val"]):
414
+ self.val_line.set_data(data["epochs"], data["val"])
415
+
416
+ # Update limits without resetting zoom if user zoomed
417
+ if getattr(self, "log_scale", False):
418
+ self.ax.set_yscale("log")
419
+ else:
420
+ self.ax.set_yscale("linear")
421
+
422
+ self.ax.relim()
423
+ self.ax.autoscale_view()
424
+ self.canvas.draw()
425
+
426
+ if max(data["epochs"]) % 2:
427
+ try:
428
+ QTimer.singleShot(
429
+ 100, lambda: self.resize(self.width() + 1, self.height() + 1)
430
+ )
431
+ except:
432
+ pass
433
+ else:
434
+ try:
435
+ QTimer.singleShot(
436
+ 100, lambda: self.resize(self.width() - 1, self.height() - 1)
437
+ )
438
+ except:
439
+ pass
440
+
441
+ def update_status(self, text):
442
+ self.status_label.setText(text)
443
+ if "Loading" in text and "librar" in text.lower():
444
+ try:
445
+ QTimer.singleShot(
446
+ 100, lambda: self.status_label.setText("Training model...")
447
+ )
448
+ except:
449
+ pass