singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,4550 @@
1
+ from PyQt6.QtWidgets import (
2
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QLineEdit,
3
+ QFileDialog, QGroupBox, QFormLayout, QMessageBox, QListWidget, QListWidgetItem,
4
+ QSpinBox, QComboBox, QTextEdit, QScrollArea, QDialog, QCheckBox, QSizePolicy,
5
+ QProgressBar, QProgressDialog, QDoubleSpinBox, QDialogButtonBox, QApplication,
6
+ QSplitter, QGridLayout,
7
+ )
8
+ from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer, QEvent
9
+ from PyQt6.QtGui import QPixmap, QImage, QPainter, QFont, QColor, QMouseEvent
10
+ import copy
11
+ import cv2
12
+ import os
13
+ import torch
14
+ import numpy as np
15
+ import json
16
+ from singlebehaviorlab.backend.model import VideoPrismBackbone, BehaviorClassifier
17
+ from singlebehaviorlab.backend.video_utils import get_video_info, save_clip
18
+ from singlebehaviorlab.backend.data_store import AnnotationManager
19
+ from singlebehaviorlab.backend.uncertainty import save_uncertainty_report
20
+ from .timeline_themes import TIMELINE_COLOR_THEMES, DEFAULT_THEME, get_palette as get_timeline_palette
21
+ from .inference_worker import InferenceWorker, _sanitize_bbox_coords
22
+ from .inference_popups import ClipPopupDialog, FrameSegmentPopupDialog
23
+
24
+
25
+ class TimelineWidget(QWidget):
26
+ """Custom widget for clickable timeline."""
27
+
28
+ def __init__(self, parent=None):
29
+ super().__init__(parent)
30
+ self.clip_width = 20
31
+ self.num_clips = 0
32
+ self.click_callback = None
33
+ self._frame_mode = False
34
+ self._pixels_per_frame = 1.0
35
+ self._total_frames = 0
36
+ self.setMinimumHeight(100)
37
+ self.setStyleSheet("background-color: white; border: 1px solid gray;")
38
+
39
+ def mousePressEvent(self, event: QMouseEvent):
40
+ """Handle mouse clicks on timeline."""
41
+ if not self.click_callback:
42
+ return
43
+
44
+ x = event.position().x()
45
+
46
+ if self._frame_mode and self._pixels_per_frame > 0:
47
+ # Frame-aggregated mode: convert pixel to frame index
48
+ frame_idx = int(x / self._pixels_per_frame)
49
+ if 0 <= frame_idx < self._total_frames:
50
+ # Pass frame index with a special marker
51
+ self.click_callback(frame_idx, frame_mode=True)
52
+ elif self.num_clips > 0:
53
+ # Clip-based mode
54
+ clip_idx = int(x / self.clip_width)
55
+ if 0 <= clip_idx < self.num_clips:
56
+ self.click_callback(clip_idx)
57
+
58
+
59
+
60
+ class InferenceWidget(QWidget):
61
+ """Widget for running inference on videos."""
62
+
63
+ # Emitted when inference finishes so the Review tab can load results.
64
+ # Payload: (results_dict, classes, is_ovr, clip_length, target_fps)
65
+ review_ready = pyqtSignal(dict, list, bool, int, int)
66
+
67
+ def __init__(self, config: dict):
68
+ super().__init__()
69
+ self.config = config
70
+ self.model = None
71
+ self.classes = []
72
+ self.attributes = []
73
+ self.label_mapping = None
74
+ self.video_path = None
75
+ self.video_paths = [] # List of selected videos
76
+ self.predictions = []
77
+ self.confidences = []
78
+ self.clip_probabilities = []
79
+ self.clip_frame_probabilities = [] # per-frame probs from FrameClassificationHead
80
+ self.attr_predictions = [] # Store attribute indices
81
+ self.attr_confidences = []
82
+ self.clip_starts = []
83
+ self.localization_bboxes = []
84
+ self.results_cache = {} # Cache for multi-video results
85
+ self.worker = None
86
+ self.exported_video_path = None
87
+ self.corrected_labels = {} # Map clip_idx -> corrected_label_index
88
+ self.corrected_attr_labels = {} # Map clip_idx -> corrected_attr_index
89
+ self.attributes_registry = None # NEW: Store registry
90
+ self.hierarchy_registry = None
91
+ self.aggregated_segments = [] # Frame-level top-1 segments: [{'class': int, 'start': int, 'end': int}, ...]
92
+ self.aggregated_multiclass_segments = [] # OvR frame-level per-class segments
93
+ self._aggregated_frame_scores_norm = None # np.ndarray [frames, classes]
94
+ self._aggregated_active_mask = None # np.ndarray [frames, classes], bool
95
+ self._aggregated_last_covered_frame = 0
96
+ self.total_frames = 0 # Total frames in current video
97
+ self.clip_popup_maximized = False # Track if popup should be maximized
98
+ self.infer_resolution = 288
99
+ self._bbox_ema_alpha = 0.85 # Minimal smoothing: 85% current frame, 15% previous
100
+ self._min_segment_frames = int(self.config.get("inference_min_segment_frames", 1))
101
+ self._merge_gap_frames = int(self.config.get("inference_merge_gap_frames", 0))
102
+ raw_smooth_win = int(self.config.get("inference_temporal_smoothing_window_frames", 1))
103
+ self._temporal_smoothing_window_frames = max(1, raw_smooth_win)
104
+ if self._temporal_smoothing_window_frames % 2 == 0:
105
+ self._temporal_smoothing_window_frames += 1
106
+ self.use_ignore_threshold = bool(self.config.get("inference_use_ignore_threshold", False))
107
+ self.global_ignore_threshold = float(self.config.get("inference_ignore_threshold", 0.60))
108
+ self.class_ignore_thresholds = dict(self.config.get("inference_class_ignore_thresholds", {}))
109
+ self.class_min_segment_frames = {
110
+ str(k): int(v) for k, v in self.config.get("inference_class_min_segment_frames", {}).items()
111
+ }
112
+ self.class_merge_gap_frames = {
113
+ str(k): int(v) for k, v in self.config.get("inference_class_merge_gap_frames", {}).items()
114
+ }
115
+ self.class_smoothing_window_frames = {
116
+ str(k): int(v) for k, v in self.config.get("inference_class_smoothing_window_frames", {}).items()
117
+ }
118
+ self.use_viterbi_decode = bool(self.config.get("inference_use_viterbi_decode", False))
119
+ self.viterbi_switch_penalty = float(self.config.get("inference_viterbi_switch_penalty", 0.35))
120
+ self.ignore_label_name = "Ignored / Unknown"
121
+ self.model_training_config = {}
122
+ self._use_ovr = False
123
+ self._allowed_cooccurrence = set()
124
+ self._ignore_threshold_user_modified = False
125
+ self._applying_auto_threshold = False
126
+ self._setup_ui()
127
+ self._update_viterbi_ui_state()
128
+
129
+ def _setup_ui(self):
130
+ """Setup UI components."""
131
+ layout = QVBoxLayout()
132
+
133
+ model_group = QGroupBox("Model")
134
+ model_layout = QFormLayout()
135
+
136
+ self.model_path_edit = QLineEdit()
137
+ self.model_path_edit.setReadOnly(True)
138
+ self.model_browse_btn = QPushButton("Load head weights...")
139
+ self.model_browse_btn.clicked.connect(self._load_model)
140
+ model_path_layout = QHBoxLayout()
141
+ model_path_layout.addWidget(self.model_path_edit)
142
+ model_path_layout.addWidget(self.model_browse_btn)
143
+ model_layout.addRow("Model:", model_path_layout)
144
+
145
+ self.classes_label = QLabel("No model loaded")
146
+ model_layout.addRow("Classes:", self.classes_label)
147
+
148
+ model_group.setLayout(model_layout)
149
+
150
+ video_group = QGroupBox("Video selection")
151
+ video_layout = QFormLayout()
152
+
153
+ self.video_path_edit = QLineEdit()
154
+ self.video_path_edit.setReadOnly(True)
155
+ self.video_browse_btn = QPushButton("Select video(s)...")
156
+ self.video_browse_btn.clicked.connect(self._select_video)
157
+ video_path_layout = QHBoxLayout()
158
+ video_path_layout.addWidget(self.video_path_edit)
159
+ video_path_layout.addWidget(self.video_browse_btn)
160
+ video_layout.addRow("Video:", video_path_layout)
161
+
162
+ self.video_info_label = QLabel("No video selected")
163
+ video_layout.addRow("Info:", self.video_info_label)
164
+
165
+ video_group.setLayout(video_layout)
166
+
167
+ model_video_layout = QVBoxLayout()
168
+ model_video_layout.addWidget(model_group)
169
+ model_video_layout.addWidget(video_group)
170
+ model_video_widget = QWidget()
171
+ model_video_widget.setLayout(model_video_layout)
172
+
173
+ params_group = QGroupBox("Inference parameters")
174
+ params_layout = QFormLayout()
175
+
176
+ self.target_fps_spin = QSpinBox()
177
+ self.target_fps_spin.setRange(1, 60)
178
+ self.target_fps_spin.setValue(16)
179
+ params_layout.addRow("Target FPS:", self.target_fps_spin)
180
+
181
+ self.clip_length_spin = QSpinBox()
182
+ self.clip_length_spin.setRange(1, 64)
183
+ self.clip_length_spin.setValue(16)
184
+ params_layout.addRow("Frames per clip:", self.clip_length_spin)
185
+
186
+ self.step_frames_spin = QSpinBox()
187
+ self.step_frames_spin.setRange(1, 64)
188
+ self.step_frames_spin.setValue(max(1, self.clip_length_spin.value() // 2))
189
+ self.step_frames_spin.setToolTip(
190
+ "Number of subsampled frames to advance between clips.\n"
191
+ "Defaults to half of 'Frames per clip' (50% overlap) but can be changed.\n"
192
+ "Smaller values mean more overlap and finer temporal resolution."
193
+ )
194
+ self.step_frames_spin.valueChanged.connect(self._on_step_or_clip_changed)
195
+ params_layout.addRow("Step frames:", self.step_frames_spin)
196
+
197
+ self.clip_length_spin.valueChanged.connect(self._on_clip_length_changed)
198
+
199
+ self.resolution_spin = QSpinBox()
200
+ self.resolution_spin.setRange(64, 1024)
201
+ self.resolution_spin.setValue(288)
202
+ params_layout.addRow("Resolution:", self.resolution_spin)
203
+
204
+ self.override_resolution_check = QCheckBox("Override model resolution")
205
+ self.override_resolution_check.setToolTip("Use the resolution above instead of model metadata")
206
+ params_layout.addRow("", self.override_resolution_check)
207
+
208
+ self.collect_attention_check = QCheckBox("Collect attention maps")
209
+ self.collect_attention_check.setToolTip(
210
+ "Record spatial attention weights during inference.\n"
211
+ "Enables exporting a heatmap video showing what the model focuses on.\n"
212
+ "Minimal performance impact."
213
+ )
214
+ params_layout.addRow("", self.collect_attention_check)
215
+
216
+ self.sample_inference_check = QCheckBox("Quick-check sampled inference")
217
+ self.sample_inference_check.setToolTip(
218
+ "Run inference only on evenly spread chunks of each video.\n"
219
+ "Useful for checking model behavior on long videos without running full inference."
220
+ )
221
+ self.sample_inference_check.toggled.connect(self._update_sample_inference_controls)
222
+ params_layout.addRow("", self.sample_inference_check)
223
+
224
+ self.sample_duration_spin = QSpinBox()
225
+ self.sample_duration_spin.setRange(10, 300)
226
+ self.sample_duration_spin.setValue(60)
227
+ self.sample_duration_spin.setSuffix(" s")
228
+ self.sample_duration_spin.setToolTip("Duration of each sampled chunk per video.")
229
+ params_layout.addRow("Chunk duration:", self.sample_duration_spin)
230
+
231
+ self.sample_count_spin = QSpinBox()
232
+ self.sample_count_spin.setRange(1, 50)
233
+ self.sample_count_spin.setValue(5)
234
+ self.sample_count_spin.setToolTip("Number of sampled chunks spread evenly across each video.")
235
+ params_layout.addRow("Number of chunks:", self.sample_count_spin)
236
+ self._update_sample_inference_controls(False)
237
+
238
+ params_group.setLayout(params_layout)
239
+
240
+ top_splitter = QSplitter(Qt.Orientation.Horizontal)
241
+ top_splitter.addWidget(model_video_widget)
242
+ top_splitter.addWidget(params_group)
243
+ top_splitter.setStretchFactor(0, 1)
244
+ top_splitter.setStretchFactor(1, 1)
245
+ layout.addWidget(top_splitter)
246
+
247
+ button_layout = QHBoxLayout()
248
+ self.run_inference_btn = QPushButton("Run inference")
249
+ self.run_inference_btn.clicked.connect(self._run_inference)
250
+ self.run_inference_btn.setEnabled(False)
251
+ button_layout.addWidget(self.run_inference_btn)
252
+
253
+ self.stop_inference_btn = QPushButton("Stop inference")
254
+ self.stop_inference_btn.clicked.connect(self._stop_inference)
255
+ self.stop_inference_btn.setEnabled(False)
256
+ self.stop_inference_btn.setToolTip("Stop batch inference; already-completed videos are kept.")
257
+ button_layout.addWidget(self.stop_inference_btn)
258
+
259
+ self.load_timeline_btn = QPushButton("Load timeline results")
260
+ self.load_timeline_btn.clicked.connect(self._load_timeline_results)
261
+ button_layout.addWidget(self.load_timeline_btn)
262
+
263
+ self.export_btn = QPushButton("Export video with overlays")
264
+ self.export_btn.clicked.connect(self._export_video_with_overlay)
265
+ self.export_btn.setEnabled(False)
266
+ button_layout.addWidget(self.export_btn)
267
+
268
+ self.preview_btn = QPushButton("Preview video with overlays")
269
+ self.preview_btn.clicked.connect(self._preview_video_with_overlay)
270
+ self.preview_btn.setEnabled(False)
271
+ button_layout.addWidget(self.preview_btn)
272
+
273
+ self.export_attention_btn = QPushButton("Export attention heatmap")
274
+ self.export_attention_btn.clicked.connect(self._export_attention_heatmap)
275
+ self.export_attention_btn.setEnabled(False)
276
+ self.export_attention_btn.setToolTip("Export video with spatial attention overlay showing what the model focuses on")
277
+ button_layout.addWidget(self.export_attention_btn)
278
+
279
+ layout.addLayout(button_layout)
280
+
281
+ self.progress_bar = QProgressBar()
282
+ self.progress_bar.setVisible(False)
283
+ self.progress_bar.setRange(0, 100)
284
+ self.progress_bar.setValue(0)
285
+ layout.addWidget(self.progress_bar)
286
+
287
+ self.progress_label = QLabel("")
288
+ layout.addWidget(self.progress_label)
289
+
290
+ timeline_group = QGroupBox("Timeline")
291
+ timeline_group_layout = QVBoxLayout()
292
+ timeline_controls_layout = QHBoxLayout()
293
+ timeline_controls_layout.setSpacing(8)
294
+ timeline_controls_layout.setContentsMargins(4, 2, 4, 2)
295
+ timeline_label_text = QLabel("Timeline:")
296
+ timeline_controls_layout.addWidget(timeline_label_text)
297
+
298
+ timeline_controls_layout.addWidget(QLabel("Video:"))
299
+ self.filter_video_combo = QComboBox()
300
+ self.filter_video_combo.currentIndexChanged.connect(self._on_video_selection_changed)
301
+ timeline_controls_layout.addWidget(self.filter_video_combo)
302
+
303
+ timeline_controls_layout.addStretch()
304
+
305
+ timeline_controls_layout.addWidget(QLabel("Show behavior:"))
306
+ self.filter_behavior_combo = QComboBox()
307
+ self.filter_behavior_combo.addItem("All Behaviors")
308
+ self.filter_behavior_combo.currentIndexChanged.connect(self._on_filter_changed)
309
+ timeline_controls_layout.addWidget(self.filter_behavior_combo)
310
+ self.use_ignore_threshold_check = QCheckBox("Ignore low-confidence")
311
+ self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
312
+ self.use_ignore_threshold_check.stateChanged.connect(self._on_ignore_threshold_changed)
313
+ timeline_controls_layout.addWidget(self.use_ignore_threshold_check)
314
+ self.ignore_threshold_spin = QDoubleSpinBox()
315
+ self.ignore_threshold_spin.setDecimals(2)
316
+ self.ignore_threshold_spin.setRange(0.0, 1.0)
317
+ self.ignore_threshold_spin.setSingleStep(0.05)
318
+ self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
319
+ self.ignore_threshold_spin.setToolTip(
320
+ "Fallback threshold for classes without a per-class threshold.\n"
321
+ "Clips below this confidence are grayed out."
322
+ )
323
+ self.ignore_threshold_spin.valueChanged.connect(self._on_ignore_threshold_changed)
324
+ timeline_controls_layout.addWidget(QLabel("Default τ:"))
325
+ timeline_controls_layout.addWidget(self.ignore_threshold_spin)
326
+ self.per_class_thresh_btn = QPushButton("Per-class τ")
327
+ self.per_class_thresh_btn.clicked.connect(self._open_per_class_thresholds_dialog)
328
+ timeline_controls_layout.addWidget(self.per_class_thresh_btn)
329
+
330
+ timeline_controls_layout.addWidget(QLabel("Theme:"))
331
+ self.timeline_theme_combo = QComboBox()
332
+ self.timeline_theme_combo.addItems(list(TIMELINE_COLOR_THEMES.keys()))
333
+ self.timeline_theme_combo.setCurrentText(DEFAULT_THEME)
334
+ self.timeline_theme_combo.setToolTip("Change timeline color theme")
335
+ self.timeline_theme_combo.currentIndexChanged.connect(self._on_theme_changed)
336
+ timeline_controls_layout.addWidget(self.timeline_theme_combo)
337
+
338
+ self.merge_timeline_check = QCheckBox("Merge consecutive identical behaviors")
339
+ self.merge_timeline_check.setToolTip("Merge consecutive clips with the same predicted behavior")
340
+ self.merge_timeline_check.stateChanged.connect(self._on_merge_changed)
341
+ timeline_controls_layout.addWidget(self.merge_timeline_check)
342
+
343
+ self.frame_aggregation_check = QCheckBox("Precise frame boundaries")
344
+ self.frame_aggregation_check.setToolTip(
345
+ "Use overlapping clip votes to determine precise behavior boundaries.\n"
346
+ "Best used with step_frames < clip_length (e.g., step=4 for 16-frame clips).\n"
347
+ "Each frame gets votes from all clips that cover it, weighted by confidence."
348
+ )
349
+ self.frame_aggregation_check.stateChanged.connect(self._on_frame_aggregation_changed)
350
+ timeline_controls_layout.addWidget(self.frame_aggregation_check)
351
+
352
+ self.use_viterbi_check = QCheckBox("Viterbi decode")
353
+ self.use_viterbi_check.setToolTip(
354
+ "Inference-only sequence decoding on the merged frame probabilities.\n"
355
+ "Single-label models use classic Viterbi over classes.\n"
356
+ "OvR models use binary per-class Viterbi with co-occurrence-aware cleanup."
357
+ )
358
+ self.use_viterbi_check.setChecked(self.use_viterbi_decode)
359
+ self.use_viterbi_check.stateChanged.connect(self._on_viterbi_changed)
360
+ timeline_controls_layout.addWidget(self.use_viterbi_check)
361
+
362
+ timeline_controls_layout.addWidget(QLabel("Viterbi switch:"))
363
+ self.viterbi_switch_penalty_spin = QDoubleSpinBox()
364
+ self.viterbi_switch_penalty_spin.setDecimals(2)
365
+ self.viterbi_switch_penalty_spin.setRange(0.0, 5.0)
366
+ self.viterbi_switch_penalty_spin.setSingleStep(0.05)
367
+ self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
368
+ self.viterbi_switch_penalty_spin.setToolTip(
369
+ "Penalty for changing behavior between adjacent frames.\n"
370
+ "Higher values make the decoded sequence more stable."
371
+ )
372
+ self.viterbi_switch_penalty_spin.valueChanged.connect(self._on_viterbi_changed)
373
+ timeline_controls_layout.addWidget(self.viterbi_switch_penalty_spin)
374
+
375
+ self.per_class_seg_btn = QPushButton("Per-class seg rules")
376
+ self.per_class_seg_btn.setToolTip(
377
+ "Set smooth window, gap fill, and minimum segment length per behavior class."
378
+ )
379
+ self.per_class_seg_btn.clicked.connect(self._open_per_class_segment_rules_dialog)
380
+ timeline_controls_layout.addWidget(self.per_class_seg_btn)
381
+
382
+ self.ovr_rows_check = QCheckBox("Per-class rows")
383
+ self.ovr_rows_check.setChecked(True)
384
+ self.ovr_rows_check.setToolTip(
385
+ "Per-class rows: one row per class. Uncheck for single-row timeline (OvR models only)."
386
+ )
387
+ self.ovr_rows_check.stateChanged.connect(lambda: self._draw_timeline())
388
+ timeline_controls_layout.addWidget(self.ovr_rows_check)
389
+
390
+ self.ovr_show_all_check = QCheckBox("Show all classes")
391
+ self.ovr_show_all_check.setChecked(False)
392
+ self.ovr_show_all_check.setToolTip(
393
+ "When enabled, every class above its threshold is shown independently\n"
394
+ "(no mutual exclusivity). When disabled, only the top-1 class and\n"
395
+ "allowed co-occurring classes are shown."
396
+ )
397
+ self.ovr_show_all_check.stateChanged.connect(self._on_ovr_show_all_changed)
398
+ timeline_controls_layout.addWidget(self.ovr_show_all_check)
399
+
400
+ timeline_controls_layout.addWidget(QLabel("Zoom (px/s):"))
401
+ self.timeline_zoom_spin = QSpinBox()
402
+ self.timeline_zoom_spin.setRange(10, 2000)
403
+ self.timeline_zoom_spin.setValue(100)
404
+ self.timeline_zoom_spin.setSingleStep(20)
405
+ self.timeline_zoom_spin.setToolTip(
406
+ "Timeline horizontal zoom: pixels per second of video.\n"
407
+ "Increase to spread the timeline out and make short segments clickable.\n"
408
+ "The timeline area scrolls horizontally when zoomed in."
409
+ )
410
+ self.timeline_zoom_spin.valueChanged.connect(self._on_timeline_zoom_changed)
411
+ timeline_controls_layout.addWidget(self.timeline_zoom_spin)
412
+
413
+ self.export_timeline_btn = QPushButton("Export CSV/SVG")
414
+ self.export_timeline_btn.setToolTip("Export timeline as CSV (behavior segments) and SVG (visualization) for external analysis")
415
+ self.export_timeline_btn.clicked.connect(self._export_timeline)
416
+ self.export_timeline_btn.setEnabled(False)
417
+ timeline_controls_layout.addWidget(self.export_timeline_btn)
418
+
419
+ self.save_results_btn = QPushButton("Save results")
420
+ self.save_results_btn.setToolTip("Save inference results (including corrections) for downstream analysis")
421
+ self.save_results_btn.clicked.connect(self._save_results)
422
+ self.save_results_btn.setEnabled(False)
423
+ timeline_controls_layout.addWidget(self.save_results_btn)
424
+
425
+ # Put controls in a horizontal scroll area so all settings stay accessible
426
+ # on smaller screens instead of being cramped into one visible row.
427
+ timeline_controls_widget = QWidget()
428
+ timeline_controls_widget.setLayout(timeline_controls_layout)
429
+ timeline_controls_widget.setMinimumHeight(40)
430
+
431
+ timeline_controls_scroll = QScrollArea()
432
+ timeline_controls_scroll.setWidget(timeline_controls_widget)
433
+ timeline_controls_scroll.setWidgetResizable(True)
434
+ timeline_controls_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
435
+ timeline_controls_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
436
+ timeline_controls_scroll.setFrameShape(QScrollArea.Shape.NoFrame)
437
+ timeline_controls_scroll.setMinimumHeight(48)
438
+ timeline_controls_scroll.setMaximumHeight(62)
439
+
440
+ self.timeline_scroll = QScrollArea()
441
+ self.timeline_scroll.setWidgetResizable(False)
442
+ self.timeline_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
443
+ self.timeline_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
444
+ self.timeline_scroll.setMinimumHeight(100)
445
+ self.timeline_scroll.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
446
+
447
+ self.timeline_widget = TimelineWidget()
448
+ self.timeline_widget.click_callback = self._show_clip_popup
449
+ self.timeline_scroll.setWidget(self.timeline_widget)
450
+
451
+ # Container for OvR multi-row: two scroll areas side by side with synced vertical scroll
452
+ self._ovr_timeline_container = QWidget()
453
+ self._ovr_timeline_container.setVisible(False)
454
+ self._ovr_timeline_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
455
+ ovr_tl_layout = QHBoxLayout(self._ovr_timeline_container)
456
+ ovr_tl_layout.setContentsMargins(0, 0, 0, 0)
457
+ ovr_tl_layout.setSpacing(0)
458
+
459
+ # Left: label scroll area (vertical only, scrollbar hidden, synced to timeline)
460
+ self._ovr_label_scroll = QScrollArea()
461
+ self._ovr_label_scroll.setFixedWidth(100)
462
+ self._ovr_label_scroll.setWidgetResizable(False)
463
+ self._ovr_label_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
464
+ self._ovr_label_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
465
+ self._ovr_label_scroll.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)
466
+ self._ovr_label_panel = QLabel()
467
+ self._ovr_label_scroll.setWidget(self._ovr_label_panel)
468
+ ovr_tl_layout.addWidget(self._ovr_label_scroll)
469
+
470
+ # Right: timeline scroll area (horizontal + vertical)
471
+ self._ovr_scroll = QScrollArea()
472
+ self._ovr_scroll.setWidgetResizable(False)
473
+ self._ovr_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
474
+ self._ovr_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
475
+ self._ovr_scroll.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
476
+ self._ovr_scroll.setMinimumWidth(300)
477
+ self._ovr_timeline_widget = QLabel()
478
+ self._ovr_timeline_widget.mousePressEvent = self._ovr_timeline_click
479
+ self._ovr_clip_width = 20
480
+ self._ovr_num_clips = 0
481
+ self._ovr_row_height = 24
482
+ self._ovr_timeline_frame_mode = False
483
+ self._ovr_pixels_per_frame = 1.0
484
+ self._ovr_scroll.setWidget(self._ovr_timeline_widget)
485
+ ovr_tl_layout.addWidget(self._ovr_scroll)
486
+ self._ovr_scroll.viewport().installEventFilter(self)
487
+ self._ovr_timeline_container.installEventFilter(self)
488
+
489
+ # Sync vertical scroll: when timeline scrolls vertically, labels follow
490
+ self._ovr_scroll.verticalScrollBar().valueChanged.connect(
491
+ self._ovr_label_scroll.verticalScrollBar().setValue
492
+ )
493
+
494
+ timeline_group_layout.addWidget(timeline_controls_scroll)
495
+ timeline_group_layout.addWidget(self.timeline_scroll, 1)
496
+ timeline_group_layout.addWidget(self._ovr_timeline_container, 1)
497
+ timeline_preset_row = QHBoxLayout()
498
+ timeline_preset_row.addStretch()
499
+ self.save_timeline_settings_btn = QPushButton("Save timeline settings")
500
+ self.save_timeline_settings_btn.setToolTip(
501
+ "Save ignore filtering and timeline postprocessing/display settings as a reusable preset."
502
+ )
503
+ self.save_timeline_settings_btn.clicked.connect(self._save_timeline_settings_preset)
504
+ timeline_preset_row.addWidget(self.save_timeline_settings_btn)
505
+ self.load_timeline_settings_btn = QPushButton("Load timeline settings")
506
+ self.load_timeline_settings_btn.setToolTip(
507
+ "Load a previously saved timeline settings preset."
508
+ )
509
+ self.load_timeline_settings_btn.clicked.connect(self._load_timeline_settings_preset)
510
+ timeline_preset_row.addWidget(self.load_timeline_settings_btn)
511
+ timeline_group_layout.addLayout(timeline_preset_row)
512
+ timeline_group.setLayout(timeline_group_layout)
513
+ timeline_group.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding)
514
+ layout.addWidget(timeline_group, 1)
515
+
516
+ results_group = QGroupBox("Results")
517
+ results_layout = QVBoxLayout()
518
+ self.results_list = QListWidget()
519
+ results_layout.addWidget(self.results_list)
520
+ results_group.setLayout(results_layout)
521
+
522
+ log_group = QGroupBox("Logs")
523
+ log_layout = QVBoxLayout()
524
+ self.log_text = QTextEdit()
525
+ self.log_text.setReadOnly(True)
526
+ log_layout.addWidget(self.log_text)
527
+ log_group.setLayout(log_layout)
528
+
529
+ results_logs_splitter = QSplitter(Qt.Orientation.Horizontal)
530
+ results_logs_splitter.addWidget(results_group)
531
+ results_logs_splitter.addWidget(log_group)
532
+ results_logs_splitter.setStretchFactor(0, 1)
533
+ results_logs_splitter.setStretchFactor(1, 1)
534
+ results_logs_splitter.setMaximumHeight(320)
535
+ layout.addWidget(results_logs_splitter, 0)
536
+
537
+ self.setLayout(layout)
538
+
539
+ def _collect_timeline_settings_payload(self) -> dict:
540
+ return {
541
+ "frame_aggregation_enabled": bool(self.frame_aggregation_check.isChecked()),
542
+ "merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
543
+ "use_viterbi_decode": bool(self.use_viterbi_decode),
544
+ "viterbi_switch_penalty": float(self.viterbi_switch_penalty),
545
+ "use_ignore_threshold": bool(self.use_ignore_threshold),
546
+ "ignore_threshold": float(self.global_ignore_threshold),
547
+ "class_ignore_thresholds": {
548
+ cls: float(t) for cls, t in self.class_ignore_thresholds.items()
549
+ },
550
+ "timeline_zoom": int(self.timeline_zoom_spin.value()),
551
+ "ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
552
+ "ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
553
+ "min_segment_frames": int(self._min_segment_frames),
554
+ "merge_gap_frames": int(self._merge_gap_frames),
555
+ "temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
556
+ "class_min_segment_frames": {
557
+ cls: int(v) for cls, v in self.class_min_segment_frames.items()
558
+ },
559
+ "class_merge_gap_frames": {
560
+ cls: int(v) for cls, v in self.class_merge_gap_frames.items()
561
+ },
562
+ "class_smoothing_window_frames": {
563
+ cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
564
+ },
565
+ }
566
+
567
+ def _apply_timeline_settings_payload(self, payload: dict):
568
+ if not isinstance(payload, dict):
569
+ return
570
+ self.use_ignore_threshold = bool(payload.get("use_ignore_threshold", self.use_ignore_threshold))
571
+ self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
572
+ if "ignore_threshold" in payload:
573
+ self.global_ignore_threshold = float(payload["ignore_threshold"])
574
+ self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
575
+ if "class_ignore_thresholds" in payload and isinstance(payload["class_ignore_thresholds"], dict):
576
+ self.class_ignore_thresholds = {
577
+ str(cls): float(t) for cls, t in payload["class_ignore_thresholds"].items()
578
+ }
579
+ if "timeline_zoom" in payload:
580
+ self.timeline_zoom_spin.setValue(int(payload["timeline_zoom"]))
581
+ if "frame_aggregation_enabled" in payload:
582
+ self.frame_aggregation_check.setChecked(bool(payload["frame_aggregation_enabled"]))
583
+ if "merge_consecutive_enabled" in payload:
584
+ self.merge_timeline_check.setChecked(bool(payload["merge_consecutive_enabled"]))
585
+ if "use_viterbi_decode" in payload:
586
+ self.use_viterbi_decode = bool(payload["use_viterbi_decode"])
587
+ self.use_viterbi_check.setChecked(self.use_viterbi_decode)
588
+ if "viterbi_switch_penalty" in payload:
589
+ self.viterbi_switch_penalty = float(payload["viterbi_switch_penalty"])
590
+ self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
591
+ if getattr(self, "ovr_rows_check", None) is not None and "ovr_rows" in payload:
592
+ self.ovr_rows_check.setChecked(bool(payload["ovr_rows"]))
593
+ if getattr(self, "ovr_show_all_check", None) is not None and "ovr_show_all" in payload:
594
+ self.ovr_show_all_check.setChecked(bool(payload["ovr_show_all"]))
595
+ if "min_segment_frames" in payload:
596
+ self._min_segment_frames = int(payload["min_segment_frames"])
597
+ if "merge_gap_frames" in payload:
598
+ self._merge_gap_frames = int(payload["merge_gap_frames"])
599
+ if "temporal_smoothing_window_frames" in payload:
600
+ smooth = int(payload["temporal_smoothing_window_frames"])
601
+ if smooth % 2 == 0:
602
+ smooth += 1
603
+ self._temporal_smoothing_window_frames = max(1, smooth)
604
+ if "class_min_segment_frames" in payload and isinstance(payload["class_min_segment_frames"], dict):
605
+ self.class_min_segment_frames = {
606
+ str(cls): int(v) for cls, v in payload["class_min_segment_frames"].items()
607
+ }
608
+ if "class_merge_gap_frames" in payload and isinstance(payload["class_merge_gap_frames"], dict):
609
+ self.class_merge_gap_frames = {
610
+ str(cls): int(v) for cls, v in payload["class_merge_gap_frames"].items()
611
+ }
612
+ if "class_smoothing_window_frames" in payload and isinstance(payload["class_smoothing_window_frames"], dict):
613
+ self.class_smoothing_window_frames = {
614
+ str(cls): int(v) for cls, v in payload["class_smoothing_window_frames"].items()
615
+ }
616
+ self.config["inference_ignore_threshold"] = self.global_ignore_threshold
617
+ self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
618
+ self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
619
+ self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
620
+ self._sync_per_class_segment_rule_config()
621
+ if self.predictions:
622
+ if self.frame_aggregation_check.isChecked():
623
+ self._compute_aggregated_timeline()
624
+ self._display_results()
625
+
626
+ def _save_timeline_settings_preset(self):
627
+ path, _ = QFileDialog.getSaveFileName(
628
+ self,
629
+ "Save Timeline Settings",
630
+ self.config.get("experiment_path", os.getcwd()),
631
+ "JSON Files (*.json)"
632
+ )
633
+ if not path:
634
+ return
635
+ payload = {
636
+ "classes": list(self.classes),
637
+ "timeline_settings": self._collect_timeline_settings_payload(),
638
+ }
639
+ with open(path, "w", encoding="utf-8") as f:
640
+ json.dump(payload, f, indent=2)
641
+ self.log_text.append(f"Timeline settings saved: {path}")
642
+
643
+ def _load_timeline_settings_preset(self):
644
+ path, _ = QFileDialog.getOpenFileName(
645
+ self,
646
+ "Load Timeline Settings",
647
+ self.config.get("experiment_path", os.getcwd()),
648
+ "JSON Files (*.json)"
649
+ )
650
+ if not path:
651
+ return
652
+ try:
653
+ with open(path, "r", encoding="utf-8") as f:
654
+ data = json.load(f)
655
+ payload = data.get("timeline_settings", data)
656
+ self._apply_timeline_settings_payload(payload)
657
+ self.log_text.append(f"Timeline settings loaded: {path}")
658
+ except Exception as e:
659
+ QMessageBox.warning(self, "Load failed", f"Failed to load timeline settings:\n{e}")
660
+
661
+ def _load_model(self):
662
+ """Load trained model head."""
663
+ model_path, _ = QFileDialog.getOpenFileName(
664
+ self,
665
+ "Load Model Head",
666
+ self.config.get("models_dir", "models/behavior_heads"),
667
+ "PyTorch Files (*.pt);;All Files (*)"
668
+ )
669
+
670
+ if not model_path:
671
+ return
672
+
673
+ try:
674
+ self.log_text.append("Loading model...")
675
+
676
+ pre_state_dict = None
677
+ inferred_embed_dim = None
678
+ inferred_backbone = None
679
+ if os.path.exists(model_path):
680
+ try:
681
+ pre_state_dict = torch.load(model_path, map_location='cpu')
682
+ for key in (
683
+ "head_root.query",
684
+ "head_class.query",
685
+ "head_root.ln1.weight",
686
+ "head_class.ln1.weight",
687
+ "head_root.fc.weight",
688
+ "head_class.fc.weight",
689
+ "query",
690
+ "ln1.weight",
691
+ "fc.weight",
692
+ ):
693
+ if key in pre_state_dict:
694
+ tensor = pre_state_dict[key]
695
+ if tensor is None:
696
+ continue
697
+ if tensor.ndim == 3:
698
+ inferred_embed_dim = tensor.shape[-1]
699
+ elif tensor.ndim == 2:
700
+ inferred_embed_dim = tensor.shape[1]
701
+ elif tensor.ndim == 1:
702
+ inferred_embed_dim = tensor.shape[0]
703
+ if inferred_embed_dim:
704
+ break
705
+ embed_to_backbone = {
706
+ 768: "videoprism_public_v1_base",
707
+ 1024: "videoprism_public_v1_large",
708
+ }
709
+ inferred_backbone = embed_to_backbone.get(inferred_embed_dim)
710
+ if inferred_backbone:
711
+ self.log_text.append(
712
+ f"Inferred backbone from weights: {inferred_backbone} (embed_dim={inferred_embed_dim})"
713
+ )
714
+ except Exception:
715
+ pre_state_dict = None
716
+ # Attempt to load metadata (preferred)
717
+ meta_classes = None
718
+ meta_path = model_path + ".meta.json"
719
+ metadata_backbone = None
720
+ if os.path.exists(meta_path):
721
+ try:
722
+ with open(meta_path, "r", encoding="utf-8") as meta_file:
723
+ meta_data = json.load(meta_file)
724
+ meta_classes = meta_data.get("classes")
725
+ self.attributes = meta_data.get("attributes", []) # Load attributes
726
+ self.attributes_registry = meta_data.get("attributes_registry", None)
727
+ self.hierarchy_registry = meta_data.get("hierarchy_registry", None)
728
+ self.label_mapping = meta_data.get("label_mapping", None)
729
+ if meta_classes:
730
+ self.log_text.append(f"Loaded {len(meta_classes)} classes from metadata: {meta_classes}")
731
+ if self.attributes:
732
+ self.log_text.append(f"Loaded {len(self.attributes)} attributes: {self.attributes}")
733
+ if self.hierarchy_registry:
734
+ self.log_text.append("Loaded deep hierarchy registry")
735
+ if not self.attributes:
736
+ leaf_nodes = [
737
+ n for n, children in self.hierarchy_registry.items()
738
+ if not children and n != "__root__"
739
+ ]
740
+ self.attributes = sorted(set(leaf_nodes))
741
+ if self.attributes:
742
+ self.log_text.append(f"Derived {len(self.attributes)} leaf attributes from hierarchy")
743
+ if self.attributes_registry:
744
+ self.log_text.append(f"Loaded hierarchical registry for {len(self.attributes_registry)} classes")
745
+ elif self.attributes and meta_classes:
746
+ # Fallback: Reconstruct registry from attribute names if missing (for models trained before registry was saved)
747
+ self.log_text.append("Registry missing from metadata. Attempting to reconstruct from attribute names...")
748
+ reconstructed = {}
749
+ for cls in meta_classes:
750
+ reconstructed[cls] = []
751
+
752
+ for attr in self.attributes:
753
+ # Match attribute to class based on prefix (e.g. "Walk_Start" -> "Walk")
754
+ # We look for the longest matching class name prefix
755
+ best_match = None
756
+ for cls in meta_classes:
757
+ if attr.startswith(cls + "_") or attr == cls:
758
+ if best_match is None or len(cls) > len(best_match):
759
+ best_match = cls
760
+
761
+ if best_match:
762
+ reconstructed[best_match].append(attr)
763
+
764
+ # Only use if we found matches
765
+ if any(reconstructed.values()):
766
+ for cls in reconstructed:
767
+ reconstructed[cls].sort()
768
+ self.attributes_registry = reconstructed
769
+ self.log_text.append(f"Reconstructed registry: { {k: len(v) for k, v in reconstructed.items()} }")
770
+ elif self.label_mapping and meta_classes:
771
+ if any("path" in v for v in self.label_mapping.values()):
772
+ # Fallback: Build hierarchy from label_mapping paths
773
+ self.log_text.append("Hierarchy missing from metadata. Reconstructing from label_mapping...")
774
+ reconstructed = {"__root__": []}
775
+ for _, info in self.label_mapping.items():
776
+ path = info.get("path", [])
777
+ if not path:
778
+ continue
779
+ if path[0] not in reconstructed["__root__"]:
780
+ reconstructed["__root__"].append(path[0])
781
+ for i in range(1, len(path)):
782
+ parent = path[i - 1]
783
+ child = path[i]
784
+ reconstructed.setdefault(parent, [])
785
+ if child not in reconstructed[parent]:
786
+ reconstructed[parent].append(child)
787
+ reconstructed.setdefault(path[-1], [])
788
+ for node in reconstructed:
789
+ reconstructed[node] = sorted(set(reconstructed[node]))
790
+ self.hierarchy_registry = reconstructed
791
+ leaf_nodes = [
792
+ n for n, children in reconstructed.items()
793
+ if not children and n != "__root__"
794
+ ]
795
+ self.attributes = sorted(set(leaf_nodes))
796
+ self.log_text.append("Reconstructed deep hierarchy registry from label_mapping")
797
+ else:
798
+ # Fallback: Build registry directly from label_mapping
799
+ self.log_text.append("Registry missing from metadata. Reconstructing from label_mapping...")
800
+ reconstructed = {cls: [] for cls in meta_classes}
801
+ for raw_label, info in self.label_mapping.items():
802
+ cls_name = info.get("class")
803
+ attr_name = info.get("attribute")
804
+ if cls_name in reconstructed and attr_name:
805
+ reconstructed[cls_name].append(attr_name)
806
+ if any(reconstructed.values()):
807
+ for cls in reconstructed:
808
+ reconstructed[cls] = sorted(set(reconstructed[cls]))
809
+ self.attributes_registry = reconstructed
810
+ self.attributes = sorted({a for attrs in reconstructed.values() for a in attrs})
811
+ self.log_text.append(f"Reconstructed registry from label_mapping: { {k: len(v) for k, v in reconstructed.items()} }")
812
+
813
+ # Load clip length from metadata if available
814
+ clip_len = meta_data.get("clip_length")
815
+ if clip_len:
816
+ self.clip_length_spin.setValue(int(clip_len))
817
+ self.log_text.append(f"Automatically set 'Frames per clip' to {clip_len} from model metadata.")
818
+
819
+ meta_fps = meta_data.get("target_fps") or meta_data.get("training_config", {}).get("target_fps")
820
+ if meta_fps:
821
+ self.target_fps_spin.setValue(int(meta_fps))
822
+ self.log_text.append(f"Automatically set 'Target FPS' to {meta_fps} from model metadata.")
823
+
824
+ resolution = meta_data.get("resolution") or meta_data.get("training_config", {}).get("resolution")
825
+ if resolution:
826
+ self.infer_resolution = int(resolution)
827
+ self.resolution_spin.setValue(self.infer_resolution)
828
+ self.log_text.append(f"Using inference resolution {self.infer_resolution}x{self.infer_resolution} from model metadata.")
829
+
830
+ if not resolution:
831
+ config_path = os.path.splitext(model_path)[0] + "_training_config.json"
832
+ if os.path.exists(config_path):
833
+ try:
834
+ with open(config_path, "r", encoding="utf-8") as cfg_file:
835
+ cfg_data = json.load(cfg_file)
836
+ if isinstance(cfg_data, dict):
837
+ cfg_train = cfg_data.get("training_config", cfg_data)
838
+ if isinstance(cfg_train, dict):
839
+ self.model_training_config = dict(cfg_train)
840
+ cfg_resolution = cfg_data.get("resolution") or cfg_data.get("training_config", {}).get("resolution")
841
+ if cfg_resolution:
842
+ self.infer_resolution = int(cfg_resolution)
843
+ self.resolution_spin.setValue(self.infer_resolution)
844
+ self.log_text.append(f"Using inference resolution {self.infer_resolution}x{self.infer_resolution} from training config.")
845
+ except Exception as cfg_err:
846
+ self.log_text.append(f"Warning: Failed to read training config {config_path}: {cfg_err}")
847
+
848
+ # Load localization crop parameters from training metadata
849
+ train_cfg = meta_data.get("training_config", {})
850
+ self.model_training_config = dict(train_cfg) if isinstance(train_cfg, dict) else {}
851
+ self._crop_padding = float(train_cfg.get("classification_crop_padding", 0.35) or 0.35)
852
+ self._crop_min_size = float(train_cfg.get("classification_crop_min_size_norm", 0.04) or 0.04)
853
+
854
+ # OvR: use sigmoid scoring instead of softmax at inference
855
+ self._use_ovr = bool(train_cfg.get("use_ovr", False))
856
+ self._update_viterbi_ui_state()
857
+ self._allowed_cooccurrence = set()
858
+ self._ovr_temperatures = {}
859
+ if self._use_ovr:
860
+ if getattr(self, "ovr_rows_check", None) is not None:
861
+ self.ovr_rows_check.setChecked(True)
862
+ self.log_text.append("OvR model detected: using sigmoid scoring at inference.")
863
+ for pair in train_cfg.get("allowed_cooccurrence", []):
864
+ if isinstance(pair, (list, tuple)) and len(pair) == 2:
865
+ self._allowed_cooccurrence.add((pair[0], pair[1]))
866
+ self._allowed_cooccurrence.add((pair[1], pair[0]))
867
+ if self._allowed_cooccurrence:
868
+ unique = train_cfg.get("allowed_cooccurrence", [])
869
+ pairs_str = ", ".join(f"{a}+{b}" for a, b in unique)
870
+ self.log_text.append(f"Allowed co-occurrence pairs: {pairs_str}")
871
+
872
+ # Per-head calibrated temperatures from training
873
+ ovr_temps = meta_data.get("ovr_temperatures", {})
874
+ if ovr_temps:
875
+ self._ovr_temperatures = {k: float(v) for k, v in ovr_temps.items()}
876
+ t_str = ", ".join(f"{k}={v}" for k, v in self._ovr_temperatures.items())
877
+ self.log_text.append(f"OvR calibrated temperatures: {t_str}")
878
+
879
+ # Validate backbone model
880
+ head_backbone = meta_data.get("backbone_model")
881
+ current_backbone = self.config.get("backbone_model", "videoprism_public_v1_base")
882
+
883
+ # Prefer inferred backbone if metadata disagrees with weights
884
+ if inferred_backbone and head_backbone and head_backbone != inferred_backbone:
885
+ self.log_text.append(
886
+ f"WARNING: Metadata backbone '{head_backbone}' does not match weights. "
887
+ f"Using inferred backbone '{inferred_backbone}'."
888
+ )
889
+ head_backbone = inferred_backbone
890
+
891
+ if head_backbone and head_backbone != current_backbone:
892
+ msg = (
893
+ f"Backbone Model Mismatch (auto-corrected for inference).\n"
894
+ f"Head backbone: '{head_backbone}', current config: '{current_backbone}'.\n"
895
+ f"Using '{head_backbone}' for inference."
896
+ )
897
+ self.log_text.append(f"WARNING: {msg}")
898
+
899
+ metadata_backbone = head_backbone
900
+
901
+ except Exception as meta_err:
902
+ self.log_text.append(f"Warning: Failed to read metadata file {meta_path}: {meta_err}")
903
+
904
+ if meta_classes:
905
+ self.classes = meta_classes
906
+ else:
907
+ annotation_file = self.config.get("annotation_file", "data/annotations/annotations.json")
908
+ if os.path.exists(annotation_file):
909
+ annotation_manager = AnnotationManager(annotation_file)
910
+ self.classes = annotation_manager.get_classes()
911
+ else:
912
+ QMessageBox.warning(self, "Warning", "Annotation file not found. Cannot determine classes.")
913
+ return
914
+
915
+ if not self.classes:
916
+ QMessageBox.warning(self, "Error", "No classes found for this model.")
917
+ return
918
+
919
+ self.log_text.append(f"Using {len(self.classes)} classes: {self.classes}")
920
+
921
+ resolved_backbone = inferred_backbone or metadata_backbone
922
+ if not resolved_backbone:
923
+ resolved_backbone = self.config.get("backbone_model", "videoprism_public_v1_base")
924
+ model_name = resolved_backbone
925
+ self.log_text.append(f"Loading VideoPrism backbone ({model_name}) at resolution {self.infer_resolution}×{self.infer_resolution}...")
926
+
927
+ backbone = VideoPrismBackbone(model_name=model_name, resolution=self.infer_resolution)
928
+
929
+ head_kwargs = {"num_heads": 4}
930
+ dropout = 0.1
931
+ num_attributes = 0
932
+ use_localization = False
933
+ use_frame_head = True
934
+ frame_head_temporal_layers = 1
935
+ temporal_pool_frames = 1
936
+ num_stages = 3
937
+ proj_dim = 256
938
+ localization_hidden_dim = 256
939
+ multi_scale = False
940
+ use_temporal_decoder = True
941
+
942
+ try:
943
+ if os.path.exists(meta_path):
944
+ num_attributes = meta_data.get("num_attributes", 0)
945
+ head_cfg = meta_data.get("head", {}) if 'meta_data' in locals() and isinstance(meta_data, dict) else {}
946
+ if isinstance(head_cfg, dict):
947
+ if "num_heads" in head_cfg:
948
+ head_kwargs["num_heads"] = head_cfg.get("num_heads", head_kwargs["num_heads"])
949
+ if "dropout" in head_cfg:
950
+ dropout = head_cfg.get("dropout", dropout)
951
+ use_localization = bool(
952
+ head_cfg.get("use_localization", False)
953
+ or meta_data.get("training_config", {}).get("use_localization", False)
954
+ )
955
+ localization_hidden_dim = int(head_cfg.get("localization_hidden_dim", 256))
956
+ use_temporal_decoder = bool(
957
+ head_cfg.get(
958
+ "use_temporal_decoder",
959
+ meta_data.get("training_config", {}).get("use_temporal_decoder", True),
960
+ )
961
+ )
962
+ frame_head_temporal_layers = int(head_cfg.get("frame_head_temporal_layers", 1))
963
+ temporal_pool_frames = int(head_cfg.get("temporal_pool_frames", 1))
964
+ num_stages = int(head_cfg.get("num_stages", 3))
965
+ proj_dim = int(head_cfg.get("proj_dim", 256))
966
+ multi_scale = bool(head_cfg.get("multi_scale", False))
967
+ if frame_head_temporal_layers <= 0:
968
+ frame_head_temporal_layers = int(
969
+ meta_data.get("training_config", {}).get("frame_head_temporal_layers", 1)
970
+ )
971
+ if temporal_pool_frames <= 0:
972
+ temporal_pool_frames = int(
973
+ meta_data.get("training_config", {}).get("temporal_pool_frames", 1)
974
+ )
975
+ if num_stages <= 0:
976
+ num_stages = int(
977
+ meta_data.get("training_config", {}).get("num_stages", 3)
978
+ )
979
+ except Exception:
980
+ pass
981
+
982
+ if frame_head_temporal_layers <= 0:
983
+ frame_head_temporal_layers = 1
984
+ if use_temporal_decoder:
985
+ self.log_text.append(f"Frame head: temporal decoder on, layers={frame_head_temporal_layers}, proj_dim={proj_dim}")
986
+ else:
987
+ self.log_text.append(f"Frame head: direct per-frame classifier, proj_dim={proj_dim}")
988
+
989
+ # Peek at checkpoint to detect multi_scale (most reliable source)
990
+ try:
991
+ _ckpt = torch.load(model_path, map_location='cpu', weights_only=False)
992
+ if isinstance(_ckpt, dict) and "multi_scale" in _ckpt:
993
+ multi_scale = bool(_ckpt["multi_scale"])
994
+ if isinstance(_ckpt, dict) and "use_temporal_decoder" in _ckpt:
995
+ use_temporal_decoder = bool(_ckpt["use_temporal_decoder"])
996
+ del _ckpt
997
+ except Exception:
998
+ pass
999
+
1000
+ self.log_text.append(f"Creating classifier (kwargs={head_kwargs}, multi_scale={multi_scale})...")
1001
+ self.model = BehaviorClassifier(
1002
+ backbone,
1003
+ num_classes=len(self.classes),
1004
+ class_names=self.classes,
1005
+ dropout=dropout,
1006
+ freeze_backbone=True,
1007
+ head_kwargs=head_kwargs,
1008
+ use_localization=use_localization,
1009
+ localization_hidden_dim=localization_hidden_dim,
1010
+ use_frame_head=use_frame_head,
1011
+ use_temporal_decoder=use_temporal_decoder,
1012
+ frame_head_temporal_layers=frame_head_temporal_layers,
1013
+ temporal_pool_frames=temporal_pool_frames,
1014
+ proj_dim=proj_dim,
1015
+ num_stages=num_stages,
1016
+ multi_scale=multi_scale,
1017
+ )
1018
+
1019
+ self.log_text.append(f"Loading head weights from {model_path}...")
1020
+ if getattr(self, "_use_ovr", False) and getattr(self.model, "frame_head", None) is not None:
1021
+ self.model.frame_head.use_ovr = True
1022
+ try:
1023
+ self.model.load_head(model_path)
1024
+ except RuntimeError as e:
1025
+ err_str = str(e)
1026
+ if "size mismatch" in err_str:
1027
+ msg = (
1028
+ f"Model architecture mismatch.\n"
1029
+ f"Error Details: {err_str}\n\n"
1030
+ "Please ensure the weights file matches the current architecture."
1031
+ )
1032
+ self.log_text.append(f"ERROR: {msg}")
1033
+ QMessageBox.critical(self, "Model Load Error", msg)
1034
+ self.model = None
1035
+ return
1036
+ raise
1037
+
1038
+ if multi_scale:
1039
+ self.log_text.append(
1040
+ "Multi-scale mode active: inference runs backbone twice per clip "
1041
+ "(full fps + half fps). Expect ~2x backbone time per clip."
1042
+ )
1043
+ self.model_path_edit.setText(model_path)
1044
+ self.classes_label.setText(", ".join(self.classes))
1045
+
1046
+ # Populate filter combo
1047
+ self.filter_behavior_combo.blockSignals(True)
1048
+ self.filter_behavior_combo.clear()
1049
+ self.filter_behavior_combo.addItem("All Behaviors")
1050
+ self.filter_behavior_combo.addItem(self.ignore_label_name)
1051
+ self.filter_behavior_combo.addItems(self.classes)
1052
+ self.filter_behavior_combo.blockSignals(False)
1053
+
1054
+ self.log_text.append("Model loaded successfully!")
1055
+ QMessageBox.information(self, "Success", "Model loaded successfully!")
1056
+
1057
+ # Offer autosave recovery if a crash interrupted a previous batch run
1058
+ self._check_autosave_recovery(model_path)
1059
+
1060
+ if self.video_path:
1061
+ self.run_inference_btn.setEnabled(True)
1062
+
1063
+ except Exception as e:
1064
+ error_msg = f"Failed to load model: {str(e)}"
1065
+ self.log_text.append(f"ERROR: {error_msg}")
1066
+ QMessageBox.critical(self, "Error", error_msg)
1067
+
1068
+ def _select_video(self):
1069
+ """Select video file(s) for inference."""
1070
+ video_paths, _ = QFileDialog.getOpenFileNames(
1071
+ self,
1072
+ "Select Video File(s)",
1073
+ self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos")),
1074
+ "Video Files (*.mp4 *.avi *.mov *.mkv);;All Files (*)"
1075
+ )
1076
+
1077
+ if video_paths:
1078
+ # Ensure videos are in experiment folder (batch operation)
1079
+ from .video_utils import ensure_videos_in_experiment
1080
+ self.video_paths = ensure_videos_in_experiment(video_paths, self.config, self)
1081
+ if len(video_paths) == 1:
1082
+ self.video_path = video_paths[0]
1083
+ self.video_path_edit.setText(self.video_path)
1084
+
1085
+ info = get_video_info(self.video_path)
1086
+ if info:
1087
+ info_text = f"FPS: {info['fps']:.2f}, Frames: {info['frame_count']}, Size: {info['width']}x{info['height']}"
1088
+ self.video_info_label.setText(info_text)
1089
+ else:
1090
+ self.video_path = video_paths[0] # Set first as default for preview
1091
+ self.video_path_edit.setText(f"{len(video_paths)} videos selected")
1092
+ self.video_info_label.setText(f"Selected {len(video_paths)} videos")
1093
+
1094
+ if self.model:
1095
+ self.run_inference_btn.setEnabled(True)
1096
+
1097
+ def _update_sample_inference_controls(self, enabled: bool):
1098
+ self.sample_duration_spin.setEnabled(bool(enabled))
1099
+ self.sample_count_spin.setEnabled(bool(enabled))
1100
+
1101
+ def _compute_sample_ranges_for_video(self, video_path: str):
1102
+ """Build evenly spread sample ranges [(start_frame, end_frame), ...] for one video."""
1103
+ if not video_path:
1104
+ return []
1105
+ cap = cv2.VideoCapture(video_path)
1106
+ if not cap.isOpened():
1107
+ return []
1108
+ fps = cap.get(cv2.CAP_PROP_FPS)
1109
+ if fps <= 0:
1110
+ fps = 30.0
1111
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1112
+ cap.release()
1113
+ if total_frames <= 0:
1114
+ return []
1115
+
1116
+ dur_frames = int(round(self.sample_duration_spin.value() * fps))
1117
+ dur_frames = max(1, min(dur_frames, total_frames))
1118
+ n = max(1, int(self.sample_count_spin.value()))
1119
+ usable = max(0, total_frames - dur_frames)
1120
+ if n == 1 or usable == 0:
1121
+ starts = [usable // 2]
1122
+ else:
1123
+ starts = [int(round(i * usable / (n - 1))) for i in range(n)]
1124
+ return [(start, min(start + dur_frames, total_frames)) for start in starts]
1125
+
1126
+ def _run_inference(self):
1127
+ """Run inference on selected video."""
1128
+ if not self.model:
1129
+ QMessageBox.warning(self, "Error", "Please load a model first.")
1130
+ return
1131
+
1132
+ if not self.video_path:
1133
+ QMessageBox.warning(self, "Error", "Please select a video first.")
1134
+ return
1135
+
1136
+ self.run_inference_btn.setEnabled(False)
1137
+ self.load_timeline_btn.setEnabled(False)
1138
+ self.progress_bar.setVisible(True)
1139
+ self.progress_bar.setValue(0)
1140
+ self.progress_label.setText("Running inference...")
1141
+ self.results_list.clear()
1142
+ self.log_text.clear()
1143
+ self.log_text.append("Starting inference...")
1144
+ if self.override_resolution_check.isChecked():
1145
+ self.infer_resolution = int(self.resolution_spin.value())
1146
+ self.log_text.append(f"Inference resolution (override): {self.infer_resolution}x{self.infer_resolution}")
1147
+ else:
1148
+ self.log_text.append(f"Inference resolution: {self.infer_resolution}x{self.infer_resolution}")
1149
+
1150
+ video_paths = self.video_paths if self.video_paths else ([self.video_path] if self.video_path else [])
1151
+ sample_ranges_by_video = {}
1152
+ if self.sample_inference_check.isChecked():
1153
+ for v_path in video_paths:
1154
+ sample_ranges = self._compute_sample_ranges_for_video(v_path)
1155
+ if sample_ranges:
1156
+ sample_ranges_by_video[v_path] = sample_ranges
1157
+ cap = cv2.VideoCapture(v_path)
1158
+ fps = cap.get(cv2.CAP_PROP_FPS)
1159
+ cap.release()
1160
+ if fps <= 0:
1161
+ fps = 30.0
1162
+ summary = ", ".join(
1163
+ f"{start / fps:.1f}s-{end / fps:.1f}s" for start, end in sample_ranges[:6]
1164
+ )
1165
+ if len(sample_ranges) > 6:
1166
+ summary += ", ..."
1167
+ self.log_text.append(
1168
+ f"Quick-check sampled inference for {os.path.basename(v_path)}: {summary}"
1169
+ )
1170
+ else:
1171
+ self.log_text.append(
1172
+ f"Warning: could not build sample ranges for {os.path.basename(v_path)}. "
1173
+ "Falling back to full video."
1174
+ )
1175
+
1176
+ model_path = self.model_path_edit.text()
1177
+
1178
+ self.worker = InferenceWorker(
1179
+ self.model,
1180
+ video_paths,
1181
+ self.target_fps_spin.value(),
1182
+ self.clip_length_spin.value(),
1183
+ self.step_frames_spin.value(),
1184
+ resolution=self.infer_resolution,
1185
+ classes=self.classes,
1186
+ use_localization_pipeline=getattr(self.model, "use_localization", False),
1187
+ crop_padding=getattr(self, "_crop_padding", 0.35),
1188
+ crop_min_size=getattr(self, "_crop_min_size", 0.04),
1189
+ use_ovr=getattr(self, "_use_ovr", False),
1190
+ ovr_temperatures=getattr(self, "_ovr_temperatures", {}),
1191
+ collect_attention=self.collect_attention_check.isChecked(),
1192
+ sample_ranges_by_video=sample_ranges_by_video,
1193
+ )
1194
+ self.worker.progress.connect(self._on_progress)
1195
+ self.worker.finished.connect(self._on_finished)
1196
+ self.worker.error.connect(self._on_error)
1197
+ self.worker.log_message.connect(self._on_log)
1198
+ self.worker.video_done.connect(self._on_video_done)
1199
+ self.run_inference_btn.setEnabled(False)
1200
+ self.stop_inference_btn.setEnabled(True)
1201
+ self.worker.start()
1202
+
1203
+ def _on_progress(self, current: int, total: int):
1204
+ """Update progress (capped at 100%)."""
1205
+ if total > 0:
1206
+ progress = min(100, int(100 * current / total))
1207
+ self.progress_bar.setValue(progress)
1208
+ self.progress_label.setText(f"Processing: {current}/{total} items ({progress}%)")
1209
+
1210
+ def _on_finished(self, results: dict):
1211
+ """Handle inference completion."""
1212
+ self.results_cache = results
1213
+ self._ignore_threshold_user_modified = False
1214
+
1215
+ # Initialize corrections storage for each video if not present
1216
+ for v_path in self.results_cache:
1217
+ if "corrected_labels" not in self.results_cache[v_path]:
1218
+ self.results_cache[v_path]["corrected_labels"] = {}
1219
+ if "corrected_attr_labels" not in self.results_cache[v_path]:
1220
+ self.results_cache[v_path]["corrected_attr_labels"] = {}
1221
+
1222
+ # Populate video dropdown
1223
+ self.filter_video_combo.blockSignals(True)
1224
+ self.filter_video_combo.clear()
1225
+ for path in sorted(results.keys()):
1226
+ self.filter_video_combo.addItem(os.path.basename(path), path)
1227
+ self.filter_video_combo.blockSignals(False)
1228
+
1229
+ # Select first video
1230
+ if results:
1231
+ has_any_frame_probs = any(
1232
+ isinstance(res.get("clip_frame_probabilities"), list) and len(res.get("clip_frame_probabilities", [])) > 0
1233
+ for res in results.values()
1234
+ )
1235
+ if has_any_frame_probs and not self.frame_aggregation_check.isChecked():
1236
+ # Automatically switch to precise frame-boundary mode when
1237
+ # frame-head outputs are available.
1238
+ self.frame_aggregation_check.setChecked(True)
1239
+ self.log_text.append("Detected frame-head predictions. Enabled 'Precise frame boundaries' mode.")
1240
+ first_video = sorted(results.keys())[0]
1241
+ idx = self.filter_video_combo.findData(first_video)
1242
+ if idx >= 0:
1243
+ self.filter_video_combo.setCurrentIndex(idx)
1244
+ self._on_video_selection_changed(idx)
1245
+
1246
+ total_clips = sum(len(res["predictions"]) for res in results.values())
1247
+ self.progress_bar.setValue(100)
1248
+ self.progress_bar.setVisible(False)
1249
+ if results:
1250
+ self.progress_label.setText(f"Inference complete! Processed {len(results)} videos, {total_clips} clips")
1251
+ else:
1252
+ self.progress_label.setText("Stopped. No videos were completed.")
1253
+ self.run_inference_btn.setEnabled(True)
1254
+ self.stop_inference_btn.setEnabled(False)
1255
+ self.load_timeline_btn.setEnabled(True)
1256
+
1257
+ self.export_btn.setEnabled(bool(results))
1258
+ self.export_timeline_btn.setEnabled(bool(results))
1259
+ self.save_results_btn.setEnabled(bool(results))
1260
+ has_attn = any(
1261
+ "clip_attention_maps" in res for res in results.values()
1262
+ ) if results else False
1263
+ self.export_attention_btn.setEnabled(has_attn)
1264
+
1265
+ if results:
1266
+ self._save_results(silent=True)
1267
+ autosave_path = self._autosave_path()
1268
+ if autosave_path and os.path.exists(autosave_path):
1269
+ try:
1270
+ os.remove(autosave_path)
1271
+ except Exception:
1272
+ pass
1273
+
1274
+ # Save uncertainty report and notify the Review tab.
1275
+ try:
1276
+ model_path = self.model_path_edit.text()
1277
+ if model_path:
1278
+ uncertainty_path = os.path.join(
1279
+ os.path.dirname(model_path),
1280
+ os.path.splitext(os.path.basename(model_path))[0] + "_uncertainty.json",
1281
+ )
1282
+ else:
1283
+ exp_path = self.config.get("experiment_path", "")
1284
+ uncertainty_path = os.path.join(exp_path, "results", "_uncertainty.json")
1285
+ save_uncertainty_report(
1286
+ results=results,
1287
+ classes=self.classes,
1288
+ output_path=uncertainty_path,
1289
+ is_ovr=bool(self._use_ovr),
1290
+ n_per_class=25,
1291
+ clip_length=self.clip_length_spin.value(),
1292
+ target_fps=self.target_fps_spin.value(),
1293
+ )
1294
+ self.log_text.append(f"Uncertainty report saved to: {uncertainty_path}")
1295
+ except Exception as _ue:
1296
+ self.log_text.append(f"Warning: could not save uncertainty report: {_ue}")
1297
+
1298
+ self.review_ready.emit(
1299
+ results,
1300
+ list(self.classes),
1301
+ bool(self._use_ovr),
1302
+ self.clip_length_spin.value(),
1303
+ self.target_fps_spin.value(),
1304
+ )
1305
+
1306
+ QMessageBox.information(self, "Success", f"Inference complete! Processed {len(results)} videos.")
1307
+
1308
+ def _autosave_path(self) -> str:
1309
+ """Return the path for the incremental autosave file next to the loaded model."""
1310
+ model_path = self.model_path_edit.text()
1311
+ if not model_path:
1312
+ return ""
1313
+ return os.path.join(os.path.dirname(model_path), "_inference_autosave.json")
1314
+
1315
+ def _on_video_done(self, video_path: str, res_entry: dict):
1316
+ """Called after each video finishes — accumulate and autosave incrementally."""
1317
+ # Initialise corrections fields
1318
+ res_entry.setdefault("corrected_labels", {})
1319
+ res_entry.setdefault("corrected_attr_labels", {})
1320
+ self.results_cache[video_path] = res_entry
1321
+ self.export_btn.setEnabled(True)
1322
+ self.export_timeline_btn.setEnabled(True)
1323
+ self.save_results_btn.setEnabled(True)
1324
+ if "clip_attention_maps" in res_entry:
1325
+ self.export_attention_btn.setEnabled(True)
1326
+
1327
+ self.filter_video_combo.blockSignals(True)
1328
+ existing_idx = self.filter_video_combo.findData(video_path)
1329
+ if existing_idx < 0:
1330
+ self.filter_video_combo.addItem(os.path.basename(video_path), video_path)
1331
+ model = self.filter_video_combo.model()
1332
+ if model is not None:
1333
+ self.filter_video_combo.model().sort(0)
1334
+ self.filter_video_combo.blockSignals(False)
1335
+
1336
+ current_video = self.filter_video_combo.currentData()
1337
+ if not current_video:
1338
+ idx = self.filter_video_combo.findData(video_path)
1339
+ if idx >= 0:
1340
+ self.filter_video_combo.setCurrentIndex(idx)
1341
+ self._on_video_selection_changed(idx)
1342
+ elif current_video == video_path:
1343
+ idx = self.filter_video_combo.currentIndex()
1344
+ self._on_video_selection_changed(idx)
1345
+
1346
+ finished_count = len(self.results_cache)
1347
+ clip_count = len(res_entry.get("predictions", []))
1348
+ self.log_text.append(
1349
+ f"Finished {os.path.basename(video_path)}: {clip_count} clips ready to review "
1350
+ f"({finished_count} video{'s' if finished_count != 1 else ''} available so far)."
1351
+ )
1352
+
1353
+ autosave_path = self._autosave_path()
1354
+ if not autosave_path:
1355
+ return
1356
+ try:
1357
+ save_data = {
1358
+ "classes": self.classes,
1359
+ "parameters": {
1360
+ "target_fps": self.target_fps_spin.value(),
1361
+ "clip_length": self.clip_length_spin.value(),
1362
+ "step_frames": self.step_frames_spin.value(),
1363
+ "sample_inference_enabled": bool(self.sample_inference_check.isChecked()),
1364
+ "sample_duration_seconds": int(self.sample_duration_spin.value()),
1365
+ "sample_num_chunks": int(self.sample_count_spin.value()),
1366
+ "frame_aggregation_enabled": self.frame_aggregation_check.isChecked(),
1367
+ "merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
1368
+ "use_viterbi_decode": bool(self.use_viterbi_decode),
1369
+ "viterbi_switch_penalty": float(self.viterbi_switch_penalty),
1370
+ "min_segment_frames": int(self._min_segment_frames),
1371
+ "merge_gap_frames": int(self._merge_gap_frames),
1372
+ "temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
1373
+ "class_min_segment_frames": {
1374
+ cls: int(v) for cls, v in self.class_min_segment_frames.items()
1375
+ },
1376
+ "class_merge_gap_frames": {
1377
+ cls: int(v) for cls, v in self.class_merge_gap_frames.items()
1378
+ },
1379
+ "class_smoothing_window_frames": {
1380
+ cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
1381
+ },
1382
+ "use_ovr": bool(self._use_ovr),
1383
+ "ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
1384
+ "ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
1385
+ "use_ignore_threshold": bool(self.use_ignore_threshold),
1386
+ "ignore_threshold": float(self.global_ignore_threshold),
1387
+ "class_ignore_thresholds": {
1388
+ cls: float(t) for cls, t in self.class_ignore_thresholds.items()
1389
+ },
1390
+ "timeline_zoom": int(self.timeline_zoom_spin.value()),
1391
+ },
1392
+ "results": self.results_cache,
1393
+ "_autosave": True,
1394
+ }
1395
+ self._write_results_bundle(autosave_path, save_data, pretty=False)
1396
+ self.log_text.append(
1397
+ f" [Autosave] {len(self.results_cache)} video(s) saved → {os.path.basename(autosave_path)}"
1398
+ )
1399
+ except Exception as e:
1400
+ self.log_text.append(f" [Autosave] Warning: could not write autosave: {e}")
1401
+
1402
+ def _check_autosave_recovery(self, model_path: str):
1403
+ """After loading a model, check if an autosave from a crashed batch run exists."""
1404
+ autosave_path = os.path.join(os.path.dirname(model_path), "_inference_autosave.json")
1405
+ if not os.path.exists(autosave_path):
1406
+ return
1407
+ try:
1408
+ import json as _json
1409
+ with open(autosave_path) as f:
1410
+ data = _json.load(f)
1411
+ n_videos = len(data.get("results", {}))
1412
+ if n_videos == 0:
1413
+ return
1414
+ reply = QMessageBox.question(
1415
+ self,
1416
+ "Recover interrupted batch run?",
1417
+ f"Found an autosave from a previous session with {n_videos} video(s) already processed.\n\n"
1418
+ f"Would you like to restore those results now?\n\n"
1419
+ f"(Autosave file: {os.path.basename(autosave_path)})",
1420
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
1421
+ )
1422
+ if reply == QMessageBox.StandardButton.Yes:
1423
+ self._load_timeline_results(path=autosave_path)
1424
+ self.log_text.append(
1425
+ f"Recovered {n_videos} video(s) from autosave. "
1426
+ "You can continue running inference on remaining videos."
1427
+ )
1428
+ except Exception as e:
1429
+ self.log_text.append(f"[Autosave] Could not read autosave file: {e}")
1430
+
1431
+ def _save_results(self, silent=False):
1432
+ """Save current inference results to JSON."""
1433
+ if not self.results_cache:
1434
+ return
1435
+
1436
+ try:
1437
+ # Determine default save path based on experiment config
1438
+ exp_path = self.config.get("experiment_path")
1439
+ if exp_path:
1440
+ results_dir = os.path.join(exp_path, "results")
1441
+ else:
1442
+ results_dir = os.path.join(self.config.get("data_dir", "data"), "results")
1443
+
1444
+ os.makedirs(results_dir, exist_ok=True)
1445
+ results_path = os.path.join(results_dir, "inference_results.json")
1446
+
1447
+ # If manual save and loaded from file, ask user
1448
+ if not silent and hasattr(self, 'loaded_results_path') and self.loaded_results_path:
1449
+ msg = QMessageBox()
1450
+ msg.setIcon(QMessageBox.Icon.Question)
1451
+ msg.setText("Save timeline results")
1452
+ msg.setInformativeText(f"Results loaded from:\n{self.loaded_results_path}\n\nDo you want to overwrite or save as new?")
1453
+ msg.setWindowTitle("Save results")
1454
+ overwrite_btn = msg.addButton("Overwrite", QMessageBox.ButtonRole.AcceptRole)
1455
+ save_as_btn = msg.addButton("Save As...", QMessageBox.ButtonRole.ActionRole)
1456
+ msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
1457
+ msg.exec()
1458
+
1459
+ if msg.clickedButton() == overwrite_btn:
1460
+ results_path = self.loaded_results_path
1461
+ elif msg.clickedButton() == save_as_btn:
1462
+ path, _ = QFileDialog.getSaveFileName(
1463
+ self, "Save Results As", self.loaded_results_path, "JSON Files (*.json)"
1464
+ )
1465
+ if not path: return
1466
+ results_path = path
1467
+ else:
1468
+ return
1469
+
1470
+ # Persist current video state into cache so it is saved
1471
+ self._persist_current_video_state()
1472
+
1473
+ # If frame aggregation is enabled, persist aggregated segments for all videos.
1474
+ frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
1475
+ if frame_aggregation_enabled:
1476
+ state = {
1477
+ "video_path": self.video_path,
1478
+ "predictions": self.predictions,
1479
+ "confidences": self.confidences,
1480
+ "clip_probabilities": self.clip_probabilities,
1481
+ "clip_starts": self.clip_starts,
1482
+ "clip_frame_probabilities": getattr(self, "clip_frame_probabilities", []),
1483
+ "total_frames": self.total_frames,
1484
+ "corrected_labels": self.corrected_labels,
1485
+ "aggregated_segments": self.aggregated_segments,
1486
+ "aggregated_multiclass_segments": self.aggregated_multiclass_segments,
1487
+ }
1488
+ try:
1489
+ for v_path, entry in self.results_cache.items():
1490
+ if not isinstance(entry, dict):
1491
+ continue
1492
+ preds = entry.get("predictions", [])
1493
+ starts = entry.get("clip_starts", [])
1494
+ if not preds or not starts:
1495
+ continue
1496
+ self.video_path = v_path
1497
+ self.predictions = preds
1498
+ self.confidences = entry.get("confidences", [])
1499
+ self.clip_probabilities = entry.get("clip_probabilities", [])
1500
+ self.clip_starts = starts
1501
+ self.total_frames = int(entry.get("total_frames", 0) or 0)
1502
+ self.corrected_labels = entry.get("corrected_labels", {})
1503
+ precomputed = entry.get("aggregated_frame_probs")
1504
+ if isinstance(precomputed, list):
1505
+ precomputed = np.asarray(precomputed, dtype=np.float32)
1506
+ if isinstance(precomputed, np.ndarray) and precomputed.ndim == 2 and precomputed.shape[0] > 0:
1507
+ self._aggregated_frame_scores_norm = precomputed
1508
+ self._aggregated_last_covered_frame = precomputed.shape[0]
1509
+ self._build_timeline_segments()
1510
+ else:
1511
+ self.clip_frame_probabilities = entry.get("clip_frame_probabilities", [])
1512
+ self.aggregated_segments = []
1513
+ self.aggregated_multiclass_segments = []
1514
+ self._compute_aggregated_timeline()
1515
+ entry["aggregated_segments"] = copy.deepcopy(self.aggregated_segments)
1516
+ entry["aggregated_multiclass_segments"] = copy.deepcopy(self.aggregated_multiclass_segments)
1517
+ finally:
1518
+ self.video_path = state["video_path"]
1519
+ self.predictions = state["predictions"]
1520
+ self.confidences = state["confidences"]
1521
+ self.clip_probabilities = state["clip_probabilities"]
1522
+ self.clip_starts = state["clip_starts"]
1523
+ self.clip_frame_probabilities = state["clip_frame_probabilities"]
1524
+ self.total_frames = state["total_frames"]
1525
+ self.corrected_labels = state["corrected_labels"]
1526
+ self.aggregated_segments = state["aggregated_segments"]
1527
+ self.aggregated_multiclass_segments = state["aggregated_multiclass_segments"]
1528
+
1529
+ save_data = {
1530
+ "classes": self.classes,
1531
+ "parameters": {
1532
+ "target_fps": self.target_fps_spin.value(),
1533
+ "clip_length": self.clip_length_spin.value(),
1534
+ "step_frames": self.step_frames_spin.value(),
1535
+ "sample_inference_enabled": bool(self.sample_inference_check.isChecked()),
1536
+ "sample_duration_seconds": int(self.sample_duration_spin.value()),
1537
+ "sample_num_chunks": int(self.sample_count_spin.value()),
1538
+ "frame_aggregation_enabled": frame_aggregation_enabled,
1539
+ "merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
1540
+ "use_viterbi_decode": bool(self.use_viterbi_decode),
1541
+ "viterbi_switch_penalty": float(self.viterbi_switch_penalty),
1542
+ "min_segment_frames": int(self._min_segment_frames),
1543
+ "merge_gap_frames": int(self._merge_gap_frames),
1544
+ "temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
1545
+ "class_min_segment_frames": {
1546
+ cls: int(v) for cls, v in self.class_min_segment_frames.items()
1547
+ },
1548
+ "class_merge_gap_frames": {
1549
+ cls: int(v) for cls, v in self.class_merge_gap_frames.items()
1550
+ },
1551
+ "class_smoothing_window_frames": {
1552
+ cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
1553
+ },
1554
+ "use_ovr": bool(self._use_ovr),
1555
+ "ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
1556
+ "ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
1557
+ "use_ignore_threshold": bool(self.use_ignore_threshold),
1558
+ "ignore_threshold": float(self.global_ignore_threshold),
1559
+ "class_ignore_thresholds": {
1560
+ cls: float(t) for cls, t in self.class_ignore_thresholds.items()
1561
+ },
1562
+ "allowed_cooccurrence": [
1563
+ [a, b] for (a, b) in sorted(self._allowed_cooccurrence) if a <= b
1564
+ ],
1565
+ "timeline_zoom": int(self.timeline_zoom_spin.value()),
1566
+ },
1567
+ "results": self.results_cache
1568
+ }
1569
+ self._write_results_bundle(results_path, save_data, pretty=True)
1570
+
1571
+ if frame_aggregation_enabled and self.aggregated_segments:
1572
+ self.log_text.append(f"Inference results saved to: {results_path} (with {len(self.aggregated_segments)} aggregated segments)")
1573
+ else:
1574
+ self.log_text.append(f"Inference results saved to: {results_path}")
1575
+
1576
+ if not silent:
1577
+ QMessageBox.information(self, "Saved", f"Results saved to:\n{results_path}")
1578
+ # Update loaded path if we saved to a new location
1579
+ self.loaded_results_path = results_path
1580
+
1581
+ except Exception as e:
1582
+ self.log_text.append(f"Warning: Failed to save inference results: {e}")
1583
+ if not silent:
1584
+ QMessageBox.critical(self, "Error", f"Failed to save results:\n{str(e)}")
1585
+
1586
+ def _load_timeline_results(self, path: str = None):
1587
+ """Load previously saved timeline results.
1588
+
1589
+ If *path* is provided (e.g. autosave recovery), skip the file dialog.
1590
+ """
1591
+ if path:
1592
+ file_path = path
1593
+ else:
1594
+ # Prompt user to select results file
1595
+ exp_path = self.config.get("experiment_path")
1596
+ if exp_path:
1597
+ default_dir = os.path.join(exp_path, "results")
1598
+ else:
1599
+ default_dir = os.path.join(self.config.get("data_dir", "data"), "results")
1600
+
1601
+ file_path, _ = QFileDialog.getOpenFileName(
1602
+ self,
1603
+ "Load Timeline Results",
1604
+ default_dir,
1605
+ "JSON Files (*.json)"
1606
+ )
1607
+
1608
+ if not file_path:
1609
+ return
1610
+
1611
+ self.loaded_results_path = file_path
1612
+
1613
+ MAX_LOAD_SIZE_MB = 300
1614
+ if os.path.getsize(file_path) > MAX_LOAD_SIZE_MB * 1024 * 1024:
1615
+ QMessageBox.critical(
1616
+ self,
1617
+ "File too large",
1618
+ f"This results file is over {MAX_LOAD_SIZE_MB} MB and may cause a crash when loading.\n\n"
1619
+ "Re-run inference and save again (attention maps are now excluded from saves).\n"
1620
+ "Or load a smaller/simpler results file.",
1621
+ )
1622
+ return
1623
+
1624
+ try:
1625
+ with open(file_path, 'r', encoding='utf-8') as f:
1626
+ data = json.load(f)
1627
+ self._restore_external_arrays(file_path, data)
1628
+
1629
+ # Extract data
1630
+ self.classes = data.get("classes", [])
1631
+ parameters = data.get("parameters", {})
1632
+ self.results_cache = data.get("results", {})
1633
+ self._ignore_threshold_user_modified = False
1634
+
1635
+ if not self.classes or not self.results_cache:
1636
+ QMessageBox.warning(self, "Invalid File", "The selected file does not contain valid inference results.")
1637
+ return
1638
+
1639
+ for v_path in self.results_cache:
1640
+ if "corrected_labels" not in self.results_cache[v_path]:
1641
+ self.results_cache[v_path]["corrected_labels"] = {}
1642
+ if "corrected_attr_labels" not in self.results_cache[v_path]:
1643
+ self.results_cache[v_path]["corrected_attr_labels"] = {}
1644
+
1645
+ # Update UI with loaded parameters
1646
+ self.target_fps_spin.setValue(parameters.get("target_fps", 16))
1647
+ self.clip_length_spin.setValue(parameters.get("clip_length", 16))
1648
+ self.step_frames_spin.setValue(max(1, self.clip_length_spin.value() // 2))
1649
+ self.sample_inference_check.setChecked(bool(parameters.get("sample_inference_enabled", False)))
1650
+ self.sample_duration_spin.setValue(int(parameters.get("sample_duration_seconds", 60)))
1651
+ self.sample_count_spin.setValue(int(parameters.get("sample_num_chunks", 5)))
1652
+ loaded_use_ovr = parameters.get("use_ovr", None)
1653
+ if loaded_use_ovr is None:
1654
+ loaded_use_ovr = any(
1655
+ bool((v or {}).get("aggregated_multiclass_segments", []))
1656
+ for v in self.results_cache.values()
1657
+ if isinstance(v, dict)
1658
+ )
1659
+ self._use_ovr = bool(loaded_use_ovr)
1660
+ self._allowed_cooccurrence = set()
1661
+ for pair in parameters.get("allowed_cooccurrence", []):
1662
+ if isinstance(pair, (list, tuple)) and len(pair) == 2:
1663
+ a, b = pair[0], pair[1]
1664
+ self._allowed_cooccurrence.add((a, b))
1665
+ self._allowed_cooccurrence.add((b, a))
1666
+ if getattr(self, "ovr_rows_check", None) is not None and self._use_ovr:
1667
+ self.ovr_rows_check.setChecked(True)
1668
+ if getattr(self, "ovr_rows_check", None) is not None and "ovr_rows" in parameters:
1669
+ self.ovr_rows_check.setChecked(bool(parameters.get("ovr_rows", True)))
1670
+ if getattr(self, "ovr_show_all_check", None) is not None:
1671
+ self.ovr_show_all_check.setChecked(bool(parameters.get("ovr_show_all", False)))
1672
+ if "merge_consecutive_enabled" in parameters:
1673
+ self.merge_timeline_check.setChecked(bool(parameters.get("merge_consecutive_enabled", False)))
1674
+ if "use_viterbi_decode" in parameters:
1675
+ self.use_viterbi_decode = bool(parameters.get("use_viterbi_decode", False))
1676
+ self.use_viterbi_check.setChecked(self.use_viterbi_decode)
1677
+ if "viterbi_switch_penalty" in parameters:
1678
+ self.viterbi_switch_penalty = float(parameters.get("viterbi_switch_penalty", 0.35))
1679
+ self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
1680
+
1681
+ # Restore frame aggregation setting
1682
+ frame_aggregation_enabled = parameters.get("frame_aggregation_enabled", False)
1683
+ self.frame_aggregation_check.setChecked(frame_aggregation_enabled)
1684
+ if "min_segment_frames" in parameters:
1685
+ self._min_segment_frames = int(parameters.get("min_segment_frames", 1))
1686
+ if "merge_gap_frames" in parameters:
1687
+ self._merge_gap_frames = int(parameters.get("merge_gap_frames", 0))
1688
+ if "temporal_smoothing_window_frames" in parameters:
1689
+ loaded_win = int(parameters.get("temporal_smoothing_window_frames", 1))
1690
+ if loaded_win % 2 == 0:
1691
+ loaded_win += 1
1692
+ self._temporal_smoothing_window_frames = max(1, loaded_win)
1693
+ if "class_min_segment_frames" in parameters:
1694
+ self.class_min_segment_frames = {
1695
+ str(cls): int(v) for cls, v in parameters.get("class_min_segment_frames", {}).items()
1696
+ }
1697
+ if "class_merge_gap_frames" in parameters:
1698
+ self.class_merge_gap_frames = {
1699
+ str(cls): int(v) for cls, v in parameters.get("class_merge_gap_frames", {}).items()
1700
+ }
1701
+ if "class_smoothing_window_frames" in parameters:
1702
+ self.class_smoothing_window_frames = {
1703
+ str(cls): int(v) for cls, v in parameters.get("class_smoothing_window_frames", {}).items()
1704
+ }
1705
+ self._sync_per_class_segment_rule_config()
1706
+
1707
+ # Restore timeline zoom
1708
+ if "timeline_zoom" in parameters:
1709
+ self.timeline_zoom_spin.blockSignals(True)
1710
+ self.timeline_zoom_spin.setValue(int(parameters["timeline_zoom"]))
1711
+ self.timeline_zoom_spin.blockSignals(False)
1712
+
1713
+ # Restore ignore threshold settings so the loaded timeline matches what was saved
1714
+ if "use_ignore_threshold" in parameters:
1715
+ self.use_ignore_threshold = bool(parameters["use_ignore_threshold"])
1716
+ self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
1717
+ if "ignore_threshold" in parameters:
1718
+ self.global_ignore_threshold = float(parameters["ignore_threshold"])
1719
+ self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
1720
+ if "class_ignore_thresholds" in parameters:
1721
+ self.class_ignore_thresholds = {
1722
+ cls: float(t) for cls, t in parameters["class_ignore_thresholds"].items()
1723
+ }
1724
+ self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
1725
+ self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
1726
+ self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
1727
+ self._update_viterbi_ui_state()
1728
+
1729
+ self.classes_label.setText(", ".join(self.classes))
1730
+
1731
+ # Populate video filter combo
1732
+ self.filter_video_combo.blockSignals(True)
1733
+ self.filter_video_combo.clear()
1734
+
1735
+ video_paths = list(self.results_cache.keys())
1736
+ for vp in video_paths:
1737
+ video_name = os.path.basename(vp)
1738
+ self.filter_video_combo.addItem(video_name, vp)
1739
+
1740
+ self.filter_video_combo.blockSignals(False)
1741
+
1742
+ # Load first video's results
1743
+ if video_paths:
1744
+ first_video = video_paths[0]
1745
+ self.video_path = first_video
1746
+ self.video_paths = video_paths
1747
+
1748
+ # Update video path display
1749
+ if len(video_paths) == 1:
1750
+ self.video_path_edit.setText(first_video)
1751
+ else:
1752
+ self.video_path_edit.setText(f"{len(video_paths)} videos loaded")
1753
+
1754
+ # Load first video data
1755
+ data = self.results_cache[first_video]
1756
+ self.predictions = data["predictions"]
1757
+ self.confidences = data["confidences"]
1758
+ self.clip_probabilities = data.get("clip_probabilities", [])
1759
+ self.clip_frame_probabilities = data.get("clip_frame_probabilities", [])
1760
+ self.attr_predictions = data.get("attr_predictions", [])
1761
+ self.attr_confidences = data.get("attr_confidences", [])
1762
+ self.clip_starts = data["clip_starts"]
1763
+ self.localization_bboxes = data.get("localization_bboxes", [])
1764
+ self.total_frames = data.get("total_frames", 0)
1765
+ self.corrected_labels = data.get("corrected_labels", {})
1766
+ self.corrected_attr_labels = data.get("corrected_attr_labels", {})
1767
+
1768
+ # Backwards compatibility: if total_frames not saved, try to get from video
1769
+ if self.total_frames <= 0 and os.path.exists(first_video):
1770
+ try:
1771
+ cap = cv2.VideoCapture(first_video)
1772
+ self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1773
+ cap.release()
1774
+ # Store for future use
1775
+ data["total_frames"] = self.total_frames
1776
+ except Exception:
1777
+ pass
1778
+
1779
+ # Load aggregated segments if saved, otherwise compute if frame aggregation enabled
1780
+ if "aggregated_segments" in data:
1781
+ self.aggregated_segments = data["aggregated_segments"]
1782
+ self.aggregated_multiclass_segments = data.get("aggregated_multiclass_segments", [])
1783
+ self.log_text.append(f"Loaded {len(self.aggregated_segments)} saved aggregated segments")
1784
+
1785
+ # Load precomputed frame probs if available (may be list when loaded from JSON)
1786
+ precomputed_probs = data.get("aggregated_frame_probs")
1787
+ if isinstance(precomputed_probs, list):
1788
+ precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
1789
+ if isinstance(precomputed_probs, np.ndarray):
1790
+ self._aggregated_frame_scores_norm = precomputed_probs
1791
+ self._aggregated_last_covered_frame = len(precomputed_probs)
1792
+ elif frame_aggregation_enabled and self._use_ovr and self.clip_probabilities:
1793
+ self._compute_aggregated_timeline()
1794
+ elif frame_aggregation_enabled:
1795
+ # Check for pre-computed frame probs (from worker or JSON load)
1796
+ precomputed_probs = data.get("aggregated_frame_probs")
1797
+ if isinstance(precomputed_probs, list):
1798
+ precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
1799
+ if isinstance(precomputed_probs, np.ndarray):
1800
+ self._aggregated_frame_scores_norm = precomputed_probs
1801
+ self._aggregated_last_covered_frame = len(precomputed_probs)
1802
+ self._build_timeline_segments()
1803
+ else:
1804
+ self._compute_aggregated_timeline()
1805
+
1806
+ # Populate behavior filter
1807
+ self.filter_behavior_combo.blockSignals(True)
1808
+ self.filter_behavior_combo.clear()
1809
+ self.filter_behavior_combo.addItem("All Behaviors")
1810
+ self.filter_behavior_combo.addItem(self.ignore_label_name)
1811
+ self.filter_behavior_combo.addItems(self.classes)
1812
+ self.filter_behavior_combo.blockSignals(False)
1813
+
1814
+ # Display results
1815
+ self._display_results()
1816
+
1817
+ # Enable export/preview buttons
1818
+ self.export_btn.setEnabled(True)
1819
+ self.preview_btn.setEnabled(True)
1820
+ self.export_timeline_btn.setEnabled(True)
1821
+ self.save_results_btn.setEnabled(True) # Enable saving loaded/corrected results
1822
+
1823
+ self.log_text.append(f"Loaded timeline results from: {file_path}")
1824
+ self.log_text.append(f"Loaded {len(video_paths)} video(s) with {len(self.predictions)} predictions")
1825
+
1826
+ QMessageBox.information(
1827
+ self,
1828
+ "Loaded",
1829
+ f"Successfully loaded timeline results.\n\n"
1830
+ f"Videos: {len(video_paths)}\n"
1831
+ f"Classes: {len(self.classes)}\n\n"
1832
+ f"You can now review clips, make corrections, and re-save."
1833
+ )
1834
+
1835
+ except Exception as e:
1836
+ self.log_text.append(f"Error loading timeline results: {e}")
1837
+ QMessageBox.critical(self, "Error", f"Failed to load timeline results:\n{str(e)}")
1838
+
1839
+ def _on_video_selection_changed(self, index: int):
1840
+ """Handle video selection change."""
1841
+ video_path = self.filter_video_combo.currentData()
1842
+ self._load_video_from_cache(video_path)
1843
+
1844
+ def _current_threshold_settings(self) -> dict:
1845
+ return {
1846
+ "use_ignore_threshold": bool(self.use_ignore_threshold),
1847
+ "ignore_threshold": float(self.global_ignore_threshold),
1848
+ "class_ignore_thresholds": {
1849
+ str(cls): float(t) for cls, t in self.class_ignore_thresholds.items()
1850
+ },
1851
+ "user_modified": bool(self._ignore_threshold_user_modified),
1852
+ }
1853
+
1854
+ def _apply_threshold_settings(self, settings: dict):
1855
+ use_ignore = bool(settings.get("use_ignore_threshold", self.use_ignore_threshold))
1856
+ tau = settings.get("ignore_threshold", self.global_ignore_threshold)
1857
+ try:
1858
+ tau = float(tau)
1859
+ except Exception:
1860
+ tau = self.global_ignore_threshold
1861
+ raw_per_class = settings.get("class_ignore_thresholds", {})
1862
+ if not isinstance(raw_per_class, dict):
1863
+ raw_per_class = {}
1864
+ per_class = {str(cls): float(t) for cls, t in raw_per_class.items()}
1865
+
1866
+ self._applying_auto_threshold = True
1867
+ try:
1868
+ self.use_ignore_threshold_check.blockSignals(True)
1869
+ self.ignore_threshold_spin.blockSignals(True)
1870
+ try:
1871
+ self.use_ignore_threshold_check.setChecked(use_ignore)
1872
+ self.ignore_threshold_spin.setValue(tau)
1873
+ finally:
1874
+ self.use_ignore_threshold_check.blockSignals(False)
1875
+ self.ignore_threshold_spin.blockSignals(False)
1876
+ finally:
1877
+ self._applying_auto_threshold = False
1878
+
1879
+ self.use_ignore_threshold = use_ignore
1880
+ self.global_ignore_threshold = tau
1881
+ self.class_ignore_thresholds = per_class
1882
+ self._ignore_threshold_user_modified = bool(settings.get("user_modified", False))
1883
+ self.config["inference_use_ignore_threshold"] = self.use_ignore_threshold
1884
+ self.config["inference_ignore_threshold"] = self.global_ignore_threshold
1885
+ self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
1886
+
1887
+ def _persist_current_video_state(self):
1888
+ if not self.video_path or self.video_path not in self.results_cache:
1889
+ return
1890
+ entry = self.results_cache[self.video_path]
1891
+ entry["corrected_labels"] = dict(self.corrected_labels)
1892
+ entry["corrected_attr_labels"] = dict(self.corrected_attr_labels)
1893
+ entry["threshold_settings"] = self._current_threshold_settings()
1894
+
1895
+ def _load_video_from_cache(
1896
+ self,
1897
+ video_path: str,
1898
+ refresh_display: bool = True,
1899
+ persist_current: bool = True,
1900
+ threshold_settings_override: dict | None = None,
1901
+ persist_loaded_thresholds: bool = True,
1902
+ ) -> bool:
1903
+ if not video_path or video_path not in self.results_cache:
1904
+ return False
1905
+
1906
+ if persist_current:
1907
+ self._persist_current_video_state()
1908
+
1909
+ data = self.results_cache[video_path]
1910
+ self.predictions = data["predictions"]
1911
+ self.confidences = data["confidences"]
1912
+ self.clip_probabilities = data.get("clip_probabilities", [])
1913
+ self.clip_frame_probabilities = data.get("clip_frame_probabilities", [])
1914
+ self.attr_predictions = data.get("attr_predictions", [])
1915
+ self.attr_confidences = data.get("attr_confidences", [])
1916
+ self.clip_starts = data["clip_starts"]
1917
+ self.localization_bboxes = data.get("localization_bboxes", [])
1918
+ self.total_frames = data.get("total_frames", 0)
1919
+ self.corrected_labels = data.get("corrected_labels", {})
1920
+ self.corrected_attr_labels = data.get("corrected_attr_labels", {})
1921
+ self.video_path = video_path
1922
+
1923
+ threshold_settings = threshold_settings_override
1924
+ if threshold_settings is None:
1925
+ threshold_settings = data.get("threshold_settings")
1926
+ if isinstance(threshold_settings, dict):
1927
+ self._apply_threshold_settings(threshold_settings)
1928
+ else:
1929
+ self._auto_update_ignore_threshold()
1930
+ if persist_loaded_thresholds:
1931
+ data["threshold_settings"] = self._current_threshold_settings()
1932
+
1933
+ if self.clip_frame_probabilities and not self.frame_aggregation_check.isChecked():
1934
+ self.frame_aggregation_check.setChecked(True)
1935
+ self.log_text.append("Frame-head outputs found for this video. Switched to precise frame-boundary mode.")
1936
+
1937
+ if self.total_frames <= 0 and os.path.exists(video_path):
1938
+ try:
1939
+ cap = cv2.VideoCapture(video_path)
1940
+ self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1941
+ cap.release()
1942
+ data["total_frames"] = self.total_frames
1943
+ except Exception:
1944
+ pass
1945
+
1946
+ if self.frame_aggregation_check.isChecked():
1947
+ precomputed_probs = data.get("aggregated_frame_probs")
1948
+ if isinstance(precomputed_probs, list):
1949
+ precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
1950
+ if isinstance(precomputed_probs, np.ndarray):
1951
+ self._aggregated_frame_scores_norm = precomputed_probs
1952
+ self._aggregated_last_covered_frame = len(precomputed_probs)
1953
+ self._build_timeline_segments()
1954
+ elif self.use_ignore_threshold:
1955
+ self._compute_aggregated_timeline()
1956
+ elif "aggregated_segments" in data:
1957
+ self.aggregated_segments = data["aggregated_segments"]
1958
+ self.aggregated_multiclass_segments = data.get("aggregated_multiclass_segments", [])
1959
+ self._aggregated_frame_scores_norm = None
1960
+ self._aggregated_active_mask = None
1961
+ self._aggregated_last_covered_frame = 0
1962
+ if self._use_ovr:
1963
+ self._compute_aggregated_timeline()
1964
+ else:
1965
+ self._compute_aggregated_timeline()
1966
+ else:
1967
+ self.aggregated_segments = []
1968
+ self.aggregated_multiclass_segments = []
1969
+ self._aggregated_frame_scores_norm = None
1970
+ self._aggregated_active_mask = None
1971
+ self._aggregated_last_covered_frame = 0
1972
+
1973
+ if refresh_display:
1974
+ self._display_results()
1975
+ return True
1976
+
1977
+ def _on_error(self, error_msg: str):
1978
+ """Handle inference error."""
1979
+ self.progress_bar.setVisible(False)
1980
+ self.progress_bar.setValue(0)
1981
+ self.progress_label.setText(f"Error: {error_msg}")
1982
+ self.run_inference_btn.setEnabled(True)
1983
+ self.stop_inference_btn.setEnabled(False)
1984
+ self.load_timeline_btn.setEnabled(True)
1985
+ QMessageBox.critical(self, "Inference Error", f"Inference failed:\n{error_msg}")
1986
+
1987
+ def _stop_inference(self):
1988
+ """Request inference worker to stop; completed videos are kept."""
1989
+ if getattr(self, "worker", None) and self.worker.isRunning():
1990
+ self.worker.stop()
1991
+ self.progress_label.setText("Stopping... (keeping completed videos)")
1992
+
1993
+ def _on_log(self, message: str):
1994
+ """Handle log message."""
1995
+ self.log_text.append(message)
1996
+
1997
+ def _display_results(self):
1998
+ """Display inference results."""
1999
+ self._update_viterbi_ui_state()
2000
+ self.results_list.clear()
2001
+
2002
+ # In precise frame mode, show segment-level results instead of per-clip
2003
+ # entries to avoid confusion when a clip contains multiple behaviors.
2004
+ if self.frame_aggregation_check.isChecked() and self.aggregated_segments:
2005
+ for i, seg in enumerate(self.aggregated_segments):
2006
+ cls_idx = int(seg.get("class", -1))
2007
+ start_f = int(seg.get("start", 0))
2008
+ end_f = int(seg.get("end", start_f))
2009
+ conf = float(seg.get("confidence", 0.0))
2010
+ if cls_idx < 0 or cls_idx >= len(self.classes):
2011
+ label = self.ignore_label_name
2012
+ else:
2013
+ label = self.classes[cls_idx]
2014
+ item_text = f"Segment {i+1}: {label} (frames {start_f}-{end_f}, confidence: {conf:.2%})"
2015
+ item = QListWidgetItem(item_text)
2016
+ if conf > 0.7:
2017
+ item.setForeground(QColor(0, 150, 0))
2018
+ elif conf > 0.5:
2019
+ item.setForeground(QColor(200, 150, 0))
2020
+ else:
2021
+ item.setForeground(QColor(200, 0, 0))
2022
+ self.results_list.addItem(item)
2023
+ self._draw_timeline()
2024
+ return
2025
+
2026
+ has_attrs = bool(self.attr_predictions and self.attributes)
2027
+
2028
+ effective_preds = self._effective_predictions()
2029
+ for i, (pred_idx, conf) in enumerate(zip(effective_preds, self.confidences)):
2030
+ if pred_idx < len(self.classes) and pred_idx >= 0:
2031
+ label = self.classes[pred_idx]
2032
+
2033
+ # Append Attribute if available
2034
+ attr_conf = None
2035
+ if has_attrs and i < len(self.attr_predictions):
2036
+ attr_idx = self.attr_predictions[i]
2037
+ if isinstance(attr_idx, int) and attr_idx < len(self.attributes):
2038
+ attr_label = self.attributes[attr_idx]
2039
+ if self.attr_confidences and i < len(self.attr_confidences):
2040
+ attr_conf = self.attr_confidences[i]
2041
+ label = f"{label} ({attr_label})"
2042
+
2043
+ extra_labels = ""
2044
+ if self._use_ovr and i < len(self.clip_probabilities):
2045
+ probs_i = self.clip_probabilities[i]
2046
+ if isinstance(probs_i, (list, tuple)) and len(probs_i) == len(self.classes):
2047
+ active = self._filter_cooccurrence(probs_i, 0.3)
2048
+ above = []
2049
+ for ci in sorted(active):
2050
+ if ci == pred_idx:
2051
+ continue
2052
+ above.append(f"{self.classes[ci]}:{float(probs_i[ci]):.0%}")
2053
+ if above:
2054
+ extra_labels = " + " + ", ".join(above)
2055
+
2056
+ if attr_conf is not None:
2057
+ item_text = f"Clip {i+1}: {label} (class: {conf:.2%}, {attr_label}: {attr_conf:.2%}){extra_labels}"
2058
+ else:
2059
+ item_text = f"Clip {i+1}: {label} (confidence: {conf:.2%}){extra_labels}"
2060
+ item = QListWidgetItem(item_text)
2061
+
2062
+ if conf > 0.7:
2063
+ item.setForeground(QColor(0, 150, 0))
2064
+ elif conf > 0.5:
2065
+ item.setForeground(QColor(200, 150, 0))
2066
+ else:
2067
+ item.setForeground(QColor(200, 0, 0))
2068
+
2069
+ self.results_list.addItem(item)
2070
+ elif pred_idx < 0:
2071
+ item = QListWidgetItem(f"Clip {i+1}: {self.ignore_label_name} (confidence: {conf:.2%})")
2072
+ item.setForeground(QColor(120, 120, 120))
2073
+ self.results_list.addItem(item)
2074
+
2075
+ self._draw_timeline()
2076
+
2077
+ def _threshold_for_pred(self, pred_idx: int) -> float:
2078
+ if pred_idx < 0 or pred_idx >= len(self.classes):
2079
+ return self.global_ignore_threshold
2080
+ cls = self.classes[pred_idx]
2081
+ if cls in self.class_ignore_thresholds:
2082
+ return float(self.class_ignore_thresholds[cls])
2083
+ return self.global_ignore_threshold
2084
+
2085
+ def _effective_prediction_for_clip(self, clip_idx: int) -> int:
2086
+ pred_idx = self.corrected_labels[clip_idx] if clip_idx in self.corrected_labels else self.predictions[clip_idx]
2087
+ if not self.use_ignore_threshold:
2088
+ return pred_idx
2089
+ if clip_idx >= len(self.confidences):
2090
+ return pred_idx
2091
+ # Round to 3 decimals so 0.9997 (displayed as ~100%) isn't below a 1.000 threshold
2092
+ conf = round(float(self.confidences[clip_idx]), 3)
2093
+ if conf < self._threshold_for_pred(pred_idx):
2094
+ return -1
2095
+ return pred_idx
2096
+
2097
+ def _effective_predictions(self):
2098
+ if not self.predictions:
2099
+ return []
2100
+ return [self._effective_prediction_for_clip(i) for i in range(len(self.predictions))]
2101
+
2102
+ def _on_ignore_threshold_changed(self, *_args):
2103
+ self.use_ignore_threshold = bool(self.use_ignore_threshold_check.isChecked())
2104
+ self.global_ignore_threshold = float(self.ignore_threshold_spin.value())
2105
+ self.config["inference_use_ignore_threshold"] = self.use_ignore_threshold
2106
+ self.config["inference_ignore_threshold"] = self.global_ignore_threshold
2107
+ if not self._applying_auto_threshold:
2108
+ self._ignore_threshold_user_modified = True
2109
+ self._persist_current_video_state()
2110
+ if self.predictions:
2111
+ if self.frame_aggregation_check.isChecked():
2112
+ self._compute_aggregated_timeline()
2113
+ self._display_results()
2114
+
2115
+ def _update_viterbi_ui_state(self):
2116
+ """Keep Viterbi controls enabled and explain current mode."""
2117
+ self.use_viterbi_check.setEnabled(True)
2118
+ self.viterbi_switch_penalty_spin.setEnabled(bool(self.use_viterbi_decode))
2119
+ if self._use_ovr:
2120
+ self.use_viterbi_check.setToolTip(
2121
+ "Inference-only sequence decoding on the merged frame probabilities.\n"
2122
+ "OvR uses binary per-class Viterbi, then keeps top classes consistent\n"
2123
+ "with allowed co-occurrence rules."
2124
+ )
2125
+ else:
2126
+ self.use_viterbi_check.setToolTip(
2127
+ "Inference-only sequence decoding for the single-label timeline.\n"
2128
+ "Reduces rapid frame-to-frame class switching without retraining."
2129
+ )
2130
+
2131
+ def _on_viterbi_changed(self, *_args):
2132
+ self.use_viterbi_decode = bool(self.use_viterbi_check.isChecked())
2133
+ self.viterbi_switch_penalty = float(self.viterbi_switch_penalty_spin.value())
2134
+ self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
2135
+ self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
2136
+ if self.use_viterbi_decode and not self.frame_aggregation_check.isChecked():
2137
+ self.frame_aggregation_check.setChecked(True)
2138
+ self._update_viterbi_ui_state()
2139
+ self._persist_current_video_state()
2140
+ if self.predictions and self.frame_aggregation_check.isChecked():
2141
+ self._compute_aggregated_timeline()
2142
+ self._display_results()
2143
+
2144
+ def _auto_update_ignore_threshold(self):
2145
+ """Auto-update ignore thresholds, preferring validation calibration."""
2146
+ if self._ignore_threshold_user_modified:
2147
+ return
2148
+ cfg = self.model_training_config if isinstance(self.model_training_config, dict) else {}
2149
+ calibrated = cfg.get("validation_calibrated_ignore_thresholds")
2150
+ if isinstance(calibrated, dict):
2151
+ tau = calibrated.get("global_threshold", self.global_ignore_threshold)
2152
+ try:
2153
+ tau = float(tau)
2154
+ except Exception:
2155
+ tau = self.global_ignore_threshold
2156
+ tau = max(0.35, min(0.90, tau))
2157
+ raw_per_class = calibrated.get("per_class_thresholds", {})
2158
+ per_class = {}
2159
+ if isinstance(raw_per_class, dict):
2160
+ for cls_name in self.classes:
2161
+ if cls_name in raw_per_class:
2162
+ try:
2163
+ per_class[cls_name] = max(0.35, min(0.90, float(raw_per_class[cls_name])))
2164
+ except Exception:
2165
+ per_class[cls_name] = tau
2166
+ else:
2167
+ per_class[cls_name] = tau
2168
+ if per_class:
2169
+ self._applying_auto_threshold = True
2170
+ try:
2171
+ self.ignore_threshold_spin.blockSignals(True)
2172
+ try:
2173
+ self.ignore_threshold_spin.setValue(tau)
2174
+ finally:
2175
+ self.ignore_threshold_spin.blockSignals(False)
2176
+ finally:
2177
+ self._applying_auto_threshold = False
2178
+ self.global_ignore_threshold = tau
2179
+ self.class_ignore_thresholds = per_class
2180
+ self.config["inference_ignore_threshold"] = tau
2181
+ self.config["inference_class_ignore_thresholds"] = dict(per_class)
2182
+ source = str(calibrated.get("source", "validation"))
2183
+ self.log_text.append(
2184
+ f"Auto-set ignore thresholds from model validation: global τ={tau:.2f}, "
2185
+ f"per-class τ for {len(per_class)} classes ({source})"
2186
+ )
2187
+ return
2188
+
2189
+ if not self.confidences:
2190
+ return
2191
+ confs = np.array(self.confidences, dtype=float)
2192
+ confs = confs[np.isfinite(confs)]
2193
+ if confs.size == 0:
2194
+ return
2195
+
2196
+ has_validation = bool(cfg) and (cfg.get("use_all_for_training") is False)
2197
+ if has_validation:
2198
+ best_val_f1 = cfg.get("best_val_f1")
2199
+ try:
2200
+ best_val_f1 = float(best_val_f1) if best_val_f1 is not None else None
2201
+ except Exception:
2202
+ best_val_f1 = None
2203
+ if best_val_f1 is not None:
2204
+ quality = max(0.0, min(1.0, (best_val_f1 - 40.0) / 50.0))
2205
+ quantile = 0.80 - 0.20 * quality
2206
+ else:
2207
+ quantile = 0.68
2208
+ source = "current-video fallback (validation-aware quantile)"
2209
+ else:
2210
+ quantile = 0.70
2211
+ source = "current-video fallback"
2212
+
2213
+ tau = float(np.quantile(confs, quantile))
2214
+ tau = max(0.35, min(0.90, tau))
2215
+ self._applying_auto_threshold = True
2216
+ try:
2217
+ self.ignore_threshold_spin.blockSignals(True)
2218
+ try:
2219
+ self.ignore_threshold_spin.setValue(tau)
2220
+ finally:
2221
+ self.ignore_threshold_spin.blockSignals(False)
2222
+ finally:
2223
+ self._applying_auto_threshold = False
2224
+ self.global_ignore_threshold = tau
2225
+ self.config["inference_ignore_threshold"] = tau
2226
+
2227
+ min_class_support = 10
2228
+ per_class = {}
2229
+ for cls_idx, cls_name in enumerate(self.classes):
2230
+ cls_confs = [
2231
+ float(self.confidences[i])
2232
+ for i, pred_idx in enumerate(self.predictions)
2233
+ if pred_idx == cls_idx and i < len(self.confidences) and np.isfinite(self.confidences[i])
2234
+ ]
2235
+ if len(cls_confs) >= min_class_support:
2236
+ cls_tau = float(np.quantile(np.array(cls_confs, dtype=float), quantile))
2237
+ cls_tau = max(0.35, min(0.90, cls_tau))
2238
+ per_class[cls_name] = cls_tau
2239
+ else:
2240
+ per_class[cls_name] = tau
2241
+
2242
+ self.class_ignore_thresholds = per_class
2243
+ self.config["inference_class_ignore_thresholds"] = dict(per_class)
2244
+ self.log_text.append(
2245
+ f"Auto-set ignore thresholds: global τ={tau:.2f}, per-class τ for {len(per_class)} classes ({source})"
2246
+ )
2247
+
2248
+ def _open_per_class_thresholds_dialog(self):
2249
+ if not self.classes:
2250
+ QMessageBox.information(self, "No classes", "Load a model first.")
2251
+ return
2252
+ dlg = QDialog(self)
2253
+ dlg.setWindowTitle("Per-class Ignore Thresholds")
2254
+ layout = QFormLayout(dlg)
2255
+ spins = {}
2256
+ for cls in self.classes:
2257
+ sp = QDoubleSpinBox()
2258
+ sp.setDecimals(3)
2259
+ sp.setRange(0.0, 1.0)
2260
+ sp.setSingleStep(0.01)
2261
+ sp.setValue(float(self.class_ignore_thresholds.get(cls, self.global_ignore_threshold)))
2262
+ layout.addRow(cls, sp)
2263
+ spins[cls] = sp
2264
+ btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
2265
+ layout.addRow(btns)
2266
+ btns.accepted.connect(dlg.accept)
2267
+ btns.rejected.connect(dlg.reject)
2268
+ if dlg.exec():
2269
+ self.class_ignore_thresholds = {cls: float(sp.value()) for cls, sp in spins.items()}
2270
+ self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
2271
+ self._on_ignore_threshold_changed()
2272
+
2273
+ def _open_per_class_segment_rules_dialog(self):
2274
+ if not self.classes:
2275
+ QMessageBox.information(self, "No classes", "Load a model first.")
2276
+ return
2277
+
2278
+ dlg = QDialog(self)
2279
+ dlg.setWindowTitle("Per-class Segment Rules")
2280
+ dlg.resize(620, 520)
2281
+ root = QVBoxLayout(dlg)
2282
+ root.addWidget(QLabel(
2283
+ "Override precise-boundary postprocessing per class. "
2284
+ "Values matching the global controls act like defaults."
2285
+ ))
2286
+
2287
+ scroll = QScrollArea()
2288
+ scroll.setWidgetResizable(True)
2289
+ inner = QWidget()
2290
+ grid = QGridLayout(inner)
2291
+ grid.addWidget(QLabel("Class"), 0, 0)
2292
+ grid.addWidget(QLabel("Smooth"), 0, 1)
2293
+ grid.addWidget(QLabel("Gap"), 0, 2)
2294
+ grid.addWidget(QLabel("Min seg"), 0, 3)
2295
+
2296
+ controls = {}
2297
+ global_smooth = int(max(1, self._temporal_smoothing_window_frames))
2298
+ if global_smooth % 2 == 0:
2299
+ global_smooth += 1
2300
+ global_gap = int(max(0, self._merge_gap_frames))
2301
+ global_min_seg = int(max(1, self._min_segment_frames))
2302
+
2303
+ for row, cls in enumerate(self.classes, start=1):
2304
+ grid.addWidget(QLabel(cls), row, 0)
2305
+
2306
+ sp_smooth = QSpinBox()
2307
+ sp_smooth.setRange(1, 99)
2308
+ sp_smooth.setSingleStep(2)
2309
+ smooth_val = int(self.class_smoothing_window_frames.get(cls, global_smooth))
2310
+ if smooth_val % 2 == 0:
2311
+ smooth_val += 1
2312
+ sp_smooth.setValue(max(1, smooth_val))
2313
+ grid.addWidget(sp_smooth, row, 1)
2314
+
2315
+ sp_gap = QSpinBox()
2316
+ sp_gap.setRange(0, 200)
2317
+ sp_gap.setValue(int(self.class_merge_gap_frames.get(cls, global_gap)))
2318
+ grid.addWidget(sp_gap, row, 2)
2319
+
2320
+ sp_min = QSpinBox()
2321
+ sp_min.setRange(1, 200)
2322
+ sp_min.setValue(int(self.class_min_segment_frames.get(cls, global_min_seg)))
2323
+ grid.addWidget(sp_min, row, 3)
2324
+
2325
+ controls[cls] = (sp_smooth, sp_gap, sp_min)
2326
+
2327
+ scroll.setWidget(inner)
2328
+ root.addWidget(scroll, stretch=1)
2329
+
2330
+ btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
2331
+ root.addWidget(btns)
2332
+ btns.accepted.connect(dlg.accept)
2333
+ btns.rejected.connect(dlg.reject)
2334
+
2335
+ if dlg.exec():
2336
+ new_smooth = {}
2337
+ new_gap = {}
2338
+ new_min = {}
2339
+ for cls, (sp_smooth, sp_gap, sp_min) in controls.items():
2340
+ smooth_val = int(sp_smooth.value())
2341
+ if smooth_val % 2 == 0:
2342
+ smooth_val += 1
2343
+ gap_val = int(sp_gap.value())
2344
+ min_val = int(sp_min.value())
2345
+ if smooth_val != global_smooth:
2346
+ new_smooth[cls] = smooth_val
2347
+ if gap_val != global_gap:
2348
+ new_gap[cls] = gap_val
2349
+ if min_val != global_min_seg:
2350
+ new_min[cls] = min_val
2351
+ self.class_smoothing_window_frames = new_smooth
2352
+ self.class_merge_gap_frames = new_gap
2353
+ self.class_min_segment_frames = new_min
2354
+ self._sync_per_class_segment_rule_config()
2355
+ if self.predictions and self.frame_aggregation_check.isChecked():
2356
+ self._compute_aggregated_timeline()
2357
+ self._display_results()
2358
+
2359
+ def _get_localization_bbox_for_clip_frame(self, clip_idx: int, frame_idx: int):
2360
+ """Return normalized xyxy localization bbox for a clip frame, or None."""
2361
+ if clip_idx < 0 or clip_idx >= len(self.localization_bboxes):
2362
+ return None
2363
+
2364
+ raw = self.localization_bboxes[clip_idx]
2365
+ if not isinstance(raw, (list, tuple)) or len(raw) == 0:
2366
+ return None
2367
+
2368
+ def _ema_smooth_boxes(boxes, alpha: float):
2369
+ if not boxes:
2370
+ return boxes
2371
+ prev = [float(v) for v in boxes[0]]
2372
+ smoothed = [prev]
2373
+ for i in range(1, len(boxes)):
2374
+ curr = [float(v) for v in boxes[i]]
2375
+ prev = [
2376
+ alpha * curr[0] + (1.0 - alpha) * prev[0],
2377
+ alpha * curr[1] + (1.0 - alpha) * prev[1],
2378
+ alpha * curr[2] + (1.0 - alpha) * prev[2],
2379
+ alpha * curr[3] + (1.0 - alpha) * prev[3],
2380
+ ]
2381
+ smoothed.append(prev)
2382
+ return smoothed
2383
+
2384
+ # Per-clip bbox: [4]
2385
+ if len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
2386
+ vals = raw
2387
+ # Per-frame bboxes: [T,4]
2388
+ elif isinstance(raw[0], (list, tuple)) and len(raw[0]) == 4:
2389
+ idx = max(0, min(int(frame_idx), len(raw) - 1))
2390
+ vals = _ema_smooth_boxes(raw, self._bbox_ema_alpha)[idx]
2391
+ else:
2392
+ return None
2393
+
2394
+ try:
2395
+ x1, y1, x2, y2 = [float(v) for v in vals]
2396
+ except Exception:
2397
+ return None
2398
+
2399
+ x1 = max(0.0, min(1.0, x1))
2400
+ y1 = max(0.0, min(1.0, y1))
2401
+ x2 = max(0.0, min(1.0, x2))
2402
+ y2 = max(0.0, min(1.0, y2))
2403
+ if x2 < x1:
2404
+ x1, x2 = x2, x1
2405
+ if y2 < y1:
2406
+ y1, y2 = y2, y1
2407
+
2408
+ if (x2 - x1) < 1e-4 or (y2 - y1) < 1e-4:
2409
+ return None
2410
+ return x1, y1, x2, y2
2411
+
2412
+ def _get_classification_roi_bbox_for_clip_frame(self, clip_idx: int):
2413
+ """Return the normalized xyxy ROI used by classification cropping."""
2414
+ if clip_idx < 0 or clip_idx >= len(self.localization_bboxes):
2415
+ return None
2416
+
2417
+ raw = self.localization_bboxes[clip_idx]
2418
+ if not isinstance(raw, (list, tuple)) or len(raw) == 0:
2419
+ return None
2420
+
2421
+ # Match InferenceWorker._build_refined_clips:
2422
+ # if temporal boxes exist, use first-frame box as fixed clip ROI.
2423
+ if isinstance(raw[0], (list, tuple)) and len(raw[0]) == 4:
2424
+ vals = raw[0]
2425
+ elif len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
2426
+ vals = raw
2427
+ else:
2428
+ return None
2429
+
2430
+ try:
2431
+ x1, y1, x2, y2 = [float(v) for v in vals]
2432
+ except Exception:
2433
+ return None
2434
+
2435
+ crop_padding = getattr(self, "_crop_padding", 0.35)
2436
+ crop_min_size = getattr(self, "_crop_min_size", 0.04)
2437
+ x1, y1, x2, y2 = _sanitize_bbox_coords(x1, y1, x2, y2, crop_padding, crop_min_size)
2438
+ if (x2 - x1) < 1e-4 or (y2 - y1) < 1e-4:
2439
+ return None
2440
+ return x1, y1, x2, y2
2441
+
2442
+ def _get_saved_frame_interval(self, video_path: str, orig_fps: float) -> int:
2443
+ """Return inference-time frame interval for a video when available."""
2444
+ if video_path and isinstance(self.results_cache, dict):
2445
+ entry = self.results_cache.get(video_path, {})
2446
+ if isinstance(entry, dict):
2447
+ stored = entry.get("frame_interval", None)
2448
+ try:
2449
+ stored_val = int(stored)
2450
+ except Exception:
2451
+ stored_val = 0
2452
+ if stored_val > 0:
2453
+ return stored_val
2454
+
2455
+ target_fps = max(1, int(self.target_fps_spin.value()))
2456
+ return max(1, int(round(float(orig_fps) / float(target_fps))))
2457
+
2458
+ def _build_center_merge_weights(self, length: int) -> np.ndarray:
2459
+ """Match inference-worker overlap merge weighting."""
2460
+ if length <= 1:
2461
+ return np.ones((max(1, length),), dtype=np.float32)
2462
+ w = np.hanning(length).astype(np.float32)
2463
+ if not np.any(w > 0):
2464
+ return np.ones((length,), dtype=np.float32)
2465
+ return np.clip(0.1 + 0.9 * w, 1e-3, None).astype(np.float32)
2466
+
2467
+ def _get_precomputed_aggregated_probs(self, video_path: str = None):
2468
+ """Return cached worker-built frame probabilities for the active video when available."""
2469
+ v_path = video_path or self.video_path
2470
+ if not v_path or not isinstance(self.results_cache, dict):
2471
+ return None
2472
+ entry = self.results_cache.get(v_path, {})
2473
+ if not isinstance(entry, dict):
2474
+ return None
2475
+ precomputed = entry.get("aggregated_frame_probs")
2476
+ if isinstance(precomputed, list):
2477
+ try:
2478
+ precomputed = np.asarray(precomputed, dtype=np.float32)
2479
+ except Exception:
2480
+ return None
2481
+ if isinstance(precomputed, np.ndarray) and precomputed.ndim == 2 and precomputed.shape[0] > 0:
2482
+ return precomputed
2483
+ return None
2484
+
2485
+ def _smooth_win_for_class(self, cls_idx: int) -> int:
2486
+ base = int(max(1, getattr(self, "_temporal_smoothing_window_frames", 1)))
2487
+ if 0 <= cls_idx < len(self.classes):
2488
+ base = int(self.class_smoothing_window_frames.get(self.classes[cls_idx], base))
2489
+ if base % 2 == 0:
2490
+ base += 1
2491
+ return max(1, base)
2492
+
2493
+ def _gap_fill_for_class(self, cls_idx: int) -> int:
2494
+ base = int(max(0, getattr(self, "_merge_gap_frames", 0)))
2495
+ if 0 <= cls_idx < len(self.classes):
2496
+ base = int(self.class_merge_gap_frames.get(self.classes[cls_idx], base))
2497
+ return max(0, base)
2498
+
2499
+ def _min_seg_for_class(self, cls_idx: int) -> int:
2500
+ base = int(max(1, getattr(self, "_min_segment_frames", 1)))
2501
+ if 0 <= cls_idx < len(self.classes):
2502
+ base = int(self.class_min_segment_frames.get(self.classes[cls_idx], base))
2503
+ return max(1, base)
2504
+
2505
+ def _sync_per_class_segment_rule_config(self):
2506
+ self.config["inference_class_min_segment_frames"] = {
2507
+ cls: int(v) for cls, v in self.class_min_segment_frames.items()
2508
+ }
2509
+ self.config["inference_class_merge_gap_frames"] = {
2510
+ cls: int(v) for cls, v in self.class_merge_gap_frames.items()
2511
+ }
2512
+ self.config["inference_class_smoothing_window_frames"] = {
2513
+ cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
2514
+ }
2515
+
2516
+ def _json_safe_result_value(self, value):
2517
+ """Recursively convert numpy values into JSON-safe Python types."""
2518
+ if isinstance(value, np.ndarray):
2519
+ return value.tolist()
2520
+ if isinstance(value, np.generic):
2521
+ return value.item()
2522
+ if isinstance(value, dict):
2523
+ skip = {"clip_attention_maps"}
2524
+ return {str(k): self._json_safe_result_value(v) for k, v in value.items() if str(k) not in skip}
2525
+ if isinstance(value, (list, tuple)):
2526
+ return [self._json_safe_result_value(v) for v in value]
2527
+ return value
2528
+
2529
+ def _arrays_sidecar_path(self, json_path: str) -> str:
2530
+ base, ext = os.path.splitext(json_path)
2531
+ if ext.lower() == ".json":
2532
+ return base + ".arrays.npz"
2533
+ return json_path + ".arrays.npz"
2534
+
2535
+ def _coerce_external_result_array(self, key: str, value):
2536
+ if value is None:
2537
+ return None
2538
+ if isinstance(value, np.ndarray):
2539
+ arr = value
2540
+ elif isinstance(value, (list, tuple)) and len(value) > 0:
2541
+ try:
2542
+ arr = np.asarray(value)
2543
+ except Exception:
2544
+ return None
2545
+ else:
2546
+ return None
2547
+ if arr.size <= 0 or arr.dtype == object:
2548
+ return None
2549
+ if np.issubdtype(arr.dtype, np.floating):
2550
+ arr = arr.astype(np.float32, copy=False)
2551
+ elif np.issubdtype(arr.dtype, np.integer):
2552
+ arr = arr.astype(np.int32, copy=False)
2553
+ return np.ascontiguousarray(arr)
2554
+
2555
+ def _prepare_results_for_storage(self, results_cache: dict):
2556
+ heavy_keys = {"clip_probabilities", "clip_frame_probabilities", "aggregated_frame_probs"}
2557
+ json_results = {}
2558
+ external_arrays = {}
2559
+ for video_idx, (video_path, entry) in enumerate(results_cache.items()):
2560
+ if not isinstance(entry, dict):
2561
+ json_results[str(video_path)] = self._json_safe_result_value(entry)
2562
+ continue
2563
+ entry_json = {}
2564
+ entry_refs = {}
2565
+ for key, value in entry.items():
2566
+ if str(key) == "clip_attention_maps":
2567
+ continue
2568
+ if key in heavy_keys:
2569
+ arr = self._coerce_external_result_array(key, value)
2570
+ if arr is not None:
2571
+ store_key = f"video_{video_idx:05d}__{key}"
2572
+ external_arrays[store_key] = arr
2573
+ entry_refs[key] = store_key
2574
+ continue
2575
+ entry_json[str(key)] = self._json_safe_result_value(value)
2576
+ if entry_refs:
2577
+ entry_json["_external_arrays"] = entry_refs
2578
+ json_results[str(video_path)] = entry_json
2579
+ return json_results, external_arrays
2580
+
2581
+ def _write_results_bundle(self, results_path: str, payload: dict, *, pretty: bool):
2582
+ results_copy = dict(payload)
2583
+ json_results, external_arrays = self._prepare_results_for_storage(payload.get("results", {}) or {})
2584
+ results_copy["results"] = json_results
2585
+ sidecar_path = self._arrays_sidecar_path(results_path)
2586
+ if external_arrays:
2587
+ np.savez_compressed(sidecar_path, **external_arrays)
2588
+ results_copy["external_array_store"] = {
2589
+ "format": "npz",
2590
+ "file": os.path.basename(sidecar_path),
2591
+ }
2592
+ else:
2593
+ results_copy.pop("external_array_store", None)
2594
+ if os.path.exists(sidecar_path):
2595
+ try:
2596
+ os.remove(sidecar_path)
2597
+ except Exception:
2598
+ pass
2599
+ with open(results_path, "w", encoding="utf-8") as f:
2600
+ if pretty:
2601
+ json.dump(results_copy, f, indent=2)
2602
+ else:
2603
+ json.dump(results_copy, f)
2604
+
2605
+ def _restore_external_arrays(self, file_path: str, data: dict):
2606
+ if not isinstance(data, dict):
2607
+ return
2608
+ results = data.get("results", {})
2609
+ if not isinstance(results, dict):
2610
+ return
2611
+ store_info = data.get("external_array_store", {})
2612
+ sidecar_file = None
2613
+ if isinstance(store_info, dict):
2614
+ sidecar_file = store_info.get("file")
2615
+ sidecar_path = (
2616
+ os.path.join(os.path.dirname(file_path), sidecar_file)
2617
+ if sidecar_file else self._arrays_sidecar_path(file_path)
2618
+ )
2619
+ if not os.path.exists(sidecar_path):
2620
+ return
2621
+ try:
2622
+ with np.load(sidecar_path, allow_pickle=False) as npz_file:
2623
+ for entry in results.values():
2624
+ if not isinstance(entry, dict):
2625
+ continue
2626
+ refs = entry.pop("_external_arrays", None)
2627
+ if not isinstance(refs, dict):
2628
+ continue
2629
+ for field_name, store_key in refs.items():
2630
+ if store_key not in npz_file:
2631
+ continue
2632
+ arr = np.asarray(npz_file[store_key])
2633
+ if field_name == "aggregated_frame_probs":
2634
+ entry[field_name] = arr.astype(np.float32, copy=False)
2635
+ else:
2636
+ entry[field_name] = arr.tolist()
2637
+ except Exception as exc:
2638
+ self.log_text.append(f"Warning: Failed to load companion results arrays: {exc}")
2639
+
2640
+ def _get_clips_dir(self) -> str:
2641
+ """Resolved clips directory from config; creates dir if needed."""
2642
+ clips_dir = self.config.get("clips_dir", "data/clips")
2643
+ if not os.path.isabs(clips_dir):
2644
+ exp_path = self.config.get("experiment_path")
2645
+ if exp_path:
2646
+ clips_dir = os.path.join(exp_path, clips_dir)
2647
+ else:
2648
+ clips_dir = os.path.abspath(clips_dir)
2649
+ os.makedirs(clips_dir, exist_ok=True)
2650
+ return clips_dir
2651
+
2652
+ def _get_annotation_file(self) -> str:
2653
+ """Path to annotations JSON from config."""
2654
+ return self.config.get("annotation_file", "data/annotations/annotations.json")
2655
+
2656
+ def _clip_path_to_id(self, clip_path: str, clips_dir: str) -> str:
2657
+ """Convert absolute clip path to annotation clip ID (relative, forward slashes)."""
2658
+ clip_id = os.path.relpath(clip_path, clips_dir).replace("\\", "/")
2659
+ if clip_id.startswith("../"):
2660
+ return os.path.basename(clip_path)
2661
+ for prefix in ("../clips/", "clips/", "data/clips/"):
2662
+ if clip_id.startswith(prefix):
2663
+ return clip_id[len(prefix):]
2664
+ return clip_id
2665
+
2666
+ def _get_video_fps(self, video_path: str) -> float:
2667
+ """Return video FPS from path; 30.0 if unavailable or invalid."""
2668
+ if not video_path:
2669
+ return 30.0
2670
+ try:
2671
+ cap = cv2.VideoCapture(video_path)
2672
+ fps = cap.get(cv2.CAP_PROP_FPS)
2673
+ cap.release()
2674
+ return float(fps) if fps and fps > 0 else 30.0
2675
+ except Exception:
2676
+ return 30.0
2677
+
2678
+ def _video_basename(self) -> str:
2679
+ """Basename of current video path without extension."""
2680
+ if not self.video_path:
2681
+ return ""
2682
+ return os.path.splitext(os.path.basename(self.video_path))[0]
2683
+
2684
+ def _get_timeline_qcolors(self) -> list:
2685
+ """Timeline palette as list of QColor for drawing."""
2686
+ return [QColor(r, g, b) for r, g, b in self._get_timeline_palette()]
2687
+
2688
+ def _unique_clip_path(self, clip_path: str) -> str:
2689
+ """Return a path that does not exist by appending _1, _2, ... to stem."""
2690
+ base_clip_path = clip_path
2691
+ suffix = 1
2692
+ while os.path.exists(clip_path):
2693
+ base, ext = os.path.splitext(base_clip_path)
2694
+ clip_path = f"{base}_{suffix}{ext}"
2695
+ suffix += 1
2696
+ return clip_path
2697
+
2698
+ def _merge_predictions(self, predictions, confidences, clip_starts):
2699
+ """Merge consecutive identical predictions."""
2700
+ if not predictions:
2701
+ return predictions, confidences, clip_starts
2702
+
2703
+ merged_preds = []
2704
+ merged_confs = []
2705
+ merged_starts = []
2706
+
2707
+ current_pred = predictions[0]
2708
+ current_conf = confidences[0]
2709
+ current_start = clip_starts[0]
2710
+
2711
+ for i in range(1, len(predictions)):
2712
+ if predictions[i] == current_pred:
2713
+ # Same behavior, continue merging (keep max confidence)
2714
+ current_conf = max(current_conf, confidences[i])
2715
+ else:
2716
+ # Different behavior, save current segment and start new one
2717
+ merged_preds.append(current_pred)
2718
+ merged_confs.append(current_conf)
2719
+ merged_starts.append(current_start)
2720
+
2721
+ current_pred = predictions[i]
2722
+ current_conf = confidences[i]
2723
+ current_start = clip_starts[i]
2724
+
2725
+ # Add last segment
2726
+ merged_preds.append(current_pred)
2727
+ merged_confs.append(current_conf)
2728
+ merged_starts.append(current_start)
2729
+
2730
+ return merged_preds, merged_confs, merged_starts
2731
+
2732
+ def _draw_timeline(self):
2733
+ """Draw colored timeline visualization."""
2734
+ if not self.predictions or not self.classes:
2735
+ return
2736
+
2737
+ # OvR: per-class rows when checkbox on; non-OvR: single-row timeline
2738
+ ovr_rows = (
2739
+ getattr(self, "ovr_rows_check", None) is not None
2740
+ and self.ovr_rows_check.isChecked()
2741
+ and self._use_ovr
2742
+ and bool(self.clip_probabilities)
2743
+ )
2744
+ self.timeline_scroll.setVisible(not ovr_rows)
2745
+ self._ovr_timeline_container.setVisible(ovr_rows)
2746
+
2747
+ frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
2748
+
2749
+ # OvR per-class row timeline
2750
+ if ovr_rows:
2751
+ if frame_aggregation_enabled:
2752
+ need_precise_recompute = (
2753
+ not self.aggregated_segments
2754
+ or not isinstance(self._aggregated_frame_scores_norm, np.ndarray)
2755
+ or int(self._aggregated_last_covered_frame) <= 0
2756
+ )
2757
+ if need_precise_recompute:
2758
+ self._compute_aggregated_timeline()
2759
+ self._draw_ovr_multirow_timeline()
2760
+ return
2761
+
2762
+ # Use frame-aggregated segments if enabled
2763
+ if frame_aggregation_enabled and self.aggregated_segments:
2764
+ self._draw_frame_aggregated_timeline()
2765
+ return
2766
+
2767
+ # Original clip-based timeline drawing
2768
+ corrected_preds = self._effective_predictions()
2769
+
2770
+ # Apply merging if enabled
2771
+ merge_enabled = self.merge_timeline_check.isChecked()
2772
+ if merge_enabled:
2773
+ display_preds, display_confs, display_starts = self._merge_predictions(
2774
+ corrected_preds, self.confidences, self.clip_starts
2775
+ )
2776
+ else:
2777
+ display_preds = corrected_preds
2778
+ display_confs = self.confidences
2779
+ display_starts = self.clip_starts
2780
+
2781
+ num_segments = len(display_preds)
2782
+ if num_segments == 0:
2783
+ return
2784
+
2785
+ # Calculate total width using zoom spinbox (px per second).
2786
+ total_clips = len(self.predictions)
2787
+ px_per_sec_clip = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
2788
+ orig_fps_clip = self._get_video_fps(self.video_path) if self.video_path else 30.0
2789
+ # base_clip_width: pixels per clip based on clip_length and zoom
2790
+ clip_dur_sec = self.clip_length_spin.value() / max(1.0, float(self.target_fps_spin.value()))
2791
+ base_clip_width = max(4, int(clip_dur_sec * px_per_sec_clip))
2792
+ total_width = max(800, total_clips * base_clip_width)
2793
+ height = 60
2794
+
2795
+ self.timeline_widget.setFixedSize(total_width, height)
2796
+ pixmap = QPixmap(total_width, height)
2797
+ pixmap.fill(QColor(255, 255, 255))
2798
+
2799
+ painter = QPainter(pixmap)
2800
+ painter.setRenderHint(QPainter.RenderHint.Antialiasing)
2801
+
2802
+ colors = self._get_timeline_qcolors()
2803
+
2804
+ x_pos = 0
2805
+ selected_behavior = self.filter_behavior_combo.currentText()
2806
+ selected_attr = None
2807
+ if selected_behavior.startswith("Attr: "):
2808
+ selected_attr = selected_behavior.replace("Attr: ", "", 1)
2809
+
2810
+ for seg_idx, (pred_idx, conf) in enumerate(zip(display_preds, display_confs)):
2811
+ if pred_idx < len(self.classes) and pred_idx >= 0:
2812
+ # Calculate segment width based on number of clips it spans
2813
+ if seg_idx < len(display_starts) - 1:
2814
+ # Calculate number of clips in this segment
2815
+ seg_start_clip = display_starts[seg_idx]
2816
+ seg_end_clip = display_starts[seg_idx + 1]
2817
+ # Find how many original clips this spans
2818
+ clip_count = sum(1 for orig_start in self.clip_starts
2819
+ if seg_start_clip <= orig_start < seg_end_clip)
2820
+ seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
2821
+ else:
2822
+ # Last segment - count remaining clips
2823
+ seg_start_clip = display_starts[seg_idx]
2824
+ clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
2825
+ seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
2826
+
2827
+ label = self.classes[pred_idx]
2828
+ if selected_behavior == "All Behaviors":
2829
+ is_selected = True
2830
+ elif selected_behavior == self.ignore_label_name:
2831
+ is_selected = False
2832
+ elif selected_attr is not None:
2833
+ if selected_attr in self.attributes:
2834
+ attr_target_idx = self.attributes.index(selected_attr)
2835
+ # Find clip indices covered by this segment
2836
+ seg_indices = [
2837
+ idx for idx, orig_start in enumerate(self.clip_starts)
2838
+ if seg_start_clip <= orig_start < (display_starts[seg_idx + 1] if seg_idx < len(display_starts) - 1 else float("inf"))
2839
+ ]
2840
+ is_selected = any(
2841
+ (self._get_attr_idx(i) == attr_target_idx)
2842
+ for i in seg_indices
2843
+ )
2844
+ else:
2845
+ is_selected = False
2846
+ else:
2847
+ is_selected = (label == selected_behavior)
2848
+
2849
+ if is_selected:
2850
+ color = colors[pred_idx % len(colors)]
2851
+ alpha = int(255 * conf)
2852
+ color.setAlpha(alpha)
2853
+ else:
2854
+ # Draw as light gray if not selected to preserve timeline structure
2855
+ color = QColor(240, 240, 240)
2856
+ color.setAlpha(255)
2857
+
2858
+ painter.fillRect(int(x_pos), 0, int(seg_width), height, color)
2859
+
2860
+ if seg_width >= 30 and is_selected:
2861
+ painter.setPen(QColor(0, 0, 0))
2862
+ painter.setFont(QFont("Arial", 8))
2863
+ painter.drawText(int(x_pos + 5), 20, label)
2864
+
2865
+ x_pos += seg_width
2866
+ elif pred_idx < 0:
2867
+ if seg_idx < len(display_starts) - 1:
2868
+ seg_start_clip = display_starts[seg_idx]
2869
+ seg_end_clip = display_starts[seg_idx + 1]
2870
+ clip_count = sum(1 for orig_start in self.clip_starts if seg_start_clip <= orig_start < seg_end_clip)
2871
+ seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
2872
+ else:
2873
+ seg_start_clip = display_starts[seg_idx]
2874
+ clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
2875
+ seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
2876
+ is_selected = selected_behavior in ("All Behaviors", self.ignore_label_name)
2877
+ color = QColor(180, 180, 180) if is_selected else QColor(240, 240, 240)
2878
+ color.setAlpha(230 if is_selected else 255)
2879
+ painter.fillRect(int(x_pos), 0, int(seg_width), height, color)
2880
+ if seg_width >= 55 and is_selected:
2881
+ painter.setPen(QColor(40, 40, 40))
2882
+ painter.setFont(QFont("Arial", 8))
2883
+ painter.drawText(int(x_pos + 5), 20, "ignored")
2884
+ x_pos += seg_width
2885
+
2886
+ painter.setPen(QColor(0, 0, 0))
2887
+ painter.setFont(QFont("Arial", 8))
2888
+ merge_text = " (merged)" if merge_enabled else ""
2889
+ painter.drawText(5, height - 5, f"Timeline: {num_segments} segments, {len(self.predictions)} clips{merge_text} (click to view)")
2890
+
2891
+ painter.end()
2892
+
2893
+ label = QLabel()
2894
+ label.setPixmap(pixmap)
2895
+ label.setFixedSize(total_width, height)
2896
+
2897
+ old_layout = self.timeline_widget.layout()
2898
+ if old_layout:
2899
+ while old_layout.count():
2900
+ child = old_layout.takeAt(0)
2901
+ if child.widget():
2902
+ child.widget().deleteLater()
2903
+ else:
2904
+ layout = QVBoxLayout()
2905
+ layout.setContentsMargins(0, 0, 0, 0)
2906
+ self.timeline_widget.setLayout(layout)
2907
+
2908
+ self.timeline_widget.layout().addWidget(label)
2909
+ self.timeline_widget.setFixedSize(total_width, height)
2910
+ self.timeline_widget.clip_width = base_clip_width
2911
+ self.timeline_widget.num_clips = len(self.predictions)
2912
+ self.timeline_widget._frame_mode = False
2913
+
2914
+ def _filter_cooccurrence(self, probs, threshold):
2915
+ """Return set of class indices to display, respecting co-occurrence rules.
2916
+
2917
+ When "Show all classes" is on, every class above threshold is returned.
2918
+ Otherwise: top-1 class is always shown, additional classes only if they
2919
+ form an allowed co-occurrence pair with the top-1 class.
2920
+ """
2921
+ show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
2922
+ n = min(len(probs), len(self.classes))
2923
+ scored = [(ci, float(probs[ci])) for ci in range(n) if float(probs[ci]) >= threshold]
2924
+ if not scored:
2925
+ return set()
2926
+ if show_all:
2927
+ return {ci for ci, _ in scored}
2928
+ scored.sort(key=lambda x: x[1], reverse=True)
2929
+ top_ci = scored[0][0]
2930
+ top_name = self.classes[top_ci]
2931
+ active = {top_ci}
2932
+ for ci, sc in scored[1:]:
2933
+ name = self.classes[ci]
2934
+ if (top_name, name) in self._allowed_cooccurrence:
2935
+ active.add(ci)
2936
+ return active
2937
+
2938
+ def _active_ovr_indices_from_scores(self, probs_row, threshold_override: float | None = None):
2939
+ """Active OvR class indices at one frame using thresholds.
2940
+
2941
+ When "Show all classes" is enabled, every class above its threshold is
2942
+ returned (fully independent). Otherwise the top-1 class is returned
2943
+ plus any class that forms an allowed co-occurrence pair with it.
2944
+ """
2945
+ show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
2946
+ n = min(len(probs_row), len(self.classes))
2947
+ scored = []
2948
+ for ci in range(n):
2949
+ s = float(probs_row[ci])
2950
+ thr = float(threshold_override) if threshold_override is not None else self._threshold_for_pred(ci)
2951
+ if s >= thr:
2952
+ scored.append((ci, s))
2953
+ if not scored:
2954
+ return []
2955
+ scored.sort(key=lambda x: x[1], reverse=True)
2956
+
2957
+ if show_all:
2958
+ return [ci for ci, _ in scored]
2959
+
2960
+ top_ci = scored[0][0]
2961
+ if not self._allowed_cooccurrence:
2962
+ return [top_ci]
2963
+ top_name = self.classes[top_ci]
2964
+ active = [top_ci]
2965
+ for ci, _ in scored[1:]:
2966
+ name = self.classes[ci]
2967
+ if (top_name, name) in self._allowed_cooccurrence:
2968
+ active.append(ci)
2969
+ return active
2970
+
2971
+ def _get_precise_active_for_frame(self, frame_idx: int):
2972
+ """Return [(class_idx, score), ...] active at frame_idx, sorted desc.
2973
+
2974
+ For OvR: uses _aggregated_active_mask when available so that
2975
+ min-segment and gap-fill filtering are respected.
2976
+ """
2977
+ if not isinstance(self._aggregated_frame_scores_norm, np.ndarray):
2978
+ return []
2979
+ if frame_idx < 0 or frame_idx >= int(self._aggregated_last_covered_frame):
2980
+ return []
2981
+ scores = self._aggregated_frame_scores_norm[frame_idx]
2982
+ if self._use_ovr:
2983
+ # Prefer the active mask (has min-segment + gap-fill applied)
2984
+ if isinstance(self._aggregated_active_mask, np.ndarray) and frame_idx < self._aggregated_active_mask.shape[0]:
2985
+ mask_row = self._aggregated_active_mask[frame_idx]
2986
+ out = [(ci, float(scores[ci])) for ci in range(len(mask_row)) if mask_row[ci]]
2987
+ else:
2988
+ thr = None if self.use_ignore_threshold else 0.35
2989
+ active = self._active_ovr_indices_from_scores(scores, threshold_override=thr)
2990
+ out = [(ci, float(scores[ci])) for ci in active]
2991
+ out.sort(key=lambda x: x[1], reverse=True)
2992
+ return out
2993
+ if len(scores) == 0:
2994
+ return []
2995
+ ci = int(np.argmax(scores))
2996
+ if 0 <= ci < len(self.classes):
2997
+ if self.use_ignore_threshold and float(scores[ci]) < self._threshold_for_pred(ci):
2998
+ return []
2999
+ return [(ci, float(scores[ci]))]
3000
+ return []
3001
+
3002
+ def _ovr_timeline_click(self, event):
3003
+ """Handle click on OvR multi-row timeline to show clip popup."""
3004
+ x = event.position().x()
3005
+ y = event.position().y()
3006
+ row_height = max(1, int(getattr(self, "_ovr_row_height", 24)))
3007
+ clicked_class = int(y / row_height) if row_height > 0 else -1
3008
+ if clicked_class < 0 or clicked_class >= len(self.classes):
3009
+ clicked_class = -1
3010
+
3011
+ if getattr(self, "_ovr_timeline_frame_mode", False):
3012
+ if self._ovr_pixels_per_frame > 0:
3013
+ frame_idx = int(x / self._ovr_pixels_per_frame)
3014
+ if 0 <= frame_idx < int(self._aggregated_last_covered_frame):
3015
+ self._show_clip_popup(frame_idx, frame_mode=True, ovr_class_idx=clicked_class)
3016
+ return
3017
+ if self._ovr_num_clips > 0 and self._ovr_clip_width > 0:
3018
+ clip_idx = int(x / self._ovr_clip_width)
3019
+ if 0 <= clip_idx < self._ovr_num_clips:
3020
+ self._show_clip_popup(clip_idx, ovr_class_idx=clicked_class)
3021
+
3022
+ def _draw_ovr_multirow_timeline(self):
3023
+ """Draw per-class row timeline: fixed labels on left, scrollable bars on right."""
3024
+ num_classes = len(self.classes)
3025
+ num_clips = len(self.predictions)
3026
+ if num_clips == 0 or num_classes == 0:
3027
+ return
3028
+
3029
+ px_per_sec_ovr = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
3030
+ orig_fps_ovr = self._get_video_fps(self.video_path) if self.video_path else 30.0
3031
+ clip_dur_sec_ovr = self.clip_length_spin.value() / max(1.0, float(self.target_fps_spin.value()))
3032
+ base_clip_width = max(4, int(clip_dur_sec_ovr * px_per_sec_ovr))
3033
+
3034
+ use_precise = (
3035
+ self.frame_aggregation_check.isChecked()
3036
+ and self._use_ovr
3037
+ and isinstance(self._aggregated_frame_scores_norm, np.ndarray)
3038
+ and int(self._aggregated_last_covered_frame) > 0
3039
+ )
3040
+ if use_precise:
3041
+ total_frames = int(self._aggregated_last_covered_frame)
3042
+ pixels_per_frame = max(0.5, px_per_sec_ovr / orig_fps_ovr)
3043
+ total_width = max(800, int(total_frames * pixels_per_frame))
3044
+ self._ovr_timeline_frame_mode = True
3045
+ self._ovr_pixels_per_frame = pixels_per_frame
3046
+ else:
3047
+ total_width = max(800, num_clips * base_clip_width)
3048
+ self._ovr_timeline_frame_mode = False
3049
+ self._ovr_pixels_per_frame = 1.0
3050
+ self._ovr_clip_width = base_clip_width
3051
+ self._ovr_num_clips = num_clips
3052
+ viewport_h = 0
3053
+ if getattr(self, "_ovr_scroll", None) is not None and self._ovr_scroll.viewport() is not None:
3054
+ viewport_h = int(self._ovr_scroll.viewport().height())
3055
+ if viewport_h <= 0 and getattr(self, "_ovr_timeline_container", None) is not None:
3056
+ viewport_h = max(0, int(self._ovr_timeline_container.height()) - 8)
3057
+ viewport_h = max(120, viewport_h)
3058
+ row_height = max(12, viewport_h // max(1, num_classes))
3059
+ self._ovr_row_height = row_height
3060
+ canvas_height = num_classes * row_height
3061
+ activation_threshold = 0.35
3062
+
3063
+ colors = self._get_timeline_qcolors()
3064
+
3065
+ # --- Fixed label panel ---
3066
+ label_pm = QPixmap(95, canvas_height)
3067
+ label_pm.fill(QColor(245, 245, 245))
3068
+ lp = QPainter(label_pm)
3069
+ lp.setFont(QFont("Arial", max(7, min(10, row_height - 10))))
3070
+ for ci in range(num_classes):
3071
+ y = ci * row_height
3072
+ c = colors[ci % len(colors)]
3073
+ swatch_h = max(4, row_height - 8)
3074
+ lp.fillRect(2, y + max(2, (row_height - swatch_h) // 2), 10, swatch_h, c)
3075
+ lp.setPen(QColor(0, 0, 0))
3076
+ text_y = y + min(row_height - 4, max(10, int(row_height * 0.72)))
3077
+ lp.drawText(16, text_y, self.classes[ci])
3078
+ lp.setPen(QColor(210, 210, 210))
3079
+ lp.drawLine(0, y + row_height - 1, 95, y + row_height - 1)
3080
+ lp.end()
3081
+ self._ovr_label_panel.setPixmap(label_pm)
3082
+ self._ovr_label_panel.setFixedHeight(canvas_height)
3083
+
3084
+ # --- Scrollable activation timeline ---
3085
+ tl_pm = QPixmap(total_width, canvas_height)
3086
+ tl_pm.fill(QColor(255, 255, 255))
3087
+ tp = QPainter(tl_pm)
3088
+ tp.setRenderHint(QPainter.RenderHint.Antialiasing)
3089
+
3090
+ if use_precise:
3091
+ total_frames = int(self._aggregated_last_covered_frame)
3092
+ active_mask = self._aggregated_active_mask
3093
+ frame_scores = self._aggregated_frame_scores_norm
3094
+ for ci in range(num_classes):
3095
+ y = ci * row_height
3096
+ color = colors[ci % len(colors)]
3097
+ bar_h = max(2, row_height - 4)
3098
+ if not isinstance(active_mask, np.ndarray) or ci >= active_mask.shape[1]:
3099
+ continue
3100
+ run_start = None
3101
+ for fi in range(total_frames):
3102
+ is_active = bool(active_mask[fi, ci]) if fi < active_mask.shape[0] else False
3103
+ if is_active and run_start is None:
3104
+ run_start = fi
3105
+ if run_start is not None and (not is_active or fi == total_frames - 1):
3106
+ run_end = fi if (is_active and fi == total_frames - 1) else fi - 1
3107
+ if run_end >= run_start:
3108
+ x0 = int(run_start * self._ovr_pixels_per_frame)
3109
+ x1 = int((run_end + 1) * self._ovr_pixels_per_frame)
3110
+ seg_w = max(1, x1 - x0)
3111
+ seg_conf = float(np.mean(frame_scores[run_start:run_end + 1, ci]))
3112
+ c = QColor(color)
3113
+ t = (seg_conf - activation_threshold) / max(0.01, 1.0 - activation_threshold)
3114
+ alpha = int(120 + 135 * min(1.0, max(0.0, t)))
3115
+ c.setAlpha(alpha)
3116
+ tp.fillRect(x0, y + 2, seg_w, bar_h, c)
3117
+ run_start = None
3118
+ else:
3119
+ ovr_threshold = None if self.use_ignore_threshold else activation_threshold
3120
+ for clip_i in range(num_clips):
3121
+ if clip_i >= len(self.clip_probabilities):
3122
+ break
3123
+ probs_i = self.clip_probabilities[clip_i]
3124
+ if not isinstance(probs_i, (list, tuple, np.ndarray)):
3125
+ continue
3126
+
3127
+ # Determine active classes with same threshold/co-occurrence logic
3128
+ # used by precise OvR mode and overlays.
3129
+ active = self._active_ovr_indices_from_scores(
3130
+ probs_i,
3131
+ threshold_override=ovr_threshold,
3132
+ )
3133
+
3134
+ for ci in active:
3135
+ score = float(probs_i[ci])
3136
+ y = ci * row_height
3137
+ color = colors[ci % len(colors)]
3138
+ bar_h = max(2, row_height - 4)
3139
+ x = clip_i * base_clip_width
3140
+ c = QColor(color)
3141
+ ci_threshold = self._threshold_for_pred(ci) if self.use_ignore_threshold else activation_threshold
3142
+ t = (score - ci_threshold) / max(0.01, 1.0 - ci_threshold)
3143
+ alpha = int(120 + 135 * min(1.0, t))
3144
+ c.setAlpha(alpha)
3145
+ tp.fillRect(int(x), y + 2, base_clip_width, bar_h, c)
3146
+
3147
+ for ci in range(num_classes):
3148
+ y = ci * row_height
3149
+ # Row separator
3150
+ tp.setPen(QColor(210, 210, 210))
3151
+ tp.drawLine(0, y + row_height - 1, total_width, y + row_height - 1)
3152
+
3153
+ tp.end()
3154
+ self._ovr_timeline_widget.setPixmap(tl_pm)
3155
+ self._ovr_timeline_widget.setFixedSize(total_width, canvas_height)
3156
+ self._ovr_label_panel.setFixedSize(95, canvas_height)
3157
+
3158
+ def eventFilter(self, obj, event):
3159
+ if (
3160
+ event.type() == QEvent.Type.Resize
3161
+ and obj in {getattr(self, "_ovr_timeline_container", None), getattr(self, "_ovr_scroll", None).viewport() if getattr(self, "_ovr_scroll", None) is not None else None}
3162
+ ):
3163
+ if getattr(self, "_ovr_timeline_container", None) is not None and self._ovr_timeline_container.isVisible():
3164
+ QTimer.singleShot(0, self._draw_timeline)
3165
+ return super().eventFilter(obj, event)
3166
+
3167
+ def _draw_frame_aggregated_timeline(self):
3168
+ """Draw timeline using frame-level aggregated segments with precise boundaries."""
3169
+ if not self.aggregated_segments or not self.classes:
3170
+ return
3171
+
3172
+ segments = self.aggregated_segments
3173
+ num_segments = len(segments)
3174
+
3175
+ # Use original video length so timeline matches full video
3176
+ total_frames = self.total_frames
3177
+ if total_frames <= 0 and self.video_path and os.path.exists(self.video_path):
3178
+ try:
3179
+ cap = cv2.VideoCapture(self.video_path)
3180
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
3181
+ cap.release()
3182
+ except Exception:
3183
+ pass
3184
+ if total_frames <= 0 and segments:
3185
+ total_frames = segments[-1]['end'] + 1
3186
+
3187
+ # Calculate timeline dimensions using the zoom spinbox (px per second of video).
3188
+ # This removes the artificial max_width cap — the scroll area handles overflow.
3189
+ px_per_sec = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
3190
+ orig_fps_tl = self._get_video_fps(self.video_path) if self.video_path else 30.0
3191
+ pixels_per_frame = max(0.5, px_per_sec / orig_fps_tl)
3192
+ total_width = max(800, int(total_frames * pixels_per_frame))
3193
+ height = 60
3194
+
3195
+ self.timeline_widget.setMinimumWidth(total_width)
3196
+ pixmap = QPixmap(total_width, height)
3197
+ pixmap.fill(QColor(255, 255, 255))
3198
+
3199
+ painter = QPainter(pixmap)
3200
+ painter.setRenderHint(QPainter.RenderHint.Antialiasing)
3201
+
3202
+ colors = self._get_timeline_qcolors()
3203
+
3204
+ selected_behavior = self.filter_behavior_combo.currentText()
3205
+
3206
+ for seg in segments:
3207
+ pred_idx = seg['class']
3208
+ conf = seg.get('confidence', 1.0)
3209
+ start_frame = seg['start']
3210
+ end_frame = seg['end']
3211
+
3212
+ # Calculate pixel positions
3213
+ x_start = int(start_frame * pixels_per_frame)
3214
+ x_end = int((end_frame + 1) * pixels_per_frame)
3215
+ seg_width = max(1, x_end - x_start)
3216
+
3217
+ if pred_idx < 0:
3218
+ label = "Filtered"
3219
+ elif pred_idx < len(self.classes):
3220
+ label = self.classes[pred_idx]
3221
+ else:
3222
+ continue
3223
+
3224
+ if selected_behavior == "All Behaviors":
3225
+ is_selected = True
3226
+ elif selected_behavior == self.ignore_label_name:
3227
+ is_selected = (pred_idx < 0)
3228
+ else:
3229
+ is_selected = (label == selected_behavior)
3230
+
3231
+ if is_selected:
3232
+ if pred_idx < 0:
3233
+ color = QColor(180, 180, 180)
3234
+ color.setAlpha(160)
3235
+ else:
3236
+ color = colors[pred_idx % len(colors)]
3237
+ alpha = int(min(255, 128 + 127 * min(1.0, conf)))
3238
+ color.setAlpha(alpha)
3239
+ else:
3240
+ color = QColor(240, 240, 240)
3241
+ color.setAlpha(255)
3242
+
3243
+ painter.fillRect(x_start, 0, seg_width, height, color)
3244
+
3245
+ # Draw label if segment is wide enough
3246
+ if seg_width >= 40 and is_selected:
3247
+ painter.setPen(QColor(0, 0, 0))
3248
+ painter.setFont(QFont("Arial", 8))
3249
+ painter.drawText(x_start + 3, 20, label)
3250
+
3251
+ # Draw boundary markers
3252
+ if is_selected:
3253
+ painter.setPen(QColor(0, 0, 0, 100))
3254
+ painter.drawLine(x_start, 0, x_start, height)
3255
+
3256
+ # Draw info text
3257
+ painter.setPen(QColor(0, 0, 0))
3258
+ painter.setFont(QFont("Arial", 8))
3259
+
3260
+ orig_fps = self._get_video_fps(self.video_path) if self.video_path else 30.0
3261
+ duration_sec = total_frames / orig_fps
3262
+ painter.drawText(5, height - 5,
3263
+ f"Frame Timeline: {num_segments} segments, {total_frames} frames ({duration_sec:.1f}s) - precise boundaries (click to view)")
3264
+
3265
+ painter.end()
3266
+
3267
+ label = QLabel()
3268
+ label.setPixmap(pixmap)
3269
+ label.setFixedSize(total_width, height)
3270
+
3271
+ old_layout = self.timeline_widget.layout()
3272
+ if old_layout:
3273
+ while old_layout.count():
3274
+ child = old_layout.takeAt(0)
3275
+ if child.widget():
3276
+ child.widget().deleteLater()
3277
+ else:
3278
+ layout = QVBoxLayout()
3279
+ layout.setContentsMargins(0, 0, 0, 0)
3280
+ self.timeline_widget.setLayout(layout)
3281
+
3282
+ self.timeline_widget.layout().addWidget(label)
3283
+ self.timeline_widget.setFixedSize(total_width, height)
3284
+
3285
+ # Store info for click handling in frame mode
3286
+ self.timeline_widget._frame_mode = True
3287
+ self.timeline_widget._pixels_per_frame = pixels_per_frame
3288
+ self.timeline_widget._total_frames = total_frames
3289
+
3290
+ def _on_filter_changed(self, index: int):
3291
+ """Handle behavior filter change."""
3292
+ if self.predictions:
3293
+ self._draw_timeline()
3294
+
3295
+ def _on_theme_changed(self, index: int):
3296
+ """Handle timeline theme change."""
3297
+ if self.predictions:
3298
+ self._draw_timeline()
3299
+
3300
+ def _get_timeline_palette(self) -> list[tuple[int, int, int]]:
3301
+ theme = self.timeline_theme_combo.currentText() if hasattr(self, "timeline_theme_combo") else DEFAULT_THEME
3302
+ return get_timeline_palette(theme)
3303
+
3304
+ def _get_attr_idx(self, clip_idx: int):
3305
+ if clip_idx in self.corrected_attr_labels:
3306
+ return self.corrected_attr_labels[clip_idx]
3307
+ if self.attr_predictions and clip_idx < len(self.attr_predictions):
3308
+ return self.attr_predictions[clip_idx]
3309
+ return None
3310
+
3311
+ def _on_clip_length_changed(self, value: int):
3312
+ clip_length = int(value)
3313
+ self.step_frames_spin.blockSignals(True)
3314
+ self.step_frames_spin.setValue(max(1, clip_length // 2))
3315
+ self.step_frames_spin.blockSignals(False)
3316
+ self._on_step_or_clip_changed(self.step_frames_spin.value())
3317
+
3318
+ def _on_step_or_clip_changed(self, value: int):
3319
+ step_frames = self.step_frames_spin.value()
3320
+ clip_length = self.clip_length_spin.value()
3321
+ if step_frames != clip_length:
3322
+ if not self.frame_aggregation_check.isChecked():
3323
+ self.frame_aggregation_check.setChecked(True)
3324
+ self.log_text.append(
3325
+ f"Auto-enabled 'Precise frame boundaries' (step={step_frames} ≠ clip={clip_length})"
3326
+ )
3327
+
3328
+ def _on_merge_changed(self, state: int):
3329
+ """Handle merge checkbox change."""
3330
+ if self.predictions:
3331
+ self._draw_timeline()
3332
+
3333
+ def _on_frame_aggregation_changed(self, state: int):
3334
+ """Handle frame aggregation checkbox change."""
3335
+ if self.predictions:
3336
+ if state:
3337
+ self._compute_aggregated_timeline()
3338
+ self._draw_timeline()
3339
+
3340
+ def _on_timeline_zoom_changed(self, *_args):
3341
+ """Redraw timeline when zoom level changes."""
3342
+ if self.predictions:
3343
+ self._draw_timeline()
3344
+
3345
+ def _on_ovr_show_all_changed(self, *_args):
3346
+ """Toggle between co-occurrence-restricted and fully independent OvR display."""
3347
+ if self.predictions and self._use_ovr:
3348
+ if self.frame_aggregation_check.isChecked():
3349
+ self._build_timeline_segments()
3350
+ self._display_results()
3351
+
3352
+ def _on_smoothing_changed(self, *_args):
3353
+ """Legacy hook: keep config synced when smoothing settings change."""
3354
+ self.config["inference_min_segment_frames"] = self._min_segment_frames
3355
+ self.config["inference_merge_gap_frames"] = self._merge_gap_frames
3356
+ self.config["inference_temporal_smoothing_window_frames"] = self._temporal_smoothing_window_frames
3357
+ self._sync_per_class_segment_rule_config()
3358
+ if self.predictions and self.frame_aggregation_check.isChecked():
3359
+ self._compute_aggregated_timeline()
3360
+ self._display_results()
3361
+
3362
+ def _smooth_frame_labels(self, frame_labels: np.ndarray) -> np.ndarray:
3363
+ """Apply simple temporal smoothing to frame-wise top-1 labels."""
3364
+ if frame_labels.size == 0:
3365
+ return frame_labels
3366
+
3367
+ labels = frame_labels.copy()
3368
+ T = int(labels.shape[0])
3369
+
3370
+ if T > 1:
3371
+ majority_labels = labels.copy()
3372
+ for i in range(T):
3373
+ center = int(labels[i])
3374
+ if center < 0:
3375
+ continue
3376
+ win = self._smooth_win_for_class(center)
3377
+ if win <= 1:
3378
+ continue
3379
+ half = win // 2
3380
+ left = max(0, i - half)
3381
+ right = min(T, i + half + 1)
3382
+ window_vals = labels[left:right]
3383
+ valid_vals = window_vals[window_vals >= 0]
3384
+ if valid_vals.size == 0:
3385
+ continue
3386
+ counts = np.bincount(valid_vals.astype(np.int64))
3387
+ max_count = int(np.max(counts))
3388
+ winners = np.where(counts == max_count)[0]
3389
+ if winners.size == 1:
3390
+ majority_labels[i] = int(winners[0])
3391
+ else:
3392
+ majority_labels[i] = center if center in winners else int(winners[0])
3393
+ labels = majority_labels
3394
+ return labels
3395
+
3396
+ def _apply_gap_merge_and_min_seg(self, frame_labels: np.ndarray, T: int) -> np.ndarray:
3397
+ """Merge short gaps between identical classes and remove short runs.
3398
+
3399
+ Operates on a copy so the caller's array is unchanged if not reassigned.
3400
+ Gap-merge: fills runs of -1 that are <= merge_gap_frames long when the
3401
+ class on both sides is the same.
3402
+ Min-segment: removes runs shorter than min_segment_frames by replacing
3403
+ them with their neighbour.
3404
+ """
3405
+ max_gap = int(max(0, getattr(self, "_merge_gap_frames", 0)))
3406
+ min_len = int(max(1, getattr(self, "_min_segment_frames", 1)))
3407
+ has_class_overrides = bool(
3408
+ self.class_merge_gap_frames
3409
+ or self.class_min_segment_frames
3410
+ )
3411
+ if max_gap == 0 and min_len <= 1 and not has_class_overrides:
3412
+ return frame_labels
3413
+
3414
+ labels = frame_labels.copy()
3415
+
3416
+ if max_gap > 0 or self.class_merge_gap_frames:
3417
+ i = 0
3418
+ while i < T:
3419
+ if labels[i] != -1:
3420
+ i += 1
3421
+ continue
3422
+ j = i
3423
+ while j + 1 < T and labels[j + 1] == -1:
3424
+ j += 1
3425
+ gap_len = j - i + 1
3426
+ left = int(labels[i - 1]) if i > 0 else -1
3427
+ right = int(labels[j + 1]) if j + 1 < T else -1
3428
+ gap_thr = self._gap_fill_for_class(left) if left >= 0 and left == right else max_gap
3429
+ if gap_len <= gap_thr and left >= 0 and left == right:
3430
+ labels[i:j + 1] = left
3431
+ i = j + 1
3432
+
3433
+ if min_len > 1 or self.class_min_segment_frames:
3434
+ changed = True
3435
+ while changed:
3436
+ changed = False
3437
+ i = 0
3438
+ while i < T:
3439
+ cls = int(labels[i])
3440
+ j = i
3441
+ while j + 1 < T and int(labels[j + 1]) == cls:
3442
+ j += 1
3443
+ run_len = j - i + 1
3444
+ min_len_cls = self._min_seg_for_class(cls) if cls >= 0 else min_len
3445
+ if cls >= 0 and run_len < min_len_cls:
3446
+ left = int(labels[i - 1]) if i > 0 else -1
3447
+ right = int(labels[j + 1]) if j + 1 < T else -1
3448
+ if left >= 0 and right >= 0:
3449
+ repl = left if left == right else left
3450
+ elif left >= 0:
3451
+ repl = left
3452
+ elif right >= 0:
3453
+ repl = right
3454
+ else:
3455
+ repl = -1
3456
+ if repl != cls:
3457
+ labels[i:j + 1] = repl
3458
+ changed = True
3459
+ i = j + 1
3460
+
3461
+ return labels
3462
+
3463
+ def _viterbi_decode_dense(self, probs: np.ndarray) -> np.ndarray:
3464
+ """Single-label Viterbi decode for a contiguous covered frame range."""
3465
+ if probs.ndim != 2 or probs.shape[0] == 0 or probs.shape[1] == 0:
3466
+ return np.zeros((0,), dtype=np.int32)
3467
+ T, C = probs.shape
3468
+ if C == 1:
3469
+ return np.zeros((T,), dtype=np.int32)
3470
+
3471
+ emissions = np.log(np.clip(probs.astype(np.float32, copy=False), 1e-8, 1.0))
3472
+ switch_penalty = float(max(0.0, self.viterbi_switch_penalty))
3473
+ trans = np.full((C, C), -switch_penalty, dtype=np.float32)
3474
+ np.fill_diagonal(trans, 0.0)
3475
+
3476
+ dp = np.empty((T, C), dtype=np.float32)
3477
+ backptr = np.zeros((T, C), dtype=np.int32)
3478
+ dp[0] = emissions[0]
3479
+
3480
+ for t in range(1, T):
3481
+ scores = dp[t - 1][:, np.newaxis] + trans
3482
+ backptr[t] = np.argmax(scores, axis=0).astype(np.int32)
3483
+ dp[t] = emissions[t] + scores[backptr[t], np.arange(C)]
3484
+
3485
+ labels = np.zeros((T,), dtype=np.int32)
3486
+ labels[-1] = int(np.argmax(dp[-1]))
3487
+ for t in range(T - 2, -1, -1):
3488
+ labels[t] = int(backptr[t + 1, labels[t + 1]])
3489
+ return labels
3490
+
3491
+ def _decode_viterbi_labels(self, fs: np.ndarray, covered_mask: np.ndarray) -> np.ndarray:
3492
+ """Run Viterbi only on contiguous covered ranges, leaving gaps as -1."""
3493
+ T = int(fs.shape[0])
3494
+ decoded = np.full((T,), -1, dtype=np.int32)
3495
+ i = 0
3496
+ while i < T:
3497
+ if not bool(covered_mask[i]):
3498
+ i += 1
3499
+ continue
3500
+ j = i
3501
+ while j + 1 < T and bool(covered_mask[j + 1]):
3502
+ j += 1
3503
+ decoded[i:j + 1] = self._viterbi_decode_dense(fs[i:j + 1])
3504
+ i = j + 1
3505
+ return decoded
3506
+
3507
+ def _binary_viterbi_decode(self, probs: np.ndarray, threshold: float) -> np.ndarray:
3508
+ """Binary Viterbi decode for one OvR class over a contiguous covered range."""
3509
+ p = np.clip(np.asarray(probs, dtype=np.float32).reshape(-1), 1e-6, 1.0 - 1e-6)
3510
+ T = int(p.shape[0])
3511
+ if T == 0:
3512
+ return np.zeros((0,), dtype=bool)
3513
+
3514
+ tau = float(np.clip(threshold, 1e-4, 1.0 - 1e-4))
3515
+ emit_off = np.log(1.0 - p) - np.log(1.0 - tau)
3516
+ emit_on = np.log(p) - np.log(tau)
3517
+ switch_penalty = float(max(0.0, self.viterbi_switch_penalty))
3518
+
3519
+ dp = np.empty((T, 2), dtype=np.float32)
3520
+ backptr = np.zeros((T, 2), dtype=np.int8)
3521
+ dp[0, 0] = emit_off[0]
3522
+ dp[0, 1] = emit_on[0]
3523
+
3524
+ for t in range(1, T):
3525
+ stay_off = dp[t - 1, 0]
3526
+ on_to_off = dp[t - 1, 1] - switch_penalty
3527
+ if stay_off >= on_to_off:
3528
+ dp[t, 0] = emit_off[t] + stay_off
3529
+ backptr[t, 0] = 0
3530
+ else:
3531
+ dp[t, 0] = emit_off[t] + on_to_off
3532
+ backptr[t, 0] = 1
3533
+
3534
+ off_to_on = dp[t - 1, 0] - switch_penalty
3535
+ stay_on = dp[t - 1, 1]
3536
+ if off_to_on >= stay_on:
3537
+ dp[t, 1] = emit_on[t] + off_to_on
3538
+ backptr[t, 1] = 0
3539
+ else:
3540
+ dp[t, 1] = emit_on[t] + stay_on
3541
+ backptr[t, 1] = 1
3542
+
3543
+ states = np.zeros((T,), dtype=np.int8)
3544
+ states[-1] = 1 if dp[-1, 1] >= dp[-1, 0] else 0
3545
+ for t in range(T - 2, -1, -1):
3546
+ states[t] = backptr[t + 1, states[t + 1]]
3547
+ return states.astype(bool)
3548
+
3549
+ def _build_timeline_segments(self):
3550
+ """Build timeline segments from precomputed frame probabilities."""
3551
+ if not isinstance(self._aggregated_frame_scores_norm, np.ndarray):
3552
+ self.aggregated_segments = []
3553
+ self.aggregated_multiclass_segments = []
3554
+ self._aggregated_active_mask = None
3555
+ self._aggregated_last_covered_frame = 0
3556
+ return
3557
+
3558
+ fs = self._aggregated_frame_scores_norm
3559
+ if fs.ndim != 2 or fs.shape[0] == 0 or fs.shape[1] == 0:
3560
+ self.aggregated_segments = []
3561
+ self.aggregated_multiclass_segments = []
3562
+ self._aggregated_active_mask = None
3563
+ self._aggregated_last_covered_frame = 0
3564
+ return
3565
+
3566
+ T, C = fs.shape
3567
+ self._aggregated_last_covered_frame = T
3568
+ covered_mask = np.sum(fs, axis=1) > 1e-8
3569
+ if self.use_viterbi_decode and not self._use_ovr:
3570
+ frame_labels = self._decode_viterbi_labels(fs, covered_mask)
3571
+ else:
3572
+ frame_labels = np.argmax(fs, axis=1)
3573
+ frame_labels[~covered_mask] = -1
3574
+ # Smooth raw argmax labels (majority-vote temporal smoothing).
3575
+ frame_labels = self._smooth_frame_labels(frame_labels)
3576
+
3577
+ # Apply ignore threshold after smoothing.
3578
+ if self.use_ignore_threshold and not self._use_ovr:
3579
+ for fi in range(T):
3580
+ ci = int(frame_labels[fi])
3581
+ if ci < 0:
3582
+ continue
3583
+ thr = self._threshold_for_pred(ci)
3584
+ if float(fs[fi, ci]) < thr:
3585
+ frame_labels[fi] = -1
3586
+
3587
+ # Merge-gap and min-segment run ONCE as the final cleanup, after all
3588
+ # other preprocessing (smoothing + threshold) is done.
3589
+ frame_labels = self._apply_gap_merge_and_min_seg(frame_labels, T)
3590
+
3591
+ segments = []
3592
+ cur_cls = int(frame_labels[0])
3593
+ cur_start = 0
3594
+ for i in range(1, T):
3595
+ if int(frame_labels[i]) != cur_cls:
3596
+ conf = float(np.mean(fs[cur_start:i, cur_cls])) if cur_cls >= 0 else 0.0
3597
+ segments.append({
3598
+ "class": int(cur_cls),
3599
+ "start": int(cur_start),
3600
+ "end": int(i - 1),
3601
+ "confidence": conf,
3602
+ })
3603
+ cur_cls = int(frame_labels[i])
3604
+ cur_start = i
3605
+ conf = float(np.mean(fs[cur_start:T, cur_cls])) if cur_cls >= 0 else 0.0
3606
+ segments.append({
3607
+ "class": int(cur_cls),
3608
+ "start": int(cur_start),
3609
+ "end": int(T - 1),
3610
+ "confidence": conf,
3611
+ })
3612
+ self.aggregated_segments = segments
3613
+
3614
+ self.aggregated_multiclass_segments = []
3615
+ self._aggregated_active_mask = None
3616
+ if self._use_ovr:
3617
+ active_mask = np.zeros((T, C), dtype=bool)
3618
+ ovr_threshold = None if self.use_ignore_threshold else 0.35
3619
+ show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
3620
+ if self.use_viterbi_decode:
3621
+ for ci in range(C):
3622
+ thr = self._threshold_for_pred(ci) if self.use_ignore_threshold else float(ovr_threshold)
3623
+ active_mask[:, ci] = self._binary_viterbi_decode(fs[:, ci], thr)
3624
+ if not show_all:
3625
+ pruned_mask = np.zeros_like(active_mask)
3626
+ for fi in range(T):
3627
+ active_idx = np.flatnonzero(active_mask[fi])
3628
+ if active_idx.size == 0:
3629
+ continue
3630
+ scores = fs[fi, active_idx]
3631
+ top_ci = int(active_idx[int(np.argmax(scores))])
3632
+ pruned_mask[fi, top_ci] = True
3633
+ if self._allowed_cooccurrence:
3634
+ top_name = self.classes[top_ci]
3635
+ for ci in active_idx:
3636
+ ci = int(ci)
3637
+ if ci == top_ci:
3638
+ continue
3639
+ name = self.classes[ci]
3640
+ if (top_name, name) in self._allowed_cooccurrence:
3641
+ pruned_mask[fi, ci] = True
3642
+ active_mask = pruned_mask
3643
+ else:
3644
+ for fi in range(T):
3645
+ for ci in self._active_ovr_indices_from_scores(fs[fi], threshold_override=ovr_threshold):
3646
+ if 0 <= ci < C:
3647
+ active_mask[fi, ci] = True
3648
+ self._aggregated_active_mask = active_mask
3649
+
3650
+ max_gap = int(max(0, getattr(self, "_merge_gap_frames", 0)))
3651
+ min_len = int(max(1, getattr(self, "_min_segment_frames", 1)))
3652
+
3653
+ # Per-class gap-fill and min-segment cleanup for OvR.
3654
+ for ci in range(C):
3655
+ col = active_mask[:, ci]
3656
+ smooth_win = self._smooth_win_for_class(ci)
3657
+ if (not self.use_viterbi_decode) and smooth_win > 1 and T > 1:
3658
+ half = smooth_win // 2
3659
+ smooth_col = col.copy()
3660
+ for i in range(T):
3661
+ left = max(0, i - half)
3662
+ right = min(T, i + half + 1)
3663
+ window = col[left:right]
3664
+ on_count = int(np.count_nonzero(window))
3665
+ off_count = int(window.size - on_count)
3666
+ if on_count > off_count:
3667
+ smooth_col[i] = True
3668
+ elif off_count > on_count:
3669
+ smooth_col[i] = False
3670
+ col = smooth_col
3671
+
3672
+ gap_thr = self._gap_fill_for_class(ci)
3673
+ min_len_cls = self._min_seg_for_class(ci)
3674
+ # Gap-fill: a short "off" gap between two "on" runs of the same class.
3675
+ if gap_thr > 0:
3676
+ i = 0
3677
+ while i < T:
3678
+ if col[i]:
3679
+ i += 1
3680
+ continue
3681
+ j = i
3682
+ while j + 1 < T and not col[j + 1]:
3683
+ j += 1
3684
+ gap_len = j - i + 1
3685
+ left_on = col[i - 1] if i > 0 else False
3686
+ right_on = col[j + 1] if j + 1 < T else False
3687
+ if gap_len <= gap_thr and left_on and right_on:
3688
+ col[i:j + 1] = True
3689
+ i = j + 1
3690
+ # Min-segment: remove short "on" runs.
3691
+ if min_len_cls > 1:
3692
+ i = 0
3693
+ while i < T:
3694
+ if not col[i]:
3695
+ i += 1
3696
+ continue
3697
+ j = i
3698
+ while j + 1 < T and col[j + 1]:
3699
+ j += 1
3700
+ run_len = j - i + 1
3701
+ if run_len < min_len_cls:
3702
+ col[i:j + 1] = False
3703
+ i = j + 1
3704
+ active_mask[:, ci] = col
3705
+
3706
+ if not show_all:
3707
+ pruned_mask = np.zeros_like(active_mask)
3708
+ for fi in range(T):
3709
+ active_idx = np.flatnonzero(active_mask[fi])
3710
+ if active_idx.size == 0:
3711
+ continue
3712
+ scores = fs[fi, active_idx]
3713
+ top_ci = int(active_idx[int(np.argmax(scores))])
3714
+ pruned_mask[fi, top_ci] = True
3715
+ if self._allowed_cooccurrence:
3716
+ top_name = self.classes[top_ci]
3717
+ for ci in active_idx:
3718
+ ci = int(ci)
3719
+ if ci == top_ci:
3720
+ continue
3721
+ name = self.classes[ci]
3722
+ if (top_name, name) in self._allowed_cooccurrence:
3723
+ pruned_mask[fi, ci] = True
3724
+ active_mask = pruned_mask
3725
+ self._aggregated_active_mask = active_mask
3726
+
3727
+ multi_segments = []
3728
+ for ci in range(C):
3729
+ run_start = None
3730
+ for fi in range(T):
3731
+ is_on = bool(active_mask[fi, ci])
3732
+ if is_on and run_start is None:
3733
+ run_start = fi
3734
+ if run_start is not None and (not is_on or fi == T - 1):
3735
+ run_end = fi if (is_on and fi == T - 1) else fi - 1
3736
+ if run_end >= run_start:
3737
+ conf = float(np.mean(fs[run_start:run_end + 1, ci]))
3738
+ multi_segments.append({
3739
+ "class": int(ci),
3740
+ "start": int(run_start),
3741
+ "end": int(run_end),
3742
+ "confidence": conf,
3743
+ })
3744
+ run_start = None
3745
+ self.aggregated_multiclass_segments = multi_segments
3746
+
3747
+ def _compute_aggregated_timeline(self):
3748
+ """
3749
+ Aggregate overlapping clip predictions into precise frame-level segments.
3750
+ Uses confidence-weighted voting: each clip votes for its predicted class
3751
+ with weight equal to its confidence score.
3752
+ """
3753
+ if not self.predictions or not self.classes:
3754
+ self.aggregated_segments = []
3755
+ self.aggregated_multiclass_segments = []
3756
+ self._aggregated_frame_scores_norm = None
3757
+ self._aggregated_active_mask = None
3758
+ self._aggregated_last_covered_frame = 0
3759
+ return
3760
+
3761
+ # Prefer the exact worker-built merged frame timeline when available.
3762
+ # This preserves center-weighted overlap merge.
3763
+ if not self.corrected_labels:
3764
+ precomputed = self._get_precomputed_aggregated_probs(self.video_path)
3765
+ if precomputed is not None:
3766
+ self._aggregated_frame_scores_norm = precomputed.copy()
3767
+ self._aggregated_last_covered_frame = int(precomputed.shape[0])
3768
+ self._build_timeline_segments()
3769
+ return
3770
+
3771
+ # Get parameters
3772
+ clip_length = self.clip_length_spin.value()
3773
+ step_frames = self.step_frames_spin.value()
3774
+ target_fps = self.target_fps_spin.value()
3775
+ num_classes = len(self.classes)
3776
+
3777
+ # Get total frames from video or estimate from clip_starts
3778
+ if self.total_frames > 0:
3779
+ total_frames = self.total_frames
3780
+ elif self.clip_starts:
3781
+ # Estimate: last clip start + clip_length
3782
+ total_frames = self.clip_starts[-1] + clip_length + 1
3783
+ else:
3784
+ self.aggregated_segments = []
3785
+ self.aggregated_multiclass_segments = []
3786
+ self._aggregated_frame_scores_norm = None
3787
+ self._aggregated_active_mask = None
3788
+ self._aggregated_last_covered_frame = 0
3789
+ return
3790
+
3791
+ orig_fps = self._get_video_fps(self.video_path) if self.video_path else 30.0
3792
+ frame_interval = self._get_saved_frame_interval(self.video_path, orig_fps)
3793
+
3794
+ # Apply corrections + ignore-threshold gating
3795
+ corrected_preds = self._effective_predictions()
3796
+
3797
+ # Initialize score matrix: [total_frames, num_classes] and coverage count
3798
+ frame_scores = np.zeros((total_frames, num_classes), dtype=np.float32)
3799
+ frame_coverage = np.zeros(total_frames, dtype=np.float32)
3800
+ used_full_probability_voting = False
3801
+ probs_available = (
3802
+ isinstance(self.clip_probabilities, list)
3803
+ and len(self.clip_probabilities) > 0
3804
+ )
3805
+ # Per-frame probabilities from FrameClassificationHead give much finer
3806
+ # temporal resolution: each frame within a clip gets its own prediction
3807
+ # instead of the entire clip being smeared with one probability vector.
3808
+ frame_probs_available = (
3809
+ isinstance(self.clip_frame_probabilities, list)
3810
+ and len(self.clip_frame_probabilities) > 0
3811
+ )
3812
+
3813
+ # Accumulate votes from each clip
3814
+ for clip_i, (pred_class, conf, start_frame) in enumerate(zip(corrected_preds, self.confidences, self.clip_starts)):
3815
+ end_frame = min(start_frame + clip_length * frame_interval, total_frames)
3816
+
3817
+ if start_frame >= total_frames:
3818
+ continue
3819
+
3820
+ # Best path: per-frame probabilities from frame head. Do NOT gate this
3821
+ # by clip-level class/confidence, since a mixed clip can contain multiple
3822
+ # behaviors with low clip confidence but useful frame-wise predictions.
3823
+ if frame_probs_available and clip_i < len(self.clip_frame_probabilities):
3824
+ fp = self.clip_frame_probabilities[clip_i]
3825
+ if isinstance(fp, (list, np.ndarray)):
3826
+ fp_arr = np.asarray(fp, dtype=np.float32) # [T, C]
3827
+ if fp_arr.ndim == 2 and fp_arr.shape[1] == num_classes:
3828
+ T_clip = fp_arr.shape[0]
3829
+ merge_w = self._build_center_merge_weights(T_clip)
3830
+ for t in range(T_clip):
3831
+ f_start = start_frame + t * frame_interval
3832
+ f_end = min(f_start + frame_interval, total_frames)
3833
+ if f_start >= total_frames:
3834
+ break
3835
+ if f_end <= f_start:
3836
+ continue
3837
+ probs_t = np.clip(fp_arr[t], 0.0, None)
3838
+ w = float(merge_w[t])
3839
+ frame_scores[f_start:f_end, :] += probs_t[np.newaxis, :] * w
3840
+ frame_coverage[f_start:f_end] += w
3841
+ used_full_probability_voting = True
3842
+ continue
3843
+
3844
+ # Remaining paths are clip-level and require a valid clip prediction.
3845
+ if not (0 <= pred_class < num_classes):
3846
+ continue
3847
+
3848
+ # Fallback: clip-level probabilities smeared across the clip
3849
+ if probs_available and clip_i < len(self.clip_probabilities):
3850
+ raw_probs = self.clip_probabilities[clip_i]
3851
+ if isinstance(raw_probs, (list, tuple, np.ndarray)) and len(raw_probs) == num_classes:
3852
+ probs_vec = np.asarray(raw_probs, dtype=np.float32)
3853
+ if np.all(np.isfinite(probs_vec)):
3854
+ probs_vec = np.clip(probs_vec, 0.0, None)
3855
+ s = float(np.sum(probs_vec))
3856
+ if s > 1e-8:
3857
+ if not self._use_ovr:
3858
+ probs_vec = probs_vec / s
3859
+ if (not self._use_ovr) and int(np.argmax(probs_vec)) != int(pred_class):
3860
+ probs_vec = np.zeros(num_classes, dtype=np.float32)
3861
+ probs_vec[int(pred_class)] = 1.0
3862
+ frame_scores[start_frame:end_frame, :] += probs_vec[np.newaxis, :]
3863
+ frame_coverage[start_frame:end_frame] += 1.0
3864
+ used_full_probability_voting = True
3865
+ continue
3866
+
3867
+ # Last fallback for older result files without per-class probabilities.
3868
+ frame_scores[start_frame:end_frame, pred_class] += float(conf)
3869
+ frame_coverage[start_frame:end_frame] += 1.0
3870
+
3871
+ # Normalize scores by clip coverage so confidence stays in [0, 1]
3872
+ frame_scores_norm = np.divide(
3873
+ frame_scores,
3874
+ np.maximum(frame_coverage[:, np.newaxis], 1.0),
3875
+ )
3876
+
3877
+ # For Softmax, ensure final probabilities sum to 1 (averaging usually does, but small errors can creep in)
3878
+ if not self._use_ovr:
3879
+ sums = np.sum(frame_scores_norm, axis=1, keepdims=True)
3880
+ valid_sums = sums > 1e-8
3881
+ frame_scores_norm = np.where(
3882
+ valid_sums,
3883
+ frame_scores_norm / sums,
3884
+ frame_scores_norm
3885
+ )
3886
+
3887
+ # Find the last frame with actual clip coverage to avoid phantom
3888
+ # class-0 labels from uncovered tail frames (argmax of all-zeros = 0).
3889
+ # Interior uncovered frames (filtered clips) are marked -1 ("Filtered").
3890
+ # Use coverage count (number of clips that touched each frame) rather than
3891
+ # score sum, which can be > 0 for uncovered frames in OvR (sigmoid > 0 always).
3892
+ covered_mask = frame_coverage > 0
3893
+ if not np.any(covered_mask):
3894
+ self.aggregated_segments = []
3895
+ self.aggregated_multiclass_segments = []
3896
+ self._aggregated_frame_scores_norm = None
3897
+ self._aggregated_active_mask = None
3898
+ self._aggregated_last_covered_frame = 0
3899
+ return
3900
+ last_covered_frame = int(np.max(np.where(covered_mask))) + 1 # exclusive end
3901
+ self._aggregated_last_covered_frame = last_covered_frame
3902
+ self._aggregated_frame_scores_norm = frame_scores_norm[:last_covered_frame].copy()
3903
+
3904
+ self._build_timeline_segments()
3905
+
3906
+ # Log summary
3907
+ if self.aggregated_segments:
3908
+ self.log_text.append(f"Frame aggregation: {len(self.aggregated_segments)} segments from {len(self.predictions)} clips")
3909
+ if self._use_ovr and self.aggregated_multiclass_segments:
3910
+ self.log_text.append(
3911
+ f" OvR precise multi-class segments: {len(self.aggregated_multiclass_segments)}"
3912
+ )
3913
+ if used_full_probability_voting:
3914
+ mode_name = "sigmoid" if getattr(self, "_use_ovr", False) else "softmax"
3915
+ self.log_text.append(f" Evidence mode: full {mode_name} probability voting")
3916
+ else:
3917
+ self.log_text.append(" Evidence mode: top-1 confidence voting")
3918
+ step_info = f"step={step_frames}, clip_length={clip_length}"
3919
+ if step_frames >= clip_length:
3920
+ self.log_text.append(f" Note: No overlap ({step_info}). For better boundaries, use step_frames < clip_length.")
3921
+ else:
3922
+ overlap_pct = (1 - step_frames / clip_length) * 100
3923
+ self.log_text.append(f" Overlap: {overlap_pct:.0f}% ({step_info})")
3924
+
3925
+ def _export_timeline(self):
3926
+ """Export timeline as SVG/PDF and CSV with behavior segments."""
3927
+ if not self.predictions or not self.video_path:
3928
+ QMessageBox.warning(self, "Error", "No predictions available to export.")
3929
+ return
3930
+
3931
+ self._persist_current_video_state()
3932
+ available_videos = [vp for vp in self.results_cache.keys() if isinstance(self.results_cache.get(vp), dict)]
3933
+ if not available_videos:
3934
+ available_videos = [self.video_path]
3935
+
3936
+ export_selection = self._prompt_timeline_export_options(available_videos)
3937
+ if not export_selection:
3938
+ return
3939
+ selected_videos, selected_classes = export_selection
3940
+
3941
+ if len(selected_videos) == 1:
3942
+ default_base = os.path.splitext(selected_videos[0])[0] + "_timeline"
3943
+ base_path, _ = QFileDialog.getSaveFileName(
3944
+ self,
3945
+ "Export Timeline",
3946
+ default_base,
3947
+ "All Files (*)"
3948
+ )
3949
+ if not base_path:
3950
+ return
3951
+ export_root = os.path.splitext(base_path)[0]
3952
+ else:
3953
+ default_dir = os.path.dirname(self.video_path) if self.video_path else os.getcwd()
3954
+ export_root = QFileDialog.getExistingDirectory(self, "Select Export Folder", default_dir)
3955
+ if not export_root:
3956
+ return
3957
+
3958
+ original_video_path = self.video_path
3959
+ original_threshold_settings = self._current_threshold_settings()
3960
+ shared_threshold_settings = dict(original_threshold_settings)
3961
+ exported = []
3962
+ n_videos = len(selected_videos)
3963
+ show_progress = n_videos > 1
3964
+ progress = None
3965
+ export_canceled = False
3966
+ if show_progress:
3967
+ progress = QProgressDialog(
3968
+ "Exporting timeline 1 / {}...".format(n_videos),
3969
+ "Cancel",
3970
+ 0,
3971
+ n_videos,
3972
+ self,
3973
+ )
3974
+ progress.setWindowTitle("Timeline export")
3975
+ progress.setWindowModality(Qt.WindowModality.WindowModal)
3976
+ progress.setMinimumDuration(0)
3977
+ progress.setValue(0)
3978
+ progress.show()
3979
+ QApplication.processEvents()
3980
+
3981
+ try:
3982
+ for vi, video_path in enumerate(selected_videos):
3983
+ if progress and progress.wasCanceled():
3984
+ export_canceled = True
3985
+ break
3986
+ if show_progress:
3987
+ progress.setValue(vi)
3988
+ progress.setLabelText(
3989
+ "Exporting timeline {} / {}: {}".format(
3990
+ vi + 1, n_videos, os.path.basename(video_path)
3991
+ )
3992
+ )
3993
+ QApplication.processEvents()
3994
+
3995
+ entry = self.results_cache.get(video_path, {})
3996
+ threshold_override = None
3997
+ if not isinstance(entry.get("threshold_settings"), dict):
3998
+ threshold_override = shared_threshold_settings
3999
+ ok = self._load_video_from_cache(
4000
+ video_path,
4001
+ refresh_display=False,
4002
+ persist_current=False,
4003
+ threshold_settings_override=threshold_override,
4004
+ persist_loaded_thresholds=False,
4005
+ )
4006
+ if not ok:
4007
+ continue
4008
+
4009
+ if len(selected_videos) == 1:
4010
+ base_path = export_root
4011
+ else:
4012
+ base_path = os.path.join(
4013
+ export_root,
4014
+ os.path.splitext(os.path.basename(video_path))[0] + "_timeline",
4015
+ )
4016
+ csv_path, svg_path, mode_text = self._export_current_timeline_to_base_path(
4017
+ base_path,
4018
+ selected_classes=selected_classes,
4019
+ )
4020
+ exported.append((video_path, csv_path, svg_path, mode_text))
4021
+ if show_progress:
4022
+ progress.setValue(vi + 1)
4023
+ QApplication.processEvents()
4024
+ except Exception as e:
4025
+ QMessageBox.critical(self, "Export Error", f"Failed to export timeline:\n{str(e)}")
4026
+ if progress:
4027
+ progress.close()
4028
+ return
4029
+ finally:
4030
+ if progress:
4031
+ progress.close()
4032
+ if original_video_path and original_video_path in self.results_cache:
4033
+ self._load_video_from_cache(
4034
+ original_video_path,
4035
+ refresh_display=True,
4036
+ persist_current=False,
4037
+ threshold_settings_override=original_threshold_settings,
4038
+ persist_loaded_thresholds=False,
4039
+ )
4040
+ idx = self.filter_video_combo.findData(original_video_path)
4041
+ if idx >= 0:
4042
+ self.filter_video_combo.blockSignals(True)
4043
+ self.filter_video_combo.setCurrentIndex(idx)
4044
+ self.filter_video_combo.blockSignals(False)
4045
+
4046
+ if not exported:
4047
+ msg = "Export canceled." if export_canceled else "No timelines were exported."
4048
+ QMessageBox.warning(self, "Export", msg)
4049
+ return
4050
+
4051
+ if len(exported) == 1:
4052
+ _, csv_path, svg_path, mode_text = exported[0]
4053
+ QMessageBox.information(
4054
+ self,
4055
+ "Export Complete",
4056
+ f"Timeline exported successfully!\n\n"
4057
+ f"Mode: {mode_text}\n"
4058
+ f"CSV: {csv_path}\n"
4059
+ f"SVG: {svg_path}"
4060
+ )
4061
+ else:
4062
+ QMessageBox.information(
4063
+ self,
4064
+ "Batch Export Complete",
4065
+ f"Exported timelines for {len(exported)} videos to:\n{export_root}"
4066
+ )
4067
+
4068
+ def _prompt_timeline_export_options(self, available_videos):
4069
+ dlg = QDialog(self)
4070
+ dlg.setWindowTitle("Batch Timeline Export")
4071
+ dlg.resize(760, 460)
4072
+ layout = QVBoxLayout(dlg)
4073
+ layout.addWidget(QLabel("Choose which videos and classes to export. All classes are included by default."))
4074
+
4075
+ content_row = QHBoxLayout()
4076
+
4077
+ video_col = QVBoxLayout()
4078
+ video_col.addWidget(QLabel("Videos"))
4079
+ video_button_row = QHBoxLayout()
4080
+ select_all_btn = QPushButton("Select all")
4081
+ current_btn = QPushButton("Current video")
4082
+ clear_btn = QPushButton("Clear")
4083
+ video_button_row.addWidget(select_all_btn)
4084
+ video_button_row.addWidget(current_btn)
4085
+ video_button_row.addWidget(clear_btn)
4086
+ video_col.addLayout(video_button_row)
4087
+
4088
+ list_widget = QListWidget()
4089
+ current_video = self.video_path
4090
+ for vp in available_videos:
4091
+ item = QListWidgetItem(os.path.basename(vp))
4092
+ item.setData(Qt.ItemDataRole.UserRole, vp)
4093
+ item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable | Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable)
4094
+ default_checked = len(available_videos) == 1 or vp == current_video
4095
+ item.setCheckState(Qt.CheckState.Checked if default_checked else Qt.CheckState.Unchecked)
4096
+ list_widget.addItem(item)
4097
+ video_col.addWidget(list_widget, stretch=1)
4098
+ content_row.addLayout(video_col, stretch=1)
4099
+
4100
+ class_col = QVBoxLayout()
4101
+ class_col.addWidget(QLabel("Classes"))
4102
+ class_button_row = QHBoxLayout()
4103
+ class_all_btn = QPushButton("All")
4104
+ class_none_btn = QPushButton("None")
4105
+ class_button_row.addWidget(class_all_btn)
4106
+ class_button_row.addWidget(class_none_btn)
4107
+ class_col.addLayout(class_button_row)
4108
+
4109
+ class_list_widget = QListWidget()
4110
+ for cls_idx, cls_name in enumerate(self.classes):
4111
+ item = QListWidgetItem(cls_name)
4112
+ item.setData(Qt.ItemDataRole.UserRole, cls_idx)
4113
+ item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable | Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable)
4114
+ item.setCheckState(Qt.CheckState.Checked)
4115
+ class_list_widget.addItem(item)
4116
+ class_col.addWidget(class_list_widget, stretch=1)
4117
+ content_row.addLayout(class_col, stretch=1)
4118
+
4119
+ layout.addLayout(content_row, stretch=1)
4120
+
4121
+ def _set_all_items(state):
4122
+ for i in range(list_widget.count()):
4123
+ list_widget.item(i).setCheckState(state)
4124
+
4125
+ def _set_all_classes(state):
4126
+ for i in range(class_list_widget.count()):
4127
+ class_list_widget.item(i).setCheckState(state)
4128
+
4129
+ def _select_current_only():
4130
+ for i in range(list_widget.count()):
4131
+ item = list_widget.item(i)
4132
+ is_current = item.data(Qt.ItemDataRole.UserRole) == current_video
4133
+ item.setCheckState(Qt.CheckState.Checked if is_current else Qt.CheckState.Unchecked)
4134
+
4135
+ select_all_btn.clicked.connect(lambda: _set_all_items(Qt.CheckState.Checked))
4136
+ clear_btn.clicked.connect(lambda: _set_all_items(Qt.CheckState.Unchecked))
4137
+ current_btn.clicked.connect(_select_current_only)
4138
+ class_all_btn.clicked.connect(lambda: _set_all_classes(Qt.CheckState.Checked))
4139
+ class_none_btn.clicked.connect(lambda: _set_all_classes(Qt.CheckState.Unchecked))
4140
+
4141
+ btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
4142
+ layout.addWidget(btns)
4143
+ btns.accepted.connect(dlg.accept)
4144
+ btns.rejected.connect(dlg.reject)
4145
+
4146
+ if not dlg.exec():
4147
+ return None
4148
+
4149
+ selected = []
4150
+ for i in range(list_widget.count()):
4151
+ item = list_widget.item(i)
4152
+ if item.checkState() == Qt.CheckState.Checked:
4153
+ selected.append(item.data(Qt.ItemDataRole.UserRole))
4154
+
4155
+ if not selected:
4156
+ QMessageBox.information(self, "No videos selected", "Select at least one video to export.")
4157
+ return None
4158
+
4159
+ selected_classes = set()
4160
+ for i in range(class_list_widget.count()):
4161
+ item = class_list_widget.item(i)
4162
+ if item.checkState() == Qt.CheckState.Checked:
4163
+ selected_classes.add(int(item.data(Qt.ItemDataRole.UserRole)))
4164
+
4165
+ if not selected_classes:
4166
+ QMessageBox.information(self, "No classes selected", "Select at least one class to export.")
4167
+ return None
4168
+
4169
+ return selected, selected_classes
4170
+
4171
+ def _export_current_timeline_to_base_path(self, base_path: str, selected_classes: set[int] | None = None):
4172
+ if not self.predictions or not self.video_path:
4173
+ raise RuntimeError("No predictions available to export.")
4174
+
4175
+ orig_fps = self._get_video_fps(self.video_path)
4176
+ frame_interval = self._get_saved_frame_interval(self.video_path, orig_fps)
4177
+ clip_length = self.clip_length_spin.value()
4178
+ frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
4179
+
4180
+ csv_path = base_path + "_behaviors.csv"
4181
+ import csv
4182
+
4183
+ if frame_aggregation_enabled and self.aggregated_segments:
4184
+ with open(csv_path, 'w', newline='') as f:
4185
+ writer = csv.writer(f)
4186
+ writer.writerow(['Behavior', 'Start Time (s)', 'End Time (s)', 'Start Frame', 'End Frame', 'Duration (s)', 'Confidence'])
4187
+
4188
+ seg_source = self.aggregated_segments
4189
+ if self._use_ovr and self.aggregated_multiclass_segments:
4190
+ seg_source = self.aggregated_multiclass_segments
4191
+
4192
+ for seg in seg_source:
4193
+ pred_idx = seg['class']
4194
+ if pred_idx < len(self.classes) and (selected_classes is None or pred_idx in selected_classes):
4195
+ behavior = self.classes[pred_idx]
4196
+ start_frame = seg['start']
4197
+ end_frame = seg['end']
4198
+ conf = seg.get('confidence', 1.0)
4199
+
4200
+ start_time = start_frame / orig_fps
4201
+ end_time = (end_frame + 1) / orig_fps
4202
+ duration = end_time - start_time
4203
+
4204
+ writer.writerow([
4205
+ behavior,
4206
+ f"{start_time:.3f}",
4207
+ f"{end_time:.3f}",
4208
+ start_frame,
4209
+ end_frame,
4210
+ f"{duration:.3f}",
4211
+ f"{conf:.3f}"
4212
+ ])
4213
+
4214
+ svg_path = base_path + "_timeline.svg"
4215
+ self._export_frame_aggregated_svg(svg_path, orig_fps, selected_classes=selected_classes)
4216
+ mode_text = "frame-aggregated (precise boundaries)"
4217
+ else:
4218
+ corrected_preds = self._effective_predictions()
4219
+
4220
+ if self.merge_timeline_check.isChecked():
4221
+ display_preds, display_confs, display_starts = self._merge_predictions(
4222
+ corrected_preds, self.confidences, self.clip_starts
4223
+ )
4224
+ else:
4225
+ display_preds = corrected_preds
4226
+ display_confs = self.confidences
4227
+ display_starts = self.clip_starts
4228
+
4229
+ with open(csv_path, 'w', newline='') as f:
4230
+ writer = csv.writer(f)
4231
+ writer.writerow(['Behavior', 'Start Time (s)', 'End Time (s)', 'Start Frame', 'End Frame', 'Confidence'])
4232
+
4233
+ for i, (pred_idx, conf, start_frame) in enumerate(zip(display_preds, display_confs, display_starts)):
4234
+ if pred_idx < len(self.classes) and pred_idx >= 0 and (selected_classes is None or pred_idx in selected_classes):
4235
+ behavior = self.classes[pred_idx]
4236
+ if i < len(display_starts) - 1:
4237
+ end_frame = display_starts[i + 1]
4238
+ else:
4239
+ end_frame = start_frame + (clip_length * frame_interval)
4240
+ start_time = start_frame / orig_fps
4241
+ end_time = end_frame / orig_fps
4242
+ writer.writerow([behavior, f"{start_time:.3f}", f"{end_time:.3f}", start_frame, end_frame, f"{conf:.3f}"])
4243
+
4244
+ svg_path = base_path + "_timeline.svg"
4245
+ self._export_timeline_svg(
4246
+ svg_path,
4247
+ display_preds,
4248
+ display_confs,
4249
+ display_starts,
4250
+ orig_fps,
4251
+ frame_interval,
4252
+ clip_length,
4253
+ selected_classes=selected_classes,
4254
+ )
4255
+ mode_text = "clip-based"
4256
+
4257
+ return csv_path, svg_path, mode_text
4258
+
4259
+ def _export_timeline_svg(self, svg_path, display_preds, display_confs, display_starts, orig_fps, frame_interval, clip_length, selected_classes: set[int] | None = None):
4260
+ """Export timeline as SVG."""
4261
+ num_segments = len(display_preds)
4262
+ if num_segments == 0:
4263
+ return
4264
+
4265
+ # Calculate dimensions
4266
+ base_clip_width = 20
4267
+ total_clips = len(self.predictions)
4268
+ width = max(1200, total_clips * base_clip_width)
4269
+ height = 80
4270
+
4271
+ colors = self._get_timeline_palette()
4272
+
4273
+ with open(svg_path, 'w') as f:
4274
+ f.write(f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">\n')
4275
+ f.write(f'<rect width="{width}" height="{height}" fill="white"/>\n')
4276
+
4277
+ x_pos = 0
4278
+ for seg_idx, (pred_idx, conf) in enumerate(zip(display_preds, display_confs)):
4279
+ if seg_idx < len(display_starts) - 1:
4280
+ seg_start_clip = display_starts[seg_idx]
4281
+ seg_end_clip = display_starts[seg_idx + 1]
4282
+ clip_count = sum(1 for orig_start in self.clip_starts if seg_start_clip <= orig_start < seg_end_clip)
4283
+ seg_width = clip_count * base_clip_width
4284
+ else:
4285
+ seg_start_clip = display_starts[seg_idx]
4286
+ clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
4287
+ seg_width = clip_count * base_clip_width
4288
+
4289
+ if pred_idx < len(self.classes) and (selected_classes is None or pred_idx in selected_classes):
4290
+ color = colors[pred_idx % len(colors)]
4291
+ r, g, b = color
4292
+ alpha = conf
4293
+
4294
+ f.write(f'<rect x="{x_pos}" y="0" width="{seg_width}" height="{height}" '
4295
+ f'fill="rgb({r},{g},{b})" opacity="{alpha:.2f}"/>\n')
4296
+
4297
+ if seg_width >= 30:
4298
+ behavior = self.classes[pred_idx]
4299
+ f.write(f'<text x="{x_pos + 5}" y="{height//2 + 5}" font-family="Arial" font-size="10">{behavior}</text>\n')
4300
+
4301
+ x_pos += seg_width
4302
+
4303
+ f.write('</svg>\n')
4304
+
4305
+ def _export_frame_aggregated_svg(self, svg_path: str, orig_fps: float, selected_classes: set[int] | None = None):
4306
+ """Export frame-aggregated timeline as SVG with precise boundaries."""
4307
+ if not self.aggregated_segments:
4308
+ return
4309
+
4310
+ use_multiclass_rows = bool(self._use_ovr and self.aggregated_multiclass_segments)
4311
+ segments = self.aggregated_multiclass_segments if use_multiclass_rows else self.aggregated_segments
4312
+ row_classes = [ci for ci in range(len(self.classes)) if selected_classes is None or ci in selected_classes]
4313
+ if not row_classes:
4314
+ return
4315
+ row_lookup = {cls_idx: row_idx for row_idx, cls_idx in enumerate(row_classes)}
4316
+ total_frames = self.total_frames
4317
+ if total_frames <= 0 and self.video_path and os.path.exists(self.video_path):
4318
+ try:
4319
+ cap = cv2.VideoCapture(self.video_path)
4320
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
4321
+ cap.release()
4322
+ except Exception:
4323
+ pass
4324
+ if total_frames <= 0 and segments:
4325
+ total_frames = segments[-1]['end'] + 1
4326
+
4327
+ # Calculate dimensions
4328
+ min_width = 1200
4329
+ max_width = 4000
4330
+ pixels_per_frame = max(0.1, min(2.0, max_width / max(1, total_frames)))
4331
+ width = max(min_width, int(total_frames * pixels_per_frame))
4332
+ if use_multiclass_rows:
4333
+ row_h = 20
4334
+ height = max(80, row_h * len(row_classes))
4335
+ else:
4336
+ row_h = 80
4337
+ height = 80
4338
+
4339
+ colors = self._get_timeline_palette()
4340
+
4341
+ with open(svg_path, 'w') as f:
4342
+ f.write(f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">\n')
4343
+ f.write(f'<rect width="{width}" height="{height}" fill="white"/>\n')
4344
+
4345
+ for seg in segments:
4346
+ pred_idx = seg['class']
4347
+ conf = seg.get('confidence', 1.0)
4348
+ start_frame = seg['start']
4349
+ end_frame = seg['end']
4350
+
4351
+ if pred_idx >= len(self.classes) or (selected_classes is not None and pred_idx not in selected_classes):
4352
+ continue
4353
+
4354
+ x_start = int(start_frame * pixels_per_frame)
4355
+ x_end = int((end_frame + 1) * pixels_per_frame)
4356
+ seg_width = max(1, x_end - x_start)
4357
+
4358
+ color = colors[pred_idx % len(colors)]
4359
+ r, g, b = color
4360
+ # Normalize confidence for opacity
4361
+ alpha = min(1.0, 0.5 + 0.5 * min(1.0, conf))
4362
+
4363
+ y0 = (row_lookup[pred_idx] * row_h) if use_multiclass_rows else 0
4364
+ h0 = row_h if use_multiclass_rows else height
4365
+ f.write(f'<rect x="{x_start}" y="{y0}" width="{seg_width}" height="{h0}" '
4366
+ f'fill="rgb({r},{g},{b})" opacity="{alpha:.2f}"/>\n')
4367
+
4368
+ # Draw boundary line
4369
+ f.write(f'<line x1="{x_start}" y1="{y0}" x2="{x_start}" y2="{y0 + h0}" '
4370
+ f'stroke="black" stroke-width="0.5" opacity="0.3"/>\n')
4371
+
4372
+ if seg_width >= 40:
4373
+ behavior = self.classes[pred_idx]
4374
+ start_time = start_frame / orig_fps
4375
+ end_time = (end_frame + 1) / orig_fps
4376
+ ty = y0 + min(15, h0 - 5)
4377
+ f.write(f'<text x="{x_start + 3}" y="{ty}" font-family="Arial" font-size="10">{behavior}</text>\n')
4378
+ if not use_multiclass_rows and h0 >= 30:
4379
+ f.write(f'<text x="{x_start + 3}" y="30" font-family="Arial" font-size="8" fill="gray">'
4380
+ f'{start_time:.2f}s-{end_time:.2f}s</text>\n')
4381
+
4382
+ # Add time axis markers
4383
+ duration_sec = total_frames / orig_fps
4384
+ tick_interval = max(1, int(duration_sec / 20)) # ~20 ticks
4385
+ for t in range(0, int(duration_sec) + 1, tick_interval):
4386
+ x = int(t * orig_fps * pixels_per_frame)
4387
+ if x < width:
4388
+ f.write(f'<line x1="{x}" y1="{height-10}" x2="{x}" y2="{height}" stroke="black" stroke-width="1"/>\n')
4389
+ f.write(f'<text x="{x}" y="{height-2}" font-family="Arial" font-size="8" text-anchor="middle">{t}s</text>\n')
4390
+
4391
+ f.write('</svg>\n')
4392
+
4393
+ def _show_clip_popup(self, idx: int, frame_mode: bool = False, ovr_class_idx: int = -1):
4394
+ """Show popup dialog with clip video, label, and confidence.
4395
+
4396
+ Args:
4397
+ idx: Clip index (clip mode) or frame index (frame mode)
4398
+ frame_mode: If True, idx is a frame index; find the segment and show video at that frame
4399
+ ovr_class_idx: When clicking a per-class OvR row, the class index of the
4400
+ clicked row. If >= 0, the popup shows this specific class instead of the
4401
+ top-1 class from the aggregated timeline.
4402
+ """
4403
+ if not self.video_path:
4404
+ return
4405
+
4406
+ # In frame mode, find the corresponding segment and show info for that frame
4407
+ if frame_mode and self.aggregated_segments:
4408
+ frame_idx = idx
4409
+
4410
+ # If a specific OvR class row was clicked, build a virtual segment
4411
+ # for that class at this frame position so the user can inspect it.
4412
+ if ovr_class_idx >= 0 and ovr_class_idx < len(self.classes):
4413
+ segment = self._build_ovr_class_segment(frame_idx, ovr_class_idx)
4414
+ if segment is not None:
4415
+ self._show_frame_segment_popup(frame_idx, segment)
4416
+ return
4417
+
4418
+ # Default: find segment from the main aggregated timeline
4419
+ segment = None
4420
+ for seg in self.aggregated_segments:
4421
+ if seg['start'] <= frame_idx <= seg['end']:
4422
+ segment = seg
4423
+ break
4424
+
4425
+ if segment is None:
4426
+ return
4427
+
4428
+ pred_idx = segment['class']
4429
+ conf = segment.get('confidence', 1.0)
4430
+
4431
+ if pred_idx >= len(self.classes):
4432
+ return
4433
+
4434
+ label = self.classes[pred_idx]
4435
+
4436
+ # Show frame-specific popup
4437
+ self._show_frame_segment_popup(frame_idx, segment)
4438
+ return
4439
+
4440
+ # Original clip-based mode
4441
+ clip_idx = idx
4442
+ if clip_idx >= len(self.predictions):
4443
+ return
4444
+ ClipPopupDialog(self, self, clip_idx)
4445
+
4446
+ def _build_ovr_class_segment(self, frame_idx: int, class_idx: int) -> dict | None:
4447
+ """Build a virtual segment for a specific OvR class around frame_idx.
4448
+
4449
+ Finds the contiguous run of active frames for class_idx that contains
4450
+ frame_idx, using the stored per-frame scores / active mask.
4451
+ Returns a segment dict compatible with FrameSegmentPopupDialog, or None.
4452
+ """
4453
+ active_mask = getattr(self, "_aggregated_active_mask", None)
4454
+ frame_scores = getattr(self, "_aggregated_frame_scores_norm", None)
4455
+ if not isinstance(active_mask, np.ndarray) or class_idx >= active_mask.shape[1]:
4456
+ return None
4457
+ total_frames = active_mask.shape[0]
4458
+ if frame_idx < 0 or frame_idx >= total_frames:
4459
+ return None
4460
+
4461
+ is_active = bool(active_mask[frame_idx, class_idx])
4462
+ # Even if the frame isn't active for this class, still show a
4463
+ # single-frame segment so the user can inspect its score.
4464
+ if not is_active:
4465
+ conf = float(frame_scores[frame_idx, class_idx]) if isinstance(frame_scores, np.ndarray) else 0.0
4466
+ return {"class": class_idx, "start": frame_idx, "end": frame_idx, "confidence": conf}
4467
+
4468
+ # Expand outward to find the contiguous active run
4469
+ start = frame_idx
4470
+ while start > 0 and bool(active_mask[start - 1, class_idx]):
4471
+ start -= 1
4472
+ end = frame_idx
4473
+ while end < total_frames - 1 and bool(active_mask[end + 1, class_idx]):
4474
+ end += 1
4475
+
4476
+ conf = float(np.mean(frame_scores[start:end + 1, class_idx])) if isinstance(frame_scores, np.ndarray) else 1.0
4477
+ return {"class": class_idx, "start": start, "end": end, "confidence": conf}
4478
+
4479
+ def _show_frame_segment_popup(self, frame_idx: int, segment: dict, segment_idx: int = None):
4480
+ """Show popup for a frame-aggregated segment.
4481
+
4482
+ Args:
4483
+ frame_idx: The specific frame that was clicked
4484
+ segment: The segment dict with 'class', 'start', 'end', 'confidence'
4485
+ segment_idx: Index of the segment in self.aggregated_segments (for navigation)
4486
+ """
4487
+ if not self.video_path:
4488
+ return
4489
+
4490
+ # Find segment index if not provided
4491
+ if segment_idx is None:
4492
+ for i, seg in enumerate(self.aggregated_segments):
4493
+ if seg['start'] == segment['start'] and seg['end'] == segment['end']:
4494
+ segment_idx = i
4495
+ break
4496
+
4497
+ FrameSegmentPopupDialog(self, self, frame_idx, segment, segment_idx)
4498
+
4499
+ def _export_video_with_overlay(self):
4500
+ """Export video with configurable overlays (delegates to overlay_export module)."""
4501
+ from .overlay_export import run_export_video_with_overlay
4502
+ run_export_video_with_overlay(self)
4503
+
4504
+ def _preview_video_with_overlay(self):
4505
+ """Open video player to preview the exported video with overlays."""
4506
+ from .overlay_export import run_preview_video_with_overlay
4507
+ run_preview_video_with_overlay(self)
4508
+
4509
+ def _export_attention_heatmap(self):
4510
+ """Export video with spatial attention heatmap overlay."""
4511
+ from .attention_export import export_attention_heatmap_video
4512
+ export_attention_heatmap_video(self)
4513
+
4514
+ def update_config(self, config: dict):
4515
+ """Apply a new configuration (experiment switch)."""
4516
+ self.config = config
4517
+ self.model = None
4518
+ self.classes = []
4519
+ self.model_path_edit.clear()
4520
+ self.video_path = None
4521
+ self.video_path_edit.clear()
4522
+ self.video_info_label.setText("No video selected")
4523
+ self.run_inference_btn.setEnabled(False)
4524
+ self.export_btn.setEnabled(False)
4525
+ self.export_timeline_btn.setEnabled(False)
4526
+ self.preview_btn.setEnabled(False)
4527
+ self.results_list.clear()
4528
+ self.progress_label.setText("")
4529
+ self.progress_bar.setVisible(False)
4530
+ self.progress_bar.setValue(0)
4531
+ self.predictions = []
4532
+ self.confidences = []
4533
+ self.clip_probabilities = []
4534
+ self.clip_frame_probabilities = []
4535
+ self.clip_starts = []
4536
+ self.corrected_labels = {}
4537
+ self.corrected_attr_labels = {}
4538
+ self.aggregated_segments = []
4539
+ self.aggregated_multiclass_segments = []
4540
+ self._aggregated_frame_scores_norm = None
4541
+ self._aggregated_active_mask = None
4542
+ self._aggregated_last_covered_frame = 0
4543
+
4544
+ timeline_layout = self.timeline_widget.layout()
4545
+ if timeline_layout:
4546
+ while timeline_layout.count():
4547
+ child = timeline_layout.takeAt(0)
4548
+ if child.widget():
4549
+ child.widget().deleteLater()
4550
+