singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,4550 @@
|
|
|
1
|
+
from PyQt6.QtWidgets import (
|
|
2
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QLineEdit,
|
|
3
|
+
QFileDialog, QGroupBox, QFormLayout, QMessageBox, QListWidget, QListWidgetItem,
|
|
4
|
+
QSpinBox, QComboBox, QTextEdit, QScrollArea, QDialog, QCheckBox, QSizePolicy,
|
|
5
|
+
QProgressBar, QProgressDialog, QDoubleSpinBox, QDialogButtonBox, QApplication,
|
|
6
|
+
QSplitter, QGridLayout,
|
|
7
|
+
)
|
|
8
|
+
from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer, QEvent
|
|
9
|
+
from PyQt6.QtGui import QPixmap, QImage, QPainter, QFont, QColor, QMouseEvent
|
|
10
|
+
import copy
|
|
11
|
+
import cv2
|
|
12
|
+
import os
|
|
13
|
+
import torch
|
|
14
|
+
import numpy as np
|
|
15
|
+
import json
|
|
16
|
+
from singlebehaviorlab.backend.model import VideoPrismBackbone, BehaviorClassifier
|
|
17
|
+
from singlebehaviorlab.backend.video_utils import get_video_info, save_clip
|
|
18
|
+
from singlebehaviorlab.backend.data_store import AnnotationManager
|
|
19
|
+
from singlebehaviorlab.backend.uncertainty import save_uncertainty_report
|
|
20
|
+
from .timeline_themes import TIMELINE_COLOR_THEMES, DEFAULT_THEME, get_palette as get_timeline_palette
|
|
21
|
+
from .inference_worker import InferenceWorker, _sanitize_bbox_coords
|
|
22
|
+
from .inference_popups import ClipPopupDialog, FrameSegmentPopupDialog
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TimelineWidget(QWidget):
|
|
26
|
+
"""Custom widget for clickable timeline."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, parent=None):
|
|
29
|
+
super().__init__(parent)
|
|
30
|
+
self.clip_width = 20
|
|
31
|
+
self.num_clips = 0
|
|
32
|
+
self.click_callback = None
|
|
33
|
+
self._frame_mode = False
|
|
34
|
+
self._pixels_per_frame = 1.0
|
|
35
|
+
self._total_frames = 0
|
|
36
|
+
self.setMinimumHeight(100)
|
|
37
|
+
self.setStyleSheet("background-color: white; border: 1px solid gray;")
|
|
38
|
+
|
|
39
|
+
def mousePressEvent(self, event: QMouseEvent):
|
|
40
|
+
"""Handle mouse clicks on timeline."""
|
|
41
|
+
if not self.click_callback:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
x = event.position().x()
|
|
45
|
+
|
|
46
|
+
if self._frame_mode and self._pixels_per_frame > 0:
|
|
47
|
+
# Frame-aggregated mode: convert pixel to frame index
|
|
48
|
+
frame_idx = int(x / self._pixels_per_frame)
|
|
49
|
+
if 0 <= frame_idx < self._total_frames:
|
|
50
|
+
# Pass frame index with a special marker
|
|
51
|
+
self.click_callback(frame_idx, frame_mode=True)
|
|
52
|
+
elif self.num_clips > 0:
|
|
53
|
+
# Clip-based mode
|
|
54
|
+
clip_idx = int(x / self.clip_width)
|
|
55
|
+
if 0 <= clip_idx < self.num_clips:
|
|
56
|
+
self.click_callback(clip_idx)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class InferenceWidget(QWidget):
|
|
61
|
+
"""Widget for running inference on videos."""
|
|
62
|
+
|
|
63
|
+
# Emitted when inference finishes so the Review tab can load results.
|
|
64
|
+
# Payload: (results_dict, classes, is_ovr, clip_length, target_fps)
|
|
65
|
+
review_ready = pyqtSignal(dict, list, bool, int, int)
|
|
66
|
+
|
|
67
|
+
def __init__(self, config: dict):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.model = None
|
|
71
|
+
self.classes = []
|
|
72
|
+
self.attributes = []
|
|
73
|
+
self.label_mapping = None
|
|
74
|
+
self.video_path = None
|
|
75
|
+
self.video_paths = [] # List of selected videos
|
|
76
|
+
self.predictions = []
|
|
77
|
+
self.confidences = []
|
|
78
|
+
self.clip_probabilities = []
|
|
79
|
+
self.clip_frame_probabilities = [] # per-frame probs from FrameClassificationHead
|
|
80
|
+
self.attr_predictions = [] # Store attribute indices
|
|
81
|
+
self.attr_confidences = []
|
|
82
|
+
self.clip_starts = []
|
|
83
|
+
self.localization_bboxes = []
|
|
84
|
+
self.results_cache = {} # Cache for multi-video results
|
|
85
|
+
self.worker = None
|
|
86
|
+
self.exported_video_path = None
|
|
87
|
+
self.corrected_labels = {} # Map clip_idx -> corrected_label_index
|
|
88
|
+
self.corrected_attr_labels = {} # Map clip_idx -> corrected_attr_index
|
|
89
|
+
self.attributes_registry = None # NEW: Store registry
|
|
90
|
+
self.hierarchy_registry = None
|
|
91
|
+
self.aggregated_segments = [] # Frame-level top-1 segments: [{'class': int, 'start': int, 'end': int}, ...]
|
|
92
|
+
self.aggregated_multiclass_segments = [] # OvR frame-level per-class segments
|
|
93
|
+
self._aggregated_frame_scores_norm = None # np.ndarray [frames, classes]
|
|
94
|
+
self._aggregated_active_mask = None # np.ndarray [frames, classes], bool
|
|
95
|
+
self._aggregated_last_covered_frame = 0
|
|
96
|
+
self.total_frames = 0 # Total frames in current video
|
|
97
|
+
self.clip_popup_maximized = False # Track if popup should be maximized
|
|
98
|
+
self.infer_resolution = 288
|
|
99
|
+
self._bbox_ema_alpha = 0.85 # Minimal smoothing: 85% current frame, 15% previous
|
|
100
|
+
self._min_segment_frames = int(self.config.get("inference_min_segment_frames", 1))
|
|
101
|
+
self._merge_gap_frames = int(self.config.get("inference_merge_gap_frames", 0))
|
|
102
|
+
raw_smooth_win = int(self.config.get("inference_temporal_smoothing_window_frames", 1))
|
|
103
|
+
self._temporal_smoothing_window_frames = max(1, raw_smooth_win)
|
|
104
|
+
if self._temporal_smoothing_window_frames % 2 == 0:
|
|
105
|
+
self._temporal_smoothing_window_frames += 1
|
|
106
|
+
self.use_ignore_threshold = bool(self.config.get("inference_use_ignore_threshold", False))
|
|
107
|
+
self.global_ignore_threshold = float(self.config.get("inference_ignore_threshold", 0.60))
|
|
108
|
+
self.class_ignore_thresholds = dict(self.config.get("inference_class_ignore_thresholds", {}))
|
|
109
|
+
self.class_min_segment_frames = {
|
|
110
|
+
str(k): int(v) for k, v in self.config.get("inference_class_min_segment_frames", {}).items()
|
|
111
|
+
}
|
|
112
|
+
self.class_merge_gap_frames = {
|
|
113
|
+
str(k): int(v) for k, v in self.config.get("inference_class_merge_gap_frames", {}).items()
|
|
114
|
+
}
|
|
115
|
+
self.class_smoothing_window_frames = {
|
|
116
|
+
str(k): int(v) for k, v in self.config.get("inference_class_smoothing_window_frames", {}).items()
|
|
117
|
+
}
|
|
118
|
+
self.use_viterbi_decode = bool(self.config.get("inference_use_viterbi_decode", False))
|
|
119
|
+
self.viterbi_switch_penalty = float(self.config.get("inference_viterbi_switch_penalty", 0.35))
|
|
120
|
+
self.ignore_label_name = "Ignored / Unknown"
|
|
121
|
+
self.model_training_config = {}
|
|
122
|
+
self._use_ovr = False
|
|
123
|
+
self._allowed_cooccurrence = set()
|
|
124
|
+
self._ignore_threshold_user_modified = False
|
|
125
|
+
self._applying_auto_threshold = False
|
|
126
|
+
self._setup_ui()
|
|
127
|
+
self._update_viterbi_ui_state()
|
|
128
|
+
|
|
129
|
+
def _setup_ui(self):
|
|
130
|
+
"""Setup UI components."""
|
|
131
|
+
layout = QVBoxLayout()
|
|
132
|
+
|
|
133
|
+
model_group = QGroupBox("Model")
|
|
134
|
+
model_layout = QFormLayout()
|
|
135
|
+
|
|
136
|
+
self.model_path_edit = QLineEdit()
|
|
137
|
+
self.model_path_edit.setReadOnly(True)
|
|
138
|
+
self.model_browse_btn = QPushButton("Load head weights...")
|
|
139
|
+
self.model_browse_btn.clicked.connect(self._load_model)
|
|
140
|
+
model_path_layout = QHBoxLayout()
|
|
141
|
+
model_path_layout.addWidget(self.model_path_edit)
|
|
142
|
+
model_path_layout.addWidget(self.model_browse_btn)
|
|
143
|
+
model_layout.addRow("Model:", model_path_layout)
|
|
144
|
+
|
|
145
|
+
self.classes_label = QLabel("No model loaded")
|
|
146
|
+
model_layout.addRow("Classes:", self.classes_label)
|
|
147
|
+
|
|
148
|
+
model_group.setLayout(model_layout)
|
|
149
|
+
|
|
150
|
+
video_group = QGroupBox("Video selection")
|
|
151
|
+
video_layout = QFormLayout()
|
|
152
|
+
|
|
153
|
+
self.video_path_edit = QLineEdit()
|
|
154
|
+
self.video_path_edit.setReadOnly(True)
|
|
155
|
+
self.video_browse_btn = QPushButton("Select video(s)...")
|
|
156
|
+
self.video_browse_btn.clicked.connect(self._select_video)
|
|
157
|
+
video_path_layout = QHBoxLayout()
|
|
158
|
+
video_path_layout.addWidget(self.video_path_edit)
|
|
159
|
+
video_path_layout.addWidget(self.video_browse_btn)
|
|
160
|
+
video_layout.addRow("Video:", video_path_layout)
|
|
161
|
+
|
|
162
|
+
self.video_info_label = QLabel("No video selected")
|
|
163
|
+
video_layout.addRow("Info:", self.video_info_label)
|
|
164
|
+
|
|
165
|
+
video_group.setLayout(video_layout)
|
|
166
|
+
|
|
167
|
+
model_video_layout = QVBoxLayout()
|
|
168
|
+
model_video_layout.addWidget(model_group)
|
|
169
|
+
model_video_layout.addWidget(video_group)
|
|
170
|
+
model_video_widget = QWidget()
|
|
171
|
+
model_video_widget.setLayout(model_video_layout)
|
|
172
|
+
|
|
173
|
+
params_group = QGroupBox("Inference parameters")
|
|
174
|
+
params_layout = QFormLayout()
|
|
175
|
+
|
|
176
|
+
self.target_fps_spin = QSpinBox()
|
|
177
|
+
self.target_fps_spin.setRange(1, 60)
|
|
178
|
+
self.target_fps_spin.setValue(16)
|
|
179
|
+
params_layout.addRow("Target FPS:", self.target_fps_spin)
|
|
180
|
+
|
|
181
|
+
self.clip_length_spin = QSpinBox()
|
|
182
|
+
self.clip_length_spin.setRange(1, 64)
|
|
183
|
+
self.clip_length_spin.setValue(16)
|
|
184
|
+
params_layout.addRow("Frames per clip:", self.clip_length_spin)
|
|
185
|
+
|
|
186
|
+
self.step_frames_spin = QSpinBox()
|
|
187
|
+
self.step_frames_spin.setRange(1, 64)
|
|
188
|
+
self.step_frames_spin.setValue(max(1, self.clip_length_spin.value() // 2))
|
|
189
|
+
self.step_frames_spin.setToolTip(
|
|
190
|
+
"Number of subsampled frames to advance between clips.\n"
|
|
191
|
+
"Defaults to half of 'Frames per clip' (50% overlap) but can be changed.\n"
|
|
192
|
+
"Smaller values mean more overlap and finer temporal resolution."
|
|
193
|
+
)
|
|
194
|
+
self.step_frames_spin.valueChanged.connect(self._on_step_or_clip_changed)
|
|
195
|
+
params_layout.addRow("Step frames:", self.step_frames_spin)
|
|
196
|
+
|
|
197
|
+
self.clip_length_spin.valueChanged.connect(self._on_clip_length_changed)
|
|
198
|
+
|
|
199
|
+
self.resolution_spin = QSpinBox()
|
|
200
|
+
self.resolution_spin.setRange(64, 1024)
|
|
201
|
+
self.resolution_spin.setValue(288)
|
|
202
|
+
params_layout.addRow("Resolution:", self.resolution_spin)
|
|
203
|
+
|
|
204
|
+
self.override_resolution_check = QCheckBox("Override model resolution")
|
|
205
|
+
self.override_resolution_check.setToolTip("Use the resolution above instead of model metadata")
|
|
206
|
+
params_layout.addRow("", self.override_resolution_check)
|
|
207
|
+
|
|
208
|
+
self.collect_attention_check = QCheckBox("Collect attention maps")
|
|
209
|
+
self.collect_attention_check.setToolTip(
|
|
210
|
+
"Record spatial attention weights during inference.\n"
|
|
211
|
+
"Enables exporting a heatmap video showing what the model focuses on.\n"
|
|
212
|
+
"Minimal performance impact."
|
|
213
|
+
)
|
|
214
|
+
params_layout.addRow("", self.collect_attention_check)
|
|
215
|
+
|
|
216
|
+
self.sample_inference_check = QCheckBox("Quick-check sampled inference")
|
|
217
|
+
self.sample_inference_check.setToolTip(
|
|
218
|
+
"Run inference only on evenly spread chunks of each video.\n"
|
|
219
|
+
"Useful for checking model behavior on long videos without running full inference."
|
|
220
|
+
)
|
|
221
|
+
self.sample_inference_check.toggled.connect(self._update_sample_inference_controls)
|
|
222
|
+
params_layout.addRow("", self.sample_inference_check)
|
|
223
|
+
|
|
224
|
+
self.sample_duration_spin = QSpinBox()
|
|
225
|
+
self.sample_duration_spin.setRange(10, 300)
|
|
226
|
+
self.sample_duration_spin.setValue(60)
|
|
227
|
+
self.sample_duration_spin.setSuffix(" s")
|
|
228
|
+
self.sample_duration_spin.setToolTip("Duration of each sampled chunk per video.")
|
|
229
|
+
params_layout.addRow("Chunk duration:", self.sample_duration_spin)
|
|
230
|
+
|
|
231
|
+
self.sample_count_spin = QSpinBox()
|
|
232
|
+
self.sample_count_spin.setRange(1, 50)
|
|
233
|
+
self.sample_count_spin.setValue(5)
|
|
234
|
+
self.sample_count_spin.setToolTip("Number of sampled chunks spread evenly across each video.")
|
|
235
|
+
params_layout.addRow("Number of chunks:", self.sample_count_spin)
|
|
236
|
+
self._update_sample_inference_controls(False)
|
|
237
|
+
|
|
238
|
+
params_group.setLayout(params_layout)
|
|
239
|
+
|
|
240
|
+
top_splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
241
|
+
top_splitter.addWidget(model_video_widget)
|
|
242
|
+
top_splitter.addWidget(params_group)
|
|
243
|
+
top_splitter.setStretchFactor(0, 1)
|
|
244
|
+
top_splitter.setStretchFactor(1, 1)
|
|
245
|
+
layout.addWidget(top_splitter)
|
|
246
|
+
|
|
247
|
+
button_layout = QHBoxLayout()
|
|
248
|
+
self.run_inference_btn = QPushButton("Run inference")
|
|
249
|
+
self.run_inference_btn.clicked.connect(self._run_inference)
|
|
250
|
+
self.run_inference_btn.setEnabled(False)
|
|
251
|
+
button_layout.addWidget(self.run_inference_btn)
|
|
252
|
+
|
|
253
|
+
self.stop_inference_btn = QPushButton("Stop inference")
|
|
254
|
+
self.stop_inference_btn.clicked.connect(self._stop_inference)
|
|
255
|
+
self.stop_inference_btn.setEnabled(False)
|
|
256
|
+
self.stop_inference_btn.setToolTip("Stop batch inference; already-completed videos are kept.")
|
|
257
|
+
button_layout.addWidget(self.stop_inference_btn)
|
|
258
|
+
|
|
259
|
+
self.load_timeline_btn = QPushButton("Load timeline results")
|
|
260
|
+
self.load_timeline_btn.clicked.connect(self._load_timeline_results)
|
|
261
|
+
button_layout.addWidget(self.load_timeline_btn)
|
|
262
|
+
|
|
263
|
+
self.export_btn = QPushButton("Export video with overlays")
|
|
264
|
+
self.export_btn.clicked.connect(self._export_video_with_overlay)
|
|
265
|
+
self.export_btn.setEnabled(False)
|
|
266
|
+
button_layout.addWidget(self.export_btn)
|
|
267
|
+
|
|
268
|
+
self.preview_btn = QPushButton("Preview video with overlays")
|
|
269
|
+
self.preview_btn.clicked.connect(self._preview_video_with_overlay)
|
|
270
|
+
self.preview_btn.setEnabled(False)
|
|
271
|
+
button_layout.addWidget(self.preview_btn)
|
|
272
|
+
|
|
273
|
+
self.export_attention_btn = QPushButton("Export attention heatmap")
|
|
274
|
+
self.export_attention_btn.clicked.connect(self._export_attention_heatmap)
|
|
275
|
+
self.export_attention_btn.setEnabled(False)
|
|
276
|
+
self.export_attention_btn.setToolTip("Export video with spatial attention overlay showing what the model focuses on")
|
|
277
|
+
button_layout.addWidget(self.export_attention_btn)
|
|
278
|
+
|
|
279
|
+
layout.addLayout(button_layout)
|
|
280
|
+
|
|
281
|
+
self.progress_bar = QProgressBar()
|
|
282
|
+
self.progress_bar.setVisible(False)
|
|
283
|
+
self.progress_bar.setRange(0, 100)
|
|
284
|
+
self.progress_bar.setValue(0)
|
|
285
|
+
layout.addWidget(self.progress_bar)
|
|
286
|
+
|
|
287
|
+
self.progress_label = QLabel("")
|
|
288
|
+
layout.addWidget(self.progress_label)
|
|
289
|
+
|
|
290
|
+
timeline_group = QGroupBox("Timeline")
|
|
291
|
+
timeline_group_layout = QVBoxLayout()
|
|
292
|
+
timeline_controls_layout = QHBoxLayout()
|
|
293
|
+
timeline_controls_layout.setSpacing(8)
|
|
294
|
+
timeline_controls_layout.setContentsMargins(4, 2, 4, 2)
|
|
295
|
+
timeline_label_text = QLabel("Timeline:")
|
|
296
|
+
timeline_controls_layout.addWidget(timeline_label_text)
|
|
297
|
+
|
|
298
|
+
timeline_controls_layout.addWidget(QLabel("Video:"))
|
|
299
|
+
self.filter_video_combo = QComboBox()
|
|
300
|
+
self.filter_video_combo.currentIndexChanged.connect(self._on_video_selection_changed)
|
|
301
|
+
timeline_controls_layout.addWidget(self.filter_video_combo)
|
|
302
|
+
|
|
303
|
+
timeline_controls_layout.addStretch()
|
|
304
|
+
|
|
305
|
+
timeline_controls_layout.addWidget(QLabel("Show behavior:"))
|
|
306
|
+
self.filter_behavior_combo = QComboBox()
|
|
307
|
+
self.filter_behavior_combo.addItem("All Behaviors")
|
|
308
|
+
self.filter_behavior_combo.currentIndexChanged.connect(self._on_filter_changed)
|
|
309
|
+
timeline_controls_layout.addWidget(self.filter_behavior_combo)
|
|
310
|
+
self.use_ignore_threshold_check = QCheckBox("Ignore low-confidence")
|
|
311
|
+
self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
|
|
312
|
+
self.use_ignore_threshold_check.stateChanged.connect(self._on_ignore_threshold_changed)
|
|
313
|
+
timeline_controls_layout.addWidget(self.use_ignore_threshold_check)
|
|
314
|
+
self.ignore_threshold_spin = QDoubleSpinBox()
|
|
315
|
+
self.ignore_threshold_spin.setDecimals(2)
|
|
316
|
+
self.ignore_threshold_spin.setRange(0.0, 1.0)
|
|
317
|
+
self.ignore_threshold_spin.setSingleStep(0.05)
|
|
318
|
+
self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
|
|
319
|
+
self.ignore_threshold_spin.setToolTip(
|
|
320
|
+
"Fallback threshold for classes without a per-class threshold.\n"
|
|
321
|
+
"Clips below this confidence are grayed out."
|
|
322
|
+
)
|
|
323
|
+
self.ignore_threshold_spin.valueChanged.connect(self._on_ignore_threshold_changed)
|
|
324
|
+
timeline_controls_layout.addWidget(QLabel("Default τ:"))
|
|
325
|
+
timeline_controls_layout.addWidget(self.ignore_threshold_spin)
|
|
326
|
+
self.per_class_thresh_btn = QPushButton("Per-class τ")
|
|
327
|
+
self.per_class_thresh_btn.clicked.connect(self._open_per_class_thresholds_dialog)
|
|
328
|
+
timeline_controls_layout.addWidget(self.per_class_thresh_btn)
|
|
329
|
+
|
|
330
|
+
timeline_controls_layout.addWidget(QLabel("Theme:"))
|
|
331
|
+
self.timeline_theme_combo = QComboBox()
|
|
332
|
+
self.timeline_theme_combo.addItems(list(TIMELINE_COLOR_THEMES.keys()))
|
|
333
|
+
self.timeline_theme_combo.setCurrentText(DEFAULT_THEME)
|
|
334
|
+
self.timeline_theme_combo.setToolTip("Change timeline color theme")
|
|
335
|
+
self.timeline_theme_combo.currentIndexChanged.connect(self._on_theme_changed)
|
|
336
|
+
timeline_controls_layout.addWidget(self.timeline_theme_combo)
|
|
337
|
+
|
|
338
|
+
self.merge_timeline_check = QCheckBox("Merge consecutive identical behaviors")
|
|
339
|
+
self.merge_timeline_check.setToolTip("Merge consecutive clips with the same predicted behavior")
|
|
340
|
+
self.merge_timeline_check.stateChanged.connect(self._on_merge_changed)
|
|
341
|
+
timeline_controls_layout.addWidget(self.merge_timeline_check)
|
|
342
|
+
|
|
343
|
+
self.frame_aggregation_check = QCheckBox("Precise frame boundaries")
|
|
344
|
+
self.frame_aggregation_check.setToolTip(
|
|
345
|
+
"Use overlapping clip votes to determine precise behavior boundaries.\n"
|
|
346
|
+
"Best used with step_frames < clip_length (e.g., step=4 for 16-frame clips).\n"
|
|
347
|
+
"Each frame gets votes from all clips that cover it, weighted by confidence."
|
|
348
|
+
)
|
|
349
|
+
self.frame_aggregation_check.stateChanged.connect(self._on_frame_aggregation_changed)
|
|
350
|
+
timeline_controls_layout.addWidget(self.frame_aggregation_check)
|
|
351
|
+
|
|
352
|
+
self.use_viterbi_check = QCheckBox("Viterbi decode")
|
|
353
|
+
self.use_viterbi_check.setToolTip(
|
|
354
|
+
"Inference-only sequence decoding on the merged frame probabilities.\n"
|
|
355
|
+
"Single-label models use classic Viterbi over classes.\n"
|
|
356
|
+
"OvR models use binary per-class Viterbi with co-occurrence-aware cleanup."
|
|
357
|
+
)
|
|
358
|
+
self.use_viterbi_check.setChecked(self.use_viterbi_decode)
|
|
359
|
+
self.use_viterbi_check.stateChanged.connect(self._on_viterbi_changed)
|
|
360
|
+
timeline_controls_layout.addWidget(self.use_viterbi_check)
|
|
361
|
+
|
|
362
|
+
timeline_controls_layout.addWidget(QLabel("Viterbi switch:"))
|
|
363
|
+
self.viterbi_switch_penalty_spin = QDoubleSpinBox()
|
|
364
|
+
self.viterbi_switch_penalty_spin.setDecimals(2)
|
|
365
|
+
self.viterbi_switch_penalty_spin.setRange(0.0, 5.0)
|
|
366
|
+
self.viterbi_switch_penalty_spin.setSingleStep(0.05)
|
|
367
|
+
self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
|
|
368
|
+
self.viterbi_switch_penalty_spin.setToolTip(
|
|
369
|
+
"Penalty for changing behavior between adjacent frames.\n"
|
|
370
|
+
"Higher values make the decoded sequence more stable."
|
|
371
|
+
)
|
|
372
|
+
self.viterbi_switch_penalty_spin.valueChanged.connect(self._on_viterbi_changed)
|
|
373
|
+
timeline_controls_layout.addWidget(self.viterbi_switch_penalty_spin)
|
|
374
|
+
|
|
375
|
+
self.per_class_seg_btn = QPushButton("Per-class seg rules")
|
|
376
|
+
self.per_class_seg_btn.setToolTip(
|
|
377
|
+
"Set smooth window, gap fill, and minimum segment length per behavior class."
|
|
378
|
+
)
|
|
379
|
+
self.per_class_seg_btn.clicked.connect(self._open_per_class_segment_rules_dialog)
|
|
380
|
+
timeline_controls_layout.addWidget(self.per_class_seg_btn)
|
|
381
|
+
|
|
382
|
+
self.ovr_rows_check = QCheckBox("Per-class rows")
|
|
383
|
+
self.ovr_rows_check.setChecked(True)
|
|
384
|
+
self.ovr_rows_check.setToolTip(
|
|
385
|
+
"Per-class rows: one row per class. Uncheck for single-row timeline (OvR models only)."
|
|
386
|
+
)
|
|
387
|
+
self.ovr_rows_check.stateChanged.connect(lambda: self._draw_timeline())
|
|
388
|
+
timeline_controls_layout.addWidget(self.ovr_rows_check)
|
|
389
|
+
|
|
390
|
+
self.ovr_show_all_check = QCheckBox("Show all classes")
|
|
391
|
+
self.ovr_show_all_check.setChecked(False)
|
|
392
|
+
self.ovr_show_all_check.setToolTip(
|
|
393
|
+
"When enabled, every class above its threshold is shown independently\n"
|
|
394
|
+
"(no mutual exclusivity). When disabled, only the top-1 class and\n"
|
|
395
|
+
"allowed co-occurring classes are shown."
|
|
396
|
+
)
|
|
397
|
+
self.ovr_show_all_check.stateChanged.connect(self._on_ovr_show_all_changed)
|
|
398
|
+
timeline_controls_layout.addWidget(self.ovr_show_all_check)
|
|
399
|
+
|
|
400
|
+
timeline_controls_layout.addWidget(QLabel("Zoom (px/s):"))
|
|
401
|
+
self.timeline_zoom_spin = QSpinBox()
|
|
402
|
+
self.timeline_zoom_spin.setRange(10, 2000)
|
|
403
|
+
self.timeline_zoom_spin.setValue(100)
|
|
404
|
+
self.timeline_zoom_spin.setSingleStep(20)
|
|
405
|
+
self.timeline_zoom_spin.setToolTip(
|
|
406
|
+
"Timeline horizontal zoom: pixels per second of video.\n"
|
|
407
|
+
"Increase to spread the timeline out and make short segments clickable.\n"
|
|
408
|
+
"The timeline area scrolls horizontally when zoomed in."
|
|
409
|
+
)
|
|
410
|
+
self.timeline_zoom_spin.valueChanged.connect(self._on_timeline_zoom_changed)
|
|
411
|
+
timeline_controls_layout.addWidget(self.timeline_zoom_spin)
|
|
412
|
+
|
|
413
|
+
self.export_timeline_btn = QPushButton("Export CSV/SVG")
|
|
414
|
+
self.export_timeline_btn.setToolTip("Export timeline as CSV (behavior segments) and SVG (visualization) for external analysis")
|
|
415
|
+
self.export_timeline_btn.clicked.connect(self._export_timeline)
|
|
416
|
+
self.export_timeline_btn.setEnabled(False)
|
|
417
|
+
timeline_controls_layout.addWidget(self.export_timeline_btn)
|
|
418
|
+
|
|
419
|
+
self.save_results_btn = QPushButton("Save results")
|
|
420
|
+
self.save_results_btn.setToolTip("Save inference results (including corrections) for downstream analysis")
|
|
421
|
+
self.save_results_btn.clicked.connect(self._save_results)
|
|
422
|
+
self.save_results_btn.setEnabled(False)
|
|
423
|
+
timeline_controls_layout.addWidget(self.save_results_btn)
|
|
424
|
+
|
|
425
|
+
# Put controls in a horizontal scroll area so all settings stay accessible
|
|
426
|
+
# on smaller screens instead of being cramped into one visible row.
|
|
427
|
+
timeline_controls_widget = QWidget()
|
|
428
|
+
timeline_controls_widget.setLayout(timeline_controls_layout)
|
|
429
|
+
timeline_controls_widget.setMinimumHeight(40)
|
|
430
|
+
|
|
431
|
+
timeline_controls_scroll = QScrollArea()
|
|
432
|
+
timeline_controls_scroll.setWidget(timeline_controls_widget)
|
|
433
|
+
timeline_controls_scroll.setWidgetResizable(True)
|
|
434
|
+
timeline_controls_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
435
|
+
timeline_controls_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
436
|
+
timeline_controls_scroll.setFrameShape(QScrollArea.Shape.NoFrame)
|
|
437
|
+
timeline_controls_scroll.setMinimumHeight(48)
|
|
438
|
+
timeline_controls_scroll.setMaximumHeight(62)
|
|
439
|
+
|
|
440
|
+
self.timeline_scroll = QScrollArea()
|
|
441
|
+
self.timeline_scroll.setWidgetResizable(False)
|
|
442
|
+
self.timeline_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
443
|
+
self.timeline_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
444
|
+
self.timeline_scroll.setMinimumHeight(100)
|
|
445
|
+
self.timeline_scroll.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
446
|
+
|
|
447
|
+
self.timeline_widget = TimelineWidget()
|
|
448
|
+
self.timeline_widget.click_callback = self._show_clip_popup
|
|
449
|
+
self.timeline_scroll.setWidget(self.timeline_widget)
|
|
450
|
+
|
|
451
|
+
# Container for OvR multi-row: two scroll areas side by side with synced vertical scroll
|
|
452
|
+
self._ovr_timeline_container = QWidget()
|
|
453
|
+
self._ovr_timeline_container.setVisible(False)
|
|
454
|
+
self._ovr_timeline_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
|
|
455
|
+
ovr_tl_layout = QHBoxLayout(self._ovr_timeline_container)
|
|
456
|
+
ovr_tl_layout.setContentsMargins(0, 0, 0, 0)
|
|
457
|
+
ovr_tl_layout.setSpacing(0)
|
|
458
|
+
|
|
459
|
+
# Left: label scroll area (vertical only, scrollbar hidden, synced to timeline)
|
|
460
|
+
self._ovr_label_scroll = QScrollArea()
|
|
461
|
+
self._ovr_label_scroll.setFixedWidth(100)
|
|
462
|
+
self._ovr_label_scroll.setWidgetResizable(False)
|
|
463
|
+
self._ovr_label_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
464
|
+
self._ovr_label_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
465
|
+
self._ovr_label_scroll.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)
|
|
466
|
+
self._ovr_label_panel = QLabel()
|
|
467
|
+
self._ovr_label_scroll.setWidget(self._ovr_label_panel)
|
|
468
|
+
ovr_tl_layout.addWidget(self._ovr_label_scroll)
|
|
469
|
+
|
|
470
|
+
# Right: timeline scroll area (horizontal + vertical)
|
|
471
|
+
self._ovr_scroll = QScrollArea()
|
|
472
|
+
self._ovr_scroll.setWidgetResizable(False)
|
|
473
|
+
self._ovr_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
|
|
474
|
+
self._ovr_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
|
475
|
+
self._ovr_scroll.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
476
|
+
self._ovr_scroll.setMinimumWidth(300)
|
|
477
|
+
self._ovr_timeline_widget = QLabel()
|
|
478
|
+
self._ovr_timeline_widget.mousePressEvent = self._ovr_timeline_click
|
|
479
|
+
self._ovr_clip_width = 20
|
|
480
|
+
self._ovr_num_clips = 0
|
|
481
|
+
self._ovr_row_height = 24
|
|
482
|
+
self._ovr_timeline_frame_mode = False
|
|
483
|
+
self._ovr_pixels_per_frame = 1.0
|
|
484
|
+
self._ovr_scroll.setWidget(self._ovr_timeline_widget)
|
|
485
|
+
ovr_tl_layout.addWidget(self._ovr_scroll)
|
|
486
|
+
self._ovr_scroll.viewport().installEventFilter(self)
|
|
487
|
+
self._ovr_timeline_container.installEventFilter(self)
|
|
488
|
+
|
|
489
|
+
# Sync vertical scroll: when timeline scrolls vertically, labels follow
|
|
490
|
+
self._ovr_scroll.verticalScrollBar().valueChanged.connect(
|
|
491
|
+
self._ovr_label_scroll.verticalScrollBar().setValue
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
timeline_group_layout.addWidget(timeline_controls_scroll)
|
|
495
|
+
timeline_group_layout.addWidget(self.timeline_scroll, 1)
|
|
496
|
+
timeline_group_layout.addWidget(self._ovr_timeline_container, 1)
|
|
497
|
+
timeline_preset_row = QHBoxLayout()
|
|
498
|
+
timeline_preset_row.addStretch()
|
|
499
|
+
self.save_timeline_settings_btn = QPushButton("Save timeline settings")
|
|
500
|
+
self.save_timeline_settings_btn.setToolTip(
|
|
501
|
+
"Save ignore filtering and timeline postprocessing/display settings as a reusable preset."
|
|
502
|
+
)
|
|
503
|
+
self.save_timeline_settings_btn.clicked.connect(self._save_timeline_settings_preset)
|
|
504
|
+
timeline_preset_row.addWidget(self.save_timeline_settings_btn)
|
|
505
|
+
self.load_timeline_settings_btn = QPushButton("Load timeline settings")
|
|
506
|
+
self.load_timeline_settings_btn.setToolTip(
|
|
507
|
+
"Load a previously saved timeline settings preset."
|
|
508
|
+
)
|
|
509
|
+
self.load_timeline_settings_btn.clicked.connect(self._load_timeline_settings_preset)
|
|
510
|
+
timeline_preset_row.addWidget(self.load_timeline_settings_btn)
|
|
511
|
+
timeline_group_layout.addLayout(timeline_preset_row)
|
|
512
|
+
timeline_group.setLayout(timeline_group_layout)
|
|
513
|
+
timeline_group.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding)
|
|
514
|
+
layout.addWidget(timeline_group, 1)
|
|
515
|
+
|
|
516
|
+
results_group = QGroupBox("Results")
|
|
517
|
+
results_layout = QVBoxLayout()
|
|
518
|
+
self.results_list = QListWidget()
|
|
519
|
+
results_layout.addWidget(self.results_list)
|
|
520
|
+
results_group.setLayout(results_layout)
|
|
521
|
+
|
|
522
|
+
log_group = QGroupBox("Logs")
|
|
523
|
+
log_layout = QVBoxLayout()
|
|
524
|
+
self.log_text = QTextEdit()
|
|
525
|
+
self.log_text.setReadOnly(True)
|
|
526
|
+
log_layout.addWidget(self.log_text)
|
|
527
|
+
log_group.setLayout(log_layout)
|
|
528
|
+
|
|
529
|
+
results_logs_splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
530
|
+
results_logs_splitter.addWidget(results_group)
|
|
531
|
+
results_logs_splitter.addWidget(log_group)
|
|
532
|
+
results_logs_splitter.setStretchFactor(0, 1)
|
|
533
|
+
results_logs_splitter.setStretchFactor(1, 1)
|
|
534
|
+
results_logs_splitter.setMaximumHeight(320)
|
|
535
|
+
layout.addWidget(results_logs_splitter, 0)
|
|
536
|
+
|
|
537
|
+
self.setLayout(layout)
|
|
538
|
+
|
|
539
|
+
def _collect_timeline_settings_payload(self) -> dict:
|
|
540
|
+
return {
|
|
541
|
+
"frame_aggregation_enabled": bool(self.frame_aggregation_check.isChecked()),
|
|
542
|
+
"merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
|
|
543
|
+
"use_viterbi_decode": bool(self.use_viterbi_decode),
|
|
544
|
+
"viterbi_switch_penalty": float(self.viterbi_switch_penalty),
|
|
545
|
+
"use_ignore_threshold": bool(self.use_ignore_threshold),
|
|
546
|
+
"ignore_threshold": float(self.global_ignore_threshold),
|
|
547
|
+
"class_ignore_thresholds": {
|
|
548
|
+
cls: float(t) for cls, t in self.class_ignore_thresholds.items()
|
|
549
|
+
},
|
|
550
|
+
"timeline_zoom": int(self.timeline_zoom_spin.value()),
|
|
551
|
+
"ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
|
|
552
|
+
"ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
|
|
553
|
+
"min_segment_frames": int(self._min_segment_frames),
|
|
554
|
+
"merge_gap_frames": int(self._merge_gap_frames),
|
|
555
|
+
"temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
|
|
556
|
+
"class_min_segment_frames": {
|
|
557
|
+
cls: int(v) for cls, v in self.class_min_segment_frames.items()
|
|
558
|
+
},
|
|
559
|
+
"class_merge_gap_frames": {
|
|
560
|
+
cls: int(v) for cls, v in self.class_merge_gap_frames.items()
|
|
561
|
+
},
|
|
562
|
+
"class_smoothing_window_frames": {
|
|
563
|
+
cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
|
|
564
|
+
},
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
def _apply_timeline_settings_payload(self, payload: dict):
|
|
568
|
+
if not isinstance(payload, dict):
|
|
569
|
+
return
|
|
570
|
+
self.use_ignore_threshold = bool(payload.get("use_ignore_threshold", self.use_ignore_threshold))
|
|
571
|
+
self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
|
|
572
|
+
if "ignore_threshold" in payload:
|
|
573
|
+
self.global_ignore_threshold = float(payload["ignore_threshold"])
|
|
574
|
+
self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
|
|
575
|
+
if "class_ignore_thresholds" in payload and isinstance(payload["class_ignore_thresholds"], dict):
|
|
576
|
+
self.class_ignore_thresholds = {
|
|
577
|
+
str(cls): float(t) for cls, t in payload["class_ignore_thresholds"].items()
|
|
578
|
+
}
|
|
579
|
+
if "timeline_zoom" in payload:
|
|
580
|
+
self.timeline_zoom_spin.setValue(int(payload["timeline_zoom"]))
|
|
581
|
+
if "frame_aggregation_enabled" in payload:
|
|
582
|
+
self.frame_aggregation_check.setChecked(bool(payload["frame_aggregation_enabled"]))
|
|
583
|
+
if "merge_consecutive_enabled" in payload:
|
|
584
|
+
self.merge_timeline_check.setChecked(bool(payload["merge_consecutive_enabled"]))
|
|
585
|
+
if "use_viterbi_decode" in payload:
|
|
586
|
+
self.use_viterbi_decode = bool(payload["use_viterbi_decode"])
|
|
587
|
+
self.use_viterbi_check.setChecked(self.use_viterbi_decode)
|
|
588
|
+
if "viterbi_switch_penalty" in payload:
|
|
589
|
+
self.viterbi_switch_penalty = float(payload["viterbi_switch_penalty"])
|
|
590
|
+
self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
|
|
591
|
+
if getattr(self, "ovr_rows_check", None) is not None and "ovr_rows" in payload:
|
|
592
|
+
self.ovr_rows_check.setChecked(bool(payload["ovr_rows"]))
|
|
593
|
+
if getattr(self, "ovr_show_all_check", None) is not None and "ovr_show_all" in payload:
|
|
594
|
+
self.ovr_show_all_check.setChecked(bool(payload["ovr_show_all"]))
|
|
595
|
+
if "min_segment_frames" in payload:
|
|
596
|
+
self._min_segment_frames = int(payload["min_segment_frames"])
|
|
597
|
+
if "merge_gap_frames" in payload:
|
|
598
|
+
self._merge_gap_frames = int(payload["merge_gap_frames"])
|
|
599
|
+
if "temporal_smoothing_window_frames" in payload:
|
|
600
|
+
smooth = int(payload["temporal_smoothing_window_frames"])
|
|
601
|
+
if smooth % 2 == 0:
|
|
602
|
+
smooth += 1
|
|
603
|
+
self._temporal_smoothing_window_frames = max(1, smooth)
|
|
604
|
+
if "class_min_segment_frames" in payload and isinstance(payload["class_min_segment_frames"], dict):
|
|
605
|
+
self.class_min_segment_frames = {
|
|
606
|
+
str(cls): int(v) for cls, v in payload["class_min_segment_frames"].items()
|
|
607
|
+
}
|
|
608
|
+
if "class_merge_gap_frames" in payload and isinstance(payload["class_merge_gap_frames"], dict):
|
|
609
|
+
self.class_merge_gap_frames = {
|
|
610
|
+
str(cls): int(v) for cls, v in payload["class_merge_gap_frames"].items()
|
|
611
|
+
}
|
|
612
|
+
if "class_smoothing_window_frames" in payload and isinstance(payload["class_smoothing_window_frames"], dict):
|
|
613
|
+
self.class_smoothing_window_frames = {
|
|
614
|
+
str(cls): int(v) for cls, v in payload["class_smoothing_window_frames"].items()
|
|
615
|
+
}
|
|
616
|
+
self.config["inference_ignore_threshold"] = self.global_ignore_threshold
|
|
617
|
+
self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
|
|
618
|
+
self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
|
|
619
|
+
self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
|
|
620
|
+
self._sync_per_class_segment_rule_config()
|
|
621
|
+
if self.predictions:
|
|
622
|
+
if self.frame_aggregation_check.isChecked():
|
|
623
|
+
self._compute_aggregated_timeline()
|
|
624
|
+
self._display_results()
|
|
625
|
+
|
|
626
|
+
def _save_timeline_settings_preset(self):
|
|
627
|
+
path, _ = QFileDialog.getSaveFileName(
|
|
628
|
+
self,
|
|
629
|
+
"Save Timeline Settings",
|
|
630
|
+
self.config.get("experiment_path", os.getcwd()),
|
|
631
|
+
"JSON Files (*.json)"
|
|
632
|
+
)
|
|
633
|
+
if not path:
|
|
634
|
+
return
|
|
635
|
+
payload = {
|
|
636
|
+
"classes": list(self.classes),
|
|
637
|
+
"timeline_settings": self._collect_timeline_settings_payload(),
|
|
638
|
+
}
|
|
639
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
640
|
+
json.dump(payload, f, indent=2)
|
|
641
|
+
self.log_text.append(f"Timeline settings saved: {path}")
|
|
642
|
+
|
|
643
|
+
def _load_timeline_settings_preset(self):
|
|
644
|
+
path, _ = QFileDialog.getOpenFileName(
|
|
645
|
+
self,
|
|
646
|
+
"Load Timeline Settings",
|
|
647
|
+
self.config.get("experiment_path", os.getcwd()),
|
|
648
|
+
"JSON Files (*.json)"
|
|
649
|
+
)
|
|
650
|
+
if not path:
|
|
651
|
+
return
|
|
652
|
+
try:
|
|
653
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
654
|
+
data = json.load(f)
|
|
655
|
+
payload = data.get("timeline_settings", data)
|
|
656
|
+
self._apply_timeline_settings_payload(payload)
|
|
657
|
+
self.log_text.append(f"Timeline settings loaded: {path}")
|
|
658
|
+
except Exception as e:
|
|
659
|
+
QMessageBox.warning(self, "Load failed", f"Failed to load timeline settings:\n{e}")
|
|
660
|
+
|
|
661
|
+
def _load_model(self):
|
|
662
|
+
"""Load trained model head."""
|
|
663
|
+
model_path, _ = QFileDialog.getOpenFileName(
|
|
664
|
+
self,
|
|
665
|
+
"Load Model Head",
|
|
666
|
+
self.config.get("models_dir", "models/behavior_heads"),
|
|
667
|
+
"PyTorch Files (*.pt);;All Files (*)"
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
if not model_path:
|
|
671
|
+
return
|
|
672
|
+
|
|
673
|
+
try:
|
|
674
|
+
self.log_text.append("Loading model...")
|
|
675
|
+
|
|
676
|
+
pre_state_dict = None
|
|
677
|
+
inferred_embed_dim = None
|
|
678
|
+
inferred_backbone = None
|
|
679
|
+
if os.path.exists(model_path):
|
|
680
|
+
try:
|
|
681
|
+
pre_state_dict = torch.load(model_path, map_location='cpu')
|
|
682
|
+
for key in (
|
|
683
|
+
"head_root.query",
|
|
684
|
+
"head_class.query",
|
|
685
|
+
"head_root.ln1.weight",
|
|
686
|
+
"head_class.ln1.weight",
|
|
687
|
+
"head_root.fc.weight",
|
|
688
|
+
"head_class.fc.weight",
|
|
689
|
+
"query",
|
|
690
|
+
"ln1.weight",
|
|
691
|
+
"fc.weight",
|
|
692
|
+
):
|
|
693
|
+
if key in pre_state_dict:
|
|
694
|
+
tensor = pre_state_dict[key]
|
|
695
|
+
if tensor is None:
|
|
696
|
+
continue
|
|
697
|
+
if tensor.ndim == 3:
|
|
698
|
+
inferred_embed_dim = tensor.shape[-1]
|
|
699
|
+
elif tensor.ndim == 2:
|
|
700
|
+
inferred_embed_dim = tensor.shape[1]
|
|
701
|
+
elif tensor.ndim == 1:
|
|
702
|
+
inferred_embed_dim = tensor.shape[0]
|
|
703
|
+
if inferred_embed_dim:
|
|
704
|
+
break
|
|
705
|
+
embed_to_backbone = {
|
|
706
|
+
768: "videoprism_public_v1_base",
|
|
707
|
+
1024: "videoprism_public_v1_large",
|
|
708
|
+
}
|
|
709
|
+
inferred_backbone = embed_to_backbone.get(inferred_embed_dim)
|
|
710
|
+
if inferred_backbone:
|
|
711
|
+
self.log_text.append(
|
|
712
|
+
f"Inferred backbone from weights: {inferred_backbone} (embed_dim={inferred_embed_dim})"
|
|
713
|
+
)
|
|
714
|
+
except Exception:
|
|
715
|
+
pre_state_dict = None
|
|
716
|
+
# Attempt to load metadata (preferred)
|
|
717
|
+
meta_classes = None
|
|
718
|
+
meta_path = model_path + ".meta.json"
|
|
719
|
+
metadata_backbone = None
|
|
720
|
+
if os.path.exists(meta_path):
|
|
721
|
+
try:
|
|
722
|
+
with open(meta_path, "r", encoding="utf-8") as meta_file:
|
|
723
|
+
meta_data = json.load(meta_file)
|
|
724
|
+
meta_classes = meta_data.get("classes")
|
|
725
|
+
self.attributes = meta_data.get("attributes", []) # Load attributes
|
|
726
|
+
self.attributes_registry = meta_data.get("attributes_registry", None)
|
|
727
|
+
self.hierarchy_registry = meta_data.get("hierarchy_registry", None)
|
|
728
|
+
self.label_mapping = meta_data.get("label_mapping", None)
|
|
729
|
+
if meta_classes:
|
|
730
|
+
self.log_text.append(f"Loaded {len(meta_classes)} classes from metadata: {meta_classes}")
|
|
731
|
+
if self.attributes:
|
|
732
|
+
self.log_text.append(f"Loaded {len(self.attributes)} attributes: {self.attributes}")
|
|
733
|
+
if self.hierarchy_registry:
|
|
734
|
+
self.log_text.append("Loaded deep hierarchy registry")
|
|
735
|
+
if not self.attributes:
|
|
736
|
+
leaf_nodes = [
|
|
737
|
+
n for n, children in self.hierarchy_registry.items()
|
|
738
|
+
if not children and n != "__root__"
|
|
739
|
+
]
|
|
740
|
+
self.attributes = sorted(set(leaf_nodes))
|
|
741
|
+
if self.attributes:
|
|
742
|
+
self.log_text.append(f"Derived {len(self.attributes)} leaf attributes from hierarchy")
|
|
743
|
+
if self.attributes_registry:
|
|
744
|
+
self.log_text.append(f"Loaded hierarchical registry for {len(self.attributes_registry)} classes")
|
|
745
|
+
elif self.attributes and meta_classes:
|
|
746
|
+
# Fallback: Reconstruct registry from attribute names if missing (for models trained before registry was saved)
|
|
747
|
+
self.log_text.append("Registry missing from metadata. Attempting to reconstruct from attribute names...")
|
|
748
|
+
reconstructed = {}
|
|
749
|
+
for cls in meta_classes:
|
|
750
|
+
reconstructed[cls] = []
|
|
751
|
+
|
|
752
|
+
for attr in self.attributes:
|
|
753
|
+
# Match attribute to class based on prefix (e.g. "Walk_Start" -> "Walk")
|
|
754
|
+
# We look for the longest matching class name prefix
|
|
755
|
+
best_match = None
|
|
756
|
+
for cls in meta_classes:
|
|
757
|
+
if attr.startswith(cls + "_") or attr == cls:
|
|
758
|
+
if best_match is None or len(cls) > len(best_match):
|
|
759
|
+
best_match = cls
|
|
760
|
+
|
|
761
|
+
if best_match:
|
|
762
|
+
reconstructed[best_match].append(attr)
|
|
763
|
+
|
|
764
|
+
# Only use if we found matches
|
|
765
|
+
if any(reconstructed.values()):
|
|
766
|
+
for cls in reconstructed:
|
|
767
|
+
reconstructed[cls].sort()
|
|
768
|
+
self.attributes_registry = reconstructed
|
|
769
|
+
self.log_text.append(f"Reconstructed registry: { {k: len(v) for k, v in reconstructed.items()} }")
|
|
770
|
+
elif self.label_mapping and meta_classes:
|
|
771
|
+
if any("path" in v for v in self.label_mapping.values()):
|
|
772
|
+
# Fallback: Build hierarchy from label_mapping paths
|
|
773
|
+
self.log_text.append("Hierarchy missing from metadata. Reconstructing from label_mapping...")
|
|
774
|
+
reconstructed = {"__root__": []}
|
|
775
|
+
for _, info in self.label_mapping.items():
|
|
776
|
+
path = info.get("path", [])
|
|
777
|
+
if not path:
|
|
778
|
+
continue
|
|
779
|
+
if path[0] not in reconstructed["__root__"]:
|
|
780
|
+
reconstructed["__root__"].append(path[0])
|
|
781
|
+
for i in range(1, len(path)):
|
|
782
|
+
parent = path[i - 1]
|
|
783
|
+
child = path[i]
|
|
784
|
+
reconstructed.setdefault(parent, [])
|
|
785
|
+
if child not in reconstructed[parent]:
|
|
786
|
+
reconstructed[parent].append(child)
|
|
787
|
+
reconstructed.setdefault(path[-1], [])
|
|
788
|
+
for node in reconstructed:
|
|
789
|
+
reconstructed[node] = sorted(set(reconstructed[node]))
|
|
790
|
+
self.hierarchy_registry = reconstructed
|
|
791
|
+
leaf_nodes = [
|
|
792
|
+
n for n, children in reconstructed.items()
|
|
793
|
+
if not children and n != "__root__"
|
|
794
|
+
]
|
|
795
|
+
self.attributes = sorted(set(leaf_nodes))
|
|
796
|
+
self.log_text.append("Reconstructed deep hierarchy registry from label_mapping")
|
|
797
|
+
else:
|
|
798
|
+
# Fallback: Build registry directly from label_mapping
|
|
799
|
+
self.log_text.append("Registry missing from metadata. Reconstructing from label_mapping...")
|
|
800
|
+
reconstructed = {cls: [] for cls in meta_classes}
|
|
801
|
+
for raw_label, info in self.label_mapping.items():
|
|
802
|
+
cls_name = info.get("class")
|
|
803
|
+
attr_name = info.get("attribute")
|
|
804
|
+
if cls_name in reconstructed and attr_name:
|
|
805
|
+
reconstructed[cls_name].append(attr_name)
|
|
806
|
+
if any(reconstructed.values()):
|
|
807
|
+
for cls in reconstructed:
|
|
808
|
+
reconstructed[cls] = sorted(set(reconstructed[cls]))
|
|
809
|
+
self.attributes_registry = reconstructed
|
|
810
|
+
self.attributes = sorted({a for attrs in reconstructed.values() for a in attrs})
|
|
811
|
+
self.log_text.append(f"Reconstructed registry from label_mapping: { {k: len(v) for k, v in reconstructed.items()} }")
|
|
812
|
+
|
|
813
|
+
# Load clip length from metadata if available
|
|
814
|
+
clip_len = meta_data.get("clip_length")
|
|
815
|
+
if clip_len:
|
|
816
|
+
self.clip_length_spin.setValue(int(clip_len))
|
|
817
|
+
self.log_text.append(f"Automatically set 'Frames per clip' to {clip_len} from model metadata.")
|
|
818
|
+
|
|
819
|
+
meta_fps = meta_data.get("target_fps") or meta_data.get("training_config", {}).get("target_fps")
|
|
820
|
+
if meta_fps:
|
|
821
|
+
self.target_fps_spin.setValue(int(meta_fps))
|
|
822
|
+
self.log_text.append(f"Automatically set 'Target FPS' to {meta_fps} from model metadata.")
|
|
823
|
+
|
|
824
|
+
resolution = meta_data.get("resolution") or meta_data.get("training_config", {}).get("resolution")
|
|
825
|
+
if resolution:
|
|
826
|
+
self.infer_resolution = int(resolution)
|
|
827
|
+
self.resolution_spin.setValue(self.infer_resolution)
|
|
828
|
+
self.log_text.append(f"Using inference resolution {self.infer_resolution}x{self.infer_resolution} from model metadata.")
|
|
829
|
+
|
|
830
|
+
if not resolution:
|
|
831
|
+
config_path = os.path.splitext(model_path)[0] + "_training_config.json"
|
|
832
|
+
if os.path.exists(config_path):
|
|
833
|
+
try:
|
|
834
|
+
with open(config_path, "r", encoding="utf-8") as cfg_file:
|
|
835
|
+
cfg_data = json.load(cfg_file)
|
|
836
|
+
if isinstance(cfg_data, dict):
|
|
837
|
+
cfg_train = cfg_data.get("training_config", cfg_data)
|
|
838
|
+
if isinstance(cfg_train, dict):
|
|
839
|
+
self.model_training_config = dict(cfg_train)
|
|
840
|
+
cfg_resolution = cfg_data.get("resolution") or cfg_data.get("training_config", {}).get("resolution")
|
|
841
|
+
if cfg_resolution:
|
|
842
|
+
self.infer_resolution = int(cfg_resolution)
|
|
843
|
+
self.resolution_spin.setValue(self.infer_resolution)
|
|
844
|
+
self.log_text.append(f"Using inference resolution {self.infer_resolution}x{self.infer_resolution} from training config.")
|
|
845
|
+
except Exception as cfg_err:
|
|
846
|
+
self.log_text.append(f"Warning: Failed to read training config {config_path}: {cfg_err}")
|
|
847
|
+
|
|
848
|
+
# Load localization crop parameters from training metadata
|
|
849
|
+
train_cfg = meta_data.get("training_config", {})
|
|
850
|
+
self.model_training_config = dict(train_cfg) if isinstance(train_cfg, dict) else {}
|
|
851
|
+
self._crop_padding = float(train_cfg.get("classification_crop_padding", 0.35) or 0.35)
|
|
852
|
+
self._crop_min_size = float(train_cfg.get("classification_crop_min_size_norm", 0.04) or 0.04)
|
|
853
|
+
|
|
854
|
+
# OvR: use sigmoid scoring instead of softmax at inference
|
|
855
|
+
self._use_ovr = bool(train_cfg.get("use_ovr", False))
|
|
856
|
+
self._update_viterbi_ui_state()
|
|
857
|
+
self._allowed_cooccurrence = set()
|
|
858
|
+
self._ovr_temperatures = {}
|
|
859
|
+
if self._use_ovr:
|
|
860
|
+
if getattr(self, "ovr_rows_check", None) is not None:
|
|
861
|
+
self.ovr_rows_check.setChecked(True)
|
|
862
|
+
self.log_text.append("OvR model detected: using sigmoid scoring at inference.")
|
|
863
|
+
for pair in train_cfg.get("allowed_cooccurrence", []):
|
|
864
|
+
if isinstance(pair, (list, tuple)) and len(pair) == 2:
|
|
865
|
+
self._allowed_cooccurrence.add((pair[0], pair[1]))
|
|
866
|
+
self._allowed_cooccurrence.add((pair[1], pair[0]))
|
|
867
|
+
if self._allowed_cooccurrence:
|
|
868
|
+
unique = train_cfg.get("allowed_cooccurrence", [])
|
|
869
|
+
pairs_str = ", ".join(f"{a}+{b}" for a, b in unique)
|
|
870
|
+
self.log_text.append(f"Allowed co-occurrence pairs: {pairs_str}")
|
|
871
|
+
|
|
872
|
+
# Per-head calibrated temperatures from training
|
|
873
|
+
ovr_temps = meta_data.get("ovr_temperatures", {})
|
|
874
|
+
if ovr_temps:
|
|
875
|
+
self._ovr_temperatures = {k: float(v) for k, v in ovr_temps.items()}
|
|
876
|
+
t_str = ", ".join(f"{k}={v}" for k, v in self._ovr_temperatures.items())
|
|
877
|
+
self.log_text.append(f"OvR calibrated temperatures: {t_str}")
|
|
878
|
+
|
|
879
|
+
# Validate backbone model
|
|
880
|
+
head_backbone = meta_data.get("backbone_model")
|
|
881
|
+
current_backbone = self.config.get("backbone_model", "videoprism_public_v1_base")
|
|
882
|
+
|
|
883
|
+
# Prefer inferred backbone if metadata disagrees with weights
|
|
884
|
+
if inferred_backbone and head_backbone and head_backbone != inferred_backbone:
|
|
885
|
+
self.log_text.append(
|
|
886
|
+
f"WARNING: Metadata backbone '{head_backbone}' does not match weights. "
|
|
887
|
+
f"Using inferred backbone '{inferred_backbone}'."
|
|
888
|
+
)
|
|
889
|
+
head_backbone = inferred_backbone
|
|
890
|
+
|
|
891
|
+
if head_backbone and head_backbone != current_backbone:
|
|
892
|
+
msg = (
|
|
893
|
+
f"Backbone Model Mismatch (auto-corrected for inference).\n"
|
|
894
|
+
f"Head backbone: '{head_backbone}', current config: '{current_backbone}'.\n"
|
|
895
|
+
f"Using '{head_backbone}' for inference."
|
|
896
|
+
)
|
|
897
|
+
self.log_text.append(f"WARNING: {msg}")
|
|
898
|
+
|
|
899
|
+
metadata_backbone = head_backbone
|
|
900
|
+
|
|
901
|
+
except Exception as meta_err:
|
|
902
|
+
self.log_text.append(f"Warning: Failed to read metadata file {meta_path}: {meta_err}")
|
|
903
|
+
|
|
904
|
+
if meta_classes:
|
|
905
|
+
self.classes = meta_classes
|
|
906
|
+
else:
|
|
907
|
+
annotation_file = self.config.get("annotation_file", "data/annotations/annotations.json")
|
|
908
|
+
if os.path.exists(annotation_file):
|
|
909
|
+
annotation_manager = AnnotationManager(annotation_file)
|
|
910
|
+
self.classes = annotation_manager.get_classes()
|
|
911
|
+
else:
|
|
912
|
+
QMessageBox.warning(self, "Warning", "Annotation file not found. Cannot determine classes.")
|
|
913
|
+
return
|
|
914
|
+
|
|
915
|
+
if not self.classes:
|
|
916
|
+
QMessageBox.warning(self, "Error", "No classes found for this model.")
|
|
917
|
+
return
|
|
918
|
+
|
|
919
|
+
self.log_text.append(f"Using {len(self.classes)} classes: {self.classes}")
|
|
920
|
+
|
|
921
|
+
resolved_backbone = inferred_backbone or metadata_backbone
|
|
922
|
+
if not resolved_backbone:
|
|
923
|
+
resolved_backbone = self.config.get("backbone_model", "videoprism_public_v1_base")
|
|
924
|
+
model_name = resolved_backbone
|
|
925
|
+
self.log_text.append(f"Loading VideoPrism backbone ({model_name}) at resolution {self.infer_resolution}×{self.infer_resolution}...")
|
|
926
|
+
|
|
927
|
+
backbone = VideoPrismBackbone(model_name=model_name, resolution=self.infer_resolution)
|
|
928
|
+
|
|
929
|
+
head_kwargs = {"num_heads": 4}
|
|
930
|
+
dropout = 0.1
|
|
931
|
+
num_attributes = 0
|
|
932
|
+
use_localization = False
|
|
933
|
+
use_frame_head = True
|
|
934
|
+
frame_head_temporal_layers = 1
|
|
935
|
+
temporal_pool_frames = 1
|
|
936
|
+
num_stages = 3
|
|
937
|
+
proj_dim = 256
|
|
938
|
+
localization_hidden_dim = 256
|
|
939
|
+
multi_scale = False
|
|
940
|
+
use_temporal_decoder = True
|
|
941
|
+
|
|
942
|
+
try:
|
|
943
|
+
if os.path.exists(meta_path):
|
|
944
|
+
num_attributes = meta_data.get("num_attributes", 0)
|
|
945
|
+
head_cfg = meta_data.get("head", {}) if 'meta_data' in locals() and isinstance(meta_data, dict) else {}
|
|
946
|
+
if isinstance(head_cfg, dict):
|
|
947
|
+
if "num_heads" in head_cfg:
|
|
948
|
+
head_kwargs["num_heads"] = head_cfg.get("num_heads", head_kwargs["num_heads"])
|
|
949
|
+
if "dropout" in head_cfg:
|
|
950
|
+
dropout = head_cfg.get("dropout", dropout)
|
|
951
|
+
use_localization = bool(
|
|
952
|
+
head_cfg.get("use_localization", False)
|
|
953
|
+
or meta_data.get("training_config", {}).get("use_localization", False)
|
|
954
|
+
)
|
|
955
|
+
localization_hidden_dim = int(head_cfg.get("localization_hidden_dim", 256))
|
|
956
|
+
use_temporal_decoder = bool(
|
|
957
|
+
head_cfg.get(
|
|
958
|
+
"use_temporal_decoder",
|
|
959
|
+
meta_data.get("training_config", {}).get("use_temporal_decoder", True),
|
|
960
|
+
)
|
|
961
|
+
)
|
|
962
|
+
frame_head_temporal_layers = int(head_cfg.get("frame_head_temporal_layers", 1))
|
|
963
|
+
temporal_pool_frames = int(head_cfg.get("temporal_pool_frames", 1))
|
|
964
|
+
num_stages = int(head_cfg.get("num_stages", 3))
|
|
965
|
+
proj_dim = int(head_cfg.get("proj_dim", 256))
|
|
966
|
+
multi_scale = bool(head_cfg.get("multi_scale", False))
|
|
967
|
+
if frame_head_temporal_layers <= 0:
|
|
968
|
+
frame_head_temporal_layers = int(
|
|
969
|
+
meta_data.get("training_config", {}).get("frame_head_temporal_layers", 1)
|
|
970
|
+
)
|
|
971
|
+
if temporal_pool_frames <= 0:
|
|
972
|
+
temporal_pool_frames = int(
|
|
973
|
+
meta_data.get("training_config", {}).get("temporal_pool_frames", 1)
|
|
974
|
+
)
|
|
975
|
+
if num_stages <= 0:
|
|
976
|
+
num_stages = int(
|
|
977
|
+
meta_data.get("training_config", {}).get("num_stages", 3)
|
|
978
|
+
)
|
|
979
|
+
except Exception:
|
|
980
|
+
pass
|
|
981
|
+
|
|
982
|
+
if frame_head_temporal_layers <= 0:
|
|
983
|
+
frame_head_temporal_layers = 1
|
|
984
|
+
if use_temporal_decoder:
|
|
985
|
+
self.log_text.append(f"Frame head: temporal decoder on, layers={frame_head_temporal_layers}, proj_dim={proj_dim}")
|
|
986
|
+
else:
|
|
987
|
+
self.log_text.append(f"Frame head: direct per-frame classifier, proj_dim={proj_dim}")
|
|
988
|
+
|
|
989
|
+
# Peek at checkpoint to detect multi_scale (most reliable source)
|
|
990
|
+
try:
|
|
991
|
+
_ckpt = torch.load(model_path, map_location='cpu', weights_only=False)
|
|
992
|
+
if isinstance(_ckpt, dict) and "multi_scale" in _ckpt:
|
|
993
|
+
multi_scale = bool(_ckpt["multi_scale"])
|
|
994
|
+
if isinstance(_ckpt, dict) and "use_temporal_decoder" in _ckpt:
|
|
995
|
+
use_temporal_decoder = bool(_ckpt["use_temporal_decoder"])
|
|
996
|
+
del _ckpt
|
|
997
|
+
except Exception:
|
|
998
|
+
pass
|
|
999
|
+
|
|
1000
|
+
self.log_text.append(f"Creating classifier (kwargs={head_kwargs}, multi_scale={multi_scale})...")
|
|
1001
|
+
self.model = BehaviorClassifier(
|
|
1002
|
+
backbone,
|
|
1003
|
+
num_classes=len(self.classes),
|
|
1004
|
+
class_names=self.classes,
|
|
1005
|
+
dropout=dropout,
|
|
1006
|
+
freeze_backbone=True,
|
|
1007
|
+
head_kwargs=head_kwargs,
|
|
1008
|
+
use_localization=use_localization,
|
|
1009
|
+
localization_hidden_dim=localization_hidden_dim,
|
|
1010
|
+
use_frame_head=use_frame_head,
|
|
1011
|
+
use_temporal_decoder=use_temporal_decoder,
|
|
1012
|
+
frame_head_temporal_layers=frame_head_temporal_layers,
|
|
1013
|
+
temporal_pool_frames=temporal_pool_frames,
|
|
1014
|
+
proj_dim=proj_dim,
|
|
1015
|
+
num_stages=num_stages,
|
|
1016
|
+
multi_scale=multi_scale,
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
self.log_text.append(f"Loading head weights from {model_path}...")
|
|
1020
|
+
if getattr(self, "_use_ovr", False) and getattr(self.model, "frame_head", None) is not None:
|
|
1021
|
+
self.model.frame_head.use_ovr = True
|
|
1022
|
+
try:
|
|
1023
|
+
self.model.load_head(model_path)
|
|
1024
|
+
except RuntimeError as e:
|
|
1025
|
+
err_str = str(e)
|
|
1026
|
+
if "size mismatch" in err_str:
|
|
1027
|
+
msg = (
|
|
1028
|
+
f"Model architecture mismatch.\n"
|
|
1029
|
+
f"Error Details: {err_str}\n\n"
|
|
1030
|
+
"Please ensure the weights file matches the current architecture."
|
|
1031
|
+
)
|
|
1032
|
+
self.log_text.append(f"ERROR: {msg}")
|
|
1033
|
+
QMessageBox.critical(self, "Model Load Error", msg)
|
|
1034
|
+
self.model = None
|
|
1035
|
+
return
|
|
1036
|
+
raise
|
|
1037
|
+
|
|
1038
|
+
if multi_scale:
|
|
1039
|
+
self.log_text.append(
|
|
1040
|
+
"Multi-scale mode active: inference runs backbone twice per clip "
|
|
1041
|
+
"(full fps + half fps). Expect ~2x backbone time per clip."
|
|
1042
|
+
)
|
|
1043
|
+
self.model_path_edit.setText(model_path)
|
|
1044
|
+
self.classes_label.setText(", ".join(self.classes))
|
|
1045
|
+
|
|
1046
|
+
# Populate filter combo
|
|
1047
|
+
self.filter_behavior_combo.blockSignals(True)
|
|
1048
|
+
self.filter_behavior_combo.clear()
|
|
1049
|
+
self.filter_behavior_combo.addItem("All Behaviors")
|
|
1050
|
+
self.filter_behavior_combo.addItem(self.ignore_label_name)
|
|
1051
|
+
self.filter_behavior_combo.addItems(self.classes)
|
|
1052
|
+
self.filter_behavior_combo.blockSignals(False)
|
|
1053
|
+
|
|
1054
|
+
self.log_text.append("Model loaded successfully!")
|
|
1055
|
+
QMessageBox.information(self, "Success", "Model loaded successfully!")
|
|
1056
|
+
|
|
1057
|
+
# Offer autosave recovery if a crash interrupted a previous batch run
|
|
1058
|
+
self._check_autosave_recovery(model_path)
|
|
1059
|
+
|
|
1060
|
+
if self.video_path:
|
|
1061
|
+
self.run_inference_btn.setEnabled(True)
|
|
1062
|
+
|
|
1063
|
+
except Exception as e:
|
|
1064
|
+
error_msg = f"Failed to load model: {str(e)}"
|
|
1065
|
+
self.log_text.append(f"ERROR: {error_msg}")
|
|
1066
|
+
QMessageBox.critical(self, "Error", error_msg)
|
|
1067
|
+
|
|
1068
|
+
def _select_video(self):
|
|
1069
|
+
"""Select video file(s) for inference."""
|
|
1070
|
+
video_paths, _ = QFileDialog.getOpenFileNames(
|
|
1071
|
+
self,
|
|
1072
|
+
"Select Video File(s)",
|
|
1073
|
+
self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos")),
|
|
1074
|
+
"Video Files (*.mp4 *.avi *.mov *.mkv);;All Files (*)"
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
if video_paths:
|
|
1078
|
+
# Ensure videos are in experiment folder (batch operation)
|
|
1079
|
+
from .video_utils import ensure_videos_in_experiment
|
|
1080
|
+
self.video_paths = ensure_videos_in_experiment(video_paths, self.config, self)
|
|
1081
|
+
if len(video_paths) == 1:
|
|
1082
|
+
self.video_path = video_paths[0]
|
|
1083
|
+
self.video_path_edit.setText(self.video_path)
|
|
1084
|
+
|
|
1085
|
+
info = get_video_info(self.video_path)
|
|
1086
|
+
if info:
|
|
1087
|
+
info_text = f"FPS: {info['fps']:.2f}, Frames: {info['frame_count']}, Size: {info['width']}x{info['height']}"
|
|
1088
|
+
self.video_info_label.setText(info_text)
|
|
1089
|
+
else:
|
|
1090
|
+
self.video_path = video_paths[0] # Set first as default for preview
|
|
1091
|
+
self.video_path_edit.setText(f"{len(video_paths)} videos selected")
|
|
1092
|
+
self.video_info_label.setText(f"Selected {len(video_paths)} videos")
|
|
1093
|
+
|
|
1094
|
+
if self.model:
|
|
1095
|
+
self.run_inference_btn.setEnabled(True)
|
|
1096
|
+
|
|
1097
|
+
def _update_sample_inference_controls(self, enabled: bool):
|
|
1098
|
+
self.sample_duration_spin.setEnabled(bool(enabled))
|
|
1099
|
+
self.sample_count_spin.setEnabled(bool(enabled))
|
|
1100
|
+
|
|
1101
|
+
def _compute_sample_ranges_for_video(self, video_path: str):
|
|
1102
|
+
"""Build evenly spread sample ranges [(start_frame, end_frame), ...] for one video."""
|
|
1103
|
+
if not video_path:
|
|
1104
|
+
return []
|
|
1105
|
+
cap = cv2.VideoCapture(video_path)
|
|
1106
|
+
if not cap.isOpened():
|
|
1107
|
+
return []
|
|
1108
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
1109
|
+
if fps <= 0:
|
|
1110
|
+
fps = 30.0
|
|
1111
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
1112
|
+
cap.release()
|
|
1113
|
+
if total_frames <= 0:
|
|
1114
|
+
return []
|
|
1115
|
+
|
|
1116
|
+
dur_frames = int(round(self.sample_duration_spin.value() * fps))
|
|
1117
|
+
dur_frames = max(1, min(dur_frames, total_frames))
|
|
1118
|
+
n = max(1, int(self.sample_count_spin.value()))
|
|
1119
|
+
usable = max(0, total_frames - dur_frames)
|
|
1120
|
+
if n == 1 or usable == 0:
|
|
1121
|
+
starts = [usable // 2]
|
|
1122
|
+
else:
|
|
1123
|
+
starts = [int(round(i * usable / (n - 1))) for i in range(n)]
|
|
1124
|
+
return [(start, min(start + dur_frames, total_frames)) for start in starts]
|
|
1125
|
+
|
|
1126
|
+
def _run_inference(self):
|
|
1127
|
+
"""Run inference on selected video."""
|
|
1128
|
+
if not self.model:
|
|
1129
|
+
QMessageBox.warning(self, "Error", "Please load a model first.")
|
|
1130
|
+
return
|
|
1131
|
+
|
|
1132
|
+
if not self.video_path:
|
|
1133
|
+
QMessageBox.warning(self, "Error", "Please select a video first.")
|
|
1134
|
+
return
|
|
1135
|
+
|
|
1136
|
+
self.run_inference_btn.setEnabled(False)
|
|
1137
|
+
self.load_timeline_btn.setEnabled(False)
|
|
1138
|
+
self.progress_bar.setVisible(True)
|
|
1139
|
+
self.progress_bar.setValue(0)
|
|
1140
|
+
self.progress_label.setText("Running inference...")
|
|
1141
|
+
self.results_list.clear()
|
|
1142
|
+
self.log_text.clear()
|
|
1143
|
+
self.log_text.append("Starting inference...")
|
|
1144
|
+
if self.override_resolution_check.isChecked():
|
|
1145
|
+
self.infer_resolution = int(self.resolution_spin.value())
|
|
1146
|
+
self.log_text.append(f"Inference resolution (override): {self.infer_resolution}x{self.infer_resolution}")
|
|
1147
|
+
else:
|
|
1148
|
+
self.log_text.append(f"Inference resolution: {self.infer_resolution}x{self.infer_resolution}")
|
|
1149
|
+
|
|
1150
|
+
video_paths = self.video_paths if self.video_paths else ([self.video_path] if self.video_path else [])
|
|
1151
|
+
sample_ranges_by_video = {}
|
|
1152
|
+
if self.sample_inference_check.isChecked():
|
|
1153
|
+
for v_path in video_paths:
|
|
1154
|
+
sample_ranges = self._compute_sample_ranges_for_video(v_path)
|
|
1155
|
+
if sample_ranges:
|
|
1156
|
+
sample_ranges_by_video[v_path] = sample_ranges
|
|
1157
|
+
cap = cv2.VideoCapture(v_path)
|
|
1158
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
1159
|
+
cap.release()
|
|
1160
|
+
if fps <= 0:
|
|
1161
|
+
fps = 30.0
|
|
1162
|
+
summary = ", ".join(
|
|
1163
|
+
f"{start / fps:.1f}s-{end / fps:.1f}s" for start, end in sample_ranges[:6]
|
|
1164
|
+
)
|
|
1165
|
+
if len(sample_ranges) > 6:
|
|
1166
|
+
summary += ", ..."
|
|
1167
|
+
self.log_text.append(
|
|
1168
|
+
f"Quick-check sampled inference for {os.path.basename(v_path)}: {summary}"
|
|
1169
|
+
)
|
|
1170
|
+
else:
|
|
1171
|
+
self.log_text.append(
|
|
1172
|
+
f"Warning: could not build sample ranges for {os.path.basename(v_path)}. "
|
|
1173
|
+
"Falling back to full video."
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
model_path = self.model_path_edit.text()
|
|
1177
|
+
|
|
1178
|
+
self.worker = InferenceWorker(
|
|
1179
|
+
self.model,
|
|
1180
|
+
video_paths,
|
|
1181
|
+
self.target_fps_spin.value(),
|
|
1182
|
+
self.clip_length_spin.value(),
|
|
1183
|
+
self.step_frames_spin.value(),
|
|
1184
|
+
resolution=self.infer_resolution,
|
|
1185
|
+
classes=self.classes,
|
|
1186
|
+
use_localization_pipeline=getattr(self.model, "use_localization", False),
|
|
1187
|
+
crop_padding=getattr(self, "_crop_padding", 0.35),
|
|
1188
|
+
crop_min_size=getattr(self, "_crop_min_size", 0.04),
|
|
1189
|
+
use_ovr=getattr(self, "_use_ovr", False),
|
|
1190
|
+
ovr_temperatures=getattr(self, "_ovr_temperatures", {}),
|
|
1191
|
+
collect_attention=self.collect_attention_check.isChecked(),
|
|
1192
|
+
sample_ranges_by_video=sample_ranges_by_video,
|
|
1193
|
+
)
|
|
1194
|
+
self.worker.progress.connect(self._on_progress)
|
|
1195
|
+
self.worker.finished.connect(self._on_finished)
|
|
1196
|
+
self.worker.error.connect(self._on_error)
|
|
1197
|
+
self.worker.log_message.connect(self._on_log)
|
|
1198
|
+
self.worker.video_done.connect(self._on_video_done)
|
|
1199
|
+
self.run_inference_btn.setEnabled(False)
|
|
1200
|
+
self.stop_inference_btn.setEnabled(True)
|
|
1201
|
+
self.worker.start()
|
|
1202
|
+
|
|
1203
|
+
def _on_progress(self, current: int, total: int):
|
|
1204
|
+
"""Update progress (capped at 100%)."""
|
|
1205
|
+
if total > 0:
|
|
1206
|
+
progress = min(100, int(100 * current / total))
|
|
1207
|
+
self.progress_bar.setValue(progress)
|
|
1208
|
+
self.progress_label.setText(f"Processing: {current}/{total} items ({progress}%)")
|
|
1209
|
+
|
|
1210
|
+
def _on_finished(self, results: dict):
|
|
1211
|
+
"""Handle inference completion."""
|
|
1212
|
+
self.results_cache = results
|
|
1213
|
+
self._ignore_threshold_user_modified = False
|
|
1214
|
+
|
|
1215
|
+
# Initialize corrections storage for each video if not present
|
|
1216
|
+
for v_path in self.results_cache:
|
|
1217
|
+
if "corrected_labels" not in self.results_cache[v_path]:
|
|
1218
|
+
self.results_cache[v_path]["corrected_labels"] = {}
|
|
1219
|
+
if "corrected_attr_labels" not in self.results_cache[v_path]:
|
|
1220
|
+
self.results_cache[v_path]["corrected_attr_labels"] = {}
|
|
1221
|
+
|
|
1222
|
+
# Populate video dropdown
|
|
1223
|
+
self.filter_video_combo.blockSignals(True)
|
|
1224
|
+
self.filter_video_combo.clear()
|
|
1225
|
+
for path in sorted(results.keys()):
|
|
1226
|
+
self.filter_video_combo.addItem(os.path.basename(path), path)
|
|
1227
|
+
self.filter_video_combo.blockSignals(False)
|
|
1228
|
+
|
|
1229
|
+
# Select first video
|
|
1230
|
+
if results:
|
|
1231
|
+
has_any_frame_probs = any(
|
|
1232
|
+
isinstance(res.get("clip_frame_probabilities"), list) and len(res.get("clip_frame_probabilities", [])) > 0
|
|
1233
|
+
for res in results.values()
|
|
1234
|
+
)
|
|
1235
|
+
if has_any_frame_probs and not self.frame_aggregation_check.isChecked():
|
|
1236
|
+
# Automatically switch to precise frame-boundary mode when
|
|
1237
|
+
# frame-head outputs are available.
|
|
1238
|
+
self.frame_aggregation_check.setChecked(True)
|
|
1239
|
+
self.log_text.append("Detected frame-head predictions. Enabled 'Precise frame boundaries' mode.")
|
|
1240
|
+
first_video = sorted(results.keys())[0]
|
|
1241
|
+
idx = self.filter_video_combo.findData(first_video)
|
|
1242
|
+
if idx >= 0:
|
|
1243
|
+
self.filter_video_combo.setCurrentIndex(idx)
|
|
1244
|
+
self._on_video_selection_changed(idx)
|
|
1245
|
+
|
|
1246
|
+
total_clips = sum(len(res["predictions"]) for res in results.values())
|
|
1247
|
+
self.progress_bar.setValue(100)
|
|
1248
|
+
self.progress_bar.setVisible(False)
|
|
1249
|
+
if results:
|
|
1250
|
+
self.progress_label.setText(f"Inference complete! Processed {len(results)} videos, {total_clips} clips")
|
|
1251
|
+
else:
|
|
1252
|
+
self.progress_label.setText("Stopped. No videos were completed.")
|
|
1253
|
+
self.run_inference_btn.setEnabled(True)
|
|
1254
|
+
self.stop_inference_btn.setEnabled(False)
|
|
1255
|
+
self.load_timeline_btn.setEnabled(True)
|
|
1256
|
+
|
|
1257
|
+
self.export_btn.setEnabled(bool(results))
|
|
1258
|
+
self.export_timeline_btn.setEnabled(bool(results))
|
|
1259
|
+
self.save_results_btn.setEnabled(bool(results))
|
|
1260
|
+
has_attn = any(
|
|
1261
|
+
"clip_attention_maps" in res for res in results.values()
|
|
1262
|
+
) if results else False
|
|
1263
|
+
self.export_attention_btn.setEnabled(has_attn)
|
|
1264
|
+
|
|
1265
|
+
if results:
|
|
1266
|
+
self._save_results(silent=True)
|
|
1267
|
+
autosave_path = self._autosave_path()
|
|
1268
|
+
if autosave_path and os.path.exists(autosave_path):
|
|
1269
|
+
try:
|
|
1270
|
+
os.remove(autosave_path)
|
|
1271
|
+
except Exception:
|
|
1272
|
+
pass
|
|
1273
|
+
|
|
1274
|
+
# Save uncertainty report and notify the Review tab.
|
|
1275
|
+
try:
|
|
1276
|
+
model_path = self.model_path_edit.text()
|
|
1277
|
+
if model_path:
|
|
1278
|
+
uncertainty_path = os.path.join(
|
|
1279
|
+
os.path.dirname(model_path),
|
|
1280
|
+
os.path.splitext(os.path.basename(model_path))[0] + "_uncertainty.json",
|
|
1281
|
+
)
|
|
1282
|
+
else:
|
|
1283
|
+
exp_path = self.config.get("experiment_path", "")
|
|
1284
|
+
uncertainty_path = os.path.join(exp_path, "results", "_uncertainty.json")
|
|
1285
|
+
save_uncertainty_report(
|
|
1286
|
+
results=results,
|
|
1287
|
+
classes=self.classes,
|
|
1288
|
+
output_path=uncertainty_path,
|
|
1289
|
+
is_ovr=bool(self._use_ovr),
|
|
1290
|
+
n_per_class=25,
|
|
1291
|
+
clip_length=self.clip_length_spin.value(),
|
|
1292
|
+
target_fps=self.target_fps_spin.value(),
|
|
1293
|
+
)
|
|
1294
|
+
self.log_text.append(f"Uncertainty report saved to: {uncertainty_path}")
|
|
1295
|
+
except Exception as _ue:
|
|
1296
|
+
self.log_text.append(f"Warning: could not save uncertainty report: {_ue}")
|
|
1297
|
+
|
|
1298
|
+
self.review_ready.emit(
|
|
1299
|
+
results,
|
|
1300
|
+
list(self.classes),
|
|
1301
|
+
bool(self._use_ovr),
|
|
1302
|
+
self.clip_length_spin.value(),
|
|
1303
|
+
self.target_fps_spin.value(),
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
QMessageBox.information(self, "Success", f"Inference complete! Processed {len(results)} videos.")
|
|
1307
|
+
|
|
1308
|
+
def _autosave_path(self) -> str:
|
|
1309
|
+
"""Return the path for the incremental autosave file next to the loaded model."""
|
|
1310
|
+
model_path = self.model_path_edit.text()
|
|
1311
|
+
if not model_path:
|
|
1312
|
+
return ""
|
|
1313
|
+
return os.path.join(os.path.dirname(model_path), "_inference_autosave.json")
|
|
1314
|
+
|
|
1315
|
+
def _on_video_done(self, video_path: str, res_entry: dict):
|
|
1316
|
+
"""Called after each video finishes — accumulate and autosave incrementally."""
|
|
1317
|
+
# Initialise corrections fields
|
|
1318
|
+
res_entry.setdefault("corrected_labels", {})
|
|
1319
|
+
res_entry.setdefault("corrected_attr_labels", {})
|
|
1320
|
+
self.results_cache[video_path] = res_entry
|
|
1321
|
+
self.export_btn.setEnabled(True)
|
|
1322
|
+
self.export_timeline_btn.setEnabled(True)
|
|
1323
|
+
self.save_results_btn.setEnabled(True)
|
|
1324
|
+
if "clip_attention_maps" in res_entry:
|
|
1325
|
+
self.export_attention_btn.setEnabled(True)
|
|
1326
|
+
|
|
1327
|
+
self.filter_video_combo.blockSignals(True)
|
|
1328
|
+
existing_idx = self.filter_video_combo.findData(video_path)
|
|
1329
|
+
if existing_idx < 0:
|
|
1330
|
+
self.filter_video_combo.addItem(os.path.basename(video_path), video_path)
|
|
1331
|
+
model = self.filter_video_combo.model()
|
|
1332
|
+
if model is not None:
|
|
1333
|
+
self.filter_video_combo.model().sort(0)
|
|
1334
|
+
self.filter_video_combo.blockSignals(False)
|
|
1335
|
+
|
|
1336
|
+
current_video = self.filter_video_combo.currentData()
|
|
1337
|
+
if not current_video:
|
|
1338
|
+
idx = self.filter_video_combo.findData(video_path)
|
|
1339
|
+
if idx >= 0:
|
|
1340
|
+
self.filter_video_combo.setCurrentIndex(idx)
|
|
1341
|
+
self._on_video_selection_changed(idx)
|
|
1342
|
+
elif current_video == video_path:
|
|
1343
|
+
idx = self.filter_video_combo.currentIndex()
|
|
1344
|
+
self._on_video_selection_changed(idx)
|
|
1345
|
+
|
|
1346
|
+
finished_count = len(self.results_cache)
|
|
1347
|
+
clip_count = len(res_entry.get("predictions", []))
|
|
1348
|
+
self.log_text.append(
|
|
1349
|
+
f"Finished {os.path.basename(video_path)}: {clip_count} clips ready to review "
|
|
1350
|
+
f"({finished_count} video{'s' if finished_count != 1 else ''} available so far)."
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1353
|
+
autosave_path = self._autosave_path()
|
|
1354
|
+
if not autosave_path:
|
|
1355
|
+
return
|
|
1356
|
+
try:
|
|
1357
|
+
save_data = {
|
|
1358
|
+
"classes": self.classes,
|
|
1359
|
+
"parameters": {
|
|
1360
|
+
"target_fps": self.target_fps_spin.value(),
|
|
1361
|
+
"clip_length": self.clip_length_spin.value(),
|
|
1362
|
+
"step_frames": self.step_frames_spin.value(),
|
|
1363
|
+
"sample_inference_enabled": bool(self.sample_inference_check.isChecked()),
|
|
1364
|
+
"sample_duration_seconds": int(self.sample_duration_spin.value()),
|
|
1365
|
+
"sample_num_chunks": int(self.sample_count_spin.value()),
|
|
1366
|
+
"frame_aggregation_enabled": self.frame_aggregation_check.isChecked(),
|
|
1367
|
+
"merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
|
|
1368
|
+
"use_viterbi_decode": bool(self.use_viterbi_decode),
|
|
1369
|
+
"viterbi_switch_penalty": float(self.viterbi_switch_penalty),
|
|
1370
|
+
"min_segment_frames": int(self._min_segment_frames),
|
|
1371
|
+
"merge_gap_frames": int(self._merge_gap_frames),
|
|
1372
|
+
"temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
|
|
1373
|
+
"class_min_segment_frames": {
|
|
1374
|
+
cls: int(v) for cls, v in self.class_min_segment_frames.items()
|
|
1375
|
+
},
|
|
1376
|
+
"class_merge_gap_frames": {
|
|
1377
|
+
cls: int(v) for cls, v in self.class_merge_gap_frames.items()
|
|
1378
|
+
},
|
|
1379
|
+
"class_smoothing_window_frames": {
|
|
1380
|
+
cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
|
|
1381
|
+
},
|
|
1382
|
+
"use_ovr": bool(self._use_ovr),
|
|
1383
|
+
"ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
|
|
1384
|
+
"ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
|
|
1385
|
+
"use_ignore_threshold": bool(self.use_ignore_threshold),
|
|
1386
|
+
"ignore_threshold": float(self.global_ignore_threshold),
|
|
1387
|
+
"class_ignore_thresholds": {
|
|
1388
|
+
cls: float(t) for cls, t in self.class_ignore_thresholds.items()
|
|
1389
|
+
},
|
|
1390
|
+
"timeline_zoom": int(self.timeline_zoom_spin.value()),
|
|
1391
|
+
},
|
|
1392
|
+
"results": self.results_cache,
|
|
1393
|
+
"_autosave": True,
|
|
1394
|
+
}
|
|
1395
|
+
self._write_results_bundle(autosave_path, save_data, pretty=False)
|
|
1396
|
+
self.log_text.append(
|
|
1397
|
+
f" [Autosave] {len(self.results_cache)} video(s) saved → {os.path.basename(autosave_path)}"
|
|
1398
|
+
)
|
|
1399
|
+
except Exception as e:
|
|
1400
|
+
self.log_text.append(f" [Autosave] Warning: could not write autosave: {e}")
|
|
1401
|
+
|
|
1402
|
+
def _check_autosave_recovery(self, model_path: str):
|
|
1403
|
+
"""After loading a model, check if an autosave from a crashed batch run exists."""
|
|
1404
|
+
autosave_path = os.path.join(os.path.dirname(model_path), "_inference_autosave.json")
|
|
1405
|
+
if not os.path.exists(autosave_path):
|
|
1406
|
+
return
|
|
1407
|
+
try:
|
|
1408
|
+
import json as _json
|
|
1409
|
+
with open(autosave_path) as f:
|
|
1410
|
+
data = _json.load(f)
|
|
1411
|
+
n_videos = len(data.get("results", {}))
|
|
1412
|
+
if n_videos == 0:
|
|
1413
|
+
return
|
|
1414
|
+
reply = QMessageBox.question(
|
|
1415
|
+
self,
|
|
1416
|
+
"Recover interrupted batch run?",
|
|
1417
|
+
f"Found an autosave from a previous session with {n_videos} video(s) already processed.\n\n"
|
|
1418
|
+
f"Would you like to restore those results now?\n\n"
|
|
1419
|
+
f"(Autosave file: {os.path.basename(autosave_path)})",
|
|
1420
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
1421
|
+
)
|
|
1422
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
1423
|
+
self._load_timeline_results(path=autosave_path)
|
|
1424
|
+
self.log_text.append(
|
|
1425
|
+
f"Recovered {n_videos} video(s) from autosave. "
|
|
1426
|
+
"You can continue running inference on remaining videos."
|
|
1427
|
+
)
|
|
1428
|
+
except Exception as e:
|
|
1429
|
+
self.log_text.append(f"[Autosave] Could not read autosave file: {e}")
|
|
1430
|
+
|
|
1431
|
+
def _save_results(self, silent=False):
|
|
1432
|
+
"""Save current inference results to JSON."""
|
|
1433
|
+
if not self.results_cache:
|
|
1434
|
+
return
|
|
1435
|
+
|
|
1436
|
+
try:
|
|
1437
|
+
# Determine default save path based on experiment config
|
|
1438
|
+
exp_path = self.config.get("experiment_path")
|
|
1439
|
+
if exp_path:
|
|
1440
|
+
results_dir = os.path.join(exp_path, "results")
|
|
1441
|
+
else:
|
|
1442
|
+
results_dir = os.path.join(self.config.get("data_dir", "data"), "results")
|
|
1443
|
+
|
|
1444
|
+
os.makedirs(results_dir, exist_ok=True)
|
|
1445
|
+
results_path = os.path.join(results_dir, "inference_results.json")
|
|
1446
|
+
|
|
1447
|
+
# If manual save and loaded from file, ask user
|
|
1448
|
+
if not silent and hasattr(self, 'loaded_results_path') and self.loaded_results_path:
|
|
1449
|
+
msg = QMessageBox()
|
|
1450
|
+
msg.setIcon(QMessageBox.Icon.Question)
|
|
1451
|
+
msg.setText("Save timeline results")
|
|
1452
|
+
msg.setInformativeText(f"Results loaded from:\n{self.loaded_results_path}\n\nDo you want to overwrite or save as new?")
|
|
1453
|
+
msg.setWindowTitle("Save results")
|
|
1454
|
+
overwrite_btn = msg.addButton("Overwrite", QMessageBox.ButtonRole.AcceptRole)
|
|
1455
|
+
save_as_btn = msg.addButton("Save As...", QMessageBox.ButtonRole.ActionRole)
|
|
1456
|
+
msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
|
|
1457
|
+
msg.exec()
|
|
1458
|
+
|
|
1459
|
+
if msg.clickedButton() == overwrite_btn:
|
|
1460
|
+
results_path = self.loaded_results_path
|
|
1461
|
+
elif msg.clickedButton() == save_as_btn:
|
|
1462
|
+
path, _ = QFileDialog.getSaveFileName(
|
|
1463
|
+
self, "Save Results As", self.loaded_results_path, "JSON Files (*.json)"
|
|
1464
|
+
)
|
|
1465
|
+
if not path: return
|
|
1466
|
+
results_path = path
|
|
1467
|
+
else:
|
|
1468
|
+
return
|
|
1469
|
+
|
|
1470
|
+
# Persist current video state into cache so it is saved
|
|
1471
|
+
self._persist_current_video_state()
|
|
1472
|
+
|
|
1473
|
+
# If frame aggregation is enabled, persist aggregated segments for all videos.
|
|
1474
|
+
frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
|
|
1475
|
+
if frame_aggregation_enabled:
|
|
1476
|
+
state = {
|
|
1477
|
+
"video_path": self.video_path,
|
|
1478
|
+
"predictions": self.predictions,
|
|
1479
|
+
"confidences": self.confidences,
|
|
1480
|
+
"clip_probabilities": self.clip_probabilities,
|
|
1481
|
+
"clip_starts": self.clip_starts,
|
|
1482
|
+
"clip_frame_probabilities": getattr(self, "clip_frame_probabilities", []),
|
|
1483
|
+
"total_frames": self.total_frames,
|
|
1484
|
+
"corrected_labels": self.corrected_labels,
|
|
1485
|
+
"aggregated_segments": self.aggregated_segments,
|
|
1486
|
+
"aggregated_multiclass_segments": self.aggregated_multiclass_segments,
|
|
1487
|
+
}
|
|
1488
|
+
try:
|
|
1489
|
+
for v_path, entry in self.results_cache.items():
|
|
1490
|
+
if not isinstance(entry, dict):
|
|
1491
|
+
continue
|
|
1492
|
+
preds = entry.get("predictions", [])
|
|
1493
|
+
starts = entry.get("clip_starts", [])
|
|
1494
|
+
if not preds or not starts:
|
|
1495
|
+
continue
|
|
1496
|
+
self.video_path = v_path
|
|
1497
|
+
self.predictions = preds
|
|
1498
|
+
self.confidences = entry.get("confidences", [])
|
|
1499
|
+
self.clip_probabilities = entry.get("clip_probabilities", [])
|
|
1500
|
+
self.clip_starts = starts
|
|
1501
|
+
self.total_frames = int(entry.get("total_frames", 0) or 0)
|
|
1502
|
+
self.corrected_labels = entry.get("corrected_labels", {})
|
|
1503
|
+
precomputed = entry.get("aggregated_frame_probs")
|
|
1504
|
+
if isinstance(precomputed, list):
|
|
1505
|
+
precomputed = np.asarray(precomputed, dtype=np.float32)
|
|
1506
|
+
if isinstance(precomputed, np.ndarray) and precomputed.ndim == 2 and precomputed.shape[0] > 0:
|
|
1507
|
+
self._aggregated_frame_scores_norm = precomputed
|
|
1508
|
+
self._aggregated_last_covered_frame = precomputed.shape[0]
|
|
1509
|
+
self._build_timeline_segments()
|
|
1510
|
+
else:
|
|
1511
|
+
self.clip_frame_probabilities = entry.get("clip_frame_probabilities", [])
|
|
1512
|
+
self.aggregated_segments = []
|
|
1513
|
+
self.aggregated_multiclass_segments = []
|
|
1514
|
+
self._compute_aggregated_timeline()
|
|
1515
|
+
entry["aggregated_segments"] = copy.deepcopy(self.aggregated_segments)
|
|
1516
|
+
entry["aggregated_multiclass_segments"] = copy.deepcopy(self.aggregated_multiclass_segments)
|
|
1517
|
+
finally:
|
|
1518
|
+
self.video_path = state["video_path"]
|
|
1519
|
+
self.predictions = state["predictions"]
|
|
1520
|
+
self.confidences = state["confidences"]
|
|
1521
|
+
self.clip_probabilities = state["clip_probabilities"]
|
|
1522
|
+
self.clip_starts = state["clip_starts"]
|
|
1523
|
+
self.clip_frame_probabilities = state["clip_frame_probabilities"]
|
|
1524
|
+
self.total_frames = state["total_frames"]
|
|
1525
|
+
self.corrected_labels = state["corrected_labels"]
|
|
1526
|
+
self.aggregated_segments = state["aggregated_segments"]
|
|
1527
|
+
self.aggregated_multiclass_segments = state["aggregated_multiclass_segments"]
|
|
1528
|
+
|
|
1529
|
+
save_data = {
|
|
1530
|
+
"classes": self.classes,
|
|
1531
|
+
"parameters": {
|
|
1532
|
+
"target_fps": self.target_fps_spin.value(),
|
|
1533
|
+
"clip_length": self.clip_length_spin.value(),
|
|
1534
|
+
"step_frames": self.step_frames_spin.value(),
|
|
1535
|
+
"sample_inference_enabled": bool(self.sample_inference_check.isChecked()),
|
|
1536
|
+
"sample_duration_seconds": int(self.sample_duration_spin.value()),
|
|
1537
|
+
"sample_num_chunks": int(self.sample_count_spin.value()),
|
|
1538
|
+
"frame_aggregation_enabled": frame_aggregation_enabled,
|
|
1539
|
+
"merge_consecutive_enabled": bool(self.merge_timeline_check.isChecked()),
|
|
1540
|
+
"use_viterbi_decode": bool(self.use_viterbi_decode),
|
|
1541
|
+
"viterbi_switch_penalty": float(self.viterbi_switch_penalty),
|
|
1542
|
+
"min_segment_frames": int(self._min_segment_frames),
|
|
1543
|
+
"merge_gap_frames": int(self._merge_gap_frames),
|
|
1544
|
+
"temporal_smoothing_window_frames": int(self._temporal_smoothing_window_frames),
|
|
1545
|
+
"class_min_segment_frames": {
|
|
1546
|
+
cls: int(v) for cls, v in self.class_min_segment_frames.items()
|
|
1547
|
+
},
|
|
1548
|
+
"class_merge_gap_frames": {
|
|
1549
|
+
cls: int(v) for cls, v in self.class_merge_gap_frames.items()
|
|
1550
|
+
},
|
|
1551
|
+
"class_smoothing_window_frames": {
|
|
1552
|
+
cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
|
|
1553
|
+
},
|
|
1554
|
+
"use_ovr": bool(self._use_ovr),
|
|
1555
|
+
"ovr_rows": bool(getattr(self, "ovr_rows_check", None) is not None and self.ovr_rows_check.isChecked()),
|
|
1556
|
+
"ovr_show_all": bool(getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()),
|
|
1557
|
+
"use_ignore_threshold": bool(self.use_ignore_threshold),
|
|
1558
|
+
"ignore_threshold": float(self.global_ignore_threshold),
|
|
1559
|
+
"class_ignore_thresholds": {
|
|
1560
|
+
cls: float(t) for cls, t in self.class_ignore_thresholds.items()
|
|
1561
|
+
},
|
|
1562
|
+
"allowed_cooccurrence": [
|
|
1563
|
+
[a, b] for (a, b) in sorted(self._allowed_cooccurrence) if a <= b
|
|
1564
|
+
],
|
|
1565
|
+
"timeline_zoom": int(self.timeline_zoom_spin.value()),
|
|
1566
|
+
},
|
|
1567
|
+
"results": self.results_cache
|
|
1568
|
+
}
|
|
1569
|
+
self._write_results_bundle(results_path, save_data, pretty=True)
|
|
1570
|
+
|
|
1571
|
+
if frame_aggregation_enabled and self.aggregated_segments:
|
|
1572
|
+
self.log_text.append(f"Inference results saved to: {results_path} (with {len(self.aggregated_segments)} aggregated segments)")
|
|
1573
|
+
else:
|
|
1574
|
+
self.log_text.append(f"Inference results saved to: {results_path}")
|
|
1575
|
+
|
|
1576
|
+
if not silent:
|
|
1577
|
+
QMessageBox.information(self, "Saved", f"Results saved to:\n{results_path}")
|
|
1578
|
+
# Update loaded path if we saved to a new location
|
|
1579
|
+
self.loaded_results_path = results_path
|
|
1580
|
+
|
|
1581
|
+
except Exception as e:
|
|
1582
|
+
self.log_text.append(f"Warning: Failed to save inference results: {e}")
|
|
1583
|
+
if not silent:
|
|
1584
|
+
QMessageBox.critical(self, "Error", f"Failed to save results:\n{str(e)}")
|
|
1585
|
+
|
|
1586
|
+
def _load_timeline_results(self, path: str = None):
|
|
1587
|
+
"""Load previously saved timeline results.
|
|
1588
|
+
|
|
1589
|
+
If *path* is provided (e.g. autosave recovery), skip the file dialog.
|
|
1590
|
+
"""
|
|
1591
|
+
if path:
|
|
1592
|
+
file_path = path
|
|
1593
|
+
else:
|
|
1594
|
+
# Prompt user to select results file
|
|
1595
|
+
exp_path = self.config.get("experiment_path")
|
|
1596
|
+
if exp_path:
|
|
1597
|
+
default_dir = os.path.join(exp_path, "results")
|
|
1598
|
+
else:
|
|
1599
|
+
default_dir = os.path.join(self.config.get("data_dir", "data"), "results")
|
|
1600
|
+
|
|
1601
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
|
1602
|
+
self,
|
|
1603
|
+
"Load Timeline Results",
|
|
1604
|
+
default_dir,
|
|
1605
|
+
"JSON Files (*.json)"
|
|
1606
|
+
)
|
|
1607
|
+
|
|
1608
|
+
if not file_path:
|
|
1609
|
+
return
|
|
1610
|
+
|
|
1611
|
+
self.loaded_results_path = file_path
|
|
1612
|
+
|
|
1613
|
+
MAX_LOAD_SIZE_MB = 300
|
|
1614
|
+
if os.path.getsize(file_path) > MAX_LOAD_SIZE_MB * 1024 * 1024:
|
|
1615
|
+
QMessageBox.critical(
|
|
1616
|
+
self,
|
|
1617
|
+
"File too large",
|
|
1618
|
+
f"This results file is over {MAX_LOAD_SIZE_MB} MB and may cause a crash when loading.\n\n"
|
|
1619
|
+
"Re-run inference and save again (attention maps are now excluded from saves).\n"
|
|
1620
|
+
"Or load a smaller/simpler results file.",
|
|
1621
|
+
)
|
|
1622
|
+
return
|
|
1623
|
+
|
|
1624
|
+
try:
|
|
1625
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
1626
|
+
data = json.load(f)
|
|
1627
|
+
self._restore_external_arrays(file_path, data)
|
|
1628
|
+
|
|
1629
|
+
# Extract data
|
|
1630
|
+
self.classes = data.get("classes", [])
|
|
1631
|
+
parameters = data.get("parameters", {})
|
|
1632
|
+
self.results_cache = data.get("results", {})
|
|
1633
|
+
self._ignore_threshold_user_modified = False
|
|
1634
|
+
|
|
1635
|
+
if not self.classes or not self.results_cache:
|
|
1636
|
+
QMessageBox.warning(self, "Invalid File", "The selected file does not contain valid inference results.")
|
|
1637
|
+
return
|
|
1638
|
+
|
|
1639
|
+
for v_path in self.results_cache:
|
|
1640
|
+
if "corrected_labels" not in self.results_cache[v_path]:
|
|
1641
|
+
self.results_cache[v_path]["corrected_labels"] = {}
|
|
1642
|
+
if "corrected_attr_labels" not in self.results_cache[v_path]:
|
|
1643
|
+
self.results_cache[v_path]["corrected_attr_labels"] = {}
|
|
1644
|
+
|
|
1645
|
+
# Update UI with loaded parameters
|
|
1646
|
+
self.target_fps_spin.setValue(parameters.get("target_fps", 16))
|
|
1647
|
+
self.clip_length_spin.setValue(parameters.get("clip_length", 16))
|
|
1648
|
+
self.step_frames_spin.setValue(max(1, self.clip_length_spin.value() // 2))
|
|
1649
|
+
self.sample_inference_check.setChecked(bool(parameters.get("sample_inference_enabled", False)))
|
|
1650
|
+
self.sample_duration_spin.setValue(int(parameters.get("sample_duration_seconds", 60)))
|
|
1651
|
+
self.sample_count_spin.setValue(int(parameters.get("sample_num_chunks", 5)))
|
|
1652
|
+
loaded_use_ovr = parameters.get("use_ovr", None)
|
|
1653
|
+
if loaded_use_ovr is None:
|
|
1654
|
+
loaded_use_ovr = any(
|
|
1655
|
+
bool((v or {}).get("aggregated_multiclass_segments", []))
|
|
1656
|
+
for v in self.results_cache.values()
|
|
1657
|
+
if isinstance(v, dict)
|
|
1658
|
+
)
|
|
1659
|
+
self._use_ovr = bool(loaded_use_ovr)
|
|
1660
|
+
self._allowed_cooccurrence = set()
|
|
1661
|
+
for pair in parameters.get("allowed_cooccurrence", []):
|
|
1662
|
+
if isinstance(pair, (list, tuple)) and len(pair) == 2:
|
|
1663
|
+
a, b = pair[0], pair[1]
|
|
1664
|
+
self._allowed_cooccurrence.add((a, b))
|
|
1665
|
+
self._allowed_cooccurrence.add((b, a))
|
|
1666
|
+
if getattr(self, "ovr_rows_check", None) is not None and self._use_ovr:
|
|
1667
|
+
self.ovr_rows_check.setChecked(True)
|
|
1668
|
+
if getattr(self, "ovr_rows_check", None) is not None and "ovr_rows" in parameters:
|
|
1669
|
+
self.ovr_rows_check.setChecked(bool(parameters.get("ovr_rows", True)))
|
|
1670
|
+
if getattr(self, "ovr_show_all_check", None) is not None:
|
|
1671
|
+
self.ovr_show_all_check.setChecked(bool(parameters.get("ovr_show_all", False)))
|
|
1672
|
+
if "merge_consecutive_enabled" in parameters:
|
|
1673
|
+
self.merge_timeline_check.setChecked(bool(parameters.get("merge_consecutive_enabled", False)))
|
|
1674
|
+
if "use_viterbi_decode" in parameters:
|
|
1675
|
+
self.use_viterbi_decode = bool(parameters.get("use_viterbi_decode", False))
|
|
1676
|
+
self.use_viterbi_check.setChecked(self.use_viterbi_decode)
|
|
1677
|
+
if "viterbi_switch_penalty" in parameters:
|
|
1678
|
+
self.viterbi_switch_penalty = float(parameters.get("viterbi_switch_penalty", 0.35))
|
|
1679
|
+
self.viterbi_switch_penalty_spin.setValue(self.viterbi_switch_penalty)
|
|
1680
|
+
|
|
1681
|
+
# Restore frame aggregation setting
|
|
1682
|
+
frame_aggregation_enabled = parameters.get("frame_aggregation_enabled", False)
|
|
1683
|
+
self.frame_aggregation_check.setChecked(frame_aggregation_enabled)
|
|
1684
|
+
if "min_segment_frames" in parameters:
|
|
1685
|
+
self._min_segment_frames = int(parameters.get("min_segment_frames", 1))
|
|
1686
|
+
if "merge_gap_frames" in parameters:
|
|
1687
|
+
self._merge_gap_frames = int(parameters.get("merge_gap_frames", 0))
|
|
1688
|
+
if "temporal_smoothing_window_frames" in parameters:
|
|
1689
|
+
loaded_win = int(parameters.get("temporal_smoothing_window_frames", 1))
|
|
1690
|
+
if loaded_win % 2 == 0:
|
|
1691
|
+
loaded_win += 1
|
|
1692
|
+
self._temporal_smoothing_window_frames = max(1, loaded_win)
|
|
1693
|
+
if "class_min_segment_frames" in parameters:
|
|
1694
|
+
self.class_min_segment_frames = {
|
|
1695
|
+
str(cls): int(v) for cls, v in parameters.get("class_min_segment_frames", {}).items()
|
|
1696
|
+
}
|
|
1697
|
+
if "class_merge_gap_frames" in parameters:
|
|
1698
|
+
self.class_merge_gap_frames = {
|
|
1699
|
+
str(cls): int(v) for cls, v in parameters.get("class_merge_gap_frames", {}).items()
|
|
1700
|
+
}
|
|
1701
|
+
if "class_smoothing_window_frames" in parameters:
|
|
1702
|
+
self.class_smoothing_window_frames = {
|
|
1703
|
+
str(cls): int(v) for cls, v in parameters.get("class_smoothing_window_frames", {}).items()
|
|
1704
|
+
}
|
|
1705
|
+
self._sync_per_class_segment_rule_config()
|
|
1706
|
+
|
|
1707
|
+
# Restore timeline zoom
|
|
1708
|
+
if "timeline_zoom" in parameters:
|
|
1709
|
+
self.timeline_zoom_spin.blockSignals(True)
|
|
1710
|
+
self.timeline_zoom_spin.setValue(int(parameters["timeline_zoom"]))
|
|
1711
|
+
self.timeline_zoom_spin.blockSignals(False)
|
|
1712
|
+
|
|
1713
|
+
# Restore ignore threshold settings so the loaded timeline matches what was saved
|
|
1714
|
+
if "use_ignore_threshold" in parameters:
|
|
1715
|
+
self.use_ignore_threshold = bool(parameters["use_ignore_threshold"])
|
|
1716
|
+
self.use_ignore_threshold_check.setChecked(self.use_ignore_threshold)
|
|
1717
|
+
if "ignore_threshold" in parameters:
|
|
1718
|
+
self.global_ignore_threshold = float(parameters["ignore_threshold"])
|
|
1719
|
+
self.ignore_threshold_spin.setValue(self.global_ignore_threshold)
|
|
1720
|
+
if "class_ignore_thresholds" in parameters:
|
|
1721
|
+
self.class_ignore_thresholds = {
|
|
1722
|
+
cls: float(t) for cls, t in parameters["class_ignore_thresholds"].items()
|
|
1723
|
+
}
|
|
1724
|
+
self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
|
|
1725
|
+
self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
|
|
1726
|
+
self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
|
|
1727
|
+
self._update_viterbi_ui_state()
|
|
1728
|
+
|
|
1729
|
+
self.classes_label.setText(", ".join(self.classes))
|
|
1730
|
+
|
|
1731
|
+
# Populate video filter combo
|
|
1732
|
+
self.filter_video_combo.blockSignals(True)
|
|
1733
|
+
self.filter_video_combo.clear()
|
|
1734
|
+
|
|
1735
|
+
video_paths = list(self.results_cache.keys())
|
|
1736
|
+
for vp in video_paths:
|
|
1737
|
+
video_name = os.path.basename(vp)
|
|
1738
|
+
self.filter_video_combo.addItem(video_name, vp)
|
|
1739
|
+
|
|
1740
|
+
self.filter_video_combo.blockSignals(False)
|
|
1741
|
+
|
|
1742
|
+
# Load first video's results
|
|
1743
|
+
if video_paths:
|
|
1744
|
+
first_video = video_paths[0]
|
|
1745
|
+
self.video_path = first_video
|
|
1746
|
+
self.video_paths = video_paths
|
|
1747
|
+
|
|
1748
|
+
# Update video path display
|
|
1749
|
+
if len(video_paths) == 1:
|
|
1750
|
+
self.video_path_edit.setText(first_video)
|
|
1751
|
+
else:
|
|
1752
|
+
self.video_path_edit.setText(f"{len(video_paths)} videos loaded")
|
|
1753
|
+
|
|
1754
|
+
# Load first video data
|
|
1755
|
+
data = self.results_cache[first_video]
|
|
1756
|
+
self.predictions = data["predictions"]
|
|
1757
|
+
self.confidences = data["confidences"]
|
|
1758
|
+
self.clip_probabilities = data.get("clip_probabilities", [])
|
|
1759
|
+
self.clip_frame_probabilities = data.get("clip_frame_probabilities", [])
|
|
1760
|
+
self.attr_predictions = data.get("attr_predictions", [])
|
|
1761
|
+
self.attr_confidences = data.get("attr_confidences", [])
|
|
1762
|
+
self.clip_starts = data["clip_starts"]
|
|
1763
|
+
self.localization_bboxes = data.get("localization_bboxes", [])
|
|
1764
|
+
self.total_frames = data.get("total_frames", 0)
|
|
1765
|
+
self.corrected_labels = data.get("corrected_labels", {})
|
|
1766
|
+
self.corrected_attr_labels = data.get("corrected_attr_labels", {})
|
|
1767
|
+
|
|
1768
|
+
# Backwards compatibility: if total_frames not saved, try to get from video
|
|
1769
|
+
if self.total_frames <= 0 and os.path.exists(first_video):
|
|
1770
|
+
try:
|
|
1771
|
+
cap = cv2.VideoCapture(first_video)
|
|
1772
|
+
self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
1773
|
+
cap.release()
|
|
1774
|
+
# Store for future use
|
|
1775
|
+
data["total_frames"] = self.total_frames
|
|
1776
|
+
except Exception:
|
|
1777
|
+
pass
|
|
1778
|
+
|
|
1779
|
+
# Load aggregated segments if saved, otherwise compute if frame aggregation enabled
|
|
1780
|
+
if "aggregated_segments" in data:
|
|
1781
|
+
self.aggregated_segments = data["aggregated_segments"]
|
|
1782
|
+
self.aggregated_multiclass_segments = data.get("aggregated_multiclass_segments", [])
|
|
1783
|
+
self.log_text.append(f"Loaded {len(self.aggregated_segments)} saved aggregated segments")
|
|
1784
|
+
|
|
1785
|
+
# Load precomputed frame probs if available (may be list when loaded from JSON)
|
|
1786
|
+
precomputed_probs = data.get("aggregated_frame_probs")
|
|
1787
|
+
if isinstance(precomputed_probs, list):
|
|
1788
|
+
precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
|
|
1789
|
+
if isinstance(precomputed_probs, np.ndarray):
|
|
1790
|
+
self._aggregated_frame_scores_norm = precomputed_probs
|
|
1791
|
+
self._aggregated_last_covered_frame = len(precomputed_probs)
|
|
1792
|
+
elif frame_aggregation_enabled and self._use_ovr and self.clip_probabilities:
|
|
1793
|
+
self._compute_aggregated_timeline()
|
|
1794
|
+
elif frame_aggregation_enabled:
|
|
1795
|
+
# Check for pre-computed frame probs (from worker or JSON load)
|
|
1796
|
+
precomputed_probs = data.get("aggregated_frame_probs")
|
|
1797
|
+
if isinstance(precomputed_probs, list):
|
|
1798
|
+
precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
|
|
1799
|
+
if isinstance(precomputed_probs, np.ndarray):
|
|
1800
|
+
self._aggregated_frame_scores_norm = precomputed_probs
|
|
1801
|
+
self._aggregated_last_covered_frame = len(precomputed_probs)
|
|
1802
|
+
self._build_timeline_segments()
|
|
1803
|
+
else:
|
|
1804
|
+
self._compute_aggregated_timeline()
|
|
1805
|
+
|
|
1806
|
+
# Populate behavior filter
|
|
1807
|
+
self.filter_behavior_combo.blockSignals(True)
|
|
1808
|
+
self.filter_behavior_combo.clear()
|
|
1809
|
+
self.filter_behavior_combo.addItem("All Behaviors")
|
|
1810
|
+
self.filter_behavior_combo.addItem(self.ignore_label_name)
|
|
1811
|
+
self.filter_behavior_combo.addItems(self.classes)
|
|
1812
|
+
self.filter_behavior_combo.blockSignals(False)
|
|
1813
|
+
|
|
1814
|
+
# Display results
|
|
1815
|
+
self._display_results()
|
|
1816
|
+
|
|
1817
|
+
# Enable export/preview buttons
|
|
1818
|
+
self.export_btn.setEnabled(True)
|
|
1819
|
+
self.preview_btn.setEnabled(True)
|
|
1820
|
+
self.export_timeline_btn.setEnabled(True)
|
|
1821
|
+
self.save_results_btn.setEnabled(True) # Enable saving loaded/corrected results
|
|
1822
|
+
|
|
1823
|
+
self.log_text.append(f"Loaded timeline results from: {file_path}")
|
|
1824
|
+
self.log_text.append(f"Loaded {len(video_paths)} video(s) with {len(self.predictions)} predictions")
|
|
1825
|
+
|
|
1826
|
+
QMessageBox.information(
|
|
1827
|
+
self,
|
|
1828
|
+
"Loaded",
|
|
1829
|
+
f"Successfully loaded timeline results.\n\n"
|
|
1830
|
+
f"Videos: {len(video_paths)}\n"
|
|
1831
|
+
f"Classes: {len(self.classes)}\n\n"
|
|
1832
|
+
f"You can now review clips, make corrections, and re-save."
|
|
1833
|
+
)
|
|
1834
|
+
|
|
1835
|
+
except Exception as e:
|
|
1836
|
+
self.log_text.append(f"Error loading timeline results: {e}")
|
|
1837
|
+
QMessageBox.critical(self, "Error", f"Failed to load timeline results:\n{str(e)}")
|
|
1838
|
+
|
|
1839
|
+
def _on_video_selection_changed(self, index: int):
|
|
1840
|
+
"""Handle video selection change."""
|
|
1841
|
+
video_path = self.filter_video_combo.currentData()
|
|
1842
|
+
self._load_video_from_cache(video_path)
|
|
1843
|
+
|
|
1844
|
+
def _current_threshold_settings(self) -> dict:
|
|
1845
|
+
return {
|
|
1846
|
+
"use_ignore_threshold": bool(self.use_ignore_threshold),
|
|
1847
|
+
"ignore_threshold": float(self.global_ignore_threshold),
|
|
1848
|
+
"class_ignore_thresholds": {
|
|
1849
|
+
str(cls): float(t) for cls, t in self.class_ignore_thresholds.items()
|
|
1850
|
+
},
|
|
1851
|
+
"user_modified": bool(self._ignore_threshold_user_modified),
|
|
1852
|
+
}
|
|
1853
|
+
|
|
1854
|
+
def _apply_threshold_settings(self, settings: dict):
|
|
1855
|
+
use_ignore = bool(settings.get("use_ignore_threshold", self.use_ignore_threshold))
|
|
1856
|
+
tau = settings.get("ignore_threshold", self.global_ignore_threshold)
|
|
1857
|
+
try:
|
|
1858
|
+
tau = float(tau)
|
|
1859
|
+
except Exception:
|
|
1860
|
+
tau = self.global_ignore_threshold
|
|
1861
|
+
raw_per_class = settings.get("class_ignore_thresholds", {})
|
|
1862
|
+
if not isinstance(raw_per_class, dict):
|
|
1863
|
+
raw_per_class = {}
|
|
1864
|
+
per_class = {str(cls): float(t) for cls, t in raw_per_class.items()}
|
|
1865
|
+
|
|
1866
|
+
self._applying_auto_threshold = True
|
|
1867
|
+
try:
|
|
1868
|
+
self.use_ignore_threshold_check.blockSignals(True)
|
|
1869
|
+
self.ignore_threshold_spin.blockSignals(True)
|
|
1870
|
+
try:
|
|
1871
|
+
self.use_ignore_threshold_check.setChecked(use_ignore)
|
|
1872
|
+
self.ignore_threshold_spin.setValue(tau)
|
|
1873
|
+
finally:
|
|
1874
|
+
self.use_ignore_threshold_check.blockSignals(False)
|
|
1875
|
+
self.ignore_threshold_spin.blockSignals(False)
|
|
1876
|
+
finally:
|
|
1877
|
+
self._applying_auto_threshold = False
|
|
1878
|
+
|
|
1879
|
+
self.use_ignore_threshold = use_ignore
|
|
1880
|
+
self.global_ignore_threshold = tau
|
|
1881
|
+
self.class_ignore_thresholds = per_class
|
|
1882
|
+
self._ignore_threshold_user_modified = bool(settings.get("user_modified", False))
|
|
1883
|
+
self.config["inference_use_ignore_threshold"] = self.use_ignore_threshold
|
|
1884
|
+
self.config["inference_ignore_threshold"] = self.global_ignore_threshold
|
|
1885
|
+
self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
|
|
1886
|
+
|
|
1887
|
+
def _persist_current_video_state(self):
|
|
1888
|
+
if not self.video_path or self.video_path not in self.results_cache:
|
|
1889
|
+
return
|
|
1890
|
+
entry = self.results_cache[self.video_path]
|
|
1891
|
+
entry["corrected_labels"] = dict(self.corrected_labels)
|
|
1892
|
+
entry["corrected_attr_labels"] = dict(self.corrected_attr_labels)
|
|
1893
|
+
entry["threshold_settings"] = self._current_threshold_settings()
|
|
1894
|
+
|
|
1895
|
+
def _load_video_from_cache(
|
|
1896
|
+
self,
|
|
1897
|
+
video_path: str,
|
|
1898
|
+
refresh_display: bool = True,
|
|
1899
|
+
persist_current: bool = True,
|
|
1900
|
+
threshold_settings_override: dict | None = None,
|
|
1901
|
+
persist_loaded_thresholds: bool = True,
|
|
1902
|
+
) -> bool:
|
|
1903
|
+
if not video_path or video_path not in self.results_cache:
|
|
1904
|
+
return False
|
|
1905
|
+
|
|
1906
|
+
if persist_current:
|
|
1907
|
+
self._persist_current_video_state()
|
|
1908
|
+
|
|
1909
|
+
data = self.results_cache[video_path]
|
|
1910
|
+
self.predictions = data["predictions"]
|
|
1911
|
+
self.confidences = data["confidences"]
|
|
1912
|
+
self.clip_probabilities = data.get("clip_probabilities", [])
|
|
1913
|
+
self.clip_frame_probabilities = data.get("clip_frame_probabilities", [])
|
|
1914
|
+
self.attr_predictions = data.get("attr_predictions", [])
|
|
1915
|
+
self.attr_confidences = data.get("attr_confidences", [])
|
|
1916
|
+
self.clip_starts = data["clip_starts"]
|
|
1917
|
+
self.localization_bboxes = data.get("localization_bboxes", [])
|
|
1918
|
+
self.total_frames = data.get("total_frames", 0)
|
|
1919
|
+
self.corrected_labels = data.get("corrected_labels", {})
|
|
1920
|
+
self.corrected_attr_labels = data.get("corrected_attr_labels", {})
|
|
1921
|
+
self.video_path = video_path
|
|
1922
|
+
|
|
1923
|
+
threshold_settings = threshold_settings_override
|
|
1924
|
+
if threshold_settings is None:
|
|
1925
|
+
threshold_settings = data.get("threshold_settings")
|
|
1926
|
+
if isinstance(threshold_settings, dict):
|
|
1927
|
+
self._apply_threshold_settings(threshold_settings)
|
|
1928
|
+
else:
|
|
1929
|
+
self._auto_update_ignore_threshold()
|
|
1930
|
+
if persist_loaded_thresholds:
|
|
1931
|
+
data["threshold_settings"] = self._current_threshold_settings()
|
|
1932
|
+
|
|
1933
|
+
if self.clip_frame_probabilities and not self.frame_aggregation_check.isChecked():
|
|
1934
|
+
self.frame_aggregation_check.setChecked(True)
|
|
1935
|
+
self.log_text.append("Frame-head outputs found for this video. Switched to precise frame-boundary mode.")
|
|
1936
|
+
|
|
1937
|
+
if self.total_frames <= 0 and os.path.exists(video_path):
|
|
1938
|
+
try:
|
|
1939
|
+
cap = cv2.VideoCapture(video_path)
|
|
1940
|
+
self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
1941
|
+
cap.release()
|
|
1942
|
+
data["total_frames"] = self.total_frames
|
|
1943
|
+
except Exception:
|
|
1944
|
+
pass
|
|
1945
|
+
|
|
1946
|
+
if self.frame_aggregation_check.isChecked():
|
|
1947
|
+
precomputed_probs = data.get("aggregated_frame_probs")
|
|
1948
|
+
if isinstance(precomputed_probs, list):
|
|
1949
|
+
precomputed_probs = np.asarray(precomputed_probs, dtype=np.float32)
|
|
1950
|
+
if isinstance(precomputed_probs, np.ndarray):
|
|
1951
|
+
self._aggregated_frame_scores_norm = precomputed_probs
|
|
1952
|
+
self._aggregated_last_covered_frame = len(precomputed_probs)
|
|
1953
|
+
self._build_timeline_segments()
|
|
1954
|
+
elif self.use_ignore_threshold:
|
|
1955
|
+
self._compute_aggregated_timeline()
|
|
1956
|
+
elif "aggregated_segments" in data:
|
|
1957
|
+
self.aggregated_segments = data["aggregated_segments"]
|
|
1958
|
+
self.aggregated_multiclass_segments = data.get("aggregated_multiclass_segments", [])
|
|
1959
|
+
self._aggregated_frame_scores_norm = None
|
|
1960
|
+
self._aggregated_active_mask = None
|
|
1961
|
+
self._aggregated_last_covered_frame = 0
|
|
1962
|
+
if self._use_ovr:
|
|
1963
|
+
self._compute_aggregated_timeline()
|
|
1964
|
+
else:
|
|
1965
|
+
self._compute_aggregated_timeline()
|
|
1966
|
+
else:
|
|
1967
|
+
self.aggregated_segments = []
|
|
1968
|
+
self.aggregated_multiclass_segments = []
|
|
1969
|
+
self._aggregated_frame_scores_norm = None
|
|
1970
|
+
self._aggregated_active_mask = None
|
|
1971
|
+
self._aggregated_last_covered_frame = 0
|
|
1972
|
+
|
|
1973
|
+
if refresh_display:
|
|
1974
|
+
self._display_results()
|
|
1975
|
+
return True
|
|
1976
|
+
|
|
1977
|
+
def _on_error(self, error_msg: str):
|
|
1978
|
+
"""Handle inference error."""
|
|
1979
|
+
self.progress_bar.setVisible(False)
|
|
1980
|
+
self.progress_bar.setValue(0)
|
|
1981
|
+
self.progress_label.setText(f"Error: {error_msg}")
|
|
1982
|
+
self.run_inference_btn.setEnabled(True)
|
|
1983
|
+
self.stop_inference_btn.setEnabled(False)
|
|
1984
|
+
self.load_timeline_btn.setEnabled(True)
|
|
1985
|
+
QMessageBox.critical(self, "Inference Error", f"Inference failed:\n{error_msg}")
|
|
1986
|
+
|
|
1987
|
+
def _stop_inference(self):
|
|
1988
|
+
"""Request inference worker to stop; completed videos are kept."""
|
|
1989
|
+
if getattr(self, "worker", None) and self.worker.isRunning():
|
|
1990
|
+
self.worker.stop()
|
|
1991
|
+
self.progress_label.setText("Stopping... (keeping completed videos)")
|
|
1992
|
+
|
|
1993
|
+
def _on_log(self, message: str):
|
|
1994
|
+
"""Handle log message."""
|
|
1995
|
+
self.log_text.append(message)
|
|
1996
|
+
|
|
1997
|
+
def _display_results(self):
|
|
1998
|
+
"""Display inference results."""
|
|
1999
|
+
self._update_viterbi_ui_state()
|
|
2000
|
+
self.results_list.clear()
|
|
2001
|
+
|
|
2002
|
+
# In precise frame mode, show segment-level results instead of per-clip
|
|
2003
|
+
# entries to avoid confusion when a clip contains multiple behaviors.
|
|
2004
|
+
if self.frame_aggregation_check.isChecked() and self.aggregated_segments:
|
|
2005
|
+
for i, seg in enumerate(self.aggregated_segments):
|
|
2006
|
+
cls_idx = int(seg.get("class", -1))
|
|
2007
|
+
start_f = int(seg.get("start", 0))
|
|
2008
|
+
end_f = int(seg.get("end", start_f))
|
|
2009
|
+
conf = float(seg.get("confidence", 0.0))
|
|
2010
|
+
if cls_idx < 0 or cls_idx >= len(self.classes):
|
|
2011
|
+
label = self.ignore_label_name
|
|
2012
|
+
else:
|
|
2013
|
+
label = self.classes[cls_idx]
|
|
2014
|
+
item_text = f"Segment {i+1}: {label} (frames {start_f}-{end_f}, confidence: {conf:.2%})"
|
|
2015
|
+
item = QListWidgetItem(item_text)
|
|
2016
|
+
if conf > 0.7:
|
|
2017
|
+
item.setForeground(QColor(0, 150, 0))
|
|
2018
|
+
elif conf > 0.5:
|
|
2019
|
+
item.setForeground(QColor(200, 150, 0))
|
|
2020
|
+
else:
|
|
2021
|
+
item.setForeground(QColor(200, 0, 0))
|
|
2022
|
+
self.results_list.addItem(item)
|
|
2023
|
+
self._draw_timeline()
|
|
2024
|
+
return
|
|
2025
|
+
|
|
2026
|
+
has_attrs = bool(self.attr_predictions and self.attributes)
|
|
2027
|
+
|
|
2028
|
+
effective_preds = self._effective_predictions()
|
|
2029
|
+
for i, (pred_idx, conf) in enumerate(zip(effective_preds, self.confidences)):
|
|
2030
|
+
if pred_idx < len(self.classes) and pred_idx >= 0:
|
|
2031
|
+
label = self.classes[pred_idx]
|
|
2032
|
+
|
|
2033
|
+
# Append Attribute if available
|
|
2034
|
+
attr_conf = None
|
|
2035
|
+
if has_attrs and i < len(self.attr_predictions):
|
|
2036
|
+
attr_idx = self.attr_predictions[i]
|
|
2037
|
+
if isinstance(attr_idx, int) and attr_idx < len(self.attributes):
|
|
2038
|
+
attr_label = self.attributes[attr_idx]
|
|
2039
|
+
if self.attr_confidences and i < len(self.attr_confidences):
|
|
2040
|
+
attr_conf = self.attr_confidences[i]
|
|
2041
|
+
label = f"{label} ({attr_label})"
|
|
2042
|
+
|
|
2043
|
+
extra_labels = ""
|
|
2044
|
+
if self._use_ovr and i < len(self.clip_probabilities):
|
|
2045
|
+
probs_i = self.clip_probabilities[i]
|
|
2046
|
+
if isinstance(probs_i, (list, tuple)) and len(probs_i) == len(self.classes):
|
|
2047
|
+
active = self._filter_cooccurrence(probs_i, 0.3)
|
|
2048
|
+
above = []
|
|
2049
|
+
for ci in sorted(active):
|
|
2050
|
+
if ci == pred_idx:
|
|
2051
|
+
continue
|
|
2052
|
+
above.append(f"{self.classes[ci]}:{float(probs_i[ci]):.0%}")
|
|
2053
|
+
if above:
|
|
2054
|
+
extra_labels = " + " + ", ".join(above)
|
|
2055
|
+
|
|
2056
|
+
if attr_conf is not None:
|
|
2057
|
+
item_text = f"Clip {i+1}: {label} (class: {conf:.2%}, {attr_label}: {attr_conf:.2%}){extra_labels}"
|
|
2058
|
+
else:
|
|
2059
|
+
item_text = f"Clip {i+1}: {label} (confidence: {conf:.2%}){extra_labels}"
|
|
2060
|
+
item = QListWidgetItem(item_text)
|
|
2061
|
+
|
|
2062
|
+
if conf > 0.7:
|
|
2063
|
+
item.setForeground(QColor(0, 150, 0))
|
|
2064
|
+
elif conf > 0.5:
|
|
2065
|
+
item.setForeground(QColor(200, 150, 0))
|
|
2066
|
+
else:
|
|
2067
|
+
item.setForeground(QColor(200, 0, 0))
|
|
2068
|
+
|
|
2069
|
+
self.results_list.addItem(item)
|
|
2070
|
+
elif pred_idx < 0:
|
|
2071
|
+
item = QListWidgetItem(f"Clip {i+1}: {self.ignore_label_name} (confidence: {conf:.2%})")
|
|
2072
|
+
item.setForeground(QColor(120, 120, 120))
|
|
2073
|
+
self.results_list.addItem(item)
|
|
2074
|
+
|
|
2075
|
+
self._draw_timeline()
|
|
2076
|
+
|
|
2077
|
+
def _threshold_for_pred(self, pred_idx: int) -> float:
|
|
2078
|
+
if pred_idx < 0 or pred_idx >= len(self.classes):
|
|
2079
|
+
return self.global_ignore_threshold
|
|
2080
|
+
cls = self.classes[pred_idx]
|
|
2081
|
+
if cls in self.class_ignore_thresholds:
|
|
2082
|
+
return float(self.class_ignore_thresholds[cls])
|
|
2083
|
+
return self.global_ignore_threshold
|
|
2084
|
+
|
|
2085
|
+
def _effective_prediction_for_clip(self, clip_idx: int) -> int:
|
|
2086
|
+
pred_idx = self.corrected_labels[clip_idx] if clip_idx in self.corrected_labels else self.predictions[clip_idx]
|
|
2087
|
+
if not self.use_ignore_threshold:
|
|
2088
|
+
return pred_idx
|
|
2089
|
+
if clip_idx >= len(self.confidences):
|
|
2090
|
+
return pred_idx
|
|
2091
|
+
# Round to 3 decimals so 0.9997 (displayed as ~100%) isn't below a 1.000 threshold
|
|
2092
|
+
conf = round(float(self.confidences[clip_idx]), 3)
|
|
2093
|
+
if conf < self._threshold_for_pred(pred_idx):
|
|
2094
|
+
return -1
|
|
2095
|
+
return pred_idx
|
|
2096
|
+
|
|
2097
|
+
def _effective_predictions(self):
|
|
2098
|
+
if not self.predictions:
|
|
2099
|
+
return []
|
|
2100
|
+
return [self._effective_prediction_for_clip(i) for i in range(len(self.predictions))]
|
|
2101
|
+
|
|
2102
|
+
def _on_ignore_threshold_changed(self, *_args):
|
|
2103
|
+
self.use_ignore_threshold = bool(self.use_ignore_threshold_check.isChecked())
|
|
2104
|
+
self.global_ignore_threshold = float(self.ignore_threshold_spin.value())
|
|
2105
|
+
self.config["inference_use_ignore_threshold"] = self.use_ignore_threshold
|
|
2106
|
+
self.config["inference_ignore_threshold"] = self.global_ignore_threshold
|
|
2107
|
+
if not self._applying_auto_threshold:
|
|
2108
|
+
self._ignore_threshold_user_modified = True
|
|
2109
|
+
self._persist_current_video_state()
|
|
2110
|
+
if self.predictions:
|
|
2111
|
+
if self.frame_aggregation_check.isChecked():
|
|
2112
|
+
self._compute_aggregated_timeline()
|
|
2113
|
+
self._display_results()
|
|
2114
|
+
|
|
2115
|
+
def _update_viterbi_ui_state(self):
|
|
2116
|
+
"""Keep Viterbi controls enabled and explain current mode."""
|
|
2117
|
+
self.use_viterbi_check.setEnabled(True)
|
|
2118
|
+
self.viterbi_switch_penalty_spin.setEnabled(bool(self.use_viterbi_decode))
|
|
2119
|
+
if self._use_ovr:
|
|
2120
|
+
self.use_viterbi_check.setToolTip(
|
|
2121
|
+
"Inference-only sequence decoding on the merged frame probabilities.\n"
|
|
2122
|
+
"OvR uses binary per-class Viterbi, then keeps top classes consistent\n"
|
|
2123
|
+
"with allowed co-occurrence rules."
|
|
2124
|
+
)
|
|
2125
|
+
else:
|
|
2126
|
+
self.use_viterbi_check.setToolTip(
|
|
2127
|
+
"Inference-only sequence decoding for the single-label timeline.\n"
|
|
2128
|
+
"Reduces rapid frame-to-frame class switching without retraining."
|
|
2129
|
+
)
|
|
2130
|
+
|
|
2131
|
+
def _on_viterbi_changed(self, *_args):
|
|
2132
|
+
self.use_viterbi_decode = bool(self.use_viterbi_check.isChecked())
|
|
2133
|
+
self.viterbi_switch_penalty = float(self.viterbi_switch_penalty_spin.value())
|
|
2134
|
+
self.config["inference_use_viterbi_decode"] = self.use_viterbi_decode
|
|
2135
|
+
self.config["inference_viterbi_switch_penalty"] = self.viterbi_switch_penalty
|
|
2136
|
+
if self.use_viterbi_decode and not self.frame_aggregation_check.isChecked():
|
|
2137
|
+
self.frame_aggregation_check.setChecked(True)
|
|
2138
|
+
self._update_viterbi_ui_state()
|
|
2139
|
+
self._persist_current_video_state()
|
|
2140
|
+
if self.predictions and self.frame_aggregation_check.isChecked():
|
|
2141
|
+
self._compute_aggregated_timeline()
|
|
2142
|
+
self._display_results()
|
|
2143
|
+
|
|
2144
|
+
def _auto_update_ignore_threshold(self):
|
|
2145
|
+
"""Auto-update ignore thresholds, preferring validation calibration."""
|
|
2146
|
+
if self._ignore_threshold_user_modified:
|
|
2147
|
+
return
|
|
2148
|
+
cfg = self.model_training_config if isinstance(self.model_training_config, dict) else {}
|
|
2149
|
+
calibrated = cfg.get("validation_calibrated_ignore_thresholds")
|
|
2150
|
+
if isinstance(calibrated, dict):
|
|
2151
|
+
tau = calibrated.get("global_threshold", self.global_ignore_threshold)
|
|
2152
|
+
try:
|
|
2153
|
+
tau = float(tau)
|
|
2154
|
+
except Exception:
|
|
2155
|
+
tau = self.global_ignore_threshold
|
|
2156
|
+
tau = max(0.35, min(0.90, tau))
|
|
2157
|
+
raw_per_class = calibrated.get("per_class_thresholds", {})
|
|
2158
|
+
per_class = {}
|
|
2159
|
+
if isinstance(raw_per_class, dict):
|
|
2160
|
+
for cls_name in self.classes:
|
|
2161
|
+
if cls_name in raw_per_class:
|
|
2162
|
+
try:
|
|
2163
|
+
per_class[cls_name] = max(0.35, min(0.90, float(raw_per_class[cls_name])))
|
|
2164
|
+
except Exception:
|
|
2165
|
+
per_class[cls_name] = tau
|
|
2166
|
+
else:
|
|
2167
|
+
per_class[cls_name] = tau
|
|
2168
|
+
if per_class:
|
|
2169
|
+
self._applying_auto_threshold = True
|
|
2170
|
+
try:
|
|
2171
|
+
self.ignore_threshold_spin.blockSignals(True)
|
|
2172
|
+
try:
|
|
2173
|
+
self.ignore_threshold_spin.setValue(tau)
|
|
2174
|
+
finally:
|
|
2175
|
+
self.ignore_threshold_spin.blockSignals(False)
|
|
2176
|
+
finally:
|
|
2177
|
+
self._applying_auto_threshold = False
|
|
2178
|
+
self.global_ignore_threshold = tau
|
|
2179
|
+
self.class_ignore_thresholds = per_class
|
|
2180
|
+
self.config["inference_ignore_threshold"] = tau
|
|
2181
|
+
self.config["inference_class_ignore_thresholds"] = dict(per_class)
|
|
2182
|
+
source = str(calibrated.get("source", "validation"))
|
|
2183
|
+
self.log_text.append(
|
|
2184
|
+
f"Auto-set ignore thresholds from model validation: global τ={tau:.2f}, "
|
|
2185
|
+
f"per-class τ for {len(per_class)} classes ({source})"
|
|
2186
|
+
)
|
|
2187
|
+
return
|
|
2188
|
+
|
|
2189
|
+
if not self.confidences:
|
|
2190
|
+
return
|
|
2191
|
+
confs = np.array(self.confidences, dtype=float)
|
|
2192
|
+
confs = confs[np.isfinite(confs)]
|
|
2193
|
+
if confs.size == 0:
|
|
2194
|
+
return
|
|
2195
|
+
|
|
2196
|
+
has_validation = bool(cfg) and (cfg.get("use_all_for_training") is False)
|
|
2197
|
+
if has_validation:
|
|
2198
|
+
best_val_f1 = cfg.get("best_val_f1")
|
|
2199
|
+
try:
|
|
2200
|
+
best_val_f1 = float(best_val_f1) if best_val_f1 is not None else None
|
|
2201
|
+
except Exception:
|
|
2202
|
+
best_val_f1 = None
|
|
2203
|
+
if best_val_f1 is not None:
|
|
2204
|
+
quality = max(0.0, min(1.0, (best_val_f1 - 40.0) / 50.0))
|
|
2205
|
+
quantile = 0.80 - 0.20 * quality
|
|
2206
|
+
else:
|
|
2207
|
+
quantile = 0.68
|
|
2208
|
+
source = "current-video fallback (validation-aware quantile)"
|
|
2209
|
+
else:
|
|
2210
|
+
quantile = 0.70
|
|
2211
|
+
source = "current-video fallback"
|
|
2212
|
+
|
|
2213
|
+
tau = float(np.quantile(confs, quantile))
|
|
2214
|
+
tau = max(0.35, min(0.90, tau))
|
|
2215
|
+
self._applying_auto_threshold = True
|
|
2216
|
+
try:
|
|
2217
|
+
self.ignore_threshold_spin.blockSignals(True)
|
|
2218
|
+
try:
|
|
2219
|
+
self.ignore_threshold_spin.setValue(tau)
|
|
2220
|
+
finally:
|
|
2221
|
+
self.ignore_threshold_spin.blockSignals(False)
|
|
2222
|
+
finally:
|
|
2223
|
+
self._applying_auto_threshold = False
|
|
2224
|
+
self.global_ignore_threshold = tau
|
|
2225
|
+
self.config["inference_ignore_threshold"] = tau
|
|
2226
|
+
|
|
2227
|
+
min_class_support = 10
|
|
2228
|
+
per_class = {}
|
|
2229
|
+
for cls_idx, cls_name in enumerate(self.classes):
|
|
2230
|
+
cls_confs = [
|
|
2231
|
+
float(self.confidences[i])
|
|
2232
|
+
for i, pred_idx in enumerate(self.predictions)
|
|
2233
|
+
if pred_idx == cls_idx and i < len(self.confidences) and np.isfinite(self.confidences[i])
|
|
2234
|
+
]
|
|
2235
|
+
if len(cls_confs) >= min_class_support:
|
|
2236
|
+
cls_tau = float(np.quantile(np.array(cls_confs, dtype=float), quantile))
|
|
2237
|
+
cls_tau = max(0.35, min(0.90, cls_tau))
|
|
2238
|
+
per_class[cls_name] = cls_tau
|
|
2239
|
+
else:
|
|
2240
|
+
per_class[cls_name] = tau
|
|
2241
|
+
|
|
2242
|
+
self.class_ignore_thresholds = per_class
|
|
2243
|
+
self.config["inference_class_ignore_thresholds"] = dict(per_class)
|
|
2244
|
+
self.log_text.append(
|
|
2245
|
+
f"Auto-set ignore thresholds: global τ={tau:.2f}, per-class τ for {len(per_class)} classes ({source})"
|
|
2246
|
+
)
|
|
2247
|
+
|
|
2248
|
+
def _open_per_class_thresholds_dialog(self):
|
|
2249
|
+
if not self.classes:
|
|
2250
|
+
QMessageBox.information(self, "No classes", "Load a model first.")
|
|
2251
|
+
return
|
|
2252
|
+
dlg = QDialog(self)
|
|
2253
|
+
dlg.setWindowTitle("Per-class Ignore Thresholds")
|
|
2254
|
+
layout = QFormLayout(dlg)
|
|
2255
|
+
spins = {}
|
|
2256
|
+
for cls in self.classes:
|
|
2257
|
+
sp = QDoubleSpinBox()
|
|
2258
|
+
sp.setDecimals(3)
|
|
2259
|
+
sp.setRange(0.0, 1.0)
|
|
2260
|
+
sp.setSingleStep(0.01)
|
|
2261
|
+
sp.setValue(float(self.class_ignore_thresholds.get(cls, self.global_ignore_threshold)))
|
|
2262
|
+
layout.addRow(cls, sp)
|
|
2263
|
+
spins[cls] = sp
|
|
2264
|
+
btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
|
2265
|
+
layout.addRow(btns)
|
|
2266
|
+
btns.accepted.connect(dlg.accept)
|
|
2267
|
+
btns.rejected.connect(dlg.reject)
|
|
2268
|
+
if dlg.exec():
|
|
2269
|
+
self.class_ignore_thresholds = {cls: float(sp.value()) for cls, sp in spins.items()}
|
|
2270
|
+
self.config["inference_class_ignore_thresholds"] = dict(self.class_ignore_thresholds)
|
|
2271
|
+
self._on_ignore_threshold_changed()
|
|
2272
|
+
|
|
2273
|
+
def _open_per_class_segment_rules_dialog(self):
|
|
2274
|
+
if not self.classes:
|
|
2275
|
+
QMessageBox.information(self, "No classes", "Load a model first.")
|
|
2276
|
+
return
|
|
2277
|
+
|
|
2278
|
+
dlg = QDialog(self)
|
|
2279
|
+
dlg.setWindowTitle("Per-class Segment Rules")
|
|
2280
|
+
dlg.resize(620, 520)
|
|
2281
|
+
root = QVBoxLayout(dlg)
|
|
2282
|
+
root.addWidget(QLabel(
|
|
2283
|
+
"Override precise-boundary postprocessing per class. "
|
|
2284
|
+
"Values matching the global controls act like defaults."
|
|
2285
|
+
))
|
|
2286
|
+
|
|
2287
|
+
scroll = QScrollArea()
|
|
2288
|
+
scroll.setWidgetResizable(True)
|
|
2289
|
+
inner = QWidget()
|
|
2290
|
+
grid = QGridLayout(inner)
|
|
2291
|
+
grid.addWidget(QLabel("Class"), 0, 0)
|
|
2292
|
+
grid.addWidget(QLabel("Smooth"), 0, 1)
|
|
2293
|
+
grid.addWidget(QLabel("Gap"), 0, 2)
|
|
2294
|
+
grid.addWidget(QLabel("Min seg"), 0, 3)
|
|
2295
|
+
|
|
2296
|
+
controls = {}
|
|
2297
|
+
global_smooth = int(max(1, self._temporal_smoothing_window_frames))
|
|
2298
|
+
if global_smooth % 2 == 0:
|
|
2299
|
+
global_smooth += 1
|
|
2300
|
+
global_gap = int(max(0, self._merge_gap_frames))
|
|
2301
|
+
global_min_seg = int(max(1, self._min_segment_frames))
|
|
2302
|
+
|
|
2303
|
+
for row, cls in enumerate(self.classes, start=1):
|
|
2304
|
+
grid.addWidget(QLabel(cls), row, 0)
|
|
2305
|
+
|
|
2306
|
+
sp_smooth = QSpinBox()
|
|
2307
|
+
sp_smooth.setRange(1, 99)
|
|
2308
|
+
sp_smooth.setSingleStep(2)
|
|
2309
|
+
smooth_val = int(self.class_smoothing_window_frames.get(cls, global_smooth))
|
|
2310
|
+
if smooth_val % 2 == 0:
|
|
2311
|
+
smooth_val += 1
|
|
2312
|
+
sp_smooth.setValue(max(1, smooth_val))
|
|
2313
|
+
grid.addWidget(sp_smooth, row, 1)
|
|
2314
|
+
|
|
2315
|
+
sp_gap = QSpinBox()
|
|
2316
|
+
sp_gap.setRange(0, 200)
|
|
2317
|
+
sp_gap.setValue(int(self.class_merge_gap_frames.get(cls, global_gap)))
|
|
2318
|
+
grid.addWidget(sp_gap, row, 2)
|
|
2319
|
+
|
|
2320
|
+
sp_min = QSpinBox()
|
|
2321
|
+
sp_min.setRange(1, 200)
|
|
2322
|
+
sp_min.setValue(int(self.class_min_segment_frames.get(cls, global_min_seg)))
|
|
2323
|
+
grid.addWidget(sp_min, row, 3)
|
|
2324
|
+
|
|
2325
|
+
controls[cls] = (sp_smooth, sp_gap, sp_min)
|
|
2326
|
+
|
|
2327
|
+
scroll.setWidget(inner)
|
|
2328
|
+
root.addWidget(scroll, stretch=1)
|
|
2329
|
+
|
|
2330
|
+
btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
|
2331
|
+
root.addWidget(btns)
|
|
2332
|
+
btns.accepted.connect(dlg.accept)
|
|
2333
|
+
btns.rejected.connect(dlg.reject)
|
|
2334
|
+
|
|
2335
|
+
if dlg.exec():
|
|
2336
|
+
new_smooth = {}
|
|
2337
|
+
new_gap = {}
|
|
2338
|
+
new_min = {}
|
|
2339
|
+
for cls, (sp_smooth, sp_gap, sp_min) in controls.items():
|
|
2340
|
+
smooth_val = int(sp_smooth.value())
|
|
2341
|
+
if smooth_val % 2 == 0:
|
|
2342
|
+
smooth_val += 1
|
|
2343
|
+
gap_val = int(sp_gap.value())
|
|
2344
|
+
min_val = int(sp_min.value())
|
|
2345
|
+
if smooth_val != global_smooth:
|
|
2346
|
+
new_smooth[cls] = smooth_val
|
|
2347
|
+
if gap_val != global_gap:
|
|
2348
|
+
new_gap[cls] = gap_val
|
|
2349
|
+
if min_val != global_min_seg:
|
|
2350
|
+
new_min[cls] = min_val
|
|
2351
|
+
self.class_smoothing_window_frames = new_smooth
|
|
2352
|
+
self.class_merge_gap_frames = new_gap
|
|
2353
|
+
self.class_min_segment_frames = new_min
|
|
2354
|
+
self._sync_per_class_segment_rule_config()
|
|
2355
|
+
if self.predictions and self.frame_aggregation_check.isChecked():
|
|
2356
|
+
self._compute_aggregated_timeline()
|
|
2357
|
+
self._display_results()
|
|
2358
|
+
|
|
2359
|
+
def _get_localization_bbox_for_clip_frame(self, clip_idx: int, frame_idx: int):
|
|
2360
|
+
"""Return normalized xyxy localization bbox for a clip frame, or None."""
|
|
2361
|
+
if clip_idx < 0 or clip_idx >= len(self.localization_bboxes):
|
|
2362
|
+
return None
|
|
2363
|
+
|
|
2364
|
+
raw = self.localization_bboxes[clip_idx]
|
|
2365
|
+
if not isinstance(raw, (list, tuple)) or len(raw) == 0:
|
|
2366
|
+
return None
|
|
2367
|
+
|
|
2368
|
+
def _ema_smooth_boxes(boxes, alpha: float):
|
|
2369
|
+
if not boxes:
|
|
2370
|
+
return boxes
|
|
2371
|
+
prev = [float(v) for v in boxes[0]]
|
|
2372
|
+
smoothed = [prev]
|
|
2373
|
+
for i in range(1, len(boxes)):
|
|
2374
|
+
curr = [float(v) for v in boxes[i]]
|
|
2375
|
+
prev = [
|
|
2376
|
+
alpha * curr[0] + (1.0 - alpha) * prev[0],
|
|
2377
|
+
alpha * curr[1] + (1.0 - alpha) * prev[1],
|
|
2378
|
+
alpha * curr[2] + (1.0 - alpha) * prev[2],
|
|
2379
|
+
alpha * curr[3] + (1.0 - alpha) * prev[3],
|
|
2380
|
+
]
|
|
2381
|
+
smoothed.append(prev)
|
|
2382
|
+
return smoothed
|
|
2383
|
+
|
|
2384
|
+
# Per-clip bbox: [4]
|
|
2385
|
+
if len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
|
|
2386
|
+
vals = raw
|
|
2387
|
+
# Per-frame bboxes: [T,4]
|
|
2388
|
+
elif isinstance(raw[0], (list, tuple)) and len(raw[0]) == 4:
|
|
2389
|
+
idx = max(0, min(int(frame_idx), len(raw) - 1))
|
|
2390
|
+
vals = _ema_smooth_boxes(raw, self._bbox_ema_alpha)[idx]
|
|
2391
|
+
else:
|
|
2392
|
+
return None
|
|
2393
|
+
|
|
2394
|
+
try:
|
|
2395
|
+
x1, y1, x2, y2 = [float(v) for v in vals]
|
|
2396
|
+
except Exception:
|
|
2397
|
+
return None
|
|
2398
|
+
|
|
2399
|
+
x1 = max(0.0, min(1.0, x1))
|
|
2400
|
+
y1 = max(0.0, min(1.0, y1))
|
|
2401
|
+
x2 = max(0.0, min(1.0, x2))
|
|
2402
|
+
y2 = max(0.0, min(1.0, y2))
|
|
2403
|
+
if x2 < x1:
|
|
2404
|
+
x1, x2 = x2, x1
|
|
2405
|
+
if y2 < y1:
|
|
2406
|
+
y1, y2 = y2, y1
|
|
2407
|
+
|
|
2408
|
+
if (x2 - x1) < 1e-4 or (y2 - y1) < 1e-4:
|
|
2409
|
+
return None
|
|
2410
|
+
return x1, y1, x2, y2
|
|
2411
|
+
|
|
2412
|
+
def _get_classification_roi_bbox_for_clip_frame(self, clip_idx: int):
|
|
2413
|
+
"""Return the normalized xyxy ROI used by classification cropping."""
|
|
2414
|
+
if clip_idx < 0 or clip_idx >= len(self.localization_bboxes):
|
|
2415
|
+
return None
|
|
2416
|
+
|
|
2417
|
+
raw = self.localization_bboxes[clip_idx]
|
|
2418
|
+
if not isinstance(raw, (list, tuple)) or len(raw) == 0:
|
|
2419
|
+
return None
|
|
2420
|
+
|
|
2421
|
+
# Match InferenceWorker._build_refined_clips:
|
|
2422
|
+
# if temporal boxes exist, use first-frame box as fixed clip ROI.
|
|
2423
|
+
if isinstance(raw[0], (list, tuple)) and len(raw[0]) == 4:
|
|
2424
|
+
vals = raw[0]
|
|
2425
|
+
elif len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
|
|
2426
|
+
vals = raw
|
|
2427
|
+
else:
|
|
2428
|
+
return None
|
|
2429
|
+
|
|
2430
|
+
try:
|
|
2431
|
+
x1, y1, x2, y2 = [float(v) for v in vals]
|
|
2432
|
+
except Exception:
|
|
2433
|
+
return None
|
|
2434
|
+
|
|
2435
|
+
crop_padding = getattr(self, "_crop_padding", 0.35)
|
|
2436
|
+
crop_min_size = getattr(self, "_crop_min_size", 0.04)
|
|
2437
|
+
x1, y1, x2, y2 = _sanitize_bbox_coords(x1, y1, x2, y2, crop_padding, crop_min_size)
|
|
2438
|
+
if (x2 - x1) < 1e-4 or (y2 - y1) < 1e-4:
|
|
2439
|
+
return None
|
|
2440
|
+
return x1, y1, x2, y2
|
|
2441
|
+
|
|
2442
|
+
def _get_saved_frame_interval(self, video_path: str, orig_fps: float) -> int:
|
|
2443
|
+
"""Return inference-time frame interval for a video when available."""
|
|
2444
|
+
if video_path and isinstance(self.results_cache, dict):
|
|
2445
|
+
entry = self.results_cache.get(video_path, {})
|
|
2446
|
+
if isinstance(entry, dict):
|
|
2447
|
+
stored = entry.get("frame_interval", None)
|
|
2448
|
+
try:
|
|
2449
|
+
stored_val = int(stored)
|
|
2450
|
+
except Exception:
|
|
2451
|
+
stored_val = 0
|
|
2452
|
+
if stored_val > 0:
|
|
2453
|
+
return stored_val
|
|
2454
|
+
|
|
2455
|
+
target_fps = max(1, int(self.target_fps_spin.value()))
|
|
2456
|
+
return max(1, int(round(float(orig_fps) / float(target_fps))))
|
|
2457
|
+
|
|
2458
|
+
def _build_center_merge_weights(self, length: int) -> np.ndarray:
|
|
2459
|
+
"""Match inference-worker overlap merge weighting."""
|
|
2460
|
+
if length <= 1:
|
|
2461
|
+
return np.ones((max(1, length),), dtype=np.float32)
|
|
2462
|
+
w = np.hanning(length).astype(np.float32)
|
|
2463
|
+
if not np.any(w > 0):
|
|
2464
|
+
return np.ones((length,), dtype=np.float32)
|
|
2465
|
+
return np.clip(0.1 + 0.9 * w, 1e-3, None).astype(np.float32)
|
|
2466
|
+
|
|
2467
|
+
def _get_precomputed_aggregated_probs(self, video_path: str = None):
|
|
2468
|
+
"""Return cached worker-built frame probabilities for the active video when available."""
|
|
2469
|
+
v_path = video_path or self.video_path
|
|
2470
|
+
if not v_path or not isinstance(self.results_cache, dict):
|
|
2471
|
+
return None
|
|
2472
|
+
entry = self.results_cache.get(v_path, {})
|
|
2473
|
+
if not isinstance(entry, dict):
|
|
2474
|
+
return None
|
|
2475
|
+
precomputed = entry.get("aggregated_frame_probs")
|
|
2476
|
+
if isinstance(precomputed, list):
|
|
2477
|
+
try:
|
|
2478
|
+
precomputed = np.asarray(precomputed, dtype=np.float32)
|
|
2479
|
+
except Exception:
|
|
2480
|
+
return None
|
|
2481
|
+
if isinstance(precomputed, np.ndarray) and precomputed.ndim == 2 and precomputed.shape[0] > 0:
|
|
2482
|
+
return precomputed
|
|
2483
|
+
return None
|
|
2484
|
+
|
|
2485
|
+
def _smooth_win_for_class(self, cls_idx: int) -> int:
|
|
2486
|
+
base = int(max(1, getattr(self, "_temporal_smoothing_window_frames", 1)))
|
|
2487
|
+
if 0 <= cls_idx < len(self.classes):
|
|
2488
|
+
base = int(self.class_smoothing_window_frames.get(self.classes[cls_idx], base))
|
|
2489
|
+
if base % 2 == 0:
|
|
2490
|
+
base += 1
|
|
2491
|
+
return max(1, base)
|
|
2492
|
+
|
|
2493
|
+
def _gap_fill_for_class(self, cls_idx: int) -> int:
|
|
2494
|
+
base = int(max(0, getattr(self, "_merge_gap_frames", 0)))
|
|
2495
|
+
if 0 <= cls_idx < len(self.classes):
|
|
2496
|
+
base = int(self.class_merge_gap_frames.get(self.classes[cls_idx], base))
|
|
2497
|
+
return max(0, base)
|
|
2498
|
+
|
|
2499
|
+
def _min_seg_for_class(self, cls_idx: int) -> int:
|
|
2500
|
+
base = int(max(1, getattr(self, "_min_segment_frames", 1)))
|
|
2501
|
+
if 0 <= cls_idx < len(self.classes):
|
|
2502
|
+
base = int(self.class_min_segment_frames.get(self.classes[cls_idx], base))
|
|
2503
|
+
return max(1, base)
|
|
2504
|
+
|
|
2505
|
+
def _sync_per_class_segment_rule_config(self):
|
|
2506
|
+
self.config["inference_class_min_segment_frames"] = {
|
|
2507
|
+
cls: int(v) for cls, v in self.class_min_segment_frames.items()
|
|
2508
|
+
}
|
|
2509
|
+
self.config["inference_class_merge_gap_frames"] = {
|
|
2510
|
+
cls: int(v) for cls, v in self.class_merge_gap_frames.items()
|
|
2511
|
+
}
|
|
2512
|
+
self.config["inference_class_smoothing_window_frames"] = {
|
|
2513
|
+
cls: int(v) for cls, v in self.class_smoothing_window_frames.items()
|
|
2514
|
+
}
|
|
2515
|
+
|
|
2516
|
+
def _json_safe_result_value(self, value):
|
|
2517
|
+
"""Recursively convert numpy values into JSON-safe Python types."""
|
|
2518
|
+
if isinstance(value, np.ndarray):
|
|
2519
|
+
return value.tolist()
|
|
2520
|
+
if isinstance(value, np.generic):
|
|
2521
|
+
return value.item()
|
|
2522
|
+
if isinstance(value, dict):
|
|
2523
|
+
skip = {"clip_attention_maps"}
|
|
2524
|
+
return {str(k): self._json_safe_result_value(v) for k, v in value.items() if str(k) not in skip}
|
|
2525
|
+
if isinstance(value, (list, tuple)):
|
|
2526
|
+
return [self._json_safe_result_value(v) for v in value]
|
|
2527
|
+
return value
|
|
2528
|
+
|
|
2529
|
+
def _arrays_sidecar_path(self, json_path: str) -> str:
|
|
2530
|
+
base, ext = os.path.splitext(json_path)
|
|
2531
|
+
if ext.lower() == ".json":
|
|
2532
|
+
return base + ".arrays.npz"
|
|
2533
|
+
return json_path + ".arrays.npz"
|
|
2534
|
+
|
|
2535
|
+
def _coerce_external_result_array(self, key: str, value):
|
|
2536
|
+
if value is None:
|
|
2537
|
+
return None
|
|
2538
|
+
if isinstance(value, np.ndarray):
|
|
2539
|
+
arr = value
|
|
2540
|
+
elif isinstance(value, (list, tuple)) and len(value) > 0:
|
|
2541
|
+
try:
|
|
2542
|
+
arr = np.asarray(value)
|
|
2543
|
+
except Exception:
|
|
2544
|
+
return None
|
|
2545
|
+
else:
|
|
2546
|
+
return None
|
|
2547
|
+
if arr.size <= 0 or arr.dtype == object:
|
|
2548
|
+
return None
|
|
2549
|
+
if np.issubdtype(arr.dtype, np.floating):
|
|
2550
|
+
arr = arr.astype(np.float32, copy=False)
|
|
2551
|
+
elif np.issubdtype(arr.dtype, np.integer):
|
|
2552
|
+
arr = arr.astype(np.int32, copy=False)
|
|
2553
|
+
return np.ascontiguousarray(arr)
|
|
2554
|
+
|
|
2555
|
+
def _prepare_results_for_storage(self, results_cache: dict):
|
|
2556
|
+
heavy_keys = {"clip_probabilities", "clip_frame_probabilities", "aggregated_frame_probs"}
|
|
2557
|
+
json_results = {}
|
|
2558
|
+
external_arrays = {}
|
|
2559
|
+
for video_idx, (video_path, entry) in enumerate(results_cache.items()):
|
|
2560
|
+
if not isinstance(entry, dict):
|
|
2561
|
+
json_results[str(video_path)] = self._json_safe_result_value(entry)
|
|
2562
|
+
continue
|
|
2563
|
+
entry_json = {}
|
|
2564
|
+
entry_refs = {}
|
|
2565
|
+
for key, value in entry.items():
|
|
2566
|
+
if str(key) == "clip_attention_maps":
|
|
2567
|
+
continue
|
|
2568
|
+
if key in heavy_keys:
|
|
2569
|
+
arr = self._coerce_external_result_array(key, value)
|
|
2570
|
+
if arr is not None:
|
|
2571
|
+
store_key = f"video_{video_idx:05d}__{key}"
|
|
2572
|
+
external_arrays[store_key] = arr
|
|
2573
|
+
entry_refs[key] = store_key
|
|
2574
|
+
continue
|
|
2575
|
+
entry_json[str(key)] = self._json_safe_result_value(value)
|
|
2576
|
+
if entry_refs:
|
|
2577
|
+
entry_json["_external_arrays"] = entry_refs
|
|
2578
|
+
json_results[str(video_path)] = entry_json
|
|
2579
|
+
return json_results, external_arrays
|
|
2580
|
+
|
|
2581
|
+
def _write_results_bundle(self, results_path: str, payload: dict, *, pretty: bool):
|
|
2582
|
+
results_copy = dict(payload)
|
|
2583
|
+
json_results, external_arrays = self._prepare_results_for_storage(payload.get("results", {}) or {})
|
|
2584
|
+
results_copy["results"] = json_results
|
|
2585
|
+
sidecar_path = self._arrays_sidecar_path(results_path)
|
|
2586
|
+
if external_arrays:
|
|
2587
|
+
np.savez_compressed(sidecar_path, **external_arrays)
|
|
2588
|
+
results_copy["external_array_store"] = {
|
|
2589
|
+
"format": "npz",
|
|
2590
|
+
"file": os.path.basename(sidecar_path),
|
|
2591
|
+
}
|
|
2592
|
+
else:
|
|
2593
|
+
results_copy.pop("external_array_store", None)
|
|
2594
|
+
if os.path.exists(sidecar_path):
|
|
2595
|
+
try:
|
|
2596
|
+
os.remove(sidecar_path)
|
|
2597
|
+
except Exception:
|
|
2598
|
+
pass
|
|
2599
|
+
with open(results_path, "w", encoding="utf-8") as f:
|
|
2600
|
+
if pretty:
|
|
2601
|
+
json.dump(results_copy, f, indent=2)
|
|
2602
|
+
else:
|
|
2603
|
+
json.dump(results_copy, f)
|
|
2604
|
+
|
|
2605
|
+
def _restore_external_arrays(self, file_path: str, data: dict):
|
|
2606
|
+
if not isinstance(data, dict):
|
|
2607
|
+
return
|
|
2608
|
+
results = data.get("results", {})
|
|
2609
|
+
if not isinstance(results, dict):
|
|
2610
|
+
return
|
|
2611
|
+
store_info = data.get("external_array_store", {})
|
|
2612
|
+
sidecar_file = None
|
|
2613
|
+
if isinstance(store_info, dict):
|
|
2614
|
+
sidecar_file = store_info.get("file")
|
|
2615
|
+
sidecar_path = (
|
|
2616
|
+
os.path.join(os.path.dirname(file_path), sidecar_file)
|
|
2617
|
+
if sidecar_file else self._arrays_sidecar_path(file_path)
|
|
2618
|
+
)
|
|
2619
|
+
if not os.path.exists(sidecar_path):
|
|
2620
|
+
return
|
|
2621
|
+
try:
|
|
2622
|
+
with np.load(sidecar_path, allow_pickle=False) as npz_file:
|
|
2623
|
+
for entry in results.values():
|
|
2624
|
+
if not isinstance(entry, dict):
|
|
2625
|
+
continue
|
|
2626
|
+
refs = entry.pop("_external_arrays", None)
|
|
2627
|
+
if not isinstance(refs, dict):
|
|
2628
|
+
continue
|
|
2629
|
+
for field_name, store_key in refs.items():
|
|
2630
|
+
if store_key not in npz_file:
|
|
2631
|
+
continue
|
|
2632
|
+
arr = np.asarray(npz_file[store_key])
|
|
2633
|
+
if field_name == "aggregated_frame_probs":
|
|
2634
|
+
entry[field_name] = arr.astype(np.float32, copy=False)
|
|
2635
|
+
else:
|
|
2636
|
+
entry[field_name] = arr.tolist()
|
|
2637
|
+
except Exception as exc:
|
|
2638
|
+
self.log_text.append(f"Warning: Failed to load companion results arrays: {exc}")
|
|
2639
|
+
|
|
2640
|
+
def _get_clips_dir(self) -> str:
|
|
2641
|
+
"""Resolved clips directory from config; creates dir if needed."""
|
|
2642
|
+
clips_dir = self.config.get("clips_dir", "data/clips")
|
|
2643
|
+
if not os.path.isabs(clips_dir):
|
|
2644
|
+
exp_path = self.config.get("experiment_path")
|
|
2645
|
+
if exp_path:
|
|
2646
|
+
clips_dir = os.path.join(exp_path, clips_dir)
|
|
2647
|
+
else:
|
|
2648
|
+
clips_dir = os.path.abspath(clips_dir)
|
|
2649
|
+
os.makedirs(clips_dir, exist_ok=True)
|
|
2650
|
+
return clips_dir
|
|
2651
|
+
|
|
2652
|
+
def _get_annotation_file(self) -> str:
|
|
2653
|
+
"""Path to annotations JSON from config."""
|
|
2654
|
+
return self.config.get("annotation_file", "data/annotations/annotations.json")
|
|
2655
|
+
|
|
2656
|
+
def _clip_path_to_id(self, clip_path: str, clips_dir: str) -> str:
|
|
2657
|
+
"""Convert absolute clip path to annotation clip ID (relative, forward slashes)."""
|
|
2658
|
+
clip_id = os.path.relpath(clip_path, clips_dir).replace("\\", "/")
|
|
2659
|
+
if clip_id.startswith("../"):
|
|
2660
|
+
return os.path.basename(clip_path)
|
|
2661
|
+
for prefix in ("../clips/", "clips/", "data/clips/"):
|
|
2662
|
+
if clip_id.startswith(prefix):
|
|
2663
|
+
return clip_id[len(prefix):]
|
|
2664
|
+
return clip_id
|
|
2665
|
+
|
|
2666
|
+
def _get_video_fps(self, video_path: str) -> float:
|
|
2667
|
+
"""Return video FPS from path; 30.0 if unavailable or invalid."""
|
|
2668
|
+
if not video_path:
|
|
2669
|
+
return 30.0
|
|
2670
|
+
try:
|
|
2671
|
+
cap = cv2.VideoCapture(video_path)
|
|
2672
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
2673
|
+
cap.release()
|
|
2674
|
+
return float(fps) if fps and fps > 0 else 30.0
|
|
2675
|
+
except Exception:
|
|
2676
|
+
return 30.0
|
|
2677
|
+
|
|
2678
|
+
def _video_basename(self) -> str:
|
|
2679
|
+
"""Basename of current video path without extension."""
|
|
2680
|
+
if not self.video_path:
|
|
2681
|
+
return ""
|
|
2682
|
+
return os.path.splitext(os.path.basename(self.video_path))[0]
|
|
2683
|
+
|
|
2684
|
+
def _get_timeline_qcolors(self) -> list:
|
|
2685
|
+
"""Timeline palette as list of QColor for drawing."""
|
|
2686
|
+
return [QColor(r, g, b) for r, g, b in self._get_timeline_palette()]
|
|
2687
|
+
|
|
2688
|
+
def _unique_clip_path(self, clip_path: str) -> str:
|
|
2689
|
+
"""Return a path that does not exist by appending _1, _2, ... to stem."""
|
|
2690
|
+
base_clip_path = clip_path
|
|
2691
|
+
suffix = 1
|
|
2692
|
+
while os.path.exists(clip_path):
|
|
2693
|
+
base, ext = os.path.splitext(base_clip_path)
|
|
2694
|
+
clip_path = f"{base}_{suffix}{ext}"
|
|
2695
|
+
suffix += 1
|
|
2696
|
+
return clip_path
|
|
2697
|
+
|
|
2698
|
+
def _merge_predictions(self, predictions, confidences, clip_starts):
|
|
2699
|
+
"""Merge consecutive identical predictions."""
|
|
2700
|
+
if not predictions:
|
|
2701
|
+
return predictions, confidences, clip_starts
|
|
2702
|
+
|
|
2703
|
+
merged_preds = []
|
|
2704
|
+
merged_confs = []
|
|
2705
|
+
merged_starts = []
|
|
2706
|
+
|
|
2707
|
+
current_pred = predictions[0]
|
|
2708
|
+
current_conf = confidences[0]
|
|
2709
|
+
current_start = clip_starts[0]
|
|
2710
|
+
|
|
2711
|
+
for i in range(1, len(predictions)):
|
|
2712
|
+
if predictions[i] == current_pred:
|
|
2713
|
+
# Same behavior, continue merging (keep max confidence)
|
|
2714
|
+
current_conf = max(current_conf, confidences[i])
|
|
2715
|
+
else:
|
|
2716
|
+
# Different behavior, save current segment and start new one
|
|
2717
|
+
merged_preds.append(current_pred)
|
|
2718
|
+
merged_confs.append(current_conf)
|
|
2719
|
+
merged_starts.append(current_start)
|
|
2720
|
+
|
|
2721
|
+
current_pred = predictions[i]
|
|
2722
|
+
current_conf = confidences[i]
|
|
2723
|
+
current_start = clip_starts[i]
|
|
2724
|
+
|
|
2725
|
+
# Add last segment
|
|
2726
|
+
merged_preds.append(current_pred)
|
|
2727
|
+
merged_confs.append(current_conf)
|
|
2728
|
+
merged_starts.append(current_start)
|
|
2729
|
+
|
|
2730
|
+
return merged_preds, merged_confs, merged_starts
|
|
2731
|
+
|
|
2732
|
+
def _draw_timeline(self):
|
|
2733
|
+
"""Draw colored timeline visualization."""
|
|
2734
|
+
if not self.predictions or not self.classes:
|
|
2735
|
+
return
|
|
2736
|
+
|
|
2737
|
+
# OvR: per-class rows when checkbox on; non-OvR: single-row timeline
|
|
2738
|
+
ovr_rows = (
|
|
2739
|
+
getattr(self, "ovr_rows_check", None) is not None
|
|
2740
|
+
and self.ovr_rows_check.isChecked()
|
|
2741
|
+
and self._use_ovr
|
|
2742
|
+
and bool(self.clip_probabilities)
|
|
2743
|
+
)
|
|
2744
|
+
self.timeline_scroll.setVisible(not ovr_rows)
|
|
2745
|
+
self._ovr_timeline_container.setVisible(ovr_rows)
|
|
2746
|
+
|
|
2747
|
+
frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
|
|
2748
|
+
|
|
2749
|
+
# OvR per-class row timeline
|
|
2750
|
+
if ovr_rows:
|
|
2751
|
+
if frame_aggregation_enabled:
|
|
2752
|
+
need_precise_recompute = (
|
|
2753
|
+
not self.aggregated_segments
|
|
2754
|
+
or not isinstance(self._aggregated_frame_scores_norm, np.ndarray)
|
|
2755
|
+
or int(self._aggregated_last_covered_frame) <= 0
|
|
2756
|
+
)
|
|
2757
|
+
if need_precise_recompute:
|
|
2758
|
+
self._compute_aggregated_timeline()
|
|
2759
|
+
self._draw_ovr_multirow_timeline()
|
|
2760
|
+
return
|
|
2761
|
+
|
|
2762
|
+
# Use frame-aggregated segments if enabled
|
|
2763
|
+
if frame_aggregation_enabled and self.aggregated_segments:
|
|
2764
|
+
self._draw_frame_aggregated_timeline()
|
|
2765
|
+
return
|
|
2766
|
+
|
|
2767
|
+
# Original clip-based timeline drawing
|
|
2768
|
+
corrected_preds = self._effective_predictions()
|
|
2769
|
+
|
|
2770
|
+
# Apply merging if enabled
|
|
2771
|
+
merge_enabled = self.merge_timeline_check.isChecked()
|
|
2772
|
+
if merge_enabled:
|
|
2773
|
+
display_preds, display_confs, display_starts = self._merge_predictions(
|
|
2774
|
+
corrected_preds, self.confidences, self.clip_starts
|
|
2775
|
+
)
|
|
2776
|
+
else:
|
|
2777
|
+
display_preds = corrected_preds
|
|
2778
|
+
display_confs = self.confidences
|
|
2779
|
+
display_starts = self.clip_starts
|
|
2780
|
+
|
|
2781
|
+
num_segments = len(display_preds)
|
|
2782
|
+
if num_segments == 0:
|
|
2783
|
+
return
|
|
2784
|
+
|
|
2785
|
+
# Calculate total width using zoom spinbox (px per second).
|
|
2786
|
+
total_clips = len(self.predictions)
|
|
2787
|
+
px_per_sec_clip = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
|
|
2788
|
+
orig_fps_clip = self._get_video_fps(self.video_path) if self.video_path else 30.0
|
|
2789
|
+
# base_clip_width: pixels per clip based on clip_length and zoom
|
|
2790
|
+
clip_dur_sec = self.clip_length_spin.value() / max(1.0, float(self.target_fps_spin.value()))
|
|
2791
|
+
base_clip_width = max(4, int(clip_dur_sec * px_per_sec_clip))
|
|
2792
|
+
total_width = max(800, total_clips * base_clip_width)
|
|
2793
|
+
height = 60
|
|
2794
|
+
|
|
2795
|
+
self.timeline_widget.setFixedSize(total_width, height)
|
|
2796
|
+
pixmap = QPixmap(total_width, height)
|
|
2797
|
+
pixmap.fill(QColor(255, 255, 255))
|
|
2798
|
+
|
|
2799
|
+
painter = QPainter(pixmap)
|
|
2800
|
+
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
2801
|
+
|
|
2802
|
+
colors = self._get_timeline_qcolors()
|
|
2803
|
+
|
|
2804
|
+
x_pos = 0
|
|
2805
|
+
selected_behavior = self.filter_behavior_combo.currentText()
|
|
2806
|
+
selected_attr = None
|
|
2807
|
+
if selected_behavior.startswith("Attr: "):
|
|
2808
|
+
selected_attr = selected_behavior.replace("Attr: ", "", 1)
|
|
2809
|
+
|
|
2810
|
+
for seg_idx, (pred_idx, conf) in enumerate(zip(display_preds, display_confs)):
|
|
2811
|
+
if pred_idx < len(self.classes) and pred_idx >= 0:
|
|
2812
|
+
# Calculate segment width based on number of clips it spans
|
|
2813
|
+
if seg_idx < len(display_starts) - 1:
|
|
2814
|
+
# Calculate number of clips in this segment
|
|
2815
|
+
seg_start_clip = display_starts[seg_idx]
|
|
2816
|
+
seg_end_clip = display_starts[seg_idx + 1]
|
|
2817
|
+
# Find how many original clips this spans
|
|
2818
|
+
clip_count = sum(1 for orig_start in self.clip_starts
|
|
2819
|
+
if seg_start_clip <= orig_start < seg_end_clip)
|
|
2820
|
+
seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
|
|
2821
|
+
else:
|
|
2822
|
+
# Last segment - count remaining clips
|
|
2823
|
+
seg_start_clip = display_starts[seg_idx]
|
|
2824
|
+
clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
|
|
2825
|
+
seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
|
|
2826
|
+
|
|
2827
|
+
label = self.classes[pred_idx]
|
|
2828
|
+
if selected_behavior == "All Behaviors":
|
|
2829
|
+
is_selected = True
|
|
2830
|
+
elif selected_behavior == self.ignore_label_name:
|
|
2831
|
+
is_selected = False
|
|
2832
|
+
elif selected_attr is not None:
|
|
2833
|
+
if selected_attr in self.attributes:
|
|
2834
|
+
attr_target_idx = self.attributes.index(selected_attr)
|
|
2835
|
+
# Find clip indices covered by this segment
|
|
2836
|
+
seg_indices = [
|
|
2837
|
+
idx for idx, orig_start in enumerate(self.clip_starts)
|
|
2838
|
+
if seg_start_clip <= orig_start < (display_starts[seg_idx + 1] if seg_idx < len(display_starts) - 1 else float("inf"))
|
|
2839
|
+
]
|
|
2840
|
+
is_selected = any(
|
|
2841
|
+
(self._get_attr_idx(i) == attr_target_idx)
|
|
2842
|
+
for i in seg_indices
|
|
2843
|
+
)
|
|
2844
|
+
else:
|
|
2845
|
+
is_selected = False
|
|
2846
|
+
else:
|
|
2847
|
+
is_selected = (label == selected_behavior)
|
|
2848
|
+
|
|
2849
|
+
if is_selected:
|
|
2850
|
+
color = colors[pred_idx % len(colors)]
|
|
2851
|
+
alpha = int(255 * conf)
|
|
2852
|
+
color.setAlpha(alpha)
|
|
2853
|
+
else:
|
|
2854
|
+
# Draw as light gray if not selected to preserve timeline structure
|
|
2855
|
+
color = QColor(240, 240, 240)
|
|
2856
|
+
color.setAlpha(255)
|
|
2857
|
+
|
|
2858
|
+
painter.fillRect(int(x_pos), 0, int(seg_width), height, color)
|
|
2859
|
+
|
|
2860
|
+
if seg_width >= 30 and is_selected:
|
|
2861
|
+
painter.setPen(QColor(0, 0, 0))
|
|
2862
|
+
painter.setFont(QFont("Arial", 8))
|
|
2863
|
+
painter.drawText(int(x_pos + 5), 20, label)
|
|
2864
|
+
|
|
2865
|
+
x_pos += seg_width
|
|
2866
|
+
elif pred_idx < 0:
|
|
2867
|
+
if seg_idx < len(display_starts) - 1:
|
|
2868
|
+
seg_start_clip = display_starts[seg_idx]
|
|
2869
|
+
seg_end_clip = display_starts[seg_idx + 1]
|
|
2870
|
+
clip_count = sum(1 for orig_start in self.clip_starts if seg_start_clip <= orig_start < seg_end_clip)
|
|
2871
|
+
seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
|
|
2872
|
+
else:
|
|
2873
|
+
seg_start_clip = display_starts[seg_idx]
|
|
2874
|
+
clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
|
|
2875
|
+
seg_width = clip_count * base_clip_width if clip_count > 0 else base_clip_width
|
|
2876
|
+
is_selected = selected_behavior in ("All Behaviors", self.ignore_label_name)
|
|
2877
|
+
color = QColor(180, 180, 180) if is_selected else QColor(240, 240, 240)
|
|
2878
|
+
color.setAlpha(230 if is_selected else 255)
|
|
2879
|
+
painter.fillRect(int(x_pos), 0, int(seg_width), height, color)
|
|
2880
|
+
if seg_width >= 55 and is_selected:
|
|
2881
|
+
painter.setPen(QColor(40, 40, 40))
|
|
2882
|
+
painter.setFont(QFont("Arial", 8))
|
|
2883
|
+
painter.drawText(int(x_pos + 5), 20, "ignored")
|
|
2884
|
+
x_pos += seg_width
|
|
2885
|
+
|
|
2886
|
+
painter.setPen(QColor(0, 0, 0))
|
|
2887
|
+
painter.setFont(QFont("Arial", 8))
|
|
2888
|
+
merge_text = " (merged)" if merge_enabled else ""
|
|
2889
|
+
painter.drawText(5, height - 5, f"Timeline: {num_segments} segments, {len(self.predictions)} clips{merge_text} (click to view)")
|
|
2890
|
+
|
|
2891
|
+
painter.end()
|
|
2892
|
+
|
|
2893
|
+
label = QLabel()
|
|
2894
|
+
label.setPixmap(pixmap)
|
|
2895
|
+
label.setFixedSize(total_width, height)
|
|
2896
|
+
|
|
2897
|
+
old_layout = self.timeline_widget.layout()
|
|
2898
|
+
if old_layout:
|
|
2899
|
+
while old_layout.count():
|
|
2900
|
+
child = old_layout.takeAt(0)
|
|
2901
|
+
if child.widget():
|
|
2902
|
+
child.widget().deleteLater()
|
|
2903
|
+
else:
|
|
2904
|
+
layout = QVBoxLayout()
|
|
2905
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
2906
|
+
self.timeline_widget.setLayout(layout)
|
|
2907
|
+
|
|
2908
|
+
self.timeline_widget.layout().addWidget(label)
|
|
2909
|
+
self.timeline_widget.setFixedSize(total_width, height)
|
|
2910
|
+
self.timeline_widget.clip_width = base_clip_width
|
|
2911
|
+
self.timeline_widget.num_clips = len(self.predictions)
|
|
2912
|
+
self.timeline_widget._frame_mode = False
|
|
2913
|
+
|
|
2914
|
+
def _filter_cooccurrence(self, probs, threshold):
|
|
2915
|
+
"""Return set of class indices to display, respecting co-occurrence rules.
|
|
2916
|
+
|
|
2917
|
+
When "Show all classes" is on, every class above threshold is returned.
|
|
2918
|
+
Otherwise: top-1 class is always shown, additional classes only if they
|
|
2919
|
+
form an allowed co-occurrence pair with the top-1 class.
|
|
2920
|
+
"""
|
|
2921
|
+
show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
|
|
2922
|
+
n = min(len(probs), len(self.classes))
|
|
2923
|
+
scored = [(ci, float(probs[ci])) for ci in range(n) if float(probs[ci]) >= threshold]
|
|
2924
|
+
if not scored:
|
|
2925
|
+
return set()
|
|
2926
|
+
if show_all:
|
|
2927
|
+
return {ci for ci, _ in scored}
|
|
2928
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
2929
|
+
top_ci = scored[0][0]
|
|
2930
|
+
top_name = self.classes[top_ci]
|
|
2931
|
+
active = {top_ci}
|
|
2932
|
+
for ci, sc in scored[1:]:
|
|
2933
|
+
name = self.classes[ci]
|
|
2934
|
+
if (top_name, name) in self._allowed_cooccurrence:
|
|
2935
|
+
active.add(ci)
|
|
2936
|
+
return active
|
|
2937
|
+
|
|
2938
|
+
def _active_ovr_indices_from_scores(self, probs_row, threshold_override: float | None = None):
|
|
2939
|
+
"""Active OvR class indices at one frame using thresholds.
|
|
2940
|
+
|
|
2941
|
+
When "Show all classes" is enabled, every class above its threshold is
|
|
2942
|
+
returned (fully independent). Otherwise the top-1 class is returned
|
|
2943
|
+
plus any class that forms an allowed co-occurrence pair with it.
|
|
2944
|
+
"""
|
|
2945
|
+
show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
|
|
2946
|
+
n = min(len(probs_row), len(self.classes))
|
|
2947
|
+
scored = []
|
|
2948
|
+
for ci in range(n):
|
|
2949
|
+
s = float(probs_row[ci])
|
|
2950
|
+
thr = float(threshold_override) if threshold_override is not None else self._threshold_for_pred(ci)
|
|
2951
|
+
if s >= thr:
|
|
2952
|
+
scored.append((ci, s))
|
|
2953
|
+
if not scored:
|
|
2954
|
+
return []
|
|
2955
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
2956
|
+
|
|
2957
|
+
if show_all:
|
|
2958
|
+
return [ci for ci, _ in scored]
|
|
2959
|
+
|
|
2960
|
+
top_ci = scored[0][0]
|
|
2961
|
+
if not self._allowed_cooccurrence:
|
|
2962
|
+
return [top_ci]
|
|
2963
|
+
top_name = self.classes[top_ci]
|
|
2964
|
+
active = [top_ci]
|
|
2965
|
+
for ci, _ in scored[1:]:
|
|
2966
|
+
name = self.classes[ci]
|
|
2967
|
+
if (top_name, name) in self._allowed_cooccurrence:
|
|
2968
|
+
active.append(ci)
|
|
2969
|
+
return active
|
|
2970
|
+
|
|
2971
|
+
def _get_precise_active_for_frame(self, frame_idx: int):
|
|
2972
|
+
"""Return [(class_idx, score), ...] active at frame_idx, sorted desc.
|
|
2973
|
+
|
|
2974
|
+
For OvR: uses _aggregated_active_mask when available so that
|
|
2975
|
+
min-segment and gap-fill filtering are respected.
|
|
2976
|
+
"""
|
|
2977
|
+
if not isinstance(self._aggregated_frame_scores_norm, np.ndarray):
|
|
2978
|
+
return []
|
|
2979
|
+
if frame_idx < 0 or frame_idx >= int(self._aggregated_last_covered_frame):
|
|
2980
|
+
return []
|
|
2981
|
+
scores = self._aggregated_frame_scores_norm[frame_idx]
|
|
2982
|
+
if self._use_ovr:
|
|
2983
|
+
# Prefer the active mask (has min-segment + gap-fill applied)
|
|
2984
|
+
if isinstance(self._aggregated_active_mask, np.ndarray) and frame_idx < self._aggregated_active_mask.shape[0]:
|
|
2985
|
+
mask_row = self._aggregated_active_mask[frame_idx]
|
|
2986
|
+
out = [(ci, float(scores[ci])) for ci in range(len(mask_row)) if mask_row[ci]]
|
|
2987
|
+
else:
|
|
2988
|
+
thr = None if self.use_ignore_threshold else 0.35
|
|
2989
|
+
active = self._active_ovr_indices_from_scores(scores, threshold_override=thr)
|
|
2990
|
+
out = [(ci, float(scores[ci])) for ci in active]
|
|
2991
|
+
out.sort(key=lambda x: x[1], reverse=True)
|
|
2992
|
+
return out
|
|
2993
|
+
if len(scores) == 0:
|
|
2994
|
+
return []
|
|
2995
|
+
ci = int(np.argmax(scores))
|
|
2996
|
+
if 0 <= ci < len(self.classes):
|
|
2997
|
+
if self.use_ignore_threshold and float(scores[ci]) < self._threshold_for_pred(ci):
|
|
2998
|
+
return []
|
|
2999
|
+
return [(ci, float(scores[ci]))]
|
|
3000
|
+
return []
|
|
3001
|
+
|
|
3002
|
+
def _ovr_timeline_click(self, event):
|
|
3003
|
+
"""Handle click on OvR multi-row timeline to show clip popup."""
|
|
3004
|
+
x = event.position().x()
|
|
3005
|
+
y = event.position().y()
|
|
3006
|
+
row_height = max(1, int(getattr(self, "_ovr_row_height", 24)))
|
|
3007
|
+
clicked_class = int(y / row_height) if row_height > 0 else -1
|
|
3008
|
+
if clicked_class < 0 or clicked_class >= len(self.classes):
|
|
3009
|
+
clicked_class = -1
|
|
3010
|
+
|
|
3011
|
+
if getattr(self, "_ovr_timeline_frame_mode", False):
|
|
3012
|
+
if self._ovr_pixels_per_frame > 0:
|
|
3013
|
+
frame_idx = int(x / self._ovr_pixels_per_frame)
|
|
3014
|
+
if 0 <= frame_idx < int(self._aggregated_last_covered_frame):
|
|
3015
|
+
self._show_clip_popup(frame_idx, frame_mode=True, ovr_class_idx=clicked_class)
|
|
3016
|
+
return
|
|
3017
|
+
if self._ovr_num_clips > 0 and self._ovr_clip_width > 0:
|
|
3018
|
+
clip_idx = int(x / self._ovr_clip_width)
|
|
3019
|
+
if 0 <= clip_idx < self._ovr_num_clips:
|
|
3020
|
+
self._show_clip_popup(clip_idx, ovr_class_idx=clicked_class)
|
|
3021
|
+
|
|
3022
|
+
def _draw_ovr_multirow_timeline(self):
|
|
3023
|
+
"""Draw per-class row timeline: fixed labels on left, scrollable bars on right."""
|
|
3024
|
+
num_classes = len(self.classes)
|
|
3025
|
+
num_clips = len(self.predictions)
|
|
3026
|
+
if num_clips == 0 or num_classes == 0:
|
|
3027
|
+
return
|
|
3028
|
+
|
|
3029
|
+
px_per_sec_ovr = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
|
|
3030
|
+
orig_fps_ovr = self._get_video_fps(self.video_path) if self.video_path else 30.0
|
|
3031
|
+
clip_dur_sec_ovr = self.clip_length_spin.value() / max(1.0, float(self.target_fps_spin.value()))
|
|
3032
|
+
base_clip_width = max(4, int(clip_dur_sec_ovr * px_per_sec_ovr))
|
|
3033
|
+
|
|
3034
|
+
use_precise = (
|
|
3035
|
+
self.frame_aggregation_check.isChecked()
|
|
3036
|
+
and self._use_ovr
|
|
3037
|
+
and isinstance(self._aggregated_frame_scores_norm, np.ndarray)
|
|
3038
|
+
and int(self._aggregated_last_covered_frame) > 0
|
|
3039
|
+
)
|
|
3040
|
+
if use_precise:
|
|
3041
|
+
total_frames = int(self._aggregated_last_covered_frame)
|
|
3042
|
+
pixels_per_frame = max(0.5, px_per_sec_ovr / orig_fps_ovr)
|
|
3043
|
+
total_width = max(800, int(total_frames * pixels_per_frame))
|
|
3044
|
+
self._ovr_timeline_frame_mode = True
|
|
3045
|
+
self._ovr_pixels_per_frame = pixels_per_frame
|
|
3046
|
+
else:
|
|
3047
|
+
total_width = max(800, num_clips * base_clip_width)
|
|
3048
|
+
self._ovr_timeline_frame_mode = False
|
|
3049
|
+
self._ovr_pixels_per_frame = 1.0
|
|
3050
|
+
self._ovr_clip_width = base_clip_width
|
|
3051
|
+
self._ovr_num_clips = num_clips
|
|
3052
|
+
viewport_h = 0
|
|
3053
|
+
if getattr(self, "_ovr_scroll", None) is not None and self._ovr_scroll.viewport() is not None:
|
|
3054
|
+
viewport_h = int(self._ovr_scroll.viewport().height())
|
|
3055
|
+
if viewport_h <= 0 and getattr(self, "_ovr_timeline_container", None) is not None:
|
|
3056
|
+
viewport_h = max(0, int(self._ovr_timeline_container.height()) - 8)
|
|
3057
|
+
viewport_h = max(120, viewport_h)
|
|
3058
|
+
row_height = max(12, viewport_h // max(1, num_classes))
|
|
3059
|
+
self._ovr_row_height = row_height
|
|
3060
|
+
canvas_height = num_classes * row_height
|
|
3061
|
+
activation_threshold = 0.35
|
|
3062
|
+
|
|
3063
|
+
colors = self._get_timeline_qcolors()
|
|
3064
|
+
|
|
3065
|
+
# --- Fixed label panel ---
|
|
3066
|
+
label_pm = QPixmap(95, canvas_height)
|
|
3067
|
+
label_pm.fill(QColor(245, 245, 245))
|
|
3068
|
+
lp = QPainter(label_pm)
|
|
3069
|
+
lp.setFont(QFont("Arial", max(7, min(10, row_height - 10))))
|
|
3070
|
+
for ci in range(num_classes):
|
|
3071
|
+
y = ci * row_height
|
|
3072
|
+
c = colors[ci % len(colors)]
|
|
3073
|
+
swatch_h = max(4, row_height - 8)
|
|
3074
|
+
lp.fillRect(2, y + max(2, (row_height - swatch_h) // 2), 10, swatch_h, c)
|
|
3075
|
+
lp.setPen(QColor(0, 0, 0))
|
|
3076
|
+
text_y = y + min(row_height - 4, max(10, int(row_height * 0.72)))
|
|
3077
|
+
lp.drawText(16, text_y, self.classes[ci])
|
|
3078
|
+
lp.setPen(QColor(210, 210, 210))
|
|
3079
|
+
lp.drawLine(0, y + row_height - 1, 95, y + row_height - 1)
|
|
3080
|
+
lp.end()
|
|
3081
|
+
self._ovr_label_panel.setPixmap(label_pm)
|
|
3082
|
+
self._ovr_label_panel.setFixedHeight(canvas_height)
|
|
3083
|
+
|
|
3084
|
+
# --- Scrollable activation timeline ---
|
|
3085
|
+
tl_pm = QPixmap(total_width, canvas_height)
|
|
3086
|
+
tl_pm.fill(QColor(255, 255, 255))
|
|
3087
|
+
tp = QPainter(tl_pm)
|
|
3088
|
+
tp.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
3089
|
+
|
|
3090
|
+
if use_precise:
|
|
3091
|
+
total_frames = int(self._aggregated_last_covered_frame)
|
|
3092
|
+
active_mask = self._aggregated_active_mask
|
|
3093
|
+
frame_scores = self._aggregated_frame_scores_norm
|
|
3094
|
+
for ci in range(num_classes):
|
|
3095
|
+
y = ci * row_height
|
|
3096
|
+
color = colors[ci % len(colors)]
|
|
3097
|
+
bar_h = max(2, row_height - 4)
|
|
3098
|
+
if not isinstance(active_mask, np.ndarray) or ci >= active_mask.shape[1]:
|
|
3099
|
+
continue
|
|
3100
|
+
run_start = None
|
|
3101
|
+
for fi in range(total_frames):
|
|
3102
|
+
is_active = bool(active_mask[fi, ci]) if fi < active_mask.shape[0] else False
|
|
3103
|
+
if is_active and run_start is None:
|
|
3104
|
+
run_start = fi
|
|
3105
|
+
if run_start is not None and (not is_active or fi == total_frames - 1):
|
|
3106
|
+
run_end = fi if (is_active and fi == total_frames - 1) else fi - 1
|
|
3107
|
+
if run_end >= run_start:
|
|
3108
|
+
x0 = int(run_start * self._ovr_pixels_per_frame)
|
|
3109
|
+
x1 = int((run_end + 1) * self._ovr_pixels_per_frame)
|
|
3110
|
+
seg_w = max(1, x1 - x0)
|
|
3111
|
+
seg_conf = float(np.mean(frame_scores[run_start:run_end + 1, ci]))
|
|
3112
|
+
c = QColor(color)
|
|
3113
|
+
t = (seg_conf - activation_threshold) / max(0.01, 1.0 - activation_threshold)
|
|
3114
|
+
alpha = int(120 + 135 * min(1.0, max(0.0, t)))
|
|
3115
|
+
c.setAlpha(alpha)
|
|
3116
|
+
tp.fillRect(x0, y + 2, seg_w, bar_h, c)
|
|
3117
|
+
run_start = None
|
|
3118
|
+
else:
|
|
3119
|
+
ovr_threshold = None if self.use_ignore_threshold else activation_threshold
|
|
3120
|
+
for clip_i in range(num_clips):
|
|
3121
|
+
if clip_i >= len(self.clip_probabilities):
|
|
3122
|
+
break
|
|
3123
|
+
probs_i = self.clip_probabilities[clip_i]
|
|
3124
|
+
if not isinstance(probs_i, (list, tuple, np.ndarray)):
|
|
3125
|
+
continue
|
|
3126
|
+
|
|
3127
|
+
# Determine active classes with same threshold/co-occurrence logic
|
|
3128
|
+
# used by precise OvR mode and overlays.
|
|
3129
|
+
active = self._active_ovr_indices_from_scores(
|
|
3130
|
+
probs_i,
|
|
3131
|
+
threshold_override=ovr_threshold,
|
|
3132
|
+
)
|
|
3133
|
+
|
|
3134
|
+
for ci in active:
|
|
3135
|
+
score = float(probs_i[ci])
|
|
3136
|
+
y = ci * row_height
|
|
3137
|
+
color = colors[ci % len(colors)]
|
|
3138
|
+
bar_h = max(2, row_height - 4)
|
|
3139
|
+
x = clip_i * base_clip_width
|
|
3140
|
+
c = QColor(color)
|
|
3141
|
+
ci_threshold = self._threshold_for_pred(ci) if self.use_ignore_threshold else activation_threshold
|
|
3142
|
+
t = (score - ci_threshold) / max(0.01, 1.0 - ci_threshold)
|
|
3143
|
+
alpha = int(120 + 135 * min(1.0, t))
|
|
3144
|
+
c.setAlpha(alpha)
|
|
3145
|
+
tp.fillRect(int(x), y + 2, base_clip_width, bar_h, c)
|
|
3146
|
+
|
|
3147
|
+
for ci in range(num_classes):
|
|
3148
|
+
y = ci * row_height
|
|
3149
|
+
# Row separator
|
|
3150
|
+
tp.setPen(QColor(210, 210, 210))
|
|
3151
|
+
tp.drawLine(0, y + row_height - 1, total_width, y + row_height - 1)
|
|
3152
|
+
|
|
3153
|
+
tp.end()
|
|
3154
|
+
self._ovr_timeline_widget.setPixmap(tl_pm)
|
|
3155
|
+
self._ovr_timeline_widget.setFixedSize(total_width, canvas_height)
|
|
3156
|
+
self._ovr_label_panel.setFixedSize(95, canvas_height)
|
|
3157
|
+
|
|
3158
|
+
def eventFilter(self, obj, event):
|
|
3159
|
+
if (
|
|
3160
|
+
event.type() == QEvent.Type.Resize
|
|
3161
|
+
and obj in {getattr(self, "_ovr_timeline_container", None), getattr(self, "_ovr_scroll", None).viewport() if getattr(self, "_ovr_scroll", None) is not None else None}
|
|
3162
|
+
):
|
|
3163
|
+
if getattr(self, "_ovr_timeline_container", None) is not None and self._ovr_timeline_container.isVisible():
|
|
3164
|
+
QTimer.singleShot(0, self._draw_timeline)
|
|
3165
|
+
return super().eventFilter(obj, event)
|
|
3166
|
+
|
|
3167
|
+
def _draw_frame_aggregated_timeline(self):
|
|
3168
|
+
"""Draw timeline using frame-level aggregated segments with precise boundaries."""
|
|
3169
|
+
if not self.aggregated_segments or not self.classes:
|
|
3170
|
+
return
|
|
3171
|
+
|
|
3172
|
+
segments = self.aggregated_segments
|
|
3173
|
+
num_segments = len(segments)
|
|
3174
|
+
|
|
3175
|
+
# Use original video length so timeline matches full video
|
|
3176
|
+
total_frames = self.total_frames
|
|
3177
|
+
if total_frames <= 0 and self.video_path and os.path.exists(self.video_path):
|
|
3178
|
+
try:
|
|
3179
|
+
cap = cv2.VideoCapture(self.video_path)
|
|
3180
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
3181
|
+
cap.release()
|
|
3182
|
+
except Exception:
|
|
3183
|
+
pass
|
|
3184
|
+
if total_frames <= 0 and segments:
|
|
3185
|
+
total_frames = segments[-1]['end'] + 1
|
|
3186
|
+
|
|
3187
|
+
# Calculate timeline dimensions using the zoom spinbox (px per second of video).
|
|
3188
|
+
# This removes the artificial max_width cap — the scroll area handles overflow.
|
|
3189
|
+
px_per_sec = float(getattr(self, "timeline_zoom_spin", None) and self.timeline_zoom_spin.value() or 100)
|
|
3190
|
+
orig_fps_tl = self._get_video_fps(self.video_path) if self.video_path else 30.0
|
|
3191
|
+
pixels_per_frame = max(0.5, px_per_sec / orig_fps_tl)
|
|
3192
|
+
total_width = max(800, int(total_frames * pixels_per_frame))
|
|
3193
|
+
height = 60
|
|
3194
|
+
|
|
3195
|
+
self.timeline_widget.setMinimumWidth(total_width)
|
|
3196
|
+
pixmap = QPixmap(total_width, height)
|
|
3197
|
+
pixmap.fill(QColor(255, 255, 255))
|
|
3198
|
+
|
|
3199
|
+
painter = QPainter(pixmap)
|
|
3200
|
+
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
3201
|
+
|
|
3202
|
+
colors = self._get_timeline_qcolors()
|
|
3203
|
+
|
|
3204
|
+
selected_behavior = self.filter_behavior_combo.currentText()
|
|
3205
|
+
|
|
3206
|
+
for seg in segments:
|
|
3207
|
+
pred_idx = seg['class']
|
|
3208
|
+
conf = seg.get('confidence', 1.0)
|
|
3209
|
+
start_frame = seg['start']
|
|
3210
|
+
end_frame = seg['end']
|
|
3211
|
+
|
|
3212
|
+
# Calculate pixel positions
|
|
3213
|
+
x_start = int(start_frame * pixels_per_frame)
|
|
3214
|
+
x_end = int((end_frame + 1) * pixels_per_frame)
|
|
3215
|
+
seg_width = max(1, x_end - x_start)
|
|
3216
|
+
|
|
3217
|
+
if pred_idx < 0:
|
|
3218
|
+
label = "Filtered"
|
|
3219
|
+
elif pred_idx < len(self.classes):
|
|
3220
|
+
label = self.classes[pred_idx]
|
|
3221
|
+
else:
|
|
3222
|
+
continue
|
|
3223
|
+
|
|
3224
|
+
if selected_behavior == "All Behaviors":
|
|
3225
|
+
is_selected = True
|
|
3226
|
+
elif selected_behavior == self.ignore_label_name:
|
|
3227
|
+
is_selected = (pred_idx < 0)
|
|
3228
|
+
else:
|
|
3229
|
+
is_selected = (label == selected_behavior)
|
|
3230
|
+
|
|
3231
|
+
if is_selected:
|
|
3232
|
+
if pred_idx < 0:
|
|
3233
|
+
color = QColor(180, 180, 180)
|
|
3234
|
+
color.setAlpha(160)
|
|
3235
|
+
else:
|
|
3236
|
+
color = colors[pred_idx % len(colors)]
|
|
3237
|
+
alpha = int(min(255, 128 + 127 * min(1.0, conf)))
|
|
3238
|
+
color.setAlpha(alpha)
|
|
3239
|
+
else:
|
|
3240
|
+
color = QColor(240, 240, 240)
|
|
3241
|
+
color.setAlpha(255)
|
|
3242
|
+
|
|
3243
|
+
painter.fillRect(x_start, 0, seg_width, height, color)
|
|
3244
|
+
|
|
3245
|
+
# Draw label if segment is wide enough
|
|
3246
|
+
if seg_width >= 40 and is_selected:
|
|
3247
|
+
painter.setPen(QColor(0, 0, 0))
|
|
3248
|
+
painter.setFont(QFont("Arial", 8))
|
|
3249
|
+
painter.drawText(x_start + 3, 20, label)
|
|
3250
|
+
|
|
3251
|
+
# Draw boundary markers
|
|
3252
|
+
if is_selected:
|
|
3253
|
+
painter.setPen(QColor(0, 0, 0, 100))
|
|
3254
|
+
painter.drawLine(x_start, 0, x_start, height)
|
|
3255
|
+
|
|
3256
|
+
# Draw info text
|
|
3257
|
+
painter.setPen(QColor(0, 0, 0))
|
|
3258
|
+
painter.setFont(QFont("Arial", 8))
|
|
3259
|
+
|
|
3260
|
+
orig_fps = self._get_video_fps(self.video_path) if self.video_path else 30.0
|
|
3261
|
+
duration_sec = total_frames / orig_fps
|
|
3262
|
+
painter.drawText(5, height - 5,
|
|
3263
|
+
f"Frame Timeline: {num_segments} segments, {total_frames} frames ({duration_sec:.1f}s) - precise boundaries (click to view)")
|
|
3264
|
+
|
|
3265
|
+
painter.end()
|
|
3266
|
+
|
|
3267
|
+
label = QLabel()
|
|
3268
|
+
label.setPixmap(pixmap)
|
|
3269
|
+
label.setFixedSize(total_width, height)
|
|
3270
|
+
|
|
3271
|
+
old_layout = self.timeline_widget.layout()
|
|
3272
|
+
if old_layout:
|
|
3273
|
+
while old_layout.count():
|
|
3274
|
+
child = old_layout.takeAt(0)
|
|
3275
|
+
if child.widget():
|
|
3276
|
+
child.widget().deleteLater()
|
|
3277
|
+
else:
|
|
3278
|
+
layout = QVBoxLayout()
|
|
3279
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
3280
|
+
self.timeline_widget.setLayout(layout)
|
|
3281
|
+
|
|
3282
|
+
self.timeline_widget.layout().addWidget(label)
|
|
3283
|
+
self.timeline_widget.setFixedSize(total_width, height)
|
|
3284
|
+
|
|
3285
|
+
# Store info for click handling in frame mode
|
|
3286
|
+
self.timeline_widget._frame_mode = True
|
|
3287
|
+
self.timeline_widget._pixels_per_frame = pixels_per_frame
|
|
3288
|
+
self.timeline_widget._total_frames = total_frames
|
|
3289
|
+
|
|
3290
|
+
def _on_filter_changed(self, index: int):
|
|
3291
|
+
"""Handle behavior filter change."""
|
|
3292
|
+
if self.predictions:
|
|
3293
|
+
self._draw_timeline()
|
|
3294
|
+
|
|
3295
|
+
def _on_theme_changed(self, index: int):
|
|
3296
|
+
"""Handle timeline theme change."""
|
|
3297
|
+
if self.predictions:
|
|
3298
|
+
self._draw_timeline()
|
|
3299
|
+
|
|
3300
|
+
def _get_timeline_palette(self) -> list[tuple[int, int, int]]:
|
|
3301
|
+
theme = self.timeline_theme_combo.currentText() if hasattr(self, "timeline_theme_combo") else DEFAULT_THEME
|
|
3302
|
+
return get_timeline_palette(theme)
|
|
3303
|
+
|
|
3304
|
+
def _get_attr_idx(self, clip_idx: int):
|
|
3305
|
+
if clip_idx in self.corrected_attr_labels:
|
|
3306
|
+
return self.corrected_attr_labels[clip_idx]
|
|
3307
|
+
if self.attr_predictions and clip_idx < len(self.attr_predictions):
|
|
3308
|
+
return self.attr_predictions[clip_idx]
|
|
3309
|
+
return None
|
|
3310
|
+
|
|
3311
|
+
def _on_clip_length_changed(self, value: int):
|
|
3312
|
+
clip_length = int(value)
|
|
3313
|
+
self.step_frames_spin.blockSignals(True)
|
|
3314
|
+
self.step_frames_spin.setValue(max(1, clip_length // 2))
|
|
3315
|
+
self.step_frames_spin.blockSignals(False)
|
|
3316
|
+
self._on_step_or_clip_changed(self.step_frames_spin.value())
|
|
3317
|
+
|
|
3318
|
+
def _on_step_or_clip_changed(self, value: int):
|
|
3319
|
+
step_frames = self.step_frames_spin.value()
|
|
3320
|
+
clip_length = self.clip_length_spin.value()
|
|
3321
|
+
if step_frames != clip_length:
|
|
3322
|
+
if not self.frame_aggregation_check.isChecked():
|
|
3323
|
+
self.frame_aggregation_check.setChecked(True)
|
|
3324
|
+
self.log_text.append(
|
|
3325
|
+
f"Auto-enabled 'Precise frame boundaries' (step={step_frames} ≠ clip={clip_length})"
|
|
3326
|
+
)
|
|
3327
|
+
|
|
3328
|
+
def _on_merge_changed(self, state: int):
|
|
3329
|
+
"""Handle merge checkbox change."""
|
|
3330
|
+
if self.predictions:
|
|
3331
|
+
self._draw_timeline()
|
|
3332
|
+
|
|
3333
|
+
def _on_frame_aggregation_changed(self, state: int):
|
|
3334
|
+
"""Handle frame aggregation checkbox change."""
|
|
3335
|
+
if self.predictions:
|
|
3336
|
+
if state:
|
|
3337
|
+
self._compute_aggregated_timeline()
|
|
3338
|
+
self._draw_timeline()
|
|
3339
|
+
|
|
3340
|
+
def _on_timeline_zoom_changed(self, *_args):
|
|
3341
|
+
"""Redraw timeline when zoom level changes."""
|
|
3342
|
+
if self.predictions:
|
|
3343
|
+
self._draw_timeline()
|
|
3344
|
+
|
|
3345
|
+
def _on_ovr_show_all_changed(self, *_args):
|
|
3346
|
+
"""Toggle between co-occurrence-restricted and fully independent OvR display."""
|
|
3347
|
+
if self.predictions and self._use_ovr:
|
|
3348
|
+
if self.frame_aggregation_check.isChecked():
|
|
3349
|
+
self._build_timeline_segments()
|
|
3350
|
+
self._display_results()
|
|
3351
|
+
|
|
3352
|
+
def _on_smoothing_changed(self, *_args):
|
|
3353
|
+
"""Legacy hook: keep config synced when smoothing settings change."""
|
|
3354
|
+
self.config["inference_min_segment_frames"] = self._min_segment_frames
|
|
3355
|
+
self.config["inference_merge_gap_frames"] = self._merge_gap_frames
|
|
3356
|
+
self.config["inference_temporal_smoothing_window_frames"] = self._temporal_smoothing_window_frames
|
|
3357
|
+
self._sync_per_class_segment_rule_config()
|
|
3358
|
+
if self.predictions and self.frame_aggregation_check.isChecked():
|
|
3359
|
+
self._compute_aggregated_timeline()
|
|
3360
|
+
self._display_results()
|
|
3361
|
+
|
|
3362
|
+
def _smooth_frame_labels(self, frame_labels: np.ndarray) -> np.ndarray:
|
|
3363
|
+
"""Apply simple temporal smoothing to frame-wise top-1 labels."""
|
|
3364
|
+
if frame_labels.size == 0:
|
|
3365
|
+
return frame_labels
|
|
3366
|
+
|
|
3367
|
+
labels = frame_labels.copy()
|
|
3368
|
+
T = int(labels.shape[0])
|
|
3369
|
+
|
|
3370
|
+
if T > 1:
|
|
3371
|
+
majority_labels = labels.copy()
|
|
3372
|
+
for i in range(T):
|
|
3373
|
+
center = int(labels[i])
|
|
3374
|
+
if center < 0:
|
|
3375
|
+
continue
|
|
3376
|
+
win = self._smooth_win_for_class(center)
|
|
3377
|
+
if win <= 1:
|
|
3378
|
+
continue
|
|
3379
|
+
half = win // 2
|
|
3380
|
+
left = max(0, i - half)
|
|
3381
|
+
right = min(T, i + half + 1)
|
|
3382
|
+
window_vals = labels[left:right]
|
|
3383
|
+
valid_vals = window_vals[window_vals >= 0]
|
|
3384
|
+
if valid_vals.size == 0:
|
|
3385
|
+
continue
|
|
3386
|
+
counts = np.bincount(valid_vals.astype(np.int64))
|
|
3387
|
+
max_count = int(np.max(counts))
|
|
3388
|
+
winners = np.where(counts == max_count)[0]
|
|
3389
|
+
if winners.size == 1:
|
|
3390
|
+
majority_labels[i] = int(winners[0])
|
|
3391
|
+
else:
|
|
3392
|
+
majority_labels[i] = center if center in winners else int(winners[0])
|
|
3393
|
+
labels = majority_labels
|
|
3394
|
+
return labels
|
|
3395
|
+
|
|
3396
|
+
def _apply_gap_merge_and_min_seg(self, frame_labels: np.ndarray, T: int) -> np.ndarray:
|
|
3397
|
+
"""Merge short gaps between identical classes and remove short runs.
|
|
3398
|
+
|
|
3399
|
+
Operates on a copy so the caller's array is unchanged if not reassigned.
|
|
3400
|
+
Gap-merge: fills runs of -1 that are <= merge_gap_frames long when the
|
|
3401
|
+
class on both sides is the same.
|
|
3402
|
+
Min-segment: removes runs shorter than min_segment_frames by replacing
|
|
3403
|
+
them with their neighbour.
|
|
3404
|
+
"""
|
|
3405
|
+
max_gap = int(max(0, getattr(self, "_merge_gap_frames", 0)))
|
|
3406
|
+
min_len = int(max(1, getattr(self, "_min_segment_frames", 1)))
|
|
3407
|
+
has_class_overrides = bool(
|
|
3408
|
+
self.class_merge_gap_frames
|
|
3409
|
+
or self.class_min_segment_frames
|
|
3410
|
+
)
|
|
3411
|
+
if max_gap == 0 and min_len <= 1 and not has_class_overrides:
|
|
3412
|
+
return frame_labels
|
|
3413
|
+
|
|
3414
|
+
labels = frame_labels.copy()
|
|
3415
|
+
|
|
3416
|
+
if max_gap > 0 or self.class_merge_gap_frames:
|
|
3417
|
+
i = 0
|
|
3418
|
+
while i < T:
|
|
3419
|
+
if labels[i] != -1:
|
|
3420
|
+
i += 1
|
|
3421
|
+
continue
|
|
3422
|
+
j = i
|
|
3423
|
+
while j + 1 < T and labels[j + 1] == -1:
|
|
3424
|
+
j += 1
|
|
3425
|
+
gap_len = j - i + 1
|
|
3426
|
+
left = int(labels[i - 1]) if i > 0 else -1
|
|
3427
|
+
right = int(labels[j + 1]) if j + 1 < T else -1
|
|
3428
|
+
gap_thr = self._gap_fill_for_class(left) if left >= 0 and left == right else max_gap
|
|
3429
|
+
if gap_len <= gap_thr and left >= 0 and left == right:
|
|
3430
|
+
labels[i:j + 1] = left
|
|
3431
|
+
i = j + 1
|
|
3432
|
+
|
|
3433
|
+
if min_len > 1 or self.class_min_segment_frames:
|
|
3434
|
+
changed = True
|
|
3435
|
+
while changed:
|
|
3436
|
+
changed = False
|
|
3437
|
+
i = 0
|
|
3438
|
+
while i < T:
|
|
3439
|
+
cls = int(labels[i])
|
|
3440
|
+
j = i
|
|
3441
|
+
while j + 1 < T and int(labels[j + 1]) == cls:
|
|
3442
|
+
j += 1
|
|
3443
|
+
run_len = j - i + 1
|
|
3444
|
+
min_len_cls = self._min_seg_for_class(cls) if cls >= 0 else min_len
|
|
3445
|
+
if cls >= 0 and run_len < min_len_cls:
|
|
3446
|
+
left = int(labels[i - 1]) if i > 0 else -1
|
|
3447
|
+
right = int(labels[j + 1]) if j + 1 < T else -1
|
|
3448
|
+
if left >= 0 and right >= 0:
|
|
3449
|
+
repl = left if left == right else left
|
|
3450
|
+
elif left >= 0:
|
|
3451
|
+
repl = left
|
|
3452
|
+
elif right >= 0:
|
|
3453
|
+
repl = right
|
|
3454
|
+
else:
|
|
3455
|
+
repl = -1
|
|
3456
|
+
if repl != cls:
|
|
3457
|
+
labels[i:j + 1] = repl
|
|
3458
|
+
changed = True
|
|
3459
|
+
i = j + 1
|
|
3460
|
+
|
|
3461
|
+
return labels
|
|
3462
|
+
|
|
3463
|
+
def _viterbi_decode_dense(self, probs: np.ndarray) -> np.ndarray:
|
|
3464
|
+
"""Single-label Viterbi decode for a contiguous covered frame range."""
|
|
3465
|
+
if probs.ndim != 2 or probs.shape[0] == 0 or probs.shape[1] == 0:
|
|
3466
|
+
return np.zeros((0,), dtype=np.int32)
|
|
3467
|
+
T, C = probs.shape
|
|
3468
|
+
if C == 1:
|
|
3469
|
+
return np.zeros((T,), dtype=np.int32)
|
|
3470
|
+
|
|
3471
|
+
emissions = np.log(np.clip(probs.astype(np.float32, copy=False), 1e-8, 1.0))
|
|
3472
|
+
switch_penalty = float(max(0.0, self.viterbi_switch_penalty))
|
|
3473
|
+
trans = np.full((C, C), -switch_penalty, dtype=np.float32)
|
|
3474
|
+
np.fill_diagonal(trans, 0.0)
|
|
3475
|
+
|
|
3476
|
+
dp = np.empty((T, C), dtype=np.float32)
|
|
3477
|
+
backptr = np.zeros((T, C), dtype=np.int32)
|
|
3478
|
+
dp[0] = emissions[0]
|
|
3479
|
+
|
|
3480
|
+
for t in range(1, T):
|
|
3481
|
+
scores = dp[t - 1][:, np.newaxis] + trans
|
|
3482
|
+
backptr[t] = np.argmax(scores, axis=0).astype(np.int32)
|
|
3483
|
+
dp[t] = emissions[t] + scores[backptr[t], np.arange(C)]
|
|
3484
|
+
|
|
3485
|
+
labels = np.zeros((T,), dtype=np.int32)
|
|
3486
|
+
labels[-1] = int(np.argmax(dp[-1]))
|
|
3487
|
+
for t in range(T - 2, -1, -1):
|
|
3488
|
+
labels[t] = int(backptr[t + 1, labels[t + 1]])
|
|
3489
|
+
return labels
|
|
3490
|
+
|
|
3491
|
+
def _decode_viterbi_labels(self, fs: np.ndarray, covered_mask: np.ndarray) -> np.ndarray:
|
|
3492
|
+
"""Run Viterbi only on contiguous covered ranges, leaving gaps as -1."""
|
|
3493
|
+
T = int(fs.shape[0])
|
|
3494
|
+
decoded = np.full((T,), -1, dtype=np.int32)
|
|
3495
|
+
i = 0
|
|
3496
|
+
while i < T:
|
|
3497
|
+
if not bool(covered_mask[i]):
|
|
3498
|
+
i += 1
|
|
3499
|
+
continue
|
|
3500
|
+
j = i
|
|
3501
|
+
while j + 1 < T and bool(covered_mask[j + 1]):
|
|
3502
|
+
j += 1
|
|
3503
|
+
decoded[i:j + 1] = self._viterbi_decode_dense(fs[i:j + 1])
|
|
3504
|
+
i = j + 1
|
|
3505
|
+
return decoded
|
|
3506
|
+
|
|
3507
|
+
def _binary_viterbi_decode(self, probs: np.ndarray, threshold: float) -> np.ndarray:
|
|
3508
|
+
"""Binary Viterbi decode for one OvR class over a contiguous covered range."""
|
|
3509
|
+
p = np.clip(np.asarray(probs, dtype=np.float32).reshape(-1), 1e-6, 1.0 - 1e-6)
|
|
3510
|
+
T = int(p.shape[0])
|
|
3511
|
+
if T == 0:
|
|
3512
|
+
return np.zeros((0,), dtype=bool)
|
|
3513
|
+
|
|
3514
|
+
tau = float(np.clip(threshold, 1e-4, 1.0 - 1e-4))
|
|
3515
|
+
emit_off = np.log(1.0 - p) - np.log(1.0 - tau)
|
|
3516
|
+
emit_on = np.log(p) - np.log(tau)
|
|
3517
|
+
switch_penalty = float(max(0.0, self.viterbi_switch_penalty))
|
|
3518
|
+
|
|
3519
|
+
dp = np.empty((T, 2), dtype=np.float32)
|
|
3520
|
+
backptr = np.zeros((T, 2), dtype=np.int8)
|
|
3521
|
+
dp[0, 0] = emit_off[0]
|
|
3522
|
+
dp[0, 1] = emit_on[0]
|
|
3523
|
+
|
|
3524
|
+
for t in range(1, T):
|
|
3525
|
+
stay_off = dp[t - 1, 0]
|
|
3526
|
+
on_to_off = dp[t - 1, 1] - switch_penalty
|
|
3527
|
+
if stay_off >= on_to_off:
|
|
3528
|
+
dp[t, 0] = emit_off[t] + stay_off
|
|
3529
|
+
backptr[t, 0] = 0
|
|
3530
|
+
else:
|
|
3531
|
+
dp[t, 0] = emit_off[t] + on_to_off
|
|
3532
|
+
backptr[t, 0] = 1
|
|
3533
|
+
|
|
3534
|
+
off_to_on = dp[t - 1, 0] - switch_penalty
|
|
3535
|
+
stay_on = dp[t - 1, 1]
|
|
3536
|
+
if off_to_on >= stay_on:
|
|
3537
|
+
dp[t, 1] = emit_on[t] + off_to_on
|
|
3538
|
+
backptr[t, 1] = 0
|
|
3539
|
+
else:
|
|
3540
|
+
dp[t, 1] = emit_on[t] + stay_on
|
|
3541
|
+
backptr[t, 1] = 1
|
|
3542
|
+
|
|
3543
|
+
states = np.zeros((T,), dtype=np.int8)
|
|
3544
|
+
states[-1] = 1 if dp[-1, 1] >= dp[-1, 0] else 0
|
|
3545
|
+
for t in range(T - 2, -1, -1):
|
|
3546
|
+
states[t] = backptr[t + 1, states[t + 1]]
|
|
3547
|
+
return states.astype(bool)
|
|
3548
|
+
|
|
3549
|
+
def _build_timeline_segments(self):
|
|
3550
|
+
"""Build timeline segments from precomputed frame probabilities."""
|
|
3551
|
+
if not isinstance(self._aggregated_frame_scores_norm, np.ndarray):
|
|
3552
|
+
self.aggregated_segments = []
|
|
3553
|
+
self.aggregated_multiclass_segments = []
|
|
3554
|
+
self._aggregated_active_mask = None
|
|
3555
|
+
self._aggregated_last_covered_frame = 0
|
|
3556
|
+
return
|
|
3557
|
+
|
|
3558
|
+
fs = self._aggregated_frame_scores_norm
|
|
3559
|
+
if fs.ndim != 2 or fs.shape[0] == 0 or fs.shape[1] == 0:
|
|
3560
|
+
self.aggregated_segments = []
|
|
3561
|
+
self.aggregated_multiclass_segments = []
|
|
3562
|
+
self._aggregated_active_mask = None
|
|
3563
|
+
self._aggregated_last_covered_frame = 0
|
|
3564
|
+
return
|
|
3565
|
+
|
|
3566
|
+
T, C = fs.shape
|
|
3567
|
+
self._aggregated_last_covered_frame = T
|
|
3568
|
+
covered_mask = np.sum(fs, axis=1) > 1e-8
|
|
3569
|
+
if self.use_viterbi_decode and not self._use_ovr:
|
|
3570
|
+
frame_labels = self._decode_viterbi_labels(fs, covered_mask)
|
|
3571
|
+
else:
|
|
3572
|
+
frame_labels = np.argmax(fs, axis=1)
|
|
3573
|
+
frame_labels[~covered_mask] = -1
|
|
3574
|
+
# Smooth raw argmax labels (majority-vote temporal smoothing).
|
|
3575
|
+
frame_labels = self._smooth_frame_labels(frame_labels)
|
|
3576
|
+
|
|
3577
|
+
# Apply ignore threshold after smoothing.
|
|
3578
|
+
if self.use_ignore_threshold and not self._use_ovr:
|
|
3579
|
+
for fi in range(T):
|
|
3580
|
+
ci = int(frame_labels[fi])
|
|
3581
|
+
if ci < 0:
|
|
3582
|
+
continue
|
|
3583
|
+
thr = self._threshold_for_pred(ci)
|
|
3584
|
+
if float(fs[fi, ci]) < thr:
|
|
3585
|
+
frame_labels[fi] = -1
|
|
3586
|
+
|
|
3587
|
+
# Merge-gap and min-segment run ONCE as the final cleanup, after all
|
|
3588
|
+
# other preprocessing (smoothing + threshold) is done.
|
|
3589
|
+
frame_labels = self._apply_gap_merge_and_min_seg(frame_labels, T)
|
|
3590
|
+
|
|
3591
|
+
segments = []
|
|
3592
|
+
cur_cls = int(frame_labels[0])
|
|
3593
|
+
cur_start = 0
|
|
3594
|
+
for i in range(1, T):
|
|
3595
|
+
if int(frame_labels[i]) != cur_cls:
|
|
3596
|
+
conf = float(np.mean(fs[cur_start:i, cur_cls])) if cur_cls >= 0 else 0.0
|
|
3597
|
+
segments.append({
|
|
3598
|
+
"class": int(cur_cls),
|
|
3599
|
+
"start": int(cur_start),
|
|
3600
|
+
"end": int(i - 1),
|
|
3601
|
+
"confidence": conf,
|
|
3602
|
+
})
|
|
3603
|
+
cur_cls = int(frame_labels[i])
|
|
3604
|
+
cur_start = i
|
|
3605
|
+
conf = float(np.mean(fs[cur_start:T, cur_cls])) if cur_cls >= 0 else 0.0
|
|
3606
|
+
segments.append({
|
|
3607
|
+
"class": int(cur_cls),
|
|
3608
|
+
"start": int(cur_start),
|
|
3609
|
+
"end": int(T - 1),
|
|
3610
|
+
"confidence": conf,
|
|
3611
|
+
})
|
|
3612
|
+
self.aggregated_segments = segments
|
|
3613
|
+
|
|
3614
|
+
self.aggregated_multiclass_segments = []
|
|
3615
|
+
self._aggregated_active_mask = None
|
|
3616
|
+
if self._use_ovr:
|
|
3617
|
+
active_mask = np.zeros((T, C), dtype=bool)
|
|
3618
|
+
ovr_threshold = None if self.use_ignore_threshold else 0.35
|
|
3619
|
+
show_all = getattr(self, "ovr_show_all_check", None) is not None and self.ovr_show_all_check.isChecked()
|
|
3620
|
+
if self.use_viterbi_decode:
|
|
3621
|
+
for ci in range(C):
|
|
3622
|
+
thr = self._threshold_for_pred(ci) if self.use_ignore_threshold else float(ovr_threshold)
|
|
3623
|
+
active_mask[:, ci] = self._binary_viterbi_decode(fs[:, ci], thr)
|
|
3624
|
+
if not show_all:
|
|
3625
|
+
pruned_mask = np.zeros_like(active_mask)
|
|
3626
|
+
for fi in range(T):
|
|
3627
|
+
active_idx = np.flatnonzero(active_mask[fi])
|
|
3628
|
+
if active_idx.size == 0:
|
|
3629
|
+
continue
|
|
3630
|
+
scores = fs[fi, active_idx]
|
|
3631
|
+
top_ci = int(active_idx[int(np.argmax(scores))])
|
|
3632
|
+
pruned_mask[fi, top_ci] = True
|
|
3633
|
+
if self._allowed_cooccurrence:
|
|
3634
|
+
top_name = self.classes[top_ci]
|
|
3635
|
+
for ci in active_idx:
|
|
3636
|
+
ci = int(ci)
|
|
3637
|
+
if ci == top_ci:
|
|
3638
|
+
continue
|
|
3639
|
+
name = self.classes[ci]
|
|
3640
|
+
if (top_name, name) in self._allowed_cooccurrence:
|
|
3641
|
+
pruned_mask[fi, ci] = True
|
|
3642
|
+
active_mask = pruned_mask
|
|
3643
|
+
else:
|
|
3644
|
+
for fi in range(T):
|
|
3645
|
+
for ci in self._active_ovr_indices_from_scores(fs[fi], threshold_override=ovr_threshold):
|
|
3646
|
+
if 0 <= ci < C:
|
|
3647
|
+
active_mask[fi, ci] = True
|
|
3648
|
+
self._aggregated_active_mask = active_mask
|
|
3649
|
+
|
|
3650
|
+
max_gap = int(max(0, getattr(self, "_merge_gap_frames", 0)))
|
|
3651
|
+
min_len = int(max(1, getattr(self, "_min_segment_frames", 1)))
|
|
3652
|
+
|
|
3653
|
+
# Per-class gap-fill and min-segment cleanup for OvR.
|
|
3654
|
+
for ci in range(C):
|
|
3655
|
+
col = active_mask[:, ci]
|
|
3656
|
+
smooth_win = self._smooth_win_for_class(ci)
|
|
3657
|
+
if (not self.use_viterbi_decode) and smooth_win > 1 and T > 1:
|
|
3658
|
+
half = smooth_win // 2
|
|
3659
|
+
smooth_col = col.copy()
|
|
3660
|
+
for i in range(T):
|
|
3661
|
+
left = max(0, i - half)
|
|
3662
|
+
right = min(T, i + half + 1)
|
|
3663
|
+
window = col[left:right]
|
|
3664
|
+
on_count = int(np.count_nonzero(window))
|
|
3665
|
+
off_count = int(window.size - on_count)
|
|
3666
|
+
if on_count > off_count:
|
|
3667
|
+
smooth_col[i] = True
|
|
3668
|
+
elif off_count > on_count:
|
|
3669
|
+
smooth_col[i] = False
|
|
3670
|
+
col = smooth_col
|
|
3671
|
+
|
|
3672
|
+
gap_thr = self._gap_fill_for_class(ci)
|
|
3673
|
+
min_len_cls = self._min_seg_for_class(ci)
|
|
3674
|
+
# Gap-fill: a short "off" gap between two "on" runs of the same class.
|
|
3675
|
+
if gap_thr > 0:
|
|
3676
|
+
i = 0
|
|
3677
|
+
while i < T:
|
|
3678
|
+
if col[i]:
|
|
3679
|
+
i += 1
|
|
3680
|
+
continue
|
|
3681
|
+
j = i
|
|
3682
|
+
while j + 1 < T and not col[j + 1]:
|
|
3683
|
+
j += 1
|
|
3684
|
+
gap_len = j - i + 1
|
|
3685
|
+
left_on = col[i - 1] if i > 0 else False
|
|
3686
|
+
right_on = col[j + 1] if j + 1 < T else False
|
|
3687
|
+
if gap_len <= gap_thr and left_on and right_on:
|
|
3688
|
+
col[i:j + 1] = True
|
|
3689
|
+
i = j + 1
|
|
3690
|
+
# Min-segment: remove short "on" runs.
|
|
3691
|
+
if min_len_cls > 1:
|
|
3692
|
+
i = 0
|
|
3693
|
+
while i < T:
|
|
3694
|
+
if not col[i]:
|
|
3695
|
+
i += 1
|
|
3696
|
+
continue
|
|
3697
|
+
j = i
|
|
3698
|
+
while j + 1 < T and col[j + 1]:
|
|
3699
|
+
j += 1
|
|
3700
|
+
run_len = j - i + 1
|
|
3701
|
+
if run_len < min_len_cls:
|
|
3702
|
+
col[i:j + 1] = False
|
|
3703
|
+
i = j + 1
|
|
3704
|
+
active_mask[:, ci] = col
|
|
3705
|
+
|
|
3706
|
+
if not show_all:
|
|
3707
|
+
pruned_mask = np.zeros_like(active_mask)
|
|
3708
|
+
for fi in range(T):
|
|
3709
|
+
active_idx = np.flatnonzero(active_mask[fi])
|
|
3710
|
+
if active_idx.size == 0:
|
|
3711
|
+
continue
|
|
3712
|
+
scores = fs[fi, active_idx]
|
|
3713
|
+
top_ci = int(active_idx[int(np.argmax(scores))])
|
|
3714
|
+
pruned_mask[fi, top_ci] = True
|
|
3715
|
+
if self._allowed_cooccurrence:
|
|
3716
|
+
top_name = self.classes[top_ci]
|
|
3717
|
+
for ci in active_idx:
|
|
3718
|
+
ci = int(ci)
|
|
3719
|
+
if ci == top_ci:
|
|
3720
|
+
continue
|
|
3721
|
+
name = self.classes[ci]
|
|
3722
|
+
if (top_name, name) in self._allowed_cooccurrence:
|
|
3723
|
+
pruned_mask[fi, ci] = True
|
|
3724
|
+
active_mask = pruned_mask
|
|
3725
|
+
self._aggregated_active_mask = active_mask
|
|
3726
|
+
|
|
3727
|
+
multi_segments = []
|
|
3728
|
+
for ci in range(C):
|
|
3729
|
+
run_start = None
|
|
3730
|
+
for fi in range(T):
|
|
3731
|
+
is_on = bool(active_mask[fi, ci])
|
|
3732
|
+
if is_on and run_start is None:
|
|
3733
|
+
run_start = fi
|
|
3734
|
+
if run_start is not None and (not is_on or fi == T - 1):
|
|
3735
|
+
run_end = fi if (is_on and fi == T - 1) else fi - 1
|
|
3736
|
+
if run_end >= run_start:
|
|
3737
|
+
conf = float(np.mean(fs[run_start:run_end + 1, ci]))
|
|
3738
|
+
multi_segments.append({
|
|
3739
|
+
"class": int(ci),
|
|
3740
|
+
"start": int(run_start),
|
|
3741
|
+
"end": int(run_end),
|
|
3742
|
+
"confidence": conf,
|
|
3743
|
+
})
|
|
3744
|
+
run_start = None
|
|
3745
|
+
self.aggregated_multiclass_segments = multi_segments
|
|
3746
|
+
|
|
3747
|
+
def _compute_aggregated_timeline(self):
|
|
3748
|
+
"""
|
|
3749
|
+
Aggregate overlapping clip predictions into precise frame-level segments.
|
|
3750
|
+
Uses confidence-weighted voting: each clip votes for its predicted class
|
|
3751
|
+
with weight equal to its confidence score.
|
|
3752
|
+
"""
|
|
3753
|
+
if not self.predictions or not self.classes:
|
|
3754
|
+
self.aggregated_segments = []
|
|
3755
|
+
self.aggregated_multiclass_segments = []
|
|
3756
|
+
self._aggregated_frame_scores_norm = None
|
|
3757
|
+
self._aggregated_active_mask = None
|
|
3758
|
+
self._aggregated_last_covered_frame = 0
|
|
3759
|
+
return
|
|
3760
|
+
|
|
3761
|
+
# Prefer the exact worker-built merged frame timeline when available.
|
|
3762
|
+
# This preserves center-weighted overlap merge.
|
|
3763
|
+
if not self.corrected_labels:
|
|
3764
|
+
precomputed = self._get_precomputed_aggregated_probs(self.video_path)
|
|
3765
|
+
if precomputed is not None:
|
|
3766
|
+
self._aggregated_frame_scores_norm = precomputed.copy()
|
|
3767
|
+
self._aggregated_last_covered_frame = int(precomputed.shape[0])
|
|
3768
|
+
self._build_timeline_segments()
|
|
3769
|
+
return
|
|
3770
|
+
|
|
3771
|
+
# Get parameters
|
|
3772
|
+
clip_length = self.clip_length_spin.value()
|
|
3773
|
+
step_frames = self.step_frames_spin.value()
|
|
3774
|
+
target_fps = self.target_fps_spin.value()
|
|
3775
|
+
num_classes = len(self.classes)
|
|
3776
|
+
|
|
3777
|
+
# Get total frames from video or estimate from clip_starts
|
|
3778
|
+
if self.total_frames > 0:
|
|
3779
|
+
total_frames = self.total_frames
|
|
3780
|
+
elif self.clip_starts:
|
|
3781
|
+
# Estimate: last clip start + clip_length
|
|
3782
|
+
total_frames = self.clip_starts[-1] + clip_length + 1
|
|
3783
|
+
else:
|
|
3784
|
+
self.aggregated_segments = []
|
|
3785
|
+
self.aggregated_multiclass_segments = []
|
|
3786
|
+
self._aggregated_frame_scores_norm = None
|
|
3787
|
+
self._aggregated_active_mask = None
|
|
3788
|
+
self._aggregated_last_covered_frame = 0
|
|
3789
|
+
return
|
|
3790
|
+
|
|
3791
|
+
orig_fps = self._get_video_fps(self.video_path) if self.video_path else 30.0
|
|
3792
|
+
frame_interval = self._get_saved_frame_interval(self.video_path, orig_fps)
|
|
3793
|
+
|
|
3794
|
+
# Apply corrections + ignore-threshold gating
|
|
3795
|
+
corrected_preds = self._effective_predictions()
|
|
3796
|
+
|
|
3797
|
+
# Initialize score matrix: [total_frames, num_classes] and coverage count
|
|
3798
|
+
frame_scores = np.zeros((total_frames, num_classes), dtype=np.float32)
|
|
3799
|
+
frame_coverage = np.zeros(total_frames, dtype=np.float32)
|
|
3800
|
+
used_full_probability_voting = False
|
|
3801
|
+
probs_available = (
|
|
3802
|
+
isinstance(self.clip_probabilities, list)
|
|
3803
|
+
and len(self.clip_probabilities) > 0
|
|
3804
|
+
)
|
|
3805
|
+
# Per-frame probabilities from FrameClassificationHead give much finer
|
|
3806
|
+
# temporal resolution: each frame within a clip gets its own prediction
|
|
3807
|
+
# instead of the entire clip being smeared with one probability vector.
|
|
3808
|
+
frame_probs_available = (
|
|
3809
|
+
isinstance(self.clip_frame_probabilities, list)
|
|
3810
|
+
and len(self.clip_frame_probabilities) > 0
|
|
3811
|
+
)
|
|
3812
|
+
|
|
3813
|
+
# Accumulate votes from each clip
|
|
3814
|
+
for clip_i, (pred_class, conf, start_frame) in enumerate(zip(corrected_preds, self.confidences, self.clip_starts)):
|
|
3815
|
+
end_frame = min(start_frame + clip_length * frame_interval, total_frames)
|
|
3816
|
+
|
|
3817
|
+
if start_frame >= total_frames:
|
|
3818
|
+
continue
|
|
3819
|
+
|
|
3820
|
+
# Best path: per-frame probabilities from frame head. Do NOT gate this
|
|
3821
|
+
# by clip-level class/confidence, since a mixed clip can contain multiple
|
|
3822
|
+
# behaviors with low clip confidence but useful frame-wise predictions.
|
|
3823
|
+
if frame_probs_available and clip_i < len(self.clip_frame_probabilities):
|
|
3824
|
+
fp = self.clip_frame_probabilities[clip_i]
|
|
3825
|
+
if isinstance(fp, (list, np.ndarray)):
|
|
3826
|
+
fp_arr = np.asarray(fp, dtype=np.float32) # [T, C]
|
|
3827
|
+
if fp_arr.ndim == 2 and fp_arr.shape[1] == num_classes:
|
|
3828
|
+
T_clip = fp_arr.shape[0]
|
|
3829
|
+
merge_w = self._build_center_merge_weights(T_clip)
|
|
3830
|
+
for t in range(T_clip):
|
|
3831
|
+
f_start = start_frame + t * frame_interval
|
|
3832
|
+
f_end = min(f_start + frame_interval, total_frames)
|
|
3833
|
+
if f_start >= total_frames:
|
|
3834
|
+
break
|
|
3835
|
+
if f_end <= f_start:
|
|
3836
|
+
continue
|
|
3837
|
+
probs_t = np.clip(fp_arr[t], 0.0, None)
|
|
3838
|
+
w = float(merge_w[t])
|
|
3839
|
+
frame_scores[f_start:f_end, :] += probs_t[np.newaxis, :] * w
|
|
3840
|
+
frame_coverage[f_start:f_end] += w
|
|
3841
|
+
used_full_probability_voting = True
|
|
3842
|
+
continue
|
|
3843
|
+
|
|
3844
|
+
# Remaining paths are clip-level and require a valid clip prediction.
|
|
3845
|
+
if not (0 <= pred_class < num_classes):
|
|
3846
|
+
continue
|
|
3847
|
+
|
|
3848
|
+
# Fallback: clip-level probabilities smeared across the clip
|
|
3849
|
+
if probs_available and clip_i < len(self.clip_probabilities):
|
|
3850
|
+
raw_probs = self.clip_probabilities[clip_i]
|
|
3851
|
+
if isinstance(raw_probs, (list, tuple, np.ndarray)) and len(raw_probs) == num_classes:
|
|
3852
|
+
probs_vec = np.asarray(raw_probs, dtype=np.float32)
|
|
3853
|
+
if np.all(np.isfinite(probs_vec)):
|
|
3854
|
+
probs_vec = np.clip(probs_vec, 0.0, None)
|
|
3855
|
+
s = float(np.sum(probs_vec))
|
|
3856
|
+
if s > 1e-8:
|
|
3857
|
+
if not self._use_ovr:
|
|
3858
|
+
probs_vec = probs_vec / s
|
|
3859
|
+
if (not self._use_ovr) and int(np.argmax(probs_vec)) != int(pred_class):
|
|
3860
|
+
probs_vec = np.zeros(num_classes, dtype=np.float32)
|
|
3861
|
+
probs_vec[int(pred_class)] = 1.0
|
|
3862
|
+
frame_scores[start_frame:end_frame, :] += probs_vec[np.newaxis, :]
|
|
3863
|
+
frame_coverage[start_frame:end_frame] += 1.0
|
|
3864
|
+
used_full_probability_voting = True
|
|
3865
|
+
continue
|
|
3866
|
+
|
|
3867
|
+
# Last fallback for older result files without per-class probabilities.
|
|
3868
|
+
frame_scores[start_frame:end_frame, pred_class] += float(conf)
|
|
3869
|
+
frame_coverage[start_frame:end_frame] += 1.0
|
|
3870
|
+
|
|
3871
|
+
# Normalize scores by clip coverage so confidence stays in [0, 1]
|
|
3872
|
+
frame_scores_norm = np.divide(
|
|
3873
|
+
frame_scores,
|
|
3874
|
+
np.maximum(frame_coverage[:, np.newaxis], 1.0),
|
|
3875
|
+
)
|
|
3876
|
+
|
|
3877
|
+
# For Softmax, ensure final probabilities sum to 1 (averaging usually does, but small errors can creep in)
|
|
3878
|
+
if not self._use_ovr:
|
|
3879
|
+
sums = np.sum(frame_scores_norm, axis=1, keepdims=True)
|
|
3880
|
+
valid_sums = sums > 1e-8
|
|
3881
|
+
frame_scores_norm = np.where(
|
|
3882
|
+
valid_sums,
|
|
3883
|
+
frame_scores_norm / sums,
|
|
3884
|
+
frame_scores_norm
|
|
3885
|
+
)
|
|
3886
|
+
|
|
3887
|
+
# Find the last frame with actual clip coverage to avoid phantom
|
|
3888
|
+
# class-0 labels from uncovered tail frames (argmax of all-zeros = 0).
|
|
3889
|
+
# Interior uncovered frames (filtered clips) are marked -1 ("Filtered").
|
|
3890
|
+
# Use coverage count (number of clips that touched each frame) rather than
|
|
3891
|
+
# score sum, which can be > 0 for uncovered frames in OvR (sigmoid > 0 always).
|
|
3892
|
+
covered_mask = frame_coverage > 0
|
|
3893
|
+
if not np.any(covered_mask):
|
|
3894
|
+
self.aggregated_segments = []
|
|
3895
|
+
self.aggregated_multiclass_segments = []
|
|
3896
|
+
self._aggregated_frame_scores_norm = None
|
|
3897
|
+
self._aggregated_active_mask = None
|
|
3898
|
+
self._aggregated_last_covered_frame = 0
|
|
3899
|
+
return
|
|
3900
|
+
last_covered_frame = int(np.max(np.where(covered_mask))) + 1 # exclusive end
|
|
3901
|
+
self._aggregated_last_covered_frame = last_covered_frame
|
|
3902
|
+
self._aggregated_frame_scores_norm = frame_scores_norm[:last_covered_frame].copy()
|
|
3903
|
+
|
|
3904
|
+
self._build_timeline_segments()
|
|
3905
|
+
|
|
3906
|
+
# Log summary
|
|
3907
|
+
if self.aggregated_segments:
|
|
3908
|
+
self.log_text.append(f"Frame aggregation: {len(self.aggregated_segments)} segments from {len(self.predictions)} clips")
|
|
3909
|
+
if self._use_ovr and self.aggregated_multiclass_segments:
|
|
3910
|
+
self.log_text.append(
|
|
3911
|
+
f" OvR precise multi-class segments: {len(self.aggregated_multiclass_segments)}"
|
|
3912
|
+
)
|
|
3913
|
+
if used_full_probability_voting:
|
|
3914
|
+
mode_name = "sigmoid" if getattr(self, "_use_ovr", False) else "softmax"
|
|
3915
|
+
self.log_text.append(f" Evidence mode: full {mode_name} probability voting")
|
|
3916
|
+
else:
|
|
3917
|
+
self.log_text.append(" Evidence mode: top-1 confidence voting")
|
|
3918
|
+
step_info = f"step={step_frames}, clip_length={clip_length}"
|
|
3919
|
+
if step_frames >= clip_length:
|
|
3920
|
+
self.log_text.append(f" Note: No overlap ({step_info}). For better boundaries, use step_frames < clip_length.")
|
|
3921
|
+
else:
|
|
3922
|
+
overlap_pct = (1 - step_frames / clip_length) * 100
|
|
3923
|
+
self.log_text.append(f" Overlap: {overlap_pct:.0f}% ({step_info})")
|
|
3924
|
+
|
|
3925
|
+
def _export_timeline(self):
|
|
3926
|
+
"""Export timeline as SVG/PDF and CSV with behavior segments."""
|
|
3927
|
+
if not self.predictions or not self.video_path:
|
|
3928
|
+
QMessageBox.warning(self, "Error", "No predictions available to export.")
|
|
3929
|
+
return
|
|
3930
|
+
|
|
3931
|
+
self._persist_current_video_state()
|
|
3932
|
+
available_videos = [vp for vp in self.results_cache.keys() if isinstance(self.results_cache.get(vp), dict)]
|
|
3933
|
+
if not available_videos:
|
|
3934
|
+
available_videos = [self.video_path]
|
|
3935
|
+
|
|
3936
|
+
export_selection = self._prompt_timeline_export_options(available_videos)
|
|
3937
|
+
if not export_selection:
|
|
3938
|
+
return
|
|
3939
|
+
selected_videos, selected_classes = export_selection
|
|
3940
|
+
|
|
3941
|
+
if len(selected_videos) == 1:
|
|
3942
|
+
default_base = os.path.splitext(selected_videos[0])[0] + "_timeline"
|
|
3943
|
+
base_path, _ = QFileDialog.getSaveFileName(
|
|
3944
|
+
self,
|
|
3945
|
+
"Export Timeline",
|
|
3946
|
+
default_base,
|
|
3947
|
+
"All Files (*)"
|
|
3948
|
+
)
|
|
3949
|
+
if not base_path:
|
|
3950
|
+
return
|
|
3951
|
+
export_root = os.path.splitext(base_path)[0]
|
|
3952
|
+
else:
|
|
3953
|
+
default_dir = os.path.dirname(self.video_path) if self.video_path else os.getcwd()
|
|
3954
|
+
export_root = QFileDialog.getExistingDirectory(self, "Select Export Folder", default_dir)
|
|
3955
|
+
if not export_root:
|
|
3956
|
+
return
|
|
3957
|
+
|
|
3958
|
+
original_video_path = self.video_path
|
|
3959
|
+
original_threshold_settings = self._current_threshold_settings()
|
|
3960
|
+
shared_threshold_settings = dict(original_threshold_settings)
|
|
3961
|
+
exported = []
|
|
3962
|
+
n_videos = len(selected_videos)
|
|
3963
|
+
show_progress = n_videos > 1
|
|
3964
|
+
progress = None
|
|
3965
|
+
export_canceled = False
|
|
3966
|
+
if show_progress:
|
|
3967
|
+
progress = QProgressDialog(
|
|
3968
|
+
"Exporting timeline 1 / {}...".format(n_videos),
|
|
3969
|
+
"Cancel",
|
|
3970
|
+
0,
|
|
3971
|
+
n_videos,
|
|
3972
|
+
self,
|
|
3973
|
+
)
|
|
3974
|
+
progress.setWindowTitle("Timeline export")
|
|
3975
|
+
progress.setWindowModality(Qt.WindowModality.WindowModal)
|
|
3976
|
+
progress.setMinimumDuration(0)
|
|
3977
|
+
progress.setValue(0)
|
|
3978
|
+
progress.show()
|
|
3979
|
+
QApplication.processEvents()
|
|
3980
|
+
|
|
3981
|
+
try:
|
|
3982
|
+
for vi, video_path in enumerate(selected_videos):
|
|
3983
|
+
if progress and progress.wasCanceled():
|
|
3984
|
+
export_canceled = True
|
|
3985
|
+
break
|
|
3986
|
+
if show_progress:
|
|
3987
|
+
progress.setValue(vi)
|
|
3988
|
+
progress.setLabelText(
|
|
3989
|
+
"Exporting timeline {} / {}: {}".format(
|
|
3990
|
+
vi + 1, n_videos, os.path.basename(video_path)
|
|
3991
|
+
)
|
|
3992
|
+
)
|
|
3993
|
+
QApplication.processEvents()
|
|
3994
|
+
|
|
3995
|
+
entry = self.results_cache.get(video_path, {})
|
|
3996
|
+
threshold_override = None
|
|
3997
|
+
if not isinstance(entry.get("threshold_settings"), dict):
|
|
3998
|
+
threshold_override = shared_threshold_settings
|
|
3999
|
+
ok = self._load_video_from_cache(
|
|
4000
|
+
video_path,
|
|
4001
|
+
refresh_display=False,
|
|
4002
|
+
persist_current=False,
|
|
4003
|
+
threshold_settings_override=threshold_override,
|
|
4004
|
+
persist_loaded_thresholds=False,
|
|
4005
|
+
)
|
|
4006
|
+
if not ok:
|
|
4007
|
+
continue
|
|
4008
|
+
|
|
4009
|
+
if len(selected_videos) == 1:
|
|
4010
|
+
base_path = export_root
|
|
4011
|
+
else:
|
|
4012
|
+
base_path = os.path.join(
|
|
4013
|
+
export_root,
|
|
4014
|
+
os.path.splitext(os.path.basename(video_path))[0] + "_timeline",
|
|
4015
|
+
)
|
|
4016
|
+
csv_path, svg_path, mode_text = self._export_current_timeline_to_base_path(
|
|
4017
|
+
base_path,
|
|
4018
|
+
selected_classes=selected_classes,
|
|
4019
|
+
)
|
|
4020
|
+
exported.append((video_path, csv_path, svg_path, mode_text))
|
|
4021
|
+
if show_progress:
|
|
4022
|
+
progress.setValue(vi + 1)
|
|
4023
|
+
QApplication.processEvents()
|
|
4024
|
+
except Exception as e:
|
|
4025
|
+
QMessageBox.critical(self, "Export Error", f"Failed to export timeline:\n{str(e)}")
|
|
4026
|
+
if progress:
|
|
4027
|
+
progress.close()
|
|
4028
|
+
return
|
|
4029
|
+
finally:
|
|
4030
|
+
if progress:
|
|
4031
|
+
progress.close()
|
|
4032
|
+
if original_video_path and original_video_path in self.results_cache:
|
|
4033
|
+
self._load_video_from_cache(
|
|
4034
|
+
original_video_path,
|
|
4035
|
+
refresh_display=True,
|
|
4036
|
+
persist_current=False,
|
|
4037
|
+
threshold_settings_override=original_threshold_settings,
|
|
4038
|
+
persist_loaded_thresholds=False,
|
|
4039
|
+
)
|
|
4040
|
+
idx = self.filter_video_combo.findData(original_video_path)
|
|
4041
|
+
if idx >= 0:
|
|
4042
|
+
self.filter_video_combo.blockSignals(True)
|
|
4043
|
+
self.filter_video_combo.setCurrentIndex(idx)
|
|
4044
|
+
self.filter_video_combo.blockSignals(False)
|
|
4045
|
+
|
|
4046
|
+
if not exported:
|
|
4047
|
+
msg = "Export canceled." if export_canceled else "No timelines were exported."
|
|
4048
|
+
QMessageBox.warning(self, "Export", msg)
|
|
4049
|
+
return
|
|
4050
|
+
|
|
4051
|
+
if len(exported) == 1:
|
|
4052
|
+
_, csv_path, svg_path, mode_text = exported[0]
|
|
4053
|
+
QMessageBox.information(
|
|
4054
|
+
self,
|
|
4055
|
+
"Export Complete",
|
|
4056
|
+
f"Timeline exported successfully!\n\n"
|
|
4057
|
+
f"Mode: {mode_text}\n"
|
|
4058
|
+
f"CSV: {csv_path}\n"
|
|
4059
|
+
f"SVG: {svg_path}"
|
|
4060
|
+
)
|
|
4061
|
+
else:
|
|
4062
|
+
QMessageBox.information(
|
|
4063
|
+
self,
|
|
4064
|
+
"Batch Export Complete",
|
|
4065
|
+
f"Exported timelines for {len(exported)} videos to:\n{export_root}"
|
|
4066
|
+
)
|
|
4067
|
+
|
|
4068
|
+
def _prompt_timeline_export_options(self, available_videos):
|
|
4069
|
+
dlg = QDialog(self)
|
|
4070
|
+
dlg.setWindowTitle("Batch Timeline Export")
|
|
4071
|
+
dlg.resize(760, 460)
|
|
4072
|
+
layout = QVBoxLayout(dlg)
|
|
4073
|
+
layout.addWidget(QLabel("Choose which videos and classes to export. All classes are included by default."))
|
|
4074
|
+
|
|
4075
|
+
content_row = QHBoxLayout()
|
|
4076
|
+
|
|
4077
|
+
video_col = QVBoxLayout()
|
|
4078
|
+
video_col.addWidget(QLabel("Videos"))
|
|
4079
|
+
video_button_row = QHBoxLayout()
|
|
4080
|
+
select_all_btn = QPushButton("Select all")
|
|
4081
|
+
current_btn = QPushButton("Current video")
|
|
4082
|
+
clear_btn = QPushButton("Clear")
|
|
4083
|
+
video_button_row.addWidget(select_all_btn)
|
|
4084
|
+
video_button_row.addWidget(current_btn)
|
|
4085
|
+
video_button_row.addWidget(clear_btn)
|
|
4086
|
+
video_col.addLayout(video_button_row)
|
|
4087
|
+
|
|
4088
|
+
list_widget = QListWidget()
|
|
4089
|
+
current_video = self.video_path
|
|
4090
|
+
for vp in available_videos:
|
|
4091
|
+
item = QListWidgetItem(os.path.basename(vp))
|
|
4092
|
+
item.setData(Qt.ItemDataRole.UserRole, vp)
|
|
4093
|
+
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable | Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable)
|
|
4094
|
+
default_checked = len(available_videos) == 1 or vp == current_video
|
|
4095
|
+
item.setCheckState(Qt.CheckState.Checked if default_checked else Qt.CheckState.Unchecked)
|
|
4096
|
+
list_widget.addItem(item)
|
|
4097
|
+
video_col.addWidget(list_widget, stretch=1)
|
|
4098
|
+
content_row.addLayout(video_col, stretch=1)
|
|
4099
|
+
|
|
4100
|
+
class_col = QVBoxLayout()
|
|
4101
|
+
class_col.addWidget(QLabel("Classes"))
|
|
4102
|
+
class_button_row = QHBoxLayout()
|
|
4103
|
+
class_all_btn = QPushButton("All")
|
|
4104
|
+
class_none_btn = QPushButton("None")
|
|
4105
|
+
class_button_row.addWidget(class_all_btn)
|
|
4106
|
+
class_button_row.addWidget(class_none_btn)
|
|
4107
|
+
class_col.addLayout(class_button_row)
|
|
4108
|
+
|
|
4109
|
+
class_list_widget = QListWidget()
|
|
4110
|
+
for cls_idx, cls_name in enumerate(self.classes):
|
|
4111
|
+
item = QListWidgetItem(cls_name)
|
|
4112
|
+
item.setData(Qt.ItemDataRole.UserRole, cls_idx)
|
|
4113
|
+
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable | Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable)
|
|
4114
|
+
item.setCheckState(Qt.CheckState.Checked)
|
|
4115
|
+
class_list_widget.addItem(item)
|
|
4116
|
+
class_col.addWidget(class_list_widget, stretch=1)
|
|
4117
|
+
content_row.addLayout(class_col, stretch=1)
|
|
4118
|
+
|
|
4119
|
+
layout.addLayout(content_row, stretch=1)
|
|
4120
|
+
|
|
4121
|
+
def _set_all_items(state):
|
|
4122
|
+
for i in range(list_widget.count()):
|
|
4123
|
+
list_widget.item(i).setCheckState(state)
|
|
4124
|
+
|
|
4125
|
+
def _set_all_classes(state):
|
|
4126
|
+
for i in range(class_list_widget.count()):
|
|
4127
|
+
class_list_widget.item(i).setCheckState(state)
|
|
4128
|
+
|
|
4129
|
+
def _select_current_only():
|
|
4130
|
+
for i in range(list_widget.count()):
|
|
4131
|
+
item = list_widget.item(i)
|
|
4132
|
+
is_current = item.data(Qt.ItemDataRole.UserRole) == current_video
|
|
4133
|
+
item.setCheckState(Qt.CheckState.Checked if is_current else Qt.CheckState.Unchecked)
|
|
4134
|
+
|
|
4135
|
+
select_all_btn.clicked.connect(lambda: _set_all_items(Qt.CheckState.Checked))
|
|
4136
|
+
clear_btn.clicked.connect(lambda: _set_all_items(Qt.CheckState.Unchecked))
|
|
4137
|
+
current_btn.clicked.connect(_select_current_only)
|
|
4138
|
+
class_all_btn.clicked.connect(lambda: _set_all_classes(Qt.CheckState.Checked))
|
|
4139
|
+
class_none_btn.clicked.connect(lambda: _set_all_classes(Qt.CheckState.Unchecked))
|
|
4140
|
+
|
|
4141
|
+
btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
|
4142
|
+
layout.addWidget(btns)
|
|
4143
|
+
btns.accepted.connect(dlg.accept)
|
|
4144
|
+
btns.rejected.connect(dlg.reject)
|
|
4145
|
+
|
|
4146
|
+
if not dlg.exec():
|
|
4147
|
+
return None
|
|
4148
|
+
|
|
4149
|
+
selected = []
|
|
4150
|
+
for i in range(list_widget.count()):
|
|
4151
|
+
item = list_widget.item(i)
|
|
4152
|
+
if item.checkState() == Qt.CheckState.Checked:
|
|
4153
|
+
selected.append(item.data(Qt.ItemDataRole.UserRole))
|
|
4154
|
+
|
|
4155
|
+
if not selected:
|
|
4156
|
+
QMessageBox.information(self, "No videos selected", "Select at least one video to export.")
|
|
4157
|
+
return None
|
|
4158
|
+
|
|
4159
|
+
selected_classes = set()
|
|
4160
|
+
for i in range(class_list_widget.count()):
|
|
4161
|
+
item = class_list_widget.item(i)
|
|
4162
|
+
if item.checkState() == Qt.CheckState.Checked:
|
|
4163
|
+
selected_classes.add(int(item.data(Qt.ItemDataRole.UserRole)))
|
|
4164
|
+
|
|
4165
|
+
if not selected_classes:
|
|
4166
|
+
QMessageBox.information(self, "No classes selected", "Select at least one class to export.")
|
|
4167
|
+
return None
|
|
4168
|
+
|
|
4169
|
+
return selected, selected_classes
|
|
4170
|
+
|
|
4171
|
+
def _export_current_timeline_to_base_path(self, base_path: str, selected_classes: set[int] | None = None):
|
|
4172
|
+
if not self.predictions or not self.video_path:
|
|
4173
|
+
raise RuntimeError("No predictions available to export.")
|
|
4174
|
+
|
|
4175
|
+
orig_fps = self._get_video_fps(self.video_path)
|
|
4176
|
+
frame_interval = self._get_saved_frame_interval(self.video_path, orig_fps)
|
|
4177
|
+
clip_length = self.clip_length_spin.value()
|
|
4178
|
+
frame_aggregation_enabled = self.frame_aggregation_check.isChecked()
|
|
4179
|
+
|
|
4180
|
+
csv_path = base_path + "_behaviors.csv"
|
|
4181
|
+
import csv
|
|
4182
|
+
|
|
4183
|
+
if frame_aggregation_enabled and self.aggregated_segments:
|
|
4184
|
+
with open(csv_path, 'w', newline='') as f:
|
|
4185
|
+
writer = csv.writer(f)
|
|
4186
|
+
writer.writerow(['Behavior', 'Start Time (s)', 'End Time (s)', 'Start Frame', 'End Frame', 'Duration (s)', 'Confidence'])
|
|
4187
|
+
|
|
4188
|
+
seg_source = self.aggregated_segments
|
|
4189
|
+
if self._use_ovr and self.aggregated_multiclass_segments:
|
|
4190
|
+
seg_source = self.aggregated_multiclass_segments
|
|
4191
|
+
|
|
4192
|
+
for seg in seg_source:
|
|
4193
|
+
pred_idx = seg['class']
|
|
4194
|
+
if pred_idx < len(self.classes) and (selected_classes is None or pred_idx in selected_classes):
|
|
4195
|
+
behavior = self.classes[pred_idx]
|
|
4196
|
+
start_frame = seg['start']
|
|
4197
|
+
end_frame = seg['end']
|
|
4198
|
+
conf = seg.get('confidence', 1.0)
|
|
4199
|
+
|
|
4200
|
+
start_time = start_frame / orig_fps
|
|
4201
|
+
end_time = (end_frame + 1) / orig_fps
|
|
4202
|
+
duration = end_time - start_time
|
|
4203
|
+
|
|
4204
|
+
writer.writerow([
|
|
4205
|
+
behavior,
|
|
4206
|
+
f"{start_time:.3f}",
|
|
4207
|
+
f"{end_time:.3f}",
|
|
4208
|
+
start_frame,
|
|
4209
|
+
end_frame,
|
|
4210
|
+
f"{duration:.3f}",
|
|
4211
|
+
f"{conf:.3f}"
|
|
4212
|
+
])
|
|
4213
|
+
|
|
4214
|
+
svg_path = base_path + "_timeline.svg"
|
|
4215
|
+
self._export_frame_aggregated_svg(svg_path, orig_fps, selected_classes=selected_classes)
|
|
4216
|
+
mode_text = "frame-aggregated (precise boundaries)"
|
|
4217
|
+
else:
|
|
4218
|
+
corrected_preds = self._effective_predictions()
|
|
4219
|
+
|
|
4220
|
+
if self.merge_timeline_check.isChecked():
|
|
4221
|
+
display_preds, display_confs, display_starts = self._merge_predictions(
|
|
4222
|
+
corrected_preds, self.confidences, self.clip_starts
|
|
4223
|
+
)
|
|
4224
|
+
else:
|
|
4225
|
+
display_preds = corrected_preds
|
|
4226
|
+
display_confs = self.confidences
|
|
4227
|
+
display_starts = self.clip_starts
|
|
4228
|
+
|
|
4229
|
+
with open(csv_path, 'w', newline='') as f:
|
|
4230
|
+
writer = csv.writer(f)
|
|
4231
|
+
writer.writerow(['Behavior', 'Start Time (s)', 'End Time (s)', 'Start Frame', 'End Frame', 'Confidence'])
|
|
4232
|
+
|
|
4233
|
+
for i, (pred_idx, conf, start_frame) in enumerate(zip(display_preds, display_confs, display_starts)):
|
|
4234
|
+
if pred_idx < len(self.classes) and pred_idx >= 0 and (selected_classes is None or pred_idx in selected_classes):
|
|
4235
|
+
behavior = self.classes[pred_idx]
|
|
4236
|
+
if i < len(display_starts) - 1:
|
|
4237
|
+
end_frame = display_starts[i + 1]
|
|
4238
|
+
else:
|
|
4239
|
+
end_frame = start_frame + (clip_length * frame_interval)
|
|
4240
|
+
start_time = start_frame / orig_fps
|
|
4241
|
+
end_time = end_frame / orig_fps
|
|
4242
|
+
writer.writerow([behavior, f"{start_time:.3f}", f"{end_time:.3f}", start_frame, end_frame, f"{conf:.3f}"])
|
|
4243
|
+
|
|
4244
|
+
svg_path = base_path + "_timeline.svg"
|
|
4245
|
+
self._export_timeline_svg(
|
|
4246
|
+
svg_path,
|
|
4247
|
+
display_preds,
|
|
4248
|
+
display_confs,
|
|
4249
|
+
display_starts,
|
|
4250
|
+
orig_fps,
|
|
4251
|
+
frame_interval,
|
|
4252
|
+
clip_length,
|
|
4253
|
+
selected_classes=selected_classes,
|
|
4254
|
+
)
|
|
4255
|
+
mode_text = "clip-based"
|
|
4256
|
+
|
|
4257
|
+
return csv_path, svg_path, mode_text
|
|
4258
|
+
|
|
4259
|
+
def _export_timeline_svg(self, svg_path, display_preds, display_confs, display_starts, orig_fps, frame_interval, clip_length, selected_classes: set[int] | None = None):
|
|
4260
|
+
"""Export timeline as SVG."""
|
|
4261
|
+
num_segments = len(display_preds)
|
|
4262
|
+
if num_segments == 0:
|
|
4263
|
+
return
|
|
4264
|
+
|
|
4265
|
+
# Calculate dimensions
|
|
4266
|
+
base_clip_width = 20
|
|
4267
|
+
total_clips = len(self.predictions)
|
|
4268
|
+
width = max(1200, total_clips * base_clip_width)
|
|
4269
|
+
height = 80
|
|
4270
|
+
|
|
4271
|
+
colors = self._get_timeline_palette()
|
|
4272
|
+
|
|
4273
|
+
with open(svg_path, 'w') as f:
|
|
4274
|
+
f.write(f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">\n')
|
|
4275
|
+
f.write(f'<rect width="{width}" height="{height}" fill="white"/>\n')
|
|
4276
|
+
|
|
4277
|
+
x_pos = 0
|
|
4278
|
+
for seg_idx, (pred_idx, conf) in enumerate(zip(display_preds, display_confs)):
|
|
4279
|
+
if seg_idx < len(display_starts) - 1:
|
|
4280
|
+
seg_start_clip = display_starts[seg_idx]
|
|
4281
|
+
seg_end_clip = display_starts[seg_idx + 1]
|
|
4282
|
+
clip_count = sum(1 for orig_start in self.clip_starts if seg_start_clip <= orig_start < seg_end_clip)
|
|
4283
|
+
seg_width = clip_count * base_clip_width
|
|
4284
|
+
else:
|
|
4285
|
+
seg_start_clip = display_starts[seg_idx]
|
|
4286
|
+
clip_count = sum(1 for orig_start in self.clip_starts if orig_start >= seg_start_clip)
|
|
4287
|
+
seg_width = clip_count * base_clip_width
|
|
4288
|
+
|
|
4289
|
+
if pred_idx < len(self.classes) and (selected_classes is None or pred_idx in selected_classes):
|
|
4290
|
+
color = colors[pred_idx % len(colors)]
|
|
4291
|
+
r, g, b = color
|
|
4292
|
+
alpha = conf
|
|
4293
|
+
|
|
4294
|
+
f.write(f'<rect x="{x_pos}" y="0" width="{seg_width}" height="{height}" '
|
|
4295
|
+
f'fill="rgb({r},{g},{b})" opacity="{alpha:.2f}"/>\n')
|
|
4296
|
+
|
|
4297
|
+
if seg_width >= 30:
|
|
4298
|
+
behavior = self.classes[pred_idx]
|
|
4299
|
+
f.write(f'<text x="{x_pos + 5}" y="{height//2 + 5}" font-family="Arial" font-size="10">{behavior}</text>\n')
|
|
4300
|
+
|
|
4301
|
+
x_pos += seg_width
|
|
4302
|
+
|
|
4303
|
+
f.write('</svg>\n')
|
|
4304
|
+
|
|
4305
|
+
def _export_frame_aggregated_svg(self, svg_path: str, orig_fps: float, selected_classes: set[int] | None = None):
|
|
4306
|
+
"""Export frame-aggregated timeline as SVG with precise boundaries."""
|
|
4307
|
+
if not self.aggregated_segments:
|
|
4308
|
+
return
|
|
4309
|
+
|
|
4310
|
+
use_multiclass_rows = bool(self._use_ovr and self.aggregated_multiclass_segments)
|
|
4311
|
+
segments = self.aggregated_multiclass_segments if use_multiclass_rows else self.aggregated_segments
|
|
4312
|
+
row_classes = [ci for ci in range(len(self.classes)) if selected_classes is None or ci in selected_classes]
|
|
4313
|
+
if not row_classes:
|
|
4314
|
+
return
|
|
4315
|
+
row_lookup = {cls_idx: row_idx for row_idx, cls_idx in enumerate(row_classes)}
|
|
4316
|
+
total_frames = self.total_frames
|
|
4317
|
+
if total_frames <= 0 and self.video_path and os.path.exists(self.video_path):
|
|
4318
|
+
try:
|
|
4319
|
+
cap = cv2.VideoCapture(self.video_path)
|
|
4320
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
4321
|
+
cap.release()
|
|
4322
|
+
except Exception:
|
|
4323
|
+
pass
|
|
4324
|
+
if total_frames <= 0 and segments:
|
|
4325
|
+
total_frames = segments[-1]['end'] + 1
|
|
4326
|
+
|
|
4327
|
+
# Calculate dimensions
|
|
4328
|
+
min_width = 1200
|
|
4329
|
+
max_width = 4000
|
|
4330
|
+
pixels_per_frame = max(0.1, min(2.0, max_width / max(1, total_frames)))
|
|
4331
|
+
width = max(min_width, int(total_frames * pixels_per_frame))
|
|
4332
|
+
if use_multiclass_rows:
|
|
4333
|
+
row_h = 20
|
|
4334
|
+
height = max(80, row_h * len(row_classes))
|
|
4335
|
+
else:
|
|
4336
|
+
row_h = 80
|
|
4337
|
+
height = 80
|
|
4338
|
+
|
|
4339
|
+
colors = self._get_timeline_palette()
|
|
4340
|
+
|
|
4341
|
+
with open(svg_path, 'w') as f:
|
|
4342
|
+
f.write(f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">\n')
|
|
4343
|
+
f.write(f'<rect width="{width}" height="{height}" fill="white"/>\n')
|
|
4344
|
+
|
|
4345
|
+
for seg in segments:
|
|
4346
|
+
pred_idx = seg['class']
|
|
4347
|
+
conf = seg.get('confidence', 1.0)
|
|
4348
|
+
start_frame = seg['start']
|
|
4349
|
+
end_frame = seg['end']
|
|
4350
|
+
|
|
4351
|
+
if pred_idx >= len(self.classes) or (selected_classes is not None and pred_idx not in selected_classes):
|
|
4352
|
+
continue
|
|
4353
|
+
|
|
4354
|
+
x_start = int(start_frame * pixels_per_frame)
|
|
4355
|
+
x_end = int((end_frame + 1) * pixels_per_frame)
|
|
4356
|
+
seg_width = max(1, x_end - x_start)
|
|
4357
|
+
|
|
4358
|
+
color = colors[pred_idx % len(colors)]
|
|
4359
|
+
r, g, b = color
|
|
4360
|
+
# Normalize confidence for opacity
|
|
4361
|
+
alpha = min(1.0, 0.5 + 0.5 * min(1.0, conf))
|
|
4362
|
+
|
|
4363
|
+
y0 = (row_lookup[pred_idx] * row_h) if use_multiclass_rows else 0
|
|
4364
|
+
h0 = row_h if use_multiclass_rows else height
|
|
4365
|
+
f.write(f'<rect x="{x_start}" y="{y0}" width="{seg_width}" height="{h0}" '
|
|
4366
|
+
f'fill="rgb({r},{g},{b})" opacity="{alpha:.2f}"/>\n')
|
|
4367
|
+
|
|
4368
|
+
# Draw boundary line
|
|
4369
|
+
f.write(f'<line x1="{x_start}" y1="{y0}" x2="{x_start}" y2="{y0 + h0}" '
|
|
4370
|
+
f'stroke="black" stroke-width="0.5" opacity="0.3"/>\n')
|
|
4371
|
+
|
|
4372
|
+
if seg_width >= 40:
|
|
4373
|
+
behavior = self.classes[pred_idx]
|
|
4374
|
+
start_time = start_frame / orig_fps
|
|
4375
|
+
end_time = (end_frame + 1) / orig_fps
|
|
4376
|
+
ty = y0 + min(15, h0 - 5)
|
|
4377
|
+
f.write(f'<text x="{x_start + 3}" y="{ty}" font-family="Arial" font-size="10">{behavior}</text>\n')
|
|
4378
|
+
if not use_multiclass_rows and h0 >= 30:
|
|
4379
|
+
f.write(f'<text x="{x_start + 3}" y="30" font-family="Arial" font-size="8" fill="gray">'
|
|
4380
|
+
f'{start_time:.2f}s-{end_time:.2f}s</text>\n')
|
|
4381
|
+
|
|
4382
|
+
# Add time axis markers
|
|
4383
|
+
duration_sec = total_frames / orig_fps
|
|
4384
|
+
tick_interval = max(1, int(duration_sec / 20)) # ~20 ticks
|
|
4385
|
+
for t in range(0, int(duration_sec) + 1, tick_interval):
|
|
4386
|
+
x = int(t * orig_fps * pixels_per_frame)
|
|
4387
|
+
if x < width:
|
|
4388
|
+
f.write(f'<line x1="{x}" y1="{height-10}" x2="{x}" y2="{height}" stroke="black" stroke-width="1"/>\n')
|
|
4389
|
+
f.write(f'<text x="{x}" y="{height-2}" font-family="Arial" font-size="8" text-anchor="middle">{t}s</text>\n')
|
|
4390
|
+
|
|
4391
|
+
f.write('</svg>\n')
|
|
4392
|
+
|
|
4393
|
+
def _show_clip_popup(self, idx: int, frame_mode: bool = False, ovr_class_idx: int = -1):
|
|
4394
|
+
"""Show popup dialog with clip video, label, and confidence.
|
|
4395
|
+
|
|
4396
|
+
Args:
|
|
4397
|
+
idx: Clip index (clip mode) or frame index (frame mode)
|
|
4398
|
+
frame_mode: If True, idx is a frame index; find the segment and show video at that frame
|
|
4399
|
+
ovr_class_idx: When clicking a per-class OvR row, the class index of the
|
|
4400
|
+
clicked row. If >= 0, the popup shows this specific class instead of the
|
|
4401
|
+
top-1 class from the aggregated timeline.
|
|
4402
|
+
"""
|
|
4403
|
+
if not self.video_path:
|
|
4404
|
+
return
|
|
4405
|
+
|
|
4406
|
+
# In frame mode, find the corresponding segment and show info for that frame
|
|
4407
|
+
if frame_mode and self.aggregated_segments:
|
|
4408
|
+
frame_idx = idx
|
|
4409
|
+
|
|
4410
|
+
# If a specific OvR class row was clicked, build a virtual segment
|
|
4411
|
+
# for that class at this frame position so the user can inspect it.
|
|
4412
|
+
if ovr_class_idx >= 0 and ovr_class_idx < len(self.classes):
|
|
4413
|
+
segment = self._build_ovr_class_segment(frame_idx, ovr_class_idx)
|
|
4414
|
+
if segment is not None:
|
|
4415
|
+
self._show_frame_segment_popup(frame_idx, segment)
|
|
4416
|
+
return
|
|
4417
|
+
|
|
4418
|
+
# Default: find segment from the main aggregated timeline
|
|
4419
|
+
segment = None
|
|
4420
|
+
for seg in self.aggregated_segments:
|
|
4421
|
+
if seg['start'] <= frame_idx <= seg['end']:
|
|
4422
|
+
segment = seg
|
|
4423
|
+
break
|
|
4424
|
+
|
|
4425
|
+
if segment is None:
|
|
4426
|
+
return
|
|
4427
|
+
|
|
4428
|
+
pred_idx = segment['class']
|
|
4429
|
+
conf = segment.get('confidence', 1.0)
|
|
4430
|
+
|
|
4431
|
+
if pred_idx >= len(self.classes):
|
|
4432
|
+
return
|
|
4433
|
+
|
|
4434
|
+
label = self.classes[pred_idx]
|
|
4435
|
+
|
|
4436
|
+
# Show frame-specific popup
|
|
4437
|
+
self._show_frame_segment_popup(frame_idx, segment)
|
|
4438
|
+
return
|
|
4439
|
+
|
|
4440
|
+
# Original clip-based mode
|
|
4441
|
+
clip_idx = idx
|
|
4442
|
+
if clip_idx >= len(self.predictions):
|
|
4443
|
+
return
|
|
4444
|
+
ClipPopupDialog(self, self, clip_idx)
|
|
4445
|
+
|
|
4446
|
+
def _build_ovr_class_segment(self, frame_idx: int, class_idx: int) -> dict | None:
|
|
4447
|
+
"""Build a virtual segment for a specific OvR class around frame_idx.
|
|
4448
|
+
|
|
4449
|
+
Finds the contiguous run of active frames for class_idx that contains
|
|
4450
|
+
frame_idx, using the stored per-frame scores / active mask.
|
|
4451
|
+
Returns a segment dict compatible with FrameSegmentPopupDialog, or None.
|
|
4452
|
+
"""
|
|
4453
|
+
active_mask = getattr(self, "_aggregated_active_mask", None)
|
|
4454
|
+
frame_scores = getattr(self, "_aggregated_frame_scores_norm", None)
|
|
4455
|
+
if not isinstance(active_mask, np.ndarray) or class_idx >= active_mask.shape[1]:
|
|
4456
|
+
return None
|
|
4457
|
+
total_frames = active_mask.shape[0]
|
|
4458
|
+
if frame_idx < 0 or frame_idx >= total_frames:
|
|
4459
|
+
return None
|
|
4460
|
+
|
|
4461
|
+
is_active = bool(active_mask[frame_idx, class_idx])
|
|
4462
|
+
# Even if the frame isn't active for this class, still show a
|
|
4463
|
+
# single-frame segment so the user can inspect its score.
|
|
4464
|
+
if not is_active:
|
|
4465
|
+
conf = float(frame_scores[frame_idx, class_idx]) if isinstance(frame_scores, np.ndarray) else 0.0
|
|
4466
|
+
return {"class": class_idx, "start": frame_idx, "end": frame_idx, "confidence": conf}
|
|
4467
|
+
|
|
4468
|
+
# Expand outward to find the contiguous active run
|
|
4469
|
+
start = frame_idx
|
|
4470
|
+
while start > 0 and bool(active_mask[start - 1, class_idx]):
|
|
4471
|
+
start -= 1
|
|
4472
|
+
end = frame_idx
|
|
4473
|
+
while end < total_frames - 1 and bool(active_mask[end + 1, class_idx]):
|
|
4474
|
+
end += 1
|
|
4475
|
+
|
|
4476
|
+
conf = float(np.mean(frame_scores[start:end + 1, class_idx])) if isinstance(frame_scores, np.ndarray) else 1.0
|
|
4477
|
+
return {"class": class_idx, "start": start, "end": end, "confidence": conf}
|
|
4478
|
+
|
|
4479
|
+
def _show_frame_segment_popup(self, frame_idx: int, segment: dict, segment_idx: int = None):
|
|
4480
|
+
"""Show popup for a frame-aggregated segment.
|
|
4481
|
+
|
|
4482
|
+
Args:
|
|
4483
|
+
frame_idx: The specific frame that was clicked
|
|
4484
|
+
segment: The segment dict with 'class', 'start', 'end', 'confidence'
|
|
4485
|
+
segment_idx: Index of the segment in self.aggregated_segments (for navigation)
|
|
4486
|
+
"""
|
|
4487
|
+
if not self.video_path:
|
|
4488
|
+
return
|
|
4489
|
+
|
|
4490
|
+
# Find segment index if not provided
|
|
4491
|
+
if segment_idx is None:
|
|
4492
|
+
for i, seg in enumerate(self.aggregated_segments):
|
|
4493
|
+
if seg['start'] == segment['start'] and seg['end'] == segment['end']:
|
|
4494
|
+
segment_idx = i
|
|
4495
|
+
break
|
|
4496
|
+
|
|
4497
|
+
FrameSegmentPopupDialog(self, self, frame_idx, segment, segment_idx)
|
|
4498
|
+
|
|
4499
|
+
def _export_video_with_overlay(self):
|
|
4500
|
+
"""Export video with configurable overlays (delegates to overlay_export module)."""
|
|
4501
|
+
from .overlay_export import run_export_video_with_overlay
|
|
4502
|
+
run_export_video_with_overlay(self)
|
|
4503
|
+
|
|
4504
|
+
def _preview_video_with_overlay(self):
|
|
4505
|
+
"""Open video player to preview the exported video with overlays."""
|
|
4506
|
+
from .overlay_export import run_preview_video_with_overlay
|
|
4507
|
+
run_preview_video_with_overlay(self)
|
|
4508
|
+
|
|
4509
|
+
def _export_attention_heatmap(self):
|
|
4510
|
+
"""Export video with spatial attention heatmap overlay."""
|
|
4511
|
+
from .attention_export import export_attention_heatmap_video
|
|
4512
|
+
export_attention_heatmap_video(self)
|
|
4513
|
+
|
|
4514
|
+
def update_config(self, config: dict):
|
|
4515
|
+
"""Apply a new configuration (experiment switch)."""
|
|
4516
|
+
self.config = config
|
|
4517
|
+
self.model = None
|
|
4518
|
+
self.classes = []
|
|
4519
|
+
self.model_path_edit.clear()
|
|
4520
|
+
self.video_path = None
|
|
4521
|
+
self.video_path_edit.clear()
|
|
4522
|
+
self.video_info_label.setText("No video selected")
|
|
4523
|
+
self.run_inference_btn.setEnabled(False)
|
|
4524
|
+
self.export_btn.setEnabled(False)
|
|
4525
|
+
self.export_timeline_btn.setEnabled(False)
|
|
4526
|
+
self.preview_btn.setEnabled(False)
|
|
4527
|
+
self.results_list.clear()
|
|
4528
|
+
self.progress_label.setText("")
|
|
4529
|
+
self.progress_bar.setVisible(False)
|
|
4530
|
+
self.progress_bar.setValue(0)
|
|
4531
|
+
self.predictions = []
|
|
4532
|
+
self.confidences = []
|
|
4533
|
+
self.clip_probabilities = []
|
|
4534
|
+
self.clip_frame_probabilities = []
|
|
4535
|
+
self.clip_starts = []
|
|
4536
|
+
self.corrected_labels = {}
|
|
4537
|
+
self.corrected_attr_labels = {}
|
|
4538
|
+
self.aggregated_segments = []
|
|
4539
|
+
self.aggregated_multiclass_segments = []
|
|
4540
|
+
self._aggregated_frame_scores_norm = None
|
|
4541
|
+
self._aggregated_active_mask = None
|
|
4542
|
+
self._aggregated_last_covered_frame = 0
|
|
4543
|
+
|
|
4544
|
+
timeline_layout = self.timeline_widget.layout()
|
|
4545
|
+
if timeline_layout:
|
|
4546
|
+
while timeline_layout.count():
|
|
4547
|
+
child = timeline_layout.takeAt(0)
|
|
4548
|
+
if child.widget():
|
|
4549
|
+
child.widget().deleteLater()
|
|
4550
|
+
|