singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,754 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
from PyQt6.QtWidgets import (
|
|
4
|
+
QMainWindow, QTabWidget, QMenuBar, QMenu, QMessageBox, QFileDialog, QInputDialog,
|
|
5
|
+
QDialog, QVBoxLayout, QHBoxLayout, QRadioButton, QButtonGroup, QDialogButtonBox,
|
|
6
|
+
QLabel, QGroupBox, QWidget, QPushButton,
|
|
7
|
+
)
|
|
8
|
+
from PyQt6.QtCore import Qt
|
|
9
|
+
from PyQt6.QtGui import QKeySequence, QAction
|
|
10
|
+
from .segmentation_tracking_widget import SegmentationTrackingWidget
|
|
11
|
+
from .registration_widget import RegistrationWidget
|
|
12
|
+
from .clustering_widget import ClusteringWidget
|
|
13
|
+
from .labeling_widget import LabelingWidget
|
|
14
|
+
from .training_widget import TrainingWidget
|
|
15
|
+
from .inference_widget import InferenceWidget
|
|
16
|
+
from .analysis_widget import AnalysisWidget
|
|
17
|
+
from .review_widget import ReviewWidget
|
|
18
|
+
from .tab_tutorial_dialog import show_tab_tutorial
|
|
19
|
+
import os
|
|
20
|
+
import json
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LabelingSetupDialog(QDialog):
|
|
27
|
+
"""Dialog for choosing how to populate the labeling list."""
|
|
28
|
+
def __init__(self, parent=None):
|
|
29
|
+
super().__init__(parent)
|
|
30
|
+
self.setWindowTitle("Labeling setup")
|
|
31
|
+
self.setMinimumWidth(500)
|
|
32
|
+
self.result_data = None
|
|
33
|
+
|
|
34
|
+
layout = QVBoxLayout(self)
|
|
35
|
+
layout.addWidget(QLabel("How would you like to populate the labeling list?"))
|
|
36
|
+
|
|
37
|
+
self.rb_clustering = QRadioButton("Use representative clips from clustering")
|
|
38
|
+
self.rb_clustering.setChecked(True)
|
|
39
|
+
layout.addWidget(self.rb_clustering)
|
|
40
|
+
|
|
41
|
+
self.clustering_group = QGroupBox("Clustering options")
|
|
42
|
+
self.clustering_group_layout = QVBoxLayout()
|
|
43
|
+
|
|
44
|
+
self.rb_segmented = QRadioButton("Use existing segmented clips (ROIs)")
|
|
45
|
+
self.rb_segmented.setChecked(True)
|
|
46
|
+
self.clustering_group_layout.addWidget(self.rb_segmented)
|
|
47
|
+
|
|
48
|
+
self.rb_raw = QRadioButton("Extract raw clips from original video (Full frame/No mask)")
|
|
49
|
+
self.clustering_group_layout.addWidget(self.rb_raw)
|
|
50
|
+
|
|
51
|
+
self.clustering_group.setLayout(self.clustering_group_layout)
|
|
52
|
+
layout.addWidget(self.clustering_group)
|
|
53
|
+
|
|
54
|
+
self.rb_no_clustering = QRadioButton("Annotate raw videos on timeline (integrated in Labeling)")
|
|
55
|
+
layout.addWidget(self.rb_no_clustering)
|
|
56
|
+
|
|
57
|
+
self.rb_continue = QRadioButton("Continue with existing/manual list")
|
|
58
|
+
layout.addWidget(self.rb_continue)
|
|
59
|
+
|
|
60
|
+
self.bg_main = QButtonGroup(self)
|
|
61
|
+
self.bg_main.addButton(self.rb_clustering)
|
|
62
|
+
self.bg_main.addButton(self.rb_no_clustering)
|
|
63
|
+
self.bg_main.addButton(self.rb_continue)
|
|
64
|
+
|
|
65
|
+
self.rb_clustering.toggled.connect(self.clustering_group.setEnabled)
|
|
66
|
+
self.clustering_group.setEnabled(True)
|
|
67
|
+
|
|
68
|
+
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
|
69
|
+
button_box.accepted.connect(self._on_accept)
|
|
70
|
+
button_box.rejected.connect(self.reject)
|
|
71
|
+
layout.addWidget(button_box)
|
|
72
|
+
|
|
73
|
+
def _on_accept(self):
|
|
74
|
+
if self.rb_clustering.isChecked():
|
|
75
|
+
mode = "clustering"
|
|
76
|
+
submode = "segmented" if self.rb_segmented.isChecked() else "raw"
|
|
77
|
+
elif self.rb_no_clustering.isChecked():
|
|
78
|
+
mode = "raw_extraction"
|
|
79
|
+
submode = None
|
|
80
|
+
else:
|
|
81
|
+
mode = "continue"
|
|
82
|
+
submode = None
|
|
83
|
+
|
|
84
|
+
self.result_data = {"mode": mode, "submode": submode}
|
|
85
|
+
self.accept()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class MainWindow(QMainWindow):
|
|
89
|
+
"""Main application window."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, config: dict):
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.config = config
|
|
94
|
+
self._labeling_setup_prompted_for_clusters = False
|
|
95
|
+
self._update_window_title()
|
|
96
|
+
self.setGeometry(100, 100, 1400, 900)
|
|
97
|
+
self.setMinimumSize(1200, 700)
|
|
98
|
+
|
|
99
|
+
self._setup_menu()
|
|
100
|
+
self._setup_tabs()
|
|
101
|
+
|
|
102
|
+
self.setCentralWidget(self.tabs)
|
|
103
|
+
|
|
104
|
+
def _setup_menu(self):
|
|
105
|
+
"""Setup menu bar."""
|
|
106
|
+
menubar = self.menuBar()
|
|
107
|
+
|
|
108
|
+
file_menu = menubar.addMenu("File")
|
|
109
|
+
open_action = QAction("Open Video...", self)
|
|
110
|
+
open_action.setShortcut(QKeySequence("Ctrl+O"))
|
|
111
|
+
open_action.triggered.connect(self._open_video)
|
|
112
|
+
file_menu.addAction(open_action)
|
|
113
|
+
file_menu.addSeparator()
|
|
114
|
+
|
|
115
|
+
create_exp_action = QAction("Create Experiment...", self)
|
|
116
|
+
create_exp_action.triggered.connect(self._create_experiment)
|
|
117
|
+
file_menu.addAction(create_exp_action)
|
|
118
|
+
|
|
119
|
+
load_exp_action = QAction("Load Experiment...", self)
|
|
120
|
+
load_exp_action.triggered.connect(self._load_experiment)
|
|
121
|
+
file_menu.addAction(load_exp_action)
|
|
122
|
+
|
|
123
|
+
save_exp_action = QAction("Save Experiment", self)
|
|
124
|
+
save_exp_action.triggered.connect(self._save_experiment)
|
|
125
|
+
file_menu.addAction(save_exp_action)
|
|
126
|
+
|
|
127
|
+
file_menu.addSeparator()
|
|
128
|
+
|
|
129
|
+
exit_action = QAction("Exit", self)
|
|
130
|
+
exit_action.setShortcut(QKeySequence("Ctrl+Q"))
|
|
131
|
+
exit_action.triggered.connect(self.close)
|
|
132
|
+
file_menu.addAction(exit_action)
|
|
133
|
+
|
|
134
|
+
def _setup_tabs(self):
|
|
135
|
+
"""Setup tab widget with all tabs."""
|
|
136
|
+
self.tabs = QTabWidget()
|
|
137
|
+
|
|
138
|
+
self.segmentation_widget = SegmentationTrackingWidget(self.config)
|
|
139
|
+
self.registration_widget = RegistrationWidget(self.config)
|
|
140
|
+
self.clustering_widget = ClusteringWidget(self.config)
|
|
141
|
+
self.labeling_widget = LabelingWidget(self.config)
|
|
142
|
+
self.training_widget = TrainingWidget(self.config)
|
|
143
|
+
self.inference_widget = InferenceWidget(self.config)
|
|
144
|
+
self.analysis_widget = AnalysisWidget(self.config)
|
|
145
|
+
self.review_widget = ReviewWidget(self.config)
|
|
146
|
+
|
|
147
|
+
self.tabs.addTab(
|
|
148
|
+
self._wrap_tab_with_help(self.labeling_widget, "labeling"), "Labeling"
|
|
149
|
+
)
|
|
150
|
+
self.tabs.addTab(
|
|
151
|
+
self._wrap_tab_with_help(self.training_widget, "training"),
|
|
152
|
+
"Training Sequencing Model",
|
|
153
|
+
)
|
|
154
|
+
self.tabs.addTab(
|
|
155
|
+
self._wrap_tab_with_help(self.inference_widget, "sequencing"), "Sequencing"
|
|
156
|
+
)
|
|
157
|
+
self.tabs.addTab(
|
|
158
|
+
self._wrap_tab_with_help(self.review_widget, "refine"), "Refine"
|
|
159
|
+
)
|
|
160
|
+
self.tabs.addTab(
|
|
161
|
+
self._wrap_tab_with_help(self.analysis_widget, "analysis"),
|
|
162
|
+
"Downstream Analysis",
|
|
163
|
+
)
|
|
164
|
+
self.tabs.addTab(
|
|
165
|
+
self._wrap_tab_with_help(self.segmentation_widget, "segmentation"),
|
|
166
|
+
"Segmentation Tracking",
|
|
167
|
+
)
|
|
168
|
+
self.tabs.addTab(
|
|
169
|
+
self._wrap_tab_with_help(self.registration_widget, "registration"),
|
|
170
|
+
"Registration",
|
|
171
|
+
)
|
|
172
|
+
self.tabs.addTab(
|
|
173
|
+
self._wrap_tab_with_help(self.clustering_widget, "clustering"),
|
|
174
|
+
"Clustering",
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.segmentation_widget.tracking_completed.connect(self._on_tracking_completed)
|
|
178
|
+
self.registration_widget.embeddings_extracted.connect(self._on_embeddings_extracted)
|
|
179
|
+
self.inference_widget.review_ready.connect(self._on_review_ready)
|
|
180
|
+
self.review_widget.annotations_updated.connect(self._on_annotations_updated)
|
|
181
|
+
|
|
182
|
+
self.tabs.currentChanged.connect(self._on_tab_changed)
|
|
183
|
+
|
|
184
|
+
def _wrap_tab_with_help(self, inner: QWidget, tab_id: str) -> QWidget:
|
|
185
|
+
outer = QWidget()
|
|
186
|
+
layout = QVBoxLayout(outer)
|
|
187
|
+
layout.setContentsMargins(4, 4, 4, 0)
|
|
188
|
+
layout.setSpacing(2)
|
|
189
|
+
top = QHBoxLayout()
|
|
190
|
+
top.addStretch(1)
|
|
191
|
+
help_btn = QPushButton("📖 Tab Guide")
|
|
192
|
+
help_btn.setToolTip("Open a detailed guide for this tab and what to do next (NOR example)")
|
|
193
|
+
help_btn.setStyleSheet(
|
|
194
|
+
"QPushButton {"
|
|
195
|
+
"padding: 6px 12px;"
|
|
196
|
+
"font-weight: 600;"
|
|
197
|
+
"border: 1px solid #8a7b00;"
|
|
198
|
+
"border-radius: 8px;"
|
|
199
|
+
"background-color: #fff3b0;"
|
|
200
|
+
"color: #3d3200;"
|
|
201
|
+
"}"
|
|
202
|
+
"QPushButton:hover {"
|
|
203
|
+
"background-color: #ffe680;"
|
|
204
|
+
"}"
|
|
205
|
+
"QPushButton:pressed {"
|
|
206
|
+
"background-color: #f7d154;"
|
|
207
|
+
"}"
|
|
208
|
+
)
|
|
209
|
+
help_btn.clicked.connect(lambda _=False, tid=tab_id: show_tab_tutorial(self, tid))
|
|
210
|
+
top.addWidget(help_btn)
|
|
211
|
+
layout.addLayout(top)
|
|
212
|
+
layout.addWidget(inner, 1)
|
|
213
|
+
return outer
|
|
214
|
+
|
|
215
|
+
def _on_tab_changed(self, index: int):
|
|
216
|
+
"""Handle tab change."""
|
|
217
|
+
current_size = self.size()
|
|
218
|
+
tab_name = self.tabs.tabText(index)
|
|
219
|
+
|
|
220
|
+
if tab_name == "Labeling":
|
|
221
|
+
if self.clustering_widget.clusters is not None and not self._labeling_setup_prompted_for_clusters:
|
|
222
|
+
self._handle_labeling_setup()
|
|
223
|
+
self._labeling_setup_prompted_for_clusters = True
|
|
224
|
+
self.labeling_widget.refresh_clip_list()
|
|
225
|
+
elif tab_name == "Training Sequencing Model":
|
|
226
|
+
self.training_widget._load_current_config()
|
|
227
|
+
self.training_widget.refresh_annotation_info()
|
|
228
|
+
elif tab_name == "Registration":
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
self.resize(current_size)
|
|
232
|
+
|
|
233
|
+
def _handle_labeling_setup(self):
|
|
234
|
+
"""Show labeling setup dialog and handle user choice."""
|
|
235
|
+
dialog = LabelingSetupDialog(self)
|
|
236
|
+
if dialog.exec():
|
|
237
|
+
choice = dialog.result_data
|
|
238
|
+
if not choice:
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
mode = choice.get("mode")
|
|
242
|
+
submode = choice.get("submode")
|
|
243
|
+
|
|
244
|
+
if mode == "clustering":
|
|
245
|
+
self._prepare_clustering_labeling_data(submode)
|
|
246
|
+
elif mode == "raw_extraction":
|
|
247
|
+
self.tabs.setCurrentWidget(self.labeling_widget)
|
|
248
|
+
self.labeling_widget.open_timeline_import_dialog()
|
|
249
|
+
QMessageBox.information(
|
|
250
|
+
self,
|
|
251
|
+
"Timeline labeling",
|
|
252
|
+
"Use the Timeline Annotation section in Labeling to add raw videos, mark intervals, and generate clips.",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def _prepare_clustering_labeling_data(self, submode):
|
|
256
|
+
"""Prepare labeling data from clustering results."""
|
|
257
|
+
if not self.clustering_widget.clusters is not None:
|
|
258
|
+
QMessageBox.warning(self, "No clusters", "No clustering data available. Please perform clustering first.")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
rep_snippets = self.clustering_widget.get_representative_snippets(n_samples=10)
|
|
263
|
+
if not rep_snippets:
|
|
264
|
+
QMessageBox.warning(self, "No snippets", "Could not identify representative snippets.")
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
clips_to_add = []
|
|
268
|
+
missing_clips = []
|
|
269
|
+
|
|
270
|
+
self.clustering_widget._build_snippet_to_clip_map()
|
|
271
|
+
snippet_map = self.clustering_widget.snippet_to_clip_map
|
|
272
|
+
|
|
273
|
+
experiment_path = self.config.get("experiment_path")
|
|
274
|
+
if not experiment_path:
|
|
275
|
+
QMessageBox.warning(self, "Error", "No experiment loaded.")
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
labeling_clips_dir = os.path.join(experiment_path, "data", "clips")
|
|
279
|
+
os.makedirs(labeling_clips_dir, exist_ok=True)
|
|
280
|
+
|
|
281
|
+
import shutil
|
|
282
|
+
|
|
283
|
+
for cluster_label, snippet_ids in rep_snippets.items():
|
|
284
|
+
label = cluster_label
|
|
285
|
+
|
|
286
|
+
for snip in snippet_ids:
|
|
287
|
+
clip_path = snippet_map.get(snip)
|
|
288
|
+
|
|
289
|
+
if submode == "segmented":
|
|
290
|
+
if clip_path and os.path.exists(clip_path):
|
|
291
|
+
seg_dir = os.path.join(labeling_clips_dir, "segmented_clips")
|
|
292
|
+
os.makedirs(seg_dir, exist_ok=True)
|
|
293
|
+
|
|
294
|
+
filename = os.path.basename(clip_path)
|
|
295
|
+
target_path = os.path.join(seg_dir, filename)
|
|
296
|
+
|
|
297
|
+
if not os.path.exists(target_path):
|
|
298
|
+
try:
|
|
299
|
+
shutil.copy2(clip_path, target_path)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.error("Error copying clip %s: %s", clip_path, e)
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
clips_to_add.append({
|
|
305
|
+
"path": target_path,
|
|
306
|
+
"label": label,
|
|
307
|
+
"snippet_id": snip
|
|
308
|
+
})
|
|
309
|
+
else:
|
|
310
|
+
missing_clips.append(snip)
|
|
311
|
+
|
|
312
|
+
elif submode == "raw":
|
|
313
|
+
raw_clip_path = self._extract_raw_clip_from_snippet(snip, label)
|
|
314
|
+
if raw_clip_path and os.path.exists(raw_clip_path):
|
|
315
|
+
clips_to_add.append({
|
|
316
|
+
"path": raw_clip_path,
|
|
317
|
+
"label": label,
|
|
318
|
+
"snippet_id": snip
|
|
319
|
+
})
|
|
320
|
+
else:
|
|
321
|
+
missing_clips.append(snip)
|
|
322
|
+
|
|
323
|
+
if submode == "raw" and not clips_to_add:
|
|
324
|
+
QMessageBox.warning(self, "No clips",
|
|
325
|
+
"Could not extract raw clips. Make sure video files and mask data are available.")
|
|
326
|
+
return
|
|
327
|
+
|
|
328
|
+
if missing_clips:
|
|
329
|
+
logger.warning("Could not find clip files for %d snippets.", len(missing_clips))
|
|
330
|
+
|
|
331
|
+
if not clips_to_add:
|
|
332
|
+
QMessageBox.warning(self, "No clips", "No valid clips found for labeling data.")
|
|
333
|
+
return
|
|
334
|
+
|
|
335
|
+
self.labeling_widget.clip_base_dir = labeling_clips_dir
|
|
336
|
+
self.labeling_widget.config["clips_dir"] = labeling_clips_dir
|
|
337
|
+
|
|
338
|
+
added_count = 0
|
|
339
|
+
am = self.labeling_widget.annotation_manager
|
|
340
|
+
existing_ids = {c.get("id") for c in am.get_all_clips()}
|
|
341
|
+
|
|
342
|
+
for clip_data in clips_to_add:
|
|
343
|
+
abs_path = clip_data["path"].replace('\\', '/')
|
|
344
|
+
try:
|
|
345
|
+
rel_path = os.path.relpath(abs_path, labeling_clips_dir).replace('\\', '/')
|
|
346
|
+
except ValueError:
|
|
347
|
+
rel_path = os.path.basename(abs_path)
|
|
348
|
+
|
|
349
|
+
if rel_path not in existing_ids:
|
|
350
|
+
am.add_clip(rel_path, clip_data["label"], meta={"snippet_id": clip_data["snippet_id"]})
|
|
351
|
+
am.add_class(clip_data["label"])
|
|
352
|
+
added_count += 1
|
|
353
|
+
|
|
354
|
+
if added_count > 0:
|
|
355
|
+
annotation_path = self.labeling_widget.annotation_manager.annotation_file
|
|
356
|
+
QMessageBox.information(self, "Data prepared",
|
|
357
|
+
f"Added {added_count} representative clips from {len(rep_snippets)} clusters to the labeling list.\n\n"
|
|
358
|
+
f"Annotations saved to:\n{annotation_path}\n\n"
|
|
359
|
+
"You can now review and refine labels.")
|
|
360
|
+
self.labeling_widget._update_class_combo()
|
|
361
|
+
self.labeling_widget.refresh_clip_list()
|
|
362
|
+
else:
|
|
363
|
+
QMessageBox.information(self, "Data ready", "All representative clips are already in the labeling list.")
|
|
364
|
+
|
|
365
|
+
except Exception as e:
|
|
366
|
+
logger.error("Failed to prepare labeling data: %s", e, exc_info=True)
|
|
367
|
+
QMessageBox.critical(self, "Error", f"Failed to prepare labeling data: {str(e)}")
|
|
368
|
+
|
|
369
|
+
def _extract_raw_clip_from_snippet(self, snippet_id: str, label: str) -> str:
|
|
370
|
+
"""Extract a raw (unmasked) clip from the original video for a snippet.
|
|
371
|
+
|
|
372
|
+
Returns the path to the extracted clip, or None if extraction failed.
|
|
373
|
+
"""
|
|
374
|
+
import cv2
|
|
375
|
+
from singlebehaviorlab.backend.video_processor import load_segmentation_data
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
metadata = self.clustering_widget.metadata
|
|
379
|
+
if metadata is None:
|
|
380
|
+
return None
|
|
381
|
+
|
|
382
|
+
snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
|
|
383
|
+
if not snippet_col:
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
snippet_row = metadata[metadata[snippet_col].astype(str) == str(snippet_id)]
|
|
387
|
+
if len(snippet_row) == 0:
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
row = snippet_row.iloc[0]
|
|
391
|
+
start_frame = row.get('start_frame')
|
|
392
|
+
end_frame = row.get('end_frame')
|
|
393
|
+
video_id = row.get('video_id', '')
|
|
394
|
+
group = row.get('group', '')
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
start_frame = int(float(start_frame))
|
|
398
|
+
end_frame = int(float(end_frame))
|
|
399
|
+
except (ValueError, TypeError):
|
|
400
|
+
return None
|
|
401
|
+
|
|
402
|
+
experiment_path = self.config.get("experiment_path")
|
|
403
|
+
if not experiment_path:
|
|
404
|
+
return None
|
|
405
|
+
|
|
406
|
+
video_name_candidates = []
|
|
407
|
+
|
|
408
|
+
if group:
|
|
409
|
+
video_name_candidates.append(str(group).strip())
|
|
410
|
+
|
|
411
|
+
if video_id:
|
|
412
|
+
vid_base = os.path.splitext(os.path.basename(str(video_id)))[0]
|
|
413
|
+
video_name_candidates.append(vid_base)
|
|
414
|
+
import re
|
|
415
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', vid_base)
|
|
416
|
+
if match:
|
|
417
|
+
video_name_candidates.append(match.group(1))
|
|
418
|
+
|
|
419
|
+
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
|
|
420
|
+
video_path = None
|
|
421
|
+
|
|
422
|
+
all_videos = []
|
|
423
|
+
for root, dirs, files in os.walk(experiment_path):
|
|
424
|
+
if 'registered_clips' in root or 'raw_clips' in root:
|
|
425
|
+
continue
|
|
426
|
+
for f in files:
|
|
427
|
+
if any(f.lower().endswith(ext.lower()) for ext in video_extensions):
|
|
428
|
+
all_videos.append(os.path.join(root, f))
|
|
429
|
+
|
|
430
|
+
for video_name in video_name_candidates:
|
|
431
|
+
if video_path:
|
|
432
|
+
break
|
|
433
|
+
|
|
434
|
+
for vp in all_videos:
|
|
435
|
+
vp_basename = os.path.splitext(os.path.basename(vp))[0]
|
|
436
|
+
if vp_basename == video_name:
|
|
437
|
+
video_path = vp
|
|
438
|
+
break
|
|
439
|
+
|
|
440
|
+
if not video_path:
|
|
441
|
+
for vp in all_videos:
|
|
442
|
+
vp_basename = os.path.splitext(os.path.basename(vp))[0]
|
|
443
|
+
if video_name in vp_basename or vp_basename in video_name:
|
|
444
|
+
video_path = vp
|
|
445
|
+
break
|
|
446
|
+
|
|
447
|
+
if not video_path:
|
|
448
|
+
logger.info("Could not find video for snippet %s", snippet_id)
|
|
449
|
+
logger.info(" Tried video names: %s", video_name_candidates)
|
|
450
|
+
logger.info(" Available videos: %s", [os.path.basename(v) for v in all_videos[:10]])
|
|
451
|
+
return None
|
|
452
|
+
|
|
453
|
+
video_name = video_name_candidates[0] if video_name_candidates else ""
|
|
454
|
+
|
|
455
|
+
mask_path = None
|
|
456
|
+
mask_dirs = [
|
|
457
|
+
os.path.join(experiment_path, "masks"),
|
|
458
|
+
os.path.join(experiment_path, "segmentation_masks"),
|
|
459
|
+
experiment_path
|
|
460
|
+
]
|
|
461
|
+
|
|
462
|
+
for mask_dir in mask_dirs:
|
|
463
|
+
if not os.path.exists(mask_dir):
|
|
464
|
+
continue
|
|
465
|
+
for f in os.listdir(mask_dir):
|
|
466
|
+
if f.endswith(('.h5', '.hdf5')) and (
|
|
467
|
+
video_name in f or
|
|
468
|
+
f.replace('_mask.h5', '').replace('_mask.hdf5', '') in video_name
|
|
469
|
+
):
|
|
470
|
+
mask_path = os.path.join(mask_dir, f)
|
|
471
|
+
break
|
|
472
|
+
if mask_path:
|
|
473
|
+
break
|
|
474
|
+
|
|
475
|
+
centroids = {}
|
|
476
|
+
box_size = 288
|
|
477
|
+
|
|
478
|
+
if mask_path:
|
|
479
|
+
try:
|
|
480
|
+
mask_data = load_segmentation_data(mask_path)
|
|
481
|
+
frame_objects = mask_data.get('frame_objects', [])
|
|
482
|
+
start_offset = mask_data.get('start_offset', 0)
|
|
483
|
+
|
|
484
|
+
for frame_idx in range(start_frame, end_frame + 1):
|
|
485
|
+
mask_frame_idx = frame_idx - start_offset
|
|
486
|
+
if mask_frame_idx < 0 or mask_frame_idx >= len(frame_objects):
|
|
487
|
+
continue
|
|
488
|
+
|
|
489
|
+
for obj in frame_objects[mask_frame_idx]:
|
|
490
|
+
mask = obj.get("mask")
|
|
491
|
+
bbox = obj.get("bbox")
|
|
492
|
+
if mask is None or bbox is None or not mask.any():
|
|
493
|
+
continue
|
|
494
|
+
ys, xs = mask.nonzero()
|
|
495
|
+
if len(xs) == 0:
|
|
496
|
+
continue
|
|
497
|
+
x_min, y_min, _, _ = bbox
|
|
498
|
+
cx = int(x_min + xs.mean())
|
|
499
|
+
cy = int(y_min + ys.mean())
|
|
500
|
+
centroids[frame_idx] = (cx, cy)
|
|
501
|
+
break
|
|
502
|
+
except Exception as e:
|
|
503
|
+
logger.error("Error loading mask data: %s", e)
|
|
504
|
+
|
|
505
|
+
cap = cv2.VideoCapture(video_path)
|
|
506
|
+
if not cap.isOpened():
|
|
507
|
+
return None
|
|
508
|
+
|
|
509
|
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
510
|
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
511
|
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
512
|
+
|
|
513
|
+
raw_clips_dir = os.path.join(experiment_path, "data", "clips", "raw_clips")
|
|
514
|
+
os.makedirs(raw_clips_dir, exist_ok=True)
|
|
515
|
+
|
|
516
|
+
safe_label = re.sub(r'[^\w\-_]', '_', label)
|
|
517
|
+
output_filename = f"{video_name}_{safe_label}_{snippet_id}.mp4"
|
|
518
|
+
output_path = os.path.join(raw_clips_dir, output_filename)
|
|
519
|
+
|
|
520
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
521
|
+
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
|
|
522
|
+
|
|
523
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
524
|
+
|
|
525
|
+
for frame_idx in range(start_frame, end_frame + 1):
|
|
526
|
+
ret, frame = cap.read()
|
|
527
|
+
if not ret:
|
|
528
|
+
break
|
|
529
|
+
|
|
530
|
+
out.write(frame)
|
|
531
|
+
|
|
532
|
+
cap.release()
|
|
533
|
+
out.release()
|
|
534
|
+
|
|
535
|
+
return output_path
|
|
536
|
+
|
|
537
|
+
except Exception as e:
|
|
538
|
+
logger.error("Error extracting raw clip from snippet: %s", e, exc_info=True)
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
def _on_review_ready(self, results: dict, classes: list, is_ovr: bool,
|
|
542
|
+
clip_length: int, target_fps: int):
|
|
543
|
+
"""Populate the Review tab when inference finishes and switch to it."""
|
|
544
|
+
self.review_widget.load_from_inference(results, classes, is_ovr, clip_length, target_fps)
|
|
545
|
+
self.tabs.setCurrentWidget(self.review_widget)
|
|
546
|
+
|
|
547
|
+
def _on_annotations_updated(self):
|
|
548
|
+
"""Refresh training / labeling tabs after review saves new clips."""
|
|
549
|
+
self.labeling_widget.refresh_clip_list()
|
|
550
|
+
self.training_widget.refresh_annotation_info()
|
|
551
|
+
|
|
552
|
+
def _on_tracking_completed(self, video_path: str, mask_path: str):
|
|
553
|
+
"""Handle tracking completion - switch to registration tab and load data."""
|
|
554
|
+
self.tabs.setCurrentWidget(self.registration_widget)
|
|
555
|
+
|
|
556
|
+
self.registration_widget.load_from_segmentation(video_path, mask_path)
|
|
557
|
+
|
|
558
|
+
def _on_embeddings_extracted(self, matrix_path: str, metadata_path: str):
|
|
559
|
+
"""Handle embedding extraction completion - auto-load into clustering tab."""
|
|
560
|
+
self.tabs.setCurrentWidget(self.clustering_widget)
|
|
561
|
+
self.clustering_widget.load_from_registration(matrix_path, metadata_path)
|
|
562
|
+
self._labeling_setup_prompted_for_clusters = False
|
|
563
|
+
|
|
564
|
+
def _open_video(self):
|
|
565
|
+
"""Open video file dialog."""
|
|
566
|
+
video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
|
|
567
|
+
video_path, _ = QFileDialog.getOpenFileName(
|
|
568
|
+
self,
|
|
569
|
+
"Open Video File",
|
|
570
|
+
video_dir,
|
|
571
|
+
"Video Files (*.mp4 *.avi *.mov *.mkv);;All Files (*)"
|
|
572
|
+
)
|
|
573
|
+
if video_path:
|
|
574
|
+
from .video_utils import ensure_video_in_experiment
|
|
575
|
+
video_path = ensure_video_in_experiment(video_path, self.config, self)
|
|
576
|
+
self.labeling_widget.add_source_videos([video_path], select_last=True)
|
|
577
|
+
self.tabs.setCurrentWidget(self.labeling_widget)
|
|
578
|
+
|
|
579
|
+
def _create_experiment(self):
|
|
580
|
+
"""Create a new experiment directory structure."""
|
|
581
|
+
name, ok = QInputDialog.getText(self, "Create Experiment", "Experiment name:")
|
|
582
|
+
if not ok or not name.strip():
|
|
583
|
+
return
|
|
584
|
+
name = name.strip()
|
|
585
|
+
|
|
586
|
+
experiments_dir = self.config.get("experiments_dir")
|
|
587
|
+
if not experiments_dir:
|
|
588
|
+
from singlebehaviorlab._paths import get_experiments_dir
|
|
589
|
+
experiments_dir = str(get_experiments_dir())
|
|
590
|
+
self.config["experiments_dir"] = experiments_dir
|
|
591
|
+
os.makedirs(experiments_dir, exist_ok=True)
|
|
592
|
+
|
|
593
|
+
experiment_path = os.path.abspath(os.path.join(experiments_dir, name))
|
|
594
|
+
if os.path.exists(experiment_path) and os.listdir(experiment_path):
|
|
595
|
+
reply = QMessageBox.question(
|
|
596
|
+
self,
|
|
597
|
+
"Overwrite Experiment",
|
|
598
|
+
f"The experiment folder '{experiment_path}' already exists and is not empty.\n"
|
|
599
|
+
"Do you want to reuse it? Existing files will be kept.",
|
|
600
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
601
|
+
QMessageBox.StandardButton.No
|
|
602
|
+
)
|
|
603
|
+
if reply != QMessageBox.StandardButton.Yes:
|
|
604
|
+
return
|
|
605
|
+
os.makedirs(experiment_path, exist_ok=True)
|
|
606
|
+
|
|
607
|
+
data_dir = os.path.join(experiment_path, "data")
|
|
608
|
+
raw_videos_dir = os.path.join(data_dir, "raw_videos")
|
|
609
|
+
clips_dir = os.path.join(data_dir, "clips")
|
|
610
|
+
annotations_dir = os.path.join(data_dir, "annotations")
|
|
611
|
+
models_dir = os.path.join(experiment_path, "models", "behavior_heads")
|
|
612
|
+
|
|
613
|
+
for path in [data_dir, raw_videos_dir, clips_dir, annotations_dir, models_dir]:
|
|
614
|
+
os.makedirs(path, exist_ok=True)
|
|
615
|
+
|
|
616
|
+
annotation_file = os.path.join(annotations_dir, "annotations.json")
|
|
617
|
+
if not os.path.exists(annotation_file):
|
|
618
|
+
default_annotations = {"classes": [], "clips": []}
|
|
619
|
+
with open(annotation_file, "w", encoding="utf-8") as f:
|
|
620
|
+
json.dump(default_annotations, f, indent=2)
|
|
621
|
+
|
|
622
|
+
profiles_file = os.path.join(experiment_path, "training_profiles.json")
|
|
623
|
+
if not os.path.exists(profiles_file):
|
|
624
|
+
from singlebehaviorlab._paths import get_training_profiles_path
|
|
625
|
+
src = get_training_profiles_path()
|
|
626
|
+
if src and os.path.exists(str(src)):
|
|
627
|
+
import shutil
|
|
628
|
+
shutil.copy2(str(src), profiles_file)
|
|
629
|
+
else:
|
|
630
|
+
with open(profiles_file, "w", encoding="utf-8") as f:
|
|
631
|
+
json.dump({"Default": {}}, f, indent=2)
|
|
632
|
+
|
|
633
|
+
experiment_config_path = os.path.join(experiment_path, "config.yaml")
|
|
634
|
+
|
|
635
|
+
new_config = dict(self.config)
|
|
636
|
+
new_config.update({
|
|
637
|
+
"experiment_name": name,
|
|
638
|
+
"experiment_path": experiment_path,
|
|
639
|
+
"data_dir": data_dir,
|
|
640
|
+
"raw_videos_dir": raw_videos_dir,
|
|
641
|
+
"clips_dir": clips_dir,
|
|
642
|
+
"annotations_dir": annotations_dir,
|
|
643
|
+
"annotation_file": annotation_file,
|
|
644
|
+
"training_clips_dir": clips_dir,
|
|
645
|
+
"training_annotation_file": annotation_file,
|
|
646
|
+
"models_dir": models_dir,
|
|
647
|
+
"config_path": experiment_config_path,
|
|
648
|
+
# Use a conservative default that works well for small labeled datasets.
|
|
649
|
+
"default_weight_decay": 0.001,
|
|
650
|
+
})
|
|
651
|
+
|
|
652
|
+
self._update_config(new_config)
|
|
653
|
+
self._save_experiment_config()
|
|
654
|
+
self._apply_config_to_widgets()
|
|
655
|
+
self._labeling_setup_prompted_for_clusters = False
|
|
656
|
+
try:
|
|
657
|
+
with open(profiles_file, "r", encoding="utf-8") as f:
|
|
658
|
+
profiles_data = json.load(f) or {}
|
|
659
|
+
if isinstance(profiles_data, dict) and "LowInputData" in profiles_data:
|
|
660
|
+
self.training_widget.apply_training_config(profiles_data["LowInputData"])
|
|
661
|
+
self.training_widget.current_profile_name = "LowInputData"
|
|
662
|
+
except Exception:
|
|
663
|
+
pass
|
|
664
|
+
self._update_window_title()
|
|
665
|
+
QMessageBox.information(self, "Experiment created", f"Experiment '{name}' created at:\n{experiment_path}")
|
|
666
|
+
|
|
667
|
+
def _load_experiment(self):
|
|
668
|
+
"""Load an existing experiment configuration."""
|
|
669
|
+
start_dir = self.config.get("experiments_dir", os.getcwd())
|
|
670
|
+
config_path, _ = QFileDialog.getOpenFileName(
|
|
671
|
+
self,
|
|
672
|
+
"Load Experiment",
|
|
673
|
+
start_dir,
|
|
674
|
+
"YAML Files (*.yaml *.yml);;All Files (*)"
|
|
675
|
+
)
|
|
676
|
+
if not config_path:
|
|
677
|
+
return
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
681
|
+
loaded = yaml.safe_load(f) or {}
|
|
682
|
+
except Exception as exc:
|
|
683
|
+
QMessageBox.critical(self, "Error", f"Failed to load experiment config:\n{exc}")
|
|
684
|
+
return
|
|
685
|
+
|
|
686
|
+
experiment_path = loaded.get("experiment_path") or os.path.dirname(os.path.abspath(config_path))
|
|
687
|
+
loaded["experiment_path"] = experiment_path
|
|
688
|
+
loaded["config_path"] = config_path
|
|
689
|
+
loaded.setdefault("experiments_dir", self.config.get("experiments_dir"))
|
|
690
|
+
|
|
691
|
+
# Resolve experiment paths relative to the loaded config when needed.
|
|
692
|
+
def _abs_path(key, default_subpath):
|
|
693
|
+
if key not in loaded or not loaded[key]:
|
|
694
|
+
loaded[key] = os.path.join(experiment_path, default_subpath)
|
|
695
|
+
elif not os.path.isabs(loaded[key]):
|
|
696
|
+
loaded[key] = os.path.join(experiment_path, loaded[key])
|
|
697
|
+
|
|
698
|
+
_abs_path("data_dir", "data")
|
|
699
|
+
_abs_path("raw_videos_dir", os.path.join("data", "raw_videos"))
|
|
700
|
+
_abs_path("clips_dir", os.path.join("data", "clips"))
|
|
701
|
+
_abs_path("annotations_dir", os.path.join("data", "annotations"))
|
|
702
|
+
_abs_path("models_dir", os.path.join("models", "behavior_heads"))
|
|
703
|
+
_abs_path("annotation_file", os.path.join("data", "annotations", "annotations.json"))
|
|
704
|
+
|
|
705
|
+
self._update_config(loaded)
|
|
706
|
+
self._apply_config_to_widgets()
|
|
707
|
+
self._labeling_setup_prompted_for_clusters = False
|
|
708
|
+
self._update_window_title()
|
|
709
|
+
QMessageBox.information(self, "Experiment loaded", f"Loaded experiment from:\n{config_path}")
|
|
710
|
+
|
|
711
|
+
def _save_experiment(self):
|
|
712
|
+
"""Save current experiment configuration."""
|
|
713
|
+
config_path = self.config.get("config_path")
|
|
714
|
+
if not config_path:
|
|
715
|
+
QMessageBox.warning(self, "Save experiment", "No experiment is currently loaded.")
|
|
716
|
+
return
|
|
717
|
+
self._save_experiment_config()
|
|
718
|
+
QMessageBox.information(self, "Experiment saved", f"Experiment saved to:\n{config_path}")
|
|
719
|
+
|
|
720
|
+
def _save_experiment_config(self):
|
|
721
|
+
"""Write the current config to disk."""
|
|
722
|
+
config_path = self.config.get("config_path")
|
|
723
|
+
if not config_path:
|
|
724
|
+
return
|
|
725
|
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
726
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
727
|
+
yaml.safe_dump(dict(self.config), f, sort_keys=False)
|
|
728
|
+
|
|
729
|
+
def _update_config(self, new_values: dict):
|
|
730
|
+
"""Update the shared config dictionary in place."""
|
|
731
|
+
self.config.clear()
|
|
732
|
+
self.config.update(new_values)
|
|
733
|
+
|
|
734
|
+
def _apply_config_to_widgets(self):
|
|
735
|
+
"""Propagate config changes to all widgets."""
|
|
736
|
+
self.segmentation_widget.update_config(self.config)
|
|
737
|
+
self.registration_widget.update_config(self.config)
|
|
738
|
+
self.clustering_widget.update_config(self.config)
|
|
739
|
+
self.labeling_widget.update_config(self.config)
|
|
740
|
+
self.training_widget.update_config(self.config)
|
|
741
|
+
self.inference_widget.update_config(self.config)
|
|
742
|
+
self.review_widget.update_config(self.config)
|
|
743
|
+
self.analysis_widget.update_config(self.config)
|
|
744
|
+
|
|
745
|
+
def _update_window_title(self):
|
|
746
|
+
"""Update window title with experiment name."""
|
|
747
|
+
name = self.config.get("experiment_name")
|
|
748
|
+
base_title = "SingleBehavior Lab"
|
|
749
|
+
if name:
|
|
750
|
+
self.setWindowTitle(f"{base_title} - {name}")
|
|
751
|
+
else:
|
|
752
|
+
self.setWindowTitle(base_title)
|
|
753
|
+
|
|
754
|
+
|