singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,3719 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from PyQt6.QtWidgets import (
|
|
3
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSpinBox,
|
|
4
|
+
QDoubleSpinBox, QPlainTextEdit, QProgressBar, QGroupBox, QFormLayout,
|
|
5
|
+
QFileDialog, QMessageBox, QLineEdit, QCheckBox, QToolButton,
|
|
6
|
+
QScrollArea, QListWidget, QListWidgetItem, QTableWidget, QTableWidgetItem,
|
|
7
|
+
QHeaderView, QDialog, QGridLayout, QSizePolicy
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
from PyQt6.QtCore import QThread, pyqtSignal, Qt
|
|
12
|
+
from PyQt6.QtGui import QPixmap
|
|
13
|
+
import matplotlib
|
|
14
|
+
matplotlib.use('QtAgg')
|
|
15
|
+
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
|
|
16
|
+
from matplotlib.figure import Figure
|
|
17
|
+
import os
|
|
18
|
+
import copy
|
|
19
|
+
import shutil
|
|
20
|
+
import tempfile
|
|
21
|
+
import json
|
|
22
|
+
import time
|
|
23
|
+
import glob
|
|
24
|
+
import torch
|
|
25
|
+
import yaml
|
|
26
|
+
import numpy as np
|
|
27
|
+
import cv2
|
|
28
|
+
import random
|
|
29
|
+
import re
|
|
30
|
+
from singlebehaviorlab.backend.data_store import AnnotationManager
|
|
31
|
+
from singlebehaviorlab.backend.model import VideoPrismBackbone, BehaviorClassifier
|
|
32
|
+
from singlebehaviorlab.backend.train import BehaviorDataset, train_model
|
|
33
|
+
from singlebehaviorlab.backend.video_utils import load_clip_frames
|
|
34
|
+
from .training_profiles import TrainingProfileDialog
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TrainingWorker(QThread):
|
|
38
|
+
"""Worker thread for training."""
|
|
39
|
+
log_message = pyqtSignal(str)
|
|
40
|
+
progress = pyqtSignal(int, int)
|
|
41
|
+
finished = pyqtSignal()
|
|
42
|
+
error = pyqtSignal(str)
|
|
43
|
+
training_complete = pyqtSignal(float, float, float, dict) # (best_val_acc, best_val_f1, final_train_acc, per_class_f1)
|
|
44
|
+
epoch_complete = pyqtSignal(dict) # New signal for real-time metrics
|
|
45
|
+
|
|
46
|
+
def __init__(self, config, train_config, annotation_file, clips_dir, output_path):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.config = config
|
|
49
|
+
self.train_config = train_config
|
|
50
|
+
self.annotation_file = annotation_file
|
|
51
|
+
self.clips_dir = clips_dir
|
|
52
|
+
self.output_path = output_path
|
|
53
|
+
self.should_stop = False
|
|
54
|
+
|
|
55
|
+
def stop(self):
|
|
56
|
+
"""Request training stop."""
|
|
57
|
+
self.should_stop = True
|
|
58
|
+
|
|
59
|
+
def _build_model_for_config(self, train_dataset, cfg, log_fn):
|
|
60
|
+
"""Create a fresh classifier for one training run."""
|
|
61
|
+
model_name = cfg.get("backbone_model", self.config.get("backbone_model", "videoprism_public_v1_base"))
|
|
62
|
+
resolution = cfg.get("resolution", 288)
|
|
63
|
+
log_fn(f"Loading VideoPrism backbone ({model_name}) at resolution {resolution}×{resolution}...")
|
|
64
|
+
backbone = VideoPrismBackbone(model_name=model_name, resolution=resolution, log_fn=log_fn)
|
|
65
|
+
log_fn("VideoPrism backbone loaded successfully")
|
|
66
|
+
|
|
67
|
+
head_kwargs = cfg.get("head_kwargs", {}).copy()
|
|
68
|
+
head_kwargs.pop("per_class_query", None)
|
|
69
|
+
head_kwargs_for_metadata = head_kwargs.copy()
|
|
70
|
+
dropout = cfg.get("dropout", 0.1)
|
|
71
|
+
localization_dropout = 0.0
|
|
72
|
+
|
|
73
|
+
log_fn("Creating classifier model...")
|
|
74
|
+
if cfg.get("use_temporal_decoder", True):
|
|
75
|
+
log_fn("Using MAP head + temporal decoder")
|
|
76
|
+
else:
|
|
77
|
+
log_fn("Using MAP head + direct per-frame linear classifier")
|
|
78
|
+
log_fn(f"MAP head kwargs: {head_kwargs}")
|
|
79
|
+
|
|
80
|
+
use_loc = cfg.get("use_localization", False)
|
|
81
|
+
num_stages = cfg.get("num_stages", 3)
|
|
82
|
+
if use_loc and num_stages > 1:
|
|
83
|
+
num_stages = 1
|
|
84
|
+
log_fn("Localization ON: reducing MS-TCN to 1 stage (no refinement) to prevent overfitting on tight crops.")
|
|
85
|
+
|
|
86
|
+
multi_scale = cfg.get("multi_scale", False) and not use_loc
|
|
87
|
+
model = BehaviorClassifier(
|
|
88
|
+
backbone,
|
|
89
|
+
num_classes=len(train_dataset.classes),
|
|
90
|
+
class_names=train_dataset.classes,
|
|
91
|
+
dropout=dropout,
|
|
92
|
+
freeze_backbone=True,
|
|
93
|
+
head_kwargs=head_kwargs,
|
|
94
|
+
use_localization=use_loc,
|
|
95
|
+
localization_hidden_dim=cfg.get("localization_hidden_dim", 256),
|
|
96
|
+
localization_dropout=localization_dropout,
|
|
97
|
+
use_frame_head=True,
|
|
98
|
+
use_temporal_decoder=cfg.get("use_temporal_decoder", True),
|
|
99
|
+
frame_head_temporal_layers=cfg.get("frame_head_temporal_layers", 1),
|
|
100
|
+
temporal_pool_frames=cfg.get("temporal_pool_frames", 1),
|
|
101
|
+
proj_dim=cfg.get("proj_dim", 256),
|
|
102
|
+
num_stages=num_stages,
|
|
103
|
+
multi_scale=multi_scale,
|
|
104
|
+
)
|
|
105
|
+
log_fn("Classifier model created successfully")
|
|
106
|
+
return model, head_kwargs_for_metadata, dropout, localization_dropout, num_stages, multi_scale
|
|
107
|
+
|
|
108
|
+
def _build_backend_train_config(
|
|
109
|
+
self,
|
|
110
|
+
cfg,
|
|
111
|
+
output_path,
|
|
112
|
+
classes,
|
|
113
|
+
augmentation_options,
|
|
114
|
+
head_kwargs_for_metadata,
|
|
115
|
+
dropout,
|
|
116
|
+
localization_dropout,
|
|
117
|
+
num_stages,
|
|
118
|
+
multi_scale,
|
|
119
|
+
):
|
|
120
|
+
return {
|
|
121
|
+
"batch_size": cfg["batch_size"],
|
|
122
|
+
"epochs": cfg["epochs"],
|
|
123
|
+
"lr": cfg.get("classification_lr", cfg["lr"]),
|
|
124
|
+
"localization_lr": cfg.get("localization_lr", cfg["lr"]),
|
|
125
|
+
"classification_lr": cfg.get("classification_lr", cfg["lr"]),
|
|
126
|
+
"use_scheduler": cfg.get("use_scheduler", True),
|
|
127
|
+
"use_ema": cfg.get("use_ema", True),
|
|
128
|
+
"weight_decay": cfg.get("weight_decay", 1e-3),
|
|
129
|
+
"output_path": output_path,
|
|
130
|
+
"save_best": True,
|
|
131
|
+
"use_class_weights": cfg.get("use_class_weights", False),
|
|
132
|
+
"use_focal_loss": cfg.get("use_focal_loss", False),
|
|
133
|
+
"focal_gamma": cfg.get("focal_gamma", 2.0),
|
|
134
|
+
"use_frame_loss": True,
|
|
135
|
+
"use_temporal_decoder": cfg.get("use_temporal_decoder", True),
|
|
136
|
+
"frame_head_temporal_layers": cfg.get("frame_head_temporal_layers", 1),
|
|
137
|
+
"temporal_pool_frames": cfg.get("temporal_pool_frames", 1),
|
|
138
|
+
"num_stages": num_stages,
|
|
139
|
+
"proj_dim": cfg.get("proj_dim", 256),
|
|
140
|
+
"multi_scale": multi_scale,
|
|
141
|
+
"use_frame_bout_balance": cfg.get("use_frame_bout_balance", True),
|
|
142
|
+
"frame_bout_balance_power": cfg.get("frame_bout_balance_power", 1.0),
|
|
143
|
+
"boundary_loss_weight": cfg.get("boundary_loss_weight", 0.3),
|
|
144
|
+
"boundary_tolerance": cfg.get("boundary_tolerance", 2),
|
|
145
|
+
"smoothness_loss_weight": cfg.get("smoothness_loss_weight", 0.05),
|
|
146
|
+
"use_localization": cfg.get("use_localization", False),
|
|
147
|
+
"use_manual_localization_switch": cfg.get("use_manual_localization_switch", False),
|
|
148
|
+
"manual_localization_switch_epoch": cfg.get("manual_localization_switch_epoch", 20),
|
|
149
|
+
"localization_hidden_dim": cfg.get("localization_hidden_dim", 256),
|
|
150
|
+
"classification_crop_padding": cfg.get("classification_crop_padding", 0.35),
|
|
151
|
+
"crop_jitter": cfg.get("crop_jitter", False),
|
|
152
|
+
"crop_jitter_strength": cfg.get("crop_jitter_strength", 0.15),
|
|
153
|
+
"emb_aug_versions": cfg.get("emb_aug_versions", 1),
|
|
154
|
+
"clip_length": cfg.get("clip_length", 8),
|
|
155
|
+
"use_ovr": cfg.get("use_ovr", False),
|
|
156
|
+
"ovr_label_smoothing": cfg.get("ovr_label_smoothing", 0.05),
|
|
157
|
+
"use_asl": cfg.get("use_asl", False),
|
|
158
|
+
"asl_gamma_neg": cfg.get("asl_gamma_neg", 2.0),
|
|
159
|
+
"asl_gamma_pos": cfg.get("asl_gamma_pos", 0.0),
|
|
160
|
+
"asl_clip": cfg.get("asl_clip", 0.05),
|
|
161
|
+
"use_hard_pair_mining": cfg.get("use_hard_pair_mining", False),
|
|
162
|
+
"hard_pairs": cfg.get("hard_pairs", []),
|
|
163
|
+
"hard_pair_margin": cfg.get("hard_pair_margin", 0.5),
|
|
164
|
+
"hard_pair_loss_weight": cfg.get("hard_pair_loss_weight", 0.2),
|
|
165
|
+
"hard_pair_confusion_boost": cfg.get("hard_pair_confusion_boost", 1.5),
|
|
166
|
+
"use_confusion_sampler": cfg.get("use_confusion_sampler", True),
|
|
167
|
+
"confusion_sampler_temperature": cfg.get("confusion_sampler_temperature", 2.0),
|
|
168
|
+
"confusion_sampler_warmup_pct": cfg.get("confusion_sampler_warmup_pct", 0.2),
|
|
169
|
+
"use_weighted_sampler": cfg.get("use_weighted_sampler", False),
|
|
170
|
+
"use_augmentation": cfg.get("use_augmentation", False),
|
|
171
|
+
"augmentation_options": augmentation_options,
|
|
172
|
+
"virtual_expansion": cfg.get("virtual_expansion", 5),
|
|
173
|
+
"stitch_augmentation_prob": cfg.get("stitch_augmentation_prob", 0.0),
|
|
174
|
+
"f1_exclude_classes": cfg.get("f1_exclude_classes", []),
|
|
175
|
+
"ovr_pos_weight_f1_excluded": cfg.get("ovr_pos_weight_f1_excluded", 1.5),
|
|
176
|
+
"val_split": cfg.get("val_split", 0.2),
|
|
177
|
+
"limit_classes": cfg.get("limit_classes", False),
|
|
178
|
+
"selected_classes": cfg.get("selected_classes", []),
|
|
179
|
+
"limit_per_class": cfg.get("limit_per_class", False),
|
|
180
|
+
"per_class_limits": cfg.get("per_class_limits", {}),
|
|
181
|
+
"per_class_val_limits": cfg.get("per_class_val_limits", {}),
|
|
182
|
+
"backbone_model": cfg.get("backbone_model", "videoprism_public_v1_base"),
|
|
183
|
+
"resolution": cfg.get("resolution", 288),
|
|
184
|
+
"use_all_for_training": cfg.get("use_all_for_training", False),
|
|
185
|
+
"use_embedding_diversity": cfg.get("use_embedding_diversity", False),
|
|
186
|
+
"class_names": cfg.get("class_names", classes),
|
|
187
|
+
"pretrained_path": cfg.get("pretrained_path"),
|
|
188
|
+
"head_kwargs": head_kwargs_for_metadata,
|
|
189
|
+
"dropout": dropout,
|
|
190
|
+
"localization_dropout": localization_dropout,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def _generate_autotune_candidates(self, base_cfg, num_runs: int):
|
|
194
|
+
"""Generate a small random search around the current config."""
|
|
195
|
+
def _uniq_float(values):
|
|
196
|
+
out = []
|
|
197
|
+
for v in values:
|
|
198
|
+
v = float(max(1e-6, min(1.0, v)))
|
|
199
|
+
if all(abs(v - prev) > 1e-12 for prev in out):
|
|
200
|
+
out.append(v)
|
|
201
|
+
return out
|
|
202
|
+
|
|
203
|
+
def _uniq_int(values, lo, hi):
|
|
204
|
+
out = []
|
|
205
|
+
for v in values:
|
|
206
|
+
v = int(max(lo, min(hi, int(round(v)))))
|
|
207
|
+
if v not in out:
|
|
208
|
+
out.append(v)
|
|
209
|
+
return out
|
|
210
|
+
|
|
211
|
+
lr0 = float(base_cfg.get("classification_lr", base_cfg.get("lr", 1e-4)))
|
|
212
|
+
wd0 = float(base_cfg.get("weight_decay", 1e-3))
|
|
213
|
+
drop0 = float(base_cfg.get("dropout", 0.3))
|
|
214
|
+
heads0 = int(base_cfg.get("head_kwargs", {}).get("num_heads", 4))
|
|
215
|
+
layers0 = int(base_cfg.get("frame_head_temporal_layers", 4))
|
|
216
|
+
ovr_ls0 = float(base_cfg.get("ovr_label_smoothing", 0.05))
|
|
217
|
+
use_temporal_decoder = bool(base_cfg.get("use_temporal_decoder", True))
|
|
218
|
+
|
|
219
|
+
lr_vals = _uniq_float([3e-5, 1e-4, 3e-4, 1e-3, lr0 / 3.0, lr0, lr0 * 3.0])
|
|
220
|
+
wd_vals = _uniq_float([1e-5, 1e-4, 5e-4, 1e-3, wd0 / 10.0, wd0, wd0 * 3.0])
|
|
221
|
+
drop_vals = _uniq_float([0.1, 0.2, 0.3, 0.4, drop0])
|
|
222
|
+
head_vals = _uniq_int([2, 4, 8, heads0, heads0 - 2, heads0 + 2], 1, 16)
|
|
223
|
+
layer_vals = [layers0] if not use_temporal_decoder else _uniq_int([1, 2, 3, 4, layers0, layers0 - 1, layers0 + 1], 1, 8)
|
|
224
|
+
ovr_ls_vals = _uniq_float([0.0, 0.02, 0.05, ovr_ls0])
|
|
225
|
+
|
|
226
|
+
rng = random.Random(42)
|
|
227
|
+
candidates = []
|
|
228
|
+
seen = set()
|
|
229
|
+
|
|
230
|
+
def _trial_tuple(cfg):
|
|
231
|
+
return (
|
|
232
|
+
round(float(cfg.get("classification_lr", cfg.get("lr", 0.0))), 12),
|
|
233
|
+
round(float(cfg.get("weight_decay", 0.0)), 12),
|
|
234
|
+
round(float(cfg.get("dropout", 0.0)), 12),
|
|
235
|
+
int(cfg.get("head_kwargs", {}).get("num_heads", 0)),
|
|
236
|
+
int(cfg.get("frame_head_temporal_layers", 0)) if cfg.get("use_temporal_decoder", True) else None,
|
|
237
|
+
round(float(cfg.get("ovr_label_smoothing", 0.0)), 12) if cfg.get("use_ovr", False) else None,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
base_copy = copy.deepcopy(base_cfg)
|
|
241
|
+
candidates.append(base_copy)
|
|
242
|
+
seen.add(_trial_tuple(base_copy))
|
|
243
|
+
|
|
244
|
+
max_attempts = max(20, num_runs * 20)
|
|
245
|
+
attempts = 0
|
|
246
|
+
while len(candidates) < max(1, int(num_runs)) and attempts < max_attempts:
|
|
247
|
+
attempts += 1
|
|
248
|
+
cfg = copy.deepcopy(base_cfg)
|
|
249
|
+
new_lr = rng.choice(lr_vals)
|
|
250
|
+
cfg["lr"] = new_lr
|
|
251
|
+
cfg["classification_lr"] = new_lr
|
|
252
|
+
cfg["weight_decay"] = rng.choice(wd_vals)
|
|
253
|
+
cfg["dropout"] = rng.choice(drop_vals)
|
|
254
|
+
cfg.setdefault("head_kwargs", {})
|
|
255
|
+
cfg["head_kwargs"]["num_heads"] = rng.choice(head_vals)
|
|
256
|
+
cfg["frame_head_temporal_layers"] = rng.choice(layer_vals)
|
|
257
|
+
if cfg.get("use_ovr", False):
|
|
258
|
+
cfg["ovr_label_smoothing"] = rng.choice(ovr_ls_vals)
|
|
259
|
+
key = _trial_tuple(cfg)
|
|
260
|
+
if key in seen:
|
|
261
|
+
continue
|
|
262
|
+
seen.add(key)
|
|
263
|
+
candidates.append(cfg)
|
|
264
|
+
return candidates[:max(1, int(num_runs))]
|
|
265
|
+
|
|
266
|
+
def _reset_runtime_dataset_caches(self, dataset):
|
|
267
|
+
"""Clear per-run disk cache flags so a new training run rebuilds them cleanly."""
|
|
268
|
+
if dataset is None:
|
|
269
|
+
return
|
|
270
|
+
if hasattr(dataset, "_emb_cache_mode"):
|
|
271
|
+
dataset._emb_cache_mode = False
|
|
272
|
+
if hasattr(dataset, "_emb_cache_dir"):
|
|
273
|
+
dataset._emb_cache_dir = None
|
|
274
|
+
if hasattr(dataset, "_roi_cache_mode"):
|
|
275
|
+
dataset._roi_cache_mode = False
|
|
276
|
+
if hasattr(dataset, "_roi_cache_dir"):
|
|
277
|
+
dataset._roi_cache_dir = None
|
|
278
|
+
|
|
279
|
+
def _cleanup_autotune_trial_outputs(self, trial_output, log_fn=None):
|
|
280
|
+
"""Delete temporary files produced by one auto-tune trial."""
|
|
281
|
+
output_dir = os.path.dirname(trial_output)
|
|
282
|
+
basename = os.path.splitext(os.path.basename(trial_output))[0]
|
|
283
|
+
paths = [
|
|
284
|
+
trial_output,
|
|
285
|
+
trial_output + ".meta.json",
|
|
286
|
+
os.path.join(output_dir, f"{basename}_best.pt"),
|
|
287
|
+
os.path.join(output_dir, f"{basename}_best.pt.meta.json"),
|
|
288
|
+
os.path.join(output_dir, f"{basename}_training_config.json"),
|
|
289
|
+
os.path.join(output_dir, f"{basename}_checkpoints"),
|
|
290
|
+
os.path.join(output_dir, f"{basename}_crop_progress"),
|
|
291
|
+
]
|
|
292
|
+
paths.extend(glob.glob(os.path.join(output_dir, f"{basename}_training_log_*.csv")))
|
|
293
|
+
paths.extend(glob.glob(os.path.join(output_dir, f"{basename}_training_plot_*.pdf")))
|
|
294
|
+
for path in paths:
|
|
295
|
+
try:
|
|
296
|
+
if os.path.isdir(path):
|
|
297
|
+
shutil.rmtree(path, ignore_errors=True)
|
|
298
|
+
elif os.path.exists(path):
|
|
299
|
+
os.remove(path)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
if log_fn:
|
|
302
|
+
log_fn(f"Warning: could not delete autotune temp output {path}: {e}")
|
|
303
|
+
|
|
304
|
+
def _run_autotune_search(self, train_dataset, val_dataset, base_cfg, classes, augmentation_options, log_fn, progress_cb, check_stop):
|
|
305
|
+
num_runs = int(max(1, base_cfg.get("auto_tune_runs", 8)))
|
|
306
|
+
search_epochs = int(max(1, base_cfg.get("auto_tune_epochs", min(12, int(base_cfg.get("epochs", 20))))))
|
|
307
|
+
candidates = self._generate_autotune_candidates(base_cfg, num_runs)
|
|
308
|
+
trial_root = tempfile.mkdtemp(prefix="autotune_trials_", dir=os.path.dirname(self.output_path))
|
|
309
|
+
best_cfg = None
|
|
310
|
+
best_result = None
|
|
311
|
+
best_score = float("-inf")
|
|
312
|
+
|
|
313
|
+
log_fn(f"Auto-tune enabled: evaluating {len(candidates)} candidate config(s) for {search_epochs} epoch(s) each.")
|
|
314
|
+
log_fn("Search parameters: classification LR, weight decay, dropout, MAP num_heads, frame temporal layers"
|
|
315
|
+
+ (", OvR label smoothing" if base_cfg.get("use_ovr", False) else ""))
|
|
316
|
+
try:
|
|
317
|
+
for idx, candidate in enumerate(candidates, start=1):
|
|
318
|
+
if check_stop():
|
|
319
|
+
return None, None
|
|
320
|
+
self._reset_runtime_dataset_caches(train_dataset)
|
|
321
|
+
self._reset_runtime_dataset_caches(val_dataset)
|
|
322
|
+
trial_cfg = copy.deepcopy(candidate)
|
|
323
|
+
trial_cfg["epochs"] = search_epochs
|
|
324
|
+
trial_output = os.path.join(trial_root, f"trial_{idx:02d}.pt")
|
|
325
|
+
log_fn(
|
|
326
|
+
f"[Auto-tune {idx}/{len(candidates)}] "
|
|
327
|
+
f"lr={trial_cfg['classification_lr']:.2e}, "
|
|
328
|
+
f"wd={trial_cfg['weight_decay']:.2e}, "
|
|
329
|
+
f"dropout={trial_cfg['dropout']:.2f}, "
|
|
330
|
+
f"heads={trial_cfg['head_kwargs'].get('num_heads', 4)}, "
|
|
331
|
+
f"layers={trial_cfg['frame_head_temporal_layers']}"
|
|
332
|
+
+ (f", ovr_ls={trial_cfg.get('ovr_label_smoothing', 0.0):.2f}" if trial_cfg.get("use_ovr", False) else "")
|
|
333
|
+
)
|
|
334
|
+
model, head_kwargs_meta, dropout, localization_dropout, num_stages, multi_scale = self._build_model_for_config(
|
|
335
|
+
train_dataset, trial_cfg, log_fn
|
|
336
|
+
)
|
|
337
|
+
backend_cfg = self._build_backend_train_config(
|
|
338
|
+
trial_cfg,
|
|
339
|
+
trial_output,
|
|
340
|
+
classes,
|
|
341
|
+
augmentation_options,
|
|
342
|
+
head_kwargs_meta,
|
|
343
|
+
dropout,
|
|
344
|
+
localization_dropout,
|
|
345
|
+
num_stages,
|
|
346
|
+
multi_scale,
|
|
347
|
+
)
|
|
348
|
+
result = train_model(
|
|
349
|
+
model,
|
|
350
|
+
train_dataset,
|
|
351
|
+
val_dataset,
|
|
352
|
+
backend_cfg,
|
|
353
|
+
log_fn=log_fn,
|
|
354
|
+
progress_callback=progress_cb,
|
|
355
|
+
stop_callback=check_stop,
|
|
356
|
+
metrics_callback=None,
|
|
357
|
+
)
|
|
358
|
+
score = float(result.get("best_val_f1", 0.0) or 0.0)
|
|
359
|
+
acc = float(result.get("best_val_acc", 0.0) or 0.0)
|
|
360
|
+
log_fn(f"[Auto-tune {idx}/{len(candidates)}] best val F1={score:.2f}, val acc={acc:.2f}")
|
|
361
|
+
if score > best_score:
|
|
362
|
+
best_score = score
|
|
363
|
+
best_cfg = copy.deepcopy(trial_cfg)
|
|
364
|
+
best_result = dict(result)
|
|
365
|
+
self._cleanup_autotune_trial_outputs(trial_output, log_fn=log_fn)
|
|
366
|
+
return best_cfg, best_result
|
|
367
|
+
finally:
|
|
368
|
+
try:
|
|
369
|
+
shutil.rmtree(trial_root)
|
|
370
|
+
except Exception as e:
|
|
371
|
+
logger.debug("Could not remove autotune trial root: %s", e)
|
|
372
|
+
|
|
373
|
+
def _save_autotuned_profile(self, tuned_cfg, best_search_result, log_fn):
|
|
374
|
+
"""Persist the selected auto-tune winner as a reusable training profile."""
|
|
375
|
+
profiles_path = self.train_config.get("training_profiles_path", "")
|
|
376
|
+
if not profiles_path:
|
|
377
|
+
return
|
|
378
|
+
try:
|
|
379
|
+
os.makedirs(os.path.dirname(profiles_path), exist_ok=True)
|
|
380
|
+
profiles = {}
|
|
381
|
+
if os.path.exists(profiles_path):
|
|
382
|
+
with open(profiles_path, "r", encoding="utf-8") as f:
|
|
383
|
+
loaded = json.load(f)
|
|
384
|
+
if isinstance(loaded, dict):
|
|
385
|
+
profiles = loaded
|
|
386
|
+
|
|
387
|
+
profile_cfg = copy.deepcopy(tuned_cfg)
|
|
388
|
+
profile_cfg["auto_tune_before_final"] = False
|
|
389
|
+
profile_cfg["auto_tune_selected_from_search"] = True
|
|
390
|
+
profile_cfg["auto_tune_selected_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
391
|
+
if best_search_result:
|
|
392
|
+
profile_cfg["auto_tune_best_search_val_f1"] = float(best_search_result.get("best_val_f1", 0.0) or 0.0)
|
|
393
|
+
profile_cfg["auto_tune_best_search_val_acc"] = float(best_search_result.get("best_val_acc", 0.0) or 0.0)
|
|
394
|
+
|
|
395
|
+
base_name = os.path.splitext(os.path.basename(self.output_path))[0] or "training"
|
|
396
|
+
profile_name = f"autotune_{base_name}_{time.strftime('%Y%m%d_%H%M%S')}"
|
|
397
|
+
suffix = 2
|
|
398
|
+
while profile_name in profiles:
|
|
399
|
+
profile_name = f"autotune_{base_name}_{time.strftime('%Y%m%d_%H%M%S')}_{suffix}"
|
|
400
|
+
suffix += 1
|
|
401
|
+
|
|
402
|
+
profiles[profile_name] = profile_cfg
|
|
403
|
+
with open(profiles_path, "w", encoding="utf-8") as f:
|
|
404
|
+
json.dump(profiles, f, indent=2)
|
|
405
|
+
log_fn(f"Saved auto-tuned winner as training profile: {profile_name}")
|
|
406
|
+
except Exception as e:
|
|
407
|
+
log_fn(f"Warning: could not save auto-tuned profile: {e}")
|
|
408
|
+
|
|
409
|
+
def run(self):
|
|
410
|
+
"""Run training."""
|
|
411
|
+
import traceback
|
|
412
|
+
try:
|
|
413
|
+
def log_fn(msg):
|
|
414
|
+
self.log_message.emit(msg)
|
|
415
|
+
|
|
416
|
+
def progress_cb(epoch, total):
|
|
417
|
+
self.progress.emit(epoch, total)
|
|
418
|
+
|
|
419
|
+
def metrics_cb(metrics):
|
|
420
|
+
self.epoch_complete.emit(metrics)
|
|
421
|
+
|
|
422
|
+
log_fn("Initializing training...")
|
|
423
|
+
log_fn(f"Annotation file: {self.annotation_file}")
|
|
424
|
+
log_fn(f"Clips directory: {self.clips_dir}")
|
|
425
|
+
|
|
426
|
+
annotation_manager = AnnotationManager(self.annotation_file)
|
|
427
|
+
labeled_clips = annotation_manager.get_labeled_clips()
|
|
428
|
+
|
|
429
|
+
if not labeled_clips:
|
|
430
|
+
error_msg = "No labeled clips found. Please label some clips first."
|
|
431
|
+
log_fn(f"ERROR: {error_msg}")
|
|
432
|
+
self.error.emit(error_msg)
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
classes = annotation_manager.get_classes()
|
|
436
|
+
|
|
437
|
+
log_fn("Using class-only training pipeline (hierarchical/attribution disabled).")
|
|
438
|
+
|
|
439
|
+
if len(classes) < 2:
|
|
440
|
+
error_msg = "Need at least 2 behavior classes for training."
|
|
441
|
+
log_fn(f"ERROR: {error_msg}")
|
|
442
|
+
self.error.emit(error_msg)
|
|
443
|
+
return
|
|
444
|
+
|
|
445
|
+
# Filter classes if class selection is enabled
|
|
446
|
+
use_ovr = self.train_config.get("use_ovr", False)
|
|
447
|
+
hybrid_ovr_bg = bool(self.train_config.get("ovr_background_as_negative", False) and use_ovr)
|
|
448
|
+
hybrid_bg_classes = set(self.train_config.get("ovr_background_class_names", []))
|
|
449
|
+
selected_classes = None
|
|
450
|
+
if self.train_config.get("limit_classes", False):
|
|
451
|
+
selected_classes = self.train_config.get("selected_classes", [])
|
|
452
|
+
if selected_classes:
|
|
453
|
+
classes = [c for c in classes if c in selected_classes]
|
|
454
|
+
log_fn(f"Limiting training to {len(classes)} selected classes: {classes}")
|
|
455
|
+
if len(classes) < 2:
|
|
456
|
+
error_msg = "Need at least 2 selected classes for training."
|
|
457
|
+
log_fn(f"ERROR: {error_msg}")
|
|
458
|
+
self.error.emit(error_msg)
|
|
459
|
+
return
|
|
460
|
+
# Filter clips: keep selected classes + near_negative clips when OvR
|
|
461
|
+
if use_ovr:
|
|
462
|
+
allowed = set(selected_classes)
|
|
463
|
+
if hybrid_ovr_bg:
|
|
464
|
+
allowed.update(hybrid_bg_classes)
|
|
465
|
+
labeled_clips = [
|
|
466
|
+
clip for clip in labeled_clips
|
|
467
|
+
if clip.get("label") in allowed
|
|
468
|
+
or clip.get("label", "").startswith("near_negative")
|
|
469
|
+
]
|
|
470
|
+
else:
|
|
471
|
+
labeled_clips = [clip for clip in labeled_clips if clip.get("label") in classes]
|
|
472
|
+
|
|
473
|
+
# OvR mode: filter near_negative_* from class list (they become suppression-only)
|
|
474
|
+
if use_ovr:
|
|
475
|
+
real_classes = [c for c in classes if not c.startswith("near_negative")]
|
|
476
|
+
hn_classes = [c for c in classes if c.startswith("near_negative")]
|
|
477
|
+
bg_classes = [c for c in real_classes if c in hybrid_bg_classes] if hybrid_ovr_bg else []
|
|
478
|
+
if hybrid_ovr_bg and bg_classes:
|
|
479
|
+
real_classes = [c for c in real_classes if c not in bg_classes]
|
|
480
|
+
if len(real_classes) < 2:
|
|
481
|
+
error_msg = "OvR mode needs at least 2 non-near-negative classes."
|
|
482
|
+
log_fn(f"ERROR: {error_msg}")
|
|
483
|
+
self.error.emit(error_msg)
|
|
484
|
+
return
|
|
485
|
+
log_fn(f"OvR mode: {len(real_classes)} real classes, {len(hn_classes)} near-negative classes (suppression only)")
|
|
486
|
+
if hybrid_ovr_bg and bg_classes:
|
|
487
|
+
log_fn(
|
|
488
|
+
f"Hybrid OvR backgrounds: {bg_classes} kept as negative-only clips "
|
|
489
|
+
f"(not trained as explicit heads)"
|
|
490
|
+
)
|
|
491
|
+
classes = real_classes
|
|
492
|
+
# Near-negative clips stay in labeled_clips (they get label=-1 in the dataset)
|
|
493
|
+
|
|
494
|
+
log_fn(f"Found {len(labeled_clips)} labeled clips")
|
|
495
|
+
log_fn(f"Classes: {classes}")
|
|
496
|
+
|
|
497
|
+
from collections import Counter
|
|
498
|
+
label_counts = Counter([clip["label"] for clip in labeled_clips])
|
|
499
|
+
log_fn("Class distribution (before limiting):")
|
|
500
|
+
for label, count in sorted(label_counts.items()):
|
|
501
|
+
log_fn(f" {label}: {count} ({100.0*count/len(labeled_clips):.1f}%)")
|
|
502
|
+
|
|
503
|
+
# Multi-class breakdown
|
|
504
|
+
mc_combos: dict[tuple, int] = {}
|
|
505
|
+
mc_per_label: dict[str, int] = {}
|
|
506
|
+
exc_per_label: dict[str, int] = {}
|
|
507
|
+
for clip in labeled_clips:
|
|
508
|
+
lbl_list = clip.get("labels")
|
|
509
|
+
if not isinstance(lbl_list, list) or not lbl_list:
|
|
510
|
+
lbl_list = [clip.get("label", "")]
|
|
511
|
+
if len(lbl_list) > 1:
|
|
512
|
+
key = tuple(sorted(lbl_list))
|
|
513
|
+
mc_combos[key] = mc_combos.get(key, 0) + 1
|
|
514
|
+
for lbl in lbl_list:
|
|
515
|
+
mc_per_label[lbl] = mc_per_label.get(lbl, 0) + 1
|
|
516
|
+
else:
|
|
517
|
+
exc_per_label[lbl_list[0]] = exc_per_label.get(lbl_list[0], 0) + 1
|
|
518
|
+
if mc_combos:
|
|
519
|
+
total_mc = sum(mc_combos.values())
|
|
520
|
+
log_fn(f"Multi-class clips: {total_mc} of {len(labeled_clips)} ({100.0*total_mc/max(1,len(labeled_clips)):.1f}%)")
|
|
521
|
+
for combo, cnt in sorted(mc_combos.items(), key=lambda x: -x[1]):
|
|
522
|
+
log_fn(f" {' + '.join(combo)}: {cnt}")
|
|
523
|
+
for lbl in sorted(mc_per_label):
|
|
524
|
+
exc = exc_per_label.get(lbl, 0)
|
|
525
|
+
sh = mc_per_label[lbl]
|
|
526
|
+
log_fn(f" {lbl}: {exc} exclusive, {sh} in multi-class clips")
|
|
527
|
+
|
|
528
|
+
use_all_for_training = self.train_config.get("use_all_for_training", False)
|
|
529
|
+
val_split = self.train_config.get("val_split", 0.2)
|
|
530
|
+
|
|
531
|
+
def _split_train_val_clip_stratified(clips, split_ratio, seed=42):
|
|
532
|
+
"""Split clips by randomly assigning ~split_ratio per class to val (stratified by clip label)."""
|
|
533
|
+
if not clips or split_ratio <= 0:
|
|
534
|
+
return list(clips), []
|
|
535
|
+
by_label = {}
|
|
536
|
+
for c in clips:
|
|
537
|
+
lbl = c.get("label")
|
|
538
|
+
by_label.setdefault(lbl, []).append(c)
|
|
539
|
+
rng = random.Random(seed)
|
|
540
|
+
train_out, val_out = [], []
|
|
541
|
+
for lbl, group in by_label.items():
|
|
542
|
+
rng.shuffle(group)
|
|
543
|
+
n_val = max(0, int(round(len(group) * split_ratio)))
|
|
544
|
+
n_val = min(n_val, len(group) - 1) if len(group) > 1 else 0
|
|
545
|
+
train_out.extend(group[n_val:])
|
|
546
|
+
val_out.extend(group[:n_val])
|
|
547
|
+
log_fn(f"Clip-stratified split: train={len(train_out)} clips, val={len(val_out)} clips")
|
|
548
|
+
if val_out:
|
|
549
|
+
log_fn(f" Val clip distribution: {dict(Counter(c.get('label') for c in val_out))}")
|
|
550
|
+
return train_out, val_out
|
|
551
|
+
|
|
552
|
+
def _split_train_val_frame_stratified(clips, classes, split_ratio, seed=42):
|
|
553
|
+
"""Split by randomly selecting frames per class for val, then assign whole clips to val if they contain any selected frame.
|
|
554
|
+
Ensures roughly equal proportion of each class in validation."""
|
|
555
|
+
if not clips or split_ratio <= 0:
|
|
556
|
+
return list(clips), []
|
|
557
|
+
class_set = set(classes)
|
|
558
|
+
_clip_len = self.train_config.get("clip_length", 8)
|
|
559
|
+
# Pool (clip_idx, frame_idx) for every labeled frame.
|
|
560
|
+
# Clips without per-frame labels contribute _clip_len entries so their
|
|
561
|
+
# weight in stratification matches what validation will actually count.
|
|
562
|
+
pool_by_class = {c: [] for c in classes}
|
|
563
|
+
for clip_idx, clip in enumerate(clips):
|
|
564
|
+
fl = clip.get("frame_labels")
|
|
565
|
+
if not isinstance(fl, (list, tuple)) or len(fl) == 0:
|
|
566
|
+
primary = clip.get("label")
|
|
567
|
+
if primary in class_set:
|
|
568
|
+
for fi in range(_clip_len):
|
|
569
|
+
pool_by_class[primary].append((clip_idx, fi))
|
|
570
|
+
continue
|
|
571
|
+
for frame_idx, lbl in enumerate(fl):
|
|
572
|
+
if lbl in class_set:
|
|
573
|
+
pool_by_class[lbl].append((clip_idx, frame_idx))
|
|
574
|
+
rng = random.Random(seed)
|
|
575
|
+
val_frames = set()
|
|
576
|
+
for c in classes:
|
|
577
|
+
lst = pool_by_class[c]
|
|
578
|
+
if not lst:
|
|
579
|
+
continue
|
|
580
|
+
rng.shuffle(lst)
|
|
581
|
+
n_val = max(0, int(round(len(lst) * split_ratio)))
|
|
582
|
+
n_val = min(n_val, len(lst))
|
|
583
|
+
for i in range(n_val):
|
|
584
|
+
val_frames.add(lst[i])
|
|
585
|
+
val_clip_indices = {clip_idx for (clip_idx, _) in val_frames}
|
|
586
|
+
# Cap val clips so train gets enough data: at most split_ratio of clips in val.
|
|
587
|
+
max_val_clips = max(1, int(round(len(clips) * split_ratio)))
|
|
588
|
+
if len(val_clip_indices) > max_val_clips:
|
|
589
|
+
val_clip_list = list(val_clip_indices)
|
|
590
|
+
rng.shuffle(val_clip_list)
|
|
591
|
+
val_clip_indices = set(val_clip_list[:max_val_clips])
|
|
592
|
+
# Keep only frames belonging to clips that survived the cap.
|
|
593
|
+
val_frames = {f for f in val_frames if f[0] in val_clip_indices}
|
|
594
|
+
train_out = [c for i, c in enumerate(clips) if i not in val_clip_indices]
|
|
595
|
+
val_out = [c for i, c in enumerate(clips) if i in val_clip_indices]
|
|
596
|
+
log_fn(f"Frame-stratified split: {len(val_frames)} val frames across {len(val_clip_indices)} clips (cap ~{split_ratio:.0%} clips) → "
|
|
597
|
+
f"train={len(train_out)} clips, val={len(val_out)} clips")
|
|
598
|
+
if val_out:
|
|
599
|
+
val_label_counts = Counter()
|
|
600
|
+
for c in val_out:
|
|
601
|
+
fl = c.get("frame_labels")
|
|
602
|
+
if isinstance(fl, (list, tuple)):
|
|
603
|
+
for lbl in fl:
|
|
604
|
+
if lbl in class_set:
|
|
605
|
+
val_label_counts[lbl] += 1
|
|
606
|
+
else:
|
|
607
|
+
val_label_counts[c.get("label")] += 1
|
|
608
|
+
log_fn(f" Val frame distribution: {dict(val_label_counts)}")
|
|
609
|
+
return train_out, val_out
|
|
610
|
+
|
|
611
|
+
# Apply per-class limits to TRAINING ONLY (validation uses all remaining clips)
|
|
612
|
+
per_class_limits = self.train_config.get("per_class_limits", {})
|
|
613
|
+
per_class_val_limits = self.train_config.get("per_class_val_limits", {})
|
|
614
|
+
if per_class_limits:
|
|
615
|
+
log_fn("Applying per-class limits to TRAINING set only...")
|
|
616
|
+
log_fn("Validation will use ALL remaining clips (not limited)")
|
|
617
|
+
|
|
618
|
+
train_clips = []
|
|
619
|
+
val_clips = []
|
|
620
|
+
|
|
621
|
+
random.seed(42) # For reproducibility
|
|
622
|
+
|
|
623
|
+
# Check if embedding-based diversity selection is enabled
|
|
624
|
+
use_embedding_diversity = self.train_config.get("use_embedding_diversity", False)
|
|
625
|
+
backbone_model = self.train_config.get("backbone_model", "videoprism_public_v1_base")
|
|
626
|
+
|
|
627
|
+
resolution = self.train_config.get("resolution", 288)
|
|
628
|
+
|
|
629
|
+
# Initialize VideoPrism backbone if needed for diversity selection
|
|
630
|
+
backbone = None
|
|
631
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
632
|
+
if use_embedding_diversity:
|
|
633
|
+
try:
|
|
634
|
+
log_fn(f"Initializing VideoPrism for embedding-based diversity selection at {resolution}×{resolution}...")
|
|
635
|
+
backbone = VideoPrismBackbone(model_name=backbone_model, resolution=resolution)
|
|
636
|
+
backbone.eval()
|
|
637
|
+
backbone.to(device)
|
|
638
|
+
log_fn("VideoPrism loaded successfully")
|
|
639
|
+
except Exception as e:
|
|
640
|
+
log_fn(f"Failed to load VideoPrism: {e}")
|
|
641
|
+
log_fn(" Falling back to random selection")
|
|
642
|
+
use_embedding_diversity = False
|
|
643
|
+
|
|
644
|
+
def farthest_point_sampling(embeddings, n_samples, seed=42):
|
|
645
|
+
"""Select n_samples using Farthest Point Sampling for maximum diversity."""
|
|
646
|
+
if len(embeddings) <= n_samples:
|
|
647
|
+
return list(range(len(embeddings)))
|
|
648
|
+
|
|
649
|
+
np.random.seed(seed)
|
|
650
|
+
embeddings = np.array(embeddings)
|
|
651
|
+
selected = []
|
|
652
|
+
|
|
653
|
+
# Start with random point
|
|
654
|
+
current = np.random.randint(0, len(embeddings))
|
|
655
|
+
selected.append(current)
|
|
656
|
+
|
|
657
|
+
# Track minimum distance to any selected point for each unselected point
|
|
658
|
+
min_distances = np.full(len(embeddings), np.inf)
|
|
659
|
+
|
|
660
|
+
# Iteratively select farthest point from all selected
|
|
661
|
+
for _ in range(n_samples - 1):
|
|
662
|
+
# Update minimum distances to nearest selected point
|
|
663
|
+
for i in range(len(embeddings)):
|
|
664
|
+
if i not in selected:
|
|
665
|
+
# Distance to the most recently selected point
|
|
666
|
+
dist_to_current = np.linalg.norm(embeddings[i] - embeddings[current])
|
|
667
|
+
# Keep minimum distance to any selected point
|
|
668
|
+
min_distances[i] = min(min_distances[i], dist_to_current)
|
|
669
|
+
|
|
670
|
+
# Select point with maximum minimum distance (farthest from all selected)
|
|
671
|
+
current = np.argmax(min_distances)
|
|
672
|
+
selected.append(current)
|
|
673
|
+
min_distances[current] = 0 # Mark as selected
|
|
674
|
+
|
|
675
|
+
return selected
|
|
676
|
+
|
|
677
|
+
for label in sorted(set([clip["label"] for clip in labeled_clips])):
|
|
678
|
+
class_clips = [clip for clip in labeled_clips if clip.get("label") == label]
|
|
679
|
+
|
|
680
|
+
if label in per_class_limits:
|
|
681
|
+
raw_limit = per_class_limits[label]
|
|
682
|
+
limit = max(1, int(round(float(raw_limit))))
|
|
683
|
+
raw_val_limit = per_class_val_limits.get(label, float('inf'))
|
|
684
|
+
val_limit = max(1, int(round(float(raw_val_limit)))) if raw_val_limit != float('inf') else float('inf')
|
|
685
|
+
if len(class_clips) > limit:
|
|
686
|
+
log_fn(f" {label}: {len(class_clips)} total clips")
|
|
687
|
+
|
|
688
|
+
if use_embedding_diversity and backbone is not None:
|
|
689
|
+
try:
|
|
690
|
+
log_fn(f" Extracting orientation-invariant embeddings for diversity selection...")
|
|
691
|
+
embeddings = []
|
|
692
|
+
valid_clips = []
|
|
693
|
+
|
|
694
|
+
for clip_idx, clip in enumerate(class_clips):
|
|
695
|
+
if (clip_idx + 1) % 20 == 0:
|
|
696
|
+
log_fn(f" Processing clip {clip_idx + 1}/{len(class_clips)}...")
|
|
697
|
+
|
|
698
|
+
clip_id = clip.get("id", "")
|
|
699
|
+
clip_path = os.path.join(self.clips_dir, clip_id)
|
|
700
|
+
|
|
701
|
+
# Try to find clip file
|
|
702
|
+
if not os.path.exists(clip_path):
|
|
703
|
+
base_name, ext = os.path.splitext(clip_id)
|
|
704
|
+
for video_ext in ['.mp4', '.avi', '.mov', '.mkv']:
|
|
705
|
+
test_path = os.path.join(self.clips_dir, base_name + video_ext)
|
|
706
|
+
if os.path.exists(test_path):
|
|
707
|
+
clip_path = test_path
|
|
708
|
+
break
|
|
709
|
+
|
|
710
|
+
if not os.path.exists(clip_path):
|
|
711
|
+
continue
|
|
712
|
+
|
|
713
|
+
try:
|
|
714
|
+
# Load frames
|
|
715
|
+
frames_bgr = load_clip_frames(clip_path, target_size=(resolution, resolution))
|
|
716
|
+
if not frames_bgr:
|
|
717
|
+
continue
|
|
718
|
+
|
|
719
|
+
# Convert BGR to RGB and normalize
|
|
720
|
+
frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames_bgr]
|
|
721
|
+
frames_array = np.stack(frames_rgb).astype(np.float32) / 255.0
|
|
722
|
+
|
|
723
|
+
# Extract embeddings with orientation augmentation (original + hflip + vflip + both)
|
|
724
|
+
embedding_list = []
|
|
725
|
+
|
|
726
|
+
for hflip, vflip in [(False, False), (True, False), (False, True), (True, True)]:
|
|
727
|
+
frames_aug = frames_array.copy()
|
|
728
|
+
|
|
729
|
+
# Apply flips
|
|
730
|
+
if hflip:
|
|
731
|
+
frames_aug = np.flip(frames_aug, axis=2) # Flip horizontally (width)
|
|
732
|
+
if vflip:
|
|
733
|
+
frames_aug = np.flip(frames_aug, axis=1) # Flip vertically (height)
|
|
734
|
+
|
|
735
|
+
# Copy after all flips to avoid negative strides (PyTorch requirement)
|
|
736
|
+
frames_aug = frames_aug.copy()
|
|
737
|
+
|
|
738
|
+
# Ensure correct shape: (T, H, W, C) -> (1, T, C, H, W) for VideoPrism
|
|
739
|
+
T, H, W, C = frames_aug.shape
|
|
740
|
+
frames_tensor = torch.from_numpy(frames_aug).permute(0, 3, 1, 2) # (T, C, H, W)
|
|
741
|
+
frames_tensor = frames_tensor.unsqueeze(0) # (1, T, C, H, W)
|
|
742
|
+
|
|
743
|
+
# Extract embedding
|
|
744
|
+
with torch.no_grad():
|
|
745
|
+
tokens = backbone(frames_tensor.to(device))
|
|
746
|
+
# Average pool tokens: (1, N, D) -> (D,)
|
|
747
|
+
emb = tokens.mean(dim=1).squeeze(0).cpu().numpy()
|
|
748
|
+
embedding_list.append(emb)
|
|
749
|
+
# Clear GPU tensors immediately
|
|
750
|
+
del tokens, frames_tensor
|
|
751
|
+
|
|
752
|
+
# Average embeddings from all orientations
|
|
753
|
+
embedding = np.mean(embedding_list, axis=0)
|
|
754
|
+
|
|
755
|
+
# L2 normalize for better diversity in angular space
|
|
756
|
+
embedding_norm = np.linalg.norm(embedding)
|
|
757
|
+
if embedding_norm > 0:
|
|
758
|
+
embedding = embedding / embedding_norm
|
|
759
|
+
|
|
760
|
+
embeddings.append(embedding)
|
|
761
|
+
valid_clips.append(clip)
|
|
762
|
+
|
|
763
|
+
# Clear intermediate variables
|
|
764
|
+
del embedding_list, frames_array, frames_rgb, frames_bgr
|
|
765
|
+
|
|
766
|
+
# Periodically clear CUDA cache to prevent accumulation
|
|
767
|
+
if (clip_idx + 1) % 50 == 0 and torch.cuda.is_available():
|
|
768
|
+
torch.cuda.empty_cache()
|
|
769
|
+
except Exception as e:
|
|
770
|
+
log_fn(f" Failed to extract embedding for {clip_id}: {e}")
|
|
771
|
+
continue
|
|
772
|
+
|
|
773
|
+
if len(embeddings) >= limit:
|
|
774
|
+
embeddings = np.array(embeddings)
|
|
775
|
+
log_fn(f" Using Farthest Point Sampling to select {limit} diverse clips...")
|
|
776
|
+
selected_indices = farthest_point_sampling(embeddings, limit, seed=42)
|
|
777
|
+
train_class_clips = [valid_clips[i] for i in selected_indices]
|
|
778
|
+
train_clip_ids = {id(c) for c in train_class_clips}
|
|
779
|
+
val_class_clips = [c for c in class_clips if id(c) not in train_clip_ids]
|
|
780
|
+
log_fn(f" Selected {len(train_class_clips)} diverse clips based on embeddings")
|
|
781
|
+
else:
|
|
782
|
+
log_fn(f" Only {len(embeddings)} valid embeddings, falling back to random")
|
|
783
|
+
random.shuffle(class_clips)
|
|
784
|
+
train_class_clips = class_clips[:limit]
|
|
785
|
+
val_class_clips = class_clips[limit:]
|
|
786
|
+
except Exception as e:
|
|
787
|
+
log_fn(f" Embedding-based selection failed: {e}")
|
|
788
|
+
log_fn(f" Falling back to random selection")
|
|
789
|
+
random.shuffle(class_clips)
|
|
790
|
+
train_class_clips = class_clips[:limit]
|
|
791
|
+
val_class_clips = class_clips[limit:]
|
|
792
|
+
else:
|
|
793
|
+
log_fn(f" Training: randomly selecting {limit} clips")
|
|
794
|
+
random.shuffle(class_clips)
|
|
795
|
+
train_class_clips = class_clips[:limit]
|
|
796
|
+
val_class_clips = class_clips[limit:]
|
|
797
|
+
|
|
798
|
+
# Apply validation limit if set
|
|
799
|
+
if len(val_class_clips) > val_limit:
|
|
800
|
+
log_fn(f" Validation: limiting to {val_limit} clips (from {len(val_class_clips)} available)")
|
|
801
|
+
random.shuffle(val_class_clips)
|
|
802
|
+
val_class_clips = val_class_clips[:val_limit]
|
|
803
|
+
else:
|
|
804
|
+
log_fn(f" Validation: using all remaining {len(val_class_clips)} clips")
|
|
805
|
+
|
|
806
|
+
train_clips.extend(train_class_clips)
|
|
807
|
+
val_clips.extend(val_class_clips)
|
|
808
|
+
else:
|
|
809
|
+
# Not enough clips to limit - use all for training, all for validation
|
|
810
|
+
log_fn(f" {label}: {len(class_clips)} total clips (below limit of {limit})")
|
|
811
|
+
log_fn(f" Training: using all {len(class_clips)} clips (can't limit below available)")
|
|
812
|
+
if use_all_for_training:
|
|
813
|
+
train_clips.extend(class_clips)
|
|
814
|
+
else:
|
|
815
|
+
# Keep train/val disjoint even for tiny classes.
|
|
816
|
+
# If there is only 1 clip, we cannot hold out validation without losing
|
|
817
|
+
# training signal, so we place it in training only.
|
|
818
|
+
if len(class_clips) <= 1:
|
|
819
|
+
train_clips.extend(class_clips)
|
|
820
|
+
log_fn(" Validation: 0 clips (class too small to split without overlap)")
|
|
821
|
+
else:
|
|
822
|
+
tmp = list(class_clips)
|
|
823
|
+
random.shuffle(tmp)
|
|
824
|
+
n = len(tmp)
|
|
825
|
+
n_val = int(round(n * float(val_split)))
|
|
826
|
+
if n_val <= 0:
|
|
827
|
+
n_val = 1
|
|
828
|
+
if n_val >= n:
|
|
829
|
+
n_val = n - 1
|
|
830
|
+
val_class_clips = tmp[:n_val]
|
|
831
|
+
train_class_clips = tmp[n_val:]
|
|
832
|
+
# Apply validation cap if requested
|
|
833
|
+
if val_limit != float('inf') and len(val_class_clips) > int(val_limit):
|
|
834
|
+
random.shuffle(val_class_clips)
|
|
835
|
+
val_class_clips = val_class_clips[:int(val_limit)]
|
|
836
|
+
train_clips.extend(train_class_clips)
|
|
837
|
+
val_clips.extend(val_class_clips)
|
|
838
|
+
log_fn(f" Validation: {len(val_class_clips)} clips (held out, no overlap)")
|
|
839
|
+
else:
|
|
840
|
+
# No limit for this class, split normally
|
|
841
|
+
if use_all_for_training:
|
|
842
|
+
train_clips.extend(class_clips)
|
|
843
|
+
else:
|
|
844
|
+
train_class, val_class = _split_train_val_clip_stratified(
|
|
845
|
+
class_clips,
|
|
846
|
+
val_split,
|
|
847
|
+
seed=42,
|
|
848
|
+
)
|
|
849
|
+
train_clips.extend(train_class)
|
|
850
|
+
val_clips.extend(val_class)
|
|
851
|
+
|
|
852
|
+
# Free GPU memory after diversity selection
|
|
853
|
+
if backbone is not None:
|
|
854
|
+
log_fn("Freeing VideoPrism backbone from GPU memory...")
|
|
855
|
+
del backbone
|
|
856
|
+
if torch.cuda.is_available():
|
|
857
|
+
torch.cuda.empty_cache()
|
|
858
|
+
log_fn("GPU memory freed")
|
|
859
|
+
|
|
860
|
+
log_fn(f"Final dataset sizes:")
|
|
861
|
+
log_fn(f" Training: {len(train_clips)} clips")
|
|
862
|
+
if not use_all_for_training:
|
|
863
|
+
log_fn(f" Validation: {len(val_clips)} clips")
|
|
864
|
+
|
|
865
|
+
def _log_split_distribution(clips, split_name):
|
|
866
|
+
counts = Counter([c["label"] for c in clips])
|
|
867
|
+
log_fn(f"{split_name} class distribution:")
|
|
868
|
+
mc = 0
|
|
869
|
+
for c in clips:
|
|
870
|
+
ll = c.get("labels")
|
|
871
|
+
if isinstance(ll, list) and len(ll) > 1:
|
|
872
|
+
mc += 1
|
|
873
|
+
for label, count in sorted(counts.items()):
|
|
874
|
+
log_fn(f" {label}: {count}")
|
|
875
|
+
if mc > 0:
|
|
876
|
+
log_fn(f" ({mc} multi-class clips in {split_name.lower()} set)")
|
|
877
|
+
|
|
878
|
+
_log_split_distribution(train_clips, "Training")
|
|
879
|
+
if not use_all_for_training:
|
|
880
|
+
_log_split_distribution(val_clips, "Validation")
|
|
881
|
+
else:
|
|
882
|
+
# No per-class limits, use standard train/val split
|
|
883
|
+
if use_all_for_training:
|
|
884
|
+
log_fn("Using all data for training (no validation split)")
|
|
885
|
+
train_clips = labeled_clips
|
|
886
|
+
val_clips = []
|
|
887
|
+
else:
|
|
888
|
+
log_fn(f"Splitting dataset into train/val (val_split={val_split:.1%}, stratified when possible)...")
|
|
889
|
+
has_frame_labels = any(
|
|
890
|
+
isinstance(c.get("frame_labels"), (list, tuple)) and len(c.get("frame_labels") or []) > 0
|
|
891
|
+
for c in labeled_clips
|
|
892
|
+
)
|
|
893
|
+
if has_frame_labels:
|
|
894
|
+
log_fn("Using frame-stratified split (equal proportion of each class in validation).")
|
|
895
|
+
train_clips, val_clips = _split_train_val_frame_stratified(
|
|
896
|
+
labeled_clips,
|
|
897
|
+
classes,
|
|
898
|
+
val_split,
|
|
899
|
+
seed=42,
|
|
900
|
+
)
|
|
901
|
+
else:
|
|
902
|
+
log_fn("Using clip-stratified split (no per-frame labels).")
|
|
903
|
+
train_clips, val_clips = _split_train_val_clip_stratified(
|
|
904
|
+
labeled_clips,
|
|
905
|
+
val_split,
|
|
906
|
+
seed=42,
|
|
907
|
+
)
|
|
908
|
+
log_fn(f"Train: {len(train_clips)}, Val: {len(val_clips)}")
|
|
909
|
+
|
|
910
|
+
log_fn(f"Train: {len(train_clips)} clips")
|
|
911
|
+
|
|
912
|
+
log_fn("Validating clip files exist...")
|
|
913
|
+
log_fn(f"Checking clips directory: {self.clips_dir}")
|
|
914
|
+
|
|
915
|
+
if not os.path.exists(self.clips_dir):
|
|
916
|
+
error_msg = f"Clips directory does not exist: {self.clips_dir}\n\nPlease check the clips directory path in the Training tab."
|
|
917
|
+
log_fn(f"ERROR: {error_msg}")
|
|
918
|
+
self.error.emit(error_msg)
|
|
919
|
+
return
|
|
920
|
+
|
|
921
|
+
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
|
|
922
|
+
all_video_files = []
|
|
923
|
+
for root, dirs, files in os.walk(self.clips_dir):
|
|
924
|
+
for file in files:
|
|
925
|
+
if any(file.lower().endswith(ext.lower()) for ext in video_extensions):
|
|
926
|
+
rel_path = os.path.relpath(os.path.join(root, file), self.clips_dir)
|
|
927
|
+
all_video_files.append(rel_path.replace('\\', '/'))
|
|
928
|
+
|
|
929
|
+
log_fn(f"Found {len(all_video_files)} video files in directory (including subdirectories)")
|
|
930
|
+
if len(all_video_files) > 0 and len(all_video_files) < 10:
|
|
931
|
+
log_fn(f"Video files found: {', '.join(all_video_files)}")
|
|
932
|
+
|
|
933
|
+
missing_clips = []
|
|
934
|
+
|
|
935
|
+
for clip_info in train_clips:
|
|
936
|
+
clip_id = clip_info["id"]
|
|
937
|
+
clip_path = os.path.join(self.clips_dir, clip_id)
|
|
938
|
+
found = False
|
|
939
|
+
|
|
940
|
+
if os.path.exists(clip_path):
|
|
941
|
+
found = True
|
|
942
|
+
else:
|
|
943
|
+
base_name, ext = os.path.splitext(clip_id)
|
|
944
|
+
clip_basename = os.path.basename(clip_id)
|
|
945
|
+
clip_dir_part = os.path.dirname(clip_id) if os.path.dirname(clip_id) else None
|
|
946
|
+
|
|
947
|
+
if not ext:
|
|
948
|
+
for video_ext in video_extensions:
|
|
949
|
+
test_path = os.path.join(self.clips_dir, clip_id + video_ext)
|
|
950
|
+
if os.path.exists(test_path):
|
|
951
|
+
found = True
|
|
952
|
+
break
|
|
953
|
+
else:
|
|
954
|
+
base_name_only = os.path.basename(base_name)
|
|
955
|
+
|
|
956
|
+
for video_ext in video_extensions:
|
|
957
|
+
test_path = os.path.join(self.clips_dir, base_name + video_ext)
|
|
958
|
+
if os.path.exists(test_path):
|
|
959
|
+
found = True
|
|
960
|
+
break
|
|
961
|
+
|
|
962
|
+
if not found:
|
|
963
|
+
test_path = os.path.join(self.clips_dir, base_name_only + video_ext)
|
|
964
|
+
if os.path.exists(test_path):
|
|
965
|
+
found = True
|
|
966
|
+
break
|
|
967
|
+
|
|
968
|
+
if not found:
|
|
969
|
+
for root, dirs, files in os.walk(self.clips_dir):
|
|
970
|
+
for file in files:
|
|
971
|
+
file_base, file_ext = os.path.splitext(file)
|
|
972
|
+
if file_base == base_name_only or file_base == base_name:
|
|
973
|
+
if file_ext.lower() in [e.lower() for e in video_extensions]:
|
|
974
|
+
found = True
|
|
975
|
+
break
|
|
976
|
+
if found:
|
|
977
|
+
break
|
|
978
|
+
|
|
979
|
+
if not found and clip_dir_part:
|
|
980
|
+
subdir_path = os.path.join(self.clips_dir, clip_dir_part)
|
|
981
|
+
if os.path.exists(subdir_path):
|
|
982
|
+
for video_ext in video_extensions:
|
|
983
|
+
test_path = os.path.join(subdir_path, clip_basename)
|
|
984
|
+
if os.path.exists(test_path):
|
|
985
|
+
found = True
|
|
986
|
+
break
|
|
987
|
+
test_path = os.path.join(subdir_path, base_name_only + video_ext)
|
|
988
|
+
if os.path.exists(test_path):
|
|
989
|
+
found = True
|
|
990
|
+
break
|
|
991
|
+
|
|
992
|
+
if not found:
|
|
993
|
+
missing_clips.append(clip_id)
|
|
994
|
+
|
|
995
|
+
if missing_clips:
|
|
996
|
+
log_fn(f"WARNING: Found {len(missing_clips)} missing clip files")
|
|
997
|
+
log_fn("Searching for files in subdirectories...")
|
|
998
|
+
|
|
999
|
+
found_in_subdirs = {}
|
|
1000
|
+
sample_missing = missing_clips[:5]
|
|
1001
|
+
|
|
1002
|
+
for clip_id in sample_missing:
|
|
1003
|
+
base_name, ext = os.path.splitext(clip_id)
|
|
1004
|
+
for root, dirs, files in os.walk(self.clips_dir):
|
|
1005
|
+
for file in files:
|
|
1006
|
+
file_base, file_ext = os.path.splitext(file)
|
|
1007
|
+
if file_base == base_name or file_base == os.path.basename(base_name):
|
|
1008
|
+
if file_ext.lower() in [e.lower() for e in video_extensions]:
|
|
1009
|
+
rel_path = os.path.relpath(os.path.join(root, file), self.clips_dir)
|
|
1010
|
+
if clip_id not in found_in_subdirs:
|
|
1011
|
+
found_in_subdirs[clip_id] = rel_path
|
|
1012
|
+
|
|
1013
|
+
if found_in_subdirs:
|
|
1014
|
+
log_fn(f"Found {len(found_in_subdirs)} files in subdirectories:")
|
|
1015
|
+
for clip_id, found_path in list(found_in_subdirs.items())[:3]:
|
|
1016
|
+
log_fn(f" {clip_id} -> {found_path}")
|
|
1017
|
+
log_fn("SUGGESTION: Your clips are in subdirectories. Update the clip IDs in annotations.json")
|
|
1018
|
+
log_fn(" to include the subdirectory path, or move files to the root clips directory.")
|
|
1019
|
+
|
|
1020
|
+
error_msg = f"Found {len(missing_clips)} missing clip files (out of {len(train_clips)} total)\n\n"
|
|
1021
|
+
error_msg += f"Clips directory: {self.clips_dir}\n\n"
|
|
1022
|
+
error_msg += "Sample missing files:\n"
|
|
1023
|
+
for clip_id in missing_clips[:10]:
|
|
1024
|
+
error_msg += f" - {clip_id}\n"
|
|
1025
|
+
if len(missing_clips) > 10:
|
|
1026
|
+
error_msg += f" ... and {len(missing_clips) - 10} more\n"
|
|
1027
|
+
|
|
1028
|
+
if found_in_subdirs:
|
|
1029
|
+
error_msg += f"\nNOTE: Found {len(found_in_subdirs)} of the sample files in subdirectories.\n"
|
|
1030
|
+
error_msg += "Your annotation file may need to include subdirectory paths in clip IDs.\n"
|
|
1031
|
+
error_msg += "Example: 'subfolder/span10.avi' instead of 'span10.avi'\n"
|
|
1032
|
+
|
|
1033
|
+
error_msg += f"\nPlease:\n"
|
|
1034
|
+
error_msg += f"1. Check that clip files exist in: {self.clips_dir}\n"
|
|
1035
|
+
error_msg += f"2. Verify the clips directory path in the Training tab\n"
|
|
1036
|
+
error_msg += f"3. Update annotations.json to include correct paths if files are in subdirectories"
|
|
1037
|
+
|
|
1038
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1039
|
+
self.error.emit(error_msg)
|
|
1040
|
+
return
|
|
1041
|
+
|
|
1042
|
+
log_fn(f"All {len(train_clips)} clips validated successfully")
|
|
1043
|
+
|
|
1044
|
+
log_fn("Creating datasets...")
|
|
1045
|
+
try:
|
|
1046
|
+
from singlebehaviorlab.backend.augmentations import ClipAugment
|
|
1047
|
+
|
|
1048
|
+
use_augmentation = self.train_config.get("use_augmentation", False)
|
|
1049
|
+
augmentation_options = self.train_config.get("augmentation_options", None)
|
|
1050
|
+
if not isinstance(augmentation_options, dict):
|
|
1051
|
+
augmentation_options = {}
|
|
1052
|
+
transform = None
|
|
1053
|
+
if use_augmentation:
|
|
1054
|
+
augmentation_defaults = {
|
|
1055
|
+
"use_horizontal_flip": True,
|
|
1056
|
+
"use_vertical_flip": False,
|
|
1057
|
+
"use_color_jitter": True,
|
|
1058
|
+
"use_gaussian_blur": True,
|
|
1059
|
+
"use_random_noise": True,
|
|
1060
|
+
"use_small_rotation": False,
|
|
1061
|
+
"use_speed_perturb": False,
|
|
1062
|
+
"use_random_shapes": False,
|
|
1063
|
+
"use_grayscale": False,
|
|
1064
|
+
"use_lighting_robustness": True,
|
|
1065
|
+
}
|
|
1066
|
+
for key, value in augmentation_defaults.items():
|
|
1067
|
+
if key not in augmentation_options:
|
|
1068
|
+
augmentation_options[key] = value
|
|
1069
|
+
if self.train_config.get("use_localization", False) and augmentation_options.get("use_small_rotation", False):
|
|
1070
|
+
# Rotation changes object coordinates; current spatial-label augmentation
|
|
1071
|
+
# only mirrors bboxes, so disable rotation for localization training.
|
|
1072
|
+
augmentation_options["use_small_rotation"] = False
|
|
1073
|
+
log_fn("Localization is enabled: disabling small rotation to keep bbox supervision aligned.")
|
|
1074
|
+
transform = ClipAugment(
|
|
1075
|
+
use_horizontal_flip=augmentation_options["use_horizontal_flip"],
|
|
1076
|
+
use_vertical_flip=augmentation_options["use_vertical_flip"],
|
|
1077
|
+
use_color_jitter=augmentation_options["use_color_jitter"],
|
|
1078
|
+
use_gaussian_blur=augmentation_options["use_gaussian_blur"],
|
|
1079
|
+
use_random_noise=augmentation_options["use_random_noise"],
|
|
1080
|
+
use_small_rotation=augmentation_options["use_small_rotation"],
|
|
1081
|
+
use_speed_perturb=augmentation_options.get("use_speed_perturb", False),
|
|
1082
|
+
use_random_shapes=augmentation_options.get("use_random_shapes", False),
|
|
1083
|
+
use_grayscale=augmentation_options.get("use_grayscale", False),
|
|
1084
|
+
use_lighting_robustness=augmentation_options.get("use_lighting_robustness", False),
|
|
1085
|
+
gaussian_blur_sigma=(0.1, 0.5),
|
|
1086
|
+
noise_std=0.02,
|
|
1087
|
+
rotation_degrees=5.0,
|
|
1088
|
+
)
|
|
1089
|
+
log_fn("Using selected data augmentation for training:")
|
|
1090
|
+
if augmentation_options["use_horizontal_flip"]:
|
|
1091
|
+
log_fn(" - Random horizontal flip")
|
|
1092
|
+
if augmentation_options["use_vertical_flip"]:
|
|
1093
|
+
log_fn(" - Random vertical flip")
|
|
1094
|
+
if augmentation_options["use_color_jitter"]:
|
|
1095
|
+
log_fn(" - Color jitter (brightness, contrast, saturation, hue)")
|
|
1096
|
+
if augmentation_options["use_gaussian_blur"]:
|
|
1097
|
+
log_fn(" - Gaussian blur (0.1-0.5 sigma)")
|
|
1098
|
+
if augmentation_options["use_random_noise"]:
|
|
1099
|
+
log_fn(" - Random noise (std=0.02)")
|
|
1100
|
+
if augmentation_options.get("use_speed_perturb", False):
|
|
1101
|
+
log_fn(" - Speed perturbation (0.7x-1.3x)")
|
|
1102
|
+
if augmentation_options.get("use_random_shapes", False):
|
|
1103
|
+
log_fn(" - Random shape overlays (occlusion)")
|
|
1104
|
+
if augmentation_options.get("use_grayscale", False):
|
|
1105
|
+
log_fn(" - Random grayscale (50% chance)")
|
|
1106
|
+
if augmentation_options.get("use_lighting_robustness", False):
|
|
1107
|
+
log_fn(" - Lighting/color robustness (gamma + channel gain)")
|
|
1108
|
+
if augmentation_options["use_small_rotation"]:
|
|
1109
|
+
log_fn(" - Small rotation (+/- 5 degrees)")
|
|
1110
|
+
log_fn(" - No cropping (content always preserved)")
|
|
1111
|
+
else:
|
|
1112
|
+
log_fn("No augmentation (using raw clips)")
|
|
1113
|
+
|
|
1114
|
+
clip_length = self.train_config.get("clip_length", 16)
|
|
1115
|
+
resolution = self.train_config.get("resolution", 288)
|
|
1116
|
+
log_fn(f"Using clip_length={clip_length} frames (center-cropped if clips are longer)")
|
|
1117
|
+
log_fn(f"Using input resolution: {resolution}x{resolution}")
|
|
1118
|
+
|
|
1119
|
+
# Virtual Dataset Expansion for small datasets
|
|
1120
|
+
num_train = len(train_clips)
|
|
1121
|
+
virtual_multiplier = 1
|
|
1122
|
+
if num_train > 0 and use_augmentation:
|
|
1123
|
+
virtual_multiplier = int(self.train_config.get("virtual_expansion", 5))
|
|
1124
|
+
log_fn(f"Small dataset detected ({num_train} clips). Using virtual expansion x{virtual_multiplier}")
|
|
1125
|
+
log_fn(f"Effective epoch size: {num_train * virtual_multiplier} samples (unique augmentations)")
|
|
1126
|
+
|
|
1127
|
+
stitch_prob = float(self.train_config.get("stitch_augmentation_prob", 0.0))
|
|
1128
|
+
emb_aug_versions = int(self.train_config.get("emb_aug_versions", 1))
|
|
1129
|
+
_multi_scale = self.train_config.get("multi_scale", False) and not self.train_config.get("use_localization", False)
|
|
1130
|
+
if self.train_config.get("use_localization", False):
|
|
1131
|
+
if stitch_prob > 0.0:
|
|
1132
|
+
stitch_prob = 0.0
|
|
1133
|
+
log_fn("Localization is enabled: disabling clip-stitch augmentation for this run.")
|
|
1134
|
+
aug_note = f" ({emb_aug_versions} aug version(s) per clip)" if emb_aug_versions > 1 else ""
|
|
1135
|
+
ms_note = " + short-scale (multi-scale)" if _multi_scale else ""
|
|
1136
|
+
log_fn(f"Embedding cache: always active{aug_note}{ms_note} — backbone skipped every training step")
|
|
1137
|
+
if stitch_prob > 0.0:
|
|
1138
|
+
log_fn(
|
|
1139
|
+
f"Clip-stitch augmentation: prob={stitch_prob:.0%}, fixed 50/50 split "
|
|
1140
|
+
f"(applied on cached embeddings during classification training)"
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
use_crop_jitter = bool(self.train_config.get("crop_jitter", False))
|
|
1144
|
+
crop_jitter_strength = float(self.train_config.get("crop_jitter_strength", 0.15))
|
|
1145
|
+
if use_crop_jitter and self.train_config.get("use_localization", False):
|
|
1146
|
+
log_fn(f"Crop jitter: enabled (strength={crop_jitter_strength:.0%} of bbox size)")
|
|
1147
|
+
|
|
1148
|
+
train_dataset = BehaviorDataset(
|
|
1149
|
+
train_clips,
|
|
1150
|
+
annotation_manager,
|
|
1151
|
+
classes,
|
|
1152
|
+
self.clips_dir,
|
|
1153
|
+
transform=transform,
|
|
1154
|
+
target_size=(resolution, resolution),
|
|
1155
|
+
clip_length=clip_length,
|
|
1156
|
+
virtual_size_multiplier=virtual_multiplier,
|
|
1157
|
+
stitch_prob=stitch_prob,
|
|
1158
|
+
crop_jitter=bool(self.train_config.get("crop_jitter", False)),
|
|
1159
|
+
crop_jitter_strength=float(self.train_config.get("crop_jitter_strength", 0.15)),
|
|
1160
|
+
ovr_background_classes=self.train_config.get("ovr_background_class_names", []) if hybrid_ovr_bg else [],
|
|
1161
|
+
)
|
|
1162
|
+
log_fn(f"Train dataset created: {len(train_dataset)} virtual samples (from {num_train} unique clips)")
|
|
1163
|
+
except Exception as e:
|
|
1164
|
+
error_msg = f"Failed to create train dataset: {str(e)}\n{traceback.format_exc()}"
|
|
1165
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1166
|
+
self.error.emit(error_msg)
|
|
1167
|
+
return
|
|
1168
|
+
|
|
1169
|
+
val_dataset = None
|
|
1170
|
+
if val_clips:
|
|
1171
|
+
try:
|
|
1172
|
+
clip_length = self.train_config.get("clip_length", 16)
|
|
1173
|
+
val_dataset = BehaviorDataset(
|
|
1174
|
+
val_clips,
|
|
1175
|
+
annotation_manager,
|
|
1176
|
+
classes,
|
|
1177
|
+
self.clips_dir,
|
|
1178
|
+
target_size=(resolution, resolution),
|
|
1179
|
+
clip_length=clip_length,
|
|
1180
|
+
ovr_background_classes=self.train_config.get("ovr_background_class_names", []) if hybrid_ovr_bg else [],
|
|
1181
|
+
)
|
|
1182
|
+
log_fn(f"Val dataset created: {len(val_dataset)} samples")
|
|
1183
|
+
except Exception as e:
|
|
1184
|
+
error_msg = f"Failed to create val dataset: {str(e)}\n{traceback.format_exc()}"
|
|
1185
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1186
|
+
self.error.emit(error_msg)
|
|
1187
|
+
return
|
|
1188
|
+
else:
|
|
1189
|
+
log_fn("No validation dataset (using all data for training)")
|
|
1190
|
+
|
|
1191
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
1192
|
+
|
|
1193
|
+
def check_stop():
|
|
1194
|
+
return self.should_stop
|
|
1195
|
+
|
|
1196
|
+
if self.train_config.get("auto_tune_before_final", False):
|
|
1197
|
+
base_train_cfg = copy.deepcopy(self.train_config)
|
|
1198
|
+
if val_dataset is None:
|
|
1199
|
+
error_msg = "Auto-tune requires a validation set. Disable 'Use all data for training' or set a validation split."
|
|
1200
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1201
|
+
self.error.emit(error_msg)
|
|
1202
|
+
return
|
|
1203
|
+
best_cfg, best_search_result = self._run_autotune_search(
|
|
1204
|
+
train_dataset,
|
|
1205
|
+
val_dataset,
|
|
1206
|
+
self.train_config,
|
|
1207
|
+
classes,
|
|
1208
|
+
augmentation_options,
|
|
1209
|
+
log_fn,
|
|
1210
|
+
progress_cb,
|
|
1211
|
+
check_stop,
|
|
1212
|
+
)
|
|
1213
|
+
if self.should_stop:
|
|
1214
|
+
log_fn("Auto-tune stopped.")
|
|
1215
|
+
self.training_complete.emit(0.0, 0.0, 0.0, {})
|
|
1216
|
+
self.finished.emit()
|
|
1217
|
+
return
|
|
1218
|
+
if not best_cfg:
|
|
1219
|
+
error_msg = "Auto-tune did not produce a valid candidate."
|
|
1220
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1221
|
+
self.error.emit(error_msg)
|
|
1222
|
+
return
|
|
1223
|
+
self._reset_runtime_dataset_caches(train_dataset)
|
|
1224
|
+
self._reset_runtime_dataset_caches(val_dataset)
|
|
1225
|
+
self.train_config = copy.deepcopy(best_cfg)
|
|
1226
|
+
self.train_config["epochs"] = int(base_train_cfg.get("epochs", 1))
|
|
1227
|
+
log_fn("Auto-tune selected final config:")
|
|
1228
|
+
log_fn(
|
|
1229
|
+
f" lr={self.train_config.get('classification_lr', self.train_config.get('lr', 0.0)):.2e}, "
|
|
1230
|
+
f"wd={self.train_config.get('weight_decay', 0.0):.2e}, "
|
|
1231
|
+
f"dropout={self.train_config.get('dropout', 0.0):.2f}, "
|
|
1232
|
+
f"heads={self.train_config.get('head_kwargs', {}).get('num_heads', 4)}, "
|
|
1233
|
+
f"layers={self.train_config.get('frame_head_temporal_layers', 1)}"
|
|
1234
|
+
+ (
|
|
1235
|
+
f", ovr_ls={self.train_config.get('ovr_label_smoothing', 0.0):.2f}"
|
|
1236
|
+
if self.train_config.get("use_ovr", False) else ""
|
|
1237
|
+
)
|
|
1238
|
+
)
|
|
1239
|
+
if best_search_result:
|
|
1240
|
+
log_fn(
|
|
1241
|
+
f" best search val F1={float(best_search_result.get('best_val_f1', 0.0) or 0.0):.2f}, "
|
|
1242
|
+
f"val acc={float(best_search_result.get('best_val_acc', 0.0) or 0.0):.2f}"
|
|
1243
|
+
)
|
|
1244
|
+
self._save_autotuned_profile(self.train_config, best_search_result, log_fn)
|
|
1245
|
+
log_fn("Starting final retrain from scratch with the selected config...")
|
|
1246
|
+
|
|
1247
|
+
try:
|
|
1248
|
+
model, head_kwargs_for_metadata, dropout, localization_dropout, num_stages, multi_scale = self._build_model_for_config(
|
|
1249
|
+
train_dataset, self.train_config, log_fn
|
|
1250
|
+
)
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
error_msg = f"Failed to create classifier: {str(e)}\n{traceback.format_exc()}"
|
|
1253
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1254
|
+
self.error.emit(error_msg)
|
|
1255
|
+
return
|
|
1256
|
+
|
|
1257
|
+
train_config = self._build_backend_train_config(
|
|
1258
|
+
self.train_config,
|
|
1259
|
+
self.output_path,
|
|
1260
|
+
classes,
|
|
1261
|
+
augmentation_options,
|
|
1262
|
+
head_kwargs_for_metadata,
|
|
1263
|
+
dropout,
|
|
1264
|
+
localization_dropout,
|
|
1265
|
+
num_stages,
|
|
1266
|
+
multi_scale,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
log_fn("Starting training loop...")
|
|
1270
|
+
|
|
1271
|
+
def check_stop():
|
|
1272
|
+
return self.should_stop
|
|
1273
|
+
|
|
1274
|
+
try:
|
|
1275
|
+
result = train_model(
|
|
1276
|
+
model,
|
|
1277
|
+
train_dataset,
|
|
1278
|
+
val_dataset,
|
|
1279
|
+
train_config,
|
|
1280
|
+
log_fn=log_fn,
|
|
1281
|
+
progress_callback=progress_cb,
|
|
1282
|
+
stop_callback=check_stop,
|
|
1283
|
+
metrics_callback=metrics_cb
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
if self.should_stop:
|
|
1287
|
+
log_fn("Training stopped.")
|
|
1288
|
+
self.training_complete.emit(0.0, 0.0, 0.0, {})
|
|
1289
|
+
else:
|
|
1290
|
+
log_fn("Training completed successfully!")
|
|
1291
|
+
best_val = result.get("best_val_acc", 0.0) if result else 0.0
|
|
1292
|
+
best_f1 = result.get("best_val_f1", 0.0) if result else 0.0
|
|
1293
|
+
final_train = result.get("final_train_acc", 0.0) if result else 0.0
|
|
1294
|
+
per_class_f1 = result.get("per_class_f1", {})
|
|
1295
|
+
|
|
1296
|
+
self.training_complete.emit(best_val, best_f1, final_train, per_class_f1)
|
|
1297
|
+
|
|
1298
|
+
self.finished.emit()
|
|
1299
|
+
except Exception as e:
|
|
1300
|
+
error_msg = f"Training failed: {str(e)}\n{traceback.format_exc()}"
|
|
1301
|
+
log_fn(f"ERROR: {error_msg}")
|
|
1302
|
+
self.error.emit(error_msg)
|
|
1303
|
+
return
|
|
1304
|
+
|
|
1305
|
+
except Exception as e:
|
|
1306
|
+
error_msg = f"Unexpected error: {str(e)}\n{traceback.format_exc()}"
|
|
1307
|
+
self.log_message.emit(f"ERROR: {error_msg}")
|
|
1308
|
+
self.error.emit(error_msg)
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
class TrainingVisualizationDialog(QDialog):
|
|
1312
|
+
"""Real-time training visualization with adaptive layout based on active metrics."""
|
|
1313
|
+
|
|
1314
|
+
_COLORS = {
|
|
1315
|
+
"train": "#2196F3",
|
|
1316
|
+
"val": "#FF9800",
|
|
1317
|
+
"train_cls": "#64B5F6",
|
|
1318
|
+
"macro_f1": "#4CAF50",
|
|
1319
|
+
"grid": "#e0e0e0",
|
|
1320
|
+
"bg": "#fafafa",
|
|
1321
|
+
"text": "#333333",
|
|
1322
|
+
"loc_iou": "#AB47BC",
|
|
1323
|
+
"loc_cerr": "#EF5350",
|
|
1324
|
+
"loc_vrate": "#26A69A",
|
|
1325
|
+
}
|
|
1326
|
+
_CLASS_PALETTE = [
|
|
1327
|
+
"#e6194b", "#3cb44b", "#4363d8", "#f58231", "#911eb4",
|
|
1328
|
+
"#42d4f4", "#f032e6", "#bfef45", "#fabed4", "#469990",
|
|
1329
|
+
"#dcbeff", "#9A6324", "#800000", "#aaffc3", "#808000",
|
|
1330
|
+
"#000075", "#a9a9a9",
|
|
1331
|
+
]
|
|
1332
|
+
|
|
1333
|
+
def __init__(self, parent=None):
|
|
1334
|
+
super().__init__(parent)
|
|
1335
|
+
self.setWindowTitle("Training Monitor")
|
|
1336
|
+
self.resize(1150, 850)
|
|
1337
|
+
|
|
1338
|
+
root = QVBoxLayout()
|
|
1339
|
+
root.setContentsMargins(6, 6, 6, 6)
|
|
1340
|
+
self.setLayout(root)
|
|
1341
|
+
|
|
1342
|
+
top_bar = QHBoxLayout()
|
|
1343
|
+
self._status_label = QLabel("Waiting for first epoch...")
|
|
1344
|
+
self._status_label.setStyleSheet("font-weight: bold; font-size: 13px; padding: 4px 0;")
|
|
1345
|
+
top_bar.addWidget(self._status_label, stretch=1)
|
|
1346
|
+
|
|
1347
|
+
# F1 class filter: toggle which per-class lines are visible
|
|
1348
|
+
self._f1_filter_layout = QHBoxLayout()
|
|
1349
|
+
self._f1_filter_layout.setContentsMargins(0, 0, 0, 0)
|
|
1350
|
+
self._f1_filter_label = QLabel("F1 classes:")
|
|
1351
|
+
self._f1_filter_label.setStyleSheet("font-size: 11px; color: #666;")
|
|
1352
|
+
self._f1_filter_layout.addWidget(self._f1_filter_label)
|
|
1353
|
+
self._f1_filter_checks: dict[str, QCheckBox] = {}
|
|
1354
|
+
self._f1_filter_container = QWidget()
|
|
1355
|
+
self._f1_filter_container.setLayout(self._f1_filter_layout)
|
|
1356
|
+
self._f1_filter_container.setVisible(False)
|
|
1357
|
+
|
|
1358
|
+
root.addLayout(top_bar)
|
|
1359
|
+
root.addWidget(self._f1_filter_container)
|
|
1360
|
+
|
|
1361
|
+
# Horizontal splitter: charts left, crop preview right
|
|
1362
|
+
from PyQt6.QtWidgets import QSplitter, QScrollArea
|
|
1363
|
+
self._splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
1364
|
+
root.addWidget(self._splitter, stretch=1)
|
|
1365
|
+
|
|
1366
|
+
# Left: matplotlib charts
|
|
1367
|
+
chart_widget = QWidget()
|
|
1368
|
+
chart_layout = QVBoxLayout()
|
|
1369
|
+
chart_layout.setContentsMargins(0, 0, 0, 0)
|
|
1370
|
+
chart_widget.setLayout(chart_layout)
|
|
1371
|
+
self.figure = Figure(figsize=(10, 8), dpi=100)
|
|
1372
|
+
self.figure.patch.set_facecolor(self._COLORS["bg"])
|
|
1373
|
+
self.canvas = FigureCanvas(self.figure)
|
|
1374
|
+
self.canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
1375
|
+
chart_layout.addWidget(self.canvas)
|
|
1376
|
+
self._splitter.addWidget(chart_widget)
|
|
1377
|
+
|
|
1378
|
+
# Right: crop progress preview (hidden until localization is active)
|
|
1379
|
+
self._crop_panel = QWidget()
|
|
1380
|
+
crop_layout = QVBoxLayout()
|
|
1381
|
+
crop_layout.setContentsMargins(4, 0, 4, 0)
|
|
1382
|
+
self._crop_panel.setLayout(crop_layout)
|
|
1383
|
+
|
|
1384
|
+
crop_header = QLabel("Localization Crop Preview")
|
|
1385
|
+
crop_header.setStyleSheet("font-weight: bold; font-size: 12px; padding: 2px 0;")
|
|
1386
|
+
crop_header.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1387
|
+
crop_layout.addWidget(crop_header)
|
|
1388
|
+
|
|
1389
|
+
self._crop_epoch_label = QLabel("")
|
|
1390
|
+
self._crop_epoch_label.setStyleSheet("font-size: 11px; color: #666; padding: 0 0 4px 0;")
|
|
1391
|
+
self._crop_epoch_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1392
|
+
crop_layout.addWidget(self._crop_epoch_label)
|
|
1393
|
+
|
|
1394
|
+
# Scrollable area for crop images
|
|
1395
|
+
scroll = QScrollArea()
|
|
1396
|
+
scroll.setWidgetResizable(True)
|
|
1397
|
+
scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
1398
|
+
scroll.setStyleSheet("QScrollArea { border: none; background: #fafafa; }")
|
|
1399
|
+
self._crop_container = QWidget()
|
|
1400
|
+
self._crop_container_layout = QVBoxLayout()
|
|
1401
|
+
self._crop_container_layout.setContentsMargins(0, 0, 0, 0)
|
|
1402
|
+
self._crop_container_layout.setSpacing(6)
|
|
1403
|
+
self._crop_container.setLayout(self._crop_container_layout)
|
|
1404
|
+
scroll.setWidget(self._crop_container)
|
|
1405
|
+
crop_layout.addWidget(scroll, stretch=1)
|
|
1406
|
+
|
|
1407
|
+
self._crop_panel.setVisible(False)
|
|
1408
|
+
self._splitter.addWidget(self._crop_panel)
|
|
1409
|
+
self._splitter.setStretchFactor(0, 3)
|
|
1410
|
+
self._splitter.setStretchFactor(1, 1)
|
|
1411
|
+
|
|
1412
|
+
self._axes = {}
|
|
1413
|
+
self._crop_labels: list = []
|
|
1414
|
+
self._init_data()
|
|
1415
|
+
|
|
1416
|
+
def _init_data(self):
|
|
1417
|
+
self.epochs = []
|
|
1418
|
+
self._confusion_warmup_epoch = 0
|
|
1419
|
+
self.train_acc = []
|
|
1420
|
+
self.val_acc = []
|
|
1421
|
+
self.train_loss = []
|
|
1422
|
+
self.train_loss_class = []
|
|
1423
|
+
self.val_loss = []
|
|
1424
|
+
self.val_f1 = []
|
|
1425
|
+
self.per_class_f1 = {}
|
|
1426
|
+
self.loc_iou = []
|
|
1427
|
+
self.loc_center_error = []
|
|
1428
|
+
self.loc_valid_rate = []
|
|
1429
|
+
self._has_localization = False
|
|
1430
|
+
self._has_validation = False
|
|
1431
|
+
self._phase = "classification"
|
|
1432
|
+
self._crop_progress_dir = None
|
|
1433
|
+
self._last_crop_epoch = -1
|
|
1434
|
+
self._f1_classes_to_show = None # None = show all; set of str = only these in F1 graph
|
|
1435
|
+
|
|
1436
|
+
def reset(self, f1_classes_to_show=None, confusion_warmup_epoch: int = 0):
|
|
1437
|
+
"""Clear all data for a new training run.
|
|
1438
|
+
|
|
1439
|
+
f1_classes_to_show: when limit_classes is used, set of class names whose
|
|
1440
|
+
checkboxes start checked; None = all checked by default.
|
|
1441
|
+
confusion_warmup_epoch: epoch after which confusion sampler activates (0 = off/no line).
|
|
1442
|
+
"""
|
|
1443
|
+
self._init_data()
|
|
1444
|
+
self._confusion_warmup_epoch = confusion_warmup_epoch
|
|
1445
|
+
self._f1_classes_to_show = f1_classes_to_show
|
|
1446
|
+
for cb in self._f1_filter_checks.values():
|
|
1447
|
+
cb.setParent(None)
|
|
1448
|
+
cb.deleteLater()
|
|
1449
|
+
self._f1_filter_checks.clear()
|
|
1450
|
+
self._f1_filter_container.setVisible(False)
|
|
1451
|
+
self.figure.clear()
|
|
1452
|
+
self._axes.clear()
|
|
1453
|
+
self.canvas.draw()
|
|
1454
|
+
self._status_label.setText("Waiting for first epoch...")
|
|
1455
|
+
self._crop_panel.setVisible(False)
|
|
1456
|
+
self._crop_epoch_label.setText("")
|
|
1457
|
+
self._clear_crop_images()
|
|
1458
|
+
|
|
1459
|
+
def _clear_crop_images(self):
|
|
1460
|
+
for lbl in self._crop_labels:
|
|
1461
|
+
lbl.setParent(None)
|
|
1462
|
+
lbl.deleteLater()
|
|
1463
|
+
self._crop_labels.clear()
|
|
1464
|
+
|
|
1465
|
+
# Layout helpers.
|
|
1466
|
+
|
|
1467
|
+
def _rebuild_layout(self):
|
|
1468
|
+
"""Create subplot grid based on which metrics are active."""
|
|
1469
|
+
self.figure.clear()
|
|
1470
|
+
self._axes.clear()
|
|
1471
|
+
|
|
1472
|
+
panels = ["loss", "acc", "f1"]
|
|
1473
|
+
if self._has_localization:
|
|
1474
|
+
panels.append("loc")
|
|
1475
|
+
|
|
1476
|
+
n = len(panels)
|
|
1477
|
+
for i, key in enumerate(panels):
|
|
1478
|
+
ax = self.figure.add_subplot(n, 1, i + 1)
|
|
1479
|
+
ax.set_facecolor(self._COLORS["bg"])
|
|
1480
|
+
self._axes[key] = ax
|
|
1481
|
+
|
|
1482
|
+
def _style_ax(self, ax, title, ylabel, xlabel=None):
|
|
1483
|
+
ax.set_title(title, fontsize=10, fontweight="bold", color=self._COLORS["text"], pad=6)
|
|
1484
|
+
ax.set_ylabel(ylabel, fontsize=9, color=self._COLORS["text"])
|
|
1485
|
+
if xlabel:
|
|
1486
|
+
ax.set_xlabel(xlabel, fontsize=9, color=self._COLORS["text"])
|
|
1487
|
+
ax.grid(True, linewidth=0.5, color=self._COLORS["grid"], alpha=0.7)
|
|
1488
|
+
ax.tick_params(labelsize=8, colors=self._COLORS["text"])
|
|
1489
|
+
for spine in ax.spines.values():
|
|
1490
|
+
spine.set_color(self._COLORS["grid"])
|
|
1491
|
+
|
|
1492
|
+
# Crop preview.
|
|
1493
|
+
|
|
1494
|
+
def _update_crop_preview(self, epoch, crop_dir):
|
|
1495
|
+
"""Load latest crop progress PNGs from disk and display them."""
|
|
1496
|
+
if not crop_dir:
|
|
1497
|
+
return
|
|
1498
|
+
self._crop_progress_dir = crop_dir
|
|
1499
|
+
|
|
1500
|
+
# Crop progress images are saved every 2 epochs as epoch_NNN_sample_M.png
|
|
1501
|
+
pattern = os.path.join(crop_dir, f"epoch_{epoch:03d}_sample_*.png")
|
|
1502
|
+
files = sorted(glob.glob(pattern))
|
|
1503
|
+
if not files:
|
|
1504
|
+
return
|
|
1505
|
+
if epoch == self._last_crop_epoch:
|
|
1506
|
+
return
|
|
1507
|
+
self._last_crop_epoch = epoch
|
|
1508
|
+
|
|
1509
|
+
self._clear_crop_images()
|
|
1510
|
+
self._crop_panel.setVisible(True)
|
|
1511
|
+
self._crop_epoch_label.setText(f"Epoch {epoch} · {len(files)} sample(s)")
|
|
1512
|
+
|
|
1513
|
+
panel_width = max(280, self._crop_panel.width() - 20)
|
|
1514
|
+
for fpath in files:
|
|
1515
|
+
try:
|
|
1516
|
+
pixmap = QPixmap(fpath)
|
|
1517
|
+
if pixmap.isNull():
|
|
1518
|
+
continue
|
|
1519
|
+
scaled = pixmap.scaledToWidth(panel_width, Qt.TransformationMode.SmoothTransformation)
|
|
1520
|
+
lbl = QLabel()
|
|
1521
|
+
lbl.setPixmap(scaled)
|
|
1522
|
+
lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1523
|
+
lbl.setStyleSheet("border: 1px solid #ccc; border-radius: 3px; background: white; padding: 2px;")
|
|
1524
|
+
lbl.setToolTip(os.path.basename(fpath))
|
|
1525
|
+
self._crop_container_layout.addWidget(lbl)
|
|
1526
|
+
self._crop_labels.append(lbl)
|
|
1527
|
+
except Exception as e:
|
|
1528
|
+
logger.debug("Could not load crop preview image: %s", e)
|
|
1529
|
+
|
|
1530
|
+
# Main update.
|
|
1531
|
+
|
|
1532
|
+
def update_plots(self, metrics):
|
|
1533
|
+
"""Update plots with new epoch metrics dict."""
|
|
1534
|
+
epoch = metrics["epoch"]
|
|
1535
|
+
self.epochs.append(epoch)
|
|
1536
|
+
self.train_acc.append(metrics["train_acc"])
|
|
1537
|
+
# In frame-level training, clip-level val_acc is not meaningful/updated.
|
|
1538
|
+
# Prefer val_frame_acc for monitor accuracy visualization.
|
|
1539
|
+
self.val_acc.append(metrics.get("val_frame_acc", metrics.get("val_acc", 0.0)))
|
|
1540
|
+
self.train_loss.append(metrics["train_loss"])
|
|
1541
|
+
self.train_loss_class.append(metrics.get("train_loss_class", 0.0))
|
|
1542
|
+
self.val_loss.append(metrics.get("val_loss", 0.0))
|
|
1543
|
+
self.val_f1.append(metrics.get("val_f1", 0.0))
|
|
1544
|
+
self._phase = metrics.get("training_phase", "classification")
|
|
1545
|
+
|
|
1546
|
+
if not self._has_validation and (
|
|
1547
|
+
any(v > 0 for v in self.val_acc)
|
|
1548
|
+
or any(v > 0 for v in self.val_loss)
|
|
1549
|
+
or any(v > 0 for v in self.val_f1)
|
|
1550
|
+
):
|
|
1551
|
+
self._has_validation = True
|
|
1552
|
+
|
|
1553
|
+
# Localization metrics
|
|
1554
|
+
iou = metrics.get("loc_val_iou", 0.0)
|
|
1555
|
+
cerr = metrics.get("loc_val_center_error", 0.0)
|
|
1556
|
+
vrate = metrics.get("loc_val_valid_rate", 0.0)
|
|
1557
|
+
self.loc_iou.append(iou)
|
|
1558
|
+
self.loc_center_error.append(cerr)
|
|
1559
|
+
self.loc_valid_rate.append(vrate)
|
|
1560
|
+
if not self._has_localization and (iou > 0 or vrate > 0):
|
|
1561
|
+
self._has_localization = True
|
|
1562
|
+
|
|
1563
|
+
# Per-class F1
|
|
1564
|
+
for cls, score in metrics.get("per_class_f1", {}).items():
|
|
1565
|
+
if cls not in self.per_class_f1:
|
|
1566
|
+
self.per_class_f1[cls] = [0.0] * (len(self.epochs) - 1)
|
|
1567
|
+
self.per_class_f1[cls].append(score)
|
|
1568
|
+
if cls not in self._f1_filter_checks:
|
|
1569
|
+
checked = self._f1_classes_to_show is None or cls in self._f1_classes_to_show
|
|
1570
|
+
cb = QCheckBox(cls)
|
|
1571
|
+
cb.setChecked(checked)
|
|
1572
|
+
cb.setStyleSheet("font-size: 11px;")
|
|
1573
|
+
cb.stateChanged.connect(lambda _: self._redraw_f1_only())
|
|
1574
|
+
self._f1_filter_checks[cls] = cb
|
|
1575
|
+
self._f1_filter_layout.addWidget(cb)
|
|
1576
|
+
self._f1_filter_container.setVisible(True)
|
|
1577
|
+
for cls in self.per_class_f1:
|
|
1578
|
+
while len(self.per_class_f1[cls]) < len(self.epochs):
|
|
1579
|
+
self.per_class_f1[cls].append(0.0)
|
|
1580
|
+
|
|
1581
|
+
self._rebuild_layout()
|
|
1582
|
+
self._draw_loss()
|
|
1583
|
+
self._draw_acc()
|
|
1584
|
+
self._draw_f1()
|
|
1585
|
+
if self._has_localization and "loc" in self._axes:
|
|
1586
|
+
self._draw_loc()
|
|
1587
|
+
|
|
1588
|
+
# Only the bottom subplot gets an x-label
|
|
1589
|
+
panels = list(self._axes.values())
|
|
1590
|
+
if panels:
|
|
1591
|
+
panels[-1].set_xlabel("Epoch", fontsize=9, color=self._COLORS["text"])
|
|
1592
|
+
|
|
1593
|
+
self.figure.tight_layout(h_pad=1.2)
|
|
1594
|
+
self.canvas.draw()
|
|
1595
|
+
|
|
1596
|
+
# Crop preview: load latest images when available
|
|
1597
|
+
crop_dir = metrics.get("crop_progress_dir")
|
|
1598
|
+
if crop_dir:
|
|
1599
|
+
self._update_crop_preview(epoch, crop_dir)
|
|
1600
|
+
|
|
1601
|
+
# Status bar
|
|
1602
|
+
phase_tag = f"[{self._phase}]" if self._phase != "classification" else ""
|
|
1603
|
+
best_f1 = max(self.val_f1) if self.val_f1 else 0.0
|
|
1604
|
+
self._status_label.setText(
|
|
1605
|
+
f"Epoch {epoch} {phase_tag} | Train Loss: {self.train_loss[-1]:.4f} | "
|
|
1606
|
+
f"Val F1: {self.val_f1[-1]:.1f}% (best {best_f1:.1f}%) | "
|
|
1607
|
+
f"Val Frame Acc: {self.val_acc[-1]:.1f}%"
|
|
1608
|
+
)
|
|
1609
|
+
|
|
1610
|
+
# Individual panel drawers.
|
|
1611
|
+
|
|
1612
|
+
def _draw_loss(self):
|
|
1613
|
+
ax = self._axes["loss"]
|
|
1614
|
+
ax.plot(self.epochs, self.train_loss, color=self._COLORS["train"],
|
|
1615
|
+
linewidth=2, label="Train Loss")
|
|
1616
|
+
if any(v > 0 for v in self.train_loss_class):
|
|
1617
|
+
ax.plot(self.epochs, self.train_loss_class, color=self._COLORS["train_cls"],
|
|
1618
|
+
linewidth=1.3, linestyle="--", alpha=0.8, label="Classification Loss")
|
|
1619
|
+
if self._has_validation:
|
|
1620
|
+
ax.plot(self.epochs, self.val_loss, color=self._COLORS["val"],
|
|
1621
|
+
linewidth=2, label="Val Loss")
|
|
1622
|
+
self._style_ax(ax, "Loss", "Loss")
|
|
1623
|
+
ax.legend(fontsize=8, loc="upper right", framealpha=0.8)
|
|
1624
|
+
|
|
1625
|
+
def _draw_acc(self):
|
|
1626
|
+
ax = self._axes["acc"]
|
|
1627
|
+
ax.plot(self.epochs, self.train_acc, color=self._COLORS["train"],
|
|
1628
|
+
linewidth=2, label="Train Acc")
|
|
1629
|
+
if self._has_validation:
|
|
1630
|
+
ax.plot(self.epochs, self.val_acc, color=self._COLORS["val"],
|
|
1631
|
+
linewidth=2, label="Val Frame Acc")
|
|
1632
|
+
self._style_ax(ax, "Accuracy", "Accuracy (%)")
|
|
1633
|
+
ax.legend(fontsize=8, loc="lower right", framealpha=0.8)
|
|
1634
|
+
|
|
1635
|
+
def _redraw_f1_only(self):
|
|
1636
|
+
"""Redraw just the F1 panel when class visibility toggles change."""
|
|
1637
|
+
if "f1" not in self._axes or not self.epochs:
|
|
1638
|
+
return
|
|
1639
|
+
self._axes["f1"].clear()
|
|
1640
|
+
self._draw_f1()
|
|
1641
|
+
self.canvas.draw()
|
|
1642
|
+
|
|
1643
|
+
def _draw_f1(self):
|
|
1644
|
+
ax = self._axes["f1"]
|
|
1645
|
+
ax.plot(self.epochs, self.val_f1, color=self._COLORS["macro_f1"],
|
|
1646
|
+
linewidth=2.5, label="Macro F1")
|
|
1647
|
+
visible_items = []
|
|
1648
|
+
for i, (cls, scores) in enumerate(self.per_class_f1.items()):
|
|
1649
|
+
cb = self._f1_filter_checks.get(cls)
|
|
1650
|
+
if cb is not None and not cb.isChecked():
|
|
1651
|
+
continue
|
|
1652
|
+
color = self._CLASS_PALETTE[i % len(self._CLASS_PALETTE)]
|
|
1653
|
+
ax.plot(self.epochs, scores, color=color,
|
|
1654
|
+
linewidth=1.2, linestyle="--", alpha=0.7, label=cls)
|
|
1655
|
+
visible_items.append(cls)
|
|
1656
|
+
if self._confusion_warmup_epoch > 0:
|
|
1657
|
+
ax.axvline(x=self._confusion_warmup_epoch, color="#9E9E9E",
|
|
1658
|
+
linestyle=":", linewidth=1.2, alpha=0.8)
|
|
1659
|
+
ax.text(self._confusion_warmup_epoch + 0.3, ax.get_ylim()[1] * 0.97,
|
|
1660
|
+
"confusion sampler", fontsize=6, color="#757575",
|
|
1661
|
+
va="top", ha="left", rotation=0)
|
|
1662
|
+
self._style_ax(ax, "Val Macro F1 (frame)", "F1 (%)")
|
|
1663
|
+
ncol = min(4, max(1, len(visible_items) + 1))
|
|
1664
|
+
ax.legend(fontsize=7, ncol=ncol, loc="lower right", framealpha=0.8)
|
|
1665
|
+
|
|
1666
|
+
def _draw_loc(self):
|
|
1667
|
+
ax = self._axes["loc"]
|
|
1668
|
+
ax.plot(self.epochs, self.loc_iou, color=self._COLORS["loc_iou"],
|
|
1669
|
+
linewidth=2, label="IoU")
|
|
1670
|
+
ax.plot(self.epochs, self.loc_valid_rate, color=self._COLORS["loc_vrate"],
|
|
1671
|
+
linewidth=2, label="Valid Rate")
|
|
1672
|
+
ax2 = ax.twinx()
|
|
1673
|
+
ax2.plot(self.epochs, self.loc_center_error, color=self._COLORS["loc_cerr"],
|
|
1674
|
+
linewidth=1.5, linestyle="--", alpha=0.8, label="Center Err")
|
|
1675
|
+
ax2.set_ylabel("Center Error", fontsize=8, color=self._COLORS["loc_cerr"])
|
|
1676
|
+
ax2.tick_params(labelsize=8, colors=self._COLORS["loc_cerr"])
|
|
1677
|
+
self._style_ax(ax, "Localization", "IoU / Valid Rate")
|
|
1678
|
+
lines1, labels1 = ax.get_legend_handles_labels()
|
|
1679
|
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
|
1680
|
+
ax.legend(lines1 + lines2, labels1 + labels2, fontsize=8, loc="lower right", framealpha=0.8)
|
|
1681
|
+
|
|
1682
|
+
class TrainingWidget(QWidget):
|
|
1683
|
+
"""Widget for training the behavior classifier."""
|
|
1684
|
+
|
|
1685
|
+
def __init__(self, config: dict):
|
|
1686
|
+
super().__init__()
|
|
1687
|
+
self.config = config
|
|
1688
|
+
self.augmentation_options = self._default_augmentation_options()
|
|
1689
|
+
self.worker = None
|
|
1690
|
+
self.annotation_manager = AnnotationManager(
|
|
1691
|
+
self.config.get(
|
|
1692
|
+
"training_annotation_file",
|
|
1693
|
+
self.config.get("annotation_file", "data/annotations/annotations.json"),
|
|
1694
|
+
)
|
|
1695
|
+
)
|
|
1696
|
+
self._config_initialized = False
|
|
1697
|
+
self.profile_dialog = None
|
|
1698
|
+
self.training_queue = []
|
|
1699
|
+
self.is_batch_training = False
|
|
1700
|
+
self.batch_results = []
|
|
1701
|
+
self.batch_results_path = None
|
|
1702
|
+
self.current_profile_name = None
|
|
1703
|
+
self.visualization_dialog = None
|
|
1704
|
+
self._resolution = int(self.config.get("resolution", 288))
|
|
1705
|
+
self._setup_ui()
|
|
1706
|
+
self._load_current_config(force=True)
|
|
1707
|
+
self.refresh_annotation_info()
|
|
1708
|
+
|
|
1709
|
+
def _load_current_config(self, force: bool = False):
|
|
1710
|
+
"""Apply current config paths to UI."""
|
|
1711
|
+
if self._config_initialized and not force:
|
|
1712
|
+
return
|
|
1713
|
+
|
|
1714
|
+
annotation_file = self.config.get(
|
|
1715
|
+
"training_annotation_file",
|
|
1716
|
+
self.config.get("annotation_file", "data/annotations/annotations.json"),
|
|
1717
|
+
)
|
|
1718
|
+
clips_dir = self.config.get(
|
|
1719
|
+
"training_clips_dir",
|
|
1720
|
+
self.config.get("clips_dir", "data/clips"),
|
|
1721
|
+
)
|
|
1722
|
+
models_dir = self.config.get("models_dir", "models/behavior_heads")
|
|
1723
|
+
|
|
1724
|
+
if hasattr(self, "annotation_file_edit"):
|
|
1725
|
+
self.annotation_file_edit.setText(annotation_file)
|
|
1726
|
+
if hasattr(self, "clips_dir_edit"):
|
|
1727
|
+
self.clips_dir_edit.setText(clips_dir)
|
|
1728
|
+
if hasattr(self, "output_path_edit"):
|
|
1729
|
+
default_output = os.path.join(models_dir, "head.pt")
|
|
1730
|
+
self.output_path_edit.setText(default_output)
|
|
1731
|
+
self._resolution = int(self.config.get("resolution", 288))
|
|
1732
|
+
if self._resolution % 18 != 0:
|
|
1733
|
+
self._resolution = (self._resolution // 18) * 18
|
|
1734
|
+
if hasattr(self, "weight_decay_spin"):
|
|
1735
|
+
self.weight_decay_spin.setValue(float(self.config.get("default_weight_decay", 0.001)))
|
|
1736
|
+
if hasattr(self, "use_supcon_check"):
|
|
1737
|
+
default_use_supcon = bool(self.config.get("default_use_supcon_loss", False))
|
|
1738
|
+
self.use_supcon_check.setChecked(default_use_supcon)
|
|
1739
|
+
if hasattr(self, "supcon_weight_spin"):
|
|
1740
|
+
self.supcon_weight_spin.setValue(float(self.config.get("default_supcon_weight", 0.2)))
|
|
1741
|
+
self.supcon_weight_spin.setEnabled(default_use_supcon)
|
|
1742
|
+
if hasattr(self, "supcon_temp_spin"):
|
|
1743
|
+
self.supcon_temp_spin.setValue(float(self.config.get("default_supcon_temperature", 0.1)))
|
|
1744
|
+
self.supcon_temp_spin.setEnabled(default_use_supcon)
|
|
1745
|
+
|
|
1746
|
+
self._config_initialized = True
|
|
1747
|
+
|
|
1748
|
+
def _setup_ui(self):
|
|
1749
|
+
"""Setup UI components."""
|
|
1750
|
+
layout = QVBoxLayout()
|
|
1751
|
+
|
|
1752
|
+
# Dataset Info with scrollbar
|
|
1753
|
+
info_group = QGroupBox("Dataset info")
|
|
1754
|
+
info_layout = QVBoxLayout()
|
|
1755
|
+
self.info_label = QLabel("Loading...")
|
|
1756
|
+
self.info_label.setWordWrap(True)
|
|
1757
|
+
info_layout.addWidget(self.info_label)
|
|
1758
|
+
info_group.setLayout(info_layout)
|
|
1759
|
+
|
|
1760
|
+
info_scroll = QScrollArea()
|
|
1761
|
+
info_scroll.setWidget(info_group)
|
|
1762
|
+
info_scroll.setWidgetResizable(True)
|
|
1763
|
+
info_scroll.setMinimumHeight(120)
|
|
1764
|
+
info_scroll.setMaximumHeight(200)
|
|
1765
|
+
|
|
1766
|
+
# Training Configuration inside a scrollable container with grouped sections
|
|
1767
|
+
config_container = QWidget()
|
|
1768
|
+
config_vbox = QVBoxLayout(config_container)
|
|
1769
|
+
config_vbox.setContentsMargins(2, 2, 2, 2)
|
|
1770
|
+
config_vbox.setSpacing(6)
|
|
1771
|
+
|
|
1772
|
+
# --- Paths & Files ---
|
|
1773
|
+
paths_group = QGroupBox("Paths && Files")
|
|
1774
|
+
config_layout = QFormLayout()
|
|
1775
|
+
|
|
1776
|
+
self.annotation_file_edit = QLineEdit()
|
|
1777
|
+
self.annotation_file_edit.setText(
|
|
1778
|
+
self.config.get(
|
|
1779
|
+
"training_annotation_file",
|
|
1780
|
+
self.config.get("annotation_file", "data/annotations/annotations.json"),
|
|
1781
|
+
)
|
|
1782
|
+
)
|
|
1783
|
+
self.annotation_browse_btn = QPushButton("Browse...")
|
|
1784
|
+
self.annotation_browse_btn.clicked.connect(self._browse_annotation)
|
|
1785
|
+
annotation_layout = QHBoxLayout()
|
|
1786
|
+
annotation_layout.addWidget(self.annotation_file_edit)
|
|
1787
|
+
annotation_layout.addWidget(self.annotation_browse_btn)
|
|
1788
|
+
config_layout.addRow("Annotation file:", annotation_layout)
|
|
1789
|
+
|
|
1790
|
+
self.clips_dir_edit = QLineEdit()
|
|
1791
|
+
self.clips_dir_edit.setText(
|
|
1792
|
+
self.config.get(
|
|
1793
|
+
"training_clips_dir",
|
|
1794
|
+
self.config.get("clips_dir", "data/clips"),
|
|
1795
|
+
)
|
|
1796
|
+
)
|
|
1797
|
+
self.clips_browse_btn = QPushButton("Browse...")
|
|
1798
|
+
self.clips_browse_btn.clicked.connect(self._browse_clips_dir)
|
|
1799
|
+
clips_layout = QHBoxLayout()
|
|
1800
|
+
clips_layout.addWidget(self.clips_dir_edit)
|
|
1801
|
+
clips_layout.addWidget(self.clips_browse_btn)
|
|
1802
|
+
config_layout.addRow("Clips directory:", clips_layout)
|
|
1803
|
+
|
|
1804
|
+
self.output_path_edit = QLineEdit()
|
|
1805
|
+
self.output_path_edit.setText(self.config.get("models_dir", "models/behavior_heads") + "/head.pt")
|
|
1806
|
+
self.output_browse_btn = QPushButton("Browse...")
|
|
1807
|
+
self.output_browse_btn.clicked.connect(self._browse_output)
|
|
1808
|
+
output_layout = QHBoxLayout()
|
|
1809
|
+
output_layout.addWidget(self.output_path_edit)
|
|
1810
|
+
output_layout.addWidget(self.output_browse_btn)
|
|
1811
|
+
config_layout.addRow("Output model:", output_layout)
|
|
1812
|
+
|
|
1813
|
+
paths_group.setLayout(config_layout)
|
|
1814
|
+
|
|
1815
|
+
# --- Training Hyperparameters ---
|
|
1816
|
+
hyper_group = QGroupBox("Training Hyperparameters")
|
|
1817
|
+
config_layout = QFormLayout()
|
|
1818
|
+
|
|
1819
|
+
clip_length_layout = QHBoxLayout()
|
|
1820
|
+
self.clip_length_spin = QSpinBox()
|
|
1821
|
+
self.clip_length_spin.setRange(1, 64)
|
|
1822
|
+
self.clip_length_spin.setValue(int(self.config.get("default_clip_length", 8)))
|
|
1823
|
+
self.clip_length_spin.setToolTip(
|
|
1824
|
+
"Number of frames to use for training.\n"
|
|
1825
|
+
"Can be equal to or less than the actual clip length.\n"
|
|
1826
|
+
"If less, the middle N frames are used (temporal center-crop)."
|
|
1827
|
+
)
|
|
1828
|
+
clip_length_layout.addWidget(self.clip_length_spin)
|
|
1829
|
+
|
|
1830
|
+
info_btn = QToolButton()
|
|
1831
|
+
info_btn.setText("?")
|
|
1832
|
+
info_btn.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextOnly)
|
|
1833
|
+
info_btn.setStyleSheet("""
|
|
1834
|
+
QToolButton {
|
|
1835
|
+
background-color: #4A90E2;
|
|
1836
|
+
color: white;
|
|
1837
|
+
border: none;
|
|
1838
|
+
border-radius: 10px;
|
|
1839
|
+
width: 20px;
|
|
1840
|
+
height: 20px;
|
|
1841
|
+
font-weight: bold;
|
|
1842
|
+
font-size: 12px;
|
|
1843
|
+
}
|
|
1844
|
+
QToolButton:hover {
|
|
1845
|
+
background-color: #357ABD;
|
|
1846
|
+
}
|
|
1847
|
+
""")
|
|
1848
|
+
info_btn.setFixedSize(20, 20)
|
|
1849
|
+
info_btn.setToolTip("Click for information about frames per clip")
|
|
1850
|
+
info_btn.clicked.connect(self._show_clip_length_info)
|
|
1851
|
+
clip_length_layout.addWidget(info_btn)
|
|
1852
|
+
clip_length_layout.addStretch()
|
|
1853
|
+
|
|
1854
|
+
config_layout.addRow("Frames per clip:", clip_length_layout)
|
|
1855
|
+
|
|
1856
|
+
self.batch_size_spin = QSpinBox()
|
|
1857
|
+
self.batch_size_spin.setRange(1, 32)
|
|
1858
|
+
self.batch_size_spin.setValue(8)
|
|
1859
|
+
config_layout.addRow("Batch size:", self.batch_size_spin)
|
|
1860
|
+
|
|
1861
|
+
self.epochs_spin = QSpinBox()
|
|
1862
|
+
self.epochs_spin.setRange(1, 1000)
|
|
1863
|
+
self.epochs_spin.setValue(60)
|
|
1864
|
+
config_layout.addRow("Epochs:", self.epochs_spin)
|
|
1865
|
+
|
|
1866
|
+
lr_row = QHBoxLayout()
|
|
1867
|
+
self.loc_lr_spin = QDoubleSpinBox()
|
|
1868
|
+
self.loc_lr_spin.setRange(1e-6, 1.0)
|
|
1869
|
+
self.loc_lr_spin.setValue(1e-4)
|
|
1870
|
+
self.loc_lr_spin.setDecimals(6)
|
|
1871
|
+
self.loc_lr_spin.setSingleStep(1e-4)
|
|
1872
|
+
self.loc_lr_spin.setToolTip("Learning rate used during localization phase (localization head only).")
|
|
1873
|
+
lr_row.addWidget(self.loc_lr_spin)
|
|
1874
|
+
lr_row.addWidget(QLabel("Loc LR"))
|
|
1875
|
+
self.class_lr_spin = QDoubleSpinBox()
|
|
1876
|
+
self.class_lr_spin.setRange(1e-6, 1.0)
|
|
1877
|
+
self.class_lr_spin.setValue(1e-4)
|
|
1878
|
+
self.class_lr_spin.setDecimals(6)
|
|
1879
|
+
self.class_lr_spin.setSingleStep(1e-4)
|
|
1880
|
+
self.class_lr_spin.setToolTip("Learning rate used during classification phase (backbone + MAP head).")
|
|
1881
|
+
lr_row.addWidget(self.class_lr_spin)
|
|
1882
|
+
lr_row.addWidget(QLabel("Class LR"))
|
|
1883
|
+
lr_row.addStretch()
|
|
1884
|
+
config_layout.addRow("Learning rates:", lr_row)
|
|
1885
|
+
|
|
1886
|
+
self.weight_decay_spin = QDoubleSpinBox()
|
|
1887
|
+
self.weight_decay_spin.setRange(0.0, 1.0)
|
|
1888
|
+
self.weight_decay_spin.setValue(float(self.config.get("default_weight_decay", 0.001)))
|
|
1889
|
+
self.weight_decay_spin.setDecimals(6)
|
|
1890
|
+
config_layout.addRow("Weight decay:", self.weight_decay_spin)
|
|
1891
|
+
|
|
1892
|
+
hyper_group.setLayout(config_layout)
|
|
1893
|
+
|
|
1894
|
+
# --- Model Architecture ---
|
|
1895
|
+
arch_group = QGroupBox("Model Architecture")
|
|
1896
|
+
config_layout = QFormLayout()
|
|
1897
|
+
|
|
1898
|
+
self.map_num_heads_spin = QSpinBox()
|
|
1899
|
+
self.map_num_heads_spin.setRange(1, 16)
|
|
1900
|
+
self.map_num_heads_spin.setValue(4)
|
|
1901
|
+
config_layout.addRow("MAP num_heads:", self.map_num_heads_spin)
|
|
1902
|
+
|
|
1903
|
+
self.proj_dim_spin = QSpinBox()
|
|
1904
|
+
self.proj_dim_spin.setRange(64, 1024)
|
|
1905
|
+
self.proj_dim_spin.setSingleStep(64)
|
|
1906
|
+
self.proj_dim_spin.setValue(256)
|
|
1907
|
+
self.proj_dim_spin.setToolTip("Projection dimension for spatial attention pool in the frame head.")
|
|
1908
|
+
config_layout.addRow("Spatial pool proj dim:", self.proj_dim_spin)
|
|
1909
|
+
|
|
1910
|
+
self.use_multi_scale_check = QCheckBox("Multi-scale temporal context")
|
|
1911
|
+
self.use_multi_scale_check.setToolTip(
|
|
1912
|
+
"Cache backbone embeddings at two temporal scales: full fps and half fps\n"
|
|
1913
|
+
"(same clip duration, fewer frames). The temporal head sees both fine-grained\n"
|
|
1914
|
+
"local motion and broader context per frame, improving precision for subtle\n"
|
|
1915
|
+
"and short behaviors.\n\n"
|
|
1916
|
+
"Doubles backbone precomputation time. Requires clip_length ≥ 4.\n"
|
|
1917
|
+
"Disabled when localization is active."
|
|
1918
|
+
)
|
|
1919
|
+
self.use_multi_scale_check.setChecked(False)
|
|
1920
|
+
config_layout.addRow("", self.use_multi_scale_check)
|
|
1921
|
+
|
|
1922
|
+
self.map_dropout_spin = QDoubleSpinBox()
|
|
1923
|
+
self.map_dropout_spin.setRange(0.0, 0.9)
|
|
1924
|
+
self.map_dropout_spin.setValue(0.3)
|
|
1925
|
+
self.map_dropout_spin.setDecimals(2)
|
|
1926
|
+
self.map_dropout_spin.setToolTip(
|
|
1927
|
+
"Dropout for classification head (spatial pool + temporal TCN + boundary head).\n"
|
|
1928
|
+
"0.2-0.3 recommended. Higher values may hurt temporal conv layers.\n"
|
|
1929
|
+
"Localization head dropout is fixed at 0.0."
|
|
1930
|
+
)
|
|
1931
|
+
config_layout.addRow("Classification head dropout:", self.map_dropout_spin)
|
|
1932
|
+
|
|
1933
|
+
# Class-balanced loss weights
|
|
1934
|
+
self.use_class_weights_check = QCheckBox("Use class-balanced loss weights")
|
|
1935
|
+
self.use_class_weights_check.setToolTip(
|
|
1936
|
+
"Weight the loss by inverse class frequency to compensate for class imbalance.\n"
|
|
1937
|
+
"Rare classes get higher loss weight. Recommended when class counts are uneven."
|
|
1938
|
+
)
|
|
1939
|
+
self.use_class_weights_check.setChecked(False)
|
|
1940
|
+
config_layout.addRow("", self.use_class_weights_check)
|
|
1941
|
+
|
|
1942
|
+
self.use_supcon_check = QCheckBox("Use supervised contrastive loss on MAP embeddings")
|
|
1943
|
+
self.use_supcon_check.setToolTip(
|
|
1944
|
+
"Add an auxiliary supervised contrastive loss on the attention-pooled\n"
|
|
1945
|
+
"frame embeddings before the final classifier. Can be used with or without\n"
|
|
1946
|
+
"the temporal decoder."
|
|
1947
|
+
)
|
|
1948
|
+
self.use_supcon_check.setChecked(False)
|
|
1949
|
+
self.use_supcon_check.stateChanged.connect(
|
|
1950
|
+
lambda state: (
|
|
1951
|
+
self.supcon_weight_spin.setEnabled(bool(state)),
|
|
1952
|
+
self.supcon_temp_spin.setEnabled(bool(state)),
|
|
1953
|
+
)
|
|
1954
|
+
)
|
|
1955
|
+
config_layout.addRow("", self.use_supcon_check)
|
|
1956
|
+
|
|
1957
|
+
self.supcon_weight_spin = QDoubleSpinBox()
|
|
1958
|
+
self.supcon_weight_spin.setRange(0.0, 5.0)
|
|
1959
|
+
self.supcon_weight_spin.setSingleStep(0.05)
|
|
1960
|
+
self.supcon_weight_spin.setValue(0.2)
|
|
1961
|
+
self.supcon_weight_spin.setDecimals(2)
|
|
1962
|
+
self.supcon_weight_spin.setEnabled(False)
|
|
1963
|
+
self.supcon_weight_spin.setToolTip("Weight of the supervised contrastive loss term.")
|
|
1964
|
+
config_layout.addRow("SupCon weight:", self.supcon_weight_spin)
|
|
1965
|
+
|
|
1966
|
+
self.supcon_temp_spin = QDoubleSpinBox()
|
|
1967
|
+
self.supcon_temp_spin.setRange(0.01, 2.0)
|
|
1968
|
+
self.supcon_temp_spin.setSingleStep(0.01)
|
|
1969
|
+
self.supcon_temp_spin.setValue(0.10)
|
|
1970
|
+
self.supcon_temp_spin.setDecimals(2)
|
|
1971
|
+
self.supcon_temp_spin.setEnabled(False)
|
|
1972
|
+
self.supcon_temp_spin.setToolTip("Temperature used in the supervised contrastive loss.")
|
|
1973
|
+
config_layout.addRow("SupCon temperature:", self.supcon_temp_spin)
|
|
1974
|
+
|
|
1975
|
+
# Frame head settings (always active)
|
|
1976
|
+
self.use_temporal_decoder_check = QCheckBox("Use temporal decoder / refinement head")
|
|
1977
|
+
self.use_temporal_decoder_check.setToolTip(
|
|
1978
|
+
"Enable the temporal MS-TCN decoder after spatial attention pooling.\n"
|
|
1979
|
+
"If disabled, training uses the simpler baseline: spatial attention pooling\n"
|
|
1980
|
+
"+ direct per-frame classifier, while still producing framewise predictions."
|
|
1981
|
+
)
|
|
1982
|
+
self.use_temporal_decoder_check.setChecked(True)
|
|
1983
|
+
self.use_temporal_decoder_check.stateChanged.connect(self._on_temporal_decoder_toggled)
|
|
1984
|
+
config_layout.addRow("", self.use_temporal_decoder_check)
|
|
1985
|
+
|
|
1986
|
+
self.frame_head_layers_spin = QSpinBox()
|
|
1987
|
+
self.frame_head_layers_spin.setRange(1, 8)
|
|
1988
|
+
self.frame_head_layers_spin.setValue(4)
|
|
1989
|
+
self.frame_head_layers_spin.setToolTip(
|
|
1990
|
+
"Number of dilated temporal conv layers in the frame head.\n"
|
|
1991
|
+
"Higher values increase temporal receptive field."
|
|
1992
|
+
)
|
|
1993
|
+
config_layout.addRow("Frame head temporal layers:", self.frame_head_layers_spin)
|
|
1994
|
+
|
|
1995
|
+
self.temporal_pool_spin = QSpinBox()
|
|
1996
|
+
self.temporal_pool_spin.setRange(1, 4)
|
|
1997
|
+
self.temporal_pool_spin.setValue(1)
|
|
1998
|
+
self.temporal_pool_spin.setToolTip(
|
|
1999
|
+
"Average this many adjacent frames before temporal classification.\n"
|
|
2000
|
+
"1 = per-frame (no pooling), 2 = average pairs."
|
|
2001
|
+
)
|
|
2002
|
+
config_layout.addRow("Temporal pool (frames):", self.temporal_pool_spin)
|
|
2003
|
+
|
|
2004
|
+
# Boundary loss
|
|
2005
|
+
self.boundary_loss_weight_spin = QDoubleSpinBox()
|
|
2006
|
+
self.boundary_loss_weight_spin.setRange(0.0, 5.0)
|
|
2007
|
+
self.boundary_loss_weight_spin.setSingleStep(0.1)
|
|
2008
|
+
self.boundary_loss_weight_spin.setValue(0.3)
|
|
2009
|
+
self.boundary_loss_weight_spin.setDecimals(2)
|
|
2010
|
+
self.boundary_loss_weight_spin.setToolTip("Weight of boundary detection loss (change-point prediction).")
|
|
2011
|
+
config_layout.addRow("Boundary loss weight:", self.boundary_loss_weight_spin)
|
|
2012
|
+
|
|
2013
|
+
self.boundary_tolerance_spin = QSpinBox()
|
|
2014
|
+
self.boundary_tolerance_spin.setRange(0, 10)
|
|
2015
|
+
self.boundary_tolerance_spin.setValue(2)
|
|
2016
|
+
self.boundary_tolerance_spin.setToolTip("Tolerance in frames around labeled transition for boundary target.")
|
|
2017
|
+
config_layout.addRow("Boundary tolerance:", self.boundary_tolerance_spin)
|
|
2018
|
+
|
|
2019
|
+
# Smoothness loss
|
|
2020
|
+
self.smoothness_loss_weight_spin = QDoubleSpinBox()
|
|
2021
|
+
self.smoothness_loss_weight_spin.setRange(0.0, 1.0)
|
|
2022
|
+
self.smoothness_loss_weight_spin.setSingleStep(0.01)
|
|
2023
|
+
self.smoothness_loss_weight_spin.setValue(0.05)
|
|
2024
|
+
self.smoothness_loss_weight_spin.setDecimals(3)
|
|
2025
|
+
self.smoothness_loss_weight_spin.setToolTip("Weight for temporal smoothness regularization on frame predictions.")
|
|
2026
|
+
config_layout.addRow("Smoothness loss weight:", self.smoothness_loss_weight_spin)
|
|
2027
|
+
|
|
2028
|
+
# Bout balance weighting
|
|
2029
|
+
self.use_bout_balance_check = QCheckBox("Use bout balance weighting")
|
|
2030
|
+
self.use_bout_balance_check.setToolTip(
|
|
2031
|
+
"Weight each frame inversely to its contiguous segment length.\n"
|
|
2032
|
+
"Prevents long bouts from dominating the loss over short actions."
|
|
2033
|
+
)
|
|
2034
|
+
self.use_bout_balance_check.setChecked(True)
|
|
2035
|
+
config_layout.addRow("", self.use_bout_balance_check)
|
|
2036
|
+
|
|
2037
|
+
self.bout_balance_power_spin = QDoubleSpinBox()
|
|
2038
|
+
self.bout_balance_power_spin.setRange(0.1, 2.0)
|
|
2039
|
+
self.bout_balance_power_spin.setSingleStep(0.1)
|
|
2040
|
+
self.bout_balance_power_spin.setValue(1.0)
|
|
2041
|
+
self.bout_balance_power_spin.setDecimals(1)
|
|
2042
|
+
self.bout_balance_power_spin.setToolTip(
|
|
2043
|
+
"Exponent for bout balance weighting: weight = segment_length ^ (-power).\n"
|
|
2044
|
+
"1.0 = linear (default, most aggressive).\n"
|
|
2045
|
+
"0.5 = square-root (softer).\n"
|
|
2046
|
+
"Higher = shorter segments get relatively more weight."
|
|
2047
|
+
)
|
|
2048
|
+
config_layout.addRow("Bout balance power:", self.bout_balance_power_spin)
|
|
2049
|
+
self.use_bout_balance_check.stateChanged.connect(
|
|
2050
|
+
lambda s: self.bout_balance_power_spin.setEnabled(bool(s))
|
|
2051
|
+
)
|
|
2052
|
+
|
|
2053
|
+
arch_group.setLayout(config_layout)
|
|
2054
|
+
|
|
2055
|
+
# --- Localization ---
|
|
2056
|
+
loc_group = QGroupBox("Localization (requires bbox labels)")
|
|
2057
|
+
self.loc_group = loc_group
|
|
2058
|
+
loc_group.setEnabled(False)
|
|
2059
|
+
loc_group.setToolTip(
|
|
2060
|
+
"Optional: localizes individual animals when multiple are in camera view.\n"
|
|
2061
|
+
"To activate, draw and save bounding boxes on at least some clips in the Labeling tab."
|
|
2062
|
+
)
|
|
2063
|
+
config_layout = QFormLayout()
|
|
2064
|
+
|
|
2065
|
+
self.use_localization_check = QCheckBox("Use Localization Supervision (bbox)")
|
|
2066
|
+
self.use_localization_check.setToolTip(
|
|
2067
|
+
"Autonomous 2-stage training: first learn localization from bbox labels, then train classifier on localized crops."
|
|
2068
|
+
)
|
|
2069
|
+
self.use_localization_check.setChecked(False)
|
|
2070
|
+
config_layout.addRow("", self.use_localization_check)
|
|
2071
|
+
|
|
2072
|
+
self.use_manual_loc_switch_check = QCheckBox("Manual switch to classification epoch")
|
|
2073
|
+
self.use_manual_loc_switch_check.setToolTip(
|
|
2074
|
+
"If enabled, switch from localization phase to classification phase at the selected epoch."
|
|
2075
|
+
)
|
|
2076
|
+
self.use_manual_loc_switch_check.setChecked(False)
|
|
2077
|
+
self.use_manual_loc_switch_check.stateChanged.connect(
|
|
2078
|
+
lambda state: self.manual_loc_switch_epoch_spin.setEnabled(
|
|
2079
|
+
bool(state) and self.use_localization_check.isChecked()
|
|
2080
|
+
)
|
|
2081
|
+
)
|
|
2082
|
+
config_layout.addRow("", self.use_manual_loc_switch_check)
|
|
2083
|
+
|
|
2084
|
+
self.manual_loc_switch_epoch_spin = QSpinBox()
|
|
2085
|
+
self.manual_loc_switch_epoch_spin.setRange(1, 10000)
|
|
2086
|
+
self.manual_loc_switch_epoch_spin.setValue(20)
|
|
2087
|
+
self.manual_loc_switch_epoch_spin.setEnabled(False)
|
|
2088
|
+
self.manual_loc_switch_epoch_spin.setToolTip(
|
|
2089
|
+
"Epoch number at which localization stops and classification starts."
|
|
2090
|
+
)
|
|
2091
|
+
config_layout.addRow("Switch epoch:", self.manual_loc_switch_epoch_spin)
|
|
2092
|
+
|
|
2093
|
+
self.use_localization_check.stateChanged.connect(
|
|
2094
|
+
lambda state: self.manual_loc_switch_epoch_spin.setEnabled(
|
|
2095
|
+
bool(state) and self.use_manual_loc_switch_check.isChecked()
|
|
2096
|
+
)
|
|
2097
|
+
)
|
|
2098
|
+
self.use_localization_check.stateChanged.connect(self._on_localization_toggled)
|
|
2099
|
+
|
|
2100
|
+
self.crop_padding_spin = QDoubleSpinBox()
|
|
2101
|
+
self.crop_padding_spin.setRange(-0.45, 1.0)
|
|
2102
|
+
self.crop_padding_spin.setSingleStep(0.05)
|
|
2103
|
+
self.crop_padding_spin.setDecimals(2)
|
|
2104
|
+
self.crop_padding_spin.setValue(0.35)
|
|
2105
|
+
self.crop_padding_spin.setToolTip(
|
|
2106
|
+
"Fractional padding around the predicted bbox for classification crops.\n"
|
|
2107
|
+
"Positive: expand crop (0.35 = 70% larger than bbox).\n"
|
|
2108
|
+
"Negative: shrink/zoom in (-0.2 = crop is 60% of bbox, zooming into center).\n"
|
|
2109
|
+
"Use 0.05-0.10 for tight crops, negative for extreme close-ups."
|
|
2110
|
+
)
|
|
2111
|
+
config_layout.addRow("Crop padding:", self.crop_padding_spin)
|
|
2112
|
+
|
|
2113
|
+
self.use_crop_jitter_check = QCheckBox("Crop jitter augmentation")
|
|
2114
|
+
self.use_crop_jitter_check.setToolTip(
|
|
2115
|
+
"Randomly shift the crop center during training to prevent\n"
|
|
2116
|
+
"the model from memorizing background/location cues.\n"
|
|
2117
|
+
"Only available when localization is enabled."
|
|
2118
|
+
)
|
|
2119
|
+
self.use_crop_jitter_check.setChecked(False)
|
|
2120
|
+
self.use_crop_jitter_check.setEnabled(False)
|
|
2121
|
+
config_layout.addRow("", self.use_crop_jitter_check)
|
|
2122
|
+
|
|
2123
|
+
self.crop_jitter_strength_spin = QDoubleSpinBox()
|
|
2124
|
+
self.crop_jitter_strength_spin.setRange(0.01, 0.5)
|
|
2125
|
+
self.crop_jitter_strength_spin.setSingleStep(0.05)
|
|
2126
|
+
self.crop_jitter_strength_spin.setDecimals(2)
|
|
2127
|
+
self.crop_jitter_strength_spin.setValue(0.15)
|
|
2128
|
+
self.crop_jitter_strength_spin.setEnabled(False)
|
|
2129
|
+
self.crop_jitter_strength_spin.setToolTip(
|
|
2130
|
+
"Max random shift as a fraction of bbox size.\n"
|
|
2131
|
+
"0.15 = shift up to 15% of bbox width/height in any direction."
|
|
2132
|
+
)
|
|
2133
|
+
config_layout.addRow("Crop jitter strength:", self.crop_jitter_strength_spin)
|
|
2134
|
+
|
|
2135
|
+
self.use_crop_jitter_check.stateChanged.connect(
|
|
2136
|
+
lambda s: self.crop_jitter_strength_spin.setEnabled(bool(s))
|
|
2137
|
+
)
|
|
2138
|
+
|
|
2139
|
+
loc_group.setLayout(config_layout)
|
|
2140
|
+
|
|
2141
|
+
# --- OvR & Hard Mining ---
|
|
2142
|
+
ovr_group = QGroupBox("OvR && Hard Mining")
|
|
2143
|
+
config_layout = QFormLayout()
|
|
2144
|
+
|
|
2145
|
+
self.use_ovr_check = QCheckBox("One-vs-Rest heads (OvR)")
|
|
2146
|
+
self.use_ovr_check.setToolTip(
|
|
2147
|
+
"Train independent binary heads per class instead of shared softmax.\n"
|
|
2148
|
+
"Better when 'other'/background is heterogeneous.\n"
|
|
2149
|
+
"near_negative_* clips automatically suppress their matched class.\n"
|
|
2150
|
+
"At inference: sigmoid scores + per-class threshold → Ignore if all low."
|
|
2151
|
+
)
|
|
2152
|
+
self.use_ovr_check.setChecked(False)
|
|
2153
|
+
self.use_ovr_check.stateChanged.connect(self._on_ovr_toggled)
|
|
2154
|
+
config_layout.addRow("", self.use_ovr_check)
|
|
2155
|
+
|
|
2156
|
+
self.ovr_background_negative_check = QCheckBox("Treat Other/background as OvR negatives")
|
|
2157
|
+
self.ovr_background_negative_check.setToolTip(
|
|
2158
|
+
"Hybrid OvR mode: remove helper classes like Other/Background from the trained heads,\n"
|
|
2159
|
+
"but keep those clips as all-zero negative supervision for the target heads.\n"
|
|
2160
|
+
"Useful when Other is heterogeneous and hurts learning real behaviors."
|
|
2161
|
+
)
|
|
2162
|
+
self.ovr_background_negative_check.setChecked(False)
|
|
2163
|
+
self.ovr_background_negative_check.setEnabled(False)
|
|
2164
|
+
config_layout.addRow("", self.ovr_background_negative_check)
|
|
2165
|
+
|
|
2166
|
+
self.ovr_label_smoothing_spin = QDoubleSpinBox()
|
|
2167
|
+
self.ovr_label_smoothing_spin.setRange(0.0, 0.3)
|
|
2168
|
+
self.ovr_label_smoothing_spin.setSingleStep(0.01)
|
|
2169
|
+
self.ovr_label_smoothing_spin.setValue(0.05)
|
|
2170
|
+
self.ovr_label_smoothing_spin.setDecimals(2)
|
|
2171
|
+
self.ovr_label_smoothing_spin.setToolTip(
|
|
2172
|
+
"Smooth binary targets from [0,1] to [eps, 1-eps].\n"
|
|
2173
|
+
"Prevents overconfident predictions and improves generalization.\n"
|
|
2174
|
+
"0 = no smoothing, 0.05 = recommended default, 0.1+ = strong regularization."
|
|
2175
|
+
)
|
|
2176
|
+
config_layout.addRow("OvR label smoothing:", self.ovr_label_smoothing_spin)
|
|
2177
|
+
|
|
2178
|
+
self.use_asl_check = QCheckBox("Use Asymmetric Loss (ASL)")
|
|
2179
|
+
self.use_asl_check.setToolTip(
|
|
2180
|
+
"Asymmetric Loss down-weights easy negatives more aggressively than positives.\n"
|
|
2181
|
+
"Recommended for OvR training where negatives dominate.\n"
|
|
2182
|
+
"γ- controls negative focusing, γ+ controls positive focusing."
|
|
2183
|
+
)
|
|
2184
|
+
self.use_asl_check.setChecked(True)
|
|
2185
|
+
self.use_asl_check.setEnabled(False)
|
|
2186
|
+
self.use_asl_check.stateChanged.connect(self._on_asl_toggled)
|
|
2187
|
+
config_layout.addRow("", self.use_asl_check)
|
|
2188
|
+
|
|
2189
|
+
asl_params_layout = QHBoxLayout()
|
|
2190
|
+
self.asl_gamma_neg_spin = QDoubleSpinBox()
|
|
2191
|
+
self.asl_gamma_neg_spin.setRange(0.0, 10.0)
|
|
2192
|
+
self.asl_gamma_neg_spin.setSingleStep(0.5)
|
|
2193
|
+
self.asl_gamma_neg_spin.setValue(2.0)
|
|
2194
|
+
self.asl_gamma_neg_spin.setDecimals(1)
|
|
2195
|
+
self.asl_gamma_neg_spin.setToolTip("Focusing parameter for negative samples (higher = more suppression)")
|
|
2196
|
+
asl_params_layout.addWidget(QLabel("γ-:"))
|
|
2197
|
+
asl_params_layout.addWidget(self.asl_gamma_neg_spin)
|
|
2198
|
+
|
|
2199
|
+
self.asl_gamma_pos_spin = QDoubleSpinBox()
|
|
2200
|
+
self.asl_gamma_pos_spin.setRange(0.0, 10.0)
|
|
2201
|
+
self.asl_gamma_pos_spin.setSingleStep(0.5)
|
|
2202
|
+
self.asl_gamma_pos_spin.setValue(0.0)
|
|
2203
|
+
self.asl_gamma_pos_spin.setDecimals(1)
|
|
2204
|
+
self.asl_gamma_pos_spin.setToolTip("Focusing parameter for positive samples (0 = no down-weighting)")
|
|
2205
|
+
asl_params_layout.addWidget(QLabel("γ+:"))
|
|
2206
|
+
asl_params_layout.addWidget(self.asl_gamma_pos_spin)
|
|
2207
|
+
|
|
2208
|
+
self.asl_clip_spin = QDoubleSpinBox()
|
|
2209
|
+
self.asl_clip_spin.setRange(0.0, 0.5)
|
|
2210
|
+
self.asl_clip_spin.setSingleStep(0.01)
|
|
2211
|
+
self.asl_clip_spin.setValue(0.05)
|
|
2212
|
+
self.asl_clip_spin.setDecimals(2)
|
|
2213
|
+
self.asl_clip_spin.setToolTip("Probability margin for hard thresholding negatives")
|
|
2214
|
+
asl_params_layout.addWidget(QLabel("clip:"))
|
|
2215
|
+
asl_params_layout.addWidget(self.asl_clip_spin)
|
|
2216
|
+
|
|
2217
|
+
self.asl_gamma_neg_spin.setEnabled(False)
|
|
2218
|
+
self.asl_gamma_pos_spin.setEnabled(False)
|
|
2219
|
+
self.asl_clip_spin.setEnabled(False)
|
|
2220
|
+
config_layout.addRow("ASL parameters:", asl_params_layout)
|
|
2221
|
+
|
|
2222
|
+
# Confusion-aware sampler (OvR only)
|
|
2223
|
+
self.use_confusion_sampler_check = QCheckBox("Confusion-aware hard mining")
|
|
2224
|
+
self.use_confusion_sampler_check.setToolTip(
|
|
2225
|
+
"After each epoch, clips that trigger the wrong OvR heads (e.g. groom clips\n"
|
|
2226
|
+
"that also activate dig's head) are sampled more often in the next epoch.\n"
|
|
2227
|
+
"Helps the model learn hard distinctions between visually similar behaviours.\n"
|
|
2228
|
+
"Temperature: how sharply to focus on the hardest clips (higher = more aggressive)."
|
|
2229
|
+
)
|
|
2230
|
+
self.use_confusion_sampler_check.setChecked(True)
|
|
2231
|
+
self.use_confusion_sampler_check.setEnabled(False)
|
|
2232
|
+
def _on_confusion_sampler_toggled(state):
|
|
2233
|
+
on = bool(state) and self.use_ovr_check.isChecked()
|
|
2234
|
+
self.confusion_temperature_spin.setEnabled(on)
|
|
2235
|
+
self.confusion_warmup_spin.setEnabled(on)
|
|
2236
|
+
self.use_confusion_sampler_check.stateChanged.connect(_on_confusion_sampler_toggled)
|
|
2237
|
+
config_layout.addRow("", self.use_confusion_sampler_check)
|
|
2238
|
+
|
|
2239
|
+
confusion_temp_layout = QHBoxLayout()
|
|
2240
|
+
self.confusion_temperature_spin = QDoubleSpinBox()
|
|
2241
|
+
self.confusion_temperature_spin.setRange(0.5, 8.0)
|
|
2242
|
+
self.confusion_temperature_spin.setSingleStep(0.5)
|
|
2243
|
+
self.confusion_temperature_spin.setValue(2.0)
|
|
2244
|
+
self.confusion_temperature_spin.setDecimals(1)
|
|
2245
|
+
self.confusion_temperature_spin.setToolTip(
|
|
2246
|
+
"Sharpness of the hard-mining distribution.\n"
|
|
2247
|
+
"1.0 = mild (all clips get similar weight).\n"
|
|
2248
|
+
"2.0 = moderate (default).\n"
|
|
2249
|
+
"4.0+ = aggressive (almost only hardest clips sampled)."
|
|
2250
|
+
)
|
|
2251
|
+
self.confusion_temperature_spin.setEnabled(False)
|
|
2252
|
+
confusion_temp_layout.addWidget(self.confusion_temperature_spin)
|
|
2253
|
+
config_layout.addRow("Confusion temperature:", confusion_temp_layout)
|
|
2254
|
+
|
|
2255
|
+
confusion_warmup_layout = QHBoxLayout()
|
|
2256
|
+
self.confusion_warmup_spin = QSpinBox()
|
|
2257
|
+
self.confusion_warmup_spin.setRange(0, 80)
|
|
2258
|
+
self.confusion_warmup_spin.setSingleStep(5)
|
|
2259
|
+
self.confusion_warmup_spin.setValue(20)
|
|
2260
|
+
self.confusion_warmup_spin.setSuffix("%")
|
|
2261
|
+
self.confusion_warmup_spin.setToolTip(
|
|
2262
|
+
"Percentage of total epochs to use uniform sampling before\n"
|
|
2263
|
+
"activating confusion-based hard mining.\n"
|
|
2264
|
+
"0% = active from start, 20% = default warmup."
|
|
2265
|
+
)
|
|
2266
|
+
self.confusion_warmup_spin.setEnabled(False)
|
|
2267
|
+
confusion_warmup_layout.addWidget(self.confusion_warmup_spin)
|
|
2268
|
+
config_layout.addRow("Confusion warmup:", confusion_warmup_layout)
|
|
2269
|
+
|
|
2270
|
+
self.use_hard_pair_mining_check = QCheckBox("Hard-pair mining")
|
|
2271
|
+
self.use_hard_pair_mining_check.setToolTip(
|
|
2272
|
+
"Add extra pairwise margin pressure for specific confusing class pairs.\n"
|
|
2273
|
+
"Use this for cases like rear vs digg where standard OvR is not enough."
|
|
2274
|
+
)
|
|
2275
|
+
self.use_hard_pair_mining_check.setChecked(False)
|
|
2276
|
+
self.use_hard_pair_mining_check.setEnabled(False)
|
|
2277
|
+
self.use_hard_pair_mining_check.stateChanged.connect(self._on_hard_pair_toggled)
|
|
2278
|
+
config_layout.addRow("", self.use_hard_pair_mining_check)
|
|
2279
|
+
|
|
2280
|
+
self.hard_pair_edit = QLineEdit()
|
|
2281
|
+
self.hard_pair_edit.setPlaceholderText("rear:digg, move:digg")
|
|
2282
|
+
self.hard_pair_edit.setEnabled(False)
|
|
2283
|
+
self.hard_pair_edit.setToolTip(
|
|
2284
|
+
"Comma-separated hard pairs in the form class_a:class_b.\n"
|
|
2285
|
+
"Example: rear:digg, move:digg"
|
|
2286
|
+
)
|
|
2287
|
+
config_layout.addRow("Hard pairs:", self.hard_pair_edit)
|
|
2288
|
+
|
|
2289
|
+
self.hard_pair_loss_weight_spin = QDoubleSpinBox()
|
|
2290
|
+
self.hard_pair_loss_weight_spin.setRange(0.0, 5.0)
|
|
2291
|
+
self.hard_pair_loss_weight_spin.setSingleStep(0.05)
|
|
2292
|
+
self.hard_pair_loss_weight_spin.setValue(0.2)
|
|
2293
|
+
self.hard_pair_loss_weight_spin.setDecimals(3)
|
|
2294
|
+
self.hard_pair_loss_weight_spin.setEnabled(False)
|
|
2295
|
+
self.hard_pair_loss_weight_spin.setToolTip(
|
|
2296
|
+
"Weight of the extra pair-margin loss added on top of the main frame loss."
|
|
2297
|
+
)
|
|
2298
|
+
config_layout.addRow("Hard-pair loss weight:", self.hard_pair_loss_weight_spin)
|
|
2299
|
+
|
|
2300
|
+
self.hard_pair_margin_spin = QDoubleSpinBox()
|
|
2301
|
+
self.hard_pair_margin_spin.setRange(0.0, 5.0)
|
|
2302
|
+
self.hard_pair_margin_spin.setSingleStep(0.05)
|
|
2303
|
+
self.hard_pair_margin_spin.setValue(0.5)
|
|
2304
|
+
self.hard_pair_margin_spin.setDecimals(2)
|
|
2305
|
+
self.hard_pair_margin_spin.setEnabled(False)
|
|
2306
|
+
self.hard_pair_margin_spin.setToolTip(
|
|
2307
|
+
"Required logit gap between the true class and its configured rival."
|
|
2308
|
+
)
|
|
2309
|
+
config_layout.addRow("Hard-pair margin:", self.hard_pair_margin_spin)
|
|
2310
|
+
|
|
2311
|
+
self.hard_pair_confusion_boost_spin = QDoubleSpinBox()
|
|
2312
|
+
self.hard_pair_confusion_boost_spin.setRange(1.0, 5.0)
|
|
2313
|
+
self.hard_pair_confusion_boost_spin.setSingleStep(0.1)
|
|
2314
|
+
self.hard_pair_confusion_boost_spin.setValue(1.5)
|
|
2315
|
+
self.hard_pair_confusion_boost_spin.setDecimals(2)
|
|
2316
|
+
self.hard_pair_confusion_boost_spin.setEnabled(False)
|
|
2317
|
+
self.hard_pair_confusion_boost_spin.setToolTip(
|
|
2318
|
+
"Extra multiplier for confusion-sampler scores when the top rival is a configured hard pair."
|
|
2319
|
+
)
|
|
2320
|
+
config_layout.addRow("Hard-pair sampler boost:", self.hard_pair_confusion_boost_spin)
|
|
2321
|
+
|
|
2322
|
+
self.use_weighted_sampler_check = QCheckBox("Use weighted random sampler")
|
|
2323
|
+
self.use_weighted_sampler_check.setToolTip("Oversample rare classes during training")
|
|
2324
|
+
self.use_weighted_sampler_check.setChecked(False)
|
|
2325
|
+
config_layout.addRow("", self.use_weighted_sampler_check)
|
|
2326
|
+
|
|
2327
|
+
ovr_group.setLayout(config_layout)
|
|
2328
|
+
|
|
2329
|
+
# --- Augmentation & Data ---
|
|
2330
|
+
data_group = QGroupBox("Augmentation && Data")
|
|
2331
|
+
config_layout = QFormLayout()
|
|
2332
|
+
|
|
2333
|
+
self.use_augmentation_check = QCheckBox("Use data augmentation")
|
|
2334
|
+
self.use_augmentation_check.setToolTip("Apply selected augmentations to training clips")
|
|
2335
|
+
self.use_augmentation_check.setChecked(False)
|
|
2336
|
+
self.use_augmentation_check.stateChanged.connect(self._on_use_augmentation_changed)
|
|
2337
|
+
|
|
2338
|
+
self.augmentation_options_btn = QPushButton("Augmentation options...")
|
|
2339
|
+
self.augmentation_options_btn.setToolTip("Choose which augmentations to apply during training")
|
|
2340
|
+
self.augmentation_options_btn.setEnabled(False)
|
|
2341
|
+
self.augmentation_options_btn.clicked.connect(self._open_augmentation_options_dialog)
|
|
2342
|
+
|
|
2343
|
+
augmentation_row = QHBoxLayout()
|
|
2344
|
+
augmentation_row.addWidget(self.use_augmentation_check)
|
|
2345
|
+
augmentation_row.addWidget(self.augmentation_options_btn)
|
|
2346
|
+
config_layout.addRow("", augmentation_row)
|
|
2347
|
+
|
|
2348
|
+
self.virtual_expansion_spin = QSpinBox()
|
|
2349
|
+
self.virtual_expansion_spin.setRange(1, 20)
|
|
2350
|
+
self.virtual_expansion_spin.setValue(5)
|
|
2351
|
+
self.virtual_expansion_spin.setEnabled(False)
|
|
2352
|
+
self.virtual_expansion_spin.setToolTip(
|
|
2353
|
+
"Virtual dataset expansion multiplier (only active when augmentation is on).\n"
|
|
2354
|
+
"Each unique clip is sampled this many times per epoch with different augmentations.\n"
|
|
2355
|
+
"Higher = more augmented variety per epoch; useful for small datasets."
|
|
2356
|
+
)
|
|
2357
|
+
config_layout.addRow("Virtual expansion (x):", self.virtual_expansion_spin)
|
|
2358
|
+
self.use_augmentation_check.stateChanged.connect(
|
|
2359
|
+
lambda s: self.virtual_expansion_spin.setEnabled(bool(s))
|
|
2360
|
+
)
|
|
2361
|
+
|
|
2362
|
+
self.use_stitch_check = QCheckBox("Clip-stitch augmentation")
|
|
2363
|
+
self.use_stitch_check.setToolTip(
|
|
2364
|
+
"Splice two clips from different classes with a fixed 50/50 split.\n"
|
|
2365
|
+
"Teaches the model per-frame behavior regardless of clip-level context,\n"
|
|
2366
|
+
"which improves inference on clips containing multiple behaviors."
|
|
2367
|
+
)
|
|
2368
|
+
self.use_stitch_check.setChecked(False)
|
|
2369
|
+
config_layout.addRow("", self.use_stitch_check)
|
|
2370
|
+
|
|
2371
|
+
self.stitch_prob_spin = QDoubleSpinBox()
|
|
2372
|
+
self.stitch_prob_spin.setRange(0.0, 1.0)
|
|
2373
|
+
self.stitch_prob_spin.setSingleStep(0.05)
|
|
2374
|
+
self.stitch_prob_spin.setValue(0.3)
|
|
2375
|
+
self.stitch_prob_spin.setDecimals(2)
|
|
2376
|
+
self.stitch_prob_spin.setEnabled(False)
|
|
2377
|
+
self.stitch_prob_spin.setToolTip(
|
|
2378
|
+
"Probability per sample of applying clip-stitch augmentation (0–1).\n"
|
|
2379
|
+
"0.3 means ~30% of training samples will be stitched mixed clips."
|
|
2380
|
+
)
|
|
2381
|
+
config_layout.addRow("Stitch probability:", self.stitch_prob_spin)
|
|
2382
|
+
|
|
2383
|
+
self.emb_aug_versions_spin = QSpinBox()
|
|
2384
|
+
self.emb_aug_versions_spin.setRange(1, 20)
|
|
2385
|
+
self.emb_aug_versions_spin.setValue(5)
|
|
2386
|
+
self.emb_aug_versions_spin.setEnabled(False)
|
|
2387
|
+
self.emb_aug_versions_spin.setToolTip(
|
|
2388
|
+
"Number of augmented embedding versions to pre-compute per clip.\n"
|
|
2389
|
+
"Each version applies a different random augmentation before the backbone.\n"
|
|
2390
|
+
"Higher = more diversity (larger cache, slower precompute, same training speed).\n"
|
|
2391
|
+
"Requires augmentation to be enabled. Recommended: 3–8."
|
|
2392
|
+
)
|
|
2393
|
+
config_layout.addRow("Cached aug versions:", self.emb_aug_versions_spin)
|
|
2394
|
+
|
|
2395
|
+
def _sync_emb_cache_controls():
|
|
2396
|
+
loc_on = self.use_localization_check.isChecked()
|
|
2397
|
+
aug_on = self.use_augmentation_check.isChecked()
|
|
2398
|
+
# Cached aug versions only useful when augmentation is on
|
|
2399
|
+
self.emb_aug_versions_spin.setEnabled(aug_on and not loc_on)
|
|
2400
|
+
# Multi-scale requires no localization
|
|
2401
|
+
self.use_multi_scale_check.setEnabled(not loc_on)
|
|
2402
|
+
if loc_on:
|
|
2403
|
+
self.use_multi_scale_check.setChecked(False)
|
|
2404
|
+
|
|
2405
|
+
self._sync_emb_cache_controls = _sync_emb_cache_controls
|
|
2406
|
+
self.use_augmentation_check.stateChanged.connect(lambda _: _sync_emb_cache_controls())
|
|
2407
|
+
|
|
2408
|
+
self.use_stitch_check.stateChanged.connect(
|
|
2409
|
+
lambda s: self.stitch_prob_spin.setEnabled(bool(s))
|
|
2410
|
+
)
|
|
2411
|
+
self.use_stitch_check.stateChanged.connect(lambda _s: self._sync_stitch_controls())
|
|
2412
|
+
self.use_localization_check.stateChanged.connect(lambda _s: self._sync_stitch_controls())
|
|
2413
|
+
self._sync_stitch_controls()
|
|
2414
|
+
|
|
2415
|
+
self.use_all_for_training_check = QCheckBox("Use all data for training (no validation)")
|
|
2416
|
+
self.use_all_for_training_check.setToolTip("Enable this for small datasets. Disables validation split.")
|
|
2417
|
+
self.use_all_for_training_check.stateChanged.connect(self._on_use_all_changed)
|
|
2418
|
+
config_layout.addRow("", self.use_all_for_training_check)
|
|
2419
|
+
|
|
2420
|
+
self.val_split_spin = QDoubleSpinBox()
|
|
2421
|
+
self.val_split_spin.setRange(0.0, 0.5)
|
|
2422
|
+
self.val_split_spin.setValue(0.2)
|
|
2423
|
+
self.val_split_spin.setDecimals(2)
|
|
2424
|
+
self.val_split_spin.setSingleStep(0.05)
|
|
2425
|
+
self.val_split_spin.setSuffix(" (20% = 0.2)")
|
|
2426
|
+
config_layout.addRow("Validation split:", self.val_split_spin)
|
|
2427
|
+
|
|
2428
|
+
self.auto_tune_check = QCheckBox("Auto-tune before final training")
|
|
2429
|
+
self.auto_tune_check.setToolTip(
|
|
2430
|
+
"Run a small random search over a few important training settings,\n"
|
|
2431
|
+
"then retrain once from scratch using the best candidate.\n"
|
|
2432
|
+
"Requires a validation split."
|
|
2433
|
+
)
|
|
2434
|
+
self.auto_tune_check.setChecked(False)
|
|
2435
|
+
self.auto_tune_check.stateChanged.connect(self._on_auto_tune_changed)
|
|
2436
|
+
config_layout.addRow("", self.auto_tune_check)
|
|
2437
|
+
|
|
2438
|
+
auto_tune_row = QHBoxLayout()
|
|
2439
|
+
self.auto_tune_runs_spin = QSpinBox()
|
|
2440
|
+
self.auto_tune_runs_spin.setRange(1, 32)
|
|
2441
|
+
self.auto_tune_runs_spin.setValue(8)
|
|
2442
|
+
self.auto_tune_runs_spin.setToolTip("Number of short candidate runs to evaluate before the final retrain.")
|
|
2443
|
+
auto_tune_row.addWidget(self.auto_tune_runs_spin)
|
|
2444
|
+
auto_tune_row.addWidget(QLabel("runs"))
|
|
2445
|
+
self.auto_tune_epochs_spin = QSpinBox()
|
|
2446
|
+
self.auto_tune_epochs_spin.setRange(1, 200)
|
|
2447
|
+
self.auto_tune_epochs_spin.setValue(12)
|
|
2448
|
+
self.auto_tune_epochs_spin.setToolTip("Epoch budget for each short auto-tune trial.")
|
|
2449
|
+
auto_tune_row.addWidget(self.auto_tune_epochs_spin)
|
|
2450
|
+
auto_tune_row.addWidget(QLabel("search epochs"))
|
|
2451
|
+
auto_tune_row.addStretch()
|
|
2452
|
+
config_layout.addRow("Auto-tune search:", auto_tune_row)
|
|
2453
|
+
self._on_auto_tune_changed(int(self.auto_tune_check.isChecked()))
|
|
2454
|
+
|
|
2455
|
+
self.select_classes_check = QCheckBox("Limit classes for training")
|
|
2456
|
+
self.select_classes_check.setToolTip("Select specific classes to use for training (useful for testing minimum examples needed)")
|
|
2457
|
+
self.select_classes_check.stateChanged.connect(self._on_select_classes_changed)
|
|
2458
|
+
config_layout.addRow("", self.select_classes_check)
|
|
2459
|
+
|
|
2460
|
+
self.class_selection_list = QListWidget()
|
|
2461
|
+
self.class_selection_list.setMaximumHeight(150)
|
|
2462
|
+
self.class_selection_list.setEnabled(False)
|
|
2463
|
+
config_layout.addRow("Selected classes:", self.class_selection_list)
|
|
2464
|
+
|
|
2465
|
+
self.limit_per_class_check = QCheckBox("Limit annotations per class")
|
|
2466
|
+
self.limit_per_class_check.setToolTip("Limit the maximum number of clips used per class for training")
|
|
2467
|
+
self.limit_per_class_check.stateChanged.connect(self._on_limit_per_class_changed)
|
|
2468
|
+
config_layout.addRow("", self.limit_per_class_check)
|
|
2469
|
+
|
|
2470
|
+
self.per_class_limit_table = QTableWidget()
|
|
2471
|
+
self.per_class_limit_table.setColumnCount(3)
|
|
2472
|
+
self.per_class_limit_table.setHorizontalHeaderLabels(["Class", "Max Train", "Max Val"])
|
|
2473
|
+
self.per_class_limit_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
|
2474
|
+
self.per_class_limit_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
|
|
2475
|
+
self.per_class_limit_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
|
|
2476
|
+
self.per_class_limit_table.setMaximumHeight(150)
|
|
2477
|
+
self.per_class_limit_table.setEnabled(False)
|
|
2478
|
+
config_layout.addRow("Per-class limits:", self.per_class_limit_table)
|
|
2479
|
+
|
|
2480
|
+
self.use_embedding_diversity_check = QCheckBox("Use embedding-based diversity selection")
|
|
2481
|
+
self.use_embedding_diversity_check.setToolTip("When limiting per-class samples, use VideoPrism embeddings to select the most diverse clips (Farthest Point Sampling). Better than random for capturing variety.")
|
|
2482
|
+
self.use_embedding_diversity_check.setChecked(False)
|
|
2483
|
+
config_layout.addRow("", self.use_embedding_diversity_check)
|
|
2484
|
+
|
|
2485
|
+
# Fine-tuning controls
|
|
2486
|
+
self.finetune_check = QCheckBox("Fine-tune existing model")
|
|
2487
|
+
self.finetune_check.setToolTip("Load weights from an existing model instead of training from scratch")
|
|
2488
|
+
self.finetune_check.stateChanged.connect(self._on_finetune_changed)
|
|
2489
|
+
config_layout.addRow("", self.finetune_check)
|
|
2490
|
+
|
|
2491
|
+
self.pretrained_path_edit = QLineEdit()
|
|
2492
|
+
self.pretrained_path_edit.setPlaceholderText("Path to existing .pt model")
|
|
2493
|
+
self.pretrained_path_edit.setEnabled(False)
|
|
2494
|
+
self.pretrained_browse_btn = QPushButton("Browse...")
|
|
2495
|
+
self.pretrained_browse_btn.setEnabled(False)
|
|
2496
|
+
self.pretrained_browse_btn.clicked.connect(self._browse_pretrained)
|
|
2497
|
+
|
|
2498
|
+
pretrained_layout = QHBoxLayout()
|
|
2499
|
+
pretrained_layout.addWidget(self.pretrained_path_edit)
|
|
2500
|
+
pretrained_layout.addWidget(self.pretrained_browse_btn)
|
|
2501
|
+
config_layout.addRow("Pretrained model:", pretrained_layout)
|
|
2502
|
+
|
|
2503
|
+
data_group.setLayout(config_layout)
|
|
2504
|
+
|
|
2505
|
+
row1 = QHBoxLayout()
|
|
2506
|
+
row1.addWidget(paths_group, 1)
|
|
2507
|
+
row1.addWidget(info_scroll, 1)
|
|
2508
|
+
config_vbox.addLayout(row1)
|
|
2509
|
+
|
|
2510
|
+
row2 = QHBoxLayout()
|
|
2511
|
+
row2.addWidget(hyper_group, 1)
|
|
2512
|
+
row2.addWidget(loc_group, 1)
|
|
2513
|
+
config_vbox.addLayout(row2)
|
|
2514
|
+
|
|
2515
|
+
row3 = QHBoxLayout()
|
|
2516
|
+
row3.addWidget(arch_group, 1)
|
|
2517
|
+
row3.addWidget(ovr_group, 1)
|
|
2518
|
+
config_vbox.addLayout(row3)
|
|
2519
|
+
|
|
2520
|
+
config_vbox.addWidget(data_group)
|
|
2521
|
+
|
|
2522
|
+
config_scroll = QScrollArea()
|
|
2523
|
+
config_scroll.setWidget(config_container)
|
|
2524
|
+
config_scroll.setWidgetResizable(True)
|
|
2525
|
+
config_scroll.setMinimumHeight(300)
|
|
2526
|
+
layout.addWidget(config_scroll, 1)
|
|
2527
|
+
|
|
2528
|
+
control_layout = QHBoxLayout()
|
|
2529
|
+
self.visualize_btn = QPushButton("Visualize training")
|
|
2530
|
+
self.visualize_btn.setToolTip("Open real-time training visualization")
|
|
2531
|
+
self.visualize_btn.clicked.connect(self._open_visualization)
|
|
2532
|
+
self.visualize_btn.setEnabled(False)
|
|
2533
|
+
control_layout.addWidget(self.visualize_btn)
|
|
2534
|
+
|
|
2535
|
+
self.advanced_btn = QPushButton("Advanced: Profiles")
|
|
2536
|
+
self.advanced_btn.setToolTip("Manage training profiles for batch experiments")
|
|
2537
|
+
self.advanced_btn.clicked.connect(self._open_profile_manager)
|
|
2538
|
+
control_layout.addWidget(self.advanced_btn)
|
|
2539
|
+
|
|
2540
|
+
control_layout.addStretch()
|
|
2541
|
+
|
|
2542
|
+
self.batch_train_check = QCheckBox("Batch Train Selected Profiles")
|
|
2543
|
+
self.batch_train_check.setToolTip("If checked, training will run sequentially for all profiles selected in the Advanced menu.")
|
|
2544
|
+
control_layout.addWidget(self.batch_train_check)
|
|
2545
|
+
|
|
2546
|
+
self.train_btn = QPushButton("Start training")
|
|
2547
|
+
self.train_btn.clicked.connect(self._start_training)
|
|
2548
|
+
self.stop_btn = QPushButton("Stop")
|
|
2549
|
+
self.stop_btn.clicked.connect(self._stop_training)
|
|
2550
|
+
self.stop_btn.setEnabled(False)
|
|
2551
|
+
control_layout.addWidget(self.train_btn)
|
|
2552
|
+
control_layout.addWidget(self.stop_btn)
|
|
2553
|
+
layout.addLayout(control_layout)
|
|
2554
|
+
|
|
2555
|
+
self.progress_bar = QProgressBar()
|
|
2556
|
+
self.progress_bar.setVisible(False)
|
|
2557
|
+
layout.addWidget(self.progress_bar)
|
|
2558
|
+
|
|
2559
|
+
log_group = QGroupBox("Training logs")
|
|
2560
|
+
log_layout = QVBoxLayout()
|
|
2561
|
+
self.log_text = QPlainTextEdit()
|
|
2562
|
+
self.log_text.setReadOnly(True)
|
|
2563
|
+
log_layout.addWidget(self.log_text)
|
|
2564
|
+
log_group.setLayout(log_layout)
|
|
2565
|
+
layout.addWidget(log_group)
|
|
2566
|
+
|
|
2567
|
+
self.setLayout(layout)
|
|
2568
|
+
self._on_temporal_decoder_toggled(int(self.use_temporal_decoder_check.isChecked()))
|
|
2569
|
+
|
|
2570
|
+
def _on_use_all_changed(self, state: int):
|
|
2571
|
+
"""Enable/disable validation split controls."""
|
|
2572
|
+
use_all = self.use_all_for_training_check.isChecked()
|
|
2573
|
+
self.val_split_spin.setEnabled(not use_all)
|
|
2574
|
+
|
|
2575
|
+
if use_all:
|
|
2576
|
+
self.log_text.appendPlainText("Using all data for training - validation disabled")
|
|
2577
|
+
else:
|
|
2578
|
+
self.log_text.appendPlainText(f"Validation split enabled: {self.val_split_spin.value():.1%}")
|
|
2579
|
+
|
|
2580
|
+
def _on_temporal_decoder_toggled(self, state: int):
|
|
2581
|
+
"""Enable only the controls that affect the temporal decoder branch."""
|
|
2582
|
+
use_decoder = self.use_temporal_decoder_check.isChecked()
|
|
2583
|
+
self.frame_head_layers_spin.setEnabled(use_decoder)
|
|
2584
|
+
self.temporal_pool_spin.setEnabled(use_decoder)
|
|
2585
|
+
self.boundary_loss_weight_spin.setEnabled(use_decoder)
|
|
2586
|
+
self.boundary_tolerance_spin.setEnabled(use_decoder)
|
|
2587
|
+
if not use_decoder:
|
|
2588
|
+
self.boundary_loss_weight_spin.setValue(0.0)
|
|
2589
|
+
elif self.boundary_loss_weight_spin.value() == 0.0:
|
|
2590
|
+
self.boundary_loss_weight_spin.setValue(0.3)
|
|
2591
|
+
|
|
2592
|
+
def _on_ovr_toggled(self, state: int):
|
|
2593
|
+
"""OvR has built-in per-head class balancing — disable redundant class weights."""
|
|
2594
|
+
is_on = self.use_ovr_check.isChecked()
|
|
2595
|
+
self.ovr_background_negative_check.setEnabled(is_on)
|
|
2596
|
+
self.ovr_label_smoothing_spin.setEnabled(is_on)
|
|
2597
|
+
self.use_asl_check.setEnabled(is_on)
|
|
2598
|
+
self.asl_gamma_neg_spin.setEnabled(is_on and self.use_asl_check.isChecked())
|
|
2599
|
+
self.asl_gamma_pos_spin.setEnabled(is_on and self.use_asl_check.isChecked())
|
|
2600
|
+
self.asl_clip_spin.setEnabled(is_on and self.use_asl_check.isChecked())
|
|
2601
|
+
self.use_confusion_sampler_check.setEnabled(is_on)
|
|
2602
|
+
self.confusion_temperature_spin.setEnabled(
|
|
2603
|
+
is_on and self.use_confusion_sampler_check.isChecked()
|
|
2604
|
+
)
|
|
2605
|
+
self.confusion_warmup_spin.setEnabled(
|
|
2606
|
+
is_on and self.use_confusion_sampler_check.isChecked()
|
|
2607
|
+
)
|
|
2608
|
+
self.use_hard_pair_mining_check.setEnabled(is_on)
|
|
2609
|
+
self._on_hard_pair_toggled(state)
|
|
2610
|
+
if is_on:
|
|
2611
|
+
self.use_class_weights_check.setChecked(False)
|
|
2612
|
+
self.use_class_weights_check.setEnabled(False)
|
|
2613
|
+
else:
|
|
2614
|
+
self.use_class_weights_check.setEnabled(True)
|
|
2615
|
+
|
|
2616
|
+
def _on_asl_toggled(self, state: int):
|
|
2617
|
+
"""Enable/disable ASL parameter controls."""
|
|
2618
|
+
is_on = self.use_asl_check.isChecked() and self.use_ovr_check.isChecked()
|
|
2619
|
+
self.asl_gamma_neg_spin.setEnabled(is_on)
|
|
2620
|
+
self.asl_gamma_pos_spin.setEnabled(is_on)
|
|
2621
|
+
self.asl_clip_spin.setEnabled(is_on)
|
|
2622
|
+
|
|
2623
|
+
def _on_hard_pair_toggled(self, state: int):
|
|
2624
|
+
"""Enable/disable hard-pair controls."""
|
|
2625
|
+
is_on = self.use_ovr_check.isChecked() and self.use_hard_pair_mining_check.isChecked()
|
|
2626
|
+
self.hard_pair_edit.setEnabled(is_on)
|
|
2627
|
+
self.hard_pair_loss_weight_spin.setEnabled(is_on)
|
|
2628
|
+
self.hard_pair_margin_spin.setEnabled(is_on)
|
|
2629
|
+
self.hard_pair_confusion_boost_spin.setEnabled(is_on)
|
|
2630
|
+
|
|
2631
|
+
def _on_use_augmentation_changed(self, state: int):
|
|
2632
|
+
"""Enable/disable augmentation options button and aug cache versions."""
|
|
2633
|
+
is_on = self.use_augmentation_check.isChecked()
|
|
2634
|
+
if hasattr(self, "augmentation_options_btn"):
|
|
2635
|
+
self.augmentation_options_btn.setEnabled(is_on)
|
|
2636
|
+
if not is_on and hasattr(self, "emb_aug_versions_spin"):
|
|
2637
|
+
self.emb_aug_versions_spin.setValue(1)
|
|
2638
|
+
if hasattr(self, "_sync_emb_cache_controls"):
|
|
2639
|
+
self._sync_emb_cache_controls()
|
|
2640
|
+
|
|
2641
|
+
def _on_localization_toggled(self, state: int):
|
|
2642
|
+
"""Manage crop jitter controls when localization is toggled."""
|
|
2643
|
+
loc_on = bool(state)
|
|
2644
|
+
if loc_on:
|
|
2645
|
+
self.use_crop_jitter_check.setEnabled(True)
|
|
2646
|
+
self.crop_jitter_strength_spin.setEnabled(self.use_crop_jitter_check.isChecked())
|
|
2647
|
+
else:
|
|
2648
|
+
self.use_crop_jitter_check.setChecked(False)
|
|
2649
|
+
self.use_crop_jitter_check.setEnabled(False)
|
|
2650
|
+
self.crop_jitter_strength_spin.setEnabled(False)
|
|
2651
|
+
|
|
2652
|
+
def _sync_stitch_controls(self):
|
|
2653
|
+
"""Disable stitching whenever localization supervision is enabled."""
|
|
2654
|
+
if not hasattr(self, "use_stitch_check"):
|
|
2655
|
+
return
|
|
2656
|
+
localization_on = bool(self.use_localization_check.isChecked()) if hasattr(self, "use_localization_check") else False
|
|
2657
|
+
if localization_on:
|
|
2658
|
+
self.use_stitch_check.setChecked(False)
|
|
2659
|
+
self.use_stitch_check.setEnabled(False)
|
|
2660
|
+
self.stitch_prob_spin.setEnabled(False)
|
|
2661
|
+
else:
|
|
2662
|
+
self.use_stitch_check.setEnabled(True)
|
|
2663
|
+
self.stitch_prob_spin.setEnabled(self.use_stitch_check.isChecked())
|
|
2664
|
+
if hasattr(self, "_sync_emb_cache_controls"):
|
|
2665
|
+
self._sync_emb_cache_controls()
|
|
2666
|
+
|
|
2667
|
+
def _parse_hard_pairs_text(self, text: str) -> list[list[str]]:
|
|
2668
|
+
"""Parse comma-separated hard-pair text into [[class_a, class_b], ...]."""
|
|
2669
|
+
pairs = []
|
|
2670
|
+
seen = set()
|
|
2671
|
+
for chunk in (text or "").split(","):
|
|
2672
|
+
item = chunk.strip()
|
|
2673
|
+
if not item:
|
|
2674
|
+
continue
|
|
2675
|
+
parts = [p.strip() for p in re.split(r"\s*(?::|>|/|\bvs\b)\s*", item, maxsplit=1, flags=re.IGNORECASE) if p.strip()]
|
|
2676
|
+
if len(parts) != 2 or parts[0] == parts[1]:
|
|
2677
|
+
continue
|
|
2678
|
+
key = tuple(sorted(parts))
|
|
2679
|
+
if key in seen:
|
|
2680
|
+
continue
|
|
2681
|
+
seen.add(key)
|
|
2682
|
+
pairs.append([parts[0], parts[1]])
|
|
2683
|
+
return pairs
|
|
2684
|
+
|
|
2685
|
+
def _format_hard_pairs_text(self, pairs) -> str:
|
|
2686
|
+
"""Format stored hard-pair config back into the UI text box."""
|
|
2687
|
+
if not isinstance(pairs, (list, tuple)):
|
|
2688
|
+
return ""
|
|
2689
|
+
items = []
|
|
2690
|
+
for pair in pairs:
|
|
2691
|
+
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
2692
|
+
continue
|
|
2693
|
+
a_name = str(pair[0]).strip()
|
|
2694
|
+
b_name = str(pair[1]).strip()
|
|
2695
|
+
if not a_name or not b_name:
|
|
2696
|
+
continue
|
|
2697
|
+
items.append(f"{a_name}:{b_name}")
|
|
2698
|
+
return ", ".join(items)
|
|
2699
|
+
|
|
2700
|
+
def _default_augmentation_options(self) -> dict:
|
|
2701
|
+
return {
|
|
2702
|
+
"use_horizontal_flip": True,
|
|
2703
|
+
"use_vertical_flip": False,
|
|
2704
|
+
"use_color_jitter": True,
|
|
2705
|
+
"use_gaussian_blur": True,
|
|
2706
|
+
"use_random_noise": True,
|
|
2707
|
+
"use_small_rotation": False,
|
|
2708
|
+
"use_speed_perturb": False,
|
|
2709
|
+
"use_random_shapes": False,
|
|
2710
|
+
"use_grayscale": False,
|
|
2711
|
+
"use_lighting_robustness": True,
|
|
2712
|
+
}
|
|
2713
|
+
|
|
2714
|
+
def _normalize_augmentation_options(self, options: dict) -> dict:
|
|
2715
|
+
defaults = self._default_augmentation_options()
|
|
2716
|
+
if not isinstance(options, dict):
|
|
2717
|
+
return defaults
|
|
2718
|
+
for key in defaults:
|
|
2719
|
+
defaults[key] = bool(options.get(key, defaults[key]))
|
|
2720
|
+
return defaults
|
|
2721
|
+
|
|
2722
|
+
def _open_augmentation_options_dialog(self):
|
|
2723
|
+
"""Open dialog to select augmentations."""
|
|
2724
|
+
dialog = QDialog(self)
|
|
2725
|
+
dialog.setWindowTitle("Augmentation Options")
|
|
2726
|
+
|
|
2727
|
+
layout = QVBoxLayout(dialog)
|
|
2728
|
+
grid = QGridLayout()
|
|
2729
|
+
|
|
2730
|
+
options = self._normalize_augmentation_options(self.augmentation_options)
|
|
2731
|
+
|
|
2732
|
+
hflip_check = QCheckBox("Random horizontal flip")
|
|
2733
|
+
hflip_check.setChecked(options["use_horizontal_flip"])
|
|
2734
|
+
grid.addWidget(hflip_check, 0, 0)
|
|
2735
|
+
|
|
2736
|
+
vflip_check = QCheckBox("Random vertical flip")
|
|
2737
|
+
vflip_check.setChecked(options["use_vertical_flip"])
|
|
2738
|
+
grid.addWidget(vflip_check, 1, 0)
|
|
2739
|
+
|
|
2740
|
+
color_check = QCheckBox("Color jitter (brightness/contrast/saturation/hue)")
|
|
2741
|
+
color_check.setChecked(options["use_color_jitter"])
|
|
2742
|
+
grid.addWidget(color_check, 2, 0)
|
|
2743
|
+
|
|
2744
|
+
blur_check = QCheckBox("Gaussian blur (0.1-0.5 sigma)")
|
|
2745
|
+
blur_check.setChecked(options["use_gaussian_blur"])
|
|
2746
|
+
grid.addWidget(blur_check, 3, 0)
|
|
2747
|
+
|
|
2748
|
+
noise_check = QCheckBox("Random noise (std=0.02)")
|
|
2749
|
+
noise_check.setChecked(options["use_random_noise"])
|
|
2750
|
+
grid.addWidget(noise_check, 4, 0)
|
|
2751
|
+
|
|
2752
|
+
rot_check = QCheckBox("Small rotation (+/- 5 degrees)")
|
|
2753
|
+
rot_check.setChecked(options["use_small_rotation"])
|
|
2754
|
+
grid.addWidget(rot_check, 5, 0)
|
|
2755
|
+
|
|
2756
|
+
speed_check = QCheckBox("Speed perturbation (0.7x - 1.3x)")
|
|
2757
|
+
speed_check.setChecked(options.get("use_speed_perturb", False))
|
|
2758
|
+
grid.addWidget(speed_check, 6, 0)
|
|
2759
|
+
|
|
2760
|
+
shapes_check = QCheckBox("Random shape overlays (occlusion)")
|
|
2761
|
+
shapes_check.setChecked(options.get("use_random_shapes", False))
|
|
2762
|
+
grid.addWidget(shapes_check, 7, 0)
|
|
2763
|
+
|
|
2764
|
+
gray_check = QCheckBox("Random grayscale (50% chance)")
|
|
2765
|
+
gray_check.setChecked(options.get("use_grayscale", False))
|
|
2766
|
+
grid.addWidget(gray_check, 8, 0)
|
|
2767
|
+
|
|
2768
|
+
light_check = QCheckBox("Lighting / color robustness")
|
|
2769
|
+
light_check.setToolTip("Clip-consistent gamma and per-channel gain jitter to reduce brightness/color bias.")
|
|
2770
|
+
light_check.setChecked(options.get("use_lighting_robustness", False))
|
|
2771
|
+
grid.addWidget(light_check, 9, 0)
|
|
2772
|
+
|
|
2773
|
+
layout.addLayout(grid)
|
|
2774
|
+
|
|
2775
|
+
buttons_layout = QHBoxLayout()
|
|
2776
|
+
buttons_layout.addStretch()
|
|
2777
|
+
ok_btn = QPushButton("OK")
|
|
2778
|
+
cancel_btn = QPushButton("Cancel")
|
|
2779
|
+
ok_btn.clicked.connect(dialog.accept)
|
|
2780
|
+
cancel_btn.clicked.connect(dialog.reject)
|
|
2781
|
+
buttons_layout.addWidget(ok_btn)
|
|
2782
|
+
buttons_layout.addWidget(cancel_btn)
|
|
2783
|
+
layout.addLayout(buttons_layout)
|
|
2784
|
+
|
|
2785
|
+
if dialog.exec() == QDialog.DialogCode.Accepted:
|
|
2786
|
+
self.augmentation_options = {
|
|
2787
|
+
"use_horizontal_flip": hflip_check.isChecked(),
|
|
2788
|
+
"use_vertical_flip": vflip_check.isChecked(),
|
|
2789
|
+
"use_color_jitter": color_check.isChecked(),
|
|
2790
|
+
"use_gaussian_blur": blur_check.isChecked(),
|
|
2791
|
+
"use_random_noise": noise_check.isChecked(),
|
|
2792
|
+
"use_small_rotation": rot_check.isChecked(),
|
|
2793
|
+
"use_speed_perturb": speed_check.isChecked(),
|
|
2794
|
+
"use_random_shapes": shapes_check.isChecked(),
|
|
2795
|
+
"use_grayscale": gray_check.isChecked(),
|
|
2796
|
+
"use_lighting_robustness": light_check.isChecked(),
|
|
2797
|
+
}
|
|
2798
|
+
|
|
2799
|
+
def _browse_annotation(self):
|
|
2800
|
+
"""Browse for annotation file."""
|
|
2801
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
|
2802
|
+
self,
|
|
2803
|
+
"Select Annotation File",
|
|
2804
|
+
self.config.get("data_dir", "data"),
|
|
2805
|
+
"JSON Files (*.json);;All Files (*)"
|
|
2806
|
+
)
|
|
2807
|
+
if file_path:
|
|
2808
|
+
self.annotation_file_edit.setText(file_path)
|
|
2809
|
+
self.refresh_annotation_info()
|
|
2810
|
+
|
|
2811
|
+
def _browse_clips_dir(self):
|
|
2812
|
+
"""Browse for clips directory."""
|
|
2813
|
+
dir_path = QFileDialog.getExistingDirectory(
|
|
2814
|
+
self,
|
|
2815
|
+
"Select Clips Directory",
|
|
2816
|
+
self.config.get("clips_dir", "data/clips")
|
|
2817
|
+
)
|
|
2818
|
+
if dir_path:
|
|
2819
|
+
self.clips_dir_edit.setText(dir_path)
|
|
2820
|
+
|
|
2821
|
+
# Check for metadata
|
|
2822
|
+
import json
|
|
2823
|
+
meta_path = os.path.join(dir_path, "clips_metadata.json")
|
|
2824
|
+
if os.path.exists(meta_path):
|
|
2825
|
+
try:
|
|
2826
|
+
with open(meta_path, 'r') as f:
|
|
2827
|
+
meta = json.load(f)
|
|
2828
|
+
|
|
2829
|
+
clip_len = meta.get("clip_length")
|
|
2830
|
+
if clip_len:
|
|
2831
|
+
self.clip_length_spin.setValue(int(clip_len))
|
|
2832
|
+
self.log_text.appendPlainText(f"Automatically set 'Frames per clip' to {clip_len} from metadata.")
|
|
2833
|
+
except Exception as e:
|
|
2834
|
+
logger.error("Error reading clips metadata: %s", e)
|
|
2835
|
+
|
|
2836
|
+
def _show_clip_length_info(self):
|
|
2837
|
+
"""Show information about frames per clip."""
|
|
2838
|
+
QMessageBox.information(
|
|
2839
|
+
self,
|
|
2840
|
+
"Frames per Clip - Information",
|
|
2841
|
+
"Number of frames the model will use for training.\n\n"
|
|
2842
|
+
"-Can be equal to or less than the actual clip length\n"
|
|
2843
|
+
"-If less, the middle N frames are selected (temporal center-crop)\n"
|
|
2844
|
+
" e.g. 16-frame clips with this set to 8 → frames 4-11 are used\n"
|
|
2845
|
+
"-Useful for testing whether shorter temporal context is sufficient\n\n"
|
|
2846
|
+
"Important:\n"
|
|
2847
|
+
"-Use the same value in the Inference tab when running predictions\n"
|
|
2848
|
+
"-Shorter clips = faster training and lower memory usage\n"
|
|
2849
|
+
"-If clips have fewer frames than this value, they are padded"
|
|
2850
|
+
)
|
|
2851
|
+
|
|
2852
|
+
def _browse_output(self):
|
|
2853
|
+
"""Browse for output model path."""
|
|
2854
|
+
file_path, _ = QFileDialog.getSaveFileName(
|
|
2855
|
+
self,
|
|
2856
|
+
"Save Model",
|
|
2857
|
+
self.config.get("models_dir", "models/behavior_heads"),
|
|
2858
|
+
"PyTorch Files (*.pt);;All Files (*)"
|
|
2859
|
+
)
|
|
2860
|
+
if file_path:
|
|
2861
|
+
self.output_path_edit.setText(file_path)
|
|
2862
|
+
|
|
2863
|
+
def _browse_pretrained(self):
|
|
2864
|
+
"""Browse for pretrained model path."""
|
|
2865
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
|
2866
|
+
self,
|
|
2867
|
+
"Select Pretrained Model",
|
|
2868
|
+
self.config.get("models_dir", "models/behavior_heads"),
|
|
2869
|
+
"PyTorch Files (*.pt);;All Files (*)"
|
|
2870
|
+
)
|
|
2871
|
+
if file_path:
|
|
2872
|
+
self.pretrained_path_edit.setText(file_path)
|
|
2873
|
+
|
|
2874
|
+
def _on_finetune_changed(self, state: int):
|
|
2875
|
+
"""Enable/disable pretrained path inputs."""
|
|
2876
|
+
enabled = self.finetune_check.isChecked()
|
|
2877
|
+
self.pretrained_path_edit.setEnabled(enabled)
|
|
2878
|
+
self.pretrained_browse_btn.setEnabled(enabled)
|
|
2879
|
+
|
|
2880
|
+
def _open_visualization(self):
|
|
2881
|
+
"""Open the training visualization dialog."""
|
|
2882
|
+
if self.visualization_dialog is None:
|
|
2883
|
+
self.visualization_dialog = TrainingVisualizationDialog(self)
|
|
2884
|
+
|
|
2885
|
+
self.visualization_dialog.show()
|
|
2886
|
+
self.visualization_dialog.raise_()
|
|
2887
|
+
self.visualization_dialog.activateWindow()
|
|
2888
|
+
|
|
2889
|
+
def refresh_annotation_info(self):
|
|
2890
|
+
"""Refresh dataset info display."""
|
|
2891
|
+
try:
|
|
2892
|
+
from collections import Counter
|
|
2893
|
+
|
|
2894
|
+
annotation_file = self.annotation_file_edit.text().strip()
|
|
2895
|
+
if not annotation_file:
|
|
2896
|
+
annotation_file = self.config.get(
|
|
2897
|
+
"training_annotation_file",
|
|
2898
|
+
self.config.get("annotation_file", "data/annotations/annotations.json"),
|
|
2899
|
+
)
|
|
2900
|
+
|
|
2901
|
+
self.annotation_manager = AnnotationManager(annotation_file)
|
|
2902
|
+
|
|
2903
|
+
labeled_clips = self.annotation_manager.get_labeled_clips()
|
|
2904
|
+
has_bboxes = any(
|
|
2905
|
+
clip.get("spatial_bbox") or clip.get("spatial_bbox_frames")
|
|
2906
|
+
for clip in labeled_clips
|
|
2907
|
+
)
|
|
2908
|
+
self.loc_group.setEnabled(has_bboxes)
|
|
2909
|
+
if has_bboxes:
|
|
2910
|
+
self.loc_group.setTitle("Localization")
|
|
2911
|
+
self.loc_group.setToolTip("")
|
|
2912
|
+
else:
|
|
2913
|
+
self.use_localization_check.setChecked(False)
|
|
2914
|
+
self.use_manual_loc_switch_check.setChecked(False)
|
|
2915
|
+
self.use_crop_jitter_check.setChecked(False)
|
|
2916
|
+
self.loc_group.setTitle("Localization (requires bbox labels)")
|
|
2917
|
+
self.loc_group.setToolTip(
|
|
2918
|
+
"Optional: localizes individual animals when multiple are in camera view.\n"
|
|
2919
|
+
"To activate, draw and save bounding boxes on at least some clips in the Labeling tab."
|
|
2920
|
+
)
|
|
2921
|
+
classes = sorted(self.annotation_manager.get_classes())
|
|
2922
|
+
real_classes = [c for c in classes if not c.startswith("near_negative")]
|
|
2923
|
+
hard_negative_classes = [c for c in classes if c.startswith("near_negative")]
|
|
2924
|
+
counts = self.annotation_manager.get_clip_count_by_label()
|
|
2925
|
+
primary_counts = Counter(
|
|
2926
|
+
clip.get("label", "")
|
|
2927
|
+
for clip in labeled_clips
|
|
2928
|
+
if clip.get("label")
|
|
2929
|
+
)
|
|
2930
|
+
|
|
2931
|
+
ml_stats = self.annotation_manager.get_multilabel_stats()
|
|
2932
|
+
exclusive = ml_stats["exclusive"]
|
|
2933
|
+
shared = ml_stats["shared"]
|
|
2934
|
+
combos = ml_stats["combos"]
|
|
2935
|
+
|
|
2936
|
+
info_text = f"Labeled clips: {len(labeled_clips)}\n"
|
|
2937
|
+
info_text += f"Behavior classes: {len(real_classes)}\n"
|
|
2938
|
+
if real_classes:
|
|
2939
|
+
info_text += f"Behavior names: {', '.join(real_classes)}\n"
|
|
2940
|
+
if hard_negative_classes:
|
|
2941
|
+
hn_clip_count = sum(primary_counts.get(label, 0) for label in hard_negative_classes)
|
|
2942
|
+
info_text += f"Hard-negative helper labels: {len(hard_negative_classes)} ({hn_clip_count} clips)\n"
|
|
2943
|
+
|
|
2944
|
+
info_text += "\nPrimary-label training counts:\n"
|
|
2945
|
+
for label in real_classes:
|
|
2946
|
+
primary_count = int(primary_counts.get(label, 0))
|
|
2947
|
+
membership_count = int(counts.get(label, 0))
|
|
2948
|
+
exc = int(exclusive.get(label, 0))
|
|
2949
|
+
sh = int(shared.get(label, 0))
|
|
2950
|
+
if sh > 0:
|
|
2951
|
+
info_text += (
|
|
2952
|
+
f" {label}: {primary_count} primary"
|
|
2953
|
+
f" ({membership_count} label memberships: {exc} exclusive, {sh} multi-class)\n"
|
|
2954
|
+
)
|
|
2955
|
+
else:
|
|
2956
|
+
info_text += f" {label}: {primary_count} primary\n"
|
|
2957
|
+
|
|
2958
|
+
if hard_negative_classes:
|
|
2959
|
+
info_text += "\nHard-negative suppression clips:\n"
|
|
2960
|
+
for label in hard_negative_classes:
|
|
2961
|
+
info_text += f" {label}: {int(primary_counts.get(label, 0))}\n"
|
|
2962
|
+
|
|
2963
|
+
real_combos = {
|
|
2964
|
+
combo: cnt
|
|
2965
|
+
for combo, cnt in combos.items()
|
|
2966
|
+
if combo and all(lbl in real_classes for lbl in combo)
|
|
2967
|
+
}
|
|
2968
|
+
if real_combos:
|
|
2969
|
+
total_mc = sum(real_combos.values())
|
|
2970
|
+
info_text += f"\nMulti-class clips: {total_mc}\n"
|
|
2971
|
+
for combo, cnt in sorted(real_combos.items(), key=lambda x: -x[1]):
|
|
2972
|
+
info_text += f" {' + '.join(combo)}: {cnt}\n"
|
|
2973
|
+
|
|
2974
|
+
self.info_label.setText(info_text)
|
|
2975
|
+
|
|
2976
|
+
# Update class selection list (exclude near_negative_* — they're auto-included in OvR)
|
|
2977
|
+
self.class_selection_list.clear()
|
|
2978
|
+
for class_name in real_classes:
|
|
2979
|
+
item = QListWidgetItem(class_name)
|
|
2980
|
+
item.setCheckState(Qt.CheckState.Checked)
|
|
2981
|
+
self.class_selection_list.addItem(item)
|
|
2982
|
+
|
|
2983
|
+
# Update per-class limit table
|
|
2984
|
+
self.per_class_limit_table.setRowCount(len(real_classes))
|
|
2985
|
+
for row, class_name in enumerate(real_classes):
|
|
2986
|
+
class_item = QTableWidgetItem(class_name)
|
|
2987
|
+
class_item.setFlags(class_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
|
2988
|
+
self.per_class_limit_table.setItem(row, 0, class_item)
|
|
2989
|
+
|
|
2990
|
+
count = int(primary_counts.get(class_name, 0))
|
|
2991
|
+
|
|
2992
|
+
# Max Train
|
|
2993
|
+
train_item = QTableWidgetItem(str(count))
|
|
2994
|
+
train_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
2995
|
+
self.per_class_limit_table.setItem(row, 1, train_item)
|
|
2996
|
+
|
|
2997
|
+
# Max Val (default to count i.e. unlimited/all)
|
|
2998
|
+
val_item = QTableWidgetItem(str(count))
|
|
2999
|
+
val_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
3000
|
+
self.per_class_limit_table.setItem(row, 2, val_item)
|
|
3001
|
+
except Exception as e:
|
|
3002
|
+
self.info_label.setText(f"Error loading info: {e}")
|
|
3003
|
+
|
|
3004
|
+
def _on_select_classes_changed(self, state: int):
|
|
3005
|
+
"""Enable/disable class selection list."""
|
|
3006
|
+
self.class_selection_list.setEnabled(self.select_classes_check.isChecked())
|
|
3007
|
+
|
|
3008
|
+
def _on_limit_per_class_changed(self, state: int):
|
|
3009
|
+
"""Enable/disable per-class limit table."""
|
|
3010
|
+
self.per_class_limit_table.setEnabled(self.limit_per_class_check.isChecked())
|
|
3011
|
+
|
|
3012
|
+
def _on_auto_tune_changed(self, state: int):
|
|
3013
|
+
"""Enable/disable auto-tune controls."""
|
|
3014
|
+
is_on = self.auto_tune_check.isChecked()
|
|
3015
|
+
self.auto_tune_runs_spin.setEnabled(is_on)
|
|
3016
|
+
self.auto_tune_epochs_spin.setEnabled(is_on)
|
|
3017
|
+
|
|
3018
|
+
def _get_training_profiles_path(self):
|
|
3019
|
+
"""Return the profile storage path for the current experiment."""
|
|
3020
|
+
config_path = self.config.get("config_path")
|
|
3021
|
+
if config_path:
|
|
3022
|
+
exp_dir = os.path.dirname(config_path)
|
|
3023
|
+
return os.path.join(exp_dir, "training_profiles.json")
|
|
3024
|
+
from singlebehaviorlab._paths import get_training_profiles_path
|
|
3025
|
+
return str(get_training_profiles_path())
|
|
3026
|
+
|
|
3027
|
+
def _open_profile_manager(self):
|
|
3028
|
+
"""Open the training profiles manager dialog."""
|
|
3029
|
+
profiles_path = self._get_training_profiles_path()
|
|
3030
|
+
|
|
3031
|
+
if not self.profile_dialog:
|
|
3032
|
+
self.profile_dialog = TrainingProfileDialog(self, profiles_file=profiles_path)
|
|
3033
|
+
else:
|
|
3034
|
+
self.profile_dialog.reload_profiles(profiles_path)
|
|
3035
|
+
|
|
3036
|
+
self.profile_dialog.show()
|
|
3037
|
+
self.profile_dialog.raise_()
|
|
3038
|
+
self.profile_dialog.activateWindow()
|
|
3039
|
+
|
|
3040
|
+
def get_training_config(self):
|
|
3041
|
+
"""Extract current training configuration from UI components."""
|
|
3042
|
+
head_kwargs = {
|
|
3043
|
+
"num_heads": self.map_num_heads_spin.value(),
|
|
3044
|
+
}
|
|
3045
|
+
|
|
3046
|
+
selected_classes = []
|
|
3047
|
+
if self.select_classes_check.isChecked():
|
|
3048
|
+
for i in range(self.class_selection_list.count()):
|
|
3049
|
+
item = self.class_selection_list.item(i)
|
|
3050
|
+
if item.checkState() == Qt.CheckState.Checked:
|
|
3051
|
+
selected_classes.append(item.text())
|
|
3052
|
+
|
|
3053
|
+
per_class_limits = {}
|
|
3054
|
+
per_class_val_limits = {}
|
|
3055
|
+
if self.limit_per_class_check.isChecked():
|
|
3056
|
+
for row in range(self.per_class_limit_table.rowCount()):
|
|
3057
|
+
class_item = self.per_class_limit_table.item(row, 0)
|
|
3058
|
+
train_item = self.per_class_limit_table.item(row, 1)
|
|
3059
|
+
val_item = self.per_class_limit_table.item(row, 2)
|
|
3060
|
+
|
|
3061
|
+
if class_item:
|
|
3062
|
+
class_name = class_item.text()
|
|
3063
|
+
# Train limit (allow float e.g. 1.5 → applied as 2)
|
|
3064
|
+
if train_item:
|
|
3065
|
+
try:
|
|
3066
|
+
val = float(train_item.text())
|
|
3067
|
+
if val > 0: per_class_limits[class_name] = val
|
|
3068
|
+
except ValueError: pass
|
|
3069
|
+
|
|
3070
|
+
# Val limit (allow float)
|
|
3071
|
+
if val_item:
|
|
3072
|
+
try:
|
|
3073
|
+
val = float(val_item.text())
|
|
3074
|
+
if val > 0: per_class_val_limits[class_name] = val
|
|
3075
|
+
except ValueError: pass
|
|
3076
|
+
|
|
3077
|
+
# Use selected classes for metadata if class limiting is enabled, otherwise use all classes
|
|
3078
|
+
if hasattr(self, 'annotation_manager'):
|
|
3079
|
+
all_classes = self.annotation_manager.get_classes()
|
|
3080
|
+
else:
|
|
3081
|
+
all_classes = []
|
|
3082
|
+
classes_for_metadata = selected_classes if self.select_classes_check.isChecked() and selected_classes else all_classes
|
|
3083
|
+
background_class_names = [
|
|
3084
|
+
c for c in classes_for_metadata
|
|
3085
|
+
if c.lower() in ("other", "background", "bg", "none")
|
|
3086
|
+
]
|
|
3087
|
+
if self.use_ovr_check.isChecked() and self.ovr_background_negative_check.isChecked():
|
|
3088
|
+
classes_for_metadata = [c for c in classes_for_metadata if c not in background_class_names]
|
|
3089
|
+
|
|
3090
|
+
# Auto-detect helper classes (Other/Background) to exclude from F1 metrics
|
|
3091
|
+
_f1_exclude = [c for c in classes_for_metadata
|
|
3092
|
+
if c.lower() in ("other", "background", "bg", "none")]
|
|
3093
|
+
|
|
3094
|
+
# Get pretrained path if enabled
|
|
3095
|
+
pretrained_path = None
|
|
3096
|
+
if self.finetune_check.isChecked():
|
|
3097
|
+
pretrained_path = self.pretrained_path_edit.text().strip()
|
|
3098
|
+
hard_pairs = self._parse_hard_pairs_text(self.hard_pair_edit.text())
|
|
3099
|
+
resolution_cfg = int(self.config.get("resolution", self._resolution if hasattr(self, "_resolution") else 288))
|
|
3100
|
+
if resolution_cfg % 18 != 0:
|
|
3101
|
+
resolution_cfg = max(18, (resolution_cfg // 18) * 18)
|
|
3102
|
+
|
|
3103
|
+
return {
|
|
3104
|
+
"batch_size": self.batch_size_spin.value(),
|
|
3105
|
+
"epochs": self.epochs_spin.value(),
|
|
3106
|
+
"lr": self.class_lr_spin.value(),
|
|
3107
|
+
"localization_lr": self.loc_lr_spin.value(),
|
|
3108
|
+
"classification_lr": self.class_lr_spin.value(),
|
|
3109
|
+
"use_scheduler": bool(self.config.get("default_use_scheduler", True)),
|
|
3110
|
+
"use_ema": bool(self.config.get("default_use_ema", True)),
|
|
3111
|
+
"weight_decay": self.weight_decay_spin.value(),
|
|
3112
|
+
"head_kwargs": head_kwargs,
|
|
3113
|
+
"dropout": self.map_dropout_spin.value(),
|
|
3114
|
+
"clip_length": self.clip_length_spin.value(),
|
|
3115
|
+
"target_fps": int(self.config.get("default_target_fps", 16)),
|
|
3116
|
+
"resolution": resolution_cfg,
|
|
3117
|
+
"use_all_for_training": self.use_all_for_training_check.isChecked(),
|
|
3118
|
+
"val_split": self.val_split_spin.value(),
|
|
3119
|
+
"auto_tune_before_final": self.auto_tune_check.isChecked(),
|
|
3120
|
+
"auto_tune_runs": self.auto_tune_runs_spin.value(),
|
|
3121
|
+
"auto_tune_epochs": self.auto_tune_epochs_spin.value(),
|
|
3122
|
+
"use_class_weights": bool(getattr(self, "use_class_weights_check", None) and self.use_class_weights_check.isChecked()),
|
|
3123
|
+
"use_focal_loss": False,
|
|
3124
|
+
"focal_gamma": 2.0,
|
|
3125
|
+
"use_supcon_loss": self.use_supcon_check.isChecked(),
|
|
3126
|
+
"supcon_weight": self.supcon_weight_spin.value(),
|
|
3127
|
+
"supcon_temperature": self.supcon_temp_spin.value(),
|
|
3128
|
+
"use_frame_loss": True,
|
|
3129
|
+
"use_temporal_decoder": self.use_temporal_decoder_check.isChecked(),
|
|
3130
|
+
"frame_head_temporal_layers": self.frame_head_layers_spin.value(),
|
|
3131
|
+
"temporal_pool_frames": self.temporal_pool_spin.value(),
|
|
3132
|
+
"proj_dim": self.proj_dim_spin.value(),
|
|
3133
|
+
"use_frame_bout_balance": self.use_bout_balance_check.isChecked(),
|
|
3134
|
+
"frame_bout_balance_power": self.bout_balance_power_spin.value(),
|
|
3135
|
+
"boundary_loss_weight": self.boundary_loss_weight_spin.value() if self.use_temporal_decoder_check.isChecked() else 0.0,
|
|
3136
|
+
"boundary_tolerance": self.boundary_tolerance_spin.value(),
|
|
3137
|
+
"smoothness_loss_weight": self.smoothness_loss_weight_spin.value(),
|
|
3138
|
+
"use_localization": self.use_localization_check.isChecked(),
|
|
3139
|
+
"use_manual_localization_switch": self.use_manual_loc_switch_check.isChecked(),
|
|
3140
|
+
"manual_localization_switch_epoch": self.manual_loc_switch_epoch_spin.value(),
|
|
3141
|
+
"localization_hidden_dim": 256,
|
|
3142
|
+
"classification_crop_padding": self.crop_padding_spin.value(),
|
|
3143
|
+
"crop_jitter": self.use_crop_jitter_check.isChecked() and self.use_localization_check.isChecked(),
|
|
3144
|
+
"crop_jitter_strength": self.crop_jitter_strength_spin.value(),
|
|
3145
|
+
"use_ovr": self.use_ovr_check.isChecked(),
|
|
3146
|
+
"ovr_background_as_negative": (
|
|
3147
|
+
self.ovr_background_negative_check.isChecked()
|
|
3148
|
+
and self.use_ovr_check.isChecked()
|
|
3149
|
+
),
|
|
3150
|
+
"ovr_background_class_names": background_class_names,
|
|
3151
|
+
"ovr_label_smoothing": self.ovr_label_smoothing_spin.value(),
|
|
3152
|
+
"use_asl": self.use_asl_check.isChecked() and self.use_ovr_check.isChecked(),
|
|
3153
|
+
"asl_gamma_neg": self.asl_gamma_neg_spin.value(),
|
|
3154
|
+
"asl_gamma_pos": self.asl_gamma_pos_spin.value(),
|
|
3155
|
+
"asl_clip": self.asl_clip_spin.value(),
|
|
3156
|
+
"use_hard_pair_mining": (
|
|
3157
|
+
self.use_hard_pair_mining_check.isChecked()
|
|
3158
|
+
and self.use_ovr_check.isChecked()
|
|
3159
|
+
and bool(hard_pairs)
|
|
3160
|
+
),
|
|
3161
|
+
"hard_pairs": hard_pairs,
|
|
3162
|
+
"hard_pair_loss_weight": self.hard_pair_loss_weight_spin.value(),
|
|
3163
|
+
"hard_pair_margin": self.hard_pair_margin_spin.value(),
|
|
3164
|
+
"hard_pair_confusion_boost": self.hard_pair_confusion_boost_spin.value(),
|
|
3165
|
+
"use_confusion_sampler": (
|
|
3166
|
+
self.use_confusion_sampler_check.isChecked()
|
|
3167
|
+
and self.use_ovr_check.isChecked()
|
|
3168
|
+
),
|
|
3169
|
+
"confusion_sampler_temperature": self.confusion_temperature_spin.value(),
|
|
3170
|
+
"confusion_sampler_warmup_pct": self.confusion_warmup_spin.value() / 100.0,
|
|
3171
|
+
"use_weighted_sampler": self.use_weighted_sampler_check.isChecked(),
|
|
3172
|
+
"use_augmentation": self.use_augmentation_check.isChecked(),
|
|
3173
|
+
"virtual_expansion": self.virtual_expansion_spin.value(),
|
|
3174
|
+
"stitch_augmentation_prob": self.stitch_prob_spin.value() if self.use_stitch_check.isChecked() else 0.0,
|
|
3175
|
+
"emb_aug_versions": self.emb_aug_versions_spin.value() if hasattr(self, "emb_aug_versions_spin") else 1,
|
|
3176
|
+
"multi_scale": self.use_multi_scale_check.isChecked() if hasattr(self, "use_multi_scale_check") else False,
|
|
3177
|
+
"augmentation_options": dict(self.augmentation_options),
|
|
3178
|
+
"limit_classes": self.select_classes_check.isChecked(),
|
|
3179
|
+
"selected_classes": selected_classes,
|
|
3180
|
+
"limit_per_class": self.limit_per_class_check.isChecked(),
|
|
3181
|
+
"per_class_limits": per_class_limits,
|
|
3182
|
+
"per_class_val_limits": per_class_val_limits,
|
|
3183
|
+
"use_embedding_diversity": self.use_embedding_diversity_check.isChecked(),
|
|
3184
|
+
"backbone_model": self.config.get("backbone_model", "videoprism_public_v1_base"),
|
|
3185
|
+
"class_names": classes_for_metadata,
|
|
3186
|
+
"pretrained_path": pretrained_path,
|
|
3187
|
+
"f1_exclude_classes": _f1_exclude,
|
|
3188
|
+
"ovr_pos_weight_f1_excluded": getattr(self, "_ovr_pos_weight_f1_excluded", 1.5),
|
|
3189
|
+
}
|
|
3190
|
+
|
|
3191
|
+
def apply_training_config(self, config):
|
|
3192
|
+
"""Apply a configuration dictionary to the UI components."""
|
|
3193
|
+
try:
|
|
3194
|
+
self.use_localization_check.setChecked(False)
|
|
3195
|
+
self.use_manual_loc_switch_check.setChecked(False)
|
|
3196
|
+
self.manual_loc_switch_epoch_spin.setValue(20)
|
|
3197
|
+
self.crop_padding_spin.setValue(0.35)
|
|
3198
|
+
self.use_crop_jitter_check.setChecked(False)
|
|
3199
|
+
self.crop_jitter_strength_spin.setValue(0.15)
|
|
3200
|
+
|
|
3201
|
+
if "batch_size" in config: self.batch_size_spin.setValue(config["batch_size"])
|
|
3202
|
+
if "epochs" in config: self.epochs_spin.setValue(config["epochs"])
|
|
3203
|
+
if "lr" in config:
|
|
3204
|
+
self.class_lr_spin.setValue(config["lr"])
|
|
3205
|
+
if "classification_lr" in config:
|
|
3206
|
+
self.class_lr_spin.setValue(config["classification_lr"])
|
|
3207
|
+
if "localization_lr" in config:
|
|
3208
|
+
self.loc_lr_spin.setValue(config["localization_lr"])
|
|
3209
|
+
elif "classification_lr" in config:
|
|
3210
|
+
self.loc_lr_spin.setValue(config["classification_lr"])
|
|
3211
|
+
elif "lr" in config:
|
|
3212
|
+
self.loc_lr_spin.setValue(config["lr"])
|
|
3213
|
+
if "weight_decay" in config: self.weight_decay_spin.setValue(config["weight_decay"])
|
|
3214
|
+
if "dropout" in config: self.map_dropout_spin.setValue(config["dropout"])
|
|
3215
|
+
if "clip_length" in config: self.clip_length_spin.setValue(config["clip_length"])
|
|
3216
|
+
if "use_all_for_training" in config: self.use_all_for_training_check.setChecked(config["use_all_for_training"])
|
|
3217
|
+
if "val_split" in config: self.val_split_spin.setValue(config["val_split"])
|
|
3218
|
+
if "auto_tune_before_final" in config: self.auto_tune_check.setChecked(bool(config["auto_tune_before_final"]))
|
|
3219
|
+
if "auto_tune_runs" in config: self.auto_tune_runs_spin.setValue(int(config["auto_tune_runs"]))
|
|
3220
|
+
if "auto_tune_epochs" in config: self.auto_tune_epochs_spin.setValue(int(config["auto_tune_epochs"]))
|
|
3221
|
+
|
|
3222
|
+
# Loss settings
|
|
3223
|
+
if "use_class_weights" in config: self.use_class_weights_check.setChecked(config["use_class_weights"])
|
|
3224
|
+
|
|
3225
|
+
if "use_supcon_loss" in config:
|
|
3226
|
+
self.use_supcon_check.setChecked(bool(config["use_supcon_loss"]))
|
|
3227
|
+
if "supcon_weight" in config:
|
|
3228
|
+
self.supcon_weight_spin.setValue(float(config["supcon_weight"]))
|
|
3229
|
+
if "supcon_temperature" in config:
|
|
3230
|
+
self.supcon_temp_spin.setValue(float(config["supcon_temperature"]))
|
|
3231
|
+
if "use_temporal_decoder" in config:
|
|
3232
|
+
self.use_temporal_decoder_check.setChecked(bool(config["use_temporal_decoder"]))
|
|
3233
|
+
|
|
3234
|
+
if "frame_head_temporal_layers" in config:
|
|
3235
|
+
self.frame_head_layers_spin.setValue(int(config["frame_head_temporal_layers"]))
|
|
3236
|
+
if "temporal_pool_frames" in config:
|
|
3237
|
+
self.temporal_pool_spin.setValue(int(config["temporal_pool_frames"]))
|
|
3238
|
+
if "proj_dim" in config:
|
|
3239
|
+
self.proj_dim_spin.setValue(int(config["proj_dim"]))
|
|
3240
|
+
if "multi_scale" in config and hasattr(self, "use_multi_scale_check"):
|
|
3241
|
+
self.use_multi_scale_check.setChecked(bool(config["multi_scale"]))
|
|
3242
|
+
if "boundary_loss_weight" in config:
|
|
3243
|
+
self.boundary_loss_weight_spin.setValue(config["boundary_loss_weight"])
|
|
3244
|
+
if "boundary_tolerance" in config:
|
|
3245
|
+
self.boundary_tolerance_spin.setValue(int(config["boundary_tolerance"]))
|
|
3246
|
+
if "smoothness_loss_weight" in config:
|
|
3247
|
+
self.smoothness_loss_weight_spin.setValue(config["smoothness_loss_weight"])
|
|
3248
|
+
if "use_frame_bout_balance" in config:
|
|
3249
|
+
self.use_bout_balance_check.setChecked(bool(config["use_frame_bout_balance"]))
|
|
3250
|
+
if "frame_bout_balance_power" in config and config["frame_bout_balance_power"] is not None:
|
|
3251
|
+
self.bout_balance_power_spin.setValue(float(config["frame_bout_balance_power"]))
|
|
3252
|
+
if "use_localization" in config:
|
|
3253
|
+
self.use_localization_check.setChecked(config["use_localization"])
|
|
3254
|
+
|
|
3255
|
+
if "use_manual_localization_switch" in config:
|
|
3256
|
+
self.use_manual_loc_switch_check.setChecked(config["use_manual_localization_switch"])
|
|
3257
|
+
if "manual_localization_switch_epoch" in config:
|
|
3258
|
+
self.manual_loc_switch_epoch_spin.setValue(config["manual_localization_switch_epoch"])
|
|
3259
|
+
if "classification_crop_padding" in config:
|
|
3260
|
+
self.crop_padding_spin.setValue(float(config["classification_crop_padding"]))
|
|
3261
|
+
if "crop_jitter" in config:
|
|
3262
|
+
self.use_crop_jitter_check.setChecked(bool(config["crop_jitter"]))
|
|
3263
|
+
if "crop_jitter_strength" in config:
|
|
3264
|
+
self.crop_jitter_strength_spin.setValue(float(config["crop_jitter_strength"]))
|
|
3265
|
+
|
|
3266
|
+
if "use_ovr" in config: self.use_ovr_check.setChecked(config["use_ovr"])
|
|
3267
|
+
if "ovr_background_as_negative" in config:
|
|
3268
|
+
self.ovr_background_negative_check.setChecked(bool(config["ovr_background_as_negative"]))
|
|
3269
|
+
if "ovr_pos_weight_f1_excluded" in config:
|
|
3270
|
+
self._ovr_pos_weight_f1_excluded = float(config["ovr_pos_weight_f1_excluded"])
|
|
3271
|
+
if "ovr_label_smoothing" in config: self.ovr_label_smoothing_spin.setValue(config["ovr_label_smoothing"])
|
|
3272
|
+
if "use_asl" in config: self.use_asl_check.setChecked(config["use_asl"])
|
|
3273
|
+
if "asl_gamma_neg" in config: self.asl_gamma_neg_spin.setValue(float(config["asl_gamma_neg"]))
|
|
3274
|
+
if "asl_gamma_pos" in config: self.asl_gamma_pos_spin.setValue(float(config["asl_gamma_pos"]))
|
|
3275
|
+
if "asl_clip" in config: self.asl_clip_spin.setValue(float(config["asl_clip"]))
|
|
3276
|
+
if "hard_pairs" in config:
|
|
3277
|
+
self.hard_pair_edit.setText(self._format_hard_pairs_text(config["hard_pairs"]))
|
|
3278
|
+
if "use_hard_pair_mining" in config:
|
|
3279
|
+
self.use_hard_pair_mining_check.setChecked(bool(config["use_hard_pair_mining"]))
|
|
3280
|
+
if "hard_pair_loss_weight" in config and config["hard_pair_loss_weight"] is not None:
|
|
3281
|
+
self.hard_pair_loss_weight_spin.setValue(float(config["hard_pair_loss_weight"]))
|
|
3282
|
+
if "hard_pair_margin" in config and config["hard_pair_margin"] is not None:
|
|
3283
|
+
self.hard_pair_margin_spin.setValue(float(config["hard_pair_margin"]))
|
|
3284
|
+
if "hard_pair_confusion_boost" in config and config["hard_pair_confusion_boost"] is not None:
|
|
3285
|
+
self.hard_pair_confusion_boost_spin.setValue(float(config["hard_pair_confusion_boost"]))
|
|
3286
|
+
if "use_confusion_sampler" in config:
|
|
3287
|
+
self.use_confusion_sampler_check.setChecked(bool(config["use_confusion_sampler"]))
|
|
3288
|
+
if "confusion_sampler_temperature" in config:
|
|
3289
|
+
self.confusion_temperature_spin.setValue(float(config["confusion_sampler_temperature"]))
|
|
3290
|
+
if "confusion_sampler_warmup_pct" in config:
|
|
3291
|
+
self.confusion_warmup_spin.setValue(int(float(config["confusion_sampler_warmup_pct"]) * 100))
|
|
3292
|
+
if "use_weighted_sampler" in config: self.use_weighted_sampler_check.setChecked(config["use_weighted_sampler"])
|
|
3293
|
+
if "use_augmentation" in config: self.use_augmentation_check.setChecked(config["use_augmentation"])
|
|
3294
|
+
if "virtual_expansion" in config: self.virtual_expansion_spin.setValue(int(config["virtual_expansion"]))
|
|
3295
|
+
if "stitch_augmentation_prob" in config:
|
|
3296
|
+
prob = float(config["stitch_augmentation_prob"])
|
|
3297
|
+
self.use_stitch_check.setChecked(prob > 0.0)
|
|
3298
|
+
self.stitch_prob_spin.setValue(prob)
|
|
3299
|
+
if "emb_aug_versions" in config and hasattr(self, "emb_aug_versions_spin"):
|
|
3300
|
+
self.emb_aug_versions_spin.setValue(int(config["emb_aug_versions"]))
|
|
3301
|
+
if "augmentation_options" in config:
|
|
3302
|
+
self.augmentation_options = self._normalize_augmentation_options(config["augmentation_options"])
|
|
3303
|
+
|
|
3304
|
+
if "head_kwargs" in config:
|
|
3305
|
+
hk = config["head_kwargs"]
|
|
3306
|
+
if "num_heads" in hk: self.map_num_heads_spin.setValue(hk["num_heads"])
|
|
3307
|
+
|
|
3308
|
+
if "limit_classes" in config:
|
|
3309
|
+
self.select_classes_check.setChecked(config["limit_classes"])
|
|
3310
|
+
# Restore selected classes if possible
|
|
3311
|
+
if config["limit_classes"] and "selected_classes" in config:
|
|
3312
|
+
selected_set = set(config["selected_classes"])
|
|
3313
|
+
# Ensure list is populated (might need refresh if empty, but usually populated)
|
|
3314
|
+
if self.class_selection_list.count() == 0:
|
|
3315
|
+
self.refresh_annotation_info()
|
|
3316
|
+
|
|
3317
|
+
for i in range(self.class_selection_list.count()):
|
|
3318
|
+
item = self.class_selection_list.item(i)
|
|
3319
|
+
if item.text() in selected_set:
|
|
3320
|
+
item.setCheckState(Qt.CheckState.Checked)
|
|
3321
|
+
else:
|
|
3322
|
+
item.setCheckState(Qt.CheckState.Unchecked)
|
|
3323
|
+
|
|
3324
|
+
if "limit_per_class" in config:
|
|
3325
|
+
self.limit_per_class_check.setChecked(config["limit_per_class"])
|
|
3326
|
+
# Restore per-class limits
|
|
3327
|
+
if config["limit_per_class"]:
|
|
3328
|
+
limits = config.get("per_class_limits", {})
|
|
3329
|
+
val_limits = config.get("per_class_val_limits", {})
|
|
3330
|
+
|
|
3331
|
+
if self.per_class_limit_table.rowCount() == 0:
|
|
3332
|
+
self.refresh_annotation_info()
|
|
3333
|
+
|
|
3334
|
+
for row in range(self.per_class_limit_table.rowCount()):
|
|
3335
|
+
class_item = self.per_class_limit_table.item(row, 0)
|
|
3336
|
+
if class_item:
|
|
3337
|
+
name = class_item.text()
|
|
3338
|
+
if name in limits:
|
|
3339
|
+
self.per_class_limit_table.setItem(row, 1, QTableWidgetItem(str(limits[name])))
|
|
3340
|
+
if name in val_limits:
|
|
3341
|
+
self.per_class_limit_table.setItem(row, 2, QTableWidgetItem(str(val_limits[name])))
|
|
3342
|
+
|
|
3343
|
+
if "use_embedding_diversity" in config: self.use_embedding_diversity_check.setChecked(config["use_embedding_diversity"])
|
|
3344
|
+
if "pretrained_path" in config:
|
|
3345
|
+
if config["pretrained_path"]:
|
|
3346
|
+
self.finetune_check.setChecked(True)
|
|
3347
|
+
self.pretrained_path_edit.setText(config["pretrained_path"])
|
|
3348
|
+
else:
|
|
3349
|
+
self.finetune_check.setChecked(False)
|
|
3350
|
+
self._sync_stitch_controls()
|
|
3351
|
+
|
|
3352
|
+
except Exception as e:
|
|
3353
|
+
logger.error("Error applying config: %s", e)
|
|
3354
|
+
raise
|
|
3355
|
+
|
|
3356
|
+
def _run_next_batch_item(self):
|
|
3357
|
+
"""Run the next profile in the batch queue."""
|
|
3358
|
+
if not self.training_queue:
|
|
3359
|
+
self.is_batch_training = False
|
|
3360
|
+
|
|
3361
|
+
# Show batch summary
|
|
3362
|
+
if self.batch_results:
|
|
3363
|
+
best_result = max(
|
|
3364
|
+
self.batch_results,
|
|
3365
|
+
key=lambda x: (
|
|
3366
|
+
x.get("best_val_f1", 0.0),
|
|
3367
|
+
x.get("best_val_acc", 0.0)
|
|
3368
|
+
)
|
|
3369
|
+
)
|
|
3370
|
+
summary = f"Batch training completed!\n\n"
|
|
3371
|
+
summary += f"Best Profile: {best_result['profile_name']}\n"
|
|
3372
|
+
summary += f"Best Val Macro F1: {best_result.get('best_val_f1', 0.0):.2f}%\n"
|
|
3373
|
+
summary += f"Best Val Acc: {best_result.get('best_val_acc', 0.0):.2f}%\n\n"
|
|
3374
|
+
summary += f"Results saved to:\n{self.batch_results_path}"
|
|
3375
|
+
QMessageBox.information(self, "Batch Training Complete", summary)
|
|
3376
|
+
else:
|
|
3377
|
+
QMessageBox.information(self, "Batch Training", "Batch training completed!")
|
|
3378
|
+
|
|
3379
|
+
self.train_btn.setEnabled(True)
|
|
3380
|
+
self.stop_btn.setEnabled(False)
|
|
3381
|
+
self.progress_bar.setVisible(False)
|
|
3382
|
+
return
|
|
3383
|
+
|
|
3384
|
+
profile_name, config = self.training_queue.pop(0)
|
|
3385
|
+
self.current_profile_name = profile_name
|
|
3386
|
+
self.log_text.appendPlainText(f"\n=== Starting Batch Item: {profile_name} ===")
|
|
3387
|
+
|
|
3388
|
+
# Apply config to UI (so it's visible what's running)
|
|
3389
|
+
try:
|
|
3390
|
+
self.apply_training_config(config)
|
|
3391
|
+
except Exception as e:
|
|
3392
|
+
self.log_text.appendPlainText(f"Error applying profile {profile_name}: {e}")
|
|
3393
|
+
self._run_next_batch_item() # Skip bad profile
|
|
3394
|
+
return
|
|
3395
|
+
|
|
3396
|
+
# Determine output path with profile suffix
|
|
3397
|
+
base_output = self.output_path_edit.text().strip()
|
|
3398
|
+
dir_name = os.path.dirname(base_output)
|
|
3399
|
+
file_name = os.path.basename(base_output)
|
|
3400
|
+
name_root, ext = os.path.splitext(file_name)
|
|
3401
|
+
new_output = os.path.join(dir_name, f"{name_root}_{profile_name}{ext}")
|
|
3402
|
+
|
|
3403
|
+
self._start_training_internal(override_output_path=new_output, profile_name=profile_name)
|
|
3404
|
+
|
|
3405
|
+
def _start_training_internal(self, override_output_path=None, profile_name=None):
|
|
3406
|
+
"""Internal method to start a single training run."""
|
|
3407
|
+
annotation_file = self.annotation_file_edit.text().strip()
|
|
3408
|
+
if not os.path.exists(annotation_file):
|
|
3409
|
+
if not self.is_batch_training: QMessageBox.warning(self, "Error", "Annotation file not found.")
|
|
3410
|
+
else: self.log_text.appendPlainText("Error: Annotation file not found.")
|
|
3411
|
+
return
|
|
3412
|
+
|
|
3413
|
+
clips_dir = self.clips_dir_edit.text().strip()
|
|
3414
|
+
if not os.path.exists(clips_dir):
|
|
3415
|
+
if not self.is_batch_training: QMessageBox.warning(self, "Error", "Clips directory not found.")
|
|
3416
|
+
else: self.log_text.appendPlainText("Error: Clips directory not found.")
|
|
3417
|
+
return
|
|
3418
|
+
|
|
3419
|
+
output_path = override_output_path or self.output_path_edit.text().strip()
|
|
3420
|
+
if not output_path:
|
|
3421
|
+
if not self.is_batch_training: QMessageBox.warning(self, "Error", "Please specify output path.")
|
|
3422
|
+
else: self.log_text.appendPlainText("Error: No output path.")
|
|
3423
|
+
return
|
|
3424
|
+
|
|
3425
|
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
3426
|
+
|
|
3427
|
+
# Get configuration from UI
|
|
3428
|
+
train_config = self.get_training_config()
|
|
3429
|
+
if profile_name:
|
|
3430
|
+
train_config["profile_name"] = profile_name
|
|
3431
|
+
|
|
3432
|
+
train_config["training_profiles_path"] = self._get_training_profiles_path()
|
|
3433
|
+
|
|
3434
|
+
raw_resolution = int(self.config.get("resolution", 288))
|
|
3435
|
+
adjusted_resolution = train_config.get("resolution", raw_resolution)
|
|
3436
|
+
if raw_resolution % 18 != 0:
|
|
3437
|
+
msg = (
|
|
3438
|
+
"Resolution must be a multiple of 18 (patch size).\n\n"
|
|
3439
|
+
f"Entered: {raw_resolution}\n"
|
|
3440
|
+
f"Adjusted to: {adjusted_resolution}"
|
|
3441
|
+
)
|
|
3442
|
+
if not self.is_batch_training:
|
|
3443
|
+
QMessageBox.information(self, "Resolution adjusted", msg)
|
|
3444
|
+
else:
|
|
3445
|
+
self.log_text.appendPlainText(f"Resolution adjusted: {msg}")
|
|
3446
|
+
self._resolution = adjusted_resolution
|
|
3447
|
+
|
|
3448
|
+
# Check pretrained path validity if needed
|
|
3449
|
+
if self.finetune_check.isChecked():
|
|
3450
|
+
if not train_config["pretrained_path"] or not os.path.exists(train_config["pretrained_path"]):
|
|
3451
|
+
if not self.is_batch_training: QMessageBox.warning(self, "Error", "Please select a valid pretrained model file.")
|
|
3452
|
+
else: self.log_text.appendPlainText("Error: Invalid pretrained path.")
|
|
3453
|
+
return
|
|
3454
|
+
|
|
3455
|
+
if train_config.get("auto_tune_before_final", False):
|
|
3456
|
+
if train_config.get("use_all_for_training", False) or float(train_config.get("val_split", 0.0)) <= 0.0:
|
|
3457
|
+
msg = "Auto-tune requires a validation split. Disable 'Use all data for training' and set validation split > 0."
|
|
3458
|
+
if not self.is_batch_training: QMessageBox.warning(self, "Error", msg)
|
|
3459
|
+
else: self.log_text.appendPlainText(f"Error: {msg}")
|
|
3460
|
+
return
|
|
3461
|
+
|
|
3462
|
+
head_kwargs = train_config["head_kwargs"]
|
|
3463
|
+
|
|
3464
|
+
# Persist actual training parameters to experiment config.yaml
|
|
3465
|
+
try:
|
|
3466
|
+
config_path = self.config.get("config_path")
|
|
3467
|
+
if config_path:
|
|
3468
|
+
# Update last-used defaults for convenience
|
|
3469
|
+
self.config["default_batch_size"] = train_config["batch_size"]
|
|
3470
|
+
self.config["default_epochs"] = train_config["epochs"]
|
|
3471
|
+
self.config["default_learning_rate"] = train_config["classification_lr"]
|
|
3472
|
+
self.config["default_localization_lr"] = train_config["localization_lr"]
|
|
3473
|
+
self.config["default_classification_lr"] = train_config["classification_lr"]
|
|
3474
|
+
self.config["default_use_scheduler"] = train_config["use_scheduler"]
|
|
3475
|
+
self.config["default_use_ema"] = train_config["use_ema"]
|
|
3476
|
+
self.config["default_weight_decay"] = train_config["weight_decay"]
|
|
3477
|
+
self.config["default_clip_length"] = train_config["clip_length"]
|
|
3478
|
+
self.config["default_use_focal_loss"] = train_config["use_focal_loss"]
|
|
3479
|
+
self.config["default_focal_gamma"] = train_config["focal_gamma"]
|
|
3480
|
+
self.config["default_use_supcon_loss"] = train_config["use_supcon_loss"]
|
|
3481
|
+
self.config["default_supcon_weight"] = train_config["supcon_weight"]
|
|
3482
|
+
self.config["default_supcon_temperature"] = train_config["supcon_temperature"]
|
|
3483
|
+
self.config["backbone_model"] = train_config["backbone_model"]
|
|
3484
|
+
self.config["resolution"] = train_config["resolution"]
|
|
3485
|
+
# Save full last training block
|
|
3486
|
+
self.config["last_training"] = {
|
|
3487
|
+
"parameters": {
|
|
3488
|
+
"batch_size": train_config["batch_size"],
|
|
3489
|
+
"epochs": train_config["epochs"],
|
|
3490
|
+
"lr": train_config["classification_lr"],
|
|
3491
|
+
"localization_lr": train_config["localization_lr"],
|
|
3492
|
+
"classification_lr": train_config["classification_lr"],
|
|
3493
|
+
"use_scheduler": train_config["use_scheduler"],
|
|
3494
|
+
"use_ema": train_config["use_ema"],
|
|
3495
|
+
"weight_decay": train_config["weight_decay"],
|
|
3496
|
+
"clip_length": train_config["clip_length"],
|
|
3497
|
+
"val_split": train_config["val_split"],
|
|
3498
|
+
"auto_tune_before_final": train_config["auto_tune_before_final"],
|
|
3499
|
+
"auto_tune_runs": train_config["auto_tune_runs"],
|
|
3500
|
+
"auto_tune_epochs": train_config["auto_tune_epochs"],
|
|
3501
|
+
"use_class_weights": train_config["use_class_weights"],
|
|
3502
|
+
"use_focal_loss": train_config["use_focal_loss"],
|
|
3503
|
+
"focal_gamma": train_config["focal_gamma"],
|
|
3504
|
+
"use_weighted_sampler": train_config["use_weighted_sampler"],
|
|
3505
|
+
"use_augmentation": train_config["use_augmentation"],
|
|
3506
|
+
"augmentation_options": train_config.get("augmentation_options", {}),
|
|
3507
|
+
"limit_classes": train_config["limit_classes"],
|
|
3508
|
+
"limit_per_class": train_config["limit_per_class"],
|
|
3509
|
+
},
|
|
3510
|
+
"selected_classes": train_config["selected_classes"],
|
|
3511
|
+
"per_class_limits": train_config["per_class_limits"],
|
|
3512
|
+
"head": {
|
|
3513
|
+
"dropout": train_config["dropout"],
|
|
3514
|
+
"map_head_kwargs": train_config["head_kwargs"],
|
|
3515
|
+
},
|
|
3516
|
+
"pretrained_path": train_config["pretrained_path"],
|
|
3517
|
+
"output_path": output_path
|
|
3518
|
+
}
|
|
3519
|
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
3520
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
3521
|
+
yaml.safe_dump(dict(self.config), f, sort_keys=False)
|
|
3522
|
+
self.log_text.appendPlainText(f"Saved training parameters to: {config_path}")
|
|
3523
|
+
except Exception as e:
|
|
3524
|
+
self.log_text.appendPlainText(f"Warning: Could not save training parameters to config.yaml: {e}")
|
|
3525
|
+
|
|
3526
|
+
self.log_text.appendPlainText(f"Starting training run (Output: {os.path.basename(output_path)})...")
|
|
3527
|
+
self.progress_bar.setVisible(True)
|
|
3528
|
+
self.progress_bar.setValue(0)
|
|
3529
|
+
self.train_btn.setEnabled(False)
|
|
3530
|
+
self.stop_btn.setEnabled(True)
|
|
3531
|
+
self.visualize_btn.setEnabled(True)
|
|
3532
|
+
|
|
3533
|
+
# Open visualization automatically on start
|
|
3534
|
+
self._open_visualization()
|
|
3535
|
+
|
|
3536
|
+
self.worker = TrainingWorker(
|
|
3537
|
+
self.config,
|
|
3538
|
+
train_config,
|
|
3539
|
+
annotation_file,
|
|
3540
|
+
clips_dir,
|
|
3541
|
+
output_path
|
|
3542
|
+
)
|
|
3543
|
+
self.worker.log_message.connect(self._on_log)
|
|
3544
|
+
self.worker.progress.connect(self._on_progress)
|
|
3545
|
+
self.worker.training_complete.connect(self._on_training_complete)
|
|
3546
|
+
self.worker.finished.connect(self._on_finished)
|
|
3547
|
+
self.worker.error.connect(self._on_error)
|
|
3548
|
+
# Reset and connect visualization update; restrict F1 graph to selected classes when limit_classes is on
|
|
3549
|
+
if self.visualization_dialog:
|
|
3550
|
+
f1_exclude = set(train_config.get("f1_exclude_classes", []))
|
|
3551
|
+
f1_classes = None
|
|
3552
|
+
if train_config.get("limit_classes", False):
|
|
3553
|
+
sel = train_config.get("selected_classes", [])
|
|
3554
|
+
if sel:
|
|
3555
|
+
f1_classes = set(sel) - f1_exclude
|
|
3556
|
+
elif f1_exclude:
|
|
3557
|
+
f1_classes = set(train_config.get("class_names", [])) - f1_exclude
|
|
3558
|
+
confusion_warmup_ep = 0
|
|
3559
|
+
if train_config.get("use_confusion_sampler", False):
|
|
3560
|
+
pct = float(train_config.get("confusion_sampler_warmup_pct", 0.2))
|
|
3561
|
+
confusion_warmup_ep = int(train_config.get("epochs", 60) * pct)
|
|
3562
|
+
self.visualization_dialog.reset(
|
|
3563
|
+
f1_classes_to_show=f1_classes if f1_classes else None,
|
|
3564
|
+
confusion_warmup_epoch=confusion_warmup_ep,
|
|
3565
|
+
)
|
|
3566
|
+
self.worker.epoch_complete.connect(self.visualization_dialog.update_plots)
|
|
3567
|
+
|
|
3568
|
+
self.worker.start()
|
|
3569
|
+
|
|
3570
|
+
def _start_training(self):
|
|
3571
|
+
"""Start training (single or batch)."""
|
|
3572
|
+
if self.worker and self.worker.isRunning():
|
|
3573
|
+
QMessageBox.warning(self, "Training", "Training is already running.")
|
|
3574
|
+
return
|
|
3575
|
+
|
|
3576
|
+
if self.batch_train_check.isChecked():
|
|
3577
|
+
if not self.profile_dialog:
|
|
3578
|
+
self.profile_dialog = TrainingProfileDialog(self, profiles_file=self._get_training_profiles_path())
|
|
3579
|
+
else:
|
|
3580
|
+
self.profile_dialog.reload_profiles(self._get_training_profiles_path())
|
|
3581
|
+
|
|
3582
|
+
profiles = self.profile_dialog.get_selected_profiles_for_batch()
|
|
3583
|
+
if not profiles:
|
|
3584
|
+
QMessageBox.warning(self, "Batch Training", "No profiles selected in Advanced > Profiles.\nPlease open Advanced settings and check profiles to run.")
|
|
3585
|
+
return
|
|
3586
|
+
|
|
3587
|
+
self.training_queue = profiles
|
|
3588
|
+
self.is_batch_training = True
|
|
3589
|
+
self.batch_results = []
|
|
3590
|
+
|
|
3591
|
+
# Initialize batch results CSV path
|
|
3592
|
+
from datetime import datetime
|
|
3593
|
+
output_dir = os.path.dirname(self.output_path_edit.text().strip())
|
|
3594
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
3595
|
+
self.batch_results_path = os.path.join(output_dir, f"batch_training_results_{timestamp}.csv")
|
|
3596
|
+
self.log_text.appendPlainText(f"Batch results will be saved to: {self.batch_results_path}")
|
|
3597
|
+
|
|
3598
|
+
self.train_btn.setEnabled(False)
|
|
3599
|
+
self.stop_btn.setEnabled(True)
|
|
3600
|
+
self._run_next_batch_item()
|
|
3601
|
+
else:
|
|
3602
|
+
self._start_training_internal()
|
|
3603
|
+
|
|
3604
|
+
def _stop_training(self):
|
|
3605
|
+
"""Stop training."""
|
|
3606
|
+
if self.worker and self.worker.isRunning():
|
|
3607
|
+
self.is_batch_training = False # Stop batch execution
|
|
3608
|
+
self.training_queue = [] # Clear queue
|
|
3609
|
+
self.worker.stop()
|
|
3610
|
+
self.log_text.appendPlainText("Stopping training...")
|
|
3611
|
+
self.worker.wait()
|
|
3612
|
+
self._on_finished()
|
|
3613
|
+
|
|
3614
|
+
def _on_log(self, message: str):
|
|
3615
|
+
"""Handle log message."""
|
|
3616
|
+
self.log_text.appendPlainText(message)
|
|
3617
|
+
|
|
3618
|
+
def _on_progress(self, epoch: int, total: int):
|
|
3619
|
+
"""Update progress bar."""
|
|
3620
|
+
if total > 0:
|
|
3621
|
+
progress = int(100 * epoch / total)
|
|
3622
|
+
self.progress_bar.setValue(progress)
|
|
3623
|
+
|
|
3624
|
+
def _on_training_complete(self, best_val_acc: float, best_val_f1: float, final_train_acc: float, per_class_f1: dict):
|
|
3625
|
+
"""Handle training metrics from completed run."""
|
|
3626
|
+
summary_msg = (
|
|
3627
|
+
f"Training summary → Best Val Macro F1: {best_val_f1:.2f}%, "
|
|
3628
|
+
f"Best Val Acc: {best_val_acc:.2f}%, Final Train Acc: {final_train_acc:.2f}%"
|
|
3629
|
+
)
|
|
3630
|
+
self.log_text.appendPlainText(summary_msg)
|
|
3631
|
+
|
|
3632
|
+
if per_class_f1:
|
|
3633
|
+
self.log_text.appendPlainText("Per-class F1 scores (Final Epoch):")
|
|
3634
|
+
for cls_name, f1 in sorted(per_class_f1.items()):
|
|
3635
|
+
self.log_text.appendPlainText(f" - {cls_name}: {f1:.2f}%")
|
|
3636
|
+
|
|
3637
|
+
if self.is_batch_training and self.current_profile_name:
|
|
3638
|
+
# Get current config for this profile
|
|
3639
|
+
current_config = self.get_training_config()
|
|
3640
|
+
|
|
3641
|
+
result = {
|
|
3642
|
+
"profile_name": self.current_profile_name,
|
|
3643
|
+
"best_val_acc": round(best_val_acc, 2),
|
|
3644
|
+
"best_val_f1": round(best_val_f1, 2),
|
|
3645
|
+
"final_train_acc": round(final_train_acc, 2),
|
|
3646
|
+
"epochs": current_config.get("epochs", 0),
|
|
3647
|
+
"batch_size": current_config.get("batch_size", 0),
|
|
3648
|
+
"lr": current_config.get("lr", 0),
|
|
3649
|
+
"use_focal_loss": current_config.get("use_focal_loss", False),
|
|
3650
|
+
}
|
|
3651
|
+
|
|
3652
|
+
# Add per-class F1 columns
|
|
3653
|
+
if per_class_f1:
|
|
3654
|
+
for cls_name, f1 in per_class_f1.items():
|
|
3655
|
+
# Sanitize column name
|
|
3656
|
+
safe_name = f"F1_{cls_name}".replace(" ", "_")
|
|
3657
|
+
result[safe_name] = round(f1, 2)
|
|
3658
|
+
|
|
3659
|
+
self.batch_results.append(result)
|
|
3660
|
+
|
|
3661
|
+
# Save/update CSV after each profile
|
|
3662
|
+
self._save_batch_results_csv()
|
|
3663
|
+
|
|
3664
|
+
def _save_batch_results_csv(self):
|
|
3665
|
+
"""Save batch results to CSV."""
|
|
3666
|
+
if not self.batch_results or not self.batch_results_path:
|
|
3667
|
+
return
|
|
3668
|
+
try:
|
|
3669
|
+
import pandas as pd
|
|
3670
|
+
df = pd.DataFrame(self.batch_results)
|
|
3671
|
+
# Sort by best_val_f1 (primary) then accuracy
|
|
3672
|
+
sort_cols = [col for col in ["best_val_f1", "best_val_acc"] if col in df.columns]
|
|
3673
|
+
if sort_cols:
|
|
3674
|
+
df = df.sort_values(sort_cols, ascending=False)
|
|
3675
|
+
df.to_csv(self.batch_results_path, index=False)
|
|
3676
|
+
self.log_text.appendPlainText(f" → Updated batch results: {self.batch_results_path}")
|
|
3677
|
+
except Exception as e:
|
|
3678
|
+
self.log_text.appendPlainText(f" Failed to save batch results: {e}")
|
|
3679
|
+
|
|
3680
|
+
def _on_finished(self):
|
|
3681
|
+
"""Handle training completion."""
|
|
3682
|
+
if self.is_batch_training:
|
|
3683
|
+
self.refresh_annotation_info()
|
|
3684
|
+
self._run_next_batch_item()
|
|
3685
|
+
return
|
|
3686
|
+
|
|
3687
|
+
self.train_btn.setEnabled(True)
|
|
3688
|
+
self.stop_btn.setEnabled(False)
|
|
3689
|
+
self.visualize_btn.setEnabled(False)
|
|
3690
|
+
self.progress_bar.setVisible(False)
|
|
3691
|
+
if self.profile_dialog:
|
|
3692
|
+
self.profile_dialog.reload_profiles(self._get_training_profiles_path())
|
|
3693
|
+
QMessageBox.information(self, "Training", "Training completed!")
|
|
3694
|
+
self.refresh_annotation_info()
|
|
3695
|
+
|
|
3696
|
+
def _on_error(self, error_msg: str):
|
|
3697
|
+
"""Handle training error."""
|
|
3698
|
+
self.log_text.appendPlainText(f"ERROR: {error_msg}")
|
|
3699
|
+
|
|
3700
|
+
if self.is_batch_training:
|
|
3701
|
+
QMessageBox.critical(self, "Training Error", f"Batch training failed on current item:\n{error_msg}\n\nBatch execution stopped.")
|
|
3702
|
+
# Stop batch on error
|
|
3703
|
+
self.is_batch_training = False
|
|
3704
|
+
self.train_btn.setEnabled(True)
|
|
3705
|
+
self.stop_btn.setEnabled(False)
|
|
3706
|
+
self.progress_bar.setVisible(False)
|
|
3707
|
+
else:
|
|
3708
|
+
self._on_finished()
|
|
3709
|
+
QMessageBox.critical(self, "Training Error", f"Training failed:\n{error_msg}")
|
|
3710
|
+
|
|
3711
|
+
def update_config(self, config: dict):
|
|
3712
|
+
"""Apply a new configuration (experiment management)."""
|
|
3713
|
+
self.config = config
|
|
3714
|
+
self.annotation_manager = AnnotationManager(
|
|
3715
|
+
self.config.get("annotation_file", "data/annotations/annotations.json")
|
|
3716
|
+
)
|
|
3717
|
+
self._config_initialized = False
|
|
3718
|
+
self._load_current_config(force=True)
|
|
3719
|
+
self.refresh_annotation_info()
|