singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,1485 @@
1
+ """Registration widget for processing videos with masks from segmentation."""
2
+ import sys
3
+ import os
4
+ import glob
5
+ # JAX memory: grow on demand, capped at 45% so PyTorch (SAM2) keeps the rest
6
+ os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
7
+ os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import re
14
+ # Add parent directory to path for backend imports
15
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+ from PyQt6.QtWidgets import (
17
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
18
+ QFileDialog, QGroupBox, QFormLayout, QSpinBox, QComboBox, QProgressBar,
19
+ QTextEdit, QMessageBox, QListWidget, QListWidgetItem, QDialog, QDialogButtonBox,
20
+ QSlider, QApplication, QCheckBox, QProgressDialog
21
+ )
22
+ from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer
23
+ from PyQt6.QtGui import QImage, QPixmap
24
+ from singlebehaviorlab.backend.video_processor import process_video, process_video_to_clips, load_segmentation_data
25
+ from singlebehaviorlab.backend.model import VideoPrismBackbone
26
+
27
+
28
+ class VideoPlayerDialog(QDialog):
29
+ """Dialog for playing video files using OpenCV with streaming playback (no full caching)."""
30
+ def __init__(self, video_paths: list, start_index: int = 0, parent=None):
31
+ super().__init__(parent)
32
+ self.video_paths = video_paths
33
+ self.current_video_idx = start_index
34
+
35
+ self.cap = None # VideoCapture for streaming
36
+ self.total_frames = 0
37
+ self.fps = 30.0
38
+ self.current_frame_idx = 0
39
+ self.is_playing = False
40
+ self.timer = QTimer(self)
41
+ self.timer.timeout.connect(self._advance_frame)
42
+ self.slider_pressed = False
43
+ self._current_frame = None # Keep reference for QImage
44
+
45
+ self._setup_ui()
46
+ self._load_current_video()
47
+
48
+ def _setup_ui(self):
49
+ self.setMinimumSize(800, 600)
50
+ layout = QVBoxLayout(self)
51
+
52
+ # Navigation
53
+ nav_layout = QHBoxLayout()
54
+ self.prev_btn = QPushButton("Previous clip")
55
+ self.prev_btn.clicked.connect(self._prev_video)
56
+ nav_layout.addWidget(self.prev_btn)
57
+
58
+ self.title_label = QLabel()
59
+ self.title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
60
+ nav_layout.addWidget(self.title_label)
61
+
62
+ self.next_btn = QPushButton("Next clip")
63
+ self.next_btn.clicked.connect(self._next_video)
64
+ nav_layout.addWidget(self.next_btn)
65
+ layout.addLayout(nav_layout)
66
+
67
+ # Video label
68
+ self.video_label = QLabel()
69
+ self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
70
+ self.video_label.setMinimumSize(640, 480)
71
+ self.video_label.setStyleSheet("background-color: black;")
72
+ layout.addWidget(self.video_label)
73
+
74
+ # Slider
75
+ self.slider = QSlider(Qt.Orientation.Horizontal)
76
+ self.slider.sliderPressed.connect(self._on_slider_pressed)
77
+ self.slider.sliderReleased.connect(self._on_slider_released)
78
+ self.slider.valueChanged.connect(self._on_slider_changed)
79
+ layout.addWidget(self.slider)
80
+
81
+ # Controls
82
+ controls_layout = QHBoxLayout()
83
+
84
+ self.play_btn = QPushButton("Play")
85
+ self.play_btn.clicked.connect(self._toggle_play)
86
+ controls_layout.addWidget(self.play_btn)
87
+
88
+ self.stop_btn = QPushButton("Stop")
89
+ self.stop_btn.clicked.connect(self._stop)
90
+ controls_layout.addWidget(self.stop_btn)
91
+
92
+ self.frame_label = QLabel("Frame: 0 / 0")
93
+ controls_layout.addWidget(self.frame_label)
94
+
95
+ controls_layout.addStretch()
96
+
97
+ close_btn = QPushButton("Close")
98
+ close_btn.clicked.connect(self.accept)
99
+ controls_layout.addWidget(close_btn)
100
+
101
+ layout.addLayout(controls_layout)
102
+
103
+ def _load_current_video(self):
104
+ """Load current video for streaming playback."""
105
+ if not self.video_paths:
106
+ return
107
+
108
+ video_path = self.video_paths[self.current_video_idx]
109
+ self.setWindowTitle(f"Video Player - {os.path.basename(video_path)}")
110
+ self.title_label.setText(f"{self.current_video_idx + 1} / {len(self.video_paths)}: {os.path.basename(video_path)}")
111
+
112
+ # Update nav buttons
113
+ self.prev_btn.setEnabled(self.current_video_idx > 0)
114
+ self.next_btn.setEnabled(self.current_video_idx < len(self.video_paths) - 1)
115
+
116
+ # Stop playback and release old capture
117
+ self._stop()
118
+ if self.cap is not None:
119
+ self.cap.release()
120
+ self.cap = None
121
+
122
+ self._load_video_content(video_path)
123
+
124
+ def _prev_video(self):
125
+ if self.current_video_idx > 0:
126
+ self.current_video_idx -= 1
127
+ self._load_current_video()
128
+
129
+ def _next_video(self):
130
+ if self.current_video_idx < len(self.video_paths) - 1:
131
+ self.current_video_idx += 1
132
+ self._load_current_video()
133
+
134
+ def _load_video_content(self, video_path):
135
+ """Open video for streaming playback (memory efficient)."""
136
+ try:
137
+ self.video_label.setText("Loading video...")
138
+ QApplication.processEvents()
139
+
140
+ self.cap = cv2.VideoCapture(video_path)
141
+ if not self.cap.isOpened():
142
+ QMessageBox.critical(self, "Error", f"Could not open video: {video_path}")
143
+ return
144
+
145
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS) or 30.0
146
+ self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
147
+
148
+ if self.total_frames <= 0:
149
+ QMessageBox.warning(self, "Error", "Could not determine video length.")
150
+ return
151
+
152
+ # Setup slider and labels
153
+ self.slider.setRange(0, self.total_frames - 1)
154
+ self.frame_label.setText(f"Frame: 1 / {self.total_frames}")
155
+
156
+ # Setup timer
157
+ interval = int(1000 / self.fps) if self.fps > 0 else 33
158
+ self.timer.setInterval(interval)
159
+
160
+ # Show first frame
161
+ self._display_frame(0)
162
+
163
+ except Exception as e:
164
+ QMessageBox.critical(self, "Error", f"Failed to load video: {str(e)}")
165
+
166
+ def _display_frame(self, idx):
167
+ """Display frame at index by seeking in video (streaming)."""
168
+ if self.cap is None or not self.cap.isOpened():
169
+ return
170
+ if not (0 <= idx < self.total_frames):
171
+ return
172
+
173
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
174
+ ret, frame = self.cap.read()
175
+ if not ret:
176
+ return
177
+
178
+ # Convert BGR to RGB and keep reference to prevent QImage crash
179
+ self._current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
180
+ if not self._current_frame.flags['C_CONTIGUOUS']:
181
+ self._current_frame = np.ascontiguousarray(self._current_frame)
182
+
183
+ h, w, ch = self._current_frame.shape
184
+ bytes_per_line = ch * w
185
+
186
+ q_image = QImage(self._current_frame.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
187
+ pixmap = QPixmap.fromImage(q_image)
188
+
189
+ scaled_pixmap = pixmap.scaled(
190
+ self.video_label.size(),
191
+ Qt.AspectRatioMode.KeepAspectRatio,
192
+ Qt.TransformationMode.SmoothTransformation
193
+ )
194
+ self.video_label.setPixmap(scaled_pixmap)
195
+ self.frame_label.setText(f"Frame: {idx + 1} / {self.total_frames}")
196
+
197
+ def _advance_frame(self):
198
+ """Advance to next frame."""
199
+ if not self.is_playing or self.total_frames <= 0:
200
+ return
201
+
202
+ self.current_frame_idx = (self.current_frame_idx + 1) % self.total_frames
203
+ self.slider.blockSignals(True)
204
+ self.slider.setValue(self.current_frame_idx)
205
+ self.slider.blockSignals(False)
206
+ self._display_frame(self.current_frame_idx)
207
+
208
+ def _on_slider_changed(self, value):
209
+ """Handle slider value change."""
210
+ self.current_frame_idx = value
211
+ self._display_frame(value)
212
+
213
+ def _on_slider_pressed(self):
214
+ """Handle slider press (pause playback)."""
215
+ self.slider_pressed = True
216
+ if self.is_playing:
217
+ self.timer.stop()
218
+
219
+ def _on_slider_released(self):
220
+ """Handle slider release (resume if was playing)."""
221
+ self.slider_pressed = False
222
+ if self.is_playing:
223
+ self.timer.start()
224
+
225
+ def _toggle_play(self):
226
+ """Toggle play/pause."""
227
+ if self.is_playing:
228
+ self.timer.stop()
229
+ self.is_playing = False
230
+ self.play_btn.setText("Play")
231
+ else:
232
+ self.timer.start()
233
+ self.is_playing = True
234
+ self.play_btn.setText("Pause")
235
+
236
+ def _stop(self):
237
+ """Stop playback and reset."""
238
+ self.timer.stop()
239
+ self.is_playing = False
240
+ self.play_btn.setText("Play")
241
+ self.current_frame_idx = 0
242
+ self.slider.setValue(0)
243
+ self._display_frame(0)
244
+
245
+ def closeEvent(self, event):
246
+ """Clean up."""
247
+ self.timer.stop()
248
+ if self.cap is not None:
249
+ self.cap.release()
250
+ self.cap = None
251
+ self._current_frame = None
252
+ super().closeEvent(event)
253
+
254
+
255
+ class ProcessingWorker(QThread):
256
+ """Worker thread for video processing into clips."""
257
+ progress = pyqtSignal(int, int, str) # current, total, video_name
258
+ finished = pyqtSignal(list) # List of output clip paths
259
+ error = pyqtSignal(str)
260
+
261
+ def __init__(self, video_mask_pairs, output_dir, params):
262
+ super().__init__()
263
+ self.video_mask_pairs = video_mask_pairs
264
+ self.output_dir = output_dir
265
+ self.params = params
266
+ self.should_stop = False
267
+ self.output_paths = []
268
+
269
+ def stop(self):
270
+ self.should_stop = True
271
+
272
+ def run(self):
273
+ try:
274
+ total = len(self.video_mask_pairs)
275
+ for i, (video_path, mask_path) in enumerate(self.video_mask_pairs):
276
+ if self.should_stop:
277
+ break
278
+
279
+ video_name = os.path.basename(video_path)
280
+ self.progress.emit(i + 1, total, video_name)
281
+
282
+ # Generate output directory for clips
283
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
284
+ video_clips_dir = os.path.join(self.output_dir, base_name)
285
+
286
+ # Progress callback
287
+ def progress_cb(clip_num, total_clips=None, obj_id=None):
288
+ if self.should_stop:
289
+ return
290
+
291
+ # Process video into clips (may return list if multiple objects)
292
+ clip_paths = process_video_to_clips(
293
+ video_path=video_path,
294
+ mask_path=mask_path,
295
+ output_dir=video_clips_dir,
296
+ box_size=self.params['box_size'],
297
+ target_size=self.params['target_size'],
298
+ background_mode=self.params['background_mode'],
299
+ normalization_method=self.params['normalization_method'],
300
+ mask_feather_px=self.params['mask_feather_px'],
301
+ anchor_mode=self.params['anchor_mode'],
302
+ target_fps=self.params['target_fps'],
303
+ clip_length_frames=self.params['clip_length_frames'],
304
+ step_frames=self.params['step_frames'],
305
+ progress_callback=progress_cb
306
+ )
307
+
308
+ if clip_paths:
309
+ # clip_paths is now list of (clip_path, start_frame, end_frame) tuples
310
+ # Store them for later use in embedding extraction
311
+ self.output_paths.extend(clip_paths)
312
+ else:
313
+ self.error.emit(f"Failed to process {video_name} - no clips created")
314
+
315
+ if not self.should_stop:
316
+ self.finished.emit(self.output_paths)
317
+
318
+ except Exception as e:
319
+ self.error.emit(str(e))
320
+
321
+
322
+ class EmbeddingExtractionWorker(QThread):
323
+ """Worker thread for extracting VideoPrism embeddings from clips."""
324
+ progress = pyqtSignal(int, int, str) # current, total, clip_name
325
+ finished = pyqtSignal(str, str) # feature_matrix_path, metadata_path
326
+ error = pyqtSignal(str)
327
+ log_message = pyqtSignal(str)
328
+
329
+ def __init__(self, clip_paths: list, output_dir: str, experiment_name: str = None, model_name: str = 'videoprism_public_v1_base', clip_frame_ranges: dict = None, append_to_existing: bool = False):
330
+ super().__init__()
331
+ self.clip_paths = clip_paths # List of clip paths (strings)
332
+ self.clip_frame_ranges = clip_frame_ranges or {} # Dict mapping clip_path -> (start_frame, end_frame)
333
+ self.output_dir = output_dir
334
+ self.experiment_name = experiment_name
335
+ self.model_name = model_name
336
+ self.should_stop = False
337
+ self.append_to_existing = append_to_existing
338
+
339
+ def stop(self):
340
+ self.should_stop = True
341
+
342
+ def run(self):
343
+ try:
344
+ self.log_message.emit(f"Loading VideoPrism model: {self.model_name}...")
345
+
346
+ # Clear PyTorch cache to free up VRAM for JAX
347
+ if torch.cuda.is_available():
348
+ torch.cuda.empty_cache()
349
+
350
+ backbone = VideoPrismBackbone(model_name=self.model_name, log_fn=self.log_message.emit)
351
+ backbone.eval()
352
+
353
+ embed_dim = backbone.get_embed_dim()
354
+ self.log_message.emit(f"VideoPrism model loaded. Embedding dimension: {embed_dim}")
355
+
356
+ feature_matrix = []
357
+ metadata = []
358
+
359
+ total = len(self.clip_paths)
360
+ self.log_message.emit(f"Processing {total} clips...")
361
+
362
+ for i, clip_path in enumerate(self.clip_paths):
363
+ if self.should_stop:
364
+ break
365
+
366
+ clip_name = os.path.basename(clip_path)
367
+ self.progress.emit(i + 1, total, clip_name)
368
+
369
+ # Load clip frames
370
+ frames = self._load_clip_frames(clip_path)
371
+ if frames is None or len(frames) == 0:
372
+ self.log_message.emit(f"Warning: Could not load frames from {clip_name}, skipping")
373
+ continue
374
+
375
+ # Extract embedding
376
+ embedding = self._extract_embedding(backbone, frames)
377
+
378
+ # Free frames memory immediately after use
379
+ del frames
380
+
381
+ if embedding is None:
382
+ self.log_message.emit(f"Warning: Could not extract embedding from {clip_name}, skipping")
383
+ continue
384
+
385
+ feature_matrix.append(embedding.tolist())
386
+
387
+ # Free embedding after converting to list
388
+ del embedding
389
+
390
+ # Periodically clear CUDA cache to prevent GPU memory accumulation
391
+ if (i + 1) % 50 == 0 and torch.cuda.is_available():
392
+ torch.cuda.empty_cache()
393
+ import gc
394
+ gc.collect()
395
+
396
+ # Parse clip name for metadata
397
+ # Format: clip_XXXXXX_objY.mp4 or clip_XXXXXX.mp4
398
+ base_name = os.path.splitext(clip_name)[0]
399
+ video_dir = os.path.dirname(clip_path)
400
+ video_name = os.path.basename(video_dir)
401
+
402
+ # Extract object ID if present
403
+ obj_match = re.search(r'_obj(\d+)', base_name)
404
+ obj_id = obj_match.group(1) if obj_match else None
405
+
406
+ # Extract clip index
407
+ clip_match = re.search(r'clip_(\d+)', base_name)
408
+ clip_idx = int(clip_match.group(1)) if clip_match else i
409
+
410
+ # Get frame range if available
411
+ start_frame = None
412
+ end_frame = None
413
+ if clip_path in self.clip_frame_ranges:
414
+ start_frame, end_frame = self.clip_frame_ranges[clip_path]
415
+
416
+ metadata.append({
417
+ 'snippet': f'snippet{i+1}',
418
+ 'group': video_name,
419
+ 'video_id': clip_name,
420
+ 'object_id': obj_id if obj_id else '',
421
+ 'clip_index': clip_idx,
422
+ 'start_frame': start_frame if start_frame is not None else '',
423
+ 'end_frame': end_frame if end_frame is not None else ''
424
+ })
425
+
426
+ if self.should_stop:
427
+ self.log_message.emit("Extraction stopped by user.")
428
+ return
429
+
430
+ if not feature_matrix:
431
+ self.error.emit("No embeddings extracted. Please check that clips are valid.")
432
+ return
433
+
434
+ # Convert to numpy array
435
+ feature_matrix = np.array(feature_matrix, dtype=np.float32) # downcast to save space
436
+ self.log_message.emit(f"Extracted {len(feature_matrix)} embeddings. Shape: {feature_matrix.shape}")
437
+
438
+ # Prepare data
439
+ num_snippets = feature_matrix.shape[0]
440
+ num_features = feature_matrix.shape[1]
441
+ metadata_df = pd.DataFrame(metadata)
442
+
443
+ # Decide whether to append to existing embeddings
444
+ append_used = False
445
+ existing_matrix_path = None
446
+ existing_metadata_path = None
447
+ if self.append_to_existing:
448
+ candidates = sorted(glob.glob(os.path.join(self.output_dir, "*_matrix.npz")), key=os.path.getmtime, reverse=True)
449
+ if candidates:
450
+ existing_matrix_path = candidates[0]
451
+ guess_meta = existing_matrix_path.replace("_matrix.npz", "_metadata.npz")
452
+ if os.path.exists(guess_meta):
453
+ existing_metadata_path = guess_meta
454
+ if not existing_matrix_path or not existing_metadata_path:
455
+ self.log_message.emit("Append requested but no existing matrix/metadata found. Creating new files.")
456
+
457
+ snippet_ids = [f'snippet{i+1}' for i in range(num_snippets)]
458
+ feature_names = [f'behaviorome_embedding_{i}' for i in range(num_features)]
459
+
460
+ # If appending, load existing and merge
461
+ if existing_matrix_path and existing_metadata_path:
462
+ try:
463
+ existing_npz = np.load(existing_matrix_path, allow_pickle=True)
464
+ existing_matrix = existing_npz["matrix"] # features x snippets
465
+ existing_feature_names = existing_npz["feature_names"]
466
+ existing_snippet_ids = existing_npz["snippet_ids"]
467
+
468
+ # Validate feature dimension
469
+ if existing_matrix.shape[0] != feature_matrix.shape[1]:
470
+ self.log_message.emit("Append skipped: feature dimension mismatch. Writing new files instead.")
471
+ else:
472
+ # Load existing metadata
473
+ existing_meta_npz = np.load(existing_metadata_path, allow_pickle=True)
474
+ existing_meta = pd.DataFrame(existing_meta_npz["metadata"], columns=existing_meta_npz["columns"])
475
+
476
+ offset = existing_matrix.shape[1]
477
+ snippet_ids = [f'snippet{offset + i + 1}' for i in range(num_snippets)]
478
+ metadata_df['snippet'] = snippet_ids
479
+
480
+ # Combine matrices and metadata
481
+ feature_matrix = feature_matrix.T # to features x new_snippets
482
+ combined_matrix = np.concatenate([existing_matrix, feature_matrix], axis=1)
483
+ combined_snippet_ids = np.concatenate([existing_snippet_ids, np.array(snippet_ids)])
484
+ combined_feature_names = existing_feature_names
485
+ combined_meta = pd.concat([existing_meta, metadata_df], ignore_index=True)
486
+
487
+ # Overwrite base name with existing
488
+ base_name = os.path.basename(existing_matrix_path).replace("_matrix.npz", "")
489
+ npz_matrix_path = existing_matrix_path
490
+ npz_metadata_path = existing_metadata_path
491
+
492
+ feature_matrix = combined_matrix.T # convert back to snippets x features for downstream naming
493
+ feature_names = combined_feature_names.tolist()
494
+ snippet_ids = combined_snippet_ids.tolist()
495
+ metadata_df = combined_meta
496
+ append_used = True
497
+ self.log_message.emit(f"Appending to existing embeddings: {os.path.basename(existing_matrix_path)}")
498
+ except Exception as e:
499
+ self.log_message.emit(f"Append failed, writing new files instead: {e}")
500
+
501
+ if not append_used:
502
+ # Generate filename with experiment name and timestamp
503
+ from datetime import datetime
504
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
505
+
506
+ # Build base filename
507
+ if self.experiment_name:
508
+ base_name = f"behaviorome_{self.experiment_name}_{timestamp}"
509
+ else:
510
+ base_name = f"behaviorome_{timestamp}"
511
+
512
+ # Build paths
513
+ npz_matrix_path = os.path.join(self.output_dir, f'{base_name}_matrix.npz')
514
+ npz_metadata_path = os.path.join(self.output_dir, f'{base_name}_metadata.npz')
515
+
516
+ # Save as NPZ (fastest, most efficient for large datasets)
517
+
518
+ try:
519
+ np.savez_compressed(
520
+ npz_matrix_path,
521
+ matrix=feature_matrix.T, # features x snippets
522
+ feature_names=np.array(feature_names),
523
+ snippet_ids=np.array(snippet_ids),
524
+ )
525
+ self.log_message.emit(f"Saved feature matrix (NPZ) to {npz_matrix_path}")
526
+ except Exception as e:
527
+ self.log_message.emit(f"NPZ save failed (matrix): {e}")
528
+ npz_matrix_path = None
529
+
530
+ try:
531
+ # Save metadata as NPZ
532
+ np.savez_compressed(
533
+ npz_metadata_path,
534
+ metadata=metadata_df.values,
535
+ columns=np.array(metadata_df.columns),
536
+ )
537
+ self.log_message.emit(f"Saved metadata (NPZ) to {npz_metadata_path}")
538
+ except Exception as e:
539
+ self.log_message.emit(f"NPZ save failed (metadata): {e}")
540
+ npz_metadata_path = None
541
+
542
+ # Also save Parquet as backup (faster than CSV, still readable)
543
+ try:
544
+ matrix_df = pd.DataFrame(feature_matrix.T, index=feature_names, columns=snippet_ids)
545
+ parquet_matrix_path = os.path.join(self.output_dir, f'{base_name}_matrix.parquet')
546
+ matrix_df.to_parquet(parquet_matrix_path, index=True)
547
+ self.log_message.emit(f"Saved feature matrix (Parquet) to {parquet_matrix_path}")
548
+ except Exception as e:
549
+ self.log_message.emit(f"Parquet save failed (matrix): {e}")
550
+
551
+ try:
552
+ parquet_metadata_path = os.path.join(self.output_dir, f'{base_name}_metadata.parquet')
553
+ metadata_df.to_parquet(parquet_metadata_path, index=False)
554
+ self.log_message.emit(f"Saved metadata (Parquet) to {parquet_metadata_path}")
555
+ except Exception as e:
556
+ self.log_message.emit(f"Parquet save failed (metadata): {e}")
557
+
558
+ # Emit NPZ paths (primary format)
559
+ self.finished.emit(npz_matrix_path if npz_matrix_path else parquet_matrix_path,
560
+ npz_metadata_path if npz_metadata_path else parquet_metadata_path)
561
+
562
+ except Exception as e:
563
+ import traceback
564
+ error_msg = f"Error extracting embeddings: {str(e)}\n{traceback.format_exc()}"
565
+ self.log_message.emit(error_msg)
566
+ self.error.emit(str(e))
567
+
568
+ def _load_clip_frames(self, clip_path: str) -> np.ndarray:
569
+ """Load all frames from a clip video."""
570
+ cap = cv2.VideoCapture(clip_path)
571
+ if not cap.isOpened():
572
+ return None
573
+
574
+ frames = []
575
+ while True:
576
+ ret, frame = cap.read()
577
+ if not ret:
578
+ break
579
+ # Convert BGR to RGB
580
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
581
+ frames.append(frame_rgb)
582
+
583
+ cap.release()
584
+ return np.array(frames) if frames else None
585
+
586
+ def _extract_embedding(self, backbone: VideoPrismBackbone, frames: np.ndarray) -> np.ndarray:
587
+ """Extract mean-pooled VideoPrism embedding from frames."""
588
+ try:
589
+ # Resize frames to 288x288 (VideoPrism expects this)
590
+ target_size = 288
591
+ processed_frames = []
592
+ for frame in frames:
593
+ resized = cv2.resize(frame, (target_size, target_size))
594
+ processed_frames.append(resized)
595
+ frames_resized = np.array(processed_frames)
596
+ del processed_frames # Free list memory
597
+
598
+ # Convert to PyTorch format: (T, C, H, W) and normalize to [0, 1]
599
+ frames_t = np.transpose(frames_resized, (0, 3, 1, 2)) # (T, C, H, W)
600
+ del frames_resized # Free numpy array
601
+
602
+ frames_tensor = torch.from_numpy(frames_t).float() / 255.0
603
+ del frames_t # Free numpy array
604
+
605
+ # Add batch dimension: (1, T, C, H, W)
606
+ frames_tensor = frames_tensor.unsqueeze(0)
607
+
608
+ with torch.no_grad():
609
+ # VideoPrism returns (B, T*N, D) where N = 16*16 = 256
610
+ tokens = backbone(frames_tensor) # (1, T*256, D)
611
+ del frames_tensor # Free input tensor immediately
612
+
613
+ # Mean pool over all tokens to get single embedding vector
614
+ embedding = tokens.mean(dim=1).squeeze(0) # (D,)
615
+ del tokens # Free large token tensor
616
+
617
+ result = embedding.cpu().numpy()
618
+ del embedding # Free GPU tensor
619
+ return result
620
+
621
+ except Exception as e:
622
+ self.log_message.emit(f"Error extracting embedding: {e}")
623
+ return None
624
+
625
+
626
+ class RegistrationWidget(QWidget):
627
+ """Widget for animal registration video processing."""
628
+
629
+ # Signal emitted when embeddings are extracted (for auto-loading in clustering tab)
630
+ embeddings_extracted = pyqtSignal(str, str) # matrix_path, metadata_path
631
+
632
+ def __init__(self, config: dict):
633
+ super().__init__()
634
+ self.config = config
635
+ self.video_mask_pairs = [] # List of (video_path, mask_path) tuples
636
+ self.worker = None
637
+ self.embedding_worker = None
638
+ self.output_dir = "" # Initialize output_dir
639
+ self.processed_videos = [] # List of processed video paths
640
+ self.clip_frame_ranges = {} # Dict mapping clip_path -> (start_frame, end_frame)
641
+ self._setup_ui()
642
+
643
+ def load_from_segmentation(self, video_path: str, mask_path: str):
644
+ """Load video and mask from segmentation tab."""
645
+ # Clear existing
646
+ self.video_list.clear()
647
+ self.mask_list.clear()
648
+ self.video_mask_pairs = []
649
+
650
+ # Add video
651
+ if os.path.exists(video_path):
652
+ item = QListWidgetItem(os.path.basename(video_path))
653
+ item.setData(Qt.ItemDataRole.UserRole, video_path)
654
+ self.video_list.addItem(item)
655
+
656
+ # Add mask
657
+ if os.path.exists(mask_path):
658
+ item = QListWidgetItem(os.path.basename(mask_path))
659
+ item.setData(Qt.ItemDataRole.UserRole, mask_path)
660
+ self.mask_list.addItem(item)
661
+
662
+ # Auto-match
663
+ self._match_videos_masks()
664
+
665
+ # Update output directory display
666
+ self._update_pairs_label()
667
+
668
+ # Check if clips already exist and enable extract embeddings button
669
+ self._check_existing_clips()
670
+
671
+ def _setup_ui(self):
672
+ layout = QVBoxLayout()
673
+
674
+ # Top row: File Selection and Processing Parameters side by side
675
+ top_row_layout = QHBoxLayout()
676
+
677
+ # File Selection (left side)
678
+ files_group = QGroupBox("File selection")
679
+ files_layout = QVBoxLayout()
680
+
681
+ # Videos
682
+ video_layout = QHBoxLayout()
683
+ self.video_list = QListWidget()
684
+ self.video_list.setMaximumHeight(100)
685
+ video_layout.addWidget(self.video_list)
686
+
687
+ video_btn_layout = QVBoxLayout()
688
+ self.add_video_btn = QPushButton("Add videos")
689
+ self.add_video_btn.clicked.connect(self._add_videos)
690
+ video_btn_layout.addWidget(self.add_video_btn)
691
+
692
+ self.remove_video_btn = QPushButton("Remove selected")
693
+ self.remove_video_btn.clicked.connect(self._remove_video)
694
+ video_btn_layout.addWidget(self.remove_video_btn)
695
+ video_btn_layout.addStretch()
696
+
697
+ video_layout.addLayout(video_btn_layout)
698
+ files_layout.addLayout(video_layout)
699
+
700
+ # Masks
701
+ mask_layout = QHBoxLayout()
702
+ self.mask_list = QListWidget()
703
+ self.mask_list.setMaximumHeight(100)
704
+ mask_layout.addWidget(self.mask_list)
705
+
706
+ mask_btn_layout = QVBoxLayout()
707
+ self.add_mask_btn = QPushButton("Add mask files")
708
+ self.add_mask_btn.clicked.connect(self._add_masks)
709
+ mask_btn_layout.addWidget(self.add_mask_btn)
710
+
711
+ self.remove_mask_btn = QPushButton("Remove selected")
712
+ self.remove_mask_btn.clicked.connect(self._remove_mask)
713
+ mask_btn_layout.addWidget(self.remove_mask_btn)
714
+ mask_btn_layout.addStretch()
715
+
716
+ mask_layout.addLayout(mask_btn_layout)
717
+ files_layout.addLayout(mask_layout)
718
+
719
+ # Help button and pairs label
720
+ help_layout = QHBoxLayout()
721
+ self.pairs_label = QLabel("0 video-mask pairs ready")
722
+ help_layout.addWidget(self.pairs_label)
723
+ help_layout.addStretch()
724
+
725
+ # Help button (small circular button)
726
+ self.help_btn = QPushButton("?")
727
+ self.help_btn.setMaximumSize(25, 25)
728
+ self.help_btn.setToolTip("Click for naming information")
729
+ self.help_btn.clicked.connect(self._show_naming_help)
730
+ help_layout.addWidget(self.help_btn)
731
+
732
+ files_layout.addLayout(help_layout)
733
+
734
+ # Load Clips Folder button for extracting embeddings from existing clips
735
+ self.load_clips_btn = QPushButton("Already have clips? Load")
736
+ self.load_clips_btn.clicked.connect(self._load_clips_folder)
737
+ self.load_clips_btn.setToolTip("Load a folder of existing processed clips to extract behaviorome embeddings")
738
+ files_layout.addWidget(self.load_clips_btn)
739
+
740
+ files_group.setLayout(files_layout)
741
+ top_row_layout.addWidget(files_group)
742
+
743
+ # Processing Parameters (right side)
744
+ params_container = QVBoxLayout()
745
+
746
+ # Processing Parameters
747
+ params_group = QGroupBox("Processing parameters")
748
+ params_layout = QFormLayout()
749
+
750
+ self.box_size_spin = QSpinBox()
751
+ self.box_size_spin.setRange(50, 1000)
752
+ self.box_size_spin.setValue(250)
753
+ params_layout.addRow("Crop Box Size (px):", self.box_size_spin)
754
+
755
+ self.target_size_spin = QSpinBox()
756
+ self.target_size_spin.setRange(50, 1000)
757
+ self.target_size_spin.setValue(288)
758
+ params_layout.addRow("Output Size (px):", self.target_size_spin)
759
+
760
+ self.background_combo = QComboBox()
761
+ self.background_combo.addItems(["white", "black", "gray", "blur", "none"])
762
+ self.background_combo.setCurrentText("white")
763
+ params_layout.addRow("Background Mode:", self.background_combo)
764
+
765
+ self.normalization_combo = QComboBox()
766
+ self.normalization_combo.addItems(["CLAHE", "Histogram Equalization", "Mean-Variance", "None"])
767
+ self.normalization_combo.setCurrentText("CLAHE")
768
+ params_layout.addRow("Normalization:", self.normalization_combo)
769
+
770
+ self.feather_spin = QSpinBox()
771
+ self.feather_spin.setRange(0, 50)
772
+ self.feather_spin.setValue(0)
773
+ params_layout.addRow("Mask Feathering (px):", self.feather_spin)
774
+
775
+ # Anchor mode checkbox
776
+ self.lock_roi_checkbox = QCheckBox("Lock ROI to first frame of clip")
777
+ self.lock_roi_checkbox.setChecked(False)
778
+ self.lock_roi_checkbox.setToolTip(
779
+ "When checked, the crop box stays fixed at the first frame's centroid,\n"
780
+ "allowing the object to move within the clip (preserves locomotion).\n"
781
+ "When unchecked, the crop box follows the object's centroid each frame."
782
+ )
783
+ params_layout.addRow("", self.lock_roi_checkbox)
784
+
785
+ params_group.setLayout(params_layout)
786
+ params_container.addWidget(params_group)
787
+
788
+ # Clip Extraction Parameters
789
+ clip_params_group = QGroupBox("Clip extraction parameters")
790
+ clip_params_layout = QFormLayout()
791
+
792
+ self.target_fps_spin = QSpinBox()
793
+ self.target_fps_spin.setRange(1, 60)
794
+ self.target_fps_spin.setValue(int(self.config.get("default_target_fps", 12)))
795
+ clip_params_layout.addRow("Target FPS:", self.target_fps_spin)
796
+
797
+ self.clip_length_spin = QSpinBox()
798
+ self.clip_length_spin.setRange(1, 64)
799
+ self.clip_length_spin.setValue(int(self.config.get("default_clip_length", 8)))
800
+ clip_params_layout.addRow("Frames per clip:", self.clip_length_spin)
801
+
802
+ self.step_frames_spin = QSpinBox()
803
+ self.step_frames_spin.setRange(1, 64)
804
+ self.step_frames_spin.setValue(int(self.config.get("default_step_frames", 8)))
805
+ clip_params_layout.addRow("Step frames:", self.step_frames_spin)
806
+
807
+ clip_params_group.setLayout(clip_params_layout)
808
+ params_container.addWidget(clip_params_group)
809
+
810
+ # Create a widget to hold the params container
811
+ params_widget = QWidget()
812
+ params_widget.setLayout(params_container)
813
+ top_row_layout.addWidget(params_widget)
814
+
815
+ layout.addLayout(top_row_layout)
816
+
817
+ # Output (auto-determined from experiment)
818
+ output_group = QGroupBox("Output")
819
+ output_layout = QVBoxLayout()
820
+
821
+ self.output_dir_label = QLabel("Clips will be saved to experiment folder")
822
+ output_layout.addWidget(self.output_dir_label)
823
+
824
+ self.append_embeddings_check = QCheckBox("Append to existing embeddings if present")
825
+ self.append_embeddings_check.setChecked(False)
826
+ self.append_embeddings_check.setToolTip("When enabled, if an existing behaviorome matrix/metadata is found in the experiment, new embeddings will be appended instead of creating a new file.")
827
+ output_layout.addWidget(self.append_embeddings_check)
828
+
829
+ output_group.setLayout(output_layout)
830
+ layout.addWidget(output_group)
831
+
832
+ # Actions
833
+ actions_group = QGroupBox("Actions")
834
+ actions_layout = QHBoxLayout()
835
+
836
+ self.process_btn = QPushButton("Process videos")
837
+ self.process_btn.clicked.connect(self._start_processing)
838
+ self.process_btn.setEnabled(False)
839
+ actions_layout.addWidget(self.process_btn)
840
+
841
+ self.stop_btn = QPushButton("Stop")
842
+ self.stop_btn.clicked.connect(self._stop_processing)
843
+ self.stop_btn.setEnabled(False)
844
+ actions_layout.addWidget(self.stop_btn)
845
+
846
+ self.view_videos_btn = QPushButton("View processed videos")
847
+ self.view_videos_btn.clicked.connect(self._view_processed_videos)
848
+ self.view_videos_btn.setEnabled(False)
849
+ actions_layout.addWidget(self.view_videos_btn)
850
+
851
+ self.extract_embeddings_btn = QPushButton("Extract behaviorome embeddings")
852
+ self.extract_embeddings_btn.clicked.connect(self._extract_embeddings)
853
+ self.extract_embeddings_btn.setEnabled(False)
854
+ actions_layout.addWidget(self.extract_embeddings_btn)
855
+
856
+ self.export_roi_btn = QPushButton("Export ROI videos (per object)")
857
+ self.export_roi_btn.setToolTip(
858
+ "Export one cropped video per tracked object.\n"
859
+ "Videos are saved in the experiment folder (roi_videos)."
860
+ )
861
+ self.export_roi_btn.clicked.connect(self._export_roi_videos)
862
+ self.export_roi_btn.setEnabled(False)
863
+ actions_layout.addWidget(self.export_roi_btn)
864
+
865
+ actions_group.setLayout(actions_layout)
866
+ layout.addWidget(actions_group)
867
+
868
+ # Progress
869
+ progress_group = QGroupBox("Progress")
870
+ progress_layout = QVBoxLayout()
871
+
872
+ self.progress_bar = QProgressBar()
873
+ progress_layout.addWidget(self.progress_bar)
874
+
875
+ self.status_label = QLabel("Ready")
876
+ progress_layout.addWidget(self.status_label)
877
+
878
+ self.log_text = QTextEdit()
879
+ self.log_text.setReadOnly(True)
880
+ self.log_text.setMaximumHeight(150)
881
+ progress_layout.addWidget(self.log_text)
882
+
883
+ progress_group.setLayout(progress_layout)
884
+ layout.addWidget(progress_group)
885
+
886
+ layout.addStretch()
887
+ self.setLayout(layout)
888
+
889
+ # Check for existing clips
890
+ self._check_existing_clips()
891
+
892
+ def _add_videos(self):
893
+ video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
894
+ paths, _ = QFileDialog.getOpenFileNames(
895
+ self, "Select Video Files", video_dir, "Video Files (*.mp4 *.avi *.mov *.mkv)"
896
+ )
897
+ if paths:
898
+ # Ensure videos are in experiment folder (batch operation)
899
+ from .video_utils import ensure_videos_in_experiment
900
+ paths = ensure_videos_in_experiment(paths, self.config, self)
901
+ for path in paths:
902
+ item = QListWidgetItem(os.path.basename(path))
903
+ item.setData(Qt.ItemDataRole.UserRole, path)
904
+ self.video_list.addItem(item)
905
+ # Auto-match after adding videos
906
+ self._match_videos_masks()
907
+ self._update_pairs_label()
908
+
909
+ def _remove_video(self):
910
+ for item in self.video_list.selectedItems():
911
+ self.video_list.takeItem(self.video_list.row(item))
912
+ # Re-match after removal
913
+ self._match_videos_masks()
914
+ self._update_pairs_label()
915
+
916
+ def _add_masks(self):
917
+ # Check experiment folder first, then default location
918
+ experiment_path = self.config.get("experiment_path")
919
+ if experiment_path and os.path.exists(experiment_path):
920
+ masks_dir = os.path.join(experiment_path, "masks")
921
+ else:
922
+ masks_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "masks")
923
+
924
+ paths, _ = QFileDialog.getOpenFileNames(
925
+ self, "Select Mask Files", masks_dir if os.path.exists(masks_dir) else "", "HDF5 Files (*.h5 *.hdf5)"
926
+ )
927
+ if paths:
928
+ for path in paths:
929
+ item = QListWidgetItem(os.path.basename(path))
930
+ item.setData(Qt.ItemDataRole.UserRole, path)
931
+ self.mask_list.addItem(item)
932
+ # Auto-match after adding masks
933
+ self._match_videos_masks()
934
+ self._update_pairs_label()
935
+
936
+ def _remove_mask(self):
937
+ for item in self.mask_list.selectedItems():
938
+ self.mask_list.takeItem(self.mask_list.row(item))
939
+ # Re-match after removal
940
+ self._match_videos_masks()
941
+ self._update_pairs_label()
942
+
943
+ def _match_videos_masks(self):
944
+ """Match videos to masks based on filename."""
945
+ self.video_mask_pairs = []
946
+
947
+ videos = []
948
+ for i in range(self.video_list.count()):
949
+ item = self.video_list.item(i)
950
+ videos.append((item.data(Qt.ItemDataRole.UserRole), item.text()))
951
+
952
+ masks = []
953
+ for i in range(self.mask_list.count()):
954
+ item = self.mask_list.item(i)
955
+ masks.append((item.data(Qt.ItemDataRole.UserRole), item.text()))
956
+
957
+ # Match by base name
958
+ matched = []
959
+ unmatched_videos = []
960
+ unmatched_masks = []
961
+
962
+ for video_path, video_name in videos:
963
+ base_video = os.path.splitext(video_name)[0]
964
+ found = False
965
+ for mask_path, mask_name in masks:
966
+ base_mask = os.path.splitext(mask_name)[0]
967
+ # Remove "_masks" suffix from mask name (most common case)
968
+ base_mask_clean = base_mask.replace('_masks', '').replace('_mask', '').replace('_objects', '')
969
+ base_video_clean = base_video.replace('_output', '').replace('_segmented', '')
970
+
971
+ # Try exact match first
972
+ if base_video == base_mask:
973
+ matched.append((video_path, mask_path))
974
+ found = True
975
+ break
976
+ # Try matching video name with mask name after removing "_masks"
977
+ elif base_video == base_mask_clean:
978
+ matched.append((video_path, mask_path))
979
+ found = True
980
+ break
981
+ # Try matching cleaned versions
982
+ elif base_video_clean == base_mask_clean:
983
+ matched.append((video_path, mask_path))
984
+ found = True
985
+ break
986
+
987
+ if not found:
988
+ unmatched_videos.append(video_name)
989
+
990
+ # Check for unmatched masks
991
+ matched_mask_names = {os.path.basename(m) for _, m in matched}
992
+ for mask_path, mask_name in masks:
993
+ if mask_name not in matched_mask_names:
994
+ unmatched_masks.append(mask_name)
995
+
996
+ self.video_mask_pairs = matched
997
+ # Matching happens automatically, no message box needed
998
+
999
+ def _update_pairs_label(self):
1000
+ count = len(self.video_mask_pairs)
1001
+ self.pairs_label.setText(f"{count} video-mask pairs ready")
1002
+ # Auto-determine output directory from experiment
1003
+ experiment_path = self.config.get("experiment_path")
1004
+ if experiment_path and os.path.exists(experiment_path):
1005
+ self.output_dir = os.path.join(experiment_path, "registered_clips")
1006
+ self.output_dir_label.setText(f"Output: {self.output_dir}")
1007
+ else:
1008
+ self.output_dir = ""
1009
+ self.output_dir_label.setText("No experiment folder - please create an experiment first")
1010
+ self.process_btn.setEnabled(count > 0 and bool(self.output_dir))
1011
+ self.export_roi_btn.setEnabled(count > 0)
1012
+
1013
+ def _show_naming_help(self):
1014
+ """Show help dialog about video-mask naming."""
1015
+ QMessageBox.information(
1016
+ self,
1017
+ "Video-Mask Naming",
1018
+ "Videos and masks are automatically matched based on their filenames.\n\n"
1019
+ "Naming rules:\n"
1020
+ "-Video: video_name.mp4\n"
1021
+ "-Mask: video_name.h5 (or video_name_masks.h5)\n\n"
1022
+ "The mask filename should match the video filename (without extension),\n"
1023
+ "or have '_masks' suffix that will be automatically removed.\n\n"
1024
+ "Example:\n"
1025
+ "-Video: my_video.mp4\n"
1026
+ "- Mask: my_video.h5\n"
1027
+ "- Mask: my_video_masks.h5"
1028
+ )
1029
+
1030
+ def _start_processing(self):
1031
+ if not self.video_mask_pairs:
1032
+ QMessageBox.warning(self, "Error", "Please add videos and masks first.")
1033
+ return
1034
+
1035
+ # Auto-determine output directory from experiment
1036
+ experiment_path = self.config.get("experiment_path")
1037
+ if not experiment_path or not os.path.exists(experiment_path):
1038
+ QMessageBox.warning(self, "Error", "No experiment folder found. Please create an experiment first.")
1039
+ return
1040
+
1041
+ output_dir = os.path.join(experiment_path, "registered_clips")
1042
+ os.makedirs(output_dir, exist_ok=True)
1043
+ self.output_dir = output_dir
1044
+ self.output_dir_label.setText(f"Output: {output_dir}")
1045
+
1046
+ # Get parameters
1047
+ params = {
1048
+ 'box_size': self.box_size_spin.value(),
1049
+ 'target_size': self.target_size_spin.value(),
1050
+ 'background_mode': self.background_combo.currentText(),
1051
+ 'normalization_method': self.normalization_combo.currentText(),
1052
+ 'mask_feather_px': self.feather_spin.value(),
1053
+ 'anchor_mode': 'first' if self.lock_roi_checkbox.isChecked() else 'frame',
1054
+ 'target_fps': self.target_fps_spin.value(),
1055
+ 'clip_length_frames': self.clip_length_spin.value(),
1056
+ 'step_frames': self.step_frames_spin.value()
1057
+ }
1058
+
1059
+ # Create worker
1060
+ self.worker = ProcessingWorker(self.video_mask_pairs, self.output_dir, params)
1061
+ self.worker.progress.connect(self._on_progress)
1062
+ self.worker.finished.connect(self._on_finished)
1063
+ self.worker.error.connect(self._on_error)
1064
+
1065
+ self.process_btn.setEnabled(False)
1066
+ self.stop_btn.setEnabled(True)
1067
+ self.view_videos_btn.setEnabled(False)
1068
+ self.progress_bar.setMaximum(len(self.video_mask_pairs))
1069
+ self.progress_bar.setValue(0)
1070
+ self.log_text.clear()
1071
+ self.processed_videos = []
1072
+
1073
+ self.worker.start()
1074
+
1075
+ def _stop_processing(self):
1076
+ if self.worker:
1077
+ self.worker.stop()
1078
+ self.worker.wait()
1079
+ self.log_text.append("Processing stopped by user.")
1080
+ self.process_btn.setEnabled(True)
1081
+ self.stop_btn.setEnabled(False)
1082
+
1083
+ def _on_progress(self, current, total, video_name):
1084
+ self.progress_bar.setValue(current)
1085
+ self.status_label.setText(f"Processing: {video_name} ({current}/{total})")
1086
+
1087
+ def _on_finished(self, output_paths):
1088
+ self.log_text.append("=" * 50)
1089
+ self.log_text.append("All videos processed successfully!")
1090
+ self.log_text.append(f"Output directory: {self.output_dir}")
1091
+ self.log_text.append(f"Created {len(output_paths)} clip(s)")
1092
+
1093
+ # Extract clip paths and frame ranges from tuples
1094
+ clip_paths_list = []
1095
+ self.clip_frame_ranges = {}
1096
+ for item in output_paths:
1097
+ if isinstance(item, tuple) and len(item) == 3:
1098
+ clip_path, start_frame, end_frame = item
1099
+ clip_paths_list.append(clip_path)
1100
+ self.clip_frame_ranges[clip_path] = (start_frame, end_frame)
1101
+ else:
1102
+ # Legacy: just a path string
1103
+ clip_paths_list.append(item)
1104
+
1105
+ # Group clips by video (using extracted paths)
1106
+ clips_by_video = {}
1107
+ for path in clip_paths_list:
1108
+ video_dir = os.path.dirname(path)
1109
+ if video_dir not in clips_by_video:
1110
+ clips_by_video[video_dir] = []
1111
+ clips_by_video[video_dir].append(path)
1112
+
1113
+ for video_dir, clips in clips_by_video.items():
1114
+ video_name = os.path.basename(video_dir)
1115
+ self.log_text.append(f" {video_name}: {len(clips)} clip(s)")
1116
+
1117
+ self.processed_videos = clip_paths_list
1118
+ self.status_label.setText("Complete")
1119
+ self.process_btn.setEnabled(True)
1120
+ self.stop_btn.setEnabled(False)
1121
+ self.view_videos_btn.setEnabled(len(clip_paths_list) > 0)
1122
+ self.extract_embeddings_btn.setEnabled(len(clip_paths_list) > 0)
1123
+
1124
+ msg = f"All videos have been processed into clips.\n\nCreated {len(clip_paths_list)} clip(s) total.\n\nWould you like to view them now?"
1125
+ reply = QMessageBox.question(
1126
+ self, "Success", msg,
1127
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
1128
+ QMessageBox.StandardButton.Yes
1129
+ )
1130
+ if reply == QMessageBox.StandardButton.Yes:
1131
+ self._view_processed_videos()
1132
+
1133
+ def _on_error(self, error_msg):
1134
+ self.log_text.append(f"Error: {error_msg}")
1135
+ QMessageBox.critical(self, "Error", error_msg)
1136
+ self.process_btn.setEnabled(True)
1137
+ self.stop_btn.setEnabled(False)
1138
+
1139
+ def _view_processed_videos(self):
1140
+ """Open video player dialog for processed videos."""
1141
+ if not self.processed_videos:
1142
+ QMessageBox.information(self, "No Videos", "No processed videos available to view.")
1143
+ return
1144
+
1145
+ # If multiple videos, let user choose start, otherwise open directly
1146
+ if len(self.processed_videos) == 1:
1147
+ dialog = VideoPlayerDialog(self.processed_videos, 0, self)
1148
+ dialog.exec()
1149
+ else:
1150
+ # Multiple videos - let user choose which one to start with
1151
+ from PyQt6.QtWidgets import QListWidget, QDialogButtonBox
1152
+
1153
+ dialog = QDialog(self)
1154
+ dialog.setWindowTitle("Select video to view")
1155
+ dialog.setMinimumSize(400, 300)
1156
+ layout = QVBoxLayout(dialog)
1157
+
1158
+ label = QLabel("Select a video to view:")
1159
+ layout.addWidget(label)
1160
+
1161
+ list_widget = QListWidget()
1162
+ for path in self.processed_videos:
1163
+ list_widget.addItem(os.path.basename(path))
1164
+ layout.addWidget(list_widget)
1165
+
1166
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
1167
+ buttons.accepted.connect(dialog.accept)
1168
+ buttons.rejected.connect(dialog.reject)
1169
+ layout.addWidget(buttons)
1170
+
1171
+ if dialog.exec() == QDialog.DialogCode.Accepted:
1172
+ selected_items = list_widget.selectedItems()
1173
+ if selected_items:
1174
+ idx = list_widget.row(selected_items[0])
1175
+ player_dialog = VideoPlayerDialog(self.processed_videos, idx, self)
1176
+ player_dialog.exec()
1177
+
1178
+ def _load_clips_folder(self):
1179
+ """Load a folder of existing clips for embedding extraction."""
1180
+ folder_path = QFileDialog.getExistingDirectory(self, "Select Clips Folder")
1181
+ if not folder_path:
1182
+ return
1183
+
1184
+ clip_paths = []
1185
+ for root, dirs, files in os.walk(folder_path):
1186
+ for file in files:
1187
+ if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
1188
+ clip_paths.append(os.path.join(root, file))
1189
+
1190
+ if not clip_paths:
1191
+ QMessageBox.warning(self, "No Clips", f"No video clips found in {folder_path}")
1192
+ return
1193
+
1194
+ self.output_dir = folder_path
1195
+ self.processed_videos = sorted(clip_paths)
1196
+
1197
+ self.log_text.append("=" * 50)
1198
+ self.log_text.append(f"Loaded {len(clip_paths)} clips from: {folder_path}")
1199
+ self.status_label.setText(f"Loaded {len(clip_paths)} clips")
1200
+ self.extract_embeddings_btn.setEnabled(True)
1201
+ self.view_videos_btn.setEnabled(True)
1202
+
1203
+ # Ask to extract immediately
1204
+ reply = QMessageBox.question(
1205
+ self,
1206
+ "Extract Embeddings",
1207
+ f"Found {len(clip_paths)} clips.\nDo you want to extract embeddings now?",
1208
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
1209
+ QMessageBox.StandardButton.Yes
1210
+ )
1211
+
1212
+ if reply == QMessageBox.StandardButton.Yes:
1213
+ self._extract_embeddings()
1214
+
1215
+ def _extract_embeddings(self):
1216
+ """Extract VideoPrism embeddings from processed clips."""
1217
+ if not self.processed_videos:
1218
+ QMessageBox.warning(self, "No Clips", "No processed clips available. Please process videos first.")
1219
+ return
1220
+
1221
+ # Get output directory (where clips are stored)
1222
+ if not self.output_dir:
1223
+ experiment_path = self.config.get("experiment_path")
1224
+ if not experiment_path:
1225
+ QMessageBox.warning(self, "No Experiment", "No active experiment. Please create or load an experiment first.")
1226
+ return
1227
+ self.output_dir = os.path.join(experiment_path, "registered_clips")
1228
+
1229
+ if not os.path.exists(self.output_dir):
1230
+ QMessageBox.warning(self, "Directory Not Found", f"Clips directory not found: {self.output_dir}")
1231
+ return
1232
+
1233
+ # Collect all clip files from registered_clips directory
1234
+ clip_paths = []
1235
+ for root, dirs, files in os.walk(self.output_dir):
1236
+ for file in files:
1237
+ if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
1238
+ clip_paths.append(os.path.join(root, file))
1239
+
1240
+ if not clip_paths:
1241
+ QMessageBox.warning(self, "No Clips", f"No video clips found in {self.output_dir}")
1242
+ return
1243
+
1244
+ # Sort clips for consistent ordering
1245
+ clip_paths.sort()
1246
+
1247
+ # Get model name from config
1248
+ model_name = self.config.get("backbone_model", "videoprism_public_v1_base")
1249
+
1250
+ # Confirm with user
1251
+ reply = QMessageBox.question(
1252
+ self,
1253
+ "Extract Embeddings",
1254
+ f"Extract VideoPrism embeddings from {len(clip_paths)} clip(s)?\n\n"
1255
+ f"Model: {model_name}\n"
1256
+ f"Output directory: {self.output_dir}",
1257
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
1258
+ QMessageBox.StandardButton.Yes
1259
+ )
1260
+
1261
+ if reply != QMessageBox.StandardButton.Yes:
1262
+ return
1263
+
1264
+ # Get experiment name from config
1265
+ experiment_name = self.config.get("experiment_name", None)
1266
+
1267
+ # Start extraction worker with frame ranges if available
1268
+ self.embedding_worker = EmbeddingExtractionWorker(
1269
+ clip_paths,
1270
+ self.output_dir,
1271
+ experiment_name=experiment_name,
1272
+ model_name=model_name,
1273
+ clip_frame_ranges=self.clip_frame_ranges if hasattr(self, 'clip_frame_ranges') else None,
1274
+ append_to_existing=self.append_embeddings_check.isChecked()
1275
+ )
1276
+ self.embedding_worker.progress.connect(self._on_embedding_progress)
1277
+ self.embedding_worker.finished.connect(self._on_embedding_finished)
1278
+ self.embedding_worker.error.connect(self._on_embedding_error)
1279
+ self.embedding_worker.log_message.connect(self._on_embedding_log)
1280
+
1281
+ self.extract_embeddings_btn.setEnabled(False)
1282
+ self.progress_bar.setMaximum(len(clip_paths))
1283
+ self.progress_bar.setValue(0)
1284
+ self.log_text.clear()
1285
+ self.log_text.append(f"Starting VideoPrism embedding extraction from {len(clip_paths)} clips...")
1286
+
1287
+ self.embedding_worker.start()
1288
+
1289
+ def _on_embedding_progress(self, current, total, clip_name):
1290
+ self.progress_bar.setValue(current)
1291
+ self.status_label.setText(f"Extracting embeddings: {clip_name} ({current}/{total})")
1292
+
1293
+ def _on_embedding_finished(self, feature_matrix_path, metadata_path):
1294
+ self.log_text.append("=" * 50)
1295
+ self.log_text.append("Embedding extraction completed successfully!")
1296
+ self.log_text.append(f"Feature matrix: {feature_matrix_path}")
1297
+ self.log_text.append(f"Metadata: {metadata_path}")
1298
+
1299
+ self.status_label.setText("Embedding extraction complete")
1300
+ self.extract_embeddings_btn.setEnabled(True)
1301
+
1302
+ # Emit signal for auto-loading in clustering tab
1303
+ self.embeddings_extracted.emit(feature_matrix_path, metadata_path)
1304
+
1305
+ QMessageBox.information(
1306
+ self,
1307
+ "Success",
1308
+ f"Behaviorome embeddings extracted successfully!\n\n"
1309
+ f"Feature matrix: {os.path.basename(feature_matrix_path)}\n"
1310
+ f"Metadata: {os.path.basename(metadata_path)}\n\n"
1311
+ f"Files saved to: {os.path.dirname(feature_matrix_path)}\n\n"
1312
+ f"Data will be automatically loaded in the Clustering tab."
1313
+ )
1314
+
1315
+ def _on_embedding_error(self, error_msg):
1316
+ self.log_text.append(f"Error: {error_msg}")
1317
+ QMessageBox.critical(self, "Error", f"Embedding extraction failed:\n{error_msg}")
1318
+ self.extract_embeddings_btn.setEnabled(True)
1319
+ self.status_label.setText("Error")
1320
+
1321
+ def _on_embedding_log(self, message):
1322
+ self.log_text.append(message)
1323
+
1324
+ def _export_roi_videos(self):
1325
+ """Export one cropped video per tracked object for each video-mask pair."""
1326
+ if not self.video_mask_pairs:
1327
+ QMessageBox.warning(self, "No Data", "Add video and mask files first.")
1328
+ return
1329
+
1330
+ all_exported = []
1331
+ for video_path, mask_path in self.video_mask_pairs:
1332
+ try:
1333
+ mask_data = load_segmentation_data(mask_path)
1334
+ except Exception as e:
1335
+ QMessageBox.warning(self, "Load Error", f"Could not load mask {os.path.basename(mask_path)}: {e}")
1336
+ continue
1337
+
1338
+ frame_objects = mask_data.get("frame_objects", [])
1339
+ if not frame_objects:
1340
+ QMessageBox.warning(self, "No Frames", f"No mask frames in {os.path.basename(mask_path)}.")
1341
+ continue
1342
+
1343
+ start_offset = mask_data.get("start_offset", 0)
1344
+ all_obj_ids = set()
1345
+ for frame_objs in frame_objects:
1346
+ for obj in frame_objs:
1347
+ all_obj_ids.add(obj.get("obj_id", 0))
1348
+ all_obj_ids = sorted(all_obj_ids)
1349
+ if not all_obj_ids:
1350
+ continue
1351
+
1352
+ experiment_path = self.config.get("experiment_path")
1353
+ if experiment_path and os.path.exists(experiment_path):
1354
+ out_dir = os.path.join(experiment_path, "roi_videos")
1355
+ else:
1356
+ out_dir = os.path.join(os.path.dirname(video_path), "roi_videos")
1357
+ os.makedirs(out_dir, exist_ok=True)
1358
+ video_basename = os.path.splitext(os.path.basename(video_path))[0]
1359
+
1360
+ obj_tracks = {oid: {} for oid in all_obj_ids}
1361
+ for i, frame_objs in enumerate(frame_objects):
1362
+ frame_idx = start_offset + i
1363
+ for obj in frame_objs:
1364
+ oid = obj.get("obj_id", 0)
1365
+ bbox = obj.get("bbox")
1366
+ if bbox is not None:
1367
+ obj_tracks[oid][frame_idx] = tuple(int(x) for x in bbox)
1368
+
1369
+ cap = cv2.VideoCapture(video_path)
1370
+ vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
1371
+ vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
1372
+ vid_fps = cap.get(cv2.CAP_PROP_FPS)
1373
+ if vid_fps <= 0:
1374
+ vid_fps = 30.0
1375
+ cap.release()
1376
+
1377
+ frame_indices = sorted(set(f for track in obj_tracks.values() for f in track))
1378
+ if not frame_indices:
1379
+ continue
1380
+ start_frame = frame_indices[0]
1381
+ end_frame = frame_indices[-1]
1382
+
1383
+ crop_padding = 0.25
1384
+ writers = {}
1385
+ obj_crop_params = {}
1386
+ obj_paths = {}
1387
+ for obj_id, track in obj_tracks.items():
1388
+ if not track:
1389
+ continue
1390
+ max_w = max(x2 - x1 for x1, y1, x2, y2 in track.values())
1391
+ max_h = max(y2 - y1 for x1, y1, x2, y2 in track.values())
1392
+ crop_w = int(max_w * (1 + 2 * crop_padding))
1393
+ crop_h = int(max_h * (1 + 2 * crop_padding))
1394
+ crop_side = max(crop_w, crop_h, 64)
1395
+ obj_crop_params[obj_id] = crop_side
1396
+ out_path = os.path.join(out_dir, f"{video_basename}_object{obj_id}.mp4")
1397
+ obj_paths[obj_id] = out_path
1398
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
1399
+ writers[obj_id] = cv2.VideoWriter(out_path, fourcc, vid_fps, (crop_side, crop_side))
1400
+
1401
+ if not writers:
1402
+ continue
1403
+
1404
+ progress = QProgressDialog(
1405
+ f"Exporting ROI videos: {os.path.basename(video_path)}...", "Cancel",
1406
+ start_frame, end_frame + 1, self
1407
+ )
1408
+ progress.setWindowTitle("Export ROI Videos")
1409
+ progress.setMinimumDuration(0)
1410
+ progress.setValue(start_frame)
1411
+
1412
+ cap = cv2.VideoCapture(video_path)
1413
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
1414
+ last_bbox = {}
1415
+ for fidx in range(start_frame, end_frame + 1):
1416
+ if progress.wasCanceled():
1417
+ break
1418
+ ret, frame = cap.read()
1419
+ if not ret:
1420
+ break
1421
+ progress.setValue(fidx)
1422
+ QApplication.processEvents()
1423
+ for obj_id, writer in writers.items():
1424
+ crop_side = obj_crop_params[obj_id]
1425
+ bbox = obj_tracks[obj_id].get(fidx) or last_bbox.get(obj_id)
1426
+ if bbox is None:
1427
+ writer.write(np.zeros((crop_side, crop_side, 3), dtype=np.uint8))
1428
+ continue
1429
+ last_bbox[obj_id] = bbox
1430
+ x1, y1, x2, y2 = bbox
1431
+ cx = (x1 + x2) // 2
1432
+ cy = (y1 + y2) // 2
1433
+ half = crop_side // 2
1434
+ rx1 = max(0, cx - half)
1435
+ ry1 = max(0, cy - half)
1436
+ rx2 = min(vid_w, rx1 + crop_side)
1437
+ ry2 = min(vid_h, ry1 + crop_side)
1438
+ rx1 = max(0, rx2 - crop_side)
1439
+ ry1 = max(0, ry2 - crop_side)
1440
+ crop = frame[ry1:ry2, rx1:rx2]
1441
+ if crop.shape[0] != crop_side or crop.shape[1] != crop_side:
1442
+ crop = cv2.resize(crop, (crop_side, crop_side), interpolation=cv2.INTER_AREA)
1443
+ writer.write(crop)
1444
+
1445
+ cap.release()
1446
+ progress.close()
1447
+ for w in writers.values():
1448
+ w.release()
1449
+ all_exported.extend([(out_dir, list(obj_paths.values()))])
1450
+
1451
+ if not all_exported:
1452
+ QMessageBox.warning(self, "No Tracks", "No object tracks with valid bboxes found.")
1453
+ return
1454
+ summary = []
1455
+ for out_dir, paths in all_exported:
1456
+ names = [os.path.basename(p) for p in paths]
1457
+ summary.append(f"{out_dir}\n " + "\n ".join(names))
1458
+ QMessageBox.information(
1459
+ self, "Export Complete",
1460
+ f"Exported ROI videos to:\n\n" + "\n\n".join(summary)
1461
+ )
1462
+
1463
+ def _check_existing_clips(self):
1464
+ """Check if processed clips exist and enable extract embeddings button."""
1465
+ experiment_path = self.config.get("experiment_path")
1466
+ if not experiment_path:
1467
+ return
1468
+
1469
+ clips_dir = os.path.join(experiment_path, "registered_clips")
1470
+ if os.path.exists(clips_dir):
1471
+ clip_files = []
1472
+ for root, dirs, files in os.walk(clips_dir):
1473
+ for file in files:
1474
+ if file.endswith(('.mp4', '.avi', '.mov', '.mkv')):
1475
+ clip_files.append(os.path.join(root, file))
1476
+
1477
+ if clip_files:
1478
+ self.processed_videos = clip_files
1479
+ self.extract_embeddings_btn.setEnabled(True)
1480
+ self.output_dir = clips_dir
1481
+
1482
+ def update_config(self, config: dict):
1483
+ """Update configuration."""
1484
+ self.config = config
1485
+