singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,3719 @@
1
+ import logging
2
+ from PyQt6.QtWidgets import (
3
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSpinBox,
4
+ QDoubleSpinBox, QPlainTextEdit, QProgressBar, QGroupBox, QFormLayout,
5
+ QFileDialog, QMessageBox, QLineEdit, QCheckBox, QToolButton,
6
+ QScrollArea, QListWidget, QListWidgetItem, QTableWidget, QTableWidgetItem,
7
+ QHeaderView, QDialog, QGridLayout, QSizePolicy
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+ from PyQt6.QtCore import QThread, pyqtSignal, Qt
12
+ from PyQt6.QtGui import QPixmap
13
+ import matplotlib
14
+ matplotlib.use('QtAgg')
15
+ from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
16
+ from matplotlib.figure import Figure
17
+ import os
18
+ import copy
19
+ import shutil
20
+ import tempfile
21
+ import json
22
+ import time
23
+ import glob
24
+ import torch
25
+ import yaml
26
+ import numpy as np
27
+ import cv2
28
+ import random
29
+ import re
30
+ from singlebehaviorlab.backend.data_store import AnnotationManager
31
+ from singlebehaviorlab.backend.model import VideoPrismBackbone, BehaviorClassifier
32
+ from singlebehaviorlab.backend.train import BehaviorDataset, train_model
33
+ from singlebehaviorlab.backend.video_utils import load_clip_frames
34
+ from .training_profiles import TrainingProfileDialog
35
+
36
+
37
+ class TrainingWorker(QThread):
38
+ """Worker thread for training."""
39
+ log_message = pyqtSignal(str)
40
+ progress = pyqtSignal(int, int)
41
+ finished = pyqtSignal()
42
+ error = pyqtSignal(str)
43
+ training_complete = pyqtSignal(float, float, float, dict) # (best_val_acc, best_val_f1, final_train_acc, per_class_f1)
44
+ epoch_complete = pyqtSignal(dict) # New signal for real-time metrics
45
+
46
+ def __init__(self, config, train_config, annotation_file, clips_dir, output_path):
47
+ super().__init__()
48
+ self.config = config
49
+ self.train_config = train_config
50
+ self.annotation_file = annotation_file
51
+ self.clips_dir = clips_dir
52
+ self.output_path = output_path
53
+ self.should_stop = False
54
+
55
+ def stop(self):
56
+ """Request training stop."""
57
+ self.should_stop = True
58
+
59
+ def _build_model_for_config(self, train_dataset, cfg, log_fn):
60
+ """Create a fresh classifier for one training run."""
61
+ model_name = cfg.get("backbone_model", self.config.get("backbone_model", "videoprism_public_v1_base"))
62
+ resolution = cfg.get("resolution", 288)
63
+ log_fn(f"Loading VideoPrism backbone ({model_name}) at resolution {resolution}×{resolution}...")
64
+ backbone = VideoPrismBackbone(model_name=model_name, resolution=resolution, log_fn=log_fn)
65
+ log_fn("VideoPrism backbone loaded successfully")
66
+
67
+ head_kwargs = cfg.get("head_kwargs", {}).copy()
68
+ head_kwargs.pop("per_class_query", None)
69
+ head_kwargs_for_metadata = head_kwargs.copy()
70
+ dropout = cfg.get("dropout", 0.1)
71
+ localization_dropout = 0.0
72
+
73
+ log_fn("Creating classifier model...")
74
+ if cfg.get("use_temporal_decoder", True):
75
+ log_fn("Using MAP head + temporal decoder")
76
+ else:
77
+ log_fn("Using MAP head + direct per-frame linear classifier")
78
+ log_fn(f"MAP head kwargs: {head_kwargs}")
79
+
80
+ use_loc = cfg.get("use_localization", False)
81
+ num_stages = cfg.get("num_stages", 3)
82
+ if use_loc and num_stages > 1:
83
+ num_stages = 1
84
+ log_fn("Localization ON: reducing MS-TCN to 1 stage (no refinement) to prevent overfitting on tight crops.")
85
+
86
+ multi_scale = cfg.get("multi_scale", False) and not use_loc
87
+ model = BehaviorClassifier(
88
+ backbone,
89
+ num_classes=len(train_dataset.classes),
90
+ class_names=train_dataset.classes,
91
+ dropout=dropout,
92
+ freeze_backbone=True,
93
+ head_kwargs=head_kwargs,
94
+ use_localization=use_loc,
95
+ localization_hidden_dim=cfg.get("localization_hidden_dim", 256),
96
+ localization_dropout=localization_dropout,
97
+ use_frame_head=True,
98
+ use_temporal_decoder=cfg.get("use_temporal_decoder", True),
99
+ frame_head_temporal_layers=cfg.get("frame_head_temporal_layers", 1),
100
+ temporal_pool_frames=cfg.get("temporal_pool_frames", 1),
101
+ proj_dim=cfg.get("proj_dim", 256),
102
+ num_stages=num_stages,
103
+ multi_scale=multi_scale,
104
+ )
105
+ log_fn("Classifier model created successfully")
106
+ return model, head_kwargs_for_metadata, dropout, localization_dropout, num_stages, multi_scale
107
+
108
+ def _build_backend_train_config(
109
+ self,
110
+ cfg,
111
+ output_path,
112
+ classes,
113
+ augmentation_options,
114
+ head_kwargs_for_metadata,
115
+ dropout,
116
+ localization_dropout,
117
+ num_stages,
118
+ multi_scale,
119
+ ):
120
+ return {
121
+ "batch_size": cfg["batch_size"],
122
+ "epochs": cfg["epochs"],
123
+ "lr": cfg.get("classification_lr", cfg["lr"]),
124
+ "localization_lr": cfg.get("localization_lr", cfg["lr"]),
125
+ "classification_lr": cfg.get("classification_lr", cfg["lr"]),
126
+ "use_scheduler": cfg.get("use_scheduler", True),
127
+ "use_ema": cfg.get("use_ema", True),
128
+ "weight_decay": cfg.get("weight_decay", 1e-3),
129
+ "output_path": output_path,
130
+ "save_best": True,
131
+ "use_class_weights": cfg.get("use_class_weights", False),
132
+ "use_focal_loss": cfg.get("use_focal_loss", False),
133
+ "focal_gamma": cfg.get("focal_gamma", 2.0),
134
+ "use_frame_loss": True,
135
+ "use_temporal_decoder": cfg.get("use_temporal_decoder", True),
136
+ "frame_head_temporal_layers": cfg.get("frame_head_temporal_layers", 1),
137
+ "temporal_pool_frames": cfg.get("temporal_pool_frames", 1),
138
+ "num_stages": num_stages,
139
+ "proj_dim": cfg.get("proj_dim", 256),
140
+ "multi_scale": multi_scale,
141
+ "use_frame_bout_balance": cfg.get("use_frame_bout_balance", True),
142
+ "frame_bout_balance_power": cfg.get("frame_bout_balance_power", 1.0),
143
+ "boundary_loss_weight": cfg.get("boundary_loss_weight", 0.3),
144
+ "boundary_tolerance": cfg.get("boundary_tolerance", 2),
145
+ "smoothness_loss_weight": cfg.get("smoothness_loss_weight", 0.05),
146
+ "use_localization": cfg.get("use_localization", False),
147
+ "use_manual_localization_switch": cfg.get("use_manual_localization_switch", False),
148
+ "manual_localization_switch_epoch": cfg.get("manual_localization_switch_epoch", 20),
149
+ "localization_hidden_dim": cfg.get("localization_hidden_dim", 256),
150
+ "classification_crop_padding": cfg.get("classification_crop_padding", 0.35),
151
+ "crop_jitter": cfg.get("crop_jitter", False),
152
+ "crop_jitter_strength": cfg.get("crop_jitter_strength", 0.15),
153
+ "emb_aug_versions": cfg.get("emb_aug_versions", 1),
154
+ "clip_length": cfg.get("clip_length", 8),
155
+ "use_ovr": cfg.get("use_ovr", False),
156
+ "ovr_label_smoothing": cfg.get("ovr_label_smoothing", 0.05),
157
+ "use_asl": cfg.get("use_asl", False),
158
+ "asl_gamma_neg": cfg.get("asl_gamma_neg", 2.0),
159
+ "asl_gamma_pos": cfg.get("asl_gamma_pos", 0.0),
160
+ "asl_clip": cfg.get("asl_clip", 0.05),
161
+ "use_hard_pair_mining": cfg.get("use_hard_pair_mining", False),
162
+ "hard_pairs": cfg.get("hard_pairs", []),
163
+ "hard_pair_margin": cfg.get("hard_pair_margin", 0.5),
164
+ "hard_pair_loss_weight": cfg.get("hard_pair_loss_weight", 0.2),
165
+ "hard_pair_confusion_boost": cfg.get("hard_pair_confusion_boost", 1.5),
166
+ "use_confusion_sampler": cfg.get("use_confusion_sampler", True),
167
+ "confusion_sampler_temperature": cfg.get("confusion_sampler_temperature", 2.0),
168
+ "confusion_sampler_warmup_pct": cfg.get("confusion_sampler_warmup_pct", 0.2),
169
+ "use_weighted_sampler": cfg.get("use_weighted_sampler", False),
170
+ "use_augmentation": cfg.get("use_augmentation", False),
171
+ "augmentation_options": augmentation_options,
172
+ "virtual_expansion": cfg.get("virtual_expansion", 5),
173
+ "stitch_augmentation_prob": cfg.get("stitch_augmentation_prob", 0.0),
174
+ "f1_exclude_classes": cfg.get("f1_exclude_classes", []),
175
+ "ovr_pos_weight_f1_excluded": cfg.get("ovr_pos_weight_f1_excluded", 1.5),
176
+ "val_split": cfg.get("val_split", 0.2),
177
+ "limit_classes": cfg.get("limit_classes", False),
178
+ "selected_classes": cfg.get("selected_classes", []),
179
+ "limit_per_class": cfg.get("limit_per_class", False),
180
+ "per_class_limits": cfg.get("per_class_limits", {}),
181
+ "per_class_val_limits": cfg.get("per_class_val_limits", {}),
182
+ "backbone_model": cfg.get("backbone_model", "videoprism_public_v1_base"),
183
+ "resolution": cfg.get("resolution", 288),
184
+ "use_all_for_training": cfg.get("use_all_for_training", False),
185
+ "use_embedding_diversity": cfg.get("use_embedding_diversity", False),
186
+ "class_names": cfg.get("class_names", classes),
187
+ "pretrained_path": cfg.get("pretrained_path"),
188
+ "head_kwargs": head_kwargs_for_metadata,
189
+ "dropout": dropout,
190
+ "localization_dropout": localization_dropout,
191
+ }
192
+
193
+ def _generate_autotune_candidates(self, base_cfg, num_runs: int):
194
+ """Generate a small random search around the current config."""
195
+ def _uniq_float(values):
196
+ out = []
197
+ for v in values:
198
+ v = float(max(1e-6, min(1.0, v)))
199
+ if all(abs(v - prev) > 1e-12 for prev in out):
200
+ out.append(v)
201
+ return out
202
+
203
+ def _uniq_int(values, lo, hi):
204
+ out = []
205
+ for v in values:
206
+ v = int(max(lo, min(hi, int(round(v)))))
207
+ if v not in out:
208
+ out.append(v)
209
+ return out
210
+
211
+ lr0 = float(base_cfg.get("classification_lr", base_cfg.get("lr", 1e-4)))
212
+ wd0 = float(base_cfg.get("weight_decay", 1e-3))
213
+ drop0 = float(base_cfg.get("dropout", 0.3))
214
+ heads0 = int(base_cfg.get("head_kwargs", {}).get("num_heads", 4))
215
+ layers0 = int(base_cfg.get("frame_head_temporal_layers", 4))
216
+ ovr_ls0 = float(base_cfg.get("ovr_label_smoothing", 0.05))
217
+ use_temporal_decoder = bool(base_cfg.get("use_temporal_decoder", True))
218
+
219
+ lr_vals = _uniq_float([3e-5, 1e-4, 3e-4, 1e-3, lr0 / 3.0, lr0, lr0 * 3.0])
220
+ wd_vals = _uniq_float([1e-5, 1e-4, 5e-4, 1e-3, wd0 / 10.0, wd0, wd0 * 3.0])
221
+ drop_vals = _uniq_float([0.1, 0.2, 0.3, 0.4, drop0])
222
+ head_vals = _uniq_int([2, 4, 8, heads0, heads0 - 2, heads0 + 2], 1, 16)
223
+ layer_vals = [layers0] if not use_temporal_decoder else _uniq_int([1, 2, 3, 4, layers0, layers0 - 1, layers0 + 1], 1, 8)
224
+ ovr_ls_vals = _uniq_float([0.0, 0.02, 0.05, ovr_ls0])
225
+
226
+ rng = random.Random(42)
227
+ candidates = []
228
+ seen = set()
229
+
230
+ def _trial_tuple(cfg):
231
+ return (
232
+ round(float(cfg.get("classification_lr", cfg.get("lr", 0.0))), 12),
233
+ round(float(cfg.get("weight_decay", 0.0)), 12),
234
+ round(float(cfg.get("dropout", 0.0)), 12),
235
+ int(cfg.get("head_kwargs", {}).get("num_heads", 0)),
236
+ int(cfg.get("frame_head_temporal_layers", 0)) if cfg.get("use_temporal_decoder", True) else None,
237
+ round(float(cfg.get("ovr_label_smoothing", 0.0)), 12) if cfg.get("use_ovr", False) else None,
238
+ )
239
+
240
+ base_copy = copy.deepcopy(base_cfg)
241
+ candidates.append(base_copy)
242
+ seen.add(_trial_tuple(base_copy))
243
+
244
+ max_attempts = max(20, num_runs * 20)
245
+ attempts = 0
246
+ while len(candidates) < max(1, int(num_runs)) and attempts < max_attempts:
247
+ attempts += 1
248
+ cfg = copy.deepcopy(base_cfg)
249
+ new_lr = rng.choice(lr_vals)
250
+ cfg["lr"] = new_lr
251
+ cfg["classification_lr"] = new_lr
252
+ cfg["weight_decay"] = rng.choice(wd_vals)
253
+ cfg["dropout"] = rng.choice(drop_vals)
254
+ cfg.setdefault("head_kwargs", {})
255
+ cfg["head_kwargs"]["num_heads"] = rng.choice(head_vals)
256
+ cfg["frame_head_temporal_layers"] = rng.choice(layer_vals)
257
+ if cfg.get("use_ovr", False):
258
+ cfg["ovr_label_smoothing"] = rng.choice(ovr_ls_vals)
259
+ key = _trial_tuple(cfg)
260
+ if key in seen:
261
+ continue
262
+ seen.add(key)
263
+ candidates.append(cfg)
264
+ return candidates[:max(1, int(num_runs))]
265
+
266
+ def _reset_runtime_dataset_caches(self, dataset):
267
+ """Clear per-run disk cache flags so a new training run rebuilds them cleanly."""
268
+ if dataset is None:
269
+ return
270
+ if hasattr(dataset, "_emb_cache_mode"):
271
+ dataset._emb_cache_mode = False
272
+ if hasattr(dataset, "_emb_cache_dir"):
273
+ dataset._emb_cache_dir = None
274
+ if hasattr(dataset, "_roi_cache_mode"):
275
+ dataset._roi_cache_mode = False
276
+ if hasattr(dataset, "_roi_cache_dir"):
277
+ dataset._roi_cache_dir = None
278
+
279
+ def _cleanup_autotune_trial_outputs(self, trial_output, log_fn=None):
280
+ """Delete temporary files produced by one auto-tune trial."""
281
+ output_dir = os.path.dirname(trial_output)
282
+ basename = os.path.splitext(os.path.basename(trial_output))[0]
283
+ paths = [
284
+ trial_output,
285
+ trial_output + ".meta.json",
286
+ os.path.join(output_dir, f"{basename}_best.pt"),
287
+ os.path.join(output_dir, f"{basename}_best.pt.meta.json"),
288
+ os.path.join(output_dir, f"{basename}_training_config.json"),
289
+ os.path.join(output_dir, f"{basename}_checkpoints"),
290
+ os.path.join(output_dir, f"{basename}_crop_progress"),
291
+ ]
292
+ paths.extend(glob.glob(os.path.join(output_dir, f"{basename}_training_log_*.csv")))
293
+ paths.extend(glob.glob(os.path.join(output_dir, f"{basename}_training_plot_*.pdf")))
294
+ for path in paths:
295
+ try:
296
+ if os.path.isdir(path):
297
+ shutil.rmtree(path, ignore_errors=True)
298
+ elif os.path.exists(path):
299
+ os.remove(path)
300
+ except Exception as e:
301
+ if log_fn:
302
+ log_fn(f"Warning: could not delete autotune temp output {path}: {e}")
303
+
304
+ def _run_autotune_search(self, train_dataset, val_dataset, base_cfg, classes, augmentation_options, log_fn, progress_cb, check_stop):
305
+ num_runs = int(max(1, base_cfg.get("auto_tune_runs", 8)))
306
+ search_epochs = int(max(1, base_cfg.get("auto_tune_epochs", min(12, int(base_cfg.get("epochs", 20))))))
307
+ candidates = self._generate_autotune_candidates(base_cfg, num_runs)
308
+ trial_root = tempfile.mkdtemp(prefix="autotune_trials_", dir=os.path.dirname(self.output_path))
309
+ best_cfg = None
310
+ best_result = None
311
+ best_score = float("-inf")
312
+
313
+ log_fn(f"Auto-tune enabled: evaluating {len(candidates)} candidate config(s) for {search_epochs} epoch(s) each.")
314
+ log_fn("Search parameters: classification LR, weight decay, dropout, MAP num_heads, frame temporal layers"
315
+ + (", OvR label smoothing" if base_cfg.get("use_ovr", False) else ""))
316
+ try:
317
+ for idx, candidate in enumerate(candidates, start=1):
318
+ if check_stop():
319
+ return None, None
320
+ self._reset_runtime_dataset_caches(train_dataset)
321
+ self._reset_runtime_dataset_caches(val_dataset)
322
+ trial_cfg = copy.deepcopy(candidate)
323
+ trial_cfg["epochs"] = search_epochs
324
+ trial_output = os.path.join(trial_root, f"trial_{idx:02d}.pt")
325
+ log_fn(
326
+ f"[Auto-tune {idx}/{len(candidates)}] "
327
+ f"lr={trial_cfg['classification_lr']:.2e}, "
328
+ f"wd={trial_cfg['weight_decay']:.2e}, "
329
+ f"dropout={trial_cfg['dropout']:.2f}, "
330
+ f"heads={trial_cfg['head_kwargs'].get('num_heads', 4)}, "
331
+ f"layers={trial_cfg['frame_head_temporal_layers']}"
332
+ + (f", ovr_ls={trial_cfg.get('ovr_label_smoothing', 0.0):.2f}" if trial_cfg.get("use_ovr", False) else "")
333
+ )
334
+ model, head_kwargs_meta, dropout, localization_dropout, num_stages, multi_scale = self._build_model_for_config(
335
+ train_dataset, trial_cfg, log_fn
336
+ )
337
+ backend_cfg = self._build_backend_train_config(
338
+ trial_cfg,
339
+ trial_output,
340
+ classes,
341
+ augmentation_options,
342
+ head_kwargs_meta,
343
+ dropout,
344
+ localization_dropout,
345
+ num_stages,
346
+ multi_scale,
347
+ )
348
+ result = train_model(
349
+ model,
350
+ train_dataset,
351
+ val_dataset,
352
+ backend_cfg,
353
+ log_fn=log_fn,
354
+ progress_callback=progress_cb,
355
+ stop_callback=check_stop,
356
+ metrics_callback=None,
357
+ )
358
+ score = float(result.get("best_val_f1", 0.0) or 0.0)
359
+ acc = float(result.get("best_val_acc", 0.0) or 0.0)
360
+ log_fn(f"[Auto-tune {idx}/{len(candidates)}] best val F1={score:.2f}, val acc={acc:.2f}")
361
+ if score > best_score:
362
+ best_score = score
363
+ best_cfg = copy.deepcopy(trial_cfg)
364
+ best_result = dict(result)
365
+ self._cleanup_autotune_trial_outputs(trial_output, log_fn=log_fn)
366
+ return best_cfg, best_result
367
+ finally:
368
+ try:
369
+ shutil.rmtree(trial_root)
370
+ except Exception as e:
371
+ logger.debug("Could not remove autotune trial root: %s", e)
372
+
373
+ def _save_autotuned_profile(self, tuned_cfg, best_search_result, log_fn):
374
+ """Persist the selected auto-tune winner as a reusable training profile."""
375
+ profiles_path = self.train_config.get("training_profiles_path", "")
376
+ if not profiles_path:
377
+ return
378
+ try:
379
+ os.makedirs(os.path.dirname(profiles_path), exist_ok=True)
380
+ profiles = {}
381
+ if os.path.exists(profiles_path):
382
+ with open(profiles_path, "r", encoding="utf-8") as f:
383
+ loaded = json.load(f)
384
+ if isinstance(loaded, dict):
385
+ profiles = loaded
386
+
387
+ profile_cfg = copy.deepcopy(tuned_cfg)
388
+ profile_cfg["auto_tune_before_final"] = False
389
+ profile_cfg["auto_tune_selected_from_search"] = True
390
+ profile_cfg["auto_tune_selected_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
391
+ if best_search_result:
392
+ profile_cfg["auto_tune_best_search_val_f1"] = float(best_search_result.get("best_val_f1", 0.0) or 0.0)
393
+ profile_cfg["auto_tune_best_search_val_acc"] = float(best_search_result.get("best_val_acc", 0.0) or 0.0)
394
+
395
+ base_name = os.path.splitext(os.path.basename(self.output_path))[0] or "training"
396
+ profile_name = f"autotune_{base_name}_{time.strftime('%Y%m%d_%H%M%S')}"
397
+ suffix = 2
398
+ while profile_name in profiles:
399
+ profile_name = f"autotune_{base_name}_{time.strftime('%Y%m%d_%H%M%S')}_{suffix}"
400
+ suffix += 1
401
+
402
+ profiles[profile_name] = profile_cfg
403
+ with open(profiles_path, "w", encoding="utf-8") as f:
404
+ json.dump(profiles, f, indent=2)
405
+ log_fn(f"Saved auto-tuned winner as training profile: {profile_name}")
406
+ except Exception as e:
407
+ log_fn(f"Warning: could not save auto-tuned profile: {e}")
408
+
409
+ def run(self):
410
+ """Run training."""
411
+ import traceback
412
+ try:
413
+ def log_fn(msg):
414
+ self.log_message.emit(msg)
415
+
416
+ def progress_cb(epoch, total):
417
+ self.progress.emit(epoch, total)
418
+
419
+ def metrics_cb(metrics):
420
+ self.epoch_complete.emit(metrics)
421
+
422
+ log_fn("Initializing training...")
423
+ log_fn(f"Annotation file: {self.annotation_file}")
424
+ log_fn(f"Clips directory: {self.clips_dir}")
425
+
426
+ annotation_manager = AnnotationManager(self.annotation_file)
427
+ labeled_clips = annotation_manager.get_labeled_clips()
428
+
429
+ if not labeled_clips:
430
+ error_msg = "No labeled clips found. Please label some clips first."
431
+ log_fn(f"ERROR: {error_msg}")
432
+ self.error.emit(error_msg)
433
+ return
434
+
435
+ classes = annotation_manager.get_classes()
436
+
437
+ log_fn("Using class-only training pipeline (hierarchical/attribution disabled).")
438
+
439
+ if len(classes) < 2:
440
+ error_msg = "Need at least 2 behavior classes for training."
441
+ log_fn(f"ERROR: {error_msg}")
442
+ self.error.emit(error_msg)
443
+ return
444
+
445
+ # Filter classes if class selection is enabled
446
+ use_ovr = self.train_config.get("use_ovr", False)
447
+ hybrid_ovr_bg = bool(self.train_config.get("ovr_background_as_negative", False) and use_ovr)
448
+ hybrid_bg_classes = set(self.train_config.get("ovr_background_class_names", []))
449
+ selected_classes = None
450
+ if self.train_config.get("limit_classes", False):
451
+ selected_classes = self.train_config.get("selected_classes", [])
452
+ if selected_classes:
453
+ classes = [c for c in classes if c in selected_classes]
454
+ log_fn(f"Limiting training to {len(classes)} selected classes: {classes}")
455
+ if len(classes) < 2:
456
+ error_msg = "Need at least 2 selected classes for training."
457
+ log_fn(f"ERROR: {error_msg}")
458
+ self.error.emit(error_msg)
459
+ return
460
+ # Filter clips: keep selected classes + near_negative clips when OvR
461
+ if use_ovr:
462
+ allowed = set(selected_classes)
463
+ if hybrid_ovr_bg:
464
+ allowed.update(hybrid_bg_classes)
465
+ labeled_clips = [
466
+ clip for clip in labeled_clips
467
+ if clip.get("label") in allowed
468
+ or clip.get("label", "").startswith("near_negative")
469
+ ]
470
+ else:
471
+ labeled_clips = [clip for clip in labeled_clips if clip.get("label") in classes]
472
+
473
+ # OvR mode: filter near_negative_* from class list (they become suppression-only)
474
+ if use_ovr:
475
+ real_classes = [c for c in classes if not c.startswith("near_negative")]
476
+ hn_classes = [c for c in classes if c.startswith("near_negative")]
477
+ bg_classes = [c for c in real_classes if c in hybrid_bg_classes] if hybrid_ovr_bg else []
478
+ if hybrid_ovr_bg and bg_classes:
479
+ real_classes = [c for c in real_classes if c not in bg_classes]
480
+ if len(real_classes) < 2:
481
+ error_msg = "OvR mode needs at least 2 non-near-negative classes."
482
+ log_fn(f"ERROR: {error_msg}")
483
+ self.error.emit(error_msg)
484
+ return
485
+ log_fn(f"OvR mode: {len(real_classes)} real classes, {len(hn_classes)} near-negative classes (suppression only)")
486
+ if hybrid_ovr_bg and bg_classes:
487
+ log_fn(
488
+ f"Hybrid OvR backgrounds: {bg_classes} kept as negative-only clips "
489
+ f"(not trained as explicit heads)"
490
+ )
491
+ classes = real_classes
492
+ # Near-negative clips stay in labeled_clips (they get label=-1 in the dataset)
493
+
494
+ log_fn(f"Found {len(labeled_clips)} labeled clips")
495
+ log_fn(f"Classes: {classes}")
496
+
497
+ from collections import Counter
498
+ label_counts = Counter([clip["label"] for clip in labeled_clips])
499
+ log_fn("Class distribution (before limiting):")
500
+ for label, count in sorted(label_counts.items()):
501
+ log_fn(f" {label}: {count} ({100.0*count/len(labeled_clips):.1f}%)")
502
+
503
+ # Multi-class breakdown
504
+ mc_combos: dict[tuple, int] = {}
505
+ mc_per_label: dict[str, int] = {}
506
+ exc_per_label: dict[str, int] = {}
507
+ for clip in labeled_clips:
508
+ lbl_list = clip.get("labels")
509
+ if not isinstance(lbl_list, list) or not lbl_list:
510
+ lbl_list = [clip.get("label", "")]
511
+ if len(lbl_list) > 1:
512
+ key = tuple(sorted(lbl_list))
513
+ mc_combos[key] = mc_combos.get(key, 0) + 1
514
+ for lbl in lbl_list:
515
+ mc_per_label[lbl] = mc_per_label.get(lbl, 0) + 1
516
+ else:
517
+ exc_per_label[lbl_list[0]] = exc_per_label.get(lbl_list[0], 0) + 1
518
+ if mc_combos:
519
+ total_mc = sum(mc_combos.values())
520
+ log_fn(f"Multi-class clips: {total_mc} of {len(labeled_clips)} ({100.0*total_mc/max(1,len(labeled_clips)):.1f}%)")
521
+ for combo, cnt in sorted(mc_combos.items(), key=lambda x: -x[1]):
522
+ log_fn(f" {' + '.join(combo)}: {cnt}")
523
+ for lbl in sorted(mc_per_label):
524
+ exc = exc_per_label.get(lbl, 0)
525
+ sh = mc_per_label[lbl]
526
+ log_fn(f" {lbl}: {exc} exclusive, {sh} in multi-class clips")
527
+
528
+ use_all_for_training = self.train_config.get("use_all_for_training", False)
529
+ val_split = self.train_config.get("val_split", 0.2)
530
+
531
+ def _split_train_val_clip_stratified(clips, split_ratio, seed=42):
532
+ """Split clips by randomly assigning ~split_ratio per class to val (stratified by clip label)."""
533
+ if not clips or split_ratio <= 0:
534
+ return list(clips), []
535
+ by_label = {}
536
+ for c in clips:
537
+ lbl = c.get("label")
538
+ by_label.setdefault(lbl, []).append(c)
539
+ rng = random.Random(seed)
540
+ train_out, val_out = [], []
541
+ for lbl, group in by_label.items():
542
+ rng.shuffle(group)
543
+ n_val = max(0, int(round(len(group) * split_ratio)))
544
+ n_val = min(n_val, len(group) - 1) if len(group) > 1 else 0
545
+ train_out.extend(group[n_val:])
546
+ val_out.extend(group[:n_val])
547
+ log_fn(f"Clip-stratified split: train={len(train_out)} clips, val={len(val_out)} clips")
548
+ if val_out:
549
+ log_fn(f" Val clip distribution: {dict(Counter(c.get('label') for c in val_out))}")
550
+ return train_out, val_out
551
+
552
+ def _split_train_val_frame_stratified(clips, classes, split_ratio, seed=42):
553
+ """Split by randomly selecting frames per class for val, then assign whole clips to val if they contain any selected frame.
554
+ Ensures roughly equal proportion of each class in validation."""
555
+ if not clips or split_ratio <= 0:
556
+ return list(clips), []
557
+ class_set = set(classes)
558
+ _clip_len = self.train_config.get("clip_length", 8)
559
+ # Pool (clip_idx, frame_idx) for every labeled frame.
560
+ # Clips without per-frame labels contribute _clip_len entries so their
561
+ # weight in stratification matches what validation will actually count.
562
+ pool_by_class = {c: [] for c in classes}
563
+ for clip_idx, clip in enumerate(clips):
564
+ fl = clip.get("frame_labels")
565
+ if not isinstance(fl, (list, tuple)) or len(fl) == 0:
566
+ primary = clip.get("label")
567
+ if primary in class_set:
568
+ for fi in range(_clip_len):
569
+ pool_by_class[primary].append((clip_idx, fi))
570
+ continue
571
+ for frame_idx, lbl in enumerate(fl):
572
+ if lbl in class_set:
573
+ pool_by_class[lbl].append((clip_idx, frame_idx))
574
+ rng = random.Random(seed)
575
+ val_frames = set()
576
+ for c in classes:
577
+ lst = pool_by_class[c]
578
+ if not lst:
579
+ continue
580
+ rng.shuffle(lst)
581
+ n_val = max(0, int(round(len(lst) * split_ratio)))
582
+ n_val = min(n_val, len(lst))
583
+ for i in range(n_val):
584
+ val_frames.add(lst[i])
585
+ val_clip_indices = {clip_idx for (clip_idx, _) in val_frames}
586
+ # Cap val clips so train gets enough data: at most split_ratio of clips in val.
587
+ max_val_clips = max(1, int(round(len(clips) * split_ratio)))
588
+ if len(val_clip_indices) > max_val_clips:
589
+ val_clip_list = list(val_clip_indices)
590
+ rng.shuffle(val_clip_list)
591
+ val_clip_indices = set(val_clip_list[:max_val_clips])
592
+ # Keep only frames belonging to clips that survived the cap.
593
+ val_frames = {f for f in val_frames if f[0] in val_clip_indices}
594
+ train_out = [c for i, c in enumerate(clips) if i not in val_clip_indices]
595
+ val_out = [c for i, c in enumerate(clips) if i in val_clip_indices]
596
+ log_fn(f"Frame-stratified split: {len(val_frames)} val frames across {len(val_clip_indices)} clips (cap ~{split_ratio:.0%} clips) → "
597
+ f"train={len(train_out)} clips, val={len(val_out)} clips")
598
+ if val_out:
599
+ val_label_counts = Counter()
600
+ for c in val_out:
601
+ fl = c.get("frame_labels")
602
+ if isinstance(fl, (list, tuple)):
603
+ for lbl in fl:
604
+ if lbl in class_set:
605
+ val_label_counts[lbl] += 1
606
+ else:
607
+ val_label_counts[c.get("label")] += 1
608
+ log_fn(f" Val frame distribution: {dict(val_label_counts)}")
609
+ return train_out, val_out
610
+
611
+ # Apply per-class limits to TRAINING ONLY (validation uses all remaining clips)
612
+ per_class_limits = self.train_config.get("per_class_limits", {})
613
+ per_class_val_limits = self.train_config.get("per_class_val_limits", {})
614
+ if per_class_limits:
615
+ log_fn("Applying per-class limits to TRAINING set only...")
616
+ log_fn("Validation will use ALL remaining clips (not limited)")
617
+
618
+ train_clips = []
619
+ val_clips = []
620
+
621
+ random.seed(42) # For reproducibility
622
+
623
+ # Check if embedding-based diversity selection is enabled
624
+ use_embedding_diversity = self.train_config.get("use_embedding_diversity", False)
625
+ backbone_model = self.train_config.get("backbone_model", "videoprism_public_v1_base")
626
+
627
+ resolution = self.train_config.get("resolution", 288)
628
+
629
+ # Initialize VideoPrism backbone if needed for diversity selection
630
+ backbone = None
631
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
632
+ if use_embedding_diversity:
633
+ try:
634
+ log_fn(f"Initializing VideoPrism for embedding-based diversity selection at {resolution}×{resolution}...")
635
+ backbone = VideoPrismBackbone(model_name=backbone_model, resolution=resolution)
636
+ backbone.eval()
637
+ backbone.to(device)
638
+ log_fn("VideoPrism loaded successfully")
639
+ except Exception as e:
640
+ log_fn(f"Failed to load VideoPrism: {e}")
641
+ log_fn(" Falling back to random selection")
642
+ use_embedding_diversity = False
643
+
644
+ def farthest_point_sampling(embeddings, n_samples, seed=42):
645
+ """Select n_samples using Farthest Point Sampling for maximum diversity."""
646
+ if len(embeddings) <= n_samples:
647
+ return list(range(len(embeddings)))
648
+
649
+ np.random.seed(seed)
650
+ embeddings = np.array(embeddings)
651
+ selected = []
652
+
653
+ # Start with random point
654
+ current = np.random.randint(0, len(embeddings))
655
+ selected.append(current)
656
+
657
+ # Track minimum distance to any selected point for each unselected point
658
+ min_distances = np.full(len(embeddings), np.inf)
659
+
660
+ # Iteratively select farthest point from all selected
661
+ for _ in range(n_samples - 1):
662
+ # Update minimum distances to nearest selected point
663
+ for i in range(len(embeddings)):
664
+ if i not in selected:
665
+ # Distance to the most recently selected point
666
+ dist_to_current = np.linalg.norm(embeddings[i] - embeddings[current])
667
+ # Keep minimum distance to any selected point
668
+ min_distances[i] = min(min_distances[i], dist_to_current)
669
+
670
+ # Select point with maximum minimum distance (farthest from all selected)
671
+ current = np.argmax(min_distances)
672
+ selected.append(current)
673
+ min_distances[current] = 0 # Mark as selected
674
+
675
+ return selected
676
+
677
+ for label in sorted(set([clip["label"] for clip in labeled_clips])):
678
+ class_clips = [clip for clip in labeled_clips if clip.get("label") == label]
679
+
680
+ if label in per_class_limits:
681
+ raw_limit = per_class_limits[label]
682
+ limit = max(1, int(round(float(raw_limit))))
683
+ raw_val_limit = per_class_val_limits.get(label, float('inf'))
684
+ val_limit = max(1, int(round(float(raw_val_limit)))) if raw_val_limit != float('inf') else float('inf')
685
+ if len(class_clips) > limit:
686
+ log_fn(f" {label}: {len(class_clips)} total clips")
687
+
688
+ if use_embedding_diversity and backbone is not None:
689
+ try:
690
+ log_fn(f" Extracting orientation-invariant embeddings for diversity selection...")
691
+ embeddings = []
692
+ valid_clips = []
693
+
694
+ for clip_idx, clip in enumerate(class_clips):
695
+ if (clip_idx + 1) % 20 == 0:
696
+ log_fn(f" Processing clip {clip_idx + 1}/{len(class_clips)}...")
697
+
698
+ clip_id = clip.get("id", "")
699
+ clip_path = os.path.join(self.clips_dir, clip_id)
700
+
701
+ # Try to find clip file
702
+ if not os.path.exists(clip_path):
703
+ base_name, ext = os.path.splitext(clip_id)
704
+ for video_ext in ['.mp4', '.avi', '.mov', '.mkv']:
705
+ test_path = os.path.join(self.clips_dir, base_name + video_ext)
706
+ if os.path.exists(test_path):
707
+ clip_path = test_path
708
+ break
709
+
710
+ if not os.path.exists(clip_path):
711
+ continue
712
+
713
+ try:
714
+ # Load frames
715
+ frames_bgr = load_clip_frames(clip_path, target_size=(resolution, resolution))
716
+ if not frames_bgr:
717
+ continue
718
+
719
+ # Convert BGR to RGB and normalize
720
+ frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames_bgr]
721
+ frames_array = np.stack(frames_rgb).astype(np.float32) / 255.0
722
+
723
+ # Extract embeddings with orientation augmentation (original + hflip + vflip + both)
724
+ embedding_list = []
725
+
726
+ for hflip, vflip in [(False, False), (True, False), (False, True), (True, True)]:
727
+ frames_aug = frames_array.copy()
728
+
729
+ # Apply flips
730
+ if hflip:
731
+ frames_aug = np.flip(frames_aug, axis=2) # Flip horizontally (width)
732
+ if vflip:
733
+ frames_aug = np.flip(frames_aug, axis=1) # Flip vertically (height)
734
+
735
+ # Copy after all flips to avoid negative strides (PyTorch requirement)
736
+ frames_aug = frames_aug.copy()
737
+
738
+ # Ensure correct shape: (T, H, W, C) -> (1, T, C, H, W) for VideoPrism
739
+ T, H, W, C = frames_aug.shape
740
+ frames_tensor = torch.from_numpy(frames_aug).permute(0, 3, 1, 2) # (T, C, H, W)
741
+ frames_tensor = frames_tensor.unsqueeze(0) # (1, T, C, H, W)
742
+
743
+ # Extract embedding
744
+ with torch.no_grad():
745
+ tokens = backbone(frames_tensor.to(device))
746
+ # Average pool tokens: (1, N, D) -> (D,)
747
+ emb = tokens.mean(dim=1).squeeze(0).cpu().numpy()
748
+ embedding_list.append(emb)
749
+ # Clear GPU tensors immediately
750
+ del tokens, frames_tensor
751
+
752
+ # Average embeddings from all orientations
753
+ embedding = np.mean(embedding_list, axis=0)
754
+
755
+ # L2 normalize for better diversity in angular space
756
+ embedding_norm = np.linalg.norm(embedding)
757
+ if embedding_norm > 0:
758
+ embedding = embedding / embedding_norm
759
+
760
+ embeddings.append(embedding)
761
+ valid_clips.append(clip)
762
+
763
+ # Clear intermediate variables
764
+ del embedding_list, frames_array, frames_rgb, frames_bgr
765
+
766
+ # Periodically clear CUDA cache to prevent accumulation
767
+ if (clip_idx + 1) % 50 == 0 and torch.cuda.is_available():
768
+ torch.cuda.empty_cache()
769
+ except Exception as e:
770
+ log_fn(f" Failed to extract embedding for {clip_id}: {e}")
771
+ continue
772
+
773
+ if len(embeddings) >= limit:
774
+ embeddings = np.array(embeddings)
775
+ log_fn(f" Using Farthest Point Sampling to select {limit} diverse clips...")
776
+ selected_indices = farthest_point_sampling(embeddings, limit, seed=42)
777
+ train_class_clips = [valid_clips[i] for i in selected_indices]
778
+ train_clip_ids = {id(c) for c in train_class_clips}
779
+ val_class_clips = [c for c in class_clips if id(c) not in train_clip_ids]
780
+ log_fn(f" Selected {len(train_class_clips)} diverse clips based on embeddings")
781
+ else:
782
+ log_fn(f" Only {len(embeddings)} valid embeddings, falling back to random")
783
+ random.shuffle(class_clips)
784
+ train_class_clips = class_clips[:limit]
785
+ val_class_clips = class_clips[limit:]
786
+ except Exception as e:
787
+ log_fn(f" Embedding-based selection failed: {e}")
788
+ log_fn(f" Falling back to random selection")
789
+ random.shuffle(class_clips)
790
+ train_class_clips = class_clips[:limit]
791
+ val_class_clips = class_clips[limit:]
792
+ else:
793
+ log_fn(f" Training: randomly selecting {limit} clips")
794
+ random.shuffle(class_clips)
795
+ train_class_clips = class_clips[:limit]
796
+ val_class_clips = class_clips[limit:]
797
+
798
+ # Apply validation limit if set
799
+ if len(val_class_clips) > val_limit:
800
+ log_fn(f" Validation: limiting to {val_limit} clips (from {len(val_class_clips)} available)")
801
+ random.shuffle(val_class_clips)
802
+ val_class_clips = val_class_clips[:val_limit]
803
+ else:
804
+ log_fn(f" Validation: using all remaining {len(val_class_clips)} clips")
805
+
806
+ train_clips.extend(train_class_clips)
807
+ val_clips.extend(val_class_clips)
808
+ else:
809
+ # Not enough clips to limit - use all for training, all for validation
810
+ log_fn(f" {label}: {len(class_clips)} total clips (below limit of {limit})")
811
+ log_fn(f" Training: using all {len(class_clips)} clips (can't limit below available)")
812
+ if use_all_for_training:
813
+ train_clips.extend(class_clips)
814
+ else:
815
+ # Keep train/val disjoint even for tiny classes.
816
+ # If there is only 1 clip, we cannot hold out validation without losing
817
+ # training signal, so we place it in training only.
818
+ if len(class_clips) <= 1:
819
+ train_clips.extend(class_clips)
820
+ log_fn(" Validation: 0 clips (class too small to split without overlap)")
821
+ else:
822
+ tmp = list(class_clips)
823
+ random.shuffle(tmp)
824
+ n = len(tmp)
825
+ n_val = int(round(n * float(val_split)))
826
+ if n_val <= 0:
827
+ n_val = 1
828
+ if n_val >= n:
829
+ n_val = n - 1
830
+ val_class_clips = tmp[:n_val]
831
+ train_class_clips = tmp[n_val:]
832
+ # Apply validation cap if requested
833
+ if val_limit != float('inf') and len(val_class_clips) > int(val_limit):
834
+ random.shuffle(val_class_clips)
835
+ val_class_clips = val_class_clips[:int(val_limit)]
836
+ train_clips.extend(train_class_clips)
837
+ val_clips.extend(val_class_clips)
838
+ log_fn(f" Validation: {len(val_class_clips)} clips (held out, no overlap)")
839
+ else:
840
+ # No limit for this class, split normally
841
+ if use_all_for_training:
842
+ train_clips.extend(class_clips)
843
+ else:
844
+ train_class, val_class = _split_train_val_clip_stratified(
845
+ class_clips,
846
+ val_split,
847
+ seed=42,
848
+ )
849
+ train_clips.extend(train_class)
850
+ val_clips.extend(val_class)
851
+
852
+ # Free GPU memory after diversity selection
853
+ if backbone is not None:
854
+ log_fn("Freeing VideoPrism backbone from GPU memory...")
855
+ del backbone
856
+ if torch.cuda.is_available():
857
+ torch.cuda.empty_cache()
858
+ log_fn("GPU memory freed")
859
+
860
+ log_fn(f"Final dataset sizes:")
861
+ log_fn(f" Training: {len(train_clips)} clips")
862
+ if not use_all_for_training:
863
+ log_fn(f" Validation: {len(val_clips)} clips")
864
+
865
+ def _log_split_distribution(clips, split_name):
866
+ counts = Counter([c["label"] for c in clips])
867
+ log_fn(f"{split_name} class distribution:")
868
+ mc = 0
869
+ for c in clips:
870
+ ll = c.get("labels")
871
+ if isinstance(ll, list) and len(ll) > 1:
872
+ mc += 1
873
+ for label, count in sorted(counts.items()):
874
+ log_fn(f" {label}: {count}")
875
+ if mc > 0:
876
+ log_fn(f" ({mc} multi-class clips in {split_name.lower()} set)")
877
+
878
+ _log_split_distribution(train_clips, "Training")
879
+ if not use_all_for_training:
880
+ _log_split_distribution(val_clips, "Validation")
881
+ else:
882
+ # No per-class limits, use standard train/val split
883
+ if use_all_for_training:
884
+ log_fn("Using all data for training (no validation split)")
885
+ train_clips = labeled_clips
886
+ val_clips = []
887
+ else:
888
+ log_fn(f"Splitting dataset into train/val (val_split={val_split:.1%}, stratified when possible)...")
889
+ has_frame_labels = any(
890
+ isinstance(c.get("frame_labels"), (list, tuple)) and len(c.get("frame_labels") or []) > 0
891
+ for c in labeled_clips
892
+ )
893
+ if has_frame_labels:
894
+ log_fn("Using frame-stratified split (equal proportion of each class in validation).")
895
+ train_clips, val_clips = _split_train_val_frame_stratified(
896
+ labeled_clips,
897
+ classes,
898
+ val_split,
899
+ seed=42,
900
+ )
901
+ else:
902
+ log_fn("Using clip-stratified split (no per-frame labels).")
903
+ train_clips, val_clips = _split_train_val_clip_stratified(
904
+ labeled_clips,
905
+ val_split,
906
+ seed=42,
907
+ )
908
+ log_fn(f"Train: {len(train_clips)}, Val: {len(val_clips)}")
909
+
910
+ log_fn(f"Train: {len(train_clips)} clips")
911
+
912
+ log_fn("Validating clip files exist...")
913
+ log_fn(f"Checking clips directory: {self.clips_dir}")
914
+
915
+ if not os.path.exists(self.clips_dir):
916
+ error_msg = f"Clips directory does not exist: {self.clips_dir}\n\nPlease check the clips directory path in the Training tab."
917
+ log_fn(f"ERROR: {error_msg}")
918
+ self.error.emit(error_msg)
919
+ return
920
+
921
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
922
+ all_video_files = []
923
+ for root, dirs, files in os.walk(self.clips_dir):
924
+ for file in files:
925
+ if any(file.lower().endswith(ext.lower()) for ext in video_extensions):
926
+ rel_path = os.path.relpath(os.path.join(root, file), self.clips_dir)
927
+ all_video_files.append(rel_path.replace('\\', '/'))
928
+
929
+ log_fn(f"Found {len(all_video_files)} video files in directory (including subdirectories)")
930
+ if len(all_video_files) > 0 and len(all_video_files) < 10:
931
+ log_fn(f"Video files found: {', '.join(all_video_files)}")
932
+
933
+ missing_clips = []
934
+
935
+ for clip_info in train_clips:
936
+ clip_id = clip_info["id"]
937
+ clip_path = os.path.join(self.clips_dir, clip_id)
938
+ found = False
939
+
940
+ if os.path.exists(clip_path):
941
+ found = True
942
+ else:
943
+ base_name, ext = os.path.splitext(clip_id)
944
+ clip_basename = os.path.basename(clip_id)
945
+ clip_dir_part = os.path.dirname(clip_id) if os.path.dirname(clip_id) else None
946
+
947
+ if not ext:
948
+ for video_ext in video_extensions:
949
+ test_path = os.path.join(self.clips_dir, clip_id + video_ext)
950
+ if os.path.exists(test_path):
951
+ found = True
952
+ break
953
+ else:
954
+ base_name_only = os.path.basename(base_name)
955
+
956
+ for video_ext in video_extensions:
957
+ test_path = os.path.join(self.clips_dir, base_name + video_ext)
958
+ if os.path.exists(test_path):
959
+ found = True
960
+ break
961
+
962
+ if not found:
963
+ test_path = os.path.join(self.clips_dir, base_name_only + video_ext)
964
+ if os.path.exists(test_path):
965
+ found = True
966
+ break
967
+
968
+ if not found:
969
+ for root, dirs, files in os.walk(self.clips_dir):
970
+ for file in files:
971
+ file_base, file_ext = os.path.splitext(file)
972
+ if file_base == base_name_only or file_base == base_name:
973
+ if file_ext.lower() in [e.lower() for e in video_extensions]:
974
+ found = True
975
+ break
976
+ if found:
977
+ break
978
+
979
+ if not found and clip_dir_part:
980
+ subdir_path = os.path.join(self.clips_dir, clip_dir_part)
981
+ if os.path.exists(subdir_path):
982
+ for video_ext in video_extensions:
983
+ test_path = os.path.join(subdir_path, clip_basename)
984
+ if os.path.exists(test_path):
985
+ found = True
986
+ break
987
+ test_path = os.path.join(subdir_path, base_name_only + video_ext)
988
+ if os.path.exists(test_path):
989
+ found = True
990
+ break
991
+
992
+ if not found:
993
+ missing_clips.append(clip_id)
994
+
995
+ if missing_clips:
996
+ log_fn(f"WARNING: Found {len(missing_clips)} missing clip files")
997
+ log_fn("Searching for files in subdirectories...")
998
+
999
+ found_in_subdirs = {}
1000
+ sample_missing = missing_clips[:5]
1001
+
1002
+ for clip_id in sample_missing:
1003
+ base_name, ext = os.path.splitext(clip_id)
1004
+ for root, dirs, files in os.walk(self.clips_dir):
1005
+ for file in files:
1006
+ file_base, file_ext = os.path.splitext(file)
1007
+ if file_base == base_name or file_base == os.path.basename(base_name):
1008
+ if file_ext.lower() in [e.lower() for e in video_extensions]:
1009
+ rel_path = os.path.relpath(os.path.join(root, file), self.clips_dir)
1010
+ if clip_id not in found_in_subdirs:
1011
+ found_in_subdirs[clip_id] = rel_path
1012
+
1013
+ if found_in_subdirs:
1014
+ log_fn(f"Found {len(found_in_subdirs)} files in subdirectories:")
1015
+ for clip_id, found_path in list(found_in_subdirs.items())[:3]:
1016
+ log_fn(f" {clip_id} -> {found_path}")
1017
+ log_fn("SUGGESTION: Your clips are in subdirectories. Update the clip IDs in annotations.json")
1018
+ log_fn(" to include the subdirectory path, or move files to the root clips directory.")
1019
+
1020
+ error_msg = f"Found {len(missing_clips)} missing clip files (out of {len(train_clips)} total)\n\n"
1021
+ error_msg += f"Clips directory: {self.clips_dir}\n\n"
1022
+ error_msg += "Sample missing files:\n"
1023
+ for clip_id in missing_clips[:10]:
1024
+ error_msg += f" - {clip_id}\n"
1025
+ if len(missing_clips) > 10:
1026
+ error_msg += f" ... and {len(missing_clips) - 10} more\n"
1027
+
1028
+ if found_in_subdirs:
1029
+ error_msg += f"\nNOTE: Found {len(found_in_subdirs)} of the sample files in subdirectories.\n"
1030
+ error_msg += "Your annotation file may need to include subdirectory paths in clip IDs.\n"
1031
+ error_msg += "Example: 'subfolder/span10.avi' instead of 'span10.avi'\n"
1032
+
1033
+ error_msg += f"\nPlease:\n"
1034
+ error_msg += f"1. Check that clip files exist in: {self.clips_dir}\n"
1035
+ error_msg += f"2. Verify the clips directory path in the Training tab\n"
1036
+ error_msg += f"3. Update annotations.json to include correct paths if files are in subdirectories"
1037
+
1038
+ log_fn(f"ERROR: {error_msg}")
1039
+ self.error.emit(error_msg)
1040
+ return
1041
+
1042
+ log_fn(f"All {len(train_clips)} clips validated successfully")
1043
+
1044
+ log_fn("Creating datasets...")
1045
+ try:
1046
+ from singlebehaviorlab.backend.augmentations import ClipAugment
1047
+
1048
+ use_augmentation = self.train_config.get("use_augmentation", False)
1049
+ augmentation_options = self.train_config.get("augmentation_options", None)
1050
+ if not isinstance(augmentation_options, dict):
1051
+ augmentation_options = {}
1052
+ transform = None
1053
+ if use_augmentation:
1054
+ augmentation_defaults = {
1055
+ "use_horizontal_flip": True,
1056
+ "use_vertical_flip": False,
1057
+ "use_color_jitter": True,
1058
+ "use_gaussian_blur": True,
1059
+ "use_random_noise": True,
1060
+ "use_small_rotation": False,
1061
+ "use_speed_perturb": False,
1062
+ "use_random_shapes": False,
1063
+ "use_grayscale": False,
1064
+ "use_lighting_robustness": True,
1065
+ }
1066
+ for key, value in augmentation_defaults.items():
1067
+ if key not in augmentation_options:
1068
+ augmentation_options[key] = value
1069
+ if self.train_config.get("use_localization", False) and augmentation_options.get("use_small_rotation", False):
1070
+ # Rotation changes object coordinates; current spatial-label augmentation
1071
+ # only mirrors bboxes, so disable rotation for localization training.
1072
+ augmentation_options["use_small_rotation"] = False
1073
+ log_fn("Localization is enabled: disabling small rotation to keep bbox supervision aligned.")
1074
+ transform = ClipAugment(
1075
+ use_horizontal_flip=augmentation_options["use_horizontal_flip"],
1076
+ use_vertical_flip=augmentation_options["use_vertical_flip"],
1077
+ use_color_jitter=augmentation_options["use_color_jitter"],
1078
+ use_gaussian_blur=augmentation_options["use_gaussian_blur"],
1079
+ use_random_noise=augmentation_options["use_random_noise"],
1080
+ use_small_rotation=augmentation_options["use_small_rotation"],
1081
+ use_speed_perturb=augmentation_options.get("use_speed_perturb", False),
1082
+ use_random_shapes=augmentation_options.get("use_random_shapes", False),
1083
+ use_grayscale=augmentation_options.get("use_grayscale", False),
1084
+ use_lighting_robustness=augmentation_options.get("use_lighting_robustness", False),
1085
+ gaussian_blur_sigma=(0.1, 0.5),
1086
+ noise_std=0.02,
1087
+ rotation_degrees=5.0,
1088
+ )
1089
+ log_fn("Using selected data augmentation for training:")
1090
+ if augmentation_options["use_horizontal_flip"]:
1091
+ log_fn(" - Random horizontal flip")
1092
+ if augmentation_options["use_vertical_flip"]:
1093
+ log_fn(" - Random vertical flip")
1094
+ if augmentation_options["use_color_jitter"]:
1095
+ log_fn(" - Color jitter (brightness, contrast, saturation, hue)")
1096
+ if augmentation_options["use_gaussian_blur"]:
1097
+ log_fn(" - Gaussian blur (0.1-0.5 sigma)")
1098
+ if augmentation_options["use_random_noise"]:
1099
+ log_fn(" - Random noise (std=0.02)")
1100
+ if augmentation_options.get("use_speed_perturb", False):
1101
+ log_fn(" - Speed perturbation (0.7x-1.3x)")
1102
+ if augmentation_options.get("use_random_shapes", False):
1103
+ log_fn(" - Random shape overlays (occlusion)")
1104
+ if augmentation_options.get("use_grayscale", False):
1105
+ log_fn(" - Random grayscale (50% chance)")
1106
+ if augmentation_options.get("use_lighting_robustness", False):
1107
+ log_fn(" - Lighting/color robustness (gamma + channel gain)")
1108
+ if augmentation_options["use_small_rotation"]:
1109
+ log_fn(" - Small rotation (+/- 5 degrees)")
1110
+ log_fn(" - No cropping (content always preserved)")
1111
+ else:
1112
+ log_fn("No augmentation (using raw clips)")
1113
+
1114
+ clip_length = self.train_config.get("clip_length", 16)
1115
+ resolution = self.train_config.get("resolution", 288)
1116
+ log_fn(f"Using clip_length={clip_length} frames (center-cropped if clips are longer)")
1117
+ log_fn(f"Using input resolution: {resolution}x{resolution}")
1118
+
1119
+ # Virtual Dataset Expansion for small datasets
1120
+ num_train = len(train_clips)
1121
+ virtual_multiplier = 1
1122
+ if num_train > 0 and use_augmentation:
1123
+ virtual_multiplier = int(self.train_config.get("virtual_expansion", 5))
1124
+ log_fn(f"Small dataset detected ({num_train} clips). Using virtual expansion x{virtual_multiplier}")
1125
+ log_fn(f"Effective epoch size: {num_train * virtual_multiplier} samples (unique augmentations)")
1126
+
1127
+ stitch_prob = float(self.train_config.get("stitch_augmentation_prob", 0.0))
1128
+ emb_aug_versions = int(self.train_config.get("emb_aug_versions", 1))
1129
+ _multi_scale = self.train_config.get("multi_scale", False) and not self.train_config.get("use_localization", False)
1130
+ if self.train_config.get("use_localization", False):
1131
+ if stitch_prob > 0.0:
1132
+ stitch_prob = 0.0
1133
+ log_fn("Localization is enabled: disabling clip-stitch augmentation for this run.")
1134
+ aug_note = f" ({emb_aug_versions} aug version(s) per clip)" if emb_aug_versions > 1 else ""
1135
+ ms_note = " + short-scale (multi-scale)" if _multi_scale else ""
1136
+ log_fn(f"Embedding cache: always active{aug_note}{ms_note} — backbone skipped every training step")
1137
+ if stitch_prob > 0.0:
1138
+ log_fn(
1139
+ f"Clip-stitch augmentation: prob={stitch_prob:.0%}, fixed 50/50 split "
1140
+ f"(applied on cached embeddings during classification training)"
1141
+ )
1142
+
1143
+ use_crop_jitter = bool(self.train_config.get("crop_jitter", False))
1144
+ crop_jitter_strength = float(self.train_config.get("crop_jitter_strength", 0.15))
1145
+ if use_crop_jitter and self.train_config.get("use_localization", False):
1146
+ log_fn(f"Crop jitter: enabled (strength={crop_jitter_strength:.0%} of bbox size)")
1147
+
1148
+ train_dataset = BehaviorDataset(
1149
+ train_clips,
1150
+ annotation_manager,
1151
+ classes,
1152
+ self.clips_dir,
1153
+ transform=transform,
1154
+ target_size=(resolution, resolution),
1155
+ clip_length=clip_length,
1156
+ virtual_size_multiplier=virtual_multiplier,
1157
+ stitch_prob=stitch_prob,
1158
+ crop_jitter=bool(self.train_config.get("crop_jitter", False)),
1159
+ crop_jitter_strength=float(self.train_config.get("crop_jitter_strength", 0.15)),
1160
+ ovr_background_classes=self.train_config.get("ovr_background_class_names", []) if hybrid_ovr_bg else [],
1161
+ )
1162
+ log_fn(f"Train dataset created: {len(train_dataset)} virtual samples (from {num_train} unique clips)")
1163
+ except Exception as e:
1164
+ error_msg = f"Failed to create train dataset: {str(e)}\n{traceback.format_exc()}"
1165
+ log_fn(f"ERROR: {error_msg}")
1166
+ self.error.emit(error_msg)
1167
+ return
1168
+
1169
+ val_dataset = None
1170
+ if val_clips:
1171
+ try:
1172
+ clip_length = self.train_config.get("clip_length", 16)
1173
+ val_dataset = BehaviorDataset(
1174
+ val_clips,
1175
+ annotation_manager,
1176
+ classes,
1177
+ self.clips_dir,
1178
+ target_size=(resolution, resolution),
1179
+ clip_length=clip_length,
1180
+ ovr_background_classes=self.train_config.get("ovr_background_class_names", []) if hybrid_ovr_bg else [],
1181
+ )
1182
+ log_fn(f"Val dataset created: {len(val_dataset)} samples")
1183
+ except Exception as e:
1184
+ error_msg = f"Failed to create val dataset: {str(e)}\n{traceback.format_exc()}"
1185
+ log_fn(f"ERROR: {error_msg}")
1186
+ self.error.emit(error_msg)
1187
+ return
1188
+ else:
1189
+ log_fn("No validation dataset (using all data for training)")
1190
+
1191
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
1192
+
1193
+ def check_stop():
1194
+ return self.should_stop
1195
+
1196
+ if self.train_config.get("auto_tune_before_final", False):
1197
+ base_train_cfg = copy.deepcopy(self.train_config)
1198
+ if val_dataset is None:
1199
+ error_msg = "Auto-tune requires a validation set. Disable 'Use all data for training' or set a validation split."
1200
+ log_fn(f"ERROR: {error_msg}")
1201
+ self.error.emit(error_msg)
1202
+ return
1203
+ best_cfg, best_search_result = self._run_autotune_search(
1204
+ train_dataset,
1205
+ val_dataset,
1206
+ self.train_config,
1207
+ classes,
1208
+ augmentation_options,
1209
+ log_fn,
1210
+ progress_cb,
1211
+ check_stop,
1212
+ )
1213
+ if self.should_stop:
1214
+ log_fn("Auto-tune stopped.")
1215
+ self.training_complete.emit(0.0, 0.0, 0.0, {})
1216
+ self.finished.emit()
1217
+ return
1218
+ if not best_cfg:
1219
+ error_msg = "Auto-tune did not produce a valid candidate."
1220
+ log_fn(f"ERROR: {error_msg}")
1221
+ self.error.emit(error_msg)
1222
+ return
1223
+ self._reset_runtime_dataset_caches(train_dataset)
1224
+ self._reset_runtime_dataset_caches(val_dataset)
1225
+ self.train_config = copy.deepcopy(best_cfg)
1226
+ self.train_config["epochs"] = int(base_train_cfg.get("epochs", 1))
1227
+ log_fn("Auto-tune selected final config:")
1228
+ log_fn(
1229
+ f" lr={self.train_config.get('classification_lr', self.train_config.get('lr', 0.0)):.2e}, "
1230
+ f"wd={self.train_config.get('weight_decay', 0.0):.2e}, "
1231
+ f"dropout={self.train_config.get('dropout', 0.0):.2f}, "
1232
+ f"heads={self.train_config.get('head_kwargs', {}).get('num_heads', 4)}, "
1233
+ f"layers={self.train_config.get('frame_head_temporal_layers', 1)}"
1234
+ + (
1235
+ f", ovr_ls={self.train_config.get('ovr_label_smoothing', 0.0):.2f}"
1236
+ if self.train_config.get("use_ovr", False) else ""
1237
+ )
1238
+ )
1239
+ if best_search_result:
1240
+ log_fn(
1241
+ f" best search val F1={float(best_search_result.get('best_val_f1', 0.0) or 0.0):.2f}, "
1242
+ f"val acc={float(best_search_result.get('best_val_acc', 0.0) or 0.0):.2f}"
1243
+ )
1244
+ self._save_autotuned_profile(self.train_config, best_search_result, log_fn)
1245
+ log_fn("Starting final retrain from scratch with the selected config...")
1246
+
1247
+ try:
1248
+ model, head_kwargs_for_metadata, dropout, localization_dropout, num_stages, multi_scale = self._build_model_for_config(
1249
+ train_dataset, self.train_config, log_fn
1250
+ )
1251
+ except Exception as e:
1252
+ error_msg = f"Failed to create classifier: {str(e)}\n{traceback.format_exc()}"
1253
+ log_fn(f"ERROR: {error_msg}")
1254
+ self.error.emit(error_msg)
1255
+ return
1256
+
1257
+ train_config = self._build_backend_train_config(
1258
+ self.train_config,
1259
+ self.output_path,
1260
+ classes,
1261
+ augmentation_options,
1262
+ head_kwargs_for_metadata,
1263
+ dropout,
1264
+ localization_dropout,
1265
+ num_stages,
1266
+ multi_scale,
1267
+ )
1268
+
1269
+ log_fn("Starting training loop...")
1270
+
1271
+ def check_stop():
1272
+ return self.should_stop
1273
+
1274
+ try:
1275
+ result = train_model(
1276
+ model,
1277
+ train_dataset,
1278
+ val_dataset,
1279
+ train_config,
1280
+ log_fn=log_fn,
1281
+ progress_callback=progress_cb,
1282
+ stop_callback=check_stop,
1283
+ metrics_callback=metrics_cb
1284
+ )
1285
+
1286
+ if self.should_stop:
1287
+ log_fn("Training stopped.")
1288
+ self.training_complete.emit(0.0, 0.0, 0.0, {})
1289
+ else:
1290
+ log_fn("Training completed successfully!")
1291
+ best_val = result.get("best_val_acc", 0.0) if result else 0.0
1292
+ best_f1 = result.get("best_val_f1", 0.0) if result else 0.0
1293
+ final_train = result.get("final_train_acc", 0.0) if result else 0.0
1294
+ per_class_f1 = result.get("per_class_f1", {})
1295
+
1296
+ self.training_complete.emit(best_val, best_f1, final_train, per_class_f1)
1297
+
1298
+ self.finished.emit()
1299
+ except Exception as e:
1300
+ error_msg = f"Training failed: {str(e)}\n{traceback.format_exc()}"
1301
+ log_fn(f"ERROR: {error_msg}")
1302
+ self.error.emit(error_msg)
1303
+ return
1304
+
1305
+ except Exception as e:
1306
+ error_msg = f"Unexpected error: {str(e)}\n{traceback.format_exc()}"
1307
+ self.log_message.emit(f"ERROR: {error_msg}")
1308
+ self.error.emit(error_msg)
1309
+
1310
+
1311
+ class TrainingVisualizationDialog(QDialog):
1312
+ """Real-time training visualization with adaptive layout based on active metrics."""
1313
+
1314
+ _COLORS = {
1315
+ "train": "#2196F3",
1316
+ "val": "#FF9800",
1317
+ "train_cls": "#64B5F6",
1318
+ "macro_f1": "#4CAF50",
1319
+ "grid": "#e0e0e0",
1320
+ "bg": "#fafafa",
1321
+ "text": "#333333",
1322
+ "loc_iou": "#AB47BC",
1323
+ "loc_cerr": "#EF5350",
1324
+ "loc_vrate": "#26A69A",
1325
+ }
1326
+ _CLASS_PALETTE = [
1327
+ "#e6194b", "#3cb44b", "#4363d8", "#f58231", "#911eb4",
1328
+ "#42d4f4", "#f032e6", "#bfef45", "#fabed4", "#469990",
1329
+ "#dcbeff", "#9A6324", "#800000", "#aaffc3", "#808000",
1330
+ "#000075", "#a9a9a9",
1331
+ ]
1332
+
1333
+ def __init__(self, parent=None):
1334
+ super().__init__(parent)
1335
+ self.setWindowTitle("Training Monitor")
1336
+ self.resize(1150, 850)
1337
+
1338
+ root = QVBoxLayout()
1339
+ root.setContentsMargins(6, 6, 6, 6)
1340
+ self.setLayout(root)
1341
+
1342
+ top_bar = QHBoxLayout()
1343
+ self._status_label = QLabel("Waiting for first epoch...")
1344
+ self._status_label.setStyleSheet("font-weight: bold; font-size: 13px; padding: 4px 0;")
1345
+ top_bar.addWidget(self._status_label, stretch=1)
1346
+
1347
+ # F1 class filter: toggle which per-class lines are visible
1348
+ self._f1_filter_layout = QHBoxLayout()
1349
+ self._f1_filter_layout.setContentsMargins(0, 0, 0, 0)
1350
+ self._f1_filter_label = QLabel("F1 classes:")
1351
+ self._f1_filter_label.setStyleSheet("font-size: 11px; color: #666;")
1352
+ self._f1_filter_layout.addWidget(self._f1_filter_label)
1353
+ self._f1_filter_checks: dict[str, QCheckBox] = {}
1354
+ self._f1_filter_container = QWidget()
1355
+ self._f1_filter_container.setLayout(self._f1_filter_layout)
1356
+ self._f1_filter_container.setVisible(False)
1357
+
1358
+ root.addLayout(top_bar)
1359
+ root.addWidget(self._f1_filter_container)
1360
+
1361
+ # Horizontal splitter: charts left, crop preview right
1362
+ from PyQt6.QtWidgets import QSplitter, QScrollArea
1363
+ self._splitter = QSplitter(Qt.Orientation.Horizontal)
1364
+ root.addWidget(self._splitter, stretch=1)
1365
+
1366
+ # Left: matplotlib charts
1367
+ chart_widget = QWidget()
1368
+ chart_layout = QVBoxLayout()
1369
+ chart_layout.setContentsMargins(0, 0, 0, 0)
1370
+ chart_widget.setLayout(chart_layout)
1371
+ self.figure = Figure(figsize=(10, 8), dpi=100)
1372
+ self.figure.patch.set_facecolor(self._COLORS["bg"])
1373
+ self.canvas = FigureCanvas(self.figure)
1374
+ self.canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
1375
+ chart_layout.addWidget(self.canvas)
1376
+ self._splitter.addWidget(chart_widget)
1377
+
1378
+ # Right: crop progress preview (hidden until localization is active)
1379
+ self._crop_panel = QWidget()
1380
+ crop_layout = QVBoxLayout()
1381
+ crop_layout.setContentsMargins(4, 0, 4, 0)
1382
+ self._crop_panel.setLayout(crop_layout)
1383
+
1384
+ crop_header = QLabel("Localization Crop Preview")
1385
+ crop_header.setStyleSheet("font-weight: bold; font-size: 12px; padding: 2px 0;")
1386
+ crop_header.setAlignment(Qt.AlignmentFlag.AlignCenter)
1387
+ crop_layout.addWidget(crop_header)
1388
+
1389
+ self._crop_epoch_label = QLabel("")
1390
+ self._crop_epoch_label.setStyleSheet("font-size: 11px; color: #666; padding: 0 0 4px 0;")
1391
+ self._crop_epoch_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
1392
+ crop_layout.addWidget(self._crop_epoch_label)
1393
+
1394
+ # Scrollable area for crop images
1395
+ scroll = QScrollArea()
1396
+ scroll.setWidgetResizable(True)
1397
+ scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
1398
+ scroll.setStyleSheet("QScrollArea { border: none; background: #fafafa; }")
1399
+ self._crop_container = QWidget()
1400
+ self._crop_container_layout = QVBoxLayout()
1401
+ self._crop_container_layout.setContentsMargins(0, 0, 0, 0)
1402
+ self._crop_container_layout.setSpacing(6)
1403
+ self._crop_container.setLayout(self._crop_container_layout)
1404
+ scroll.setWidget(self._crop_container)
1405
+ crop_layout.addWidget(scroll, stretch=1)
1406
+
1407
+ self._crop_panel.setVisible(False)
1408
+ self._splitter.addWidget(self._crop_panel)
1409
+ self._splitter.setStretchFactor(0, 3)
1410
+ self._splitter.setStretchFactor(1, 1)
1411
+
1412
+ self._axes = {}
1413
+ self._crop_labels: list = []
1414
+ self._init_data()
1415
+
1416
+ def _init_data(self):
1417
+ self.epochs = []
1418
+ self._confusion_warmup_epoch = 0
1419
+ self.train_acc = []
1420
+ self.val_acc = []
1421
+ self.train_loss = []
1422
+ self.train_loss_class = []
1423
+ self.val_loss = []
1424
+ self.val_f1 = []
1425
+ self.per_class_f1 = {}
1426
+ self.loc_iou = []
1427
+ self.loc_center_error = []
1428
+ self.loc_valid_rate = []
1429
+ self._has_localization = False
1430
+ self._has_validation = False
1431
+ self._phase = "classification"
1432
+ self._crop_progress_dir = None
1433
+ self._last_crop_epoch = -1
1434
+ self._f1_classes_to_show = None # None = show all; set of str = only these in F1 graph
1435
+
1436
+ def reset(self, f1_classes_to_show=None, confusion_warmup_epoch: int = 0):
1437
+ """Clear all data for a new training run.
1438
+
1439
+ f1_classes_to_show: when limit_classes is used, set of class names whose
1440
+ checkboxes start checked; None = all checked by default.
1441
+ confusion_warmup_epoch: epoch after which confusion sampler activates (0 = off/no line).
1442
+ """
1443
+ self._init_data()
1444
+ self._confusion_warmup_epoch = confusion_warmup_epoch
1445
+ self._f1_classes_to_show = f1_classes_to_show
1446
+ for cb in self._f1_filter_checks.values():
1447
+ cb.setParent(None)
1448
+ cb.deleteLater()
1449
+ self._f1_filter_checks.clear()
1450
+ self._f1_filter_container.setVisible(False)
1451
+ self.figure.clear()
1452
+ self._axes.clear()
1453
+ self.canvas.draw()
1454
+ self._status_label.setText("Waiting for first epoch...")
1455
+ self._crop_panel.setVisible(False)
1456
+ self._crop_epoch_label.setText("")
1457
+ self._clear_crop_images()
1458
+
1459
+ def _clear_crop_images(self):
1460
+ for lbl in self._crop_labels:
1461
+ lbl.setParent(None)
1462
+ lbl.deleteLater()
1463
+ self._crop_labels.clear()
1464
+
1465
+ # Layout helpers.
1466
+
1467
+ def _rebuild_layout(self):
1468
+ """Create subplot grid based on which metrics are active."""
1469
+ self.figure.clear()
1470
+ self._axes.clear()
1471
+
1472
+ panels = ["loss", "acc", "f1"]
1473
+ if self._has_localization:
1474
+ panels.append("loc")
1475
+
1476
+ n = len(panels)
1477
+ for i, key in enumerate(panels):
1478
+ ax = self.figure.add_subplot(n, 1, i + 1)
1479
+ ax.set_facecolor(self._COLORS["bg"])
1480
+ self._axes[key] = ax
1481
+
1482
+ def _style_ax(self, ax, title, ylabel, xlabel=None):
1483
+ ax.set_title(title, fontsize=10, fontweight="bold", color=self._COLORS["text"], pad=6)
1484
+ ax.set_ylabel(ylabel, fontsize=9, color=self._COLORS["text"])
1485
+ if xlabel:
1486
+ ax.set_xlabel(xlabel, fontsize=9, color=self._COLORS["text"])
1487
+ ax.grid(True, linewidth=0.5, color=self._COLORS["grid"], alpha=0.7)
1488
+ ax.tick_params(labelsize=8, colors=self._COLORS["text"])
1489
+ for spine in ax.spines.values():
1490
+ spine.set_color(self._COLORS["grid"])
1491
+
1492
+ # Crop preview.
1493
+
1494
+ def _update_crop_preview(self, epoch, crop_dir):
1495
+ """Load latest crop progress PNGs from disk and display them."""
1496
+ if not crop_dir:
1497
+ return
1498
+ self._crop_progress_dir = crop_dir
1499
+
1500
+ # Crop progress images are saved every 2 epochs as epoch_NNN_sample_M.png
1501
+ pattern = os.path.join(crop_dir, f"epoch_{epoch:03d}_sample_*.png")
1502
+ files = sorted(glob.glob(pattern))
1503
+ if not files:
1504
+ return
1505
+ if epoch == self._last_crop_epoch:
1506
+ return
1507
+ self._last_crop_epoch = epoch
1508
+
1509
+ self._clear_crop_images()
1510
+ self._crop_panel.setVisible(True)
1511
+ self._crop_epoch_label.setText(f"Epoch {epoch} · {len(files)} sample(s)")
1512
+
1513
+ panel_width = max(280, self._crop_panel.width() - 20)
1514
+ for fpath in files:
1515
+ try:
1516
+ pixmap = QPixmap(fpath)
1517
+ if pixmap.isNull():
1518
+ continue
1519
+ scaled = pixmap.scaledToWidth(panel_width, Qt.TransformationMode.SmoothTransformation)
1520
+ lbl = QLabel()
1521
+ lbl.setPixmap(scaled)
1522
+ lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
1523
+ lbl.setStyleSheet("border: 1px solid #ccc; border-radius: 3px; background: white; padding: 2px;")
1524
+ lbl.setToolTip(os.path.basename(fpath))
1525
+ self._crop_container_layout.addWidget(lbl)
1526
+ self._crop_labels.append(lbl)
1527
+ except Exception as e:
1528
+ logger.debug("Could not load crop preview image: %s", e)
1529
+
1530
+ # Main update.
1531
+
1532
+ def update_plots(self, metrics):
1533
+ """Update plots with new epoch metrics dict."""
1534
+ epoch = metrics["epoch"]
1535
+ self.epochs.append(epoch)
1536
+ self.train_acc.append(metrics["train_acc"])
1537
+ # In frame-level training, clip-level val_acc is not meaningful/updated.
1538
+ # Prefer val_frame_acc for monitor accuracy visualization.
1539
+ self.val_acc.append(metrics.get("val_frame_acc", metrics.get("val_acc", 0.0)))
1540
+ self.train_loss.append(metrics["train_loss"])
1541
+ self.train_loss_class.append(metrics.get("train_loss_class", 0.0))
1542
+ self.val_loss.append(metrics.get("val_loss", 0.0))
1543
+ self.val_f1.append(metrics.get("val_f1", 0.0))
1544
+ self._phase = metrics.get("training_phase", "classification")
1545
+
1546
+ if not self._has_validation and (
1547
+ any(v > 0 for v in self.val_acc)
1548
+ or any(v > 0 for v in self.val_loss)
1549
+ or any(v > 0 for v in self.val_f1)
1550
+ ):
1551
+ self._has_validation = True
1552
+
1553
+ # Localization metrics
1554
+ iou = metrics.get("loc_val_iou", 0.0)
1555
+ cerr = metrics.get("loc_val_center_error", 0.0)
1556
+ vrate = metrics.get("loc_val_valid_rate", 0.0)
1557
+ self.loc_iou.append(iou)
1558
+ self.loc_center_error.append(cerr)
1559
+ self.loc_valid_rate.append(vrate)
1560
+ if not self._has_localization and (iou > 0 or vrate > 0):
1561
+ self._has_localization = True
1562
+
1563
+ # Per-class F1
1564
+ for cls, score in metrics.get("per_class_f1", {}).items():
1565
+ if cls not in self.per_class_f1:
1566
+ self.per_class_f1[cls] = [0.0] * (len(self.epochs) - 1)
1567
+ self.per_class_f1[cls].append(score)
1568
+ if cls not in self._f1_filter_checks:
1569
+ checked = self._f1_classes_to_show is None or cls in self._f1_classes_to_show
1570
+ cb = QCheckBox(cls)
1571
+ cb.setChecked(checked)
1572
+ cb.setStyleSheet("font-size: 11px;")
1573
+ cb.stateChanged.connect(lambda _: self._redraw_f1_only())
1574
+ self._f1_filter_checks[cls] = cb
1575
+ self._f1_filter_layout.addWidget(cb)
1576
+ self._f1_filter_container.setVisible(True)
1577
+ for cls in self.per_class_f1:
1578
+ while len(self.per_class_f1[cls]) < len(self.epochs):
1579
+ self.per_class_f1[cls].append(0.0)
1580
+
1581
+ self._rebuild_layout()
1582
+ self._draw_loss()
1583
+ self._draw_acc()
1584
+ self._draw_f1()
1585
+ if self._has_localization and "loc" in self._axes:
1586
+ self._draw_loc()
1587
+
1588
+ # Only the bottom subplot gets an x-label
1589
+ panels = list(self._axes.values())
1590
+ if panels:
1591
+ panels[-1].set_xlabel("Epoch", fontsize=9, color=self._COLORS["text"])
1592
+
1593
+ self.figure.tight_layout(h_pad=1.2)
1594
+ self.canvas.draw()
1595
+
1596
+ # Crop preview: load latest images when available
1597
+ crop_dir = metrics.get("crop_progress_dir")
1598
+ if crop_dir:
1599
+ self._update_crop_preview(epoch, crop_dir)
1600
+
1601
+ # Status bar
1602
+ phase_tag = f"[{self._phase}]" if self._phase != "classification" else ""
1603
+ best_f1 = max(self.val_f1) if self.val_f1 else 0.0
1604
+ self._status_label.setText(
1605
+ f"Epoch {epoch} {phase_tag} | Train Loss: {self.train_loss[-1]:.4f} | "
1606
+ f"Val F1: {self.val_f1[-1]:.1f}% (best {best_f1:.1f}%) | "
1607
+ f"Val Frame Acc: {self.val_acc[-1]:.1f}%"
1608
+ )
1609
+
1610
+ # Individual panel drawers.
1611
+
1612
+ def _draw_loss(self):
1613
+ ax = self._axes["loss"]
1614
+ ax.plot(self.epochs, self.train_loss, color=self._COLORS["train"],
1615
+ linewidth=2, label="Train Loss")
1616
+ if any(v > 0 for v in self.train_loss_class):
1617
+ ax.plot(self.epochs, self.train_loss_class, color=self._COLORS["train_cls"],
1618
+ linewidth=1.3, linestyle="--", alpha=0.8, label="Classification Loss")
1619
+ if self._has_validation:
1620
+ ax.plot(self.epochs, self.val_loss, color=self._COLORS["val"],
1621
+ linewidth=2, label="Val Loss")
1622
+ self._style_ax(ax, "Loss", "Loss")
1623
+ ax.legend(fontsize=8, loc="upper right", framealpha=0.8)
1624
+
1625
+ def _draw_acc(self):
1626
+ ax = self._axes["acc"]
1627
+ ax.plot(self.epochs, self.train_acc, color=self._COLORS["train"],
1628
+ linewidth=2, label="Train Acc")
1629
+ if self._has_validation:
1630
+ ax.plot(self.epochs, self.val_acc, color=self._COLORS["val"],
1631
+ linewidth=2, label="Val Frame Acc")
1632
+ self._style_ax(ax, "Accuracy", "Accuracy (%)")
1633
+ ax.legend(fontsize=8, loc="lower right", framealpha=0.8)
1634
+
1635
+ def _redraw_f1_only(self):
1636
+ """Redraw just the F1 panel when class visibility toggles change."""
1637
+ if "f1" not in self._axes or not self.epochs:
1638
+ return
1639
+ self._axes["f1"].clear()
1640
+ self._draw_f1()
1641
+ self.canvas.draw()
1642
+
1643
+ def _draw_f1(self):
1644
+ ax = self._axes["f1"]
1645
+ ax.plot(self.epochs, self.val_f1, color=self._COLORS["macro_f1"],
1646
+ linewidth=2.5, label="Macro F1")
1647
+ visible_items = []
1648
+ for i, (cls, scores) in enumerate(self.per_class_f1.items()):
1649
+ cb = self._f1_filter_checks.get(cls)
1650
+ if cb is not None and not cb.isChecked():
1651
+ continue
1652
+ color = self._CLASS_PALETTE[i % len(self._CLASS_PALETTE)]
1653
+ ax.plot(self.epochs, scores, color=color,
1654
+ linewidth=1.2, linestyle="--", alpha=0.7, label=cls)
1655
+ visible_items.append(cls)
1656
+ if self._confusion_warmup_epoch > 0:
1657
+ ax.axvline(x=self._confusion_warmup_epoch, color="#9E9E9E",
1658
+ linestyle=":", linewidth=1.2, alpha=0.8)
1659
+ ax.text(self._confusion_warmup_epoch + 0.3, ax.get_ylim()[1] * 0.97,
1660
+ "confusion sampler", fontsize=6, color="#757575",
1661
+ va="top", ha="left", rotation=0)
1662
+ self._style_ax(ax, "Val Macro F1 (frame)", "F1 (%)")
1663
+ ncol = min(4, max(1, len(visible_items) + 1))
1664
+ ax.legend(fontsize=7, ncol=ncol, loc="lower right", framealpha=0.8)
1665
+
1666
+ def _draw_loc(self):
1667
+ ax = self._axes["loc"]
1668
+ ax.plot(self.epochs, self.loc_iou, color=self._COLORS["loc_iou"],
1669
+ linewidth=2, label="IoU")
1670
+ ax.plot(self.epochs, self.loc_valid_rate, color=self._COLORS["loc_vrate"],
1671
+ linewidth=2, label="Valid Rate")
1672
+ ax2 = ax.twinx()
1673
+ ax2.plot(self.epochs, self.loc_center_error, color=self._COLORS["loc_cerr"],
1674
+ linewidth=1.5, linestyle="--", alpha=0.8, label="Center Err")
1675
+ ax2.set_ylabel("Center Error", fontsize=8, color=self._COLORS["loc_cerr"])
1676
+ ax2.tick_params(labelsize=8, colors=self._COLORS["loc_cerr"])
1677
+ self._style_ax(ax, "Localization", "IoU / Valid Rate")
1678
+ lines1, labels1 = ax.get_legend_handles_labels()
1679
+ lines2, labels2 = ax2.get_legend_handles_labels()
1680
+ ax.legend(lines1 + lines2, labels1 + labels2, fontsize=8, loc="lower right", framealpha=0.8)
1681
+
1682
+ class TrainingWidget(QWidget):
1683
+ """Widget for training the behavior classifier."""
1684
+
1685
+ def __init__(self, config: dict):
1686
+ super().__init__()
1687
+ self.config = config
1688
+ self.augmentation_options = self._default_augmentation_options()
1689
+ self.worker = None
1690
+ self.annotation_manager = AnnotationManager(
1691
+ self.config.get(
1692
+ "training_annotation_file",
1693
+ self.config.get("annotation_file", "data/annotations/annotations.json"),
1694
+ )
1695
+ )
1696
+ self._config_initialized = False
1697
+ self.profile_dialog = None
1698
+ self.training_queue = []
1699
+ self.is_batch_training = False
1700
+ self.batch_results = []
1701
+ self.batch_results_path = None
1702
+ self.current_profile_name = None
1703
+ self.visualization_dialog = None
1704
+ self._resolution = int(self.config.get("resolution", 288))
1705
+ self._setup_ui()
1706
+ self._load_current_config(force=True)
1707
+ self.refresh_annotation_info()
1708
+
1709
+ def _load_current_config(self, force: bool = False):
1710
+ """Apply current config paths to UI."""
1711
+ if self._config_initialized and not force:
1712
+ return
1713
+
1714
+ annotation_file = self.config.get(
1715
+ "training_annotation_file",
1716
+ self.config.get("annotation_file", "data/annotations/annotations.json"),
1717
+ )
1718
+ clips_dir = self.config.get(
1719
+ "training_clips_dir",
1720
+ self.config.get("clips_dir", "data/clips"),
1721
+ )
1722
+ models_dir = self.config.get("models_dir", "models/behavior_heads")
1723
+
1724
+ if hasattr(self, "annotation_file_edit"):
1725
+ self.annotation_file_edit.setText(annotation_file)
1726
+ if hasattr(self, "clips_dir_edit"):
1727
+ self.clips_dir_edit.setText(clips_dir)
1728
+ if hasattr(self, "output_path_edit"):
1729
+ default_output = os.path.join(models_dir, "head.pt")
1730
+ self.output_path_edit.setText(default_output)
1731
+ self._resolution = int(self.config.get("resolution", 288))
1732
+ if self._resolution % 18 != 0:
1733
+ self._resolution = (self._resolution // 18) * 18
1734
+ if hasattr(self, "weight_decay_spin"):
1735
+ self.weight_decay_spin.setValue(float(self.config.get("default_weight_decay", 0.001)))
1736
+ if hasattr(self, "use_supcon_check"):
1737
+ default_use_supcon = bool(self.config.get("default_use_supcon_loss", False))
1738
+ self.use_supcon_check.setChecked(default_use_supcon)
1739
+ if hasattr(self, "supcon_weight_spin"):
1740
+ self.supcon_weight_spin.setValue(float(self.config.get("default_supcon_weight", 0.2)))
1741
+ self.supcon_weight_spin.setEnabled(default_use_supcon)
1742
+ if hasattr(self, "supcon_temp_spin"):
1743
+ self.supcon_temp_spin.setValue(float(self.config.get("default_supcon_temperature", 0.1)))
1744
+ self.supcon_temp_spin.setEnabled(default_use_supcon)
1745
+
1746
+ self._config_initialized = True
1747
+
1748
+ def _setup_ui(self):
1749
+ """Setup UI components."""
1750
+ layout = QVBoxLayout()
1751
+
1752
+ # Dataset Info with scrollbar
1753
+ info_group = QGroupBox("Dataset info")
1754
+ info_layout = QVBoxLayout()
1755
+ self.info_label = QLabel("Loading...")
1756
+ self.info_label.setWordWrap(True)
1757
+ info_layout.addWidget(self.info_label)
1758
+ info_group.setLayout(info_layout)
1759
+
1760
+ info_scroll = QScrollArea()
1761
+ info_scroll.setWidget(info_group)
1762
+ info_scroll.setWidgetResizable(True)
1763
+ info_scroll.setMinimumHeight(120)
1764
+ info_scroll.setMaximumHeight(200)
1765
+
1766
+ # Training Configuration inside a scrollable container with grouped sections
1767
+ config_container = QWidget()
1768
+ config_vbox = QVBoxLayout(config_container)
1769
+ config_vbox.setContentsMargins(2, 2, 2, 2)
1770
+ config_vbox.setSpacing(6)
1771
+
1772
+ # --- Paths & Files ---
1773
+ paths_group = QGroupBox("Paths && Files")
1774
+ config_layout = QFormLayout()
1775
+
1776
+ self.annotation_file_edit = QLineEdit()
1777
+ self.annotation_file_edit.setText(
1778
+ self.config.get(
1779
+ "training_annotation_file",
1780
+ self.config.get("annotation_file", "data/annotations/annotations.json"),
1781
+ )
1782
+ )
1783
+ self.annotation_browse_btn = QPushButton("Browse...")
1784
+ self.annotation_browse_btn.clicked.connect(self._browse_annotation)
1785
+ annotation_layout = QHBoxLayout()
1786
+ annotation_layout.addWidget(self.annotation_file_edit)
1787
+ annotation_layout.addWidget(self.annotation_browse_btn)
1788
+ config_layout.addRow("Annotation file:", annotation_layout)
1789
+
1790
+ self.clips_dir_edit = QLineEdit()
1791
+ self.clips_dir_edit.setText(
1792
+ self.config.get(
1793
+ "training_clips_dir",
1794
+ self.config.get("clips_dir", "data/clips"),
1795
+ )
1796
+ )
1797
+ self.clips_browse_btn = QPushButton("Browse...")
1798
+ self.clips_browse_btn.clicked.connect(self._browse_clips_dir)
1799
+ clips_layout = QHBoxLayout()
1800
+ clips_layout.addWidget(self.clips_dir_edit)
1801
+ clips_layout.addWidget(self.clips_browse_btn)
1802
+ config_layout.addRow("Clips directory:", clips_layout)
1803
+
1804
+ self.output_path_edit = QLineEdit()
1805
+ self.output_path_edit.setText(self.config.get("models_dir", "models/behavior_heads") + "/head.pt")
1806
+ self.output_browse_btn = QPushButton("Browse...")
1807
+ self.output_browse_btn.clicked.connect(self._browse_output)
1808
+ output_layout = QHBoxLayout()
1809
+ output_layout.addWidget(self.output_path_edit)
1810
+ output_layout.addWidget(self.output_browse_btn)
1811
+ config_layout.addRow("Output model:", output_layout)
1812
+
1813
+ paths_group.setLayout(config_layout)
1814
+
1815
+ # --- Training Hyperparameters ---
1816
+ hyper_group = QGroupBox("Training Hyperparameters")
1817
+ config_layout = QFormLayout()
1818
+
1819
+ clip_length_layout = QHBoxLayout()
1820
+ self.clip_length_spin = QSpinBox()
1821
+ self.clip_length_spin.setRange(1, 64)
1822
+ self.clip_length_spin.setValue(int(self.config.get("default_clip_length", 8)))
1823
+ self.clip_length_spin.setToolTip(
1824
+ "Number of frames to use for training.\n"
1825
+ "Can be equal to or less than the actual clip length.\n"
1826
+ "If less, the middle N frames are used (temporal center-crop)."
1827
+ )
1828
+ clip_length_layout.addWidget(self.clip_length_spin)
1829
+
1830
+ info_btn = QToolButton()
1831
+ info_btn.setText("?")
1832
+ info_btn.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextOnly)
1833
+ info_btn.setStyleSheet("""
1834
+ QToolButton {
1835
+ background-color: #4A90E2;
1836
+ color: white;
1837
+ border: none;
1838
+ border-radius: 10px;
1839
+ width: 20px;
1840
+ height: 20px;
1841
+ font-weight: bold;
1842
+ font-size: 12px;
1843
+ }
1844
+ QToolButton:hover {
1845
+ background-color: #357ABD;
1846
+ }
1847
+ """)
1848
+ info_btn.setFixedSize(20, 20)
1849
+ info_btn.setToolTip("Click for information about frames per clip")
1850
+ info_btn.clicked.connect(self._show_clip_length_info)
1851
+ clip_length_layout.addWidget(info_btn)
1852
+ clip_length_layout.addStretch()
1853
+
1854
+ config_layout.addRow("Frames per clip:", clip_length_layout)
1855
+
1856
+ self.batch_size_spin = QSpinBox()
1857
+ self.batch_size_spin.setRange(1, 32)
1858
+ self.batch_size_spin.setValue(8)
1859
+ config_layout.addRow("Batch size:", self.batch_size_spin)
1860
+
1861
+ self.epochs_spin = QSpinBox()
1862
+ self.epochs_spin.setRange(1, 1000)
1863
+ self.epochs_spin.setValue(60)
1864
+ config_layout.addRow("Epochs:", self.epochs_spin)
1865
+
1866
+ lr_row = QHBoxLayout()
1867
+ self.loc_lr_spin = QDoubleSpinBox()
1868
+ self.loc_lr_spin.setRange(1e-6, 1.0)
1869
+ self.loc_lr_spin.setValue(1e-4)
1870
+ self.loc_lr_spin.setDecimals(6)
1871
+ self.loc_lr_spin.setSingleStep(1e-4)
1872
+ self.loc_lr_spin.setToolTip("Learning rate used during localization phase (localization head only).")
1873
+ lr_row.addWidget(self.loc_lr_spin)
1874
+ lr_row.addWidget(QLabel("Loc LR"))
1875
+ self.class_lr_spin = QDoubleSpinBox()
1876
+ self.class_lr_spin.setRange(1e-6, 1.0)
1877
+ self.class_lr_spin.setValue(1e-4)
1878
+ self.class_lr_spin.setDecimals(6)
1879
+ self.class_lr_spin.setSingleStep(1e-4)
1880
+ self.class_lr_spin.setToolTip("Learning rate used during classification phase (backbone + MAP head).")
1881
+ lr_row.addWidget(self.class_lr_spin)
1882
+ lr_row.addWidget(QLabel("Class LR"))
1883
+ lr_row.addStretch()
1884
+ config_layout.addRow("Learning rates:", lr_row)
1885
+
1886
+ self.weight_decay_spin = QDoubleSpinBox()
1887
+ self.weight_decay_spin.setRange(0.0, 1.0)
1888
+ self.weight_decay_spin.setValue(float(self.config.get("default_weight_decay", 0.001)))
1889
+ self.weight_decay_spin.setDecimals(6)
1890
+ config_layout.addRow("Weight decay:", self.weight_decay_spin)
1891
+
1892
+ hyper_group.setLayout(config_layout)
1893
+
1894
+ # --- Model Architecture ---
1895
+ arch_group = QGroupBox("Model Architecture")
1896
+ config_layout = QFormLayout()
1897
+
1898
+ self.map_num_heads_spin = QSpinBox()
1899
+ self.map_num_heads_spin.setRange(1, 16)
1900
+ self.map_num_heads_spin.setValue(4)
1901
+ config_layout.addRow("MAP num_heads:", self.map_num_heads_spin)
1902
+
1903
+ self.proj_dim_spin = QSpinBox()
1904
+ self.proj_dim_spin.setRange(64, 1024)
1905
+ self.proj_dim_spin.setSingleStep(64)
1906
+ self.proj_dim_spin.setValue(256)
1907
+ self.proj_dim_spin.setToolTip("Projection dimension for spatial attention pool in the frame head.")
1908
+ config_layout.addRow("Spatial pool proj dim:", self.proj_dim_spin)
1909
+
1910
+ self.use_multi_scale_check = QCheckBox("Multi-scale temporal context")
1911
+ self.use_multi_scale_check.setToolTip(
1912
+ "Cache backbone embeddings at two temporal scales: full fps and half fps\n"
1913
+ "(same clip duration, fewer frames). The temporal head sees both fine-grained\n"
1914
+ "local motion and broader context per frame, improving precision for subtle\n"
1915
+ "and short behaviors.\n\n"
1916
+ "Doubles backbone precomputation time. Requires clip_length ≥ 4.\n"
1917
+ "Disabled when localization is active."
1918
+ )
1919
+ self.use_multi_scale_check.setChecked(False)
1920
+ config_layout.addRow("", self.use_multi_scale_check)
1921
+
1922
+ self.map_dropout_spin = QDoubleSpinBox()
1923
+ self.map_dropout_spin.setRange(0.0, 0.9)
1924
+ self.map_dropout_spin.setValue(0.3)
1925
+ self.map_dropout_spin.setDecimals(2)
1926
+ self.map_dropout_spin.setToolTip(
1927
+ "Dropout for classification head (spatial pool + temporal TCN + boundary head).\n"
1928
+ "0.2-0.3 recommended. Higher values may hurt temporal conv layers.\n"
1929
+ "Localization head dropout is fixed at 0.0."
1930
+ )
1931
+ config_layout.addRow("Classification head dropout:", self.map_dropout_spin)
1932
+
1933
+ # Class-balanced loss weights
1934
+ self.use_class_weights_check = QCheckBox("Use class-balanced loss weights")
1935
+ self.use_class_weights_check.setToolTip(
1936
+ "Weight the loss by inverse class frequency to compensate for class imbalance.\n"
1937
+ "Rare classes get higher loss weight. Recommended when class counts are uneven."
1938
+ )
1939
+ self.use_class_weights_check.setChecked(False)
1940
+ config_layout.addRow("", self.use_class_weights_check)
1941
+
1942
+ self.use_supcon_check = QCheckBox("Use supervised contrastive loss on MAP embeddings")
1943
+ self.use_supcon_check.setToolTip(
1944
+ "Add an auxiliary supervised contrastive loss on the attention-pooled\n"
1945
+ "frame embeddings before the final classifier. Can be used with or without\n"
1946
+ "the temporal decoder."
1947
+ )
1948
+ self.use_supcon_check.setChecked(False)
1949
+ self.use_supcon_check.stateChanged.connect(
1950
+ lambda state: (
1951
+ self.supcon_weight_spin.setEnabled(bool(state)),
1952
+ self.supcon_temp_spin.setEnabled(bool(state)),
1953
+ )
1954
+ )
1955
+ config_layout.addRow("", self.use_supcon_check)
1956
+
1957
+ self.supcon_weight_spin = QDoubleSpinBox()
1958
+ self.supcon_weight_spin.setRange(0.0, 5.0)
1959
+ self.supcon_weight_spin.setSingleStep(0.05)
1960
+ self.supcon_weight_spin.setValue(0.2)
1961
+ self.supcon_weight_spin.setDecimals(2)
1962
+ self.supcon_weight_spin.setEnabled(False)
1963
+ self.supcon_weight_spin.setToolTip("Weight of the supervised contrastive loss term.")
1964
+ config_layout.addRow("SupCon weight:", self.supcon_weight_spin)
1965
+
1966
+ self.supcon_temp_spin = QDoubleSpinBox()
1967
+ self.supcon_temp_spin.setRange(0.01, 2.0)
1968
+ self.supcon_temp_spin.setSingleStep(0.01)
1969
+ self.supcon_temp_spin.setValue(0.10)
1970
+ self.supcon_temp_spin.setDecimals(2)
1971
+ self.supcon_temp_spin.setEnabled(False)
1972
+ self.supcon_temp_spin.setToolTip("Temperature used in the supervised contrastive loss.")
1973
+ config_layout.addRow("SupCon temperature:", self.supcon_temp_spin)
1974
+
1975
+ # Frame head settings (always active)
1976
+ self.use_temporal_decoder_check = QCheckBox("Use temporal decoder / refinement head")
1977
+ self.use_temporal_decoder_check.setToolTip(
1978
+ "Enable the temporal MS-TCN decoder after spatial attention pooling.\n"
1979
+ "If disabled, training uses the simpler baseline: spatial attention pooling\n"
1980
+ "+ direct per-frame classifier, while still producing framewise predictions."
1981
+ )
1982
+ self.use_temporal_decoder_check.setChecked(True)
1983
+ self.use_temporal_decoder_check.stateChanged.connect(self._on_temporal_decoder_toggled)
1984
+ config_layout.addRow("", self.use_temporal_decoder_check)
1985
+
1986
+ self.frame_head_layers_spin = QSpinBox()
1987
+ self.frame_head_layers_spin.setRange(1, 8)
1988
+ self.frame_head_layers_spin.setValue(4)
1989
+ self.frame_head_layers_spin.setToolTip(
1990
+ "Number of dilated temporal conv layers in the frame head.\n"
1991
+ "Higher values increase temporal receptive field."
1992
+ )
1993
+ config_layout.addRow("Frame head temporal layers:", self.frame_head_layers_spin)
1994
+
1995
+ self.temporal_pool_spin = QSpinBox()
1996
+ self.temporal_pool_spin.setRange(1, 4)
1997
+ self.temporal_pool_spin.setValue(1)
1998
+ self.temporal_pool_spin.setToolTip(
1999
+ "Average this many adjacent frames before temporal classification.\n"
2000
+ "1 = per-frame (no pooling), 2 = average pairs."
2001
+ )
2002
+ config_layout.addRow("Temporal pool (frames):", self.temporal_pool_spin)
2003
+
2004
+ # Boundary loss
2005
+ self.boundary_loss_weight_spin = QDoubleSpinBox()
2006
+ self.boundary_loss_weight_spin.setRange(0.0, 5.0)
2007
+ self.boundary_loss_weight_spin.setSingleStep(0.1)
2008
+ self.boundary_loss_weight_spin.setValue(0.3)
2009
+ self.boundary_loss_weight_spin.setDecimals(2)
2010
+ self.boundary_loss_weight_spin.setToolTip("Weight of boundary detection loss (change-point prediction).")
2011
+ config_layout.addRow("Boundary loss weight:", self.boundary_loss_weight_spin)
2012
+
2013
+ self.boundary_tolerance_spin = QSpinBox()
2014
+ self.boundary_tolerance_spin.setRange(0, 10)
2015
+ self.boundary_tolerance_spin.setValue(2)
2016
+ self.boundary_tolerance_spin.setToolTip("Tolerance in frames around labeled transition for boundary target.")
2017
+ config_layout.addRow("Boundary tolerance:", self.boundary_tolerance_spin)
2018
+
2019
+ # Smoothness loss
2020
+ self.smoothness_loss_weight_spin = QDoubleSpinBox()
2021
+ self.smoothness_loss_weight_spin.setRange(0.0, 1.0)
2022
+ self.smoothness_loss_weight_spin.setSingleStep(0.01)
2023
+ self.smoothness_loss_weight_spin.setValue(0.05)
2024
+ self.smoothness_loss_weight_spin.setDecimals(3)
2025
+ self.smoothness_loss_weight_spin.setToolTip("Weight for temporal smoothness regularization on frame predictions.")
2026
+ config_layout.addRow("Smoothness loss weight:", self.smoothness_loss_weight_spin)
2027
+
2028
+ # Bout balance weighting
2029
+ self.use_bout_balance_check = QCheckBox("Use bout balance weighting")
2030
+ self.use_bout_balance_check.setToolTip(
2031
+ "Weight each frame inversely to its contiguous segment length.\n"
2032
+ "Prevents long bouts from dominating the loss over short actions."
2033
+ )
2034
+ self.use_bout_balance_check.setChecked(True)
2035
+ config_layout.addRow("", self.use_bout_balance_check)
2036
+
2037
+ self.bout_balance_power_spin = QDoubleSpinBox()
2038
+ self.bout_balance_power_spin.setRange(0.1, 2.0)
2039
+ self.bout_balance_power_spin.setSingleStep(0.1)
2040
+ self.bout_balance_power_spin.setValue(1.0)
2041
+ self.bout_balance_power_spin.setDecimals(1)
2042
+ self.bout_balance_power_spin.setToolTip(
2043
+ "Exponent for bout balance weighting: weight = segment_length ^ (-power).\n"
2044
+ "1.0 = linear (default, most aggressive).\n"
2045
+ "0.5 = square-root (softer).\n"
2046
+ "Higher = shorter segments get relatively more weight."
2047
+ )
2048
+ config_layout.addRow("Bout balance power:", self.bout_balance_power_spin)
2049
+ self.use_bout_balance_check.stateChanged.connect(
2050
+ lambda s: self.bout_balance_power_spin.setEnabled(bool(s))
2051
+ )
2052
+
2053
+ arch_group.setLayout(config_layout)
2054
+
2055
+ # --- Localization ---
2056
+ loc_group = QGroupBox("Localization (requires bbox labels)")
2057
+ self.loc_group = loc_group
2058
+ loc_group.setEnabled(False)
2059
+ loc_group.setToolTip(
2060
+ "Optional: localizes individual animals when multiple are in camera view.\n"
2061
+ "To activate, draw and save bounding boxes on at least some clips in the Labeling tab."
2062
+ )
2063
+ config_layout = QFormLayout()
2064
+
2065
+ self.use_localization_check = QCheckBox("Use Localization Supervision (bbox)")
2066
+ self.use_localization_check.setToolTip(
2067
+ "Autonomous 2-stage training: first learn localization from bbox labels, then train classifier on localized crops."
2068
+ )
2069
+ self.use_localization_check.setChecked(False)
2070
+ config_layout.addRow("", self.use_localization_check)
2071
+
2072
+ self.use_manual_loc_switch_check = QCheckBox("Manual switch to classification epoch")
2073
+ self.use_manual_loc_switch_check.setToolTip(
2074
+ "If enabled, switch from localization phase to classification phase at the selected epoch."
2075
+ )
2076
+ self.use_manual_loc_switch_check.setChecked(False)
2077
+ self.use_manual_loc_switch_check.stateChanged.connect(
2078
+ lambda state: self.manual_loc_switch_epoch_spin.setEnabled(
2079
+ bool(state) and self.use_localization_check.isChecked()
2080
+ )
2081
+ )
2082
+ config_layout.addRow("", self.use_manual_loc_switch_check)
2083
+
2084
+ self.manual_loc_switch_epoch_spin = QSpinBox()
2085
+ self.manual_loc_switch_epoch_spin.setRange(1, 10000)
2086
+ self.manual_loc_switch_epoch_spin.setValue(20)
2087
+ self.manual_loc_switch_epoch_spin.setEnabled(False)
2088
+ self.manual_loc_switch_epoch_spin.setToolTip(
2089
+ "Epoch number at which localization stops and classification starts."
2090
+ )
2091
+ config_layout.addRow("Switch epoch:", self.manual_loc_switch_epoch_spin)
2092
+
2093
+ self.use_localization_check.stateChanged.connect(
2094
+ lambda state: self.manual_loc_switch_epoch_spin.setEnabled(
2095
+ bool(state) and self.use_manual_loc_switch_check.isChecked()
2096
+ )
2097
+ )
2098
+ self.use_localization_check.stateChanged.connect(self._on_localization_toggled)
2099
+
2100
+ self.crop_padding_spin = QDoubleSpinBox()
2101
+ self.crop_padding_spin.setRange(-0.45, 1.0)
2102
+ self.crop_padding_spin.setSingleStep(0.05)
2103
+ self.crop_padding_spin.setDecimals(2)
2104
+ self.crop_padding_spin.setValue(0.35)
2105
+ self.crop_padding_spin.setToolTip(
2106
+ "Fractional padding around the predicted bbox for classification crops.\n"
2107
+ "Positive: expand crop (0.35 = 70% larger than bbox).\n"
2108
+ "Negative: shrink/zoom in (-0.2 = crop is 60% of bbox, zooming into center).\n"
2109
+ "Use 0.05-0.10 for tight crops, negative for extreme close-ups."
2110
+ )
2111
+ config_layout.addRow("Crop padding:", self.crop_padding_spin)
2112
+
2113
+ self.use_crop_jitter_check = QCheckBox("Crop jitter augmentation")
2114
+ self.use_crop_jitter_check.setToolTip(
2115
+ "Randomly shift the crop center during training to prevent\n"
2116
+ "the model from memorizing background/location cues.\n"
2117
+ "Only available when localization is enabled."
2118
+ )
2119
+ self.use_crop_jitter_check.setChecked(False)
2120
+ self.use_crop_jitter_check.setEnabled(False)
2121
+ config_layout.addRow("", self.use_crop_jitter_check)
2122
+
2123
+ self.crop_jitter_strength_spin = QDoubleSpinBox()
2124
+ self.crop_jitter_strength_spin.setRange(0.01, 0.5)
2125
+ self.crop_jitter_strength_spin.setSingleStep(0.05)
2126
+ self.crop_jitter_strength_spin.setDecimals(2)
2127
+ self.crop_jitter_strength_spin.setValue(0.15)
2128
+ self.crop_jitter_strength_spin.setEnabled(False)
2129
+ self.crop_jitter_strength_spin.setToolTip(
2130
+ "Max random shift as a fraction of bbox size.\n"
2131
+ "0.15 = shift up to 15% of bbox width/height in any direction."
2132
+ )
2133
+ config_layout.addRow("Crop jitter strength:", self.crop_jitter_strength_spin)
2134
+
2135
+ self.use_crop_jitter_check.stateChanged.connect(
2136
+ lambda s: self.crop_jitter_strength_spin.setEnabled(bool(s))
2137
+ )
2138
+
2139
+ loc_group.setLayout(config_layout)
2140
+
2141
+ # --- OvR & Hard Mining ---
2142
+ ovr_group = QGroupBox("OvR && Hard Mining")
2143
+ config_layout = QFormLayout()
2144
+
2145
+ self.use_ovr_check = QCheckBox("One-vs-Rest heads (OvR)")
2146
+ self.use_ovr_check.setToolTip(
2147
+ "Train independent binary heads per class instead of shared softmax.\n"
2148
+ "Better when 'other'/background is heterogeneous.\n"
2149
+ "near_negative_* clips automatically suppress their matched class.\n"
2150
+ "At inference: sigmoid scores + per-class threshold → Ignore if all low."
2151
+ )
2152
+ self.use_ovr_check.setChecked(False)
2153
+ self.use_ovr_check.stateChanged.connect(self._on_ovr_toggled)
2154
+ config_layout.addRow("", self.use_ovr_check)
2155
+
2156
+ self.ovr_background_negative_check = QCheckBox("Treat Other/background as OvR negatives")
2157
+ self.ovr_background_negative_check.setToolTip(
2158
+ "Hybrid OvR mode: remove helper classes like Other/Background from the trained heads,\n"
2159
+ "but keep those clips as all-zero negative supervision for the target heads.\n"
2160
+ "Useful when Other is heterogeneous and hurts learning real behaviors."
2161
+ )
2162
+ self.ovr_background_negative_check.setChecked(False)
2163
+ self.ovr_background_negative_check.setEnabled(False)
2164
+ config_layout.addRow("", self.ovr_background_negative_check)
2165
+
2166
+ self.ovr_label_smoothing_spin = QDoubleSpinBox()
2167
+ self.ovr_label_smoothing_spin.setRange(0.0, 0.3)
2168
+ self.ovr_label_smoothing_spin.setSingleStep(0.01)
2169
+ self.ovr_label_smoothing_spin.setValue(0.05)
2170
+ self.ovr_label_smoothing_spin.setDecimals(2)
2171
+ self.ovr_label_smoothing_spin.setToolTip(
2172
+ "Smooth binary targets from [0,1] to [eps, 1-eps].\n"
2173
+ "Prevents overconfident predictions and improves generalization.\n"
2174
+ "0 = no smoothing, 0.05 = recommended default, 0.1+ = strong regularization."
2175
+ )
2176
+ config_layout.addRow("OvR label smoothing:", self.ovr_label_smoothing_spin)
2177
+
2178
+ self.use_asl_check = QCheckBox("Use Asymmetric Loss (ASL)")
2179
+ self.use_asl_check.setToolTip(
2180
+ "Asymmetric Loss down-weights easy negatives more aggressively than positives.\n"
2181
+ "Recommended for OvR training where negatives dominate.\n"
2182
+ "γ- controls negative focusing, γ+ controls positive focusing."
2183
+ )
2184
+ self.use_asl_check.setChecked(True)
2185
+ self.use_asl_check.setEnabled(False)
2186
+ self.use_asl_check.stateChanged.connect(self._on_asl_toggled)
2187
+ config_layout.addRow("", self.use_asl_check)
2188
+
2189
+ asl_params_layout = QHBoxLayout()
2190
+ self.asl_gamma_neg_spin = QDoubleSpinBox()
2191
+ self.asl_gamma_neg_spin.setRange(0.0, 10.0)
2192
+ self.asl_gamma_neg_spin.setSingleStep(0.5)
2193
+ self.asl_gamma_neg_spin.setValue(2.0)
2194
+ self.asl_gamma_neg_spin.setDecimals(1)
2195
+ self.asl_gamma_neg_spin.setToolTip("Focusing parameter for negative samples (higher = more suppression)")
2196
+ asl_params_layout.addWidget(QLabel("γ-:"))
2197
+ asl_params_layout.addWidget(self.asl_gamma_neg_spin)
2198
+
2199
+ self.asl_gamma_pos_spin = QDoubleSpinBox()
2200
+ self.asl_gamma_pos_spin.setRange(0.0, 10.0)
2201
+ self.asl_gamma_pos_spin.setSingleStep(0.5)
2202
+ self.asl_gamma_pos_spin.setValue(0.0)
2203
+ self.asl_gamma_pos_spin.setDecimals(1)
2204
+ self.asl_gamma_pos_spin.setToolTip("Focusing parameter for positive samples (0 = no down-weighting)")
2205
+ asl_params_layout.addWidget(QLabel("γ+:"))
2206
+ asl_params_layout.addWidget(self.asl_gamma_pos_spin)
2207
+
2208
+ self.asl_clip_spin = QDoubleSpinBox()
2209
+ self.asl_clip_spin.setRange(0.0, 0.5)
2210
+ self.asl_clip_spin.setSingleStep(0.01)
2211
+ self.asl_clip_spin.setValue(0.05)
2212
+ self.asl_clip_spin.setDecimals(2)
2213
+ self.asl_clip_spin.setToolTip("Probability margin for hard thresholding negatives")
2214
+ asl_params_layout.addWidget(QLabel("clip:"))
2215
+ asl_params_layout.addWidget(self.asl_clip_spin)
2216
+
2217
+ self.asl_gamma_neg_spin.setEnabled(False)
2218
+ self.asl_gamma_pos_spin.setEnabled(False)
2219
+ self.asl_clip_spin.setEnabled(False)
2220
+ config_layout.addRow("ASL parameters:", asl_params_layout)
2221
+
2222
+ # Confusion-aware sampler (OvR only)
2223
+ self.use_confusion_sampler_check = QCheckBox("Confusion-aware hard mining")
2224
+ self.use_confusion_sampler_check.setToolTip(
2225
+ "After each epoch, clips that trigger the wrong OvR heads (e.g. groom clips\n"
2226
+ "that also activate dig's head) are sampled more often in the next epoch.\n"
2227
+ "Helps the model learn hard distinctions between visually similar behaviours.\n"
2228
+ "Temperature: how sharply to focus on the hardest clips (higher = more aggressive)."
2229
+ )
2230
+ self.use_confusion_sampler_check.setChecked(True)
2231
+ self.use_confusion_sampler_check.setEnabled(False)
2232
+ def _on_confusion_sampler_toggled(state):
2233
+ on = bool(state) and self.use_ovr_check.isChecked()
2234
+ self.confusion_temperature_spin.setEnabled(on)
2235
+ self.confusion_warmup_spin.setEnabled(on)
2236
+ self.use_confusion_sampler_check.stateChanged.connect(_on_confusion_sampler_toggled)
2237
+ config_layout.addRow("", self.use_confusion_sampler_check)
2238
+
2239
+ confusion_temp_layout = QHBoxLayout()
2240
+ self.confusion_temperature_spin = QDoubleSpinBox()
2241
+ self.confusion_temperature_spin.setRange(0.5, 8.0)
2242
+ self.confusion_temperature_spin.setSingleStep(0.5)
2243
+ self.confusion_temperature_spin.setValue(2.0)
2244
+ self.confusion_temperature_spin.setDecimals(1)
2245
+ self.confusion_temperature_spin.setToolTip(
2246
+ "Sharpness of the hard-mining distribution.\n"
2247
+ "1.0 = mild (all clips get similar weight).\n"
2248
+ "2.0 = moderate (default).\n"
2249
+ "4.0+ = aggressive (almost only hardest clips sampled)."
2250
+ )
2251
+ self.confusion_temperature_spin.setEnabled(False)
2252
+ confusion_temp_layout.addWidget(self.confusion_temperature_spin)
2253
+ config_layout.addRow("Confusion temperature:", confusion_temp_layout)
2254
+
2255
+ confusion_warmup_layout = QHBoxLayout()
2256
+ self.confusion_warmup_spin = QSpinBox()
2257
+ self.confusion_warmup_spin.setRange(0, 80)
2258
+ self.confusion_warmup_spin.setSingleStep(5)
2259
+ self.confusion_warmup_spin.setValue(20)
2260
+ self.confusion_warmup_spin.setSuffix("%")
2261
+ self.confusion_warmup_spin.setToolTip(
2262
+ "Percentage of total epochs to use uniform sampling before\n"
2263
+ "activating confusion-based hard mining.\n"
2264
+ "0% = active from start, 20% = default warmup."
2265
+ )
2266
+ self.confusion_warmup_spin.setEnabled(False)
2267
+ confusion_warmup_layout.addWidget(self.confusion_warmup_spin)
2268
+ config_layout.addRow("Confusion warmup:", confusion_warmup_layout)
2269
+
2270
+ self.use_hard_pair_mining_check = QCheckBox("Hard-pair mining")
2271
+ self.use_hard_pair_mining_check.setToolTip(
2272
+ "Add extra pairwise margin pressure for specific confusing class pairs.\n"
2273
+ "Use this for cases like rear vs digg where standard OvR is not enough."
2274
+ )
2275
+ self.use_hard_pair_mining_check.setChecked(False)
2276
+ self.use_hard_pair_mining_check.setEnabled(False)
2277
+ self.use_hard_pair_mining_check.stateChanged.connect(self._on_hard_pair_toggled)
2278
+ config_layout.addRow("", self.use_hard_pair_mining_check)
2279
+
2280
+ self.hard_pair_edit = QLineEdit()
2281
+ self.hard_pair_edit.setPlaceholderText("rear:digg, move:digg")
2282
+ self.hard_pair_edit.setEnabled(False)
2283
+ self.hard_pair_edit.setToolTip(
2284
+ "Comma-separated hard pairs in the form class_a:class_b.\n"
2285
+ "Example: rear:digg, move:digg"
2286
+ )
2287
+ config_layout.addRow("Hard pairs:", self.hard_pair_edit)
2288
+
2289
+ self.hard_pair_loss_weight_spin = QDoubleSpinBox()
2290
+ self.hard_pair_loss_weight_spin.setRange(0.0, 5.0)
2291
+ self.hard_pair_loss_weight_spin.setSingleStep(0.05)
2292
+ self.hard_pair_loss_weight_spin.setValue(0.2)
2293
+ self.hard_pair_loss_weight_spin.setDecimals(3)
2294
+ self.hard_pair_loss_weight_spin.setEnabled(False)
2295
+ self.hard_pair_loss_weight_spin.setToolTip(
2296
+ "Weight of the extra pair-margin loss added on top of the main frame loss."
2297
+ )
2298
+ config_layout.addRow("Hard-pair loss weight:", self.hard_pair_loss_weight_spin)
2299
+
2300
+ self.hard_pair_margin_spin = QDoubleSpinBox()
2301
+ self.hard_pair_margin_spin.setRange(0.0, 5.0)
2302
+ self.hard_pair_margin_spin.setSingleStep(0.05)
2303
+ self.hard_pair_margin_spin.setValue(0.5)
2304
+ self.hard_pair_margin_spin.setDecimals(2)
2305
+ self.hard_pair_margin_spin.setEnabled(False)
2306
+ self.hard_pair_margin_spin.setToolTip(
2307
+ "Required logit gap between the true class and its configured rival."
2308
+ )
2309
+ config_layout.addRow("Hard-pair margin:", self.hard_pair_margin_spin)
2310
+
2311
+ self.hard_pair_confusion_boost_spin = QDoubleSpinBox()
2312
+ self.hard_pair_confusion_boost_spin.setRange(1.0, 5.0)
2313
+ self.hard_pair_confusion_boost_spin.setSingleStep(0.1)
2314
+ self.hard_pair_confusion_boost_spin.setValue(1.5)
2315
+ self.hard_pair_confusion_boost_spin.setDecimals(2)
2316
+ self.hard_pair_confusion_boost_spin.setEnabled(False)
2317
+ self.hard_pair_confusion_boost_spin.setToolTip(
2318
+ "Extra multiplier for confusion-sampler scores when the top rival is a configured hard pair."
2319
+ )
2320
+ config_layout.addRow("Hard-pair sampler boost:", self.hard_pair_confusion_boost_spin)
2321
+
2322
+ self.use_weighted_sampler_check = QCheckBox("Use weighted random sampler")
2323
+ self.use_weighted_sampler_check.setToolTip("Oversample rare classes during training")
2324
+ self.use_weighted_sampler_check.setChecked(False)
2325
+ config_layout.addRow("", self.use_weighted_sampler_check)
2326
+
2327
+ ovr_group.setLayout(config_layout)
2328
+
2329
+ # --- Augmentation & Data ---
2330
+ data_group = QGroupBox("Augmentation && Data")
2331
+ config_layout = QFormLayout()
2332
+
2333
+ self.use_augmentation_check = QCheckBox("Use data augmentation")
2334
+ self.use_augmentation_check.setToolTip("Apply selected augmentations to training clips")
2335
+ self.use_augmentation_check.setChecked(False)
2336
+ self.use_augmentation_check.stateChanged.connect(self._on_use_augmentation_changed)
2337
+
2338
+ self.augmentation_options_btn = QPushButton("Augmentation options...")
2339
+ self.augmentation_options_btn.setToolTip("Choose which augmentations to apply during training")
2340
+ self.augmentation_options_btn.setEnabled(False)
2341
+ self.augmentation_options_btn.clicked.connect(self._open_augmentation_options_dialog)
2342
+
2343
+ augmentation_row = QHBoxLayout()
2344
+ augmentation_row.addWidget(self.use_augmentation_check)
2345
+ augmentation_row.addWidget(self.augmentation_options_btn)
2346
+ config_layout.addRow("", augmentation_row)
2347
+
2348
+ self.virtual_expansion_spin = QSpinBox()
2349
+ self.virtual_expansion_spin.setRange(1, 20)
2350
+ self.virtual_expansion_spin.setValue(5)
2351
+ self.virtual_expansion_spin.setEnabled(False)
2352
+ self.virtual_expansion_spin.setToolTip(
2353
+ "Virtual dataset expansion multiplier (only active when augmentation is on).\n"
2354
+ "Each unique clip is sampled this many times per epoch with different augmentations.\n"
2355
+ "Higher = more augmented variety per epoch; useful for small datasets."
2356
+ )
2357
+ config_layout.addRow("Virtual expansion (x):", self.virtual_expansion_spin)
2358
+ self.use_augmentation_check.stateChanged.connect(
2359
+ lambda s: self.virtual_expansion_spin.setEnabled(bool(s))
2360
+ )
2361
+
2362
+ self.use_stitch_check = QCheckBox("Clip-stitch augmentation")
2363
+ self.use_stitch_check.setToolTip(
2364
+ "Splice two clips from different classes with a fixed 50/50 split.\n"
2365
+ "Teaches the model per-frame behavior regardless of clip-level context,\n"
2366
+ "which improves inference on clips containing multiple behaviors."
2367
+ )
2368
+ self.use_stitch_check.setChecked(False)
2369
+ config_layout.addRow("", self.use_stitch_check)
2370
+
2371
+ self.stitch_prob_spin = QDoubleSpinBox()
2372
+ self.stitch_prob_spin.setRange(0.0, 1.0)
2373
+ self.stitch_prob_spin.setSingleStep(0.05)
2374
+ self.stitch_prob_spin.setValue(0.3)
2375
+ self.stitch_prob_spin.setDecimals(2)
2376
+ self.stitch_prob_spin.setEnabled(False)
2377
+ self.stitch_prob_spin.setToolTip(
2378
+ "Probability per sample of applying clip-stitch augmentation (0–1).\n"
2379
+ "0.3 means ~30% of training samples will be stitched mixed clips."
2380
+ )
2381
+ config_layout.addRow("Stitch probability:", self.stitch_prob_spin)
2382
+
2383
+ self.emb_aug_versions_spin = QSpinBox()
2384
+ self.emb_aug_versions_spin.setRange(1, 20)
2385
+ self.emb_aug_versions_spin.setValue(5)
2386
+ self.emb_aug_versions_spin.setEnabled(False)
2387
+ self.emb_aug_versions_spin.setToolTip(
2388
+ "Number of augmented embedding versions to pre-compute per clip.\n"
2389
+ "Each version applies a different random augmentation before the backbone.\n"
2390
+ "Higher = more diversity (larger cache, slower precompute, same training speed).\n"
2391
+ "Requires augmentation to be enabled. Recommended: 3–8."
2392
+ )
2393
+ config_layout.addRow("Cached aug versions:", self.emb_aug_versions_spin)
2394
+
2395
+ def _sync_emb_cache_controls():
2396
+ loc_on = self.use_localization_check.isChecked()
2397
+ aug_on = self.use_augmentation_check.isChecked()
2398
+ # Cached aug versions only useful when augmentation is on
2399
+ self.emb_aug_versions_spin.setEnabled(aug_on and not loc_on)
2400
+ # Multi-scale requires no localization
2401
+ self.use_multi_scale_check.setEnabled(not loc_on)
2402
+ if loc_on:
2403
+ self.use_multi_scale_check.setChecked(False)
2404
+
2405
+ self._sync_emb_cache_controls = _sync_emb_cache_controls
2406
+ self.use_augmentation_check.stateChanged.connect(lambda _: _sync_emb_cache_controls())
2407
+
2408
+ self.use_stitch_check.stateChanged.connect(
2409
+ lambda s: self.stitch_prob_spin.setEnabled(bool(s))
2410
+ )
2411
+ self.use_stitch_check.stateChanged.connect(lambda _s: self._sync_stitch_controls())
2412
+ self.use_localization_check.stateChanged.connect(lambda _s: self._sync_stitch_controls())
2413
+ self._sync_stitch_controls()
2414
+
2415
+ self.use_all_for_training_check = QCheckBox("Use all data for training (no validation)")
2416
+ self.use_all_for_training_check.setToolTip("Enable this for small datasets. Disables validation split.")
2417
+ self.use_all_for_training_check.stateChanged.connect(self._on_use_all_changed)
2418
+ config_layout.addRow("", self.use_all_for_training_check)
2419
+
2420
+ self.val_split_spin = QDoubleSpinBox()
2421
+ self.val_split_spin.setRange(0.0, 0.5)
2422
+ self.val_split_spin.setValue(0.2)
2423
+ self.val_split_spin.setDecimals(2)
2424
+ self.val_split_spin.setSingleStep(0.05)
2425
+ self.val_split_spin.setSuffix(" (20% = 0.2)")
2426
+ config_layout.addRow("Validation split:", self.val_split_spin)
2427
+
2428
+ self.auto_tune_check = QCheckBox("Auto-tune before final training")
2429
+ self.auto_tune_check.setToolTip(
2430
+ "Run a small random search over a few important training settings,\n"
2431
+ "then retrain once from scratch using the best candidate.\n"
2432
+ "Requires a validation split."
2433
+ )
2434
+ self.auto_tune_check.setChecked(False)
2435
+ self.auto_tune_check.stateChanged.connect(self._on_auto_tune_changed)
2436
+ config_layout.addRow("", self.auto_tune_check)
2437
+
2438
+ auto_tune_row = QHBoxLayout()
2439
+ self.auto_tune_runs_spin = QSpinBox()
2440
+ self.auto_tune_runs_spin.setRange(1, 32)
2441
+ self.auto_tune_runs_spin.setValue(8)
2442
+ self.auto_tune_runs_spin.setToolTip("Number of short candidate runs to evaluate before the final retrain.")
2443
+ auto_tune_row.addWidget(self.auto_tune_runs_spin)
2444
+ auto_tune_row.addWidget(QLabel("runs"))
2445
+ self.auto_tune_epochs_spin = QSpinBox()
2446
+ self.auto_tune_epochs_spin.setRange(1, 200)
2447
+ self.auto_tune_epochs_spin.setValue(12)
2448
+ self.auto_tune_epochs_spin.setToolTip("Epoch budget for each short auto-tune trial.")
2449
+ auto_tune_row.addWidget(self.auto_tune_epochs_spin)
2450
+ auto_tune_row.addWidget(QLabel("search epochs"))
2451
+ auto_tune_row.addStretch()
2452
+ config_layout.addRow("Auto-tune search:", auto_tune_row)
2453
+ self._on_auto_tune_changed(int(self.auto_tune_check.isChecked()))
2454
+
2455
+ self.select_classes_check = QCheckBox("Limit classes for training")
2456
+ self.select_classes_check.setToolTip("Select specific classes to use for training (useful for testing minimum examples needed)")
2457
+ self.select_classes_check.stateChanged.connect(self._on_select_classes_changed)
2458
+ config_layout.addRow("", self.select_classes_check)
2459
+
2460
+ self.class_selection_list = QListWidget()
2461
+ self.class_selection_list.setMaximumHeight(150)
2462
+ self.class_selection_list.setEnabled(False)
2463
+ config_layout.addRow("Selected classes:", self.class_selection_list)
2464
+
2465
+ self.limit_per_class_check = QCheckBox("Limit annotations per class")
2466
+ self.limit_per_class_check.setToolTip("Limit the maximum number of clips used per class for training")
2467
+ self.limit_per_class_check.stateChanged.connect(self._on_limit_per_class_changed)
2468
+ config_layout.addRow("", self.limit_per_class_check)
2469
+
2470
+ self.per_class_limit_table = QTableWidget()
2471
+ self.per_class_limit_table.setColumnCount(3)
2472
+ self.per_class_limit_table.setHorizontalHeaderLabels(["Class", "Max Train", "Max Val"])
2473
+ self.per_class_limit_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
2474
+ self.per_class_limit_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
2475
+ self.per_class_limit_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
2476
+ self.per_class_limit_table.setMaximumHeight(150)
2477
+ self.per_class_limit_table.setEnabled(False)
2478
+ config_layout.addRow("Per-class limits:", self.per_class_limit_table)
2479
+
2480
+ self.use_embedding_diversity_check = QCheckBox("Use embedding-based diversity selection")
2481
+ self.use_embedding_diversity_check.setToolTip("When limiting per-class samples, use VideoPrism embeddings to select the most diverse clips (Farthest Point Sampling). Better than random for capturing variety.")
2482
+ self.use_embedding_diversity_check.setChecked(False)
2483
+ config_layout.addRow("", self.use_embedding_diversity_check)
2484
+
2485
+ # Fine-tuning controls
2486
+ self.finetune_check = QCheckBox("Fine-tune existing model")
2487
+ self.finetune_check.setToolTip("Load weights from an existing model instead of training from scratch")
2488
+ self.finetune_check.stateChanged.connect(self._on_finetune_changed)
2489
+ config_layout.addRow("", self.finetune_check)
2490
+
2491
+ self.pretrained_path_edit = QLineEdit()
2492
+ self.pretrained_path_edit.setPlaceholderText("Path to existing .pt model")
2493
+ self.pretrained_path_edit.setEnabled(False)
2494
+ self.pretrained_browse_btn = QPushButton("Browse...")
2495
+ self.pretrained_browse_btn.setEnabled(False)
2496
+ self.pretrained_browse_btn.clicked.connect(self._browse_pretrained)
2497
+
2498
+ pretrained_layout = QHBoxLayout()
2499
+ pretrained_layout.addWidget(self.pretrained_path_edit)
2500
+ pretrained_layout.addWidget(self.pretrained_browse_btn)
2501
+ config_layout.addRow("Pretrained model:", pretrained_layout)
2502
+
2503
+ data_group.setLayout(config_layout)
2504
+
2505
+ row1 = QHBoxLayout()
2506
+ row1.addWidget(paths_group, 1)
2507
+ row1.addWidget(info_scroll, 1)
2508
+ config_vbox.addLayout(row1)
2509
+
2510
+ row2 = QHBoxLayout()
2511
+ row2.addWidget(hyper_group, 1)
2512
+ row2.addWidget(loc_group, 1)
2513
+ config_vbox.addLayout(row2)
2514
+
2515
+ row3 = QHBoxLayout()
2516
+ row3.addWidget(arch_group, 1)
2517
+ row3.addWidget(ovr_group, 1)
2518
+ config_vbox.addLayout(row3)
2519
+
2520
+ config_vbox.addWidget(data_group)
2521
+
2522
+ config_scroll = QScrollArea()
2523
+ config_scroll.setWidget(config_container)
2524
+ config_scroll.setWidgetResizable(True)
2525
+ config_scroll.setMinimumHeight(300)
2526
+ layout.addWidget(config_scroll, 1)
2527
+
2528
+ control_layout = QHBoxLayout()
2529
+ self.visualize_btn = QPushButton("Visualize training")
2530
+ self.visualize_btn.setToolTip("Open real-time training visualization")
2531
+ self.visualize_btn.clicked.connect(self._open_visualization)
2532
+ self.visualize_btn.setEnabled(False)
2533
+ control_layout.addWidget(self.visualize_btn)
2534
+
2535
+ self.advanced_btn = QPushButton("Advanced: Profiles")
2536
+ self.advanced_btn.setToolTip("Manage training profiles for batch experiments")
2537
+ self.advanced_btn.clicked.connect(self._open_profile_manager)
2538
+ control_layout.addWidget(self.advanced_btn)
2539
+
2540
+ control_layout.addStretch()
2541
+
2542
+ self.batch_train_check = QCheckBox("Batch Train Selected Profiles")
2543
+ self.batch_train_check.setToolTip("If checked, training will run sequentially for all profiles selected in the Advanced menu.")
2544
+ control_layout.addWidget(self.batch_train_check)
2545
+
2546
+ self.train_btn = QPushButton("Start training")
2547
+ self.train_btn.clicked.connect(self._start_training)
2548
+ self.stop_btn = QPushButton("Stop")
2549
+ self.stop_btn.clicked.connect(self._stop_training)
2550
+ self.stop_btn.setEnabled(False)
2551
+ control_layout.addWidget(self.train_btn)
2552
+ control_layout.addWidget(self.stop_btn)
2553
+ layout.addLayout(control_layout)
2554
+
2555
+ self.progress_bar = QProgressBar()
2556
+ self.progress_bar.setVisible(False)
2557
+ layout.addWidget(self.progress_bar)
2558
+
2559
+ log_group = QGroupBox("Training logs")
2560
+ log_layout = QVBoxLayout()
2561
+ self.log_text = QPlainTextEdit()
2562
+ self.log_text.setReadOnly(True)
2563
+ log_layout.addWidget(self.log_text)
2564
+ log_group.setLayout(log_layout)
2565
+ layout.addWidget(log_group)
2566
+
2567
+ self.setLayout(layout)
2568
+ self._on_temporal_decoder_toggled(int(self.use_temporal_decoder_check.isChecked()))
2569
+
2570
+ def _on_use_all_changed(self, state: int):
2571
+ """Enable/disable validation split controls."""
2572
+ use_all = self.use_all_for_training_check.isChecked()
2573
+ self.val_split_spin.setEnabled(not use_all)
2574
+
2575
+ if use_all:
2576
+ self.log_text.appendPlainText("Using all data for training - validation disabled")
2577
+ else:
2578
+ self.log_text.appendPlainText(f"Validation split enabled: {self.val_split_spin.value():.1%}")
2579
+
2580
+ def _on_temporal_decoder_toggled(self, state: int):
2581
+ """Enable only the controls that affect the temporal decoder branch."""
2582
+ use_decoder = self.use_temporal_decoder_check.isChecked()
2583
+ self.frame_head_layers_spin.setEnabled(use_decoder)
2584
+ self.temporal_pool_spin.setEnabled(use_decoder)
2585
+ self.boundary_loss_weight_spin.setEnabled(use_decoder)
2586
+ self.boundary_tolerance_spin.setEnabled(use_decoder)
2587
+ if not use_decoder:
2588
+ self.boundary_loss_weight_spin.setValue(0.0)
2589
+ elif self.boundary_loss_weight_spin.value() == 0.0:
2590
+ self.boundary_loss_weight_spin.setValue(0.3)
2591
+
2592
+ def _on_ovr_toggled(self, state: int):
2593
+ """OvR has built-in per-head class balancing — disable redundant class weights."""
2594
+ is_on = self.use_ovr_check.isChecked()
2595
+ self.ovr_background_negative_check.setEnabled(is_on)
2596
+ self.ovr_label_smoothing_spin.setEnabled(is_on)
2597
+ self.use_asl_check.setEnabled(is_on)
2598
+ self.asl_gamma_neg_spin.setEnabled(is_on and self.use_asl_check.isChecked())
2599
+ self.asl_gamma_pos_spin.setEnabled(is_on and self.use_asl_check.isChecked())
2600
+ self.asl_clip_spin.setEnabled(is_on and self.use_asl_check.isChecked())
2601
+ self.use_confusion_sampler_check.setEnabled(is_on)
2602
+ self.confusion_temperature_spin.setEnabled(
2603
+ is_on and self.use_confusion_sampler_check.isChecked()
2604
+ )
2605
+ self.confusion_warmup_spin.setEnabled(
2606
+ is_on and self.use_confusion_sampler_check.isChecked()
2607
+ )
2608
+ self.use_hard_pair_mining_check.setEnabled(is_on)
2609
+ self._on_hard_pair_toggled(state)
2610
+ if is_on:
2611
+ self.use_class_weights_check.setChecked(False)
2612
+ self.use_class_weights_check.setEnabled(False)
2613
+ else:
2614
+ self.use_class_weights_check.setEnabled(True)
2615
+
2616
+ def _on_asl_toggled(self, state: int):
2617
+ """Enable/disable ASL parameter controls."""
2618
+ is_on = self.use_asl_check.isChecked() and self.use_ovr_check.isChecked()
2619
+ self.asl_gamma_neg_spin.setEnabled(is_on)
2620
+ self.asl_gamma_pos_spin.setEnabled(is_on)
2621
+ self.asl_clip_spin.setEnabled(is_on)
2622
+
2623
+ def _on_hard_pair_toggled(self, state: int):
2624
+ """Enable/disable hard-pair controls."""
2625
+ is_on = self.use_ovr_check.isChecked() and self.use_hard_pair_mining_check.isChecked()
2626
+ self.hard_pair_edit.setEnabled(is_on)
2627
+ self.hard_pair_loss_weight_spin.setEnabled(is_on)
2628
+ self.hard_pair_margin_spin.setEnabled(is_on)
2629
+ self.hard_pair_confusion_boost_spin.setEnabled(is_on)
2630
+
2631
+ def _on_use_augmentation_changed(self, state: int):
2632
+ """Enable/disable augmentation options button and aug cache versions."""
2633
+ is_on = self.use_augmentation_check.isChecked()
2634
+ if hasattr(self, "augmentation_options_btn"):
2635
+ self.augmentation_options_btn.setEnabled(is_on)
2636
+ if not is_on and hasattr(self, "emb_aug_versions_spin"):
2637
+ self.emb_aug_versions_spin.setValue(1)
2638
+ if hasattr(self, "_sync_emb_cache_controls"):
2639
+ self._sync_emb_cache_controls()
2640
+
2641
+ def _on_localization_toggled(self, state: int):
2642
+ """Manage crop jitter controls when localization is toggled."""
2643
+ loc_on = bool(state)
2644
+ if loc_on:
2645
+ self.use_crop_jitter_check.setEnabled(True)
2646
+ self.crop_jitter_strength_spin.setEnabled(self.use_crop_jitter_check.isChecked())
2647
+ else:
2648
+ self.use_crop_jitter_check.setChecked(False)
2649
+ self.use_crop_jitter_check.setEnabled(False)
2650
+ self.crop_jitter_strength_spin.setEnabled(False)
2651
+
2652
+ def _sync_stitch_controls(self):
2653
+ """Disable stitching whenever localization supervision is enabled."""
2654
+ if not hasattr(self, "use_stitch_check"):
2655
+ return
2656
+ localization_on = bool(self.use_localization_check.isChecked()) if hasattr(self, "use_localization_check") else False
2657
+ if localization_on:
2658
+ self.use_stitch_check.setChecked(False)
2659
+ self.use_stitch_check.setEnabled(False)
2660
+ self.stitch_prob_spin.setEnabled(False)
2661
+ else:
2662
+ self.use_stitch_check.setEnabled(True)
2663
+ self.stitch_prob_spin.setEnabled(self.use_stitch_check.isChecked())
2664
+ if hasattr(self, "_sync_emb_cache_controls"):
2665
+ self._sync_emb_cache_controls()
2666
+
2667
+ def _parse_hard_pairs_text(self, text: str) -> list[list[str]]:
2668
+ """Parse comma-separated hard-pair text into [[class_a, class_b], ...]."""
2669
+ pairs = []
2670
+ seen = set()
2671
+ for chunk in (text or "").split(","):
2672
+ item = chunk.strip()
2673
+ if not item:
2674
+ continue
2675
+ parts = [p.strip() for p in re.split(r"\s*(?::|>|/|\bvs\b)\s*", item, maxsplit=1, flags=re.IGNORECASE) if p.strip()]
2676
+ if len(parts) != 2 or parts[0] == parts[1]:
2677
+ continue
2678
+ key = tuple(sorted(parts))
2679
+ if key in seen:
2680
+ continue
2681
+ seen.add(key)
2682
+ pairs.append([parts[0], parts[1]])
2683
+ return pairs
2684
+
2685
+ def _format_hard_pairs_text(self, pairs) -> str:
2686
+ """Format stored hard-pair config back into the UI text box."""
2687
+ if not isinstance(pairs, (list, tuple)):
2688
+ return ""
2689
+ items = []
2690
+ for pair in pairs:
2691
+ if not isinstance(pair, (list, tuple)) or len(pair) != 2:
2692
+ continue
2693
+ a_name = str(pair[0]).strip()
2694
+ b_name = str(pair[1]).strip()
2695
+ if not a_name or not b_name:
2696
+ continue
2697
+ items.append(f"{a_name}:{b_name}")
2698
+ return ", ".join(items)
2699
+
2700
+ def _default_augmentation_options(self) -> dict:
2701
+ return {
2702
+ "use_horizontal_flip": True,
2703
+ "use_vertical_flip": False,
2704
+ "use_color_jitter": True,
2705
+ "use_gaussian_blur": True,
2706
+ "use_random_noise": True,
2707
+ "use_small_rotation": False,
2708
+ "use_speed_perturb": False,
2709
+ "use_random_shapes": False,
2710
+ "use_grayscale": False,
2711
+ "use_lighting_robustness": True,
2712
+ }
2713
+
2714
+ def _normalize_augmentation_options(self, options: dict) -> dict:
2715
+ defaults = self._default_augmentation_options()
2716
+ if not isinstance(options, dict):
2717
+ return defaults
2718
+ for key in defaults:
2719
+ defaults[key] = bool(options.get(key, defaults[key]))
2720
+ return defaults
2721
+
2722
+ def _open_augmentation_options_dialog(self):
2723
+ """Open dialog to select augmentations."""
2724
+ dialog = QDialog(self)
2725
+ dialog.setWindowTitle("Augmentation Options")
2726
+
2727
+ layout = QVBoxLayout(dialog)
2728
+ grid = QGridLayout()
2729
+
2730
+ options = self._normalize_augmentation_options(self.augmentation_options)
2731
+
2732
+ hflip_check = QCheckBox("Random horizontal flip")
2733
+ hflip_check.setChecked(options["use_horizontal_flip"])
2734
+ grid.addWidget(hflip_check, 0, 0)
2735
+
2736
+ vflip_check = QCheckBox("Random vertical flip")
2737
+ vflip_check.setChecked(options["use_vertical_flip"])
2738
+ grid.addWidget(vflip_check, 1, 0)
2739
+
2740
+ color_check = QCheckBox("Color jitter (brightness/contrast/saturation/hue)")
2741
+ color_check.setChecked(options["use_color_jitter"])
2742
+ grid.addWidget(color_check, 2, 0)
2743
+
2744
+ blur_check = QCheckBox("Gaussian blur (0.1-0.5 sigma)")
2745
+ blur_check.setChecked(options["use_gaussian_blur"])
2746
+ grid.addWidget(blur_check, 3, 0)
2747
+
2748
+ noise_check = QCheckBox("Random noise (std=0.02)")
2749
+ noise_check.setChecked(options["use_random_noise"])
2750
+ grid.addWidget(noise_check, 4, 0)
2751
+
2752
+ rot_check = QCheckBox("Small rotation (+/- 5 degrees)")
2753
+ rot_check.setChecked(options["use_small_rotation"])
2754
+ grid.addWidget(rot_check, 5, 0)
2755
+
2756
+ speed_check = QCheckBox("Speed perturbation (0.7x - 1.3x)")
2757
+ speed_check.setChecked(options.get("use_speed_perturb", False))
2758
+ grid.addWidget(speed_check, 6, 0)
2759
+
2760
+ shapes_check = QCheckBox("Random shape overlays (occlusion)")
2761
+ shapes_check.setChecked(options.get("use_random_shapes", False))
2762
+ grid.addWidget(shapes_check, 7, 0)
2763
+
2764
+ gray_check = QCheckBox("Random grayscale (50% chance)")
2765
+ gray_check.setChecked(options.get("use_grayscale", False))
2766
+ grid.addWidget(gray_check, 8, 0)
2767
+
2768
+ light_check = QCheckBox("Lighting / color robustness")
2769
+ light_check.setToolTip("Clip-consistent gamma and per-channel gain jitter to reduce brightness/color bias.")
2770
+ light_check.setChecked(options.get("use_lighting_robustness", False))
2771
+ grid.addWidget(light_check, 9, 0)
2772
+
2773
+ layout.addLayout(grid)
2774
+
2775
+ buttons_layout = QHBoxLayout()
2776
+ buttons_layout.addStretch()
2777
+ ok_btn = QPushButton("OK")
2778
+ cancel_btn = QPushButton("Cancel")
2779
+ ok_btn.clicked.connect(dialog.accept)
2780
+ cancel_btn.clicked.connect(dialog.reject)
2781
+ buttons_layout.addWidget(ok_btn)
2782
+ buttons_layout.addWidget(cancel_btn)
2783
+ layout.addLayout(buttons_layout)
2784
+
2785
+ if dialog.exec() == QDialog.DialogCode.Accepted:
2786
+ self.augmentation_options = {
2787
+ "use_horizontal_flip": hflip_check.isChecked(),
2788
+ "use_vertical_flip": vflip_check.isChecked(),
2789
+ "use_color_jitter": color_check.isChecked(),
2790
+ "use_gaussian_blur": blur_check.isChecked(),
2791
+ "use_random_noise": noise_check.isChecked(),
2792
+ "use_small_rotation": rot_check.isChecked(),
2793
+ "use_speed_perturb": speed_check.isChecked(),
2794
+ "use_random_shapes": shapes_check.isChecked(),
2795
+ "use_grayscale": gray_check.isChecked(),
2796
+ "use_lighting_robustness": light_check.isChecked(),
2797
+ }
2798
+
2799
+ def _browse_annotation(self):
2800
+ """Browse for annotation file."""
2801
+ file_path, _ = QFileDialog.getOpenFileName(
2802
+ self,
2803
+ "Select Annotation File",
2804
+ self.config.get("data_dir", "data"),
2805
+ "JSON Files (*.json);;All Files (*)"
2806
+ )
2807
+ if file_path:
2808
+ self.annotation_file_edit.setText(file_path)
2809
+ self.refresh_annotation_info()
2810
+
2811
+ def _browse_clips_dir(self):
2812
+ """Browse for clips directory."""
2813
+ dir_path = QFileDialog.getExistingDirectory(
2814
+ self,
2815
+ "Select Clips Directory",
2816
+ self.config.get("clips_dir", "data/clips")
2817
+ )
2818
+ if dir_path:
2819
+ self.clips_dir_edit.setText(dir_path)
2820
+
2821
+ # Check for metadata
2822
+ import json
2823
+ meta_path = os.path.join(dir_path, "clips_metadata.json")
2824
+ if os.path.exists(meta_path):
2825
+ try:
2826
+ with open(meta_path, 'r') as f:
2827
+ meta = json.load(f)
2828
+
2829
+ clip_len = meta.get("clip_length")
2830
+ if clip_len:
2831
+ self.clip_length_spin.setValue(int(clip_len))
2832
+ self.log_text.appendPlainText(f"Automatically set 'Frames per clip' to {clip_len} from metadata.")
2833
+ except Exception as e:
2834
+ logger.error("Error reading clips metadata: %s", e)
2835
+
2836
+ def _show_clip_length_info(self):
2837
+ """Show information about frames per clip."""
2838
+ QMessageBox.information(
2839
+ self,
2840
+ "Frames per Clip - Information",
2841
+ "Number of frames the model will use for training.\n\n"
2842
+ "-Can be equal to or less than the actual clip length\n"
2843
+ "-If less, the middle N frames are selected (temporal center-crop)\n"
2844
+ " e.g. 16-frame clips with this set to 8 → frames 4-11 are used\n"
2845
+ "-Useful for testing whether shorter temporal context is sufficient\n\n"
2846
+ "Important:\n"
2847
+ "-Use the same value in the Inference tab when running predictions\n"
2848
+ "-Shorter clips = faster training and lower memory usage\n"
2849
+ "-If clips have fewer frames than this value, they are padded"
2850
+ )
2851
+
2852
+ def _browse_output(self):
2853
+ """Browse for output model path."""
2854
+ file_path, _ = QFileDialog.getSaveFileName(
2855
+ self,
2856
+ "Save Model",
2857
+ self.config.get("models_dir", "models/behavior_heads"),
2858
+ "PyTorch Files (*.pt);;All Files (*)"
2859
+ )
2860
+ if file_path:
2861
+ self.output_path_edit.setText(file_path)
2862
+
2863
+ def _browse_pretrained(self):
2864
+ """Browse for pretrained model path."""
2865
+ file_path, _ = QFileDialog.getOpenFileName(
2866
+ self,
2867
+ "Select Pretrained Model",
2868
+ self.config.get("models_dir", "models/behavior_heads"),
2869
+ "PyTorch Files (*.pt);;All Files (*)"
2870
+ )
2871
+ if file_path:
2872
+ self.pretrained_path_edit.setText(file_path)
2873
+
2874
+ def _on_finetune_changed(self, state: int):
2875
+ """Enable/disable pretrained path inputs."""
2876
+ enabled = self.finetune_check.isChecked()
2877
+ self.pretrained_path_edit.setEnabled(enabled)
2878
+ self.pretrained_browse_btn.setEnabled(enabled)
2879
+
2880
+ def _open_visualization(self):
2881
+ """Open the training visualization dialog."""
2882
+ if self.visualization_dialog is None:
2883
+ self.visualization_dialog = TrainingVisualizationDialog(self)
2884
+
2885
+ self.visualization_dialog.show()
2886
+ self.visualization_dialog.raise_()
2887
+ self.visualization_dialog.activateWindow()
2888
+
2889
+ def refresh_annotation_info(self):
2890
+ """Refresh dataset info display."""
2891
+ try:
2892
+ from collections import Counter
2893
+
2894
+ annotation_file = self.annotation_file_edit.text().strip()
2895
+ if not annotation_file:
2896
+ annotation_file = self.config.get(
2897
+ "training_annotation_file",
2898
+ self.config.get("annotation_file", "data/annotations/annotations.json"),
2899
+ )
2900
+
2901
+ self.annotation_manager = AnnotationManager(annotation_file)
2902
+
2903
+ labeled_clips = self.annotation_manager.get_labeled_clips()
2904
+ has_bboxes = any(
2905
+ clip.get("spatial_bbox") or clip.get("spatial_bbox_frames")
2906
+ for clip in labeled_clips
2907
+ )
2908
+ self.loc_group.setEnabled(has_bboxes)
2909
+ if has_bboxes:
2910
+ self.loc_group.setTitle("Localization")
2911
+ self.loc_group.setToolTip("")
2912
+ else:
2913
+ self.use_localization_check.setChecked(False)
2914
+ self.use_manual_loc_switch_check.setChecked(False)
2915
+ self.use_crop_jitter_check.setChecked(False)
2916
+ self.loc_group.setTitle("Localization (requires bbox labels)")
2917
+ self.loc_group.setToolTip(
2918
+ "Optional: localizes individual animals when multiple are in camera view.\n"
2919
+ "To activate, draw and save bounding boxes on at least some clips in the Labeling tab."
2920
+ )
2921
+ classes = sorted(self.annotation_manager.get_classes())
2922
+ real_classes = [c for c in classes if not c.startswith("near_negative")]
2923
+ hard_negative_classes = [c for c in classes if c.startswith("near_negative")]
2924
+ counts = self.annotation_manager.get_clip_count_by_label()
2925
+ primary_counts = Counter(
2926
+ clip.get("label", "")
2927
+ for clip in labeled_clips
2928
+ if clip.get("label")
2929
+ )
2930
+
2931
+ ml_stats = self.annotation_manager.get_multilabel_stats()
2932
+ exclusive = ml_stats["exclusive"]
2933
+ shared = ml_stats["shared"]
2934
+ combos = ml_stats["combos"]
2935
+
2936
+ info_text = f"Labeled clips: {len(labeled_clips)}\n"
2937
+ info_text += f"Behavior classes: {len(real_classes)}\n"
2938
+ if real_classes:
2939
+ info_text += f"Behavior names: {', '.join(real_classes)}\n"
2940
+ if hard_negative_classes:
2941
+ hn_clip_count = sum(primary_counts.get(label, 0) for label in hard_negative_classes)
2942
+ info_text += f"Hard-negative helper labels: {len(hard_negative_classes)} ({hn_clip_count} clips)\n"
2943
+
2944
+ info_text += "\nPrimary-label training counts:\n"
2945
+ for label in real_classes:
2946
+ primary_count = int(primary_counts.get(label, 0))
2947
+ membership_count = int(counts.get(label, 0))
2948
+ exc = int(exclusive.get(label, 0))
2949
+ sh = int(shared.get(label, 0))
2950
+ if sh > 0:
2951
+ info_text += (
2952
+ f" {label}: {primary_count} primary"
2953
+ f" ({membership_count} label memberships: {exc} exclusive, {sh} multi-class)\n"
2954
+ )
2955
+ else:
2956
+ info_text += f" {label}: {primary_count} primary\n"
2957
+
2958
+ if hard_negative_classes:
2959
+ info_text += "\nHard-negative suppression clips:\n"
2960
+ for label in hard_negative_classes:
2961
+ info_text += f" {label}: {int(primary_counts.get(label, 0))}\n"
2962
+
2963
+ real_combos = {
2964
+ combo: cnt
2965
+ for combo, cnt in combos.items()
2966
+ if combo and all(lbl in real_classes for lbl in combo)
2967
+ }
2968
+ if real_combos:
2969
+ total_mc = sum(real_combos.values())
2970
+ info_text += f"\nMulti-class clips: {total_mc}\n"
2971
+ for combo, cnt in sorted(real_combos.items(), key=lambda x: -x[1]):
2972
+ info_text += f" {' + '.join(combo)}: {cnt}\n"
2973
+
2974
+ self.info_label.setText(info_text)
2975
+
2976
+ # Update class selection list (exclude near_negative_* — they're auto-included in OvR)
2977
+ self.class_selection_list.clear()
2978
+ for class_name in real_classes:
2979
+ item = QListWidgetItem(class_name)
2980
+ item.setCheckState(Qt.CheckState.Checked)
2981
+ self.class_selection_list.addItem(item)
2982
+
2983
+ # Update per-class limit table
2984
+ self.per_class_limit_table.setRowCount(len(real_classes))
2985
+ for row, class_name in enumerate(real_classes):
2986
+ class_item = QTableWidgetItem(class_name)
2987
+ class_item.setFlags(class_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
2988
+ self.per_class_limit_table.setItem(row, 0, class_item)
2989
+
2990
+ count = int(primary_counts.get(class_name, 0))
2991
+
2992
+ # Max Train
2993
+ train_item = QTableWidgetItem(str(count))
2994
+ train_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
2995
+ self.per_class_limit_table.setItem(row, 1, train_item)
2996
+
2997
+ # Max Val (default to count i.e. unlimited/all)
2998
+ val_item = QTableWidgetItem(str(count))
2999
+ val_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
3000
+ self.per_class_limit_table.setItem(row, 2, val_item)
3001
+ except Exception as e:
3002
+ self.info_label.setText(f"Error loading info: {e}")
3003
+
3004
+ def _on_select_classes_changed(self, state: int):
3005
+ """Enable/disable class selection list."""
3006
+ self.class_selection_list.setEnabled(self.select_classes_check.isChecked())
3007
+
3008
+ def _on_limit_per_class_changed(self, state: int):
3009
+ """Enable/disable per-class limit table."""
3010
+ self.per_class_limit_table.setEnabled(self.limit_per_class_check.isChecked())
3011
+
3012
+ def _on_auto_tune_changed(self, state: int):
3013
+ """Enable/disable auto-tune controls."""
3014
+ is_on = self.auto_tune_check.isChecked()
3015
+ self.auto_tune_runs_spin.setEnabled(is_on)
3016
+ self.auto_tune_epochs_spin.setEnabled(is_on)
3017
+
3018
+ def _get_training_profiles_path(self):
3019
+ """Return the profile storage path for the current experiment."""
3020
+ config_path = self.config.get("config_path")
3021
+ if config_path:
3022
+ exp_dir = os.path.dirname(config_path)
3023
+ return os.path.join(exp_dir, "training_profiles.json")
3024
+ from singlebehaviorlab._paths import get_training_profiles_path
3025
+ return str(get_training_profiles_path())
3026
+
3027
+ def _open_profile_manager(self):
3028
+ """Open the training profiles manager dialog."""
3029
+ profiles_path = self._get_training_profiles_path()
3030
+
3031
+ if not self.profile_dialog:
3032
+ self.profile_dialog = TrainingProfileDialog(self, profiles_file=profiles_path)
3033
+ else:
3034
+ self.profile_dialog.reload_profiles(profiles_path)
3035
+
3036
+ self.profile_dialog.show()
3037
+ self.profile_dialog.raise_()
3038
+ self.profile_dialog.activateWindow()
3039
+
3040
+ def get_training_config(self):
3041
+ """Extract current training configuration from UI components."""
3042
+ head_kwargs = {
3043
+ "num_heads": self.map_num_heads_spin.value(),
3044
+ }
3045
+
3046
+ selected_classes = []
3047
+ if self.select_classes_check.isChecked():
3048
+ for i in range(self.class_selection_list.count()):
3049
+ item = self.class_selection_list.item(i)
3050
+ if item.checkState() == Qt.CheckState.Checked:
3051
+ selected_classes.append(item.text())
3052
+
3053
+ per_class_limits = {}
3054
+ per_class_val_limits = {}
3055
+ if self.limit_per_class_check.isChecked():
3056
+ for row in range(self.per_class_limit_table.rowCount()):
3057
+ class_item = self.per_class_limit_table.item(row, 0)
3058
+ train_item = self.per_class_limit_table.item(row, 1)
3059
+ val_item = self.per_class_limit_table.item(row, 2)
3060
+
3061
+ if class_item:
3062
+ class_name = class_item.text()
3063
+ # Train limit (allow float e.g. 1.5 → applied as 2)
3064
+ if train_item:
3065
+ try:
3066
+ val = float(train_item.text())
3067
+ if val > 0: per_class_limits[class_name] = val
3068
+ except ValueError: pass
3069
+
3070
+ # Val limit (allow float)
3071
+ if val_item:
3072
+ try:
3073
+ val = float(val_item.text())
3074
+ if val > 0: per_class_val_limits[class_name] = val
3075
+ except ValueError: pass
3076
+
3077
+ # Use selected classes for metadata if class limiting is enabled, otherwise use all classes
3078
+ if hasattr(self, 'annotation_manager'):
3079
+ all_classes = self.annotation_manager.get_classes()
3080
+ else:
3081
+ all_classes = []
3082
+ classes_for_metadata = selected_classes if self.select_classes_check.isChecked() and selected_classes else all_classes
3083
+ background_class_names = [
3084
+ c for c in classes_for_metadata
3085
+ if c.lower() in ("other", "background", "bg", "none")
3086
+ ]
3087
+ if self.use_ovr_check.isChecked() and self.ovr_background_negative_check.isChecked():
3088
+ classes_for_metadata = [c for c in classes_for_metadata if c not in background_class_names]
3089
+
3090
+ # Auto-detect helper classes (Other/Background) to exclude from F1 metrics
3091
+ _f1_exclude = [c for c in classes_for_metadata
3092
+ if c.lower() in ("other", "background", "bg", "none")]
3093
+
3094
+ # Get pretrained path if enabled
3095
+ pretrained_path = None
3096
+ if self.finetune_check.isChecked():
3097
+ pretrained_path = self.pretrained_path_edit.text().strip()
3098
+ hard_pairs = self._parse_hard_pairs_text(self.hard_pair_edit.text())
3099
+ resolution_cfg = int(self.config.get("resolution", self._resolution if hasattr(self, "_resolution") else 288))
3100
+ if resolution_cfg % 18 != 0:
3101
+ resolution_cfg = max(18, (resolution_cfg // 18) * 18)
3102
+
3103
+ return {
3104
+ "batch_size": self.batch_size_spin.value(),
3105
+ "epochs": self.epochs_spin.value(),
3106
+ "lr": self.class_lr_spin.value(),
3107
+ "localization_lr": self.loc_lr_spin.value(),
3108
+ "classification_lr": self.class_lr_spin.value(),
3109
+ "use_scheduler": bool(self.config.get("default_use_scheduler", True)),
3110
+ "use_ema": bool(self.config.get("default_use_ema", True)),
3111
+ "weight_decay": self.weight_decay_spin.value(),
3112
+ "head_kwargs": head_kwargs,
3113
+ "dropout": self.map_dropout_spin.value(),
3114
+ "clip_length": self.clip_length_spin.value(),
3115
+ "target_fps": int(self.config.get("default_target_fps", 16)),
3116
+ "resolution": resolution_cfg,
3117
+ "use_all_for_training": self.use_all_for_training_check.isChecked(),
3118
+ "val_split": self.val_split_spin.value(),
3119
+ "auto_tune_before_final": self.auto_tune_check.isChecked(),
3120
+ "auto_tune_runs": self.auto_tune_runs_spin.value(),
3121
+ "auto_tune_epochs": self.auto_tune_epochs_spin.value(),
3122
+ "use_class_weights": bool(getattr(self, "use_class_weights_check", None) and self.use_class_weights_check.isChecked()),
3123
+ "use_focal_loss": False,
3124
+ "focal_gamma": 2.0,
3125
+ "use_supcon_loss": self.use_supcon_check.isChecked(),
3126
+ "supcon_weight": self.supcon_weight_spin.value(),
3127
+ "supcon_temperature": self.supcon_temp_spin.value(),
3128
+ "use_frame_loss": True,
3129
+ "use_temporal_decoder": self.use_temporal_decoder_check.isChecked(),
3130
+ "frame_head_temporal_layers": self.frame_head_layers_spin.value(),
3131
+ "temporal_pool_frames": self.temporal_pool_spin.value(),
3132
+ "proj_dim": self.proj_dim_spin.value(),
3133
+ "use_frame_bout_balance": self.use_bout_balance_check.isChecked(),
3134
+ "frame_bout_balance_power": self.bout_balance_power_spin.value(),
3135
+ "boundary_loss_weight": self.boundary_loss_weight_spin.value() if self.use_temporal_decoder_check.isChecked() else 0.0,
3136
+ "boundary_tolerance": self.boundary_tolerance_spin.value(),
3137
+ "smoothness_loss_weight": self.smoothness_loss_weight_spin.value(),
3138
+ "use_localization": self.use_localization_check.isChecked(),
3139
+ "use_manual_localization_switch": self.use_manual_loc_switch_check.isChecked(),
3140
+ "manual_localization_switch_epoch": self.manual_loc_switch_epoch_spin.value(),
3141
+ "localization_hidden_dim": 256,
3142
+ "classification_crop_padding": self.crop_padding_spin.value(),
3143
+ "crop_jitter": self.use_crop_jitter_check.isChecked() and self.use_localization_check.isChecked(),
3144
+ "crop_jitter_strength": self.crop_jitter_strength_spin.value(),
3145
+ "use_ovr": self.use_ovr_check.isChecked(),
3146
+ "ovr_background_as_negative": (
3147
+ self.ovr_background_negative_check.isChecked()
3148
+ and self.use_ovr_check.isChecked()
3149
+ ),
3150
+ "ovr_background_class_names": background_class_names,
3151
+ "ovr_label_smoothing": self.ovr_label_smoothing_spin.value(),
3152
+ "use_asl": self.use_asl_check.isChecked() and self.use_ovr_check.isChecked(),
3153
+ "asl_gamma_neg": self.asl_gamma_neg_spin.value(),
3154
+ "asl_gamma_pos": self.asl_gamma_pos_spin.value(),
3155
+ "asl_clip": self.asl_clip_spin.value(),
3156
+ "use_hard_pair_mining": (
3157
+ self.use_hard_pair_mining_check.isChecked()
3158
+ and self.use_ovr_check.isChecked()
3159
+ and bool(hard_pairs)
3160
+ ),
3161
+ "hard_pairs": hard_pairs,
3162
+ "hard_pair_loss_weight": self.hard_pair_loss_weight_spin.value(),
3163
+ "hard_pair_margin": self.hard_pair_margin_spin.value(),
3164
+ "hard_pair_confusion_boost": self.hard_pair_confusion_boost_spin.value(),
3165
+ "use_confusion_sampler": (
3166
+ self.use_confusion_sampler_check.isChecked()
3167
+ and self.use_ovr_check.isChecked()
3168
+ ),
3169
+ "confusion_sampler_temperature": self.confusion_temperature_spin.value(),
3170
+ "confusion_sampler_warmup_pct": self.confusion_warmup_spin.value() / 100.0,
3171
+ "use_weighted_sampler": self.use_weighted_sampler_check.isChecked(),
3172
+ "use_augmentation": self.use_augmentation_check.isChecked(),
3173
+ "virtual_expansion": self.virtual_expansion_spin.value(),
3174
+ "stitch_augmentation_prob": self.stitch_prob_spin.value() if self.use_stitch_check.isChecked() else 0.0,
3175
+ "emb_aug_versions": self.emb_aug_versions_spin.value() if hasattr(self, "emb_aug_versions_spin") else 1,
3176
+ "multi_scale": self.use_multi_scale_check.isChecked() if hasattr(self, "use_multi_scale_check") else False,
3177
+ "augmentation_options": dict(self.augmentation_options),
3178
+ "limit_classes": self.select_classes_check.isChecked(),
3179
+ "selected_classes": selected_classes,
3180
+ "limit_per_class": self.limit_per_class_check.isChecked(),
3181
+ "per_class_limits": per_class_limits,
3182
+ "per_class_val_limits": per_class_val_limits,
3183
+ "use_embedding_diversity": self.use_embedding_diversity_check.isChecked(),
3184
+ "backbone_model": self.config.get("backbone_model", "videoprism_public_v1_base"),
3185
+ "class_names": classes_for_metadata,
3186
+ "pretrained_path": pretrained_path,
3187
+ "f1_exclude_classes": _f1_exclude,
3188
+ "ovr_pos_weight_f1_excluded": getattr(self, "_ovr_pos_weight_f1_excluded", 1.5),
3189
+ }
3190
+
3191
+ def apply_training_config(self, config):
3192
+ """Apply a configuration dictionary to the UI components."""
3193
+ try:
3194
+ self.use_localization_check.setChecked(False)
3195
+ self.use_manual_loc_switch_check.setChecked(False)
3196
+ self.manual_loc_switch_epoch_spin.setValue(20)
3197
+ self.crop_padding_spin.setValue(0.35)
3198
+ self.use_crop_jitter_check.setChecked(False)
3199
+ self.crop_jitter_strength_spin.setValue(0.15)
3200
+
3201
+ if "batch_size" in config: self.batch_size_spin.setValue(config["batch_size"])
3202
+ if "epochs" in config: self.epochs_spin.setValue(config["epochs"])
3203
+ if "lr" in config:
3204
+ self.class_lr_spin.setValue(config["lr"])
3205
+ if "classification_lr" in config:
3206
+ self.class_lr_spin.setValue(config["classification_lr"])
3207
+ if "localization_lr" in config:
3208
+ self.loc_lr_spin.setValue(config["localization_lr"])
3209
+ elif "classification_lr" in config:
3210
+ self.loc_lr_spin.setValue(config["classification_lr"])
3211
+ elif "lr" in config:
3212
+ self.loc_lr_spin.setValue(config["lr"])
3213
+ if "weight_decay" in config: self.weight_decay_spin.setValue(config["weight_decay"])
3214
+ if "dropout" in config: self.map_dropout_spin.setValue(config["dropout"])
3215
+ if "clip_length" in config: self.clip_length_spin.setValue(config["clip_length"])
3216
+ if "use_all_for_training" in config: self.use_all_for_training_check.setChecked(config["use_all_for_training"])
3217
+ if "val_split" in config: self.val_split_spin.setValue(config["val_split"])
3218
+ if "auto_tune_before_final" in config: self.auto_tune_check.setChecked(bool(config["auto_tune_before_final"]))
3219
+ if "auto_tune_runs" in config: self.auto_tune_runs_spin.setValue(int(config["auto_tune_runs"]))
3220
+ if "auto_tune_epochs" in config: self.auto_tune_epochs_spin.setValue(int(config["auto_tune_epochs"]))
3221
+
3222
+ # Loss settings
3223
+ if "use_class_weights" in config: self.use_class_weights_check.setChecked(config["use_class_weights"])
3224
+
3225
+ if "use_supcon_loss" in config:
3226
+ self.use_supcon_check.setChecked(bool(config["use_supcon_loss"]))
3227
+ if "supcon_weight" in config:
3228
+ self.supcon_weight_spin.setValue(float(config["supcon_weight"]))
3229
+ if "supcon_temperature" in config:
3230
+ self.supcon_temp_spin.setValue(float(config["supcon_temperature"]))
3231
+ if "use_temporal_decoder" in config:
3232
+ self.use_temporal_decoder_check.setChecked(bool(config["use_temporal_decoder"]))
3233
+
3234
+ if "frame_head_temporal_layers" in config:
3235
+ self.frame_head_layers_spin.setValue(int(config["frame_head_temporal_layers"]))
3236
+ if "temporal_pool_frames" in config:
3237
+ self.temporal_pool_spin.setValue(int(config["temporal_pool_frames"]))
3238
+ if "proj_dim" in config:
3239
+ self.proj_dim_spin.setValue(int(config["proj_dim"]))
3240
+ if "multi_scale" in config and hasattr(self, "use_multi_scale_check"):
3241
+ self.use_multi_scale_check.setChecked(bool(config["multi_scale"]))
3242
+ if "boundary_loss_weight" in config:
3243
+ self.boundary_loss_weight_spin.setValue(config["boundary_loss_weight"])
3244
+ if "boundary_tolerance" in config:
3245
+ self.boundary_tolerance_spin.setValue(int(config["boundary_tolerance"]))
3246
+ if "smoothness_loss_weight" in config:
3247
+ self.smoothness_loss_weight_spin.setValue(config["smoothness_loss_weight"])
3248
+ if "use_frame_bout_balance" in config:
3249
+ self.use_bout_balance_check.setChecked(bool(config["use_frame_bout_balance"]))
3250
+ if "frame_bout_balance_power" in config and config["frame_bout_balance_power"] is not None:
3251
+ self.bout_balance_power_spin.setValue(float(config["frame_bout_balance_power"]))
3252
+ if "use_localization" in config:
3253
+ self.use_localization_check.setChecked(config["use_localization"])
3254
+
3255
+ if "use_manual_localization_switch" in config:
3256
+ self.use_manual_loc_switch_check.setChecked(config["use_manual_localization_switch"])
3257
+ if "manual_localization_switch_epoch" in config:
3258
+ self.manual_loc_switch_epoch_spin.setValue(config["manual_localization_switch_epoch"])
3259
+ if "classification_crop_padding" in config:
3260
+ self.crop_padding_spin.setValue(float(config["classification_crop_padding"]))
3261
+ if "crop_jitter" in config:
3262
+ self.use_crop_jitter_check.setChecked(bool(config["crop_jitter"]))
3263
+ if "crop_jitter_strength" in config:
3264
+ self.crop_jitter_strength_spin.setValue(float(config["crop_jitter_strength"]))
3265
+
3266
+ if "use_ovr" in config: self.use_ovr_check.setChecked(config["use_ovr"])
3267
+ if "ovr_background_as_negative" in config:
3268
+ self.ovr_background_negative_check.setChecked(bool(config["ovr_background_as_negative"]))
3269
+ if "ovr_pos_weight_f1_excluded" in config:
3270
+ self._ovr_pos_weight_f1_excluded = float(config["ovr_pos_weight_f1_excluded"])
3271
+ if "ovr_label_smoothing" in config: self.ovr_label_smoothing_spin.setValue(config["ovr_label_smoothing"])
3272
+ if "use_asl" in config: self.use_asl_check.setChecked(config["use_asl"])
3273
+ if "asl_gamma_neg" in config: self.asl_gamma_neg_spin.setValue(float(config["asl_gamma_neg"]))
3274
+ if "asl_gamma_pos" in config: self.asl_gamma_pos_spin.setValue(float(config["asl_gamma_pos"]))
3275
+ if "asl_clip" in config: self.asl_clip_spin.setValue(float(config["asl_clip"]))
3276
+ if "hard_pairs" in config:
3277
+ self.hard_pair_edit.setText(self._format_hard_pairs_text(config["hard_pairs"]))
3278
+ if "use_hard_pair_mining" in config:
3279
+ self.use_hard_pair_mining_check.setChecked(bool(config["use_hard_pair_mining"]))
3280
+ if "hard_pair_loss_weight" in config and config["hard_pair_loss_weight"] is not None:
3281
+ self.hard_pair_loss_weight_spin.setValue(float(config["hard_pair_loss_weight"]))
3282
+ if "hard_pair_margin" in config and config["hard_pair_margin"] is not None:
3283
+ self.hard_pair_margin_spin.setValue(float(config["hard_pair_margin"]))
3284
+ if "hard_pair_confusion_boost" in config and config["hard_pair_confusion_boost"] is not None:
3285
+ self.hard_pair_confusion_boost_spin.setValue(float(config["hard_pair_confusion_boost"]))
3286
+ if "use_confusion_sampler" in config:
3287
+ self.use_confusion_sampler_check.setChecked(bool(config["use_confusion_sampler"]))
3288
+ if "confusion_sampler_temperature" in config:
3289
+ self.confusion_temperature_spin.setValue(float(config["confusion_sampler_temperature"]))
3290
+ if "confusion_sampler_warmup_pct" in config:
3291
+ self.confusion_warmup_spin.setValue(int(float(config["confusion_sampler_warmup_pct"]) * 100))
3292
+ if "use_weighted_sampler" in config: self.use_weighted_sampler_check.setChecked(config["use_weighted_sampler"])
3293
+ if "use_augmentation" in config: self.use_augmentation_check.setChecked(config["use_augmentation"])
3294
+ if "virtual_expansion" in config: self.virtual_expansion_spin.setValue(int(config["virtual_expansion"]))
3295
+ if "stitch_augmentation_prob" in config:
3296
+ prob = float(config["stitch_augmentation_prob"])
3297
+ self.use_stitch_check.setChecked(prob > 0.0)
3298
+ self.stitch_prob_spin.setValue(prob)
3299
+ if "emb_aug_versions" in config and hasattr(self, "emb_aug_versions_spin"):
3300
+ self.emb_aug_versions_spin.setValue(int(config["emb_aug_versions"]))
3301
+ if "augmentation_options" in config:
3302
+ self.augmentation_options = self._normalize_augmentation_options(config["augmentation_options"])
3303
+
3304
+ if "head_kwargs" in config:
3305
+ hk = config["head_kwargs"]
3306
+ if "num_heads" in hk: self.map_num_heads_spin.setValue(hk["num_heads"])
3307
+
3308
+ if "limit_classes" in config:
3309
+ self.select_classes_check.setChecked(config["limit_classes"])
3310
+ # Restore selected classes if possible
3311
+ if config["limit_classes"] and "selected_classes" in config:
3312
+ selected_set = set(config["selected_classes"])
3313
+ # Ensure list is populated (might need refresh if empty, but usually populated)
3314
+ if self.class_selection_list.count() == 0:
3315
+ self.refresh_annotation_info()
3316
+
3317
+ for i in range(self.class_selection_list.count()):
3318
+ item = self.class_selection_list.item(i)
3319
+ if item.text() in selected_set:
3320
+ item.setCheckState(Qt.CheckState.Checked)
3321
+ else:
3322
+ item.setCheckState(Qt.CheckState.Unchecked)
3323
+
3324
+ if "limit_per_class" in config:
3325
+ self.limit_per_class_check.setChecked(config["limit_per_class"])
3326
+ # Restore per-class limits
3327
+ if config["limit_per_class"]:
3328
+ limits = config.get("per_class_limits", {})
3329
+ val_limits = config.get("per_class_val_limits", {})
3330
+
3331
+ if self.per_class_limit_table.rowCount() == 0:
3332
+ self.refresh_annotation_info()
3333
+
3334
+ for row in range(self.per_class_limit_table.rowCount()):
3335
+ class_item = self.per_class_limit_table.item(row, 0)
3336
+ if class_item:
3337
+ name = class_item.text()
3338
+ if name in limits:
3339
+ self.per_class_limit_table.setItem(row, 1, QTableWidgetItem(str(limits[name])))
3340
+ if name in val_limits:
3341
+ self.per_class_limit_table.setItem(row, 2, QTableWidgetItem(str(val_limits[name])))
3342
+
3343
+ if "use_embedding_diversity" in config: self.use_embedding_diversity_check.setChecked(config["use_embedding_diversity"])
3344
+ if "pretrained_path" in config:
3345
+ if config["pretrained_path"]:
3346
+ self.finetune_check.setChecked(True)
3347
+ self.pretrained_path_edit.setText(config["pretrained_path"])
3348
+ else:
3349
+ self.finetune_check.setChecked(False)
3350
+ self._sync_stitch_controls()
3351
+
3352
+ except Exception as e:
3353
+ logger.error("Error applying config: %s", e)
3354
+ raise
3355
+
3356
+ def _run_next_batch_item(self):
3357
+ """Run the next profile in the batch queue."""
3358
+ if not self.training_queue:
3359
+ self.is_batch_training = False
3360
+
3361
+ # Show batch summary
3362
+ if self.batch_results:
3363
+ best_result = max(
3364
+ self.batch_results,
3365
+ key=lambda x: (
3366
+ x.get("best_val_f1", 0.0),
3367
+ x.get("best_val_acc", 0.0)
3368
+ )
3369
+ )
3370
+ summary = f"Batch training completed!\n\n"
3371
+ summary += f"Best Profile: {best_result['profile_name']}\n"
3372
+ summary += f"Best Val Macro F1: {best_result.get('best_val_f1', 0.0):.2f}%\n"
3373
+ summary += f"Best Val Acc: {best_result.get('best_val_acc', 0.0):.2f}%\n\n"
3374
+ summary += f"Results saved to:\n{self.batch_results_path}"
3375
+ QMessageBox.information(self, "Batch Training Complete", summary)
3376
+ else:
3377
+ QMessageBox.information(self, "Batch Training", "Batch training completed!")
3378
+
3379
+ self.train_btn.setEnabled(True)
3380
+ self.stop_btn.setEnabled(False)
3381
+ self.progress_bar.setVisible(False)
3382
+ return
3383
+
3384
+ profile_name, config = self.training_queue.pop(0)
3385
+ self.current_profile_name = profile_name
3386
+ self.log_text.appendPlainText(f"\n=== Starting Batch Item: {profile_name} ===")
3387
+
3388
+ # Apply config to UI (so it's visible what's running)
3389
+ try:
3390
+ self.apply_training_config(config)
3391
+ except Exception as e:
3392
+ self.log_text.appendPlainText(f"Error applying profile {profile_name}: {e}")
3393
+ self._run_next_batch_item() # Skip bad profile
3394
+ return
3395
+
3396
+ # Determine output path with profile suffix
3397
+ base_output = self.output_path_edit.text().strip()
3398
+ dir_name = os.path.dirname(base_output)
3399
+ file_name = os.path.basename(base_output)
3400
+ name_root, ext = os.path.splitext(file_name)
3401
+ new_output = os.path.join(dir_name, f"{name_root}_{profile_name}{ext}")
3402
+
3403
+ self._start_training_internal(override_output_path=new_output, profile_name=profile_name)
3404
+
3405
+ def _start_training_internal(self, override_output_path=None, profile_name=None):
3406
+ """Internal method to start a single training run."""
3407
+ annotation_file = self.annotation_file_edit.text().strip()
3408
+ if not os.path.exists(annotation_file):
3409
+ if not self.is_batch_training: QMessageBox.warning(self, "Error", "Annotation file not found.")
3410
+ else: self.log_text.appendPlainText("Error: Annotation file not found.")
3411
+ return
3412
+
3413
+ clips_dir = self.clips_dir_edit.text().strip()
3414
+ if not os.path.exists(clips_dir):
3415
+ if not self.is_batch_training: QMessageBox.warning(self, "Error", "Clips directory not found.")
3416
+ else: self.log_text.appendPlainText("Error: Clips directory not found.")
3417
+ return
3418
+
3419
+ output_path = override_output_path or self.output_path_edit.text().strip()
3420
+ if not output_path:
3421
+ if not self.is_batch_training: QMessageBox.warning(self, "Error", "Please specify output path.")
3422
+ else: self.log_text.appendPlainText("Error: No output path.")
3423
+ return
3424
+
3425
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
3426
+
3427
+ # Get configuration from UI
3428
+ train_config = self.get_training_config()
3429
+ if profile_name:
3430
+ train_config["profile_name"] = profile_name
3431
+
3432
+ train_config["training_profiles_path"] = self._get_training_profiles_path()
3433
+
3434
+ raw_resolution = int(self.config.get("resolution", 288))
3435
+ adjusted_resolution = train_config.get("resolution", raw_resolution)
3436
+ if raw_resolution % 18 != 0:
3437
+ msg = (
3438
+ "Resolution must be a multiple of 18 (patch size).\n\n"
3439
+ f"Entered: {raw_resolution}\n"
3440
+ f"Adjusted to: {adjusted_resolution}"
3441
+ )
3442
+ if not self.is_batch_training:
3443
+ QMessageBox.information(self, "Resolution adjusted", msg)
3444
+ else:
3445
+ self.log_text.appendPlainText(f"Resolution adjusted: {msg}")
3446
+ self._resolution = adjusted_resolution
3447
+
3448
+ # Check pretrained path validity if needed
3449
+ if self.finetune_check.isChecked():
3450
+ if not train_config["pretrained_path"] or not os.path.exists(train_config["pretrained_path"]):
3451
+ if not self.is_batch_training: QMessageBox.warning(self, "Error", "Please select a valid pretrained model file.")
3452
+ else: self.log_text.appendPlainText("Error: Invalid pretrained path.")
3453
+ return
3454
+
3455
+ if train_config.get("auto_tune_before_final", False):
3456
+ if train_config.get("use_all_for_training", False) or float(train_config.get("val_split", 0.0)) <= 0.0:
3457
+ msg = "Auto-tune requires a validation split. Disable 'Use all data for training' and set validation split > 0."
3458
+ if not self.is_batch_training: QMessageBox.warning(self, "Error", msg)
3459
+ else: self.log_text.appendPlainText(f"Error: {msg}")
3460
+ return
3461
+
3462
+ head_kwargs = train_config["head_kwargs"]
3463
+
3464
+ # Persist actual training parameters to experiment config.yaml
3465
+ try:
3466
+ config_path = self.config.get("config_path")
3467
+ if config_path:
3468
+ # Update last-used defaults for convenience
3469
+ self.config["default_batch_size"] = train_config["batch_size"]
3470
+ self.config["default_epochs"] = train_config["epochs"]
3471
+ self.config["default_learning_rate"] = train_config["classification_lr"]
3472
+ self.config["default_localization_lr"] = train_config["localization_lr"]
3473
+ self.config["default_classification_lr"] = train_config["classification_lr"]
3474
+ self.config["default_use_scheduler"] = train_config["use_scheduler"]
3475
+ self.config["default_use_ema"] = train_config["use_ema"]
3476
+ self.config["default_weight_decay"] = train_config["weight_decay"]
3477
+ self.config["default_clip_length"] = train_config["clip_length"]
3478
+ self.config["default_use_focal_loss"] = train_config["use_focal_loss"]
3479
+ self.config["default_focal_gamma"] = train_config["focal_gamma"]
3480
+ self.config["default_use_supcon_loss"] = train_config["use_supcon_loss"]
3481
+ self.config["default_supcon_weight"] = train_config["supcon_weight"]
3482
+ self.config["default_supcon_temperature"] = train_config["supcon_temperature"]
3483
+ self.config["backbone_model"] = train_config["backbone_model"]
3484
+ self.config["resolution"] = train_config["resolution"]
3485
+ # Save full last training block
3486
+ self.config["last_training"] = {
3487
+ "parameters": {
3488
+ "batch_size": train_config["batch_size"],
3489
+ "epochs": train_config["epochs"],
3490
+ "lr": train_config["classification_lr"],
3491
+ "localization_lr": train_config["localization_lr"],
3492
+ "classification_lr": train_config["classification_lr"],
3493
+ "use_scheduler": train_config["use_scheduler"],
3494
+ "use_ema": train_config["use_ema"],
3495
+ "weight_decay": train_config["weight_decay"],
3496
+ "clip_length": train_config["clip_length"],
3497
+ "val_split": train_config["val_split"],
3498
+ "auto_tune_before_final": train_config["auto_tune_before_final"],
3499
+ "auto_tune_runs": train_config["auto_tune_runs"],
3500
+ "auto_tune_epochs": train_config["auto_tune_epochs"],
3501
+ "use_class_weights": train_config["use_class_weights"],
3502
+ "use_focal_loss": train_config["use_focal_loss"],
3503
+ "focal_gamma": train_config["focal_gamma"],
3504
+ "use_weighted_sampler": train_config["use_weighted_sampler"],
3505
+ "use_augmentation": train_config["use_augmentation"],
3506
+ "augmentation_options": train_config.get("augmentation_options", {}),
3507
+ "limit_classes": train_config["limit_classes"],
3508
+ "limit_per_class": train_config["limit_per_class"],
3509
+ },
3510
+ "selected_classes": train_config["selected_classes"],
3511
+ "per_class_limits": train_config["per_class_limits"],
3512
+ "head": {
3513
+ "dropout": train_config["dropout"],
3514
+ "map_head_kwargs": train_config["head_kwargs"],
3515
+ },
3516
+ "pretrained_path": train_config["pretrained_path"],
3517
+ "output_path": output_path
3518
+ }
3519
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
3520
+ with open(config_path, "w", encoding="utf-8") as f:
3521
+ yaml.safe_dump(dict(self.config), f, sort_keys=False)
3522
+ self.log_text.appendPlainText(f"Saved training parameters to: {config_path}")
3523
+ except Exception as e:
3524
+ self.log_text.appendPlainText(f"Warning: Could not save training parameters to config.yaml: {e}")
3525
+
3526
+ self.log_text.appendPlainText(f"Starting training run (Output: {os.path.basename(output_path)})...")
3527
+ self.progress_bar.setVisible(True)
3528
+ self.progress_bar.setValue(0)
3529
+ self.train_btn.setEnabled(False)
3530
+ self.stop_btn.setEnabled(True)
3531
+ self.visualize_btn.setEnabled(True)
3532
+
3533
+ # Open visualization automatically on start
3534
+ self._open_visualization()
3535
+
3536
+ self.worker = TrainingWorker(
3537
+ self.config,
3538
+ train_config,
3539
+ annotation_file,
3540
+ clips_dir,
3541
+ output_path
3542
+ )
3543
+ self.worker.log_message.connect(self._on_log)
3544
+ self.worker.progress.connect(self._on_progress)
3545
+ self.worker.training_complete.connect(self._on_training_complete)
3546
+ self.worker.finished.connect(self._on_finished)
3547
+ self.worker.error.connect(self._on_error)
3548
+ # Reset and connect visualization update; restrict F1 graph to selected classes when limit_classes is on
3549
+ if self.visualization_dialog:
3550
+ f1_exclude = set(train_config.get("f1_exclude_classes", []))
3551
+ f1_classes = None
3552
+ if train_config.get("limit_classes", False):
3553
+ sel = train_config.get("selected_classes", [])
3554
+ if sel:
3555
+ f1_classes = set(sel) - f1_exclude
3556
+ elif f1_exclude:
3557
+ f1_classes = set(train_config.get("class_names", [])) - f1_exclude
3558
+ confusion_warmup_ep = 0
3559
+ if train_config.get("use_confusion_sampler", False):
3560
+ pct = float(train_config.get("confusion_sampler_warmup_pct", 0.2))
3561
+ confusion_warmup_ep = int(train_config.get("epochs", 60) * pct)
3562
+ self.visualization_dialog.reset(
3563
+ f1_classes_to_show=f1_classes if f1_classes else None,
3564
+ confusion_warmup_epoch=confusion_warmup_ep,
3565
+ )
3566
+ self.worker.epoch_complete.connect(self.visualization_dialog.update_plots)
3567
+
3568
+ self.worker.start()
3569
+
3570
+ def _start_training(self):
3571
+ """Start training (single or batch)."""
3572
+ if self.worker and self.worker.isRunning():
3573
+ QMessageBox.warning(self, "Training", "Training is already running.")
3574
+ return
3575
+
3576
+ if self.batch_train_check.isChecked():
3577
+ if not self.profile_dialog:
3578
+ self.profile_dialog = TrainingProfileDialog(self, profiles_file=self._get_training_profiles_path())
3579
+ else:
3580
+ self.profile_dialog.reload_profiles(self._get_training_profiles_path())
3581
+
3582
+ profiles = self.profile_dialog.get_selected_profiles_for_batch()
3583
+ if not profiles:
3584
+ QMessageBox.warning(self, "Batch Training", "No profiles selected in Advanced > Profiles.\nPlease open Advanced settings and check profiles to run.")
3585
+ return
3586
+
3587
+ self.training_queue = profiles
3588
+ self.is_batch_training = True
3589
+ self.batch_results = []
3590
+
3591
+ # Initialize batch results CSV path
3592
+ from datetime import datetime
3593
+ output_dir = os.path.dirname(self.output_path_edit.text().strip())
3594
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3595
+ self.batch_results_path = os.path.join(output_dir, f"batch_training_results_{timestamp}.csv")
3596
+ self.log_text.appendPlainText(f"Batch results will be saved to: {self.batch_results_path}")
3597
+
3598
+ self.train_btn.setEnabled(False)
3599
+ self.stop_btn.setEnabled(True)
3600
+ self._run_next_batch_item()
3601
+ else:
3602
+ self._start_training_internal()
3603
+
3604
+ def _stop_training(self):
3605
+ """Stop training."""
3606
+ if self.worker and self.worker.isRunning():
3607
+ self.is_batch_training = False # Stop batch execution
3608
+ self.training_queue = [] # Clear queue
3609
+ self.worker.stop()
3610
+ self.log_text.appendPlainText("Stopping training...")
3611
+ self.worker.wait()
3612
+ self._on_finished()
3613
+
3614
+ def _on_log(self, message: str):
3615
+ """Handle log message."""
3616
+ self.log_text.appendPlainText(message)
3617
+
3618
+ def _on_progress(self, epoch: int, total: int):
3619
+ """Update progress bar."""
3620
+ if total > 0:
3621
+ progress = int(100 * epoch / total)
3622
+ self.progress_bar.setValue(progress)
3623
+
3624
+ def _on_training_complete(self, best_val_acc: float, best_val_f1: float, final_train_acc: float, per_class_f1: dict):
3625
+ """Handle training metrics from completed run."""
3626
+ summary_msg = (
3627
+ f"Training summary → Best Val Macro F1: {best_val_f1:.2f}%, "
3628
+ f"Best Val Acc: {best_val_acc:.2f}%, Final Train Acc: {final_train_acc:.2f}%"
3629
+ )
3630
+ self.log_text.appendPlainText(summary_msg)
3631
+
3632
+ if per_class_f1:
3633
+ self.log_text.appendPlainText("Per-class F1 scores (Final Epoch):")
3634
+ for cls_name, f1 in sorted(per_class_f1.items()):
3635
+ self.log_text.appendPlainText(f" - {cls_name}: {f1:.2f}%")
3636
+
3637
+ if self.is_batch_training and self.current_profile_name:
3638
+ # Get current config for this profile
3639
+ current_config = self.get_training_config()
3640
+
3641
+ result = {
3642
+ "profile_name": self.current_profile_name,
3643
+ "best_val_acc": round(best_val_acc, 2),
3644
+ "best_val_f1": round(best_val_f1, 2),
3645
+ "final_train_acc": round(final_train_acc, 2),
3646
+ "epochs": current_config.get("epochs", 0),
3647
+ "batch_size": current_config.get("batch_size", 0),
3648
+ "lr": current_config.get("lr", 0),
3649
+ "use_focal_loss": current_config.get("use_focal_loss", False),
3650
+ }
3651
+
3652
+ # Add per-class F1 columns
3653
+ if per_class_f1:
3654
+ for cls_name, f1 in per_class_f1.items():
3655
+ # Sanitize column name
3656
+ safe_name = f"F1_{cls_name}".replace(" ", "_")
3657
+ result[safe_name] = round(f1, 2)
3658
+
3659
+ self.batch_results.append(result)
3660
+
3661
+ # Save/update CSV after each profile
3662
+ self._save_batch_results_csv()
3663
+
3664
+ def _save_batch_results_csv(self):
3665
+ """Save batch results to CSV."""
3666
+ if not self.batch_results or not self.batch_results_path:
3667
+ return
3668
+ try:
3669
+ import pandas as pd
3670
+ df = pd.DataFrame(self.batch_results)
3671
+ # Sort by best_val_f1 (primary) then accuracy
3672
+ sort_cols = [col for col in ["best_val_f1", "best_val_acc"] if col in df.columns]
3673
+ if sort_cols:
3674
+ df = df.sort_values(sort_cols, ascending=False)
3675
+ df.to_csv(self.batch_results_path, index=False)
3676
+ self.log_text.appendPlainText(f" → Updated batch results: {self.batch_results_path}")
3677
+ except Exception as e:
3678
+ self.log_text.appendPlainText(f" Failed to save batch results: {e}")
3679
+
3680
+ def _on_finished(self):
3681
+ """Handle training completion."""
3682
+ if self.is_batch_training:
3683
+ self.refresh_annotation_info()
3684
+ self._run_next_batch_item()
3685
+ return
3686
+
3687
+ self.train_btn.setEnabled(True)
3688
+ self.stop_btn.setEnabled(False)
3689
+ self.visualize_btn.setEnabled(False)
3690
+ self.progress_bar.setVisible(False)
3691
+ if self.profile_dialog:
3692
+ self.profile_dialog.reload_profiles(self._get_training_profiles_path())
3693
+ QMessageBox.information(self, "Training", "Training completed!")
3694
+ self.refresh_annotation_info()
3695
+
3696
+ def _on_error(self, error_msg: str):
3697
+ """Handle training error."""
3698
+ self.log_text.appendPlainText(f"ERROR: {error_msg}")
3699
+
3700
+ if self.is_batch_training:
3701
+ QMessageBox.critical(self, "Training Error", f"Batch training failed on current item:\n{error_msg}\n\nBatch execution stopped.")
3702
+ # Stop batch on error
3703
+ self.is_batch_training = False
3704
+ self.train_btn.setEnabled(True)
3705
+ self.stop_btn.setEnabled(False)
3706
+ self.progress_bar.setVisible(False)
3707
+ else:
3708
+ self._on_finished()
3709
+ QMessageBox.critical(self, "Training Error", f"Training failed:\n{error_msg}")
3710
+
3711
+ def update_config(self, config: dict):
3712
+ """Apply a new configuration (experiment management)."""
3713
+ self.config = config
3714
+ self.annotation_manager = AnnotationManager(
3715
+ self.config.get("annotation_file", "data/annotations/annotations.json")
3716
+ )
3717
+ self._config_initialized = False
3718
+ self._load_current_config(force=True)
3719
+ self.refresh_annotation_info()