singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,1138 @@
1
+ """Popup dialogs for inference clip and frame-segment inspection."""
2
+
3
+ from PyQt6.QtWidgets import (
4
+ QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QComboBox,
5
+ QGroupBox, QMessageBox, QSizePolicy, QSpinBox,
6
+ )
7
+ from PyQt6.QtCore import Qt, QTimer
8
+ from PyQt6.QtGui import QImage, QPixmap
9
+ import cv2
10
+ import os
11
+ import numpy as np
12
+ from singlebehaviorlab.backend.video_utils import save_clip
13
+ from singlebehaviorlab.backend.data_store import AnnotationManager
14
+
15
+
16
+ class ClipPopupDialog(QDialog):
17
+ """Dialog showing a single inference clip with label, correction, and training controls."""
18
+
19
+ def __init__(self, parent, widget, clip_idx):
20
+ super().__init__(parent)
21
+ self._widget = widget
22
+ self._clip_idx = clip_idx
23
+
24
+ pred_idx = self._widget._effective_prediction_for_clip(self._clip_idx)
25
+ conf = self._widget.confidences[self._clip_idx]
26
+ label = self._widget.classes[pred_idx] if (0 <= pred_idx < len(self._widget.classes)) else self._widget.ignore_label_name
27
+
28
+ attr_info = ""
29
+ attr_idx = self._widget._get_attr_idx(self._clip_idx)
30
+ if self._widget.attributes and isinstance(attr_idx, int) and attr_idx < len(self._widget.attributes):
31
+ attr_label = self._widget.attributes[attr_idx]
32
+ attr_conf = 0.0
33
+ if self._widget.attr_confidences and self._clip_idx < len(self._widget.attr_confidences):
34
+ attr_conf = self._widget.attr_confidences[self._clip_idx]
35
+ attr_info = f"<br><b>Attribute:</b> {attr_label} ({attr_conf:.2%})"
36
+
37
+ bbox_info = "<br><b>Localization BBox:</b> unavailable"
38
+ if self._widget._get_localization_bbox_for_clip_frame(self._clip_idx, 0) is not None:
39
+ bbox_info = "<br><b>Localization BBox:</b> shown in green"
40
+
41
+ self.setWindowTitle(f"Clip {self._clip_idx + 1}: {label} ({conf:.1%} confidence)")
42
+ self.setMinimumSize(700, 650)
43
+ self.setWindowFlags(
44
+ Qt.WindowType.Window
45
+ | Qt.WindowType.WindowMinMaxButtonsHint
46
+ | Qt.WindowType.WindowMaximizeButtonHint
47
+ | Qt.WindowType.WindowCloseButtonHint
48
+ )
49
+ self.setSizeGripEnabled(True)
50
+
51
+ if self._widget.clip_popup_maximized:
52
+ self.showMaximized()
53
+ else:
54
+ self.show()
55
+
56
+ layout = QVBoxLayout()
57
+
58
+ info_layout = QHBoxLayout()
59
+ ovr_scores = ""
60
+ if self._widget._use_ovr and self._clip_idx < len(self._widget.clip_probabilities):
61
+ probs = self._widget.clip_probabilities[self._clip_idx]
62
+ if isinstance(probs, (list, tuple)):
63
+ scored = []
64
+ for ci, sc in enumerate(probs):
65
+ if ci < len(self._widget.classes):
66
+ scored.append((self._widget.classes[ci], float(sc)))
67
+ scored.sort(key=lambda x: x[1], reverse=True)
68
+ parts = [f"{name}: {sc:.1%}" for name, sc in scored]
69
+ ovr_scores = "<br><b>All scores:</b> " + " | ".join(parts)
70
+ info_label = QLabel(f"<b>Predicted Label:</b> {label}<br><b>Confidence:</b> {conf:.2%}{attr_info}{bbox_info}{ovr_scores}")
71
+ info_label.setStyleSheet("font-size: 14px; padding: 10px;")
72
+ info_layout.addWidget(info_label)
73
+
74
+ if self._clip_idx in self._widget.corrected_labels or self._clip_idx in self._widget.corrected_attr_labels:
75
+ corrected_label = QLabel("<b style='color: green;'>Corrected</b>")
76
+ corrected_label.setStyleSheet("font-size: 14px; padding: 10px;")
77
+ info_layout.addWidget(corrected_label)
78
+
79
+ layout.addLayout(info_layout)
80
+
81
+ correction_group = QGroupBox("Correct label")
82
+ correction_layout = QVBoxLayout()
83
+
84
+ label_combo = QComboBox()
85
+ label_combo.addItems(self._widget.classes)
86
+ if 0 <= pred_idx < len(self._widget.classes):
87
+ label_combo.setCurrentIndex(pred_idx)
88
+
89
+ attr_combo = QComboBox()
90
+ attr_combo.addItem("No attribute")
91
+ if self._widget.attributes:
92
+ attr_combo.addItems(self._widget.attributes)
93
+ if isinstance(attr_idx, int) and attr_idx < len(self._widget.attributes):
94
+ attr_combo.setCurrentIndex(attr_idx + 1)
95
+
96
+ colors = self._widget._get_timeline_palette()
97
+ color_indicator = QLabel()
98
+ color_indicator.setFixedSize(20, 20)
99
+
100
+ def update_indicator(index):
101
+ if index < len(colors):
102
+ c = colors[index % len(colors)]
103
+ color_hex = f"rgb({c[0]},{c[1]},{c[2]})"
104
+ color_indicator.setStyleSheet(f"background-color: {color_hex}; border-radius: 10px; border: 1px solid #666;")
105
+
106
+ label_combo.currentIndexChanged.connect(update_indicator)
107
+ update_indicator(pred_idx)
108
+
109
+ combo_layout = QHBoxLayout()
110
+ combo_layout.addWidget(QLabel("Select correct label:"))
111
+ combo_layout.addWidget(color_indicator)
112
+ combo_layout.addWidget(label_combo)
113
+ correction_layout.addLayout(combo_layout)
114
+
115
+ attr_layout = QHBoxLayout()
116
+ attr_layout.addWidget(QLabel("Select attribute:"))
117
+ attr_layout.addWidget(attr_combo)
118
+ correction_layout.addLayout(attr_layout)
119
+
120
+ save_correction_btn = QPushButton("Save correction")
121
+ save_correction_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold; padding: 5px;")
122
+
123
+ def save_correction():
124
+ new_label_idx = label_combo.currentIndex()
125
+ original_pred_idx = self._widget.predictions[self._clip_idx]
126
+
127
+ if new_label_idx != original_pred_idx:
128
+ self._widget.corrected_labels[self._clip_idx] = new_label_idx
129
+ else:
130
+ if self._clip_idx in self._widget.corrected_labels:
131
+ del self._widget.corrected_labels[self._clip_idx]
132
+
133
+ new_attr_idx = attr_combo.currentIndex() - 1
134
+ original_attr_idx = self._widget.attr_predictions[self._clip_idx] if (self._widget.attr_predictions and self._clip_idx < len(self._widget.attr_predictions)) else None
135
+ if new_attr_idx >= 0 and new_attr_idx != original_attr_idx:
136
+ self._widget.corrected_attr_labels[self._clip_idx] = new_attr_idx
137
+ else:
138
+ if self._clip_idx in self._widget.corrected_attr_labels:
139
+ del self._widget.corrected_attr_labels[self._clip_idx]
140
+
141
+ self._widget._draw_timeline()
142
+ QMessageBox.information(self, "Correction Saved", "Corrections saved.\n\nTimeline updated.")
143
+
144
+ save_correction_btn.clicked.connect(save_correction)
145
+ correction_layout.addWidget(save_correction_btn)
146
+
147
+ correction_group.setLayout(correction_layout)
148
+ layout.addWidget(correction_group)
149
+
150
+ training_group = QGroupBox("Add to training dataset")
151
+ training_layout = QVBoxLayout()
152
+
153
+ add_to_training_btn = QPushButton("Add to training dataset")
154
+ add_to_training_btn.setStyleSheet("background-color: #2196F3; color: white; font-weight: bold; padding: 5px;")
155
+
156
+ def _extract_and_store_clip(selected_label: str, extra_meta: dict = None):
157
+ clips_dir = self._widget._get_clips_dir()
158
+
159
+ cap = cv2.VideoCapture(self._widget.video_path)
160
+ if not cap.isOpened():
161
+ QMessageBox.warning(self, "Error", "Could not open video file.")
162
+ return None, None
163
+
164
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
165
+ if orig_fps <= 0:
166
+ orig_fps = 30.0
167
+
168
+ target_fps = self._widget.target_fps_spin.value()
169
+ frame_interval = self._widget._get_saved_frame_interval(self._widget.video_path, orig_fps)
170
+ clip_length = self._widget.clip_length_spin.value()
171
+ clip_start_frame = self._widget.clip_starts[self._clip_idx]
172
+
173
+ cap.set(cv2.CAP_PROP_POS_FRAMES, clip_start_frame)
174
+ frames = []
175
+ frame_count = 0
176
+ while len(frames) < clip_length:
177
+ ret, frame = cap.read()
178
+ if not ret:
179
+ break
180
+ if frame_count % frame_interval == 0:
181
+ frames.append(frame.copy())
182
+ frame_count += 1
183
+ cap.release()
184
+
185
+ if not frames:
186
+ QMessageBox.warning(self, "Error", "Could not extract frames from clip.")
187
+ return None, None
188
+
189
+ video_basename = self._widget._video_basename()
190
+ clip_filename = f"{video_basename}_clip_{self._clip_idx:06d}_frame_{clip_start_frame}.mp4"
191
+ clip_path = os.path.join(clips_dir, clip_filename)
192
+ counter = 1
193
+ while os.path.exists(clip_path):
194
+ clip_filename = f"{video_basename}_clip_{self._clip_idx:06d}_frame_{clip_start_frame}_{counter}.mp4"
195
+ clip_path = os.path.join(clips_dir, clip_filename)
196
+ counter += 1
197
+
198
+ save_clip(frames, clip_path, target_fps)
199
+ if not os.path.exists(clip_path) or os.path.getsize(clip_path) == 0:
200
+ QMessageBox.warning(
201
+ self, "Error",
202
+ f"Failed to save clip to disk.\nPath: {clip_path}\n"
203
+ "Check write permissions and disk space."
204
+ )
205
+ return None, None
206
+
207
+ annotation_manager = AnnotationManager(self._widget._get_annotation_file())
208
+ clip_id = self._widget._clip_path_to_id(clip_path, clips_dir)
209
+
210
+ annotation_manager.add_class(selected_label)
211
+ meta = {
212
+ "source_video": os.path.basename(self._widget.video_path),
213
+ "source_frame": clip_start_frame,
214
+ "target_fps": target_fps,
215
+ "clip_length": clip_length,
216
+ "added_from_inference": True,
217
+ }
218
+ if extra_meta:
219
+ meta.update(extra_meta)
220
+ used_clip_id = annotation_manager.add_clip(clip_id, selected_label, meta=meta)
221
+ frame_labels = [selected_label] * len(frames)
222
+ annotation_manager.set_frame_labels(used_clip_id, frame_labels)
223
+ return clip_path, selected_label
224
+
225
+ def add_to_training():
226
+ try:
227
+ selected_label_idx = label_combo.currentIndex()
228
+ selected_label = self._widget.classes[selected_label_idx]
229
+ if self._widget.attributes:
230
+ selected_attr_idx = attr_combo.currentIndex() - 1
231
+ if 0 <= selected_attr_idx < len(self._widget.attributes):
232
+ selected_label = self._widget.attributes[selected_attr_idx]
233
+ clip_path, label = _extract_and_store_clip(selected_label)
234
+ if clip_path:
235
+ QMessageBox.information(
236
+ self,
237
+ "Success",
238
+ f"Clip added to training dataset!\n\n"
239
+ f"Label: {label}\n"
240
+ f"Saved to: {clip_path}\n\n"
241
+ f"You can now retrain the model with this new data.",
242
+ )
243
+ except Exception as e:
244
+ QMessageBox.critical(self, "Error", f"Failed to add clip to training dataset:\n{str(e)}")
245
+
246
+ def add_as_near_negative():
247
+ try:
248
+ prefix = str(self._widget.config.get("near_negative_label", "near_negative")).strip() or "near_negative"
249
+ target_class = None
250
+
251
+ sel_idx = label_combo.currentIndex()
252
+ if 0 <= sel_idx < len(self._widget.classes):
253
+ target_class = self._widget.classes[sel_idx]
254
+ else:
255
+ try:
256
+ raw_pred_idx = int(self._widget.predictions[self._clip_idx])
257
+ except Exception:
258
+ raw_pred_idx = -1
259
+ if 0 <= raw_pred_idx < len(self._widget.classes):
260
+ target_class = self._widget.classes[raw_pred_idx]
261
+
262
+ if target_class:
263
+ class_token = str(target_class).strip().replace(" ", "_").replace("/", "_").replace("\\", "_")
264
+ while "__" in class_token:
265
+ class_token = class_token.replace("__", "_")
266
+ class_token = class_token.strip("_")
267
+ if class_token:
268
+ near_label = f"{prefix}_{class_token}" if not prefix.endswith(f"_{class_token}") else prefix
269
+ else:
270
+ near_label = prefix
271
+ else:
272
+ near_label = prefix
273
+
274
+ clip_path, label = _extract_and_store_clip(
275
+ near_label,
276
+ extra_meta={
277
+ "near_negative": True,
278
+ "hard_negative_candidate": True,
279
+ "hard_negative_for_class": target_class,
280
+ },
281
+ )
282
+ if clip_path:
283
+ self._widget.log_text.append(f"Near negative saved: label='{label}', clip='{clip_path}'")
284
+ QMessageBox.information(
285
+ self,
286
+ "Near negative added",
287
+ f"Clip saved as hard negative for this prediction.\n\n"
288
+ f"Label: {label}\n"
289
+ f"Saved to: {clip_path}\n\n"
290
+ "Train with multiple near_negative_* classes; at inference ignore them as background.",
291
+ )
292
+ except Exception as e:
293
+ QMessageBox.critical(self, "Error", f"Failed to add near negative clip:\n{str(e)}")
294
+
295
+ add_to_training_btn.clicked.connect(add_to_training)
296
+ training_layout.addWidget(add_to_training_btn)
297
+ add_near_negative_btn = QPushButton("Mark as near negative")
298
+ add_near_negative_btn.setToolTip(
299
+ "Saves clip as near_negative_<predicted_class> (e.g. near_negative_jump). "
300
+ "Train with all near_negative_* classes; at inference treat them as background."
301
+ )
302
+ add_near_negative_btn.setStyleSheet("background-color: #5a6578; color: white; font-weight: bold; padding: 5px;")
303
+ add_near_negative_btn.clicked.connect(add_as_near_negative)
304
+ training_layout.addWidget(add_near_negative_btn)
305
+
306
+ training_group.setLayout(training_layout)
307
+ layout.addWidget(training_group)
308
+
309
+ video_label = QLabel("Loading clip...")
310
+ video_label.setMinimumSize(640, 360)
311
+ video_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
312
+ video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
313
+ video_label.setStyleSheet("background-color: black; color: white;")
314
+ layout.addWidget(video_label, 1)
315
+
316
+ controls_layout = QHBoxLayout()
317
+
318
+ prev_btn = QPushButton("< Previous")
319
+ prev_btn.setToolTip("Previous Clip")
320
+
321
+ has_prev = False
322
+ selected_behavior = self._widget.filter_behavior_combo.currentText()
323
+ selected_attr = None
324
+ if selected_behavior.startswith("Attr: "):
325
+ selected_attr = selected_behavior.replace("Attr: ", "", 1)
326
+ prev_search_idx = self._clip_idx - 1
327
+
328
+ if selected_behavior == "All Behaviors":
329
+ if prev_search_idx >= 0:
330
+ has_prev = True
331
+ else:
332
+ while prev_search_idx >= 0:
333
+ p_idx = self._widget._effective_prediction_for_clip(prev_search_idx)
334
+
335
+ if selected_attr is not None:
336
+ if selected_attr in self._widget.attributes:
337
+ attr_target_idx = self._widget.attributes.index(selected_attr)
338
+ if self._widget._get_attr_idx(prev_search_idx) == attr_target_idx:
339
+ has_prev = True
340
+ break
341
+ else:
342
+ if selected_behavior == self._widget.ignore_label_name and p_idx < 0:
343
+ has_prev = True
344
+ break
345
+ if p_idx < len(self._widget.classes) and p_idx >= 0 and self._widget.classes[p_idx] == selected_behavior:
346
+ has_prev = True
347
+ break
348
+ prev_search_idx -= 1
349
+
350
+ if not has_prev:
351
+ prev_btn.setEnabled(False)
352
+
353
+ def go_prev():
354
+ self._widget.clip_popup_maximized = self.isMaximized()
355
+ self.close()
356
+
357
+ target_idx = self._clip_idx - 1
358
+ if selected_behavior == "All Behaviors":
359
+ if target_idx >= 0:
360
+ self._widget._show_clip_popup(target_idx)
361
+ else:
362
+ while target_idx >= 0:
363
+ p_idx = self._widget._effective_prediction_for_clip(target_idx)
364
+
365
+ if selected_attr is not None:
366
+ if selected_attr in self._widget.attributes:
367
+ attr_target_idx = self._widget.attributes.index(selected_attr)
368
+ if self._widget._get_attr_idx(target_idx) == attr_target_idx:
369
+ self._widget._show_clip_popup(target_idx)
370
+ return
371
+ else:
372
+ if selected_behavior == self._widget.ignore_label_name and p_idx < 0:
373
+ self._widget._show_clip_popup(target_idx)
374
+ return
375
+ if p_idx < len(self._widget.classes) and p_idx >= 0 and self._widget.classes[p_idx] == selected_behavior:
376
+ self._widget._show_clip_popup(target_idx)
377
+ return
378
+ target_idx -= 1
379
+
380
+ prev_btn.clicked.connect(go_prev)
381
+ controls_layout.addWidget(prev_btn)
382
+
383
+ play_pause_btn = QPushButton("Play")
384
+ is_playing = [False]
385
+ current_frame_idx = [0]
386
+ video_frames = []
387
+ self._play_timer = QTimer()
388
+
389
+ def load_clip_frames():
390
+ try:
391
+ cap = cv2.VideoCapture(self._widget.video_path)
392
+ if not cap.isOpened():
393
+ video_label.setText("Error: Could not open video")
394
+ return []
395
+
396
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
397
+ if orig_fps <= 0:
398
+ orig_fps = 30.0
399
+
400
+ target_fps = self._widget.target_fps_spin.value()
401
+ frame_interval = self._widget._get_saved_frame_interval(self._widget.video_path, orig_fps)
402
+ clip_length = self._widget.clip_length_spin.value()
403
+
404
+ clip_start_frame = self._widget.clip_starts[self._clip_idx]
405
+
406
+ cap.set(cv2.CAP_PROP_POS_FRAMES, clip_start_frame)
407
+
408
+ frames = []
409
+ frame_count = 0
410
+
411
+ while len(frames) < clip_length:
412
+ ret, frame = cap.read()
413
+ if not ret:
414
+ break
415
+
416
+ if frame_count % frame_interval == 0:
417
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
418
+ bbox = self._widget._get_localization_bbox_for_clip_frame(self._clip_idx, len(frames))
419
+ crop_bbox = self._widget._get_classification_roi_bbox_for_clip_frame(self._clip_idx)
420
+ if bbox is not None:
421
+ h, w = frame_rgb.shape[:2]
422
+ x1, y1, x2, y2 = bbox
423
+ fx1 = max(0, min(int(round(x1 * w)), w - 1))
424
+ fy1 = max(0, min(int(round(y1 * h)), h - 1))
425
+ fx2 = max(fx1 + 1, min(int(round(x2 * w)), w))
426
+ fy2 = max(fy1 + 1, min(int(round(y2 * h)), h))
427
+ cv2.rectangle(frame_rgb, (fx1, fy1), (fx2 - 1, fy2 - 1), (0, 255, 0), 2)
428
+ cv2.putText(
429
+ frame_rgb,
430
+ "Raw localization",
431
+ (fx1, max(18, fy1 - 6)),
432
+ cv2.FONT_HERSHEY_SIMPLEX,
433
+ 0.55,
434
+ (0, 255, 0),
435
+ 2,
436
+ cv2.LINE_AA,
437
+ )
438
+ if crop_bbox is not None:
439
+ h, w = frame_rgb.shape[:2]
440
+ x1, y1, x2, y2 = crop_bbox
441
+ fx1 = max(0, min(int(round(x1 * w)), w - 1))
442
+ fy1 = max(0, min(int(round(y1 * h)), h - 1))
443
+ fx2 = max(fx1 + 1, min(int(round(x2 * w)), w))
444
+ fy2 = max(fy1 + 1, min(int(round(y2 * h)), h))
445
+ cv2.rectangle(frame_rgb, (fx1, fy1), (fx2 - 1, fy2 - 1), (0, 255, 255), 2)
446
+ cv2.putText(
447
+ frame_rgb,
448
+ "Cls crop ROI",
449
+ (fx1, min(h - 8, fy2 + 18)),
450
+ cv2.FONT_HERSHEY_SIMPLEX,
451
+ 0.52,
452
+ (0, 255, 255),
453
+ 2,
454
+ cv2.LINE_AA,
455
+ )
456
+ frames.append(frame_rgb)
457
+
458
+ frame_count += 1
459
+
460
+ cap.release()
461
+
462
+ if frames:
463
+ fps = target_fps
464
+ self._play_timer.setInterval(int(1000 / fps))
465
+
466
+ return frames
467
+ except Exception as e:
468
+ video_label.setText(f"Error loading clip: {str(e)}")
469
+ return []
470
+
471
+ def update_frame():
472
+ if current_frame_idx[0] < len(video_frames):
473
+ frame = video_frames[current_frame_idx[0]]
474
+ h, w, c = frame.shape
475
+ bytes_per_line = int(c * w)
476
+ if not frame.flags['C_CONTIGUOUS']:
477
+ frame = np.ascontiguousarray(frame)
478
+
479
+ q_image = QImage(frame.data, int(w), int(h), bytes_per_line, QImage.Format.Format_RGB888)
480
+
481
+ pixmap = QPixmap.fromImage(q_image)
482
+ scaled_pixmap = pixmap.scaled(
483
+ video_label.size(),
484
+ Qt.AspectRatioMode.KeepAspectRatio,
485
+ Qt.TransformationMode.SmoothTransformation
486
+ )
487
+ video_label.setPixmap(scaled_pixmap)
488
+ current_frame_idx[0] = (current_frame_idx[0] + 1) % len(video_frames)
489
+ else:
490
+ self._play_timer.stop()
491
+ is_playing[0] = False
492
+ play_pause_btn.setText("Play")
493
+
494
+ def toggle_play():
495
+ if not video_frames:
496
+ loaded_frames = load_clip_frames()
497
+ video_frames.extend(loaded_frames)
498
+ if video_frames:
499
+ update_frame()
500
+
501
+ if is_playing[0]:
502
+ play_pause_btn.setText("Play")
503
+ self._play_timer.stop()
504
+ is_playing[0] = False
505
+ else:
506
+ play_pause_btn.setText("Pause")
507
+ self._play_timer.start()
508
+ is_playing[0] = True
509
+
510
+ self._play_timer.timeout.connect(update_frame)
511
+ self._play_timer.setInterval(33)
512
+
513
+ play_pause_btn.clicked.connect(toggle_play)
514
+ controls_layout.addWidget(play_pause_btn)
515
+
516
+ restart_btn = QPushButton("Restart")
517
+
518
+ def restart():
519
+ current_frame_idx[0] = 0
520
+ if is_playing[0]:
521
+ toggle_play()
522
+ update_frame()
523
+
524
+ restart_btn.clicked.connect(restart)
525
+ controls_layout.addWidget(restart_btn)
526
+
527
+ next_btn = QPushButton("Next >")
528
+ next_btn.setToolTip("Next Clip")
529
+
530
+ has_next = False
531
+ next_search_idx = self._clip_idx + 1
532
+
533
+ if selected_behavior == "All Behaviors":
534
+ if next_search_idx < len(self._widget.predictions):
535
+ has_next = True
536
+ else:
537
+ while next_search_idx < len(self._widget.predictions):
538
+ p_idx = self._widget._effective_prediction_for_clip(next_search_idx)
539
+
540
+ if selected_attr is not None:
541
+ if selected_attr in self._widget.attributes:
542
+ attr_target_idx = self._widget.attributes.index(selected_attr)
543
+ if self._widget._get_attr_idx(next_search_idx) == attr_target_idx:
544
+ has_next = True
545
+ break
546
+ else:
547
+ if selected_behavior == self._widget.ignore_label_name and p_idx < 0:
548
+ has_next = True
549
+ break
550
+ if p_idx < len(self._widget.classes) and p_idx >= 0 and self._widget.classes[p_idx] == selected_behavior:
551
+ has_next = True
552
+ break
553
+ next_search_idx += 1
554
+
555
+ if not has_next:
556
+ next_btn.setEnabled(False)
557
+
558
+ def go_next():
559
+ self._widget.clip_popup_maximized = self.isMaximized()
560
+ self.close()
561
+
562
+ target_idx = self._clip_idx + 1
563
+ if selected_behavior == "All Behaviors":
564
+ if target_idx < len(self._widget.predictions):
565
+ self._widget._show_clip_popup(target_idx)
566
+ else:
567
+ while target_idx < len(self._widget.predictions):
568
+ p_idx = self._widget._effective_prediction_for_clip(target_idx)
569
+
570
+ if selected_attr is not None:
571
+ if selected_attr in self._widget.attributes:
572
+ attr_target_idx = self._widget.attributes.index(selected_attr)
573
+ if self._widget._get_attr_idx(target_idx) == attr_target_idx:
574
+ self._widget._show_clip_popup(target_idx)
575
+ return
576
+ else:
577
+ if selected_behavior == self._widget.ignore_label_name and p_idx < 0:
578
+ self._widget._show_clip_popup(target_idx)
579
+ return
580
+ if p_idx < len(self._widget.classes) and p_idx >= 0 and self._widget.classes[p_idx] == selected_behavior:
581
+ self._widget._show_clip_popup(target_idx)
582
+ return
583
+ target_idx += 1
584
+
585
+ next_btn.clicked.connect(go_next)
586
+ controls_layout.addWidget(next_btn)
587
+
588
+ controls_layout.addStretch()
589
+
590
+ close_btn = QPushButton("Close")
591
+ close_btn.clicked.connect(self.close)
592
+ controls_layout.addWidget(close_btn)
593
+
594
+ layout.addLayout(controls_layout)
595
+ self.setLayout(layout)
596
+
597
+ def load_and_show_first_frame():
598
+ loaded_frames = load_clip_frames()
599
+ video_frames.extend(loaded_frames)
600
+ if video_frames:
601
+ update_frame()
602
+
603
+ QTimer.singleShot(100, load_and_show_first_frame)
604
+
605
+ self.exec()
606
+
607
+ if self._play_timer.isActive():
608
+ self._play_timer.stop()
609
+
610
+ def closeEvent(self, event):
611
+ if hasattr(self, '_play_timer') and self._play_timer.isActive():
612
+ self._play_timer.stop()
613
+ super().closeEvent(event)
614
+
615
+
616
+ class FrameSegmentPopupDialog(QDialog):
617
+ """Dialog showing a frame-aggregated segment with training and transition controls."""
618
+
619
+ def __init__(self, parent, widget, frame_idx, segment, segment_idx):
620
+ super().__init__(parent)
621
+ self._widget = widget
622
+ self._frame_idx = frame_idx
623
+ self._segment = segment
624
+ self._segment_idx = segment_idx
625
+
626
+ if not self._widget.video_path:
627
+ self.close()
628
+ return
629
+
630
+ if self._segment_idx is None:
631
+ for i, seg in enumerate(self._widget.aggregated_segments):
632
+ if seg['start'] == self._segment['start'] and seg['end'] == self._segment['end']:
633
+ self._segment_idx = i
634
+ break
635
+
636
+ pred_idx = self._segment['class']
637
+ conf = self._segment.get('confidence', 1.0)
638
+ start_frame = self._segment['start']
639
+ end_frame = self._segment['end']
640
+
641
+ if pred_idx >= len(self._widget.classes):
642
+ self.close()
643
+ return
644
+
645
+ label = self._widget.classes[pred_idx]
646
+
647
+ cap = cv2.VideoCapture(self._widget.video_path)
648
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
649
+ if orig_fps <= 0:
650
+ orig_fps = 30.0
651
+ cap.release()
652
+
653
+ start_time = start_frame / orig_fps
654
+ end_time = (end_frame + 1) / orig_fps
655
+ duration = end_time - start_time
656
+ clicked_time = self._frame_idx / orig_fps
657
+
658
+ self.setWindowTitle(f"Segment {(self._segment_idx + 1) if self._segment_idx is not None else '?'}/{len(self._widget.aggregated_segments)}: {label} (frames {start_frame}-{end_frame})")
659
+ self.setMinimumSize(700, 650)
660
+ self.setWindowFlags(
661
+ Qt.WindowType.Window
662
+ | Qt.WindowType.WindowMinMaxButtonsHint
663
+ | Qt.WindowType.WindowMaximizeButtonHint
664
+ | Qt.WindowType.WindowCloseButtonHint
665
+ )
666
+ self.setSizeGripEnabled(True)
667
+
668
+ layout = QVBoxLayout()
669
+
670
+ info_text = (
671
+ f"<b>Behavior:</b> {label}<br>"
672
+ f"<b>Aggregated Confidence:</b> {conf:.2f}<br>"
673
+ f"<b>Frame Range:</b> {start_frame} - {end_frame} ({end_frame - start_frame + 1} frames)<br>"
674
+ f"<b>Time Range:</b> {start_time:.2f}s - {end_time:.2f}s ({duration:.2f}s duration)<br>"
675
+ f"<b>Clicked Frame:</b> {self._frame_idx} ({clicked_time:.2f}s)"
676
+ )
677
+ info_label = QLabel(info_text)
678
+ info_label.setStyleSheet("font-size: 14px; padding: 10px;")
679
+ layout.addWidget(info_label)
680
+
681
+ training_group = QGroupBox("Add segment to training dataset")
682
+ training_layout = QVBoxLayout()
683
+ seg_label_row = QHBoxLayout()
684
+ seg_label_row.addWidget(QLabel("Training label:"))
685
+ seg_label_combo = QComboBox()
686
+ seg_label_combo.addItems(self._widget.classes)
687
+ if 0 <= pred_idx < len(self._widget.classes):
688
+ seg_label_combo.setCurrentIndex(pred_idx)
689
+ seg_label_row.addWidget(seg_label_combo)
690
+ training_layout.addLayout(seg_label_row)
691
+
692
+ add_segment_btn = QPushButton("Add segment chunks to training dataset")
693
+ add_segment_btn.setStyleSheet("background-color: #2196F3; color: white; font-weight: bold; padding: 5px;")
694
+ add_segment_btn.setToolTip(
695
+ "Creates consecutive clips over this segment.\n"
696
+ "Frames inside the segment are labeled; frames outside are set to None (ignored)."
697
+ )
698
+
699
+ def add_segment_chunks_to_training():
700
+ try:
701
+ selected_label = seg_label_combo.currentText().strip()
702
+ if not selected_label:
703
+ QMessageBox.warning(self, "Missing label", "Select a training label first.")
704
+ return
705
+
706
+ clips_dir = self._widget._get_clips_dir()
707
+ annotation_manager = AnnotationManager(self._widget._get_annotation_file())
708
+ annotation_manager.add_class(selected_label)
709
+
710
+ clip_length = int(self._widget.clip_length_spin.value())
711
+ if clip_length <= 0:
712
+ QMessageBox.warning(self, "Invalid clip length", "Clip length must be > 0.")
713
+ return
714
+
715
+ frame_interval = int(max(1, self._widget._get_saved_frame_interval(self._widget.video_path, orig_fps)))
716
+ segment_sampled_frames = ((end_frame - start_frame) // frame_interval) + 1
717
+ if segment_sampled_frames <= 0:
718
+ QMessageBox.warning(self, "Empty segment", "Segment has no usable frames.")
719
+ return
720
+
721
+ num_chunks = (segment_sampled_frames + clip_length - 1) // clip_length
722
+ if num_chunks <= 0:
723
+ QMessageBox.warning(self, "No chunks", "Could not create segment chunks.")
724
+ return
725
+
726
+ cap = cv2.VideoCapture(self._widget.video_path)
727
+ if not cap.isOpened():
728
+ QMessageBox.warning(self, "Error", "Could not open video file.")
729
+ return
730
+
731
+ target_fps = int(self._widget.target_fps_spin.value())
732
+ video_basename = self._widget._video_basename()
733
+ added_paths = []
734
+
735
+ for chunk_idx in range(num_chunks):
736
+ chunk_start_vid_frame = int(start_frame + chunk_idx * clip_length * frame_interval)
737
+ frames_in_segment = min(clip_length, segment_sampled_frames - chunk_idx * clip_length)
738
+
739
+ if frames_in_segment <= 0:
740
+ continue
741
+
742
+ cap.set(cv2.CAP_PROP_POS_FRAMES, chunk_start_vid_frame)
743
+ frames = []
744
+ frame_count = 0
745
+
746
+ while len(frames) < clip_length:
747
+ ret, frame = cap.read()
748
+ if not ret:
749
+ break
750
+ if frame_count % frame_interval == 0:
751
+ frames.append(frame.copy())
752
+ frame_count += 1
753
+
754
+ if frames and len(frames) < clip_length:
755
+ last_frame = frames[-1]
756
+ while len(frames) < clip_length:
757
+ frames.append(last_frame.copy())
758
+
759
+ if not frames:
760
+ continue
761
+
762
+ clip_filename = (
763
+ f"{video_basename}_seg_{start_frame}_{end_frame}_"
764
+ f"chunk_{chunk_idx:03d}_frame_{chunk_start_vid_frame}.mp4"
765
+ )
766
+ clip_path = os.path.join(clips_dir, clip_filename)
767
+ clip_path = self._widget._unique_clip_path(clip_path)
768
+
769
+ save_clip(frames, clip_path, target_fps)
770
+ if not os.path.exists(clip_path) or os.path.getsize(clip_path) == 0:
771
+ continue
772
+
773
+ clip_id = self._widget._clip_path_to_id(clip_path, clips_dir)
774
+
775
+ frame_labels = []
776
+ for i in range(clip_length):
777
+ if i < frames_in_segment:
778
+ frame_labels.append(selected_label)
779
+ else:
780
+ frame_labels.append(None)
781
+
782
+ meta = {
783
+ "source_video": os.path.basename(self._widget.video_path),
784
+ "source_segment_start_frame": int(start_frame),
785
+ "source_segment_end_frame": int(end_frame),
786
+ "source_chunk_index": int(chunk_idx),
787
+ "source_frame": int(chunk_start_vid_frame),
788
+ "target_fps": int(target_fps),
789
+ "clip_length": int(clip_length),
790
+ "added_from_inference_segment": True,
791
+ "segment_label_frames": int(frames_in_segment),
792
+ }
793
+
794
+ annotation_manager.add_clip(clip_id, selected_label, meta=meta, _defer_save=True)
795
+ annotation_manager.set_frame_labels(clip_id, frame_labels, _defer_save=True)
796
+ added_paths.append(clip_path)
797
+
798
+ cap.release()
799
+
800
+ if not added_paths:
801
+ QMessageBox.warning(self, "Nothing added", "No segment clips were added.")
802
+ return
803
+
804
+ annotation_manager.save()
805
+ self._widget.log_text.append(
806
+ f"Added {len(added_paths)} segment chunk(s) to training dataset "
807
+ f"for '{selected_label}' (segment {start_frame}-{end_frame})."
808
+ )
809
+ QMessageBox.information(
810
+ self,
811
+ "Segment added",
812
+ f"Added {len(added_paths)} clip(s) to training dataset.\n\n"
813
+ f"Label: {selected_label}\n"
814
+ f"Segment: frames {start_frame}-{end_frame}\n\n"
815
+ f"Each clip has frame labels only where behavior is inside this segment; "
816
+ "other frames are set to None (ignored).",
817
+ )
818
+ except Exception as e:
819
+ QMessageBox.critical(self, "Error", f"Failed to add segment chunks:\n{str(e)}")
820
+
821
+ add_segment_btn.clicked.connect(add_segment_chunks_to_training)
822
+ training_layout.addWidget(add_segment_btn)
823
+
824
+ transition_len_row = QHBoxLayout()
825
+ transition_len_row.addWidget(QLabel("Transition clip frames:"))
826
+ transition_len_spin = QSpinBox()
827
+ transition_len_spin.setRange(2, 64)
828
+ transition_len_spin.setValue(int(self._widget.clip_length_spin.value()))
829
+ transition_len_spin.setToolTip("Number of sampled frames to save for this transition clip.")
830
+ transition_len_row.addWidget(transition_len_spin)
831
+ transition_len_row.addWidget(QLabel("Ignore ±frames:"))
832
+ transition_ignore_spin = QSpinBox()
833
+ transition_ignore_spin.setRange(0, 8)
834
+ transition_ignore_spin.setValue(1)
835
+ transition_ignore_spin.setToolTip(
836
+ "Frames around the exact boundary set to None (ignored during frame loss)."
837
+ )
838
+ transition_len_row.addWidget(transition_ignore_spin)
839
+ training_layout.addLayout(transition_len_row)
840
+
841
+ def _safe_label_token(label_text: str) -> str:
842
+ token = str(label_text or "").strip().replace(" ", "_").replace("/", "_").replace("\\", "_")
843
+ while "__" in token:
844
+ token = token.replace("__", "_")
845
+ return token.strip("_") or "class"
846
+
847
+ def _extract_transition_clip(left_seg: dict, right_seg: dict, boundary_name: str):
848
+ try:
849
+ left_idx = int(left_seg.get("class", -1))
850
+ right_idx = int(right_seg.get("class", -1))
851
+ if not (0 <= left_idx < len(self._widget.classes) and 0 <= right_idx < len(self._widget.classes)):
852
+ QMessageBox.warning(self, "Invalid class", "Could not resolve neighboring segment labels.")
853
+ return
854
+
855
+ left_label = self._widget.classes[left_idx]
856
+ right_label = self._widget.classes[right_idx]
857
+ clip_len = int(transition_len_spin.value())
858
+ if clip_len <= 1:
859
+ QMessageBox.warning(self, "Invalid clip length", "Transition clip length must be >= 2.")
860
+ return
861
+ ignore_half = int(max(0, transition_ignore_spin.value()))
862
+
863
+ clips_dir = self._widget._get_clips_dir()
864
+ annotation_manager = AnnotationManager(self._widget._get_annotation_file())
865
+ annotation_manager.add_class(left_label)
866
+ annotation_manager.add_class(right_label)
867
+
868
+ frame_interval = int(max(1, self._widget._get_saved_frame_interval(self._widget.video_path, orig_fps)))
869
+ boundary_frame = int((int(left_seg["end"]) + int(right_seg["start"])) // 2)
870
+ center_idx = clip_len // 2
871
+ clip_start_vid_frame = max(0, boundary_frame - center_idx * frame_interval)
872
+
873
+ cap_local = cv2.VideoCapture(self._widget.video_path)
874
+ if not cap_local.isOpened():
875
+ QMessageBox.warning(self, "Error", "Could not open video file.")
876
+ return
877
+ cap_local.set(cv2.CAP_PROP_POS_FRAMES, clip_start_vid_frame)
878
+ frames = []
879
+ sampled_video_frames = []
880
+ read_ctr = 0
881
+ while len(frames) < clip_len:
882
+ ret, frame = cap_local.read()
883
+ if not ret:
884
+ break
885
+ if read_ctr % frame_interval == 0:
886
+ frames.append(frame.copy())
887
+ sampled_video_frames.append(int(clip_start_vid_frame + read_ctr))
888
+ read_ctr += 1
889
+ cap_local.release()
890
+
891
+ if not frames:
892
+ QMessageBox.warning(self, "No frames", "Could not extract transition clip frames.")
893
+ return
894
+
895
+ if len(frames) < clip_len:
896
+ last_frame = frames[-1]
897
+ last_idx = sampled_video_frames[-1]
898
+ while len(frames) < clip_len:
899
+ frames.append(last_frame.copy())
900
+ last_idx += frame_interval
901
+ sampled_video_frames.append(int(last_idx))
902
+
903
+ left_end = int(left_seg["end"]) - ignore_half * frame_interval
904
+ right_start = int(right_seg["start"]) + ignore_half * frame_interval
905
+ frame_labels = []
906
+ for vf in sampled_video_frames:
907
+ if vf <= left_end:
908
+ frame_labels.append(left_label)
909
+ elif vf >= right_start:
910
+ frame_labels.append(right_label)
911
+ else:
912
+ frame_labels.append(None)
913
+
914
+ if left_label not in frame_labels:
915
+ frame_labels[0] = left_label
916
+ if right_label not in frame_labels:
917
+ frame_labels[-1] = right_label
918
+
919
+ non_none = [x for x in frame_labels if x is not None]
920
+ primary_label = left_label
921
+ if non_none:
922
+ left_count = sum(1 for x in non_none if x == left_label)
923
+ right_count = sum(1 for x in non_none if x == right_label)
924
+ primary_label = left_label if left_count >= right_count else right_label
925
+
926
+ video_basename = self._widget._video_basename()
927
+ left_tok = _safe_label_token(left_label)
928
+ right_tok = _safe_label_token(right_label)
929
+ clip_filename = (
930
+ f"{video_basename}_transition_{left_tok}_to_{right_tok}_"
931
+ f"frame_{clip_start_vid_frame}_len_{clip_len}.mp4"
932
+ )
933
+ clip_path = os.path.join(clips_dir, clip_filename)
934
+ clip_path = self._widget._unique_clip_path(clip_path)
935
+
936
+ target_fps = int(self._widget.target_fps_spin.value())
937
+ save_clip(frames, clip_path, target_fps)
938
+ if not os.path.exists(clip_path) or os.path.getsize(clip_path) == 0:
939
+ QMessageBox.warning(self, "Save failed", "Failed to save transition clip.")
940
+ return
941
+
942
+ clip_id = self._widget._clip_path_to_id(clip_path, clips_dir)
943
+ meta = {
944
+ "source_video": os.path.basename(self._widget.video_path),
945
+ "source_frame": int(clip_start_vid_frame),
946
+ "target_fps": int(target_fps),
947
+ "clip_length": int(clip_len),
948
+ "added_from_inference_transition": True,
949
+ "transition_direction": boundary_name,
950
+ "transition_from_label": left_label,
951
+ "transition_to_label": right_label,
952
+ "transition_boundary_frame": int(boundary_frame),
953
+ "transition_ignore_half_frames": int(ignore_half),
954
+ }
955
+ used_clip_id = annotation_manager.add_clip(clip_id, primary_label, meta=meta)
956
+ annotation_manager.set_frame_labels(used_clip_id, frame_labels)
957
+
958
+ n_left = sum(1 for x in frame_labels if x == left_label)
959
+ n_right = sum(1 for x in frame_labels if x == right_label)
960
+ n_ignored = sum(1 for x in frame_labels if x is None)
961
+ self._widget.log_text.append(
962
+ f"Added transition clip ({boundary_name}): {left_label}->{right_label}, "
963
+ f"frames={clip_len}, labels=({n_left}/{n_right}/ignored={n_ignored})"
964
+ )
965
+ QMessageBox.information(
966
+ self,
967
+ "Transition clip added",
968
+ f"Saved transition training clip.\n\n"
969
+ f"From: {left_label}\n"
970
+ f"To: {right_label}\n"
971
+ f"Direction: {boundary_name}\n"
972
+ f"Clip: {os.path.basename(clip_path)}\n"
973
+ f"Frame labels: {n_left} left, {n_right} right, {n_ignored} ignored",
974
+ )
975
+ except Exception as e:
976
+ QMessageBox.critical(self, "Error", f"Failed to add transition clip:\n{str(e)}")
977
+
978
+ transition_btn_row = QHBoxLayout()
979
+ prev_transition_btn = QPushButton("Add prev -> current transition clip")
980
+ next_transition_btn = QPushButton("Add current -> next transition clip")
981
+ prev_transition_btn.setStyleSheet("background-color: #4b7bec; color: white; font-weight: bold; padding: 5px;")
982
+ next_transition_btn.setStyleSheet("background-color: #4b7bec; color: white; font-weight: bold; padding: 5px;")
983
+
984
+ has_prev_seg = self._segment_idx is not None and self._segment_idx > 0
985
+ has_next_seg = self._segment_idx is not None and self._segment_idx < (len(self._widget.aggregated_segments) - 1)
986
+ prev_transition_btn.setEnabled(has_prev_seg)
987
+ next_transition_btn.setEnabled(has_next_seg)
988
+ prev_transition_btn.setToolTip("Create one fixed-length transition clip around the previous->current boundary.")
989
+ next_transition_btn.setToolTip("Create one fixed-length transition clip around the current->next boundary.")
990
+
991
+ def _add_prev_transition():
992
+ if not has_prev_seg:
993
+ return
994
+ prev_seg = self._widget.aggregated_segments[self._segment_idx - 1]
995
+ curr_seg = self._widget.aggregated_segments[self._segment_idx]
996
+ _extract_transition_clip(prev_seg, curr_seg, "prev_to_current")
997
+
998
+ def _add_next_transition():
999
+ if not has_next_seg:
1000
+ return
1001
+ curr_seg = self._widget.aggregated_segments[self._segment_idx]
1002
+ next_seg = self._widget.aggregated_segments[self._segment_idx + 1]
1003
+ _extract_transition_clip(curr_seg, next_seg, "current_to_next")
1004
+
1005
+ prev_transition_btn.clicked.connect(_add_prev_transition)
1006
+ next_transition_btn.clicked.connect(_add_next_transition)
1007
+ transition_btn_row.addWidget(prev_transition_btn)
1008
+ transition_btn_row.addWidget(next_transition_btn)
1009
+ training_layout.addLayout(transition_btn_row)
1010
+
1011
+ training_group.setLayout(training_layout)
1012
+ layout.addWidget(training_group)
1013
+
1014
+ video_label = QLabel("Loading segment...")
1015
+ video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
1016
+ video_label.setMinimumSize(640, 360)
1017
+ video_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
1018
+ layout.addWidget(video_label)
1019
+
1020
+ controls_layout = QHBoxLayout()
1021
+
1022
+ prev_seg_btn = QPushButton("< Prev Segment")
1023
+ prev_seg_btn.setToolTip("Go to previous behavior segment")
1024
+ has_prev = self._segment_idx is not None and self._segment_idx > 0
1025
+ prev_seg_btn.setEnabled(has_prev)
1026
+
1027
+ def go_prev_segment():
1028
+ if self._segment_idx is not None and self._segment_idx > 0:
1029
+ self.close()
1030
+ prev_seg = self._widget.aggregated_segments[self._segment_idx - 1]
1031
+ mid_frame = (prev_seg['start'] + prev_seg['end']) // 2
1032
+ self._widget._show_frame_segment_popup(mid_frame, prev_seg, self._segment_idx - 1)
1033
+
1034
+ prev_seg_btn.clicked.connect(go_prev_segment)
1035
+ controls_layout.addWidget(prev_seg_btn)
1036
+
1037
+ frame_slider = QSpinBox()
1038
+ frame_slider.setRange(start_frame, end_frame)
1039
+ frame_slider.setValue(self._frame_idx)
1040
+ controls_layout.addWidget(QLabel("Frame:"))
1041
+ controls_layout.addWidget(frame_slider)
1042
+
1043
+ play_pause_btn = QPushButton("Play")
1044
+ is_playing = [False]
1045
+ current_frame_idx = [self._frame_idx]
1046
+ video_frames = {}
1047
+ self._play_timer = QTimer()
1048
+
1049
+ def load_frame(f_idx):
1050
+ if f_idx in video_frames:
1051
+ return video_frames[f_idx]
1052
+
1053
+ cap = cv2.VideoCapture(self._widget.video_path)
1054
+ cap.set(cv2.CAP_PROP_POS_FRAMES, f_idx)
1055
+ ret, frame = cap.read()
1056
+ cap.release()
1057
+
1058
+ if ret:
1059
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1060
+ video_frames[f_idx] = frame_rgb
1061
+ return frame_rgb
1062
+ return None
1063
+
1064
+ def update_frame():
1065
+ frame = load_frame(current_frame_idx[0])
1066
+ if frame is not None:
1067
+ h, w, c = frame.shape
1068
+ q_img = QImage(frame.data, w, h, w * c, QImage.Format.Format_RGB888)
1069
+ pixmap = QPixmap.fromImage(q_img)
1070
+ scaled_pixmap = pixmap.scaled(
1071
+ video_label.size(),
1072
+ Qt.AspectRatioMode.KeepAspectRatio,
1073
+ Qt.TransformationMode.SmoothTransformation
1074
+ )
1075
+ video_label.setPixmap(scaled_pixmap)
1076
+
1077
+ def on_frame_changed(val):
1078
+ current_frame_idx[0] = val
1079
+ update_frame()
1080
+
1081
+ frame_slider.valueChanged.connect(on_frame_changed)
1082
+
1083
+ def toggle_play():
1084
+ if is_playing[0]:
1085
+ self._play_timer.stop()
1086
+ play_pause_btn.setText("Play")
1087
+ is_playing[0] = False
1088
+ else:
1089
+ self._play_timer.start(int(1000 / orig_fps))
1090
+ play_pause_btn.setText("Pause")
1091
+ is_playing[0] = True
1092
+
1093
+ def advance_frame():
1094
+ current_frame_idx[0] += 1
1095
+ if current_frame_idx[0] > end_frame:
1096
+ current_frame_idx[0] = start_frame
1097
+ frame_slider.setValue(current_frame_idx[0])
1098
+ update_frame()
1099
+
1100
+ self._play_timer.timeout.connect(advance_frame)
1101
+ play_pause_btn.clicked.connect(toggle_play)
1102
+ controls_layout.addWidget(play_pause_btn)
1103
+
1104
+ next_seg_btn = QPushButton("Next Segment >")
1105
+ next_seg_btn.setToolTip("Go to next behavior segment")
1106
+ has_next = self._segment_idx is not None and self._segment_idx < len(self._widget.aggregated_segments) - 1
1107
+ next_seg_btn.setEnabled(has_next)
1108
+
1109
+ def go_next_segment():
1110
+ if self._segment_idx is not None and self._segment_idx < len(self._widget.aggregated_segments) - 1:
1111
+ self.close()
1112
+ next_seg = self._widget.aggregated_segments[self._segment_idx + 1]
1113
+ mid_frame = (next_seg['start'] + next_seg['end']) // 2
1114
+ self._widget._show_frame_segment_popup(mid_frame, next_seg, self._segment_idx + 1)
1115
+
1116
+ next_seg_btn.clicked.connect(go_next_segment)
1117
+ controls_layout.addWidget(next_seg_btn)
1118
+
1119
+ controls_layout.addStretch()
1120
+
1121
+ close_btn = QPushButton("Close")
1122
+ close_btn.clicked.connect(self.close)
1123
+ controls_layout.addWidget(close_btn)
1124
+
1125
+ layout.addLayout(controls_layout)
1126
+ self.setLayout(layout)
1127
+
1128
+ QTimer.singleShot(100, update_frame)
1129
+
1130
+ self.exec()
1131
+
1132
+ if self._play_timer.isActive():
1133
+ self._play_timer.stop()
1134
+
1135
+ def closeEvent(self, event):
1136
+ if hasattr(self, '_play_timer') and self._play_timer.isActive():
1137
+ self._play_timer.stop()
1138
+ super().closeEvent(event)