singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,1485 @@
|
|
|
1
|
+
"""Registration widget for processing videos with masks from segmentation."""
|
|
2
|
+
import sys
|
|
3
|
+
import os
|
|
4
|
+
import glob
|
|
5
|
+
# JAX memory: grow on demand, capped at 45% so PyTorch (SAM2) keeps the rest
|
|
6
|
+
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
|
7
|
+
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
|
|
8
|
+
|
|
9
|
+
import cv2
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import torch
|
|
13
|
+
import re
|
|
14
|
+
# Add parent directory to path for backend imports
|
|
15
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
16
|
+
from PyQt6.QtWidgets import (
|
|
17
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
|
|
18
|
+
QFileDialog, QGroupBox, QFormLayout, QSpinBox, QComboBox, QProgressBar,
|
|
19
|
+
QTextEdit, QMessageBox, QListWidget, QListWidgetItem, QDialog, QDialogButtonBox,
|
|
20
|
+
QSlider, QApplication, QCheckBox, QProgressDialog
|
|
21
|
+
)
|
|
22
|
+
from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer
|
|
23
|
+
from PyQt6.QtGui import QImage, QPixmap
|
|
24
|
+
from singlebehaviorlab.backend.video_processor import process_video, process_video_to_clips, load_segmentation_data
|
|
25
|
+
from singlebehaviorlab.backend.model import VideoPrismBackbone
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class VideoPlayerDialog(QDialog):
|
|
29
|
+
"""Dialog for playing video files using OpenCV with streaming playback (no full caching)."""
|
|
30
|
+
def __init__(self, video_paths: list, start_index: int = 0, parent=None):
|
|
31
|
+
super().__init__(parent)
|
|
32
|
+
self.video_paths = video_paths
|
|
33
|
+
self.current_video_idx = start_index
|
|
34
|
+
|
|
35
|
+
self.cap = None # VideoCapture for streaming
|
|
36
|
+
self.total_frames = 0
|
|
37
|
+
self.fps = 30.0
|
|
38
|
+
self.current_frame_idx = 0
|
|
39
|
+
self.is_playing = False
|
|
40
|
+
self.timer = QTimer(self)
|
|
41
|
+
self.timer.timeout.connect(self._advance_frame)
|
|
42
|
+
self.slider_pressed = False
|
|
43
|
+
self._current_frame = None # Keep reference for QImage
|
|
44
|
+
|
|
45
|
+
self._setup_ui()
|
|
46
|
+
self._load_current_video()
|
|
47
|
+
|
|
48
|
+
def _setup_ui(self):
|
|
49
|
+
self.setMinimumSize(800, 600)
|
|
50
|
+
layout = QVBoxLayout(self)
|
|
51
|
+
|
|
52
|
+
# Navigation
|
|
53
|
+
nav_layout = QHBoxLayout()
|
|
54
|
+
self.prev_btn = QPushButton("Previous clip")
|
|
55
|
+
self.prev_btn.clicked.connect(self._prev_video)
|
|
56
|
+
nav_layout.addWidget(self.prev_btn)
|
|
57
|
+
|
|
58
|
+
self.title_label = QLabel()
|
|
59
|
+
self.title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
60
|
+
nav_layout.addWidget(self.title_label)
|
|
61
|
+
|
|
62
|
+
self.next_btn = QPushButton("Next clip")
|
|
63
|
+
self.next_btn.clicked.connect(self._next_video)
|
|
64
|
+
nav_layout.addWidget(self.next_btn)
|
|
65
|
+
layout.addLayout(nav_layout)
|
|
66
|
+
|
|
67
|
+
# Video label
|
|
68
|
+
self.video_label = QLabel()
|
|
69
|
+
self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
70
|
+
self.video_label.setMinimumSize(640, 480)
|
|
71
|
+
self.video_label.setStyleSheet("background-color: black;")
|
|
72
|
+
layout.addWidget(self.video_label)
|
|
73
|
+
|
|
74
|
+
# Slider
|
|
75
|
+
self.slider = QSlider(Qt.Orientation.Horizontal)
|
|
76
|
+
self.slider.sliderPressed.connect(self._on_slider_pressed)
|
|
77
|
+
self.slider.sliderReleased.connect(self._on_slider_released)
|
|
78
|
+
self.slider.valueChanged.connect(self._on_slider_changed)
|
|
79
|
+
layout.addWidget(self.slider)
|
|
80
|
+
|
|
81
|
+
# Controls
|
|
82
|
+
controls_layout = QHBoxLayout()
|
|
83
|
+
|
|
84
|
+
self.play_btn = QPushButton("Play")
|
|
85
|
+
self.play_btn.clicked.connect(self._toggle_play)
|
|
86
|
+
controls_layout.addWidget(self.play_btn)
|
|
87
|
+
|
|
88
|
+
self.stop_btn = QPushButton("Stop")
|
|
89
|
+
self.stop_btn.clicked.connect(self._stop)
|
|
90
|
+
controls_layout.addWidget(self.stop_btn)
|
|
91
|
+
|
|
92
|
+
self.frame_label = QLabel("Frame: 0 / 0")
|
|
93
|
+
controls_layout.addWidget(self.frame_label)
|
|
94
|
+
|
|
95
|
+
controls_layout.addStretch()
|
|
96
|
+
|
|
97
|
+
close_btn = QPushButton("Close")
|
|
98
|
+
close_btn.clicked.connect(self.accept)
|
|
99
|
+
controls_layout.addWidget(close_btn)
|
|
100
|
+
|
|
101
|
+
layout.addLayout(controls_layout)
|
|
102
|
+
|
|
103
|
+
def _load_current_video(self):
|
|
104
|
+
"""Load current video for streaming playback."""
|
|
105
|
+
if not self.video_paths:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
video_path = self.video_paths[self.current_video_idx]
|
|
109
|
+
self.setWindowTitle(f"Video Player - {os.path.basename(video_path)}")
|
|
110
|
+
self.title_label.setText(f"{self.current_video_idx + 1} / {len(self.video_paths)}: {os.path.basename(video_path)}")
|
|
111
|
+
|
|
112
|
+
# Update nav buttons
|
|
113
|
+
self.prev_btn.setEnabled(self.current_video_idx > 0)
|
|
114
|
+
self.next_btn.setEnabled(self.current_video_idx < len(self.video_paths) - 1)
|
|
115
|
+
|
|
116
|
+
# Stop playback and release old capture
|
|
117
|
+
self._stop()
|
|
118
|
+
if self.cap is not None:
|
|
119
|
+
self.cap.release()
|
|
120
|
+
self.cap = None
|
|
121
|
+
|
|
122
|
+
self._load_video_content(video_path)
|
|
123
|
+
|
|
124
|
+
def _prev_video(self):
|
|
125
|
+
if self.current_video_idx > 0:
|
|
126
|
+
self.current_video_idx -= 1
|
|
127
|
+
self._load_current_video()
|
|
128
|
+
|
|
129
|
+
def _next_video(self):
|
|
130
|
+
if self.current_video_idx < len(self.video_paths) - 1:
|
|
131
|
+
self.current_video_idx += 1
|
|
132
|
+
self._load_current_video()
|
|
133
|
+
|
|
134
|
+
def _load_video_content(self, video_path):
|
|
135
|
+
"""Open video for streaming playback (memory efficient)."""
|
|
136
|
+
try:
|
|
137
|
+
self.video_label.setText("Loading video...")
|
|
138
|
+
QApplication.processEvents()
|
|
139
|
+
|
|
140
|
+
self.cap = cv2.VideoCapture(video_path)
|
|
141
|
+
if not self.cap.isOpened():
|
|
142
|
+
QMessageBox.critical(self, "Error", f"Could not open video: {video_path}")
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
146
|
+
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
147
|
+
|
|
148
|
+
if self.total_frames <= 0:
|
|
149
|
+
QMessageBox.warning(self, "Error", "Could not determine video length.")
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
# Setup slider and labels
|
|
153
|
+
self.slider.setRange(0, self.total_frames - 1)
|
|
154
|
+
self.frame_label.setText(f"Frame: 1 / {self.total_frames}")
|
|
155
|
+
|
|
156
|
+
# Setup timer
|
|
157
|
+
interval = int(1000 / self.fps) if self.fps > 0 else 33
|
|
158
|
+
self.timer.setInterval(interval)
|
|
159
|
+
|
|
160
|
+
# Show first frame
|
|
161
|
+
self._display_frame(0)
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
QMessageBox.critical(self, "Error", f"Failed to load video: {str(e)}")
|
|
165
|
+
|
|
166
|
+
def _display_frame(self, idx):
|
|
167
|
+
"""Display frame at index by seeking in video (streaming)."""
|
|
168
|
+
if self.cap is None or not self.cap.isOpened():
|
|
169
|
+
return
|
|
170
|
+
if not (0 <= idx < self.total_frames):
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
|
174
|
+
ret, frame = self.cap.read()
|
|
175
|
+
if not ret:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# Convert BGR to RGB and keep reference to prevent QImage crash
|
|
179
|
+
self._current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
180
|
+
if not self._current_frame.flags['C_CONTIGUOUS']:
|
|
181
|
+
self._current_frame = np.ascontiguousarray(self._current_frame)
|
|
182
|
+
|
|
183
|
+
h, w, ch = self._current_frame.shape
|
|
184
|
+
bytes_per_line = ch * w
|
|
185
|
+
|
|
186
|
+
q_image = QImage(self._current_frame.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
|
|
187
|
+
pixmap = QPixmap.fromImage(q_image)
|
|
188
|
+
|
|
189
|
+
scaled_pixmap = pixmap.scaled(
|
|
190
|
+
self.video_label.size(),
|
|
191
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
|
192
|
+
Qt.TransformationMode.SmoothTransformation
|
|
193
|
+
)
|
|
194
|
+
self.video_label.setPixmap(scaled_pixmap)
|
|
195
|
+
self.frame_label.setText(f"Frame: {idx + 1} / {self.total_frames}")
|
|
196
|
+
|
|
197
|
+
def _advance_frame(self):
|
|
198
|
+
"""Advance to next frame."""
|
|
199
|
+
if not self.is_playing or self.total_frames <= 0:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
self.current_frame_idx = (self.current_frame_idx + 1) % self.total_frames
|
|
203
|
+
self.slider.blockSignals(True)
|
|
204
|
+
self.slider.setValue(self.current_frame_idx)
|
|
205
|
+
self.slider.blockSignals(False)
|
|
206
|
+
self._display_frame(self.current_frame_idx)
|
|
207
|
+
|
|
208
|
+
def _on_slider_changed(self, value):
|
|
209
|
+
"""Handle slider value change."""
|
|
210
|
+
self.current_frame_idx = value
|
|
211
|
+
self._display_frame(value)
|
|
212
|
+
|
|
213
|
+
def _on_slider_pressed(self):
|
|
214
|
+
"""Handle slider press (pause playback)."""
|
|
215
|
+
self.slider_pressed = True
|
|
216
|
+
if self.is_playing:
|
|
217
|
+
self.timer.stop()
|
|
218
|
+
|
|
219
|
+
def _on_slider_released(self):
|
|
220
|
+
"""Handle slider release (resume if was playing)."""
|
|
221
|
+
self.slider_pressed = False
|
|
222
|
+
if self.is_playing:
|
|
223
|
+
self.timer.start()
|
|
224
|
+
|
|
225
|
+
def _toggle_play(self):
|
|
226
|
+
"""Toggle play/pause."""
|
|
227
|
+
if self.is_playing:
|
|
228
|
+
self.timer.stop()
|
|
229
|
+
self.is_playing = False
|
|
230
|
+
self.play_btn.setText("Play")
|
|
231
|
+
else:
|
|
232
|
+
self.timer.start()
|
|
233
|
+
self.is_playing = True
|
|
234
|
+
self.play_btn.setText("Pause")
|
|
235
|
+
|
|
236
|
+
def _stop(self):
|
|
237
|
+
"""Stop playback and reset."""
|
|
238
|
+
self.timer.stop()
|
|
239
|
+
self.is_playing = False
|
|
240
|
+
self.play_btn.setText("Play")
|
|
241
|
+
self.current_frame_idx = 0
|
|
242
|
+
self.slider.setValue(0)
|
|
243
|
+
self._display_frame(0)
|
|
244
|
+
|
|
245
|
+
def closeEvent(self, event):
|
|
246
|
+
"""Clean up."""
|
|
247
|
+
self.timer.stop()
|
|
248
|
+
if self.cap is not None:
|
|
249
|
+
self.cap.release()
|
|
250
|
+
self.cap = None
|
|
251
|
+
self._current_frame = None
|
|
252
|
+
super().closeEvent(event)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class ProcessingWorker(QThread):
|
|
256
|
+
"""Worker thread for video processing into clips."""
|
|
257
|
+
progress = pyqtSignal(int, int, str) # current, total, video_name
|
|
258
|
+
finished = pyqtSignal(list) # List of output clip paths
|
|
259
|
+
error = pyqtSignal(str)
|
|
260
|
+
|
|
261
|
+
def __init__(self, video_mask_pairs, output_dir, params):
|
|
262
|
+
super().__init__()
|
|
263
|
+
self.video_mask_pairs = video_mask_pairs
|
|
264
|
+
self.output_dir = output_dir
|
|
265
|
+
self.params = params
|
|
266
|
+
self.should_stop = False
|
|
267
|
+
self.output_paths = []
|
|
268
|
+
|
|
269
|
+
def stop(self):
|
|
270
|
+
self.should_stop = True
|
|
271
|
+
|
|
272
|
+
def run(self):
|
|
273
|
+
try:
|
|
274
|
+
total = len(self.video_mask_pairs)
|
|
275
|
+
for i, (video_path, mask_path) in enumerate(self.video_mask_pairs):
|
|
276
|
+
if self.should_stop:
|
|
277
|
+
break
|
|
278
|
+
|
|
279
|
+
video_name = os.path.basename(video_path)
|
|
280
|
+
self.progress.emit(i + 1, total, video_name)
|
|
281
|
+
|
|
282
|
+
# Generate output directory for clips
|
|
283
|
+
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
|
284
|
+
video_clips_dir = os.path.join(self.output_dir, base_name)
|
|
285
|
+
|
|
286
|
+
# Progress callback
|
|
287
|
+
def progress_cb(clip_num, total_clips=None, obj_id=None):
|
|
288
|
+
if self.should_stop:
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
# Process video into clips (may return list if multiple objects)
|
|
292
|
+
clip_paths = process_video_to_clips(
|
|
293
|
+
video_path=video_path,
|
|
294
|
+
mask_path=mask_path,
|
|
295
|
+
output_dir=video_clips_dir,
|
|
296
|
+
box_size=self.params['box_size'],
|
|
297
|
+
target_size=self.params['target_size'],
|
|
298
|
+
background_mode=self.params['background_mode'],
|
|
299
|
+
normalization_method=self.params['normalization_method'],
|
|
300
|
+
mask_feather_px=self.params['mask_feather_px'],
|
|
301
|
+
anchor_mode=self.params['anchor_mode'],
|
|
302
|
+
target_fps=self.params['target_fps'],
|
|
303
|
+
clip_length_frames=self.params['clip_length_frames'],
|
|
304
|
+
step_frames=self.params['step_frames'],
|
|
305
|
+
progress_callback=progress_cb
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if clip_paths:
|
|
309
|
+
# clip_paths is now list of (clip_path, start_frame, end_frame) tuples
|
|
310
|
+
# Store them for later use in embedding extraction
|
|
311
|
+
self.output_paths.extend(clip_paths)
|
|
312
|
+
else:
|
|
313
|
+
self.error.emit(f"Failed to process {video_name} - no clips created")
|
|
314
|
+
|
|
315
|
+
if not self.should_stop:
|
|
316
|
+
self.finished.emit(self.output_paths)
|
|
317
|
+
|
|
318
|
+
except Exception as e:
|
|
319
|
+
self.error.emit(str(e))
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class EmbeddingExtractionWorker(QThread):
|
|
323
|
+
"""Worker thread for extracting VideoPrism embeddings from clips."""
|
|
324
|
+
progress = pyqtSignal(int, int, str) # current, total, clip_name
|
|
325
|
+
finished = pyqtSignal(str, str) # feature_matrix_path, metadata_path
|
|
326
|
+
error = pyqtSignal(str)
|
|
327
|
+
log_message = pyqtSignal(str)
|
|
328
|
+
|
|
329
|
+
def __init__(self, clip_paths: list, output_dir: str, experiment_name: str = None, model_name: str = 'videoprism_public_v1_base', clip_frame_ranges: dict = None, append_to_existing: bool = False):
|
|
330
|
+
super().__init__()
|
|
331
|
+
self.clip_paths = clip_paths # List of clip paths (strings)
|
|
332
|
+
self.clip_frame_ranges = clip_frame_ranges or {} # Dict mapping clip_path -> (start_frame, end_frame)
|
|
333
|
+
self.output_dir = output_dir
|
|
334
|
+
self.experiment_name = experiment_name
|
|
335
|
+
self.model_name = model_name
|
|
336
|
+
self.should_stop = False
|
|
337
|
+
self.append_to_existing = append_to_existing
|
|
338
|
+
|
|
339
|
+
def stop(self):
|
|
340
|
+
self.should_stop = True
|
|
341
|
+
|
|
342
|
+
def run(self):
|
|
343
|
+
try:
|
|
344
|
+
self.log_message.emit(f"Loading VideoPrism model: {self.model_name}...")
|
|
345
|
+
|
|
346
|
+
# Clear PyTorch cache to free up VRAM for JAX
|
|
347
|
+
if torch.cuda.is_available():
|
|
348
|
+
torch.cuda.empty_cache()
|
|
349
|
+
|
|
350
|
+
backbone = VideoPrismBackbone(model_name=self.model_name, log_fn=self.log_message.emit)
|
|
351
|
+
backbone.eval()
|
|
352
|
+
|
|
353
|
+
embed_dim = backbone.get_embed_dim()
|
|
354
|
+
self.log_message.emit(f"VideoPrism model loaded. Embedding dimension: {embed_dim}")
|
|
355
|
+
|
|
356
|
+
feature_matrix = []
|
|
357
|
+
metadata = []
|
|
358
|
+
|
|
359
|
+
total = len(self.clip_paths)
|
|
360
|
+
self.log_message.emit(f"Processing {total} clips...")
|
|
361
|
+
|
|
362
|
+
for i, clip_path in enumerate(self.clip_paths):
|
|
363
|
+
if self.should_stop:
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
clip_name = os.path.basename(clip_path)
|
|
367
|
+
self.progress.emit(i + 1, total, clip_name)
|
|
368
|
+
|
|
369
|
+
# Load clip frames
|
|
370
|
+
frames = self._load_clip_frames(clip_path)
|
|
371
|
+
if frames is None or len(frames) == 0:
|
|
372
|
+
self.log_message.emit(f"Warning: Could not load frames from {clip_name}, skipping")
|
|
373
|
+
continue
|
|
374
|
+
|
|
375
|
+
# Extract embedding
|
|
376
|
+
embedding = self._extract_embedding(backbone, frames)
|
|
377
|
+
|
|
378
|
+
# Free frames memory immediately after use
|
|
379
|
+
del frames
|
|
380
|
+
|
|
381
|
+
if embedding is None:
|
|
382
|
+
self.log_message.emit(f"Warning: Could not extract embedding from {clip_name}, skipping")
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
feature_matrix.append(embedding.tolist())
|
|
386
|
+
|
|
387
|
+
# Free embedding after converting to list
|
|
388
|
+
del embedding
|
|
389
|
+
|
|
390
|
+
# Periodically clear CUDA cache to prevent GPU memory accumulation
|
|
391
|
+
if (i + 1) % 50 == 0 and torch.cuda.is_available():
|
|
392
|
+
torch.cuda.empty_cache()
|
|
393
|
+
import gc
|
|
394
|
+
gc.collect()
|
|
395
|
+
|
|
396
|
+
# Parse clip name for metadata
|
|
397
|
+
# Format: clip_XXXXXX_objY.mp4 or clip_XXXXXX.mp4
|
|
398
|
+
base_name = os.path.splitext(clip_name)[0]
|
|
399
|
+
video_dir = os.path.dirname(clip_path)
|
|
400
|
+
video_name = os.path.basename(video_dir)
|
|
401
|
+
|
|
402
|
+
# Extract object ID if present
|
|
403
|
+
obj_match = re.search(r'_obj(\d+)', base_name)
|
|
404
|
+
obj_id = obj_match.group(1) if obj_match else None
|
|
405
|
+
|
|
406
|
+
# Extract clip index
|
|
407
|
+
clip_match = re.search(r'clip_(\d+)', base_name)
|
|
408
|
+
clip_idx = int(clip_match.group(1)) if clip_match else i
|
|
409
|
+
|
|
410
|
+
# Get frame range if available
|
|
411
|
+
start_frame = None
|
|
412
|
+
end_frame = None
|
|
413
|
+
if clip_path in self.clip_frame_ranges:
|
|
414
|
+
start_frame, end_frame = self.clip_frame_ranges[clip_path]
|
|
415
|
+
|
|
416
|
+
metadata.append({
|
|
417
|
+
'snippet': f'snippet{i+1}',
|
|
418
|
+
'group': video_name,
|
|
419
|
+
'video_id': clip_name,
|
|
420
|
+
'object_id': obj_id if obj_id else '',
|
|
421
|
+
'clip_index': clip_idx,
|
|
422
|
+
'start_frame': start_frame if start_frame is not None else '',
|
|
423
|
+
'end_frame': end_frame if end_frame is not None else ''
|
|
424
|
+
})
|
|
425
|
+
|
|
426
|
+
if self.should_stop:
|
|
427
|
+
self.log_message.emit("Extraction stopped by user.")
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
if not feature_matrix:
|
|
431
|
+
self.error.emit("No embeddings extracted. Please check that clips are valid.")
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
# Convert to numpy array
|
|
435
|
+
feature_matrix = np.array(feature_matrix, dtype=np.float32) # downcast to save space
|
|
436
|
+
self.log_message.emit(f"Extracted {len(feature_matrix)} embeddings. Shape: {feature_matrix.shape}")
|
|
437
|
+
|
|
438
|
+
# Prepare data
|
|
439
|
+
num_snippets = feature_matrix.shape[0]
|
|
440
|
+
num_features = feature_matrix.shape[1]
|
|
441
|
+
metadata_df = pd.DataFrame(metadata)
|
|
442
|
+
|
|
443
|
+
# Decide whether to append to existing embeddings
|
|
444
|
+
append_used = False
|
|
445
|
+
existing_matrix_path = None
|
|
446
|
+
existing_metadata_path = None
|
|
447
|
+
if self.append_to_existing:
|
|
448
|
+
candidates = sorted(glob.glob(os.path.join(self.output_dir, "*_matrix.npz")), key=os.path.getmtime, reverse=True)
|
|
449
|
+
if candidates:
|
|
450
|
+
existing_matrix_path = candidates[0]
|
|
451
|
+
guess_meta = existing_matrix_path.replace("_matrix.npz", "_metadata.npz")
|
|
452
|
+
if os.path.exists(guess_meta):
|
|
453
|
+
existing_metadata_path = guess_meta
|
|
454
|
+
if not existing_matrix_path or not existing_metadata_path:
|
|
455
|
+
self.log_message.emit("Append requested but no existing matrix/metadata found. Creating new files.")
|
|
456
|
+
|
|
457
|
+
snippet_ids = [f'snippet{i+1}' for i in range(num_snippets)]
|
|
458
|
+
feature_names = [f'behaviorome_embedding_{i}' for i in range(num_features)]
|
|
459
|
+
|
|
460
|
+
# If appending, load existing and merge
|
|
461
|
+
if existing_matrix_path and existing_metadata_path:
|
|
462
|
+
try:
|
|
463
|
+
existing_npz = np.load(existing_matrix_path, allow_pickle=True)
|
|
464
|
+
existing_matrix = existing_npz["matrix"] # features x snippets
|
|
465
|
+
existing_feature_names = existing_npz["feature_names"]
|
|
466
|
+
existing_snippet_ids = existing_npz["snippet_ids"]
|
|
467
|
+
|
|
468
|
+
# Validate feature dimension
|
|
469
|
+
if existing_matrix.shape[0] != feature_matrix.shape[1]:
|
|
470
|
+
self.log_message.emit("Append skipped: feature dimension mismatch. Writing new files instead.")
|
|
471
|
+
else:
|
|
472
|
+
# Load existing metadata
|
|
473
|
+
existing_meta_npz = np.load(existing_metadata_path, allow_pickle=True)
|
|
474
|
+
existing_meta = pd.DataFrame(existing_meta_npz["metadata"], columns=existing_meta_npz["columns"])
|
|
475
|
+
|
|
476
|
+
offset = existing_matrix.shape[1]
|
|
477
|
+
snippet_ids = [f'snippet{offset + i + 1}' for i in range(num_snippets)]
|
|
478
|
+
metadata_df['snippet'] = snippet_ids
|
|
479
|
+
|
|
480
|
+
# Combine matrices and metadata
|
|
481
|
+
feature_matrix = feature_matrix.T # to features x new_snippets
|
|
482
|
+
combined_matrix = np.concatenate([existing_matrix, feature_matrix], axis=1)
|
|
483
|
+
combined_snippet_ids = np.concatenate([existing_snippet_ids, np.array(snippet_ids)])
|
|
484
|
+
combined_feature_names = existing_feature_names
|
|
485
|
+
combined_meta = pd.concat([existing_meta, metadata_df], ignore_index=True)
|
|
486
|
+
|
|
487
|
+
# Overwrite base name with existing
|
|
488
|
+
base_name = os.path.basename(existing_matrix_path).replace("_matrix.npz", "")
|
|
489
|
+
npz_matrix_path = existing_matrix_path
|
|
490
|
+
npz_metadata_path = existing_metadata_path
|
|
491
|
+
|
|
492
|
+
feature_matrix = combined_matrix.T # convert back to snippets x features for downstream naming
|
|
493
|
+
feature_names = combined_feature_names.tolist()
|
|
494
|
+
snippet_ids = combined_snippet_ids.tolist()
|
|
495
|
+
metadata_df = combined_meta
|
|
496
|
+
append_used = True
|
|
497
|
+
self.log_message.emit(f"Appending to existing embeddings: {os.path.basename(existing_matrix_path)}")
|
|
498
|
+
except Exception as e:
|
|
499
|
+
self.log_message.emit(f"Append failed, writing new files instead: {e}")
|
|
500
|
+
|
|
501
|
+
if not append_used:
|
|
502
|
+
# Generate filename with experiment name and timestamp
|
|
503
|
+
from datetime import datetime
|
|
504
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
505
|
+
|
|
506
|
+
# Build base filename
|
|
507
|
+
if self.experiment_name:
|
|
508
|
+
base_name = f"behaviorome_{self.experiment_name}_{timestamp}"
|
|
509
|
+
else:
|
|
510
|
+
base_name = f"behaviorome_{timestamp}"
|
|
511
|
+
|
|
512
|
+
# Build paths
|
|
513
|
+
npz_matrix_path = os.path.join(self.output_dir, f'{base_name}_matrix.npz')
|
|
514
|
+
npz_metadata_path = os.path.join(self.output_dir, f'{base_name}_metadata.npz')
|
|
515
|
+
|
|
516
|
+
# Save as NPZ (fastest, most efficient for large datasets)
|
|
517
|
+
|
|
518
|
+
try:
|
|
519
|
+
np.savez_compressed(
|
|
520
|
+
npz_matrix_path,
|
|
521
|
+
matrix=feature_matrix.T, # features x snippets
|
|
522
|
+
feature_names=np.array(feature_names),
|
|
523
|
+
snippet_ids=np.array(snippet_ids),
|
|
524
|
+
)
|
|
525
|
+
self.log_message.emit(f"Saved feature matrix (NPZ) to {npz_matrix_path}")
|
|
526
|
+
except Exception as e:
|
|
527
|
+
self.log_message.emit(f"NPZ save failed (matrix): {e}")
|
|
528
|
+
npz_matrix_path = None
|
|
529
|
+
|
|
530
|
+
try:
|
|
531
|
+
# Save metadata as NPZ
|
|
532
|
+
np.savez_compressed(
|
|
533
|
+
npz_metadata_path,
|
|
534
|
+
metadata=metadata_df.values,
|
|
535
|
+
columns=np.array(metadata_df.columns),
|
|
536
|
+
)
|
|
537
|
+
self.log_message.emit(f"Saved metadata (NPZ) to {npz_metadata_path}")
|
|
538
|
+
except Exception as e:
|
|
539
|
+
self.log_message.emit(f"NPZ save failed (metadata): {e}")
|
|
540
|
+
npz_metadata_path = None
|
|
541
|
+
|
|
542
|
+
# Also save Parquet as backup (faster than CSV, still readable)
|
|
543
|
+
try:
|
|
544
|
+
matrix_df = pd.DataFrame(feature_matrix.T, index=feature_names, columns=snippet_ids)
|
|
545
|
+
parquet_matrix_path = os.path.join(self.output_dir, f'{base_name}_matrix.parquet')
|
|
546
|
+
matrix_df.to_parquet(parquet_matrix_path, index=True)
|
|
547
|
+
self.log_message.emit(f"Saved feature matrix (Parquet) to {parquet_matrix_path}")
|
|
548
|
+
except Exception as e:
|
|
549
|
+
self.log_message.emit(f"Parquet save failed (matrix): {e}")
|
|
550
|
+
|
|
551
|
+
try:
|
|
552
|
+
parquet_metadata_path = os.path.join(self.output_dir, f'{base_name}_metadata.parquet')
|
|
553
|
+
metadata_df.to_parquet(parquet_metadata_path, index=False)
|
|
554
|
+
self.log_message.emit(f"Saved metadata (Parquet) to {parquet_metadata_path}")
|
|
555
|
+
except Exception as e:
|
|
556
|
+
self.log_message.emit(f"Parquet save failed (metadata): {e}")
|
|
557
|
+
|
|
558
|
+
# Emit NPZ paths (primary format)
|
|
559
|
+
self.finished.emit(npz_matrix_path if npz_matrix_path else parquet_matrix_path,
|
|
560
|
+
npz_metadata_path if npz_metadata_path else parquet_metadata_path)
|
|
561
|
+
|
|
562
|
+
except Exception as e:
|
|
563
|
+
import traceback
|
|
564
|
+
error_msg = f"Error extracting embeddings: {str(e)}\n{traceback.format_exc()}"
|
|
565
|
+
self.log_message.emit(error_msg)
|
|
566
|
+
self.error.emit(str(e))
|
|
567
|
+
|
|
568
|
+
def _load_clip_frames(self, clip_path: str) -> np.ndarray:
|
|
569
|
+
"""Load all frames from a clip video."""
|
|
570
|
+
cap = cv2.VideoCapture(clip_path)
|
|
571
|
+
if not cap.isOpened():
|
|
572
|
+
return None
|
|
573
|
+
|
|
574
|
+
frames = []
|
|
575
|
+
while True:
|
|
576
|
+
ret, frame = cap.read()
|
|
577
|
+
if not ret:
|
|
578
|
+
break
|
|
579
|
+
# Convert BGR to RGB
|
|
580
|
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
581
|
+
frames.append(frame_rgb)
|
|
582
|
+
|
|
583
|
+
cap.release()
|
|
584
|
+
return np.array(frames) if frames else None
|
|
585
|
+
|
|
586
|
+
def _extract_embedding(self, backbone: VideoPrismBackbone, frames: np.ndarray) -> np.ndarray:
|
|
587
|
+
"""Extract mean-pooled VideoPrism embedding from frames."""
|
|
588
|
+
try:
|
|
589
|
+
# Resize frames to 288x288 (VideoPrism expects this)
|
|
590
|
+
target_size = 288
|
|
591
|
+
processed_frames = []
|
|
592
|
+
for frame in frames:
|
|
593
|
+
resized = cv2.resize(frame, (target_size, target_size))
|
|
594
|
+
processed_frames.append(resized)
|
|
595
|
+
frames_resized = np.array(processed_frames)
|
|
596
|
+
del processed_frames # Free list memory
|
|
597
|
+
|
|
598
|
+
# Convert to PyTorch format: (T, C, H, W) and normalize to [0, 1]
|
|
599
|
+
frames_t = np.transpose(frames_resized, (0, 3, 1, 2)) # (T, C, H, W)
|
|
600
|
+
del frames_resized # Free numpy array
|
|
601
|
+
|
|
602
|
+
frames_tensor = torch.from_numpy(frames_t).float() / 255.0
|
|
603
|
+
del frames_t # Free numpy array
|
|
604
|
+
|
|
605
|
+
# Add batch dimension: (1, T, C, H, W)
|
|
606
|
+
frames_tensor = frames_tensor.unsqueeze(0)
|
|
607
|
+
|
|
608
|
+
with torch.no_grad():
|
|
609
|
+
# VideoPrism returns (B, T*N, D) where N = 16*16 = 256
|
|
610
|
+
tokens = backbone(frames_tensor) # (1, T*256, D)
|
|
611
|
+
del frames_tensor # Free input tensor immediately
|
|
612
|
+
|
|
613
|
+
# Mean pool over all tokens to get single embedding vector
|
|
614
|
+
embedding = tokens.mean(dim=1).squeeze(0) # (D,)
|
|
615
|
+
del tokens # Free large token tensor
|
|
616
|
+
|
|
617
|
+
result = embedding.cpu().numpy()
|
|
618
|
+
del embedding # Free GPU tensor
|
|
619
|
+
return result
|
|
620
|
+
|
|
621
|
+
except Exception as e:
|
|
622
|
+
self.log_message.emit(f"Error extracting embedding: {e}")
|
|
623
|
+
return None
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
class RegistrationWidget(QWidget):
|
|
627
|
+
"""Widget for animal registration video processing."""
|
|
628
|
+
|
|
629
|
+
# Signal emitted when embeddings are extracted (for auto-loading in clustering tab)
|
|
630
|
+
embeddings_extracted = pyqtSignal(str, str) # matrix_path, metadata_path
|
|
631
|
+
|
|
632
|
+
def __init__(self, config: dict):
|
|
633
|
+
super().__init__()
|
|
634
|
+
self.config = config
|
|
635
|
+
self.video_mask_pairs = [] # List of (video_path, mask_path) tuples
|
|
636
|
+
self.worker = None
|
|
637
|
+
self.embedding_worker = None
|
|
638
|
+
self.output_dir = "" # Initialize output_dir
|
|
639
|
+
self.processed_videos = [] # List of processed video paths
|
|
640
|
+
self.clip_frame_ranges = {} # Dict mapping clip_path -> (start_frame, end_frame)
|
|
641
|
+
self._setup_ui()
|
|
642
|
+
|
|
643
|
+
def load_from_segmentation(self, video_path: str, mask_path: str):
|
|
644
|
+
"""Load video and mask from segmentation tab."""
|
|
645
|
+
# Clear existing
|
|
646
|
+
self.video_list.clear()
|
|
647
|
+
self.mask_list.clear()
|
|
648
|
+
self.video_mask_pairs = []
|
|
649
|
+
|
|
650
|
+
# Add video
|
|
651
|
+
if os.path.exists(video_path):
|
|
652
|
+
item = QListWidgetItem(os.path.basename(video_path))
|
|
653
|
+
item.setData(Qt.ItemDataRole.UserRole, video_path)
|
|
654
|
+
self.video_list.addItem(item)
|
|
655
|
+
|
|
656
|
+
# Add mask
|
|
657
|
+
if os.path.exists(mask_path):
|
|
658
|
+
item = QListWidgetItem(os.path.basename(mask_path))
|
|
659
|
+
item.setData(Qt.ItemDataRole.UserRole, mask_path)
|
|
660
|
+
self.mask_list.addItem(item)
|
|
661
|
+
|
|
662
|
+
# Auto-match
|
|
663
|
+
self._match_videos_masks()
|
|
664
|
+
|
|
665
|
+
# Update output directory display
|
|
666
|
+
self._update_pairs_label()
|
|
667
|
+
|
|
668
|
+
# Check if clips already exist and enable extract embeddings button
|
|
669
|
+
self._check_existing_clips()
|
|
670
|
+
|
|
671
|
+
def _setup_ui(self):
|
|
672
|
+
layout = QVBoxLayout()
|
|
673
|
+
|
|
674
|
+
# Top row: File Selection and Processing Parameters side by side
|
|
675
|
+
top_row_layout = QHBoxLayout()
|
|
676
|
+
|
|
677
|
+
# File Selection (left side)
|
|
678
|
+
files_group = QGroupBox("File selection")
|
|
679
|
+
files_layout = QVBoxLayout()
|
|
680
|
+
|
|
681
|
+
# Videos
|
|
682
|
+
video_layout = QHBoxLayout()
|
|
683
|
+
self.video_list = QListWidget()
|
|
684
|
+
self.video_list.setMaximumHeight(100)
|
|
685
|
+
video_layout.addWidget(self.video_list)
|
|
686
|
+
|
|
687
|
+
video_btn_layout = QVBoxLayout()
|
|
688
|
+
self.add_video_btn = QPushButton("Add videos")
|
|
689
|
+
self.add_video_btn.clicked.connect(self._add_videos)
|
|
690
|
+
video_btn_layout.addWidget(self.add_video_btn)
|
|
691
|
+
|
|
692
|
+
self.remove_video_btn = QPushButton("Remove selected")
|
|
693
|
+
self.remove_video_btn.clicked.connect(self._remove_video)
|
|
694
|
+
video_btn_layout.addWidget(self.remove_video_btn)
|
|
695
|
+
video_btn_layout.addStretch()
|
|
696
|
+
|
|
697
|
+
video_layout.addLayout(video_btn_layout)
|
|
698
|
+
files_layout.addLayout(video_layout)
|
|
699
|
+
|
|
700
|
+
# Masks
|
|
701
|
+
mask_layout = QHBoxLayout()
|
|
702
|
+
self.mask_list = QListWidget()
|
|
703
|
+
self.mask_list.setMaximumHeight(100)
|
|
704
|
+
mask_layout.addWidget(self.mask_list)
|
|
705
|
+
|
|
706
|
+
mask_btn_layout = QVBoxLayout()
|
|
707
|
+
self.add_mask_btn = QPushButton("Add mask files")
|
|
708
|
+
self.add_mask_btn.clicked.connect(self._add_masks)
|
|
709
|
+
mask_btn_layout.addWidget(self.add_mask_btn)
|
|
710
|
+
|
|
711
|
+
self.remove_mask_btn = QPushButton("Remove selected")
|
|
712
|
+
self.remove_mask_btn.clicked.connect(self._remove_mask)
|
|
713
|
+
mask_btn_layout.addWidget(self.remove_mask_btn)
|
|
714
|
+
mask_btn_layout.addStretch()
|
|
715
|
+
|
|
716
|
+
mask_layout.addLayout(mask_btn_layout)
|
|
717
|
+
files_layout.addLayout(mask_layout)
|
|
718
|
+
|
|
719
|
+
# Help button and pairs label
|
|
720
|
+
help_layout = QHBoxLayout()
|
|
721
|
+
self.pairs_label = QLabel("0 video-mask pairs ready")
|
|
722
|
+
help_layout.addWidget(self.pairs_label)
|
|
723
|
+
help_layout.addStretch()
|
|
724
|
+
|
|
725
|
+
# Help button (small circular button)
|
|
726
|
+
self.help_btn = QPushButton("?")
|
|
727
|
+
self.help_btn.setMaximumSize(25, 25)
|
|
728
|
+
self.help_btn.setToolTip("Click for naming information")
|
|
729
|
+
self.help_btn.clicked.connect(self._show_naming_help)
|
|
730
|
+
help_layout.addWidget(self.help_btn)
|
|
731
|
+
|
|
732
|
+
files_layout.addLayout(help_layout)
|
|
733
|
+
|
|
734
|
+
# Load Clips Folder button for extracting embeddings from existing clips
|
|
735
|
+
self.load_clips_btn = QPushButton("Already have clips? Load")
|
|
736
|
+
self.load_clips_btn.clicked.connect(self._load_clips_folder)
|
|
737
|
+
self.load_clips_btn.setToolTip("Load a folder of existing processed clips to extract behaviorome embeddings")
|
|
738
|
+
files_layout.addWidget(self.load_clips_btn)
|
|
739
|
+
|
|
740
|
+
files_group.setLayout(files_layout)
|
|
741
|
+
top_row_layout.addWidget(files_group)
|
|
742
|
+
|
|
743
|
+
# Processing Parameters (right side)
|
|
744
|
+
params_container = QVBoxLayout()
|
|
745
|
+
|
|
746
|
+
# Processing Parameters
|
|
747
|
+
params_group = QGroupBox("Processing parameters")
|
|
748
|
+
params_layout = QFormLayout()
|
|
749
|
+
|
|
750
|
+
self.box_size_spin = QSpinBox()
|
|
751
|
+
self.box_size_spin.setRange(50, 1000)
|
|
752
|
+
self.box_size_spin.setValue(250)
|
|
753
|
+
params_layout.addRow("Crop Box Size (px):", self.box_size_spin)
|
|
754
|
+
|
|
755
|
+
self.target_size_spin = QSpinBox()
|
|
756
|
+
self.target_size_spin.setRange(50, 1000)
|
|
757
|
+
self.target_size_spin.setValue(288)
|
|
758
|
+
params_layout.addRow("Output Size (px):", self.target_size_spin)
|
|
759
|
+
|
|
760
|
+
self.background_combo = QComboBox()
|
|
761
|
+
self.background_combo.addItems(["white", "black", "gray", "blur", "none"])
|
|
762
|
+
self.background_combo.setCurrentText("white")
|
|
763
|
+
params_layout.addRow("Background Mode:", self.background_combo)
|
|
764
|
+
|
|
765
|
+
self.normalization_combo = QComboBox()
|
|
766
|
+
self.normalization_combo.addItems(["CLAHE", "Histogram Equalization", "Mean-Variance", "None"])
|
|
767
|
+
self.normalization_combo.setCurrentText("CLAHE")
|
|
768
|
+
params_layout.addRow("Normalization:", self.normalization_combo)
|
|
769
|
+
|
|
770
|
+
self.feather_spin = QSpinBox()
|
|
771
|
+
self.feather_spin.setRange(0, 50)
|
|
772
|
+
self.feather_spin.setValue(0)
|
|
773
|
+
params_layout.addRow("Mask Feathering (px):", self.feather_spin)
|
|
774
|
+
|
|
775
|
+
# Anchor mode checkbox
|
|
776
|
+
self.lock_roi_checkbox = QCheckBox("Lock ROI to first frame of clip")
|
|
777
|
+
self.lock_roi_checkbox.setChecked(False)
|
|
778
|
+
self.lock_roi_checkbox.setToolTip(
|
|
779
|
+
"When checked, the crop box stays fixed at the first frame's centroid,\n"
|
|
780
|
+
"allowing the object to move within the clip (preserves locomotion).\n"
|
|
781
|
+
"When unchecked, the crop box follows the object's centroid each frame."
|
|
782
|
+
)
|
|
783
|
+
params_layout.addRow("", self.lock_roi_checkbox)
|
|
784
|
+
|
|
785
|
+
params_group.setLayout(params_layout)
|
|
786
|
+
params_container.addWidget(params_group)
|
|
787
|
+
|
|
788
|
+
# Clip Extraction Parameters
|
|
789
|
+
clip_params_group = QGroupBox("Clip extraction parameters")
|
|
790
|
+
clip_params_layout = QFormLayout()
|
|
791
|
+
|
|
792
|
+
self.target_fps_spin = QSpinBox()
|
|
793
|
+
self.target_fps_spin.setRange(1, 60)
|
|
794
|
+
self.target_fps_spin.setValue(int(self.config.get("default_target_fps", 12)))
|
|
795
|
+
clip_params_layout.addRow("Target FPS:", self.target_fps_spin)
|
|
796
|
+
|
|
797
|
+
self.clip_length_spin = QSpinBox()
|
|
798
|
+
self.clip_length_spin.setRange(1, 64)
|
|
799
|
+
self.clip_length_spin.setValue(int(self.config.get("default_clip_length", 8)))
|
|
800
|
+
clip_params_layout.addRow("Frames per clip:", self.clip_length_spin)
|
|
801
|
+
|
|
802
|
+
self.step_frames_spin = QSpinBox()
|
|
803
|
+
self.step_frames_spin.setRange(1, 64)
|
|
804
|
+
self.step_frames_spin.setValue(int(self.config.get("default_step_frames", 8)))
|
|
805
|
+
clip_params_layout.addRow("Step frames:", self.step_frames_spin)
|
|
806
|
+
|
|
807
|
+
clip_params_group.setLayout(clip_params_layout)
|
|
808
|
+
params_container.addWidget(clip_params_group)
|
|
809
|
+
|
|
810
|
+
# Create a widget to hold the params container
|
|
811
|
+
params_widget = QWidget()
|
|
812
|
+
params_widget.setLayout(params_container)
|
|
813
|
+
top_row_layout.addWidget(params_widget)
|
|
814
|
+
|
|
815
|
+
layout.addLayout(top_row_layout)
|
|
816
|
+
|
|
817
|
+
# Output (auto-determined from experiment)
|
|
818
|
+
output_group = QGroupBox("Output")
|
|
819
|
+
output_layout = QVBoxLayout()
|
|
820
|
+
|
|
821
|
+
self.output_dir_label = QLabel("Clips will be saved to experiment folder")
|
|
822
|
+
output_layout.addWidget(self.output_dir_label)
|
|
823
|
+
|
|
824
|
+
self.append_embeddings_check = QCheckBox("Append to existing embeddings if present")
|
|
825
|
+
self.append_embeddings_check.setChecked(False)
|
|
826
|
+
self.append_embeddings_check.setToolTip("When enabled, if an existing behaviorome matrix/metadata is found in the experiment, new embeddings will be appended instead of creating a new file.")
|
|
827
|
+
output_layout.addWidget(self.append_embeddings_check)
|
|
828
|
+
|
|
829
|
+
output_group.setLayout(output_layout)
|
|
830
|
+
layout.addWidget(output_group)
|
|
831
|
+
|
|
832
|
+
# Actions
|
|
833
|
+
actions_group = QGroupBox("Actions")
|
|
834
|
+
actions_layout = QHBoxLayout()
|
|
835
|
+
|
|
836
|
+
self.process_btn = QPushButton("Process videos")
|
|
837
|
+
self.process_btn.clicked.connect(self._start_processing)
|
|
838
|
+
self.process_btn.setEnabled(False)
|
|
839
|
+
actions_layout.addWidget(self.process_btn)
|
|
840
|
+
|
|
841
|
+
self.stop_btn = QPushButton("Stop")
|
|
842
|
+
self.stop_btn.clicked.connect(self._stop_processing)
|
|
843
|
+
self.stop_btn.setEnabled(False)
|
|
844
|
+
actions_layout.addWidget(self.stop_btn)
|
|
845
|
+
|
|
846
|
+
self.view_videos_btn = QPushButton("View processed videos")
|
|
847
|
+
self.view_videos_btn.clicked.connect(self._view_processed_videos)
|
|
848
|
+
self.view_videos_btn.setEnabled(False)
|
|
849
|
+
actions_layout.addWidget(self.view_videos_btn)
|
|
850
|
+
|
|
851
|
+
self.extract_embeddings_btn = QPushButton("Extract behaviorome embeddings")
|
|
852
|
+
self.extract_embeddings_btn.clicked.connect(self._extract_embeddings)
|
|
853
|
+
self.extract_embeddings_btn.setEnabled(False)
|
|
854
|
+
actions_layout.addWidget(self.extract_embeddings_btn)
|
|
855
|
+
|
|
856
|
+
self.export_roi_btn = QPushButton("Export ROI videos (per object)")
|
|
857
|
+
self.export_roi_btn.setToolTip(
|
|
858
|
+
"Export one cropped video per tracked object.\n"
|
|
859
|
+
"Videos are saved in the experiment folder (roi_videos)."
|
|
860
|
+
)
|
|
861
|
+
self.export_roi_btn.clicked.connect(self._export_roi_videos)
|
|
862
|
+
self.export_roi_btn.setEnabled(False)
|
|
863
|
+
actions_layout.addWidget(self.export_roi_btn)
|
|
864
|
+
|
|
865
|
+
actions_group.setLayout(actions_layout)
|
|
866
|
+
layout.addWidget(actions_group)
|
|
867
|
+
|
|
868
|
+
# Progress
|
|
869
|
+
progress_group = QGroupBox("Progress")
|
|
870
|
+
progress_layout = QVBoxLayout()
|
|
871
|
+
|
|
872
|
+
self.progress_bar = QProgressBar()
|
|
873
|
+
progress_layout.addWidget(self.progress_bar)
|
|
874
|
+
|
|
875
|
+
self.status_label = QLabel("Ready")
|
|
876
|
+
progress_layout.addWidget(self.status_label)
|
|
877
|
+
|
|
878
|
+
self.log_text = QTextEdit()
|
|
879
|
+
self.log_text.setReadOnly(True)
|
|
880
|
+
self.log_text.setMaximumHeight(150)
|
|
881
|
+
progress_layout.addWidget(self.log_text)
|
|
882
|
+
|
|
883
|
+
progress_group.setLayout(progress_layout)
|
|
884
|
+
layout.addWidget(progress_group)
|
|
885
|
+
|
|
886
|
+
layout.addStretch()
|
|
887
|
+
self.setLayout(layout)
|
|
888
|
+
|
|
889
|
+
# Check for existing clips
|
|
890
|
+
self._check_existing_clips()
|
|
891
|
+
|
|
892
|
+
def _add_videos(self):
|
|
893
|
+
video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
|
|
894
|
+
paths, _ = QFileDialog.getOpenFileNames(
|
|
895
|
+
self, "Select Video Files", video_dir, "Video Files (*.mp4 *.avi *.mov *.mkv)"
|
|
896
|
+
)
|
|
897
|
+
if paths:
|
|
898
|
+
# Ensure videos are in experiment folder (batch operation)
|
|
899
|
+
from .video_utils import ensure_videos_in_experiment
|
|
900
|
+
paths = ensure_videos_in_experiment(paths, self.config, self)
|
|
901
|
+
for path in paths:
|
|
902
|
+
item = QListWidgetItem(os.path.basename(path))
|
|
903
|
+
item.setData(Qt.ItemDataRole.UserRole, path)
|
|
904
|
+
self.video_list.addItem(item)
|
|
905
|
+
# Auto-match after adding videos
|
|
906
|
+
self._match_videos_masks()
|
|
907
|
+
self._update_pairs_label()
|
|
908
|
+
|
|
909
|
+
def _remove_video(self):
|
|
910
|
+
for item in self.video_list.selectedItems():
|
|
911
|
+
self.video_list.takeItem(self.video_list.row(item))
|
|
912
|
+
# Re-match after removal
|
|
913
|
+
self._match_videos_masks()
|
|
914
|
+
self._update_pairs_label()
|
|
915
|
+
|
|
916
|
+
def _add_masks(self):
|
|
917
|
+
# Check experiment folder first, then default location
|
|
918
|
+
experiment_path = self.config.get("experiment_path")
|
|
919
|
+
if experiment_path and os.path.exists(experiment_path):
|
|
920
|
+
masks_dir = os.path.join(experiment_path, "masks")
|
|
921
|
+
else:
|
|
922
|
+
masks_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "masks")
|
|
923
|
+
|
|
924
|
+
paths, _ = QFileDialog.getOpenFileNames(
|
|
925
|
+
self, "Select Mask Files", masks_dir if os.path.exists(masks_dir) else "", "HDF5 Files (*.h5 *.hdf5)"
|
|
926
|
+
)
|
|
927
|
+
if paths:
|
|
928
|
+
for path in paths:
|
|
929
|
+
item = QListWidgetItem(os.path.basename(path))
|
|
930
|
+
item.setData(Qt.ItemDataRole.UserRole, path)
|
|
931
|
+
self.mask_list.addItem(item)
|
|
932
|
+
# Auto-match after adding masks
|
|
933
|
+
self._match_videos_masks()
|
|
934
|
+
self._update_pairs_label()
|
|
935
|
+
|
|
936
|
+
def _remove_mask(self):
|
|
937
|
+
for item in self.mask_list.selectedItems():
|
|
938
|
+
self.mask_list.takeItem(self.mask_list.row(item))
|
|
939
|
+
# Re-match after removal
|
|
940
|
+
self._match_videos_masks()
|
|
941
|
+
self._update_pairs_label()
|
|
942
|
+
|
|
943
|
+
def _match_videos_masks(self):
|
|
944
|
+
"""Match videos to masks based on filename."""
|
|
945
|
+
self.video_mask_pairs = []
|
|
946
|
+
|
|
947
|
+
videos = []
|
|
948
|
+
for i in range(self.video_list.count()):
|
|
949
|
+
item = self.video_list.item(i)
|
|
950
|
+
videos.append((item.data(Qt.ItemDataRole.UserRole), item.text()))
|
|
951
|
+
|
|
952
|
+
masks = []
|
|
953
|
+
for i in range(self.mask_list.count()):
|
|
954
|
+
item = self.mask_list.item(i)
|
|
955
|
+
masks.append((item.data(Qt.ItemDataRole.UserRole), item.text()))
|
|
956
|
+
|
|
957
|
+
# Match by base name
|
|
958
|
+
matched = []
|
|
959
|
+
unmatched_videos = []
|
|
960
|
+
unmatched_masks = []
|
|
961
|
+
|
|
962
|
+
for video_path, video_name in videos:
|
|
963
|
+
base_video = os.path.splitext(video_name)[0]
|
|
964
|
+
found = False
|
|
965
|
+
for mask_path, mask_name in masks:
|
|
966
|
+
base_mask = os.path.splitext(mask_name)[0]
|
|
967
|
+
# Remove "_masks" suffix from mask name (most common case)
|
|
968
|
+
base_mask_clean = base_mask.replace('_masks', '').replace('_mask', '').replace('_objects', '')
|
|
969
|
+
base_video_clean = base_video.replace('_output', '').replace('_segmented', '')
|
|
970
|
+
|
|
971
|
+
# Try exact match first
|
|
972
|
+
if base_video == base_mask:
|
|
973
|
+
matched.append((video_path, mask_path))
|
|
974
|
+
found = True
|
|
975
|
+
break
|
|
976
|
+
# Try matching video name with mask name after removing "_masks"
|
|
977
|
+
elif base_video == base_mask_clean:
|
|
978
|
+
matched.append((video_path, mask_path))
|
|
979
|
+
found = True
|
|
980
|
+
break
|
|
981
|
+
# Try matching cleaned versions
|
|
982
|
+
elif base_video_clean == base_mask_clean:
|
|
983
|
+
matched.append((video_path, mask_path))
|
|
984
|
+
found = True
|
|
985
|
+
break
|
|
986
|
+
|
|
987
|
+
if not found:
|
|
988
|
+
unmatched_videos.append(video_name)
|
|
989
|
+
|
|
990
|
+
# Check for unmatched masks
|
|
991
|
+
matched_mask_names = {os.path.basename(m) for _, m in matched}
|
|
992
|
+
for mask_path, mask_name in masks:
|
|
993
|
+
if mask_name not in matched_mask_names:
|
|
994
|
+
unmatched_masks.append(mask_name)
|
|
995
|
+
|
|
996
|
+
self.video_mask_pairs = matched
|
|
997
|
+
# Matching happens automatically, no message box needed
|
|
998
|
+
|
|
999
|
+
def _update_pairs_label(self):
|
|
1000
|
+
count = len(self.video_mask_pairs)
|
|
1001
|
+
self.pairs_label.setText(f"{count} video-mask pairs ready")
|
|
1002
|
+
# Auto-determine output directory from experiment
|
|
1003
|
+
experiment_path = self.config.get("experiment_path")
|
|
1004
|
+
if experiment_path and os.path.exists(experiment_path):
|
|
1005
|
+
self.output_dir = os.path.join(experiment_path, "registered_clips")
|
|
1006
|
+
self.output_dir_label.setText(f"Output: {self.output_dir}")
|
|
1007
|
+
else:
|
|
1008
|
+
self.output_dir = ""
|
|
1009
|
+
self.output_dir_label.setText("No experiment folder - please create an experiment first")
|
|
1010
|
+
self.process_btn.setEnabled(count > 0 and bool(self.output_dir))
|
|
1011
|
+
self.export_roi_btn.setEnabled(count > 0)
|
|
1012
|
+
|
|
1013
|
+
def _show_naming_help(self):
|
|
1014
|
+
"""Show help dialog about video-mask naming."""
|
|
1015
|
+
QMessageBox.information(
|
|
1016
|
+
self,
|
|
1017
|
+
"Video-Mask Naming",
|
|
1018
|
+
"Videos and masks are automatically matched based on their filenames.\n\n"
|
|
1019
|
+
"Naming rules:\n"
|
|
1020
|
+
"-Video: video_name.mp4\n"
|
|
1021
|
+
"-Mask: video_name.h5 (or video_name_masks.h5)\n\n"
|
|
1022
|
+
"The mask filename should match the video filename (without extension),\n"
|
|
1023
|
+
"or have '_masks' suffix that will be automatically removed.\n\n"
|
|
1024
|
+
"Example:\n"
|
|
1025
|
+
"-Video: my_video.mp4\n"
|
|
1026
|
+
"- Mask: my_video.h5\n"
|
|
1027
|
+
"- Mask: my_video_masks.h5"
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
def _start_processing(self):
|
|
1031
|
+
if not self.video_mask_pairs:
|
|
1032
|
+
QMessageBox.warning(self, "Error", "Please add videos and masks first.")
|
|
1033
|
+
return
|
|
1034
|
+
|
|
1035
|
+
# Auto-determine output directory from experiment
|
|
1036
|
+
experiment_path = self.config.get("experiment_path")
|
|
1037
|
+
if not experiment_path or not os.path.exists(experiment_path):
|
|
1038
|
+
QMessageBox.warning(self, "Error", "No experiment folder found. Please create an experiment first.")
|
|
1039
|
+
return
|
|
1040
|
+
|
|
1041
|
+
output_dir = os.path.join(experiment_path, "registered_clips")
|
|
1042
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1043
|
+
self.output_dir = output_dir
|
|
1044
|
+
self.output_dir_label.setText(f"Output: {output_dir}")
|
|
1045
|
+
|
|
1046
|
+
# Get parameters
|
|
1047
|
+
params = {
|
|
1048
|
+
'box_size': self.box_size_spin.value(),
|
|
1049
|
+
'target_size': self.target_size_spin.value(),
|
|
1050
|
+
'background_mode': self.background_combo.currentText(),
|
|
1051
|
+
'normalization_method': self.normalization_combo.currentText(),
|
|
1052
|
+
'mask_feather_px': self.feather_spin.value(),
|
|
1053
|
+
'anchor_mode': 'first' if self.lock_roi_checkbox.isChecked() else 'frame',
|
|
1054
|
+
'target_fps': self.target_fps_spin.value(),
|
|
1055
|
+
'clip_length_frames': self.clip_length_spin.value(),
|
|
1056
|
+
'step_frames': self.step_frames_spin.value()
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
# Create worker
|
|
1060
|
+
self.worker = ProcessingWorker(self.video_mask_pairs, self.output_dir, params)
|
|
1061
|
+
self.worker.progress.connect(self._on_progress)
|
|
1062
|
+
self.worker.finished.connect(self._on_finished)
|
|
1063
|
+
self.worker.error.connect(self._on_error)
|
|
1064
|
+
|
|
1065
|
+
self.process_btn.setEnabled(False)
|
|
1066
|
+
self.stop_btn.setEnabled(True)
|
|
1067
|
+
self.view_videos_btn.setEnabled(False)
|
|
1068
|
+
self.progress_bar.setMaximum(len(self.video_mask_pairs))
|
|
1069
|
+
self.progress_bar.setValue(0)
|
|
1070
|
+
self.log_text.clear()
|
|
1071
|
+
self.processed_videos = []
|
|
1072
|
+
|
|
1073
|
+
self.worker.start()
|
|
1074
|
+
|
|
1075
|
+
def _stop_processing(self):
|
|
1076
|
+
if self.worker:
|
|
1077
|
+
self.worker.stop()
|
|
1078
|
+
self.worker.wait()
|
|
1079
|
+
self.log_text.append("Processing stopped by user.")
|
|
1080
|
+
self.process_btn.setEnabled(True)
|
|
1081
|
+
self.stop_btn.setEnabled(False)
|
|
1082
|
+
|
|
1083
|
+
def _on_progress(self, current, total, video_name):
|
|
1084
|
+
self.progress_bar.setValue(current)
|
|
1085
|
+
self.status_label.setText(f"Processing: {video_name} ({current}/{total})")
|
|
1086
|
+
|
|
1087
|
+
def _on_finished(self, output_paths):
|
|
1088
|
+
self.log_text.append("=" * 50)
|
|
1089
|
+
self.log_text.append("All videos processed successfully!")
|
|
1090
|
+
self.log_text.append(f"Output directory: {self.output_dir}")
|
|
1091
|
+
self.log_text.append(f"Created {len(output_paths)} clip(s)")
|
|
1092
|
+
|
|
1093
|
+
# Extract clip paths and frame ranges from tuples
|
|
1094
|
+
clip_paths_list = []
|
|
1095
|
+
self.clip_frame_ranges = {}
|
|
1096
|
+
for item in output_paths:
|
|
1097
|
+
if isinstance(item, tuple) and len(item) == 3:
|
|
1098
|
+
clip_path, start_frame, end_frame = item
|
|
1099
|
+
clip_paths_list.append(clip_path)
|
|
1100
|
+
self.clip_frame_ranges[clip_path] = (start_frame, end_frame)
|
|
1101
|
+
else:
|
|
1102
|
+
# Legacy: just a path string
|
|
1103
|
+
clip_paths_list.append(item)
|
|
1104
|
+
|
|
1105
|
+
# Group clips by video (using extracted paths)
|
|
1106
|
+
clips_by_video = {}
|
|
1107
|
+
for path in clip_paths_list:
|
|
1108
|
+
video_dir = os.path.dirname(path)
|
|
1109
|
+
if video_dir not in clips_by_video:
|
|
1110
|
+
clips_by_video[video_dir] = []
|
|
1111
|
+
clips_by_video[video_dir].append(path)
|
|
1112
|
+
|
|
1113
|
+
for video_dir, clips in clips_by_video.items():
|
|
1114
|
+
video_name = os.path.basename(video_dir)
|
|
1115
|
+
self.log_text.append(f" {video_name}: {len(clips)} clip(s)")
|
|
1116
|
+
|
|
1117
|
+
self.processed_videos = clip_paths_list
|
|
1118
|
+
self.status_label.setText("Complete")
|
|
1119
|
+
self.process_btn.setEnabled(True)
|
|
1120
|
+
self.stop_btn.setEnabled(False)
|
|
1121
|
+
self.view_videos_btn.setEnabled(len(clip_paths_list) > 0)
|
|
1122
|
+
self.extract_embeddings_btn.setEnabled(len(clip_paths_list) > 0)
|
|
1123
|
+
|
|
1124
|
+
msg = f"All videos have been processed into clips.\n\nCreated {len(clip_paths_list)} clip(s) total.\n\nWould you like to view them now?"
|
|
1125
|
+
reply = QMessageBox.question(
|
|
1126
|
+
self, "Success", msg,
|
|
1127
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
1128
|
+
QMessageBox.StandardButton.Yes
|
|
1129
|
+
)
|
|
1130
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
1131
|
+
self._view_processed_videos()
|
|
1132
|
+
|
|
1133
|
+
def _on_error(self, error_msg):
|
|
1134
|
+
self.log_text.append(f"Error: {error_msg}")
|
|
1135
|
+
QMessageBox.critical(self, "Error", error_msg)
|
|
1136
|
+
self.process_btn.setEnabled(True)
|
|
1137
|
+
self.stop_btn.setEnabled(False)
|
|
1138
|
+
|
|
1139
|
+
def _view_processed_videos(self):
|
|
1140
|
+
"""Open video player dialog for processed videos."""
|
|
1141
|
+
if not self.processed_videos:
|
|
1142
|
+
QMessageBox.information(self, "No Videos", "No processed videos available to view.")
|
|
1143
|
+
return
|
|
1144
|
+
|
|
1145
|
+
# If multiple videos, let user choose start, otherwise open directly
|
|
1146
|
+
if len(self.processed_videos) == 1:
|
|
1147
|
+
dialog = VideoPlayerDialog(self.processed_videos, 0, self)
|
|
1148
|
+
dialog.exec()
|
|
1149
|
+
else:
|
|
1150
|
+
# Multiple videos - let user choose which one to start with
|
|
1151
|
+
from PyQt6.QtWidgets import QListWidget, QDialogButtonBox
|
|
1152
|
+
|
|
1153
|
+
dialog = QDialog(self)
|
|
1154
|
+
dialog.setWindowTitle("Select video to view")
|
|
1155
|
+
dialog.setMinimumSize(400, 300)
|
|
1156
|
+
layout = QVBoxLayout(dialog)
|
|
1157
|
+
|
|
1158
|
+
label = QLabel("Select a video to view:")
|
|
1159
|
+
layout.addWidget(label)
|
|
1160
|
+
|
|
1161
|
+
list_widget = QListWidget()
|
|
1162
|
+
for path in self.processed_videos:
|
|
1163
|
+
list_widget.addItem(os.path.basename(path))
|
|
1164
|
+
layout.addWidget(list_widget)
|
|
1165
|
+
|
|
1166
|
+
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
|
1167
|
+
buttons.accepted.connect(dialog.accept)
|
|
1168
|
+
buttons.rejected.connect(dialog.reject)
|
|
1169
|
+
layout.addWidget(buttons)
|
|
1170
|
+
|
|
1171
|
+
if dialog.exec() == QDialog.DialogCode.Accepted:
|
|
1172
|
+
selected_items = list_widget.selectedItems()
|
|
1173
|
+
if selected_items:
|
|
1174
|
+
idx = list_widget.row(selected_items[0])
|
|
1175
|
+
player_dialog = VideoPlayerDialog(self.processed_videos, idx, self)
|
|
1176
|
+
player_dialog.exec()
|
|
1177
|
+
|
|
1178
|
+
def _load_clips_folder(self):
|
|
1179
|
+
"""Load a folder of existing clips for embedding extraction."""
|
|
1180
|
+
folder_path = QFileDialog.getExistingDirectory(self, "Select Clips Folder")
|
|
1181
|
+
if not folder_path:
|
|
1182
|
+
return
|
|
1183
|
+
|
|
1184
|
+
clip_paths = []
|
|
1185
|
+
for root, dirs, files in os.walk(folder_path):
|
|
1186
|
+
for file in files:
|
|
1187
|
+
if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
|
1188
|
+
clip_paths.append(os.path.join(root, file))
|
|
1189
|
+
|
|
1190
|
+
if not clip_paths:
|
|
1191
|
+
QMessageBox.warning(self, "No Clips", f"No video clips found in {folder_path}")
|
|
1192
|
+
return
|
|
1193
|
+
|
|
1194
|
+
self.output_dir = folder_path
|
|
1195
|
+
self.processed_videos = sorted(clip_paths)
|
|
1196
|
+
|
|
1197
|
+
self.log_text.append("=" * 50)
|
|
1198
|
+
self.log_text.append(f"Loaded {len(clip_paths)} clips from: {folder_path}")
|
|
1199
|
+
self.status_label.setText(f"Loaded {len(clip_paths)} clips")
|
|
1200
|
+
self.extract_embeddings_btn.setEnabled(True)
|
|
1201
|
+
self.view_videos_btn.setEnabled(True)
|
|
1202
|
+
|
|
1203
|
+
# Ask to extract immediately
|
|
1204
|
+
reply = QMessageBox.question(
|
|
1205
|
+
self,
|
|
1206
|
+
"Extract Embeddings",
|
|
1207
|
+
f"Found {len(clip_paths)} clips.\nDo you want to extract embeddings now?",
|
|
1208
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
1209
|
+
QMessageBox.StandardButton.Yes
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
1213
|
+
self._extract_embeddings()
|
|
1214
|
+
|
|
1215
|
+
def _extract_embeddings(self):
|
|
1216
|
+
"""Extract VideoPrism embeddings from processed clips."""
|
|
1217
|
+
if not self.processed_videos:
|
|
1218
|
+
QMessageBox.warning(self, "No Clips", "No processed clips available. Please process videos first.")
|
|
1219
|
+
return
|
|
1220
|
+
|
|
1221
|
+
# Get output directory (where clips are stored)
|
|
1222
|
+
if not self.output_dir:
|
|
1223
|
+
experiment_path = self.config.get("experiment_path")
|
|
1224
|
+
if not experiment_path:
|
|
1225
|
+
QMessageBox.warning(self, "No Experiment", "No active experiment. Please create or load an experiment first.")
|
|
1226
|
+
return
|
|
1227
|
+
self.output_dir = os.path.join(experiment_path, "registered_clips")
|
|
1228
|
+
|
|
1229
|
+
if not os.path.exists(self.output_dir):
|
|
1230
|
+
QMessageBox.warning(self, "Directory Not Found", f"Clips directory not found: {self.output_dir}")
|
|
1231
|
+
return
|
|
1232
|
+
|
|
1233
|
+
# Collect all clip files from registered_clips directory
|
|
1234
|
+
clip_paths = []
|
|
1235
|
+
for root, dirs, files in os.walk(self.output_dir):
|
|
1236
|
+
for file in files:
|
|
1237
|
+
if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
|
1238
|
+
clip_paths.append(os.path.join(root, file))
|
|
1239
|
+
|
|
1240
|
+
if not clip_paths:
|
|
1241
|
+
QMessageBox.warning(self, "No Clips", f"No video clips found in {self.output_dir}")
|
|
1242
|
+
return
|
|
1243
|
+
|
|
1244
|
+
# Sort clips for consistent ordering
|
|
1245
|
+
clip_paths.sort()
|
|
1246
|
+
|
|
1247
|
+
# Get model name from config
|
|
1248
|
+
model_name = self.config.get("backbone_model", "videoprism_public_v1_base")
|
|
1249
|
+
|
|
1250
|
+
# Confirm with user
|
|
1251
|
+
reply = QMessageBox.question(
|
|
1252
|
+
self,
|
|
1253
|
+
"Extract Embeddings",
|
|
1254
|
+
f"Extract VideoPrism embeddings from {len(clip_paths)} clip(s)?\n\n"
|
|
1255
|
+
f"Model: {model_name}\n"
|
|
1256
|
+
f"Output directory: {self.output_dir}",
|
|
1257
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
1258
|
+
QMessageBox.StandardButton.Yes
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
if reply != QMessageBox.StandardButton.Yes:
|
|
1262
|
+
return
|
|
1263
|
+
|
|
1264
|
+
# Get experiment name from config
|
|
1265
|
+
experiment_name = self.config.get("experiment_name", None)
|
|
1266
|
+
|
|
1267
|
+
# Start extraction worker with frame ranges if available
|
|
1268
|
+
self.embedding_worker = EmbeddingExtractionWorker(
|
|
1269
|
+
clip_paths,
|
|
1270
|
+
self.output_dir,
|
|
1271
|
+
experiment_name=experiment_name,
|
|
1272
|
+
model_name=model_name,
|
|
1273
|
+
clip_frame_ranges=self.clip_frame_ranges if hasattr(self, 'clip_frame_ranges') else None,
|
|
1274
|
+
append_to_existing=self.append_embeddings_check.isChecked()
|
|
1275
|
+
)
|
|
1276
|
+
self.embedding_worker.progress.connect(self._on_embedding_progress)
|
|
1277
|
+
self.embedding_worker.finished.connect(self._on_embedding_finished)
|
|
1278
|
+
self.embedding_worker.error.connect(self._on_embedding_error)
|
|
1279
|
+
self.embedding_worker.log_message.connect(self._on_embedding_log)
|
|
1280
|
+
|
|
1281
|
+
self.extract_embeddings_btn.setEnabled(False)
|
|
1282
|
+
self.progress_bar.setMaximum(len(clip_paths))
|
|
1283
|
+
self.progress_bar.setValue(0)
|
|
1284
|
+
self.log_text.clear()
|
|
1285
|
+
self.log_text.append(f"Starting VideoPrism embedding extraction from {len(clip_paths)} clips...")
|
|
1286
|
+
|
|
1287
|
+
self.embedding_worker.start()
|
|
1288
|
+
|
|
1289
|
+
def _on_embedding_progress(self, current, total, clip_name):
|
|
1290
|
+
self.progress_bar.setValue(current)
|
|
1291
|
+
self.status_label.setText(f"Extracting embeddings: {clip_name} ({current}/{total})")
|
|
1292
|
+
|
|
1293
|
+
def _on_embedding_finished(self, feature_matrix_path, metadata_path):
|
|
1294
|
+
self.log_text.append("=" * 50)
|
|
1295
|
+
self.log_text.append("Embedding extraction completed successfully!")
|
|
1296
|
+
self.log_text.append(f"Feature matrix: {feature_matrix_path}")
|
|
1297
|
+
self.log_text.append(f"Metadata: {metadata_path}")
|
|
1298
|
+
|
|
1299
|
+
self.status_label.setText("Embedding extraction complete")
|
|
1300
|
+
self.extract_embeddings_btn.setEnabled(True)
|
|
1301
|
+
|
|
1302
|
+
# Emit signal for auto-loading in clustering tab
|
|
1303
|
+
self.embeddings_extracted.emit(feature_matrix_path, metadata_path)
|
|
1304
|
+
|
|
1305
|
+
QMessageBox.information(
|
|
1306
|
+
self,
|
|
1307
|
+
"Success",
|
|
1308
|
+
f"Behaviorome embeddings extracted successfully!\n\n"
|
|
1309
|
+
f"Feature matrix: {os.path.basename(feature_matrix_path)}\n"
|
|
1310
|
+
f"Metadata: {os.path.basename(metadata_path)}\n\n"
|
|
1311
|
+
f"Files saved to: {os.path.dirname(feature_matrix_path)}\n\n"
|
|
1312
|
+
f"Data will be automatically loaded in the Clustering tab."
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
def _on_embedding_error(self, error_msg):
|
|
1316
|
+
self.log_text.append(f"Error: {error_msg}")
|
|
1317
|
+
QMessageBox.critical(self, "Error", f"Embedding extraction failed:\n{error_msg}")
|
|
1318
|
+
self.extract_embeddings_btn.setEnabled(True)
|
|
1319
|
+
self.status_label.setText("Error")
|
|
1320
|
+
|
|
1321
|
+
def _on_embedding_log(self, message):
|
|
1322
|
+
self.log_text.append(message)
|
|
1323
|
+
|
|
1324
|
+
def _export_roi_videos(self):
|
|
1325
|
+
"""Export one cropped video per tracked object for each video-mask pair."""
|
|
1326
|
+
if not self.video_mask_pairs:
|
|
1327
|
+
QMessageBox.warning(self, "No Data", "Add video and mask files first.")
|
|
1328
|
+
return
|
|
1329
|
+
|
|
1330
|
+
all_exported = []
|
|
1331
|
+
for video_path, mask_path in self.video_mask_pairs:
|
|
1332
|
+
try:
|
|
1333
|
+
mask_data = load_segmentation_data(mask_path)
|
|
1334
|
+
except Exception as e:
|
|
1335
|
+
QMessageBox.warning(self, "Load Error", f"Could not load mask {os.path.basename(mask_path)}: {e}")
|
|
1336
|
+
continue
|
|
1337
|
+
|
|
1338
|
+
frame_objects = mask_data.get("frame_objects", [])
|
|
1339
|
+
if not frame_objects:
|
|
1340
|
+
QMessageBox.warning(self, "No Frames", f"No mask frames in {os.path.basename(mask_path)}.")
|
|
1341
|
+
continue
|
|
1342
|
+
|
|
1343
|
+
start_offset = mask_data.get("start_offset", 0)
|
|
1344
|
+
all_obj_ids = set()
|
|
1345
|
+
for frame_objs in frame_objects:
|
|
1346
|
+
for obj in frame_objs:
|
|
1347
|
+
all_obj_ids.add(obj.get("obj_id", 0))
|
|
1348
|
+
all_obj_ids = sorted(all_obj_ids)
|
|
1349
|
+
if not all_obj_ids:
|
|
1350
|
+
continue
|
|
1351
|
+
|
|
1352
|
+
experiment_path = self.config.get("experiment_path")
|
|
1353
|
+
if experiment_path and os.path.exists(experiment_path):
|
|
1354
|
+
out_dir = os.path.join(experiment_path, "roi_videos")
|
|
1355
|
+
else:
|
|
1356
|
+
out_dir = os.path.join(os.path.dirname(video_path), "roi_videos")
|
|
1357
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
1358
|
+
video_basename = os.path.splitext(os.path.basename(video_path))[0]
|
|
1359
|
+
|
|
1360
|
+
obj_tracks = {oid: {} for oid in all_obj_ids}
|
|
1361
|
+
for i, frame_objs in enumerate(frame_objects):
|
|
1362
|
+
frame_idx = start_offset + i
|
|
1363
|
+
for obj in frame_objs:
|
|
1364
|
+
oid = obj.get("obj_id", 0)
|
|
1365
|
+
bbox = obj.get("bbox")
|
|
1366
|
+
if bbox is not None:
|
|
1367
|
+
obj_tracks[oid][frame_idx] = tuple(int(x) for x in bbox)
|
|
1368
|
+
|
|
1369
|
+
cap = cv2.VideoCapture(video_path)
|
|
1370
|
+
vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
1371
|
+
vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
1372
|
+
vid_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
1373
|
+
if vid_fps <= 0:
|
|
1374
|
+
vid_fps = 30.0
|
|
1375
|
+
cap.release()
|
|
1376
|
+
|
|
1377
|
+
frame_indices = sorted(set(f for track in obj_tracks.values() for f in track))
|
|
1378
|
+
if not frame_indices:
|
|
1379
|
+
continue
|
|
1380
|
+
start_frame = frame_indices[0]
|
|
1381
|
+
end_frame = frame_indices[-1]
|
|
1382
|
+
|
|
1383
|
+
crop_padding = 0.25
|
|
1384
|
+
writers = {}
|
|
1385
|
+
obj_crop_params = {}
|
|
1386
|
+
obj_paths = {}
|
|
1387
|
+
for obj_id, track in obj_tracks.items():
|
|
1388
|
+
if not track:
|
|
1389
|
+
continue
|
|
1390
|
+
max_w = max(x2 - x1 for x1, y1, x2, y2 in track.values())
|
|
1391
|
+
max_h = max(y2 - y1 for x1, y1, x2, y2 in track.values())
|
|
1392
|
+
crop_w = int(max_w * (1 + 2 * crop_padding))
|
|
1393
|
+
crop_h = int(max_h * (1 + 2 * crop_padding))
|
|
1394
|
+
crop_side = max(crop_w, crop_h, 64)
|
|
1395
|
+
obj_crop_params[obj_id] = crop_side
|
|
1396
|
+
out_path = os.path.join(out_dir, f"{video_basename}_object{obj_id}.mp4")
|
|
1397
|
+
obj_paths[obj_id] = out_path
|
|
1398
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
1399
|
+
writers[obj_id] = cv2.VideoWriter(out_path, fourcc, vid_fps, (crop_side, crop_side))
|
|
1400
|
+
|
|
1401
|
+
if not writers:
|
|
1402
|
+
continue
|
|
1403
|
+
|
|
1404
|
+
progress = QProgressDialog(
|
|
1405
|
+
f"Exporting ROI videos: {os.path.basename(video_path)}...", "Cancel",
|
|
1406
|
+
start_frame, end_frame + 1, self
|
|
1407
|
+
)
|
|
1408
|
+
progress.setWindowTitle("Export ROI Videos")
|
|
1409
|
+
progress.setMinimumDuration(0)
|
|
1410
|
+
progress.setValue(start_frame)
|
|
1411
|
+
|
|
1412
|
+
cap = cv2.VideoCapture(video_path)
|
|
1413
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
1414
|
+
last_bbox = {}
|
|
1415
|
+
for fidx in range(start_frame, end_frame + 1):
|
|
1416
|
+
if progress.wasCanceled():
|
|
1417
|
+
break
|
|
1418
|
+
ret, frame = cap.read()
|
|
1419
|
+
if not ret:
|
|
1420
|
+
break
|
|
1421
|
+
progress.setValue(fidx)
|
|
1422
|
+
QApplication.processEvents()
|
|
1423
|
+
for obj_id, writer in writers.items():
|
|
1424
|
+
crop_side = obj_crop_params[obj_id]
|
|
1425
|
+
bbox = obj_tracks[obj_id].get(fidx) or last_bbox.get(obj_id)
|
|
1426
|
+
if bbox is None:
|
|
1427
|
+
writer.write(np.zeros((crop_side, crop_side, 3), dtype=np.uint8))
|
|
1428
|
+
continue
|
|
1429
|
+
last_bbox[obj_id] = bbox
|
|
1430
|
+
x1, y1, x2, y2 = bbox
|
|
1431
|
+
cx = (x1 + x2) // 2
|
|
1432
|
+
cy = (y1 + y2) // 2
|
|
1433
|
+
half = crop_side // 2
|
|
1434
|
+
rx1 = max(0, cx - half)
|
|
1435
|
+
ry1 = max(0, cy - half)
|
|
1436
|
+
rx2 = min(vid_w, rx1 + crop_side)
|
|
1437
|
+
ry2 = min(vid_h, ry1 + crop_side)
|
|
1438
|
+
rx1 = max(0, rx2 - crop_side)
|
|
1439
|
+
ry1 = max(0, ry2 - crop_side)
|
|
1440
|
+
crop = frame[ry1:ry2, rx1:rx2]
|
|
1441
|
+
if crop.shape[0] != crop_side or crop.shape[1] != crop_side:
|
|
1442
|
+
crop = cv2.resize(crop, (crop_side, crop_side), interpolation=cv2.INTER_AREA)
|
|
1443
|
+
writer.write(crop)
|
|
1444
|
+
|
|
1445
|
+
cap.release()
|
|
1446
|
+
progress.close()
|
|
1447
|
+
for w in writers.values():
|
|
1448
|
+
w.release()
|
|
1449
|
+
all_exported.extend([(out_dir, list(obj_paths.values()))])
|
|
1450
|
+
|
|
1451
|
+
if not all_exported:
|
|
1452
|
+
QMessageBox.warning(self, "No Tracks", "No object tracks with valid bboxes found.")
|
|
1453
|
+
return
|
|
1454
|
+
summary = []
|
|
1455
|
+
for out_dir, paths in all_exported:
|
|
1456
|
+
names = [os.path.basename(p) for p in paths]
|
|
1457
|
+
summary.append(f"{out_dir}\n " + "\n ".join(names))
|
|
1458
|
+
QMessageBox.information(
|
|
1459
|
+
self, "Export Complete",
|
|
1460
|
+
f"Exported ROI videos to:\n\n" + "\n\n".join(summary)
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
def _check_existing_clips(self):
|
|
1464
|
+
"""Check if processed clips exist and enable extract embeddings button."""
|
|
1465
|
+
experiment_path = self.config.get("experiment_path")
|
|
1466
|
+
if not experiment_path:
|
|
1467
|
+
return
|
|
1468
|
+
|
|
1469
|
+
clips_dir = os.path.join(experiment_path, "registered_clips")
|
|
1470
|
+
if os.path.exists(clips_dir):
|
|
1471
|
+
clip_files = []
|
|
1472
|
+
for root, dirs, files in os.walk(clips_dir):
|
|
1473
|
+
for file in files:
|
|
1474
|
+
if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
|
1475
|
+
clip_files.append(os.path.join(root, file))
|
|
1476
|
+
|
|
1477
|
+
if clip_files:
|
|
1478
|
+
self.processed_videos = clip_files
|
|
1479
|
+
self.extract_embeddings_btn.setEnabled(True)
|
|
1480
|
+
self.output_dir = clips_dir
|
|
1481
|
+
|
|
1482
|
+
def update_config(self, config: dict):
|
|
1483
|
+
"""Update configuration."""
|
|
1484
|
+
self.config = config
|
|
1485
|
+
|