singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,754 @@
1
+ import logging
2
+ import sys
3
+ from PyQt6.QtWidgets import (
4
+ QMainWindow, QTabWidget, QMenuBar, QMenu, QMessageBox, QFileDialog, QInputDialog,
5
+ QDialog, QVBoxLayout, QHBoxLayout, QRadioButton, QButtonGroup, QDialogButtonBox,
6
+ QLabel, QGroupBox, QWidget, QPushButton,
7
+ )
8
+ from PyQt6.QtCore import Qt
9
+ from PyQt6.QtGui import QKeySequence, QAction
10
+ from .segmentation_tracking_widget import SegmentationTrackingWidget
11
+ from .registration_widget import RegistrationWidget
12
+ from .clustering_widget import ClusteringWidget
13
+ from .labeling_widget import LabelingWidget
14
+ from .training_widget import TrainingWidget
15
+ from .inference_widget import InferenceWidget
16
+ from .analysis_widget import AnalysisWidget
17
+ from .review_widget import ReviewWidget
18
+ from .tab_tutorial_dialog import show_tab_tutorial
19
+ import os
20
+ import json
21
+ import yaml
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class LabelingSetupDialog(QDialog):
27
+ """Dialog for choosing how to populate the labeling list."""
28
+ def __init__(self, parent=None):
29
+ super().__init__(parent)
30
+ self.setWindowTitle("Labeling setup")
31
+ self.setMinimumWidth(500)
32
+ self.result_data = None
33
+
34
+ layout = QVBoxLayout(self)
35
+ layout.addWidget(QLabel("How would you like to populate the labeling list?"))
36
+
37
+ self.rb_clustering = QRadioButton("Use representative clips from clustering")
38
+ self.rb_clustering.setChecked(True)
39
+ layout.addWidget(self.rb_clustering)
40
+
41
+ self.clustering_group = QGroupBox("Clustering options")
42
+ self.clustering_group_layout = QVBoxLayout()
43
+
44
+ self.rb_segmented = QRadioButton("Use existing segmented clips (ROIs)")
45
+ self.rb_segmented.setChecked(True)
46
+ self.clustering_group_layout.addWidget(self.rb_segmented)
47
+
48
+ self.rb_raw = QRadioButton("Extract raw clips from original video (Full frame/No mask)")
49
+ self.clustering_group_layout.addWidget(self.rb_raw)
50
+
51
+ self.clustering_group.setLayout(self.clustering_group_layout)
52
+ layout.addWidget(self.clustering_group)
53
+
54
+ self.rb_no_clustering = QRadioButton("Annotate raw videos on timeline (integrated in Labeling)")
55
+ layout.addWidget(self.rb_no_clustering)
56
+
57
+ self.rb_continue = QRadioButton("Continue with existing/manual list")
58
+ layout.addWidget(self.rb_continue)
59
+
60
+ self.bg_main = QButtonGroup(self)
61
+ self.bg_main.addButton(self.rb_clustering)
62
+ self.bg_main.addButton(self.rb_no_clustering)
63
+ self.bg_main.addButton(self.rb_continue)
64
+
65
+ self.rb_clustering.toggled.connect(self.clustering_group.setEnabled)
66
+ self.clustering_group.setEnabled(True)
67
+
68
+ button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
69
+ button_box.accepted.connect(self._on_accept)
70
+ button_box.rejected.connect(self.reject)
71
+ layout.addWidget(button_box)
72
+
73
+ def _on_accept(self):
74
+ if self.rb_clustering.isChecked():
75
+ mode = "clustering"
76
+ submode = "segmented" if self.rb_segmented.isChecked() else "raw"
77
+ elif self.rb_no_clustering.isChecked():
78
+ mode = "raw_extraction"
79
+ submode = None
80
+ else:
81
+ mode = "continue"
82
+ submode = None
83
+
84
+ self.result_data = {"mode": mode, "submode": submode}
85
+ self.accept()
86
+
87
+
88
+ class MainWindow(QMainWindow):
89
+ """Main application window."""
90
+
91
+ def __init__(self, config: dict):
92
+ super().__init__()
93
+ self.config = config
94
+ self._labeling_setup_prompted_for_clusters = False
95
+ self._update_window_title()
96
+ self.setGeometry(100, 100, 1400, 900)
97
+ self.setMinimumSize(1200, 700)
98
+
99
+ self._setup_menu()
100
+ self._setup_tabs()
101
+
102
+ self.setCentralWidget(self.tabs)
103
+
104
+ def _setup_menu(self):
105
+ """Setup menu bar."""
106
+ menubar = self.menuBar()
107
+
108
+ file_menu = menubar.addMenu("File")
109
+ open_action = QAction("Open Video...", self)
110
+ open_action.setShortcut(QKeySequence("Ctrl+O"))
111
+ open_action.triggered.connect(self._open_video)
112
+ file_menu.addAction(open_action)
113
+ file_menu.addSeparator()
114
+
115
+ create_exp_action = QAction("Create Experiment...", self)
116
+ create_exp_action.triggered.connect(self._create_experiment)
117
+ file_menu.addAction(create_exp_action)
118
+
119
+ load_exp_action = QAction("Load Experiment...", self)
120
+ load_exp_action.triggered.connect(self._load_experiment)
121
+ file_menu.addAction(load_exp_action)
122
+
123
+ save_exp_action = QAction("Save Experiment", self)
124
+ save_exp_action.triggered.connect(self._save_experiment)
125
+ file_menu.addAction(save_exp_action)
126
+
127
+ file_menu.addSeparator()
128
+
129
+ exit_action = QAction("Exit", self)
130
+ exit_action.setShortcut(QKeySequence("Ctrl+Q"))
131
+ exit_action.triggered.connect(self.close)
132
+ file_menu.addAction(exit_action)
133
+
134
+ def _setup_tabs(self):
135
+ """Setup tab widget with all tabs."""
136
+ self.tabs = QTabWidget()
137
+
138
+ self.segmentation_widget = SegmentationTrackingWidget(self.config)
139
+ self.registration_widget = RegistrationWidget(self.config)
140
+ self.clustering_widget = ClusteringWidget(self.config)
141
+ self.labeling_widget = LabelingWidget(self.config)
142
+ self.training_widget = TrainingWidget(self.config)
143
+ self.inference_widget = InferenceWidget(self.config)
144
+ self.analysis_widget = AnalysisWidget(self.config)
145
+ self.review_widget = ReviewWidget(self.config)
146
+
147
+ self.tabs.addTab(
148
+ self._wrap_tab_with_help(self.labeling_widget, "labeling"), "Labeling"
149
+ )
150
+ self.tabs.addTab(
151
+ self._wrap_tab_with_help(self.training_widget, "training"),
152
+ "Training Sequencing Model",
153
+ )
154
+ self.tabs.addTab(
155
+ self._wrap_tab_with_help(self.inference_widget, "sequencing"), "Sequencing"
156
+ )
157
+ self.tabs.addTab(
158
+ self._wrap_tab_with_help(self.review_widget, "refine"), "Refine"
159
+ )
160
+ self.tabs.addTab(
161
+ self._wrap_tab_with_help(self.analysis_widget, "analysis"),
162
+ "Downstream Analysis",
163
+ )
164
+ self.tabs.addTab(
165
+ self._wrap_tab_with_help(self.segmentation_widget, "segmentation"),
166
+ "Segmentation Tracking",
167
+ )
168
+ self.tabs.addTab(
169
+ self._wrap_tab_with_help(self.registration_widget, "registration"),
170
+ "Registration",
171
+ )
172
+ self.tabs.addTab(
173
+ self._wrap_tab_with_help(self.clustering_widget, "clustering"),
174
+ "Clustering",
175
+ )
176
+
177
+ self.segmentation_widget.tracking_completed.connect(self._on_tracking_completed)
178
+ self.registration_widget.embeddings_extracted.connect(self._on_embeddings_extracted)
179
+ self.inference_widget.review_ready.connect(self._on_review_ready)
180
+ self.review_widget.annotations_updated.connect(self._on_annotations_updated)
181
+
182
+ self.tabs.currentChanged.connect(self._on_tab_changed)
183
+
184
+ def _wrap_tab_with_help(self, inner: QWidget, tab_id: str) -> QWidget:
185
+ outer = QWidget()
186
+ layout = QVBoxLayout(outer)
187
+ layout.setContentsMargins(4, 4, 4, 0)
188
+ layout.setSpacing(2)
189
+ top = QHBoxLayout()
190
+ top.addStretch(1)
191
+ help_btn = QPushButton("📖 Tab Guide")
192
+ help_btn.setToolTip("Open a detailed guide for this tab and what to do next (NOR example)")
193
+ help_btn.setStyleSheet(
194
+ "QPushButton {"
195
+ "padding: 6px 12px;"
196
+ "font-weight: 600;"
197
+ "border: 1px solid #8a7b00;"
198
+ "border-radius: 8px;"
199
+ "background-color: #fff3b0;"
200
+ "color: #3d3200;"
201
+ "}"
202
+ "QPushButton:hover {"
203
+ "background-color: #ffe680;"
204
+ "}"
205
+ "QPushButton:pressed {"
206
+ "background-color: #f7d154;"
207
+ "}"
208
+ )
209
+ help_btn.clicked.connect(lambda _=False, tid=tab_id: show_tab_tutorial(self, tid))
210
+ top.addWidget(help_btn)
211
+ layout.addLayout(top)
212
+ layout.addWidget(inner, 1)
213
+ return outer
214
+
215
+ def _on_tab_changed(self, index: int):
216
+ """Handle tab change."""
217
+ current_size = self.size()
218
+ tab_name = self.tabs.tabText(index)
219
+
220
+ if tab_name == "Labeling":
221
+ if self.clustering_widget.clusters is not None and not self._labeling_setup_prompted_for_clusters:
222
+ self._handle_labeling_setup()
223
+ self._labeling_setup_prompted_for_clusters = True
224
+ self.labeling_widget.refresh_clip_list()
225
+ elif tab_name == "Training Sequencing Model":
226
+ self.training_widget._load_current_config()
227
+ self.training_widget.refresh_annotation_info()
228
+ elif tab_name == "Registration":
229
+ pass
230
+
231
+ self.resize(current_size)
232
+
233
+ def _handle_labeling_setup(self):
234
+ """Show labeling setup dialog and handle user choice."""
235
+ dialog = LabelingSetupDialog(self)
236
+ if dialog.exec():
237
+ choice = dialog.result_data
238
+ if not choice:
239
+ return
240
+
241
+ mode = choice.get("mode")
242
+ submode = choice.get("submode")
243
+
244
+ if mode == "clustering":
245
+ self._prepare_clustering_labeling_data(submode)
246
+ elif mode == "raw_extraction":
247
+ self.tabs.setCurrentWidget(self.labeling_widget)
248
+ self.labeling_widget.open_timeline_import_dialog()
249
+ QMessageBox.information(
250
+ self,
251
+ "Timeline labeling",
252
+ "Use the Timeline Annotation section in Labeling to add raw videos, mark intervals, and generate clips.",
253
+ )
254
+
255
+ def _prepare_clustering_labeling_data(self, submode):
256
+ """Prepare labeling data from clustering results."""
257
+ if not self.clustering_widget.clusters is not None:
258
+ QMessageBox.warning(self, "No clusters", "No clustering data available. Please perform clustering first.")
259
+ return
260
+
261
+ try:
262
+ rep_snippets = self.clustering_widget.get_representative_snippets(n_samples=10)
263
+ if not rep_snippets:
264
+ QMessageBox.warning(self, "No snippets", "Could not identify representative snippets.")
265
+ return
266
+
267
+ clips_to_add = []
268
+ missing_clips = []
269
+
270
+ self.clustering_widget._build_snippet_to_clip_map()
271
+ snippet_map = self.clustering_widget.snippet_to_clip_map
272
+
273
+ experiment_path = self.config.get("experiment_path")
274
+ if not experiment_path:
275
+ QMessageBox.warning(self, "Error", "No experiment loaded.")
276
+ return
277
+
278
+ labeling_clips_dir = os.path.join(experiment_path, "data", "clips")
279
+ os.makedirs(labeling_clips_dir, exist_ok=True)
280
+
281
+ import shutil
282
+
283
+ for cluster_label, snippet_ids in rep_snippets.items():
284
+ label = cluster_label
285
+
286
+ for snip in snippet_ids:
287
+ clip_path = snippet_map.get(snip)
288
+
289
+ if submode == "segmented":
290
+ if clip_path and os.path.exists(clip_path):
291
+ seg_dir = os.path.join(labeling_clips_dir, "segmented_clips")
292
+ os.makedirs(seg_dir, exist_ok=True)
293
+
294
+ filename = os.path.basename(clip_path)
295
+ target_path = os.path.join(seg_dir, filename)
296
+
297
+ if not os.path.exists(target_path):
298
+ try:
299
+ shutil.copy2(clip_path, target_path)
300
+ except Exception as e:
301
+ logger.error("Error copying clip %s: %s", clip_path, e)
302
+ continue
303
+
304
+ clips_to_add.append({
305
+ "path": target_path,
306
+ "label": label,
307
+ "snippet_id": snip
308
+ })
309
+ else:
310
+ missing_clips.append(snip)
311
+
312
+ elif submode == "raw":
313
+ raw_clip_path = self._extract_raw_clip_from_snippet(snip, label)
314
+ if raw_clip_path and os.path.exists(raw_clip_path):
315
+ clips_to_add.append({
316
+ "path": raw_clip_path,
317
+ "label": label,
318
+ "snippet_id": snip
319
+ })
320
+ else:
321
+ missing_clips.append(snip)
322
+
323
+ if submode == "raw" and not clips_to_add:
324
+ QMessageBox.warning(self, "No clips",
325
+ "Could not extract raw clips. Make sure video files and mask data are available.")
326
+ return
327
+
328
+ if missing_clips:
329
+ logger.warning("Could not find clip files for %d snippets.", len(missing_clips))
330
+
331
+ if not clips_to_add:
332
+ QMessageBox.warning(self, "No clips", "No valid clips found for labeling data.")
333
+ return
334
+
335
+ self.labeling_widget.clip_base_dir = labeling_clips_dir
336
+ self.labeling_widget.config["clips_dir"] = labeling_clips_dir
337
+
338
+ added_count = 0
339
+ am = self.labeling_widget.annotation_manager
340
+ existing_ids = {c.get("id") for c in am.get_all_clips()}
341
+
342
+ for clip_data in clips_to_add:
343
+ abs_path = clip_data["path"].replace('\\', '/')
344
+ try:
345
+ rel_path = os.path.relpath(abs_path, labeling_clips_dir).replace('\\', '/')
346
+ except ValueError:
347
+ rel_path = os.path.basename(abs_path)
348
+
349
+ if rel_path not in existing_ids:
350
+ am.add_clip(rel_path, clip_data["label"], meta={"snippet_id": clip_data["snippet_id"]})
351
+ am.add_class(clip_data["label"])
352
+ added_count += 1
353
+
354
+ if added_count > 0:
355
+ annotation_path = self.labeling_widget.annotation_manager.annotation_file
356
+ QMessageBox.information(self, "Data prepared",
357
+ f"Added {added_count} representative clips from {len(rep_snippets)} clusters to the labeling list.\n\n"
358
+ f"Annotations saved to:\n{annotation_path}\n\n"
359
+ "You can now review and refine labels.")
360
+ self.labeling_widget._update_class_combo()
361
+ self.labeling_widget.refresh_clip_list()
362
+ else:
363
+ QMessageBox.information(self, "Data ready", "All representative clips are already in the labeling list.")
364
+
365
+ except Exception as e:
366
+ logger.error("Failed to prepare labeling data: %s", e, exc_info=True)
367
+ QMessageBox.critical(self, "Error", f"Failed to prepare labeling data: {str(e)}")
368
+
369
+ def _extract_raw_clip_from_snippet(self, snippet_id: str, label: str) -> str:
370
+ """Extract a raw (unmasked) clip from the original video for a snippet.
371
+
372
+ Returns the path to the extracted clip, or None if extraction failed.
373
+ """
374
+ import cv2
375
+ from singlebehaviorlab.backend.video_processor import load_segmentation_data
376
+
377
+ try:
378
+ metadata = self.clustering_widget.metadata
379
+ if metadata is None:
380
+ return None
381
+
382
+ snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
383
+ if not snippet_col:
384
+ return None
385
+
386
+ snippet_row = metadata[metadata[snippet_col].astype(str) == str(snippet_id)]
387
+ if len(snippet_row) == 0:
388
+ return None
389
+
390
+ row = snippet_row.iloc[0]
391
+ start_frame = row.get('start_frame')
392
+ end_frame = row.get('end_frame')
393
+ video_id = row.get('video_id', '')
394
+ group = row.get('group', '')
395
+
396
+ try:
397
+ start_frame = int(float(start_frame))
398
+ end_frame = int(float(end_frame))
399
+ except (ValueError, TypeError):
400
+ return None
401
+
402
+ experiment_path = self.config.get("experiment_path")
403
+ if not experiment_path:
404
+ return None
405
+
406
+ video_name_candidates = []
407
+
408
+ if group:
409
+ video_name_candidates.append(str(group).strip())
410
+
411
+ if video_id:
412
+ vid_base = os.path.splitext(os.path.basename(str(video_id)))[0]
413
+ video_name_candidates.append(vid_base)
414
+ import re
415
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', vid_base)
416
+ if match:
417
+ video_name_candidates.append(match.group(1))
418
+
419
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
420
+ video_path = None
421
+
422
+ all_videos = []
423
+ for root, dirs, files in os.walk(experiment_path):
424
+ if 'registered_clips' in root or 'raw_clips' in root:
425
+ continue
426
+ for f in files:
427
+ if any(f.lower().endswith(ext.lower()) for ext in video_extensions):
428
+ all_videos.append(os.path.join(root, f))
429
+
430
+ for video_name in video_name_candidates:
431
+ if video_path:
432
+ break
433
+
434
+ for vp in all_videos:
435
+ vp_basename = os.path.splitext(os.path.basename(vp))[0]
436
+ if vp_basename == video_name:
437
+ video_path = vp
438
+ break
439
+
440
+ if not video_path:
441
+ for vp in all_videos:
442
+ vp_basename = os.path.splitext(os.path.basename(vp))[0]
443
+ if video_name in vp_basename or vp_basename in video_name:
444
+ video_path = vp
445
+ break
446
+
447
+ if not video_path:
448
+ logger.info("Could not find video for snippet %s", snippet_id)
449
+ logger.info(" Tried video names: %s", video_name_candidates)
450
+ logger.info(" Available videos: %s", [os.path.basename(v) for v in all_videos[:10]])
451
+ return None
452
+
453
+ video_name = video_name_candidates[0] if video_name_candidates else ""
454
+
455
+ mask_path = None
456
+ mask_dirs = [
457
+ os.path.join(experiment_path, "masks"),
458
+ os.path.join(experiment_path, "segmentation_masks"),
459
+ experiment_path
460
+ ]
461
+
462
+ for mask_dir in mask_dirs:
463
+ if not os.path.exists(mask_dir):
464
+ continue
465
+ for f in os.listdir(mask_dir):
466
+ if f.endswith(('.h5', '.hdf5')) and (
467
+ video_name in f or
468
+ f.replace('_mask.h5', '').replace('_mask.hdf5', '') in video_name
469
+ ):
470
+ mask_path = os.path.join(mask_dir, f)
471
+ break
472
+ if mask_path:
473
+ break
474
+
475
+ centroids = {}
476
+ box_size = 288
477
+
478
+ if mask_path:
479
+ try:
480
+ mask_data = load_segmentation_data(mask_path)
481
+ frame_objects = mask_data.get('frame_objects', [])
482
+ start_offset = mask_data.get('start_offset', 0)
483
+
484
+ for frame_idx in range(start_frame, end_frame + 1):
485
+ mask_frame_idx = frame_idx - start_offset
486
+ if mask_frame_idx < 0 or mask_frame_idx >= len(frame_objects):
487
+ continue
488
+
489
+ for obj in frame_objects[mask_frame_idx]:
490
+ mask = obj.get("mask")
491
+ bbox = obj.get("bbox")
492
+ if mask is None or bbox is None or not mask.any():
493
+ continue
494
+ ys, xs = mask.nonzero()
495
+ if len(xs) == 0:
496
+ continue
497
+ x_min, y_min, _, _ = bbox
498
+ cx = int(x_min + xs.mean())
499
+ cy = int(y_min + ys.mean())
500
+ centroids[frame_idx] = (cx, cy)
501
+ break
502
+ except Exception as e:
503
+ logger.error("Error loading mask data: %s", e)
504
+
505
+ cap = cv2.VideoCapture(video_path)
506
+ if not cap.isOpened():
507
+ return None
508
+
509
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
510
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
511
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
512
+
513
+ raw_clips_dir = os.path.join(experiment_path, "data", "clips", "raw_clips")
514
+ os.makedirs(raw_clips_dir, exist_ok=True)
515
+
516
+ safe_label = re.sub(r'[^\w\-_]', '_', label)
517
+ output_filename = f"{video_name}_{safe_label}_{snippet_id}.mp4"
518
+ output_path = os.path.join(raw_clips_dir, output_filename)
519
+
520
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
521
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
522
+
523
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
524
+
525
+ for frame_idx in range(start_frame, end_frame + 1):
526
+ ret, frame = cap.read()
527
+ if not ret:
528
+ break
529
+
530
+ out.write(frame)
531
+
532
+ cap.release()
533
+ out.release()
534
+
535
+ return output_path
536
+
537
+ except Exception as e:
538
+ logger.error("Error extracting raw clip from snippet: %s", e, exc_info=True)
539
+ return None
540
+
541
+ def _on_review_ready(self, results: dict, classes: list, is_ovr: bool,
542
+ clip_length: int, target_fps: int):
543
+ """Populate the Review tab when inference finishes and switch to it."""
544
+ self.review_widget.load_from_inference(results, classes, is_ovr, clip_length, target_fps)
545
+ self.tabs.setCurrentWidget(self.review_widget)
546
+
547
+ def _on_annotations_updated(self):
548
+ """Refresh training / labeling tabs after review saves new clips."""
549
+ self.labeling_widget.refresh_clip_list()
550
+ self.training_widget.refresh_annotation_info()
551
+
552
+ def _on_tracking_completed(self, video_path: str, mask_path: str):
553
+ """Handle tracking completion - switch to registration tab and load data."""
554
+ self.tabs.setCurrentWidget(self.registration_widget)
555
+
556
+ self.registration_widget.load_from_segmentation(video_path, mask_path)
557
+
558
+ def _on_embeddings_extracted(self, matrix_path: str, metadata_path: str):
559
+ """Handle embedding extraction completion - auto-load into clustering tab."""
560
+ self.tabs.setCurrentWidget(self.clustering_widget)
561
+ self.clustering_widget.load_from_registration(matrix_path, metadata_path)
562
+ self._labeling_setup_prompted_for_clusters = False
563
+
564
+ def _open_video(self):
565
+ """Open video file dialog."""
566
+ video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
567
+ video_path, _ = QFileDialog.getOpenFileName(
568
+ self,
569
+ "Open Video File",
570
+ video_dir,
571
+ "Video Files (*.mp4 *.avi *.mov *.mkv);;All Files (*)"
572
+ )
573
+ if video_path:
574
+ from .video_utils import ensure_video_in_experiment
575
+ video_path = ensure_video_in_experiment(video_path, self.config, self)
576
+ self.labeling_widget.add_source_videos([video_path], select_last=True)
577
+ self.tabs.setCurrentWidget(self.labeling_widget)
578
+
579
+ def _create_experiment(self):
580
+ """Create a new experiment directory structure."""
581
+ name, ok = QInputDialog.getText(self, "Create Experiment", "Experiment name:")
582
+ if not ok or not name.strip():
583
+ return
584
+ name = name.strip()
585
+
586
+ experiments_dir = self.config.get("experiments_dir")
587
+ if not experiments_dir:
588
+ from singlebehaviorlab._paths import get_experiments_dir
589
+ experiments_dir = str(get_experiments_dir())
590
+ self.config["experiments_dir"] = experiments_dir
591
+ os.makedirs(experiments_dir, exist_ok=True)
592
+
593
+ experiment_path = os.path.abspath(os.path.join(experiments_dir, name))
594
+ if os.path.exists(experiment_path) and os.listdir(experiment_path):
595
+ reply = QMessageBox.question(
596
+ self,
597
+ "Overwrite Experiment",
598
+ f"The experiment folder '{experiment_path}' already exists and is not empty.\n"
599
+ "Do you want to reuse it? Existing files will be kept.",
600
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
601
+ QMessageBox.StandardButton.No
602
+ )
603
+ if reply != QMessageBox.StandardButton.Yes:
604
+ return
605
+ os.makedirs(experiment_path, exist_ok=True)
606
+
607
+ data_dir = os.path.join(experiment_path, "data")
608
+ raw_videos_dir = os.path.join(data_dir, "raw_videos")
609
+ clips_dir = os.path.join(data_dir, "clips")
610
+ annotations_dir = os.path.join(data_dir, "annotations")
611
+ models_dir = os.path.join(experiment_path, "models", "behavior_heads")
612
+
613
+ for path in [data_dir, raw_videos_dir, clips_dir, annotations_dir, models_dir]:
614
+ os.makedirs(path, exist_ok=True)
615
+
616
+ annotation_file = os.path.join(annotations_dir, "annotations.json")
617
+ if not os.path.exists(annotation_file):
618
+ default_annotations = {"classes": [], "clips": []}
619
+ with open(annotation_file, "w", encoding="utf-8") as f:
620
+ json.dump(default_annotations, f, indent=2)
621
+
622
+ profiles_file = os.path.join(experiment_path, "training_profiles.json")
623
+ if not os.path.exists(profiles_file):
624
+ from singlebehaviorlab._paths import get_training_profiles_path
625
+ src = get_training_profiles_path()
626
+ if src and os.path.exists(str(src)):
627
+ import shutil
628
+ shutil.copy2(str(src), profiles_file)
629
+ else:
630
+ with open(profiles_file, "w", encoding="utf-8") as f:
631
+ json.dump({"Default": {}}, f, indent=2)
632
+
633
+ experiment_config_path = os.path.join(experiment_path, "config.yaml")
634
+
635
+ new_config = dict(self.config)
636
+ new_config.update({
637
+ "experiment_name": name,
638
+ "experiment_path": experiment_path,
639
+ "data_dir": data_dir,
640
+ "raw_videos_dir": raw_videos_dir,
641
+ "clips_dir": clips_dir,
642
+ "annotations_dir": annotations_dir,
643
+ "annotation_file": annotation_file,
644
+ "training_clips_dir": clips_dir,
645
+ "training_annotation_file": annotation_file,
646
+ "models_dir": models_dir,
647
+ "config_path": experiment_config_path,
648
+ # Use a conservative default that works well for small labeled datasets.
649
+ "default_weight_decay": 0.001,
650
+ })
651
+
652
+ self._update_config(new_config)
653
+ self._save_experiment_config()
654
+ self._apply_config_to_widgets()
655
+ self._labeling_setup_prompted_for_clusters = False
656
+ try:
657
+ with open(profiles_file, "r", encoding="utf-8") as f:
658
+ profiles_data = json.load(f) or {}
659
+ if isinstance(profiles_data, dict) and "LowInputData" in profiles_data:
660
+ self.training_widget.apply_training_config(profiles_data["LowInputData"])
661
+ self.training_widget.current_profile_name = "LowInputData"
662
+ except Exception:
663
+ pass
664
+ self._update_window_title()
665
+ QMessageBox.information(self, "Experiment created", f"Experiment '{name}' created at:\n{experiment_path}")
666
+
667
+ def _load_experiment(self):
668
+ """Load an existing experiment configuration."""
669
+ start_dir = self.config.get("experiments_dir", os.getcwd())
670
+ config_path, _ = QFileDialog.getOpenFileName(
671
+ self,
672
+ "Load Experiment",
673
+ start_dir,
674
+ "YAML Files (*.yaml *.yml);;All Files (*)"
675
+ )
676
+ if not config_path:
677
+ return
678
+
679
+ try:
680
+ with open(config_path, "r", encoding="utf-8") as f:
681
+ loaded = yaml.safe_load(f) or {}
682
+ except Exception as exc:
683
+ QMessageBox.critical(self, "Error", f"Failed to load experiment config:\n{exc}")
684
+ return
685
+
686
+ experiment_path = loaded.get("experiment_path") or os.path.dirname(os.path.abspath(config_path))
687
+ loaded["experiment_path"] = experiment_path
688
+ loaded["config_path"] = config_path
689
+ loaded.setdefault("experiments_dir", self.config.get("experiments_dir"))
690
+
691
+ # Resolve experiment paths relative to the loaded config when needed.
692
+ def _abs_path(key, default_subpath):
693
+ if key not in loaded or not loaded[key]:
694
+ loaded[key] = os.path.join(experiment_path, default_subpath)
695
+ elif not os.path.isabs(loaded[key]):
696
+ loaded[key] = os.path.join(experiment_path, loaded[key])
697
+
698
+ _abs_path("data_dir", "data")
699
+ _abs_path("raw_videos_dir", os.path.join("data", "raw_videos"))
700
+ _abs_path("clips_dir", os.path.join("data", "clips"))
701
+ _abs_path("annotations_dir", os.path.join("data", "annotations"))
702
+ _abs_path("models_dir", os.path.join("models", "behavior_heads"))
703
+ _abs_path("annotation_file", os.path.join("data", "annotations", "annotations.json"))
704
+
705
+ self._update_config(loaded)
706
+ self._apply_config_to_widgets()
707
+ self._labeling_setup_prompted_for_clusters = False
708
+ self._update_window_title()
709
+ QMessageBox.information(self, "Experiment loaded", f"Loaded experiment from:\n{config_path}")
710
+
711
+ def _save_experiment(self):
712
+ """Save current experiment configuration."""
713
+ config_path = self.config.get("config_path")
714
+ if not config_path:
715
+ QMessageBox.warning(self, "Save experiment", "No experiment is currently loaded.")
716
+ return
717
+ self._save_experiment_config()
718
+ QMessageBox.information(self, "Experiment saved", f"Experiment saved to:\n{config_path}")
719
+
720
+ def _save_experiment_config(self):
721
+ """Write the current config to disk."""
722
+ config_path = self.config.get("config_path")
723
+ if not config_path:
724
+ return
725
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
726
+ with open(config_path, "w", encoding="utf-8") as f:
727
+ yaml.safe_dump(dict(self.config), f, sort_keys=False)
728
+
729
+ def _update_config(self, new_values: dict):
730
+ """Update the shared config dictionary in place."""
731
+ self.config.clear()
732
+ self.config.update(new_values)
733
+
734
+ def _apply_config_to_widgets(self):
735
+ """Propagate config changes to all widgets."""
736
+ self.segmentation_widget.update_config(self.config)
737
+ self.registration_widget.update_config(self.config)
738
+ self.clustering_widget.update_config(self.config)
739
+ self.labeling_widget.update_config(self.config)
740
+ self.training_widget.update_config(self.config)
741
+ self.inference_widget.update_config(self.config)
742
+ self.review_widget.update_config(self.config)
743
+ self.analysis_widget.update_config(self.config)
744
+
745
+ def _update_window_title(self):
746
+ """Update window title with experiment name."""
747
+ name = self.config.get("experiment_name")
748
+ base_title = "SingleBehavior Lab"
749
+ if name:
750
+ self.setWindowTitle(f"{base_title} - {name}")
751
+ else:
752
+ self.setWindowTitle(base_title)
753
+
754
+