singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,2752 @@
1
+ import sys
2
+ import os
3
+ import gc
4
+ import logging
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+ from PyQt6.QtWidgets import (
11
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog,
12
+ QRadioButton, QSlider, QButtonGroup, QMessageBox, QProgressBar,
13
+ QComboBox, QDoubleSpinBox, QSpinBox, QFormLayout, QCheckBox, QGroupBox,
14
+ QSizePolicy, QListWidget, QScrollArea, QProgressDialog, QApplication
15
+ )
16
+ from PyQt6.QtCore import Qt, QTimer, QThread, pyqtSignal, QPointF, QEvent
17
+ from PyQt6.QtGui import QImage, QPixmap, QPainter, QColor, QPen, QBrush
18
+ import shutil
19
+ from pathlib import Path
20
+ from contextlib import nullcontext
21
+ from importlib import metadata as importlib_metadata
22
+
23
+ # Motion tracking (Kalman filter, OC-SORT) in separate module
24
+ from .motion_tracking import MultiObjectMotionTracker
25
+
26
+
27
+ # Colors for different objects (R, G, B)
28
+ OBJ_COLORS = [
29
+ (0, 255, 0), # 1: Green
30
+ (255, 0, 0), # 2: Red
31
+ (0, 0, 255), # 3: Blue
32
+ (255, 255, 0), # 4: Yellow
33
+ (0, 255, 255), # 5: Cyan
34
+ (255, 0, 255), # 6: Magenta
35
+ (255, 128, 0), # 7: Orange
36
+ (128, 0, 255), # 8: Purple
37
+ (128, 128, 0), # 9: Olive
38
+ (0, 128, 128), # 10: Teal
39
+ ]
40
+
41
+ def get_obj_color(obj_id):
42
+ idx = (obj_id - 1) % len(OBJ_COLORS)
43
+ return OBJ_COLORS[idx]
44
+
45
+
46
+ class CheckpointDownloadWorker(QThread):
47
+ """Worker thread for downloading SAM2 checkpoints."""
48
+ progress = pyqtSignal(str)
49
+ finished = pyqtSignal(bool, str)
50
+
51
+ # Model URLs (SAM 2.1)
52
+ MODEL_URLS = {
53
+ "sam2.1_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
54
+ "sam2.1_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
55
+ "sam2.1_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
56
+ "sam2.1_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
57
+ # SAM 2.0 (older versions)
58
+ "sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
59
+ "sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
60
+ "sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
61
+ "sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
62
+ }
63
+
64
+ def __init__(self, checkpoint_name, checkpoint_path, checkpoint_url):
65
+ super().__init__()
66
+ self.checkpoint_name = checkpoint_name
67
+ self.checkpoint_path = checkpoint_path
68
+ self.checkpoint_url = checkpoint_url
69
+
70
+ def run(self):
71
+ try:
72
+ # Check if already downloaded
73
+ if os.path.exists(self.checkpoint_path):
74
+ file_size = os.path.getsize(self.checkpoint_path) / (1024**2) # MB
75
+ if file_size > 10: # Reasonable size check (should be >100MB)
76
+ self.finished.emit(True, f"Checkpoint already exists ({file_size:.1f} MB)")
77
+ return
78
+
79
+ self.progress.emit(f"Downloading {self.checkpoint_name}...")
80
+ self.progress.emit(f"URL: {self.checkpoint_url}")
81
+
82
+ # Try wget first, then curl
83
+ import urllib.request
84
+
85
+ def show_progress(block_num, block_size, total_size):
86
+ if total_size > 0:
87
+ percent = min(100, (block_num * block_size * 100) / total_size)
88
+ self.progress.emit(f"Downloading {self.checkpoint_name}: {percent:.1f}%")
89
+
90
+ # Download with progress
91
+ urllib.request.urlretrieve(
92
+ self.checkpoint_url,
93
+ self.checkpoint_path,
94
+ reporthook=show_progress
95
+ )
96
+
97
+ # Verify download
98
+ if os.path.exists(self.checkpoint_path):
99
+ file_size = os.path.getsize(self.checkpoint_path) / (1024**2)
100
+ if file_size < 10: # Suspiciously small
101
+ os.remove(self.checkpoint_path)
102
+ raise Exception(f"Downloaded file seems too small ({file_size:.1f} MB). Download may have failed.")
103
+ self.finished.emit(True, f"Downloaded successfully ({file_size:.1f} MB)")
104
+ else:
105
+ raise Exception("Download completed but file not found")
106
+
107
+ except Exception as e:
108
+ if os.path.exists(self.checkpoint_path):
109
+ try:
110
+ os.remove(self.checkpoint_path)
111
+ except:
112
+ pass
113
+ self.finished.emit(False, f"Download failed: {str(e)}")
114
+
115
+
116
+ class TrackingWorker(QThread):
117
+ """Worker thread for running tracking."""
118
+ progress_signal = pyqtSignal(int)
119
+ frame_result_signal = pyqtSignal(int, dict) # frame_idx, {obj_id: mask}
120
+ finished_signal = pyqtSignal(dict)
121
+ error_signal = pyqtSignal(str)
122
+ log_message = pyqtSignal(str)
123
+
124
+ def __init__(self, predictor, video_path, user_points, start_frame, end_frame,
125
+ mask_threshold=0.0, offload_video=True, offload_state=True,
126
+ enable_memory_management=True, reseed_between_chunks=False,
127
+ initial_masks=None, enable_motion_tracking=False,
128
+ motion_score_threshold=0.3, motion_consecutive_low=3,
129
+ motion_area_threshold=0.5, enable_ocsort=False,
130
+ ocsort_inertia=0.2, use_cuda_bf16_autocast=True):
131
+ super().__init__()
132
+ self.predictor = predictor
133
+ self.video_path = video_path
134
+ self.user_points = user_points # (frame_idx, obj_id) -> {'points': [], 'labels': []}
135
+ self.start_frame = start_frame
136
+ self.end_frame = end_frame
137
+ self.mask_threshold = mask_threshold
138
+ self.offload_video = offload_video
139
+ self.offload_state = offload_state
140
+ self.enable_memory_management = enable_memory_management
141
+ self.reseed_between_chunks = reseed_between_chunks
142
+ self.initial_masks = initial_masks or {} # {(frame_idx, obj_id): mask_array} for resume conditioning
143
+ self.chunk_size = 200
144
+ self.should_stop = False
145
+ self.use_cuda_bf16_autocast = bool(use_cuda_bf16_autocast)
146
+
147
+ # Motion-aware tracking
148
+ self.enable_motion_tracking = enable_motion_tracking
149
+ self.motion_score_threshold = motion_score_threshold
150
+ self.motion_tracker = None
151
+ if enable_motion_tracking:
152
+ self.motion_tracker = MultiObjectMotionTracker(
153
+ motion_score_threshold=motion_score_threshold,
154
+ use_kalman=True,
155
+ consecutive_low_threshold=motion_consecutive_low,
156
+ area_change_threshold=motion_area_threshold,
157
+ use_ocsort=enable_ocsort,
158
+ ocsort_inertia=ocsort_inertia
159
+ )
160
+
161
+ def _use_cuda_bf16(self):
162
+ """Use bf16 autocast only when SAM2 runs on CUDA."""
163
+ dev = getattr(self.predictor, "device", None)
164
+ dev_type = getattr(dev, "type", str(dev))
165
+ return bool(
166
+ self.use_cuda_bf16_autocast
167
+ and torch.cuda.is_available()
168
+ and dev_type == "cuda"
169
+ )
170
+
171
+ def _sam2_autocast(self):
172
+ if self._use_cuda_bf16():
173
+ return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
174
+ return nullcontext()
175
+
176
+ def _sam2_call(self, fn, *args, **kwargs):
177
+ with self._sam2_autocast():
178
+ return fn(*args, **kwargs)
179
+
180
+ def stop(self):
181
+ """Request tracking stop."""
182
+ self.should_stop = True
183
+
184
+ def run(self):
185
+ """Run tracking with incremental processing."""
186
+ try:
187
+ all_video_segments = {} # global_frame_idx -> {obj_id: mask}
188
+ MAX_MASKS_IN_MEMORY = 500 # Keep only recent masks, older ones are already emitted via signal
189
+ last_masks_for_reseed = None # Store last frame masks of previous chunk for optional reseed
190
+
191
+ try:
192
+ import decord
193
+ except ImportError:
194
+ raise ImportError("decord not found. Please install it: pip install eva-decord")
195
+
196
+ from collections import OrderedDict
197
+
198
+ def load_frames(start, end):
199
+ decord.bridge.set_bridge("torch")
200
+ image_size = self.predictor.image_size
201
+ vr = decord.VideoReader(self.video_path, width=image_size, height=image_size)
202
+ target_dtype = getattr(self.predictor, "dtype", torch.float32)
203
+ if self._use_cuda_bf16() and not self.offload_video:
204
+ target_dtype = torch.bfloat16
205
+
206
+ if end > len(vr):
207
+ end = len(vr)
208
+ indices = list(range(start, end))
209
+ frames = vr.get_batch(indices)
210
+ del vr # Free VideoReader memory after loading frames
211
+ images = frames.permute(0, 3, 1, 2).float() / 255.0
212
+ del frames # Free original frame tensor after processing
213
+
214
+ img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)[:, None, None]
215
+ img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)[:, None, None]
216
+
217
+ if not self.offload_video:
218
+ images = images.to(self.predictor.device, dtype=target_dtype)
219
+ img_mean = img_mean.to(self.predictor.device, dtype=target_dtype)
220
+ img_std = img_std.to(self.predictor.device, dtype=target_dtype)
221
+ else:
222
+ # Keep on CPU but ensure dtype matches model expectations
223
+ images = images.to(dtype=target_dtype)
224
+ img_mean = img_mean.to(dtype=target_dtype)
225
+ img_std = img_std.to(dtype=target_dtype)
226
+
227
+ images -= img_mean
228
+ images /= img_std
229
+ return images
230
+
231
+ def is_cuda_alloc_error(exc):
232
+ msg = str(exc).lower()
233
+ return (
234
+ "cuda out of memory" in msg
235
+ or "cublas_status_alloc_failed" in msg
236
+ or ("cuda error" in msg and "alloc" in msg)
237
+ )
238
+
239
+ def run_with_cuda_retry(op_name, fn):
240
+ try:
241
+ return fn()
242
+ except RuntimeError as e:
243
+ if not is_cuda_alloc_error(e):
244
+ raise
245
+ self.log_message.emit(
246
+ f"[GPU] Memory error during {op_name}. Clearing cache and retrying once..."
247
+ )
248
+ if torch.cuda.is_available():
249
+ torch.cuda.empty_cache()
250
+ gc.collect()
251
+ try:
252
+ return fn()
253
+ except RuntimeError as e2:
254
+ if is_cuda_alloc_error(e2):
255
+ raise RuntimeError(
256
+ "GPU memory allocation failed while running SAM2. "
257
+ "Try one or more: enable Offload Video to CPU, enable Offload State to CPU, "
258
+ "use a smaller SAM2 model, or track a shorter range."
259
+ ) from e2
260
+ raise
261
+
262
+ # Initialize with first chunk
263
+ current_end = min(self.start_frame + self.chunk_size, self.end_frame)
264
+ self.log_message.emit(f"Initializing with chunk: {self.start_frame} to {current_end}")
265
+
266
+ images = load_frames(self.start_frame, current_end)
267
+
268
+ # Get original dimensions
269
+ vr_meta = decord.VideoReader(self.video_path)
270
+ vh, vw, _ = vr_meta[0].shape
271
+ del vr_meta # Free memory immediately after getting dimensions
272
+
273
+ images_list = [images[i] for i in range(len(images))]
274
+
275
+ inference_state = {}
276
+ inference_state["images"] = images_list
277
+ inference_state["num_frames"] = len(images_list)
278
+ inference_state["offload_video_to_cpu"] = self.offload_video
279
+ inference_state["offload_state_to_cpu"] = self.offload_state
280
+ inference_state["video_height"] = vh
281
+ inference_state["video_width"] = vw
282
+ inference_state["device"] = self.predictor.device
283
+ inference_state["storage_device"] = torch.device("cpu") if self.offload_state else self.predictor.device
284
+ inference_state["point_inputs_per_obj"] = {}
285
+ inference_state["mask_inputs_per_obj"] = {}
286
+ inference_state["cached_features"] = {}
287
+ inference_state["constants"] = {}
288
+ inference_state["obj_id_to_idx"] = OrderedDict()
289
+ inference_state["obj_idx_to_id"] = OrderedDict()
290
+ inference_state["obj_ids"] = []
291
+ inference_state["output_dict_per_obj"] = {}
292
+ inference_state["temp_output_dict_per_obj"] = {}
293
+ inference_state["frames_tracked_per_obj"] = {}
294
+
295
+ # Warm up
296
+ try:
297
+ self._sam2_call(self.predictor._get_image_feature, inference_state, frame_idx=0, batch_size=1)
298
+ except:
299
+ pass
300
+
301
+ self.predictor.reset_state(inference_state)
302
+
303
+ # Loop through chunks with sliding window memory management
304
+ # Track the global offset (how many frames we've trimmed from the start)
305
+ # Only used when memory management is enabled
306
+ global_offset = 0 # Tracks how many frames we've dropped from the front
307
+ MAX_FRAMES_IN_MEMORY = 800 # Keep ~800 frames in memory
308
+
309
+ processed_up_to = self.start_frame
310
+
311
+ while processed_up_to < self.end_frame:
312
+ if self.should_stop:
313
+ break
314
+
315
+ # inference_state["images"] grows with each processed frame;
316
+ # frames correspond to self.start_frame + index. When memory
317
+ # management is disabled, global_offset stays at 0.
318
+ buffer_start = self.start_frame + (global_offset if self.enable_memory_management else 0)
319
+
320
+ chunk_start = processed_up_to
321
+ chunk_end = buffer_start + inference_state["num_frames"] # End of current buffer
322
+
323
+ self.log_message.emit(f"Processing range: {chunk_start} to {chunk_end} (buffer: {buffer_start} to {chunk_end})")
324
+
325
+ for (frame_idx, obj_id), data in self.user_points.items():
326
+ if chunk_start <= frame_idx < chunk_end:
327
+ # Local index relative to current buffer (after trimming)
328
+ local_idx = frame_idx - buffer_start
329
+ if local_idx < 0 or local_idx >= inference_state["num_frames"]:
330
+ # Frame was trimmed, skip (shouldn't happen if logic is correct)
331
+ continue
332
+ pts = np.array(data['points'], dtype=np.float32)
333
+ lbls = np.array(data['labels'], dtype=np.int32)
334
+
335
+ run_with_cuda_retry(
336
+ "add_new_points_or_box",
337
+ lambda: self._sam2_call(self.predictor.add_new_points_or_box,
338
+ inference_state=inference_state,
339
+ frame_idx=local_idx,
340
+ obj_id=obj_id,
341
+ points=pts,
342
+ labels=lbls,
343
+ normalize_coords=True,
344
+ ),
345
+ )
346
+
347
+ # Inject initial masks (e.g., from pause/resume refinement)
348
+ for (frame_idx, obj_id), mask in self.initial_masks.items():
349
+ if chunk_start <= frame_idx < chunk_end:
350
+ local_idx = frame_idx - buffer_start
351
+ if local_idx < 0 or local_idx >= inference_state["num_frames"]:
352
+ continue
353
+ try:
354
+ # Resize mask to video dimensions if needed
355
+ vh = inference_state["video_height"]
356
+ vw = inference_state["video_width"]
357
+ if mask.shape[0] != vh or mask.shape[1] != vw:
358
+ import cv2
359
+ mask_resized = cv2.resize(mask.astype(np.float32), (vw, vh), interpolation=cv2.INTER_NEAREST)
360
+ mask = (mask_resized > 0.5).astype(np.uint8)
361
+
362
+ run_with_cuda_retry(
363
+ "add_new_mask",
364
+ lambda: self._sam2_call(self.predictor.add_new_mask,
365
+ inference_state=inference_state,
366
+ frame_idx=local_idx,
367
+ obj_id=obj_id,
368
+ mask=mask.astype(bool),
369
+ ),
370
+ )
371
+ self.log_message.emit(f"Injected refined mask for obj {obj_id} at frame {frame_idx}")
372
+ except Exception as e:
373
+ self.log_message.emit(f"Warning: Could not inject mask: {e}")
374
+
375
+ # Propagate from chunk_start
376
+ # We need local index for propagation start (relative to current buffer)
377
+ prop_start_local = chunk_start - buffer_start
378
+ if prop_start_local < 0:
379
+ prop_start_local = 0 # Can't propagate from before buffer start
380
+
381
+ # Memory trimming may drop the initial conditioning frame (the
382
+ # user's first click). The bundled SAM2 fork modifies
383
+ # propagate_in_video_preflight to allow propagation when only
384
+ # tracking history (non_cond_frame_outputs) is present, so
385
+ # explicit mask re-injection is not required here.
386
+
387
+ with self._sam2_autocast():
388
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
389
+ inference_state,
390
+ start_frame_idx=prop_start_local
391
+ ):
392
+ if self.should_stop:
393
+ break
394
+
395
+ # Convert local buffer index to global frame index
396
+ global_idx = out_frame_idx + buffer_start
397
+
398
+ if global_idx not in all_video_segments:
399
+ all_video_segments[global_idx] = {}
400
+
401
+ frame_scores = {} # Track scores for this frame per object
402
+ low_quality_objects = [] # Objects with sustained low scores this frame
403
+
404
+ for i, o_id in enumerate(out_obj_ids):
405
+ mask_logit = out_mask_logits[i]
406
+ if mask_logit.ndim == 3:
407
+ mask_logit = mask_logit[0]
408
+ mask = (mask_logit > self.mask_threshold).cpu().numpy().astype(np.uint8).squeeze()
409
+ all_video_segments[global_idx][o_id] = mask
410
+
411
+ # Motion-aware scoring
412
+ if self.motion_tracker is not None:
413
+ score, should_use = self.motion_tracker.update(
414
+ o_id, mask, mask_logit, global_idx
415
+ )
416
+ frame_scores[o_id] = score
417
+
418
+ # Only filter SAM2 memory after sustained low quality.
419
+ # This avoids dropping useful memory on one-frame glitches.
420
+ if (not should_use) and self.motion_tracker.check_needs_correction(o_id):
421
+ low_quality_objects.append(o_id)
422
+
423
+ # Log motion tracking info periodically
424
+ if self.motion_tracker is not None and global_idx % 50 == 0:
425
+ score_str = ", ".join([f"obj{k}:{v:.2f}" for k, v in frame_scores.items()])
426
+ self.log_message.emit(f"Frame {global_idx} scores: {score_str}")
427
+
428
+ # Memory filtering: remove low-quality frames from memory
429
+ if self.motion_tracker is not None and low_quality_objects:
430
+ for obj_id in low_quality_objects:
431
+ obj_idx = inference_state.get("obj_id_to_idx", {}).get(obj_id)
432
+ if obj_idx is not None:
433
+ # Recency-weighted memory filtering:
434
+ # keep recent memory frames and remove older low-score frames first.
435
+ obj_output = inference_state.get("output_dict_per_obj", {}).get(obj_idx, {})
436
+ non_cond = obj_output.get("non_cond_frame_outputs", {})
437
+ if not non_cond:
438
+ continue
439
+ threshold = self.motion_tracker.get_effective_threshold(obj_id)
440
+ recent_keep = 6
441
+ max_remove = 2
442
+ removed = 0
443
+ old_keys = sorted(
444
+ k for k in non_cond.keys() if k < (out_frame_idx - recent_keep)
445
+ )
446
+ for mem_local_idx in old_keys:
447
+ if removed >= max_remove:
448
+ break
449
+ mem_global_idx = mem_local_idx + buffer_start
450
+ mem_score = self.motion_tracker.get_frame_score(obj_id, mem_global_idx)
451
+ if mem_score is not None and mem_score < threshold:
452
+ del non_cond[mem_local_idx]
453
+ removed += 1
454
+
455
+ # Fallback if no older candidate was removable.
456
+ if removed == 0 and out_frame_idx in non_cond:
457
+ del non_cond[out_frame_idx]
458
+
459
+ # Appearance memory re-seed: when object recovers from long occlusion,
460
+ # inject golden mask so SAM2 remembers what the object looked like.
461
+ if self.motion_tracker is not None and self.motion_tracker.appearance_memory is not None:
462
+ for o_id in out_obj_ids:
463
+ amem = self.motion_tracker.appearance_memory
464
+ if amem.is_recovery_pending(o_id):
465
+ golden_mask = amem.pop_reseed_mask(o_id)
466
+ if golden_mask is not None:
467
+ try:
468
+ # Inject at current frame so SAM2 uses the golden
469
+ # mask immediately (not delayed by one frame).
470
+ reseed_local = out_frame_idx
471
+ i_vh = inference_state["video_height"]
472
+ i_vw = inference_state["video_width"]
473
+ gm = golden_mask
474
+ if gm.shape[0] != i_vh or gm.shape[1] != i_vw:
475
+ import cv2 as _cv2
476
+ gm = _cv2.resize(
477
+ gm.astype(np.float32), (i_vw, i_vh),
478
+ interpolation=_cv2.INTER_NEAREST
479
+ )
480
+ gm = (gm > 0.5).astype(np.uint8)
481
+ run_with_cuda_retry(
482
+ "appearance_reseed_add_new_mask",
483
+ lambda: self._sam2_call(
484
+ self.predictor.add_new_mask,
485
+ inference_state=inference_state,
486
+ frame_idx=reseed_local,
487
+ obj_id=o_id,
488
+ mask=gm.astype(bool),
489
+ ),
490
+ )
491
+ self.log_message.emit(
492
+ f"[AppearanceMemory] Re-seeded obj {o_id} at frame {global_idx} with golden mask"
493
+ )
494
+ except Exception as e:
495
+ self.log_message.emit(f"[AppearanceMemory] Re-seed failed for obj {o_id}: {e}")
496
+
497
+ # Automatic prompt injection: when drift detected, inject predicted bbox
498
+ if self.motion_tracker is not None:
499
+ for o_id in out_obj_ids:
500
+ if self.motion_tracker.check_needs_correction(o_id):
501
+ # Get Kalman-predicted bbox
502
+ pred_bbox = self.motion_tracker.get_predicted_bbox_for_correction(o_id)
503
+ if pred_bbox is not None:
504
+ if not self.motion_tracker.is_correction_bbox_sane(o_id, pred_bbox):
505
+ self.log_message.emit(
506
+ f"[Motion] Skipped correction for obj {o_id}: jump/scale too large"
507
+ )
508
+ self.motion_tracker.reset_correction_flag(o_id)
509
+ continue
510
+ try:
511
+ # Inject predicted bbox as new prompt for next frame
512
+ next_local_idx = out_frame_idx + 1
513
+ if next_local_idx < inference_state["num_frames"]:
514
+ run_with_cuda_retry(
515
+ "motion_correction_add_new_points_or_box",
516
+ lambda: self._sam2_call(
517
+ self.predictor.add_new_points_or_box,
518
+ inference_state,
519
+ frame_idx=next_local_idx,
520
+ obj_id=o_id,
521
+ box=pred_bbox,
522
+ clear_old_points=True,
523
+ normalize_coords=False,
524
+ ),
525
+ )
526
+ self.log_message.emit(
527
+ f"[Motion] Injected correction for obj {o_id} at frame {global_idx+1}"
528
+ )
529
+ self.motion_tracker.reset_correction_flag(o_id)
530
+ except Exception as e:
531
+ self.log_message.emit(f"Correction failed: {e}")
532
+
533
+ self.progress_signal.emit(global_idx)
534
+ # Emit real-time result for this frame
535
+ self.frame_result_signal.emit(global_idx, all_video_segments[global_idx])
536
+
537
+ # Clear old masks from memory (they're already emitted to main thread)
538
+ # Keep only recent MAX_MASKS_IN_MEMORY masks for the final emit
539
+ if len(all_video_segments) > MAX_MASKS_IN_MEMORY:
540
+ oldest_frame = min(all_video_segments.keys())
541
+ del all_video_segments[oldest_frame]
542
+
543
+ # Periodically clear CUDA cache (every 100 frames) to prevent accumulation
544
+ if global_idx % 100 == 0 and torch.cuda.is_available():
545
+ torch.cuda.empty_cache()
546
+
547
+ processed_up_to = chunk_end
548
+
549
+ # Clear CUDA cache and run garbage collection after each chunk
550
+ if torch.cuda.is_available():
551
+ torch.cuda.empty_cache()
552
+ gc.collect()
553
+
554
+ # MEMORY MANAGEMENT: Trim old frames if we exceed MAX_FRAMES_IN_MEMORY
555
+ # Only apply if memory management is enabled
556
+ if self.enable_memory_management and inference_state["num_frames"] > MAX_FRAMES_IN_MEMORY:
557
+ frames_to_trim = inference_state["num_frames"] - MAX_FRAMES_IN_MEMORY
558
+ self.log_message.emit(f"Trimming {frames_to_trim} old frames from memory (keeping last {MAX_FRAMES_IN_MEMORY} frames)...")
559
+
560
+ # 1. Trim images list (keep last MAX_FRAMES_IN_MEMORY frames)
561
+ # Since images is a list, this is O(1) pointer manipulation, not O(N) memory copy
562
+ inference_state["images"] = inference_state["images"][-MAX_FRAMES_IN_MEMORY:]
563
+ inference_state["num_frames"] = len(inference_state["images"])
564
+
565
+ # 2. Update global offset
566
+ global_offset += frames_to_trim
567
+
568
+ # 3. Shift all indices in inference_state dictionaries
569
+ def shift_dict_keys(d, shift):
570
+ """Shift dictionary keys by subtracting shift, removing negative keys"""
571
+ new_d = {}
572
+ for k, v in d.items():
573
+ new_k = k - shift
574
+ if new_k >= 0: # Only keep non-negative keys (frames still in buffer)
575
+ new_d[new_k] = v
576
+ return new_d
577
+
578
+ # Shift cached features
579
+ inference_state["cached_features"] = shift_dict_keys(inference_state["cached_features"], frames_to_trim)
580
+
581
+ # Shift per-object dictionaries
582
+ for obj_idx in list(inference_state["point_inputs_per_obj"].keys()):
583
+ inference_state["point_inputs_per_obj"][obj_idx] = shift_dict_keys(
584
+ inference_state["point_inputs_per_obj"][obj_idx], frames_to_trim
585
+ )
586
+ inference_state["mask_inputs_per_obj"][obj_idx] = shift_dict_keys(
587
+ inference_state["mask_inputs_per_obj"][obj_idx], frames_to_trim
588
+ )
589
+
590
+ # Shift output dicts (keep conditioning frames if they're still in range)
591
+ obj_output = inference_state["output_dict_per_obj"][obj_idx]
592
+ obj_output["cond_frame_outputs"] = shift_dict_keys(
593
+ obj_output["cond_frame_outputs"], frames_to_trim
594
+ )
595
+ # SAM2's memory bank only requires the last num_maskmem
596
+ # non_cond frames, but all non_cond frames still inside
597
+ # the (already trimmed) buffer are retained.
598
+ obj_output["non_cond_frame_outputs"] = shift_dict_keys(
599
+ obj_output["non_cond_frame_outputs"], frames_to_trim
600
+ )
601
+
602
+ obj_temp = inference_state["temp_output_dict_per_obj"][obj_idx]
603
+ obj_temp["cond_frame_outputs"] = shift_dict_keys(
604
+ obj_temp["cond_frame_outputs"], frames_to_trim
605
+ )
606
+ obj_temp["non_cond_frame_outputs"] = shift_dict_keys(
607
+ obj_temp["non_cond_frame_outputs"], frames_to_trim
608
+ )
609
+
610
+ # Shift frames_tracked metadata
611
+ inference_state["frames_tracked_per_obj"][obj_idx] = shift_dict_keys(
612
+ inference_state["frames_tracked_per_obj"][obj_idx], frames_to_trim
613
+ )
614
+
615
+ self.log_message.emit(f"Memory trimmed. Global offset now: {global_offset}")
616
+
617
+ # Clear CUDA cache after memory trimming
618
+ if torch.cuda.is_available():
619
+ torch.cuda.empty_cache()
620
+
621
+ # Load NEXT chunk if needed
622
+ if processed_up_to < self.end_frame:
623
+ next_end = min(processed_up_to + self.chunk_size, self.end_frame)
624
+ self.log_message.emit(f"Loading next chunk: {processed_up_to} to {next_end}")
625
+
626
+ # Capture last frame masks of current chunk for optional reseed
627
+ if self.reseed_between_chunks:
628
+ last_frame_idx = processed_up_to - 1
629
+ last_masks_for_reseed = all_video_segments.get(last_frame_idx, None)
630
+
631
+ new_images = load_frames(processed_up_to, next_end)
632
+
633
+ # Append to inference_state
634
+ # OPTIMIZATION: Since images is a list, we can extend it directly (O(1) per frame)
635
+ # instead of torch.cat which would copy all existing frames (O(N))
636
+ new_images_list = [new_images[i] for i in range(len(new_images))]
637
+ inference_state["images"].extend(new_images_list)
638
+ inference_state["num_frames"] = len(inference_state["images"])
639
+
640
+ # Optional reseed: add mask from last frame of previous chunk
641
+ if self.reseed_between_chunks and last_masks_for_reseed:
642
+ try:
643
+ seed_frame_local = processed_up_to - buffer_start # first frame of new chunk in buffer coords
644
+ vh = inference_state["video_height"]
645
+ vw = inference_state["video_width"]
646
+
647
+ for obj_id, mask in last_masks_for_reseed.items():
648
+ if mask is None or mask.max() == 0:
649
+ continue
650
+
651
+ # Resize mask to video dimensions if needed
652
+ if mask.shape[0] != vh or mask.shape[1] != vw:
653
+ import cv2
654
+ mask_resized = cv2.resize(mask.astype(np.float32), (vw, vh), interpolation=cv2.INTER_NEAREST)
655
+ mask = (mask_resized > 0.5).astype(np.uint8)
656
+
657
+ run_with_cuda_retry(
658
+ "chunk_reseed_add_new_mask",
659
+ lambda: self._sam2_call(self.predictor.add_new_mask,
660
+ inference_state=inference_state,
661
+ frame_idx=seed_frame_local,
662
+ obj_id=obj_id,
663
+ mask=mask.astype(bool),
664
+ ),
665
+ )
666
+ self.log_message.emit(f"Reseeded obj {obj_id} with mask at frame {processed_up_to}")
667
+ except Exception as e:
668
+ self.log_message.emit(f"Warning: reseed between chunks failed: {e}")
669
+
670
+ self.finished_signal.emit(all_video_segments)
671
+
672
+ except Exception as e:
673
+ import traceback
674
+ self.log_message.emit(f"ERROR: {str(e)}\n{traceback.format_exc()}")
675
+ self.error_signal.emit(str(e))
676
+
677
+
678
+ class VideoLabel(QLabel):
679
+ """Custom label for video display with click handling."""
680
+ click_signal = pyqtSignal(int, int)
681
+
682
+ def mousePressEvent(self, event):
683
+ if event.button() == Qt.MouseButton.LeftButton:
684
+ self.click_signal.emit(event.pos().x(), event.pos().y())
685
+
686
+
687
+ class SegmentationTrackingWidget(QWidget):
688
+ """Widget for segmentation and multi-object tracking using SAM2."""
689
+
690
+ # Signal emitted when tracking completes with video and mask paths
691
+ tracking_completed = pyqtSignal(str, str) # video_path, mask_path
692
+
693
+ def __init__(self, config: dict):
694
+ super().__init__()
695
+ self.config = config
696
+ self.video_path = None
697
+ self.cap = None
698
+ self.total_frames = 0
699
+ self.current_frame_idx = 0
700
+ self.frame = None
701
+ self._frame_rgb = None # Keep RGB frame reference for QImage memory safety
702
+ self.points = [] # List of (x, y, label, frame_idx, obj_id)
703
+ self.masks = {} # frame_idx -> {obj_id: mask_array}
704
+ self.last_processed_frame = None
705
+ self.tracking_paused = False
706
+ self.resume_from_frame = None
707
+ self.resume_initial_masks = {}
708
+ self._base_display_pixmap = None
709
+ self.zoom_factor = 1.0
710
+ self.zoom_min = 0.5
711
+ self.zoom_max = 4.0
712
+ self.zoom_step = 0.2
713
+
714
+ self.obj_ids = [1]
715
+ self.current_obj_id = 1
716
+
717
+ self.predictor = None
718
+ self.inference_state = None
719
+ self.state_start_frame = 0
720
+
721
+ # Multi-video support
722
+ self.videos = [] # list of per-video state dicts
723
+ self.current_video_idx = None
724
+ self.batch_queue = []
725
+ self.batch_mode = False
726
+
727
+ # Settings
728
+ self.mask_threshold = 0.0
729
+ self.fill_hole_area = 0
730
+ self.offload_video = True
731
+ self.offload_state = True
732
+ self.use_cuda_bf16_autocast = True
733
+ self.enable_memory_management = True
734
+ self.max_frames_per_load = 200 # Limit frames loaded at once to prevent RAM issues
735
+ self.save_overlay_video = True
736
+
737
+ # Motion-aware tracking settings
738
+ self.enable_motion_tracking = False # Off by default
739
+ self.motion_score_threshold = 0.3
740
+ self.motion_consecutive_low = 3 # Frames before auto-correction
741
+ self.motion_area_threshold = 0.5 # Max allowed area change ratio
742
+
743
+ # OC-SORT drift correction settings
744
+ self.enable_ocsort = False # Off by default
745
+ self.ocsort_inertia = 0.2 # Paper default: 0.2
746
+
747
+ # SAM2 paths - resolved via _paths so this works both from source and pip install
748
+ from singlebehaviorlab._paths import get_sam2_backend_dir, get_sam2_checkpoints_dir
749
+ self.sam2_dir = str(get_sam2_checkpoints_dir())
750
+ self.sam2_backend_dir = str(get_sam2_backend_dir())
751
+ self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", "sam2.1_hiera_large.pt")
752
+ self.model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
753
+
754
+ self.tracking_worker = None
755
+ self.download_worker = None
756
+
757
+ self._setup_ui()
758
+ self._check_sam2_availability()
759
+
760
+ def _use_cuda_bf16(self):
761
+ """Use bf16 autocast only for CUDA SAM2 inference."""
762
+ dev = getattr(self.predictor, "device", None)
763
+ dev_type = getattr(dev, "type", str(dev))
764
+ return bool(
765
+ self.use_cuda_bf16_autocast
766
+ and torch.cuda.is_available()
767
+ and dev_type == "cuda"
768
+ )
769
+
770
+ def _sam2_autocast(self):
771
+ if self._use_cuda_bf16():
772
+ return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
773
+ return nullcontext()
774
+
775
+ def _sam2_call(self, fn, *args, **kwargs):
776
+ with self._sam2_autocast():
777
+ return fn(*args, **kwargs)
778
+ self._check_model_availability()
779
+
780
+ def _ensure_hydra_initialized(self):
781
+ """Ensure Hydra is initialized before using SAM2."""
782
+ try:
783
+ from hydra.core.global_hydra import GlobalHydra
784
+ from hydra import initialize_config_dir
785
+
786
+ # Check if Hydra is already initialized
787
+ if GlobalHydra.instance().is_initialized():
788
+ return True
789
+
790
+ # Find sam2 configs directory
791
+ sam2_configs_dir = None
792
+
793
+ # Try to find from installed package first (most reliable)
794
+ try:
795
+ import sam2
796
+ if hasattr(sam2, '__file__') and sam2.__file__:
797
+ sam2_path = os.path.dirname(sam2.__file__)
798
+ configs_path = os.path.join(sam2_path, "configs")
799
+ if os.path.exists(configs_path):
800
+ sam2_configs_dir = configs_path
801
+ elif hasattr(sam2, '__path__'):
802
+ # Handle namespace packages
803
+ for path in sam2.__path__:
804
+ configs_path = os.path.join(path, "configs")
805
+ if os.path.exists(configs_path):
806
+ sam2_configs_dir = configs_path
807
+ break
808
+ except ImportError:
809
+ pass
810
+
811
+ # Fall back to the bundled sam2_backend configs directory.
812
+ if not sam2_configs_dir:
813
+ sam2_backend_configs = os.path.join(self.sam2_backend_dir, "sam2", "configs")
814
+ if os.path.exists(sam2_backend_configs):
815
+ sam2_configs_dir = sam2_backend_configs
816
+
817
+ if sam2_configs_dir:
818
+ # Initialize Hydra with the config directory
819
+ initialize_config_dir(config_dir=sam2_configs_dir, version_base="1.2")
820
+ return True
821
+
822
+ # Fallback: try initialize_config_module (may work if SAM2 is properly installed)
823
+ try:
824
+ from hydra import initialize_config_module
825
+ initialize_config_module("sam2", version_base="1.2")
826
+ return True
827
+ except Exception:
828
+ pass
829
+
830
+ return False
831
+
832
+ except ImportError as e:
833
+ # Hydra not installed
834
+ return False
835
+ except Exception as e:
836
+ # Hydra initialization failed
837
+ return False
838
+
839
+ def _has_installed_sam2_distribution(self):
840
+ """Return True when SAM2 is importable as a Python package."""
841
+ for dist_name in ("SAM-2", "sam2"):
842
+ try:
843
+ importlib_metadata.distribution(dist_name)
844
+ return True
845
+ except importlib_metadata.PackageNotFoundError:
846
+ continue
847
+ except Exception:
848
+ continue
849
+ try:
850
+ import importlib.util
851
+ return importlib.util.find_spec("sam2") is not None
852
+ except Exception:
853
+ return False
854
+
855
+
856
+ def _check_sam2_availability(self):
857
+ """Check if SAM2 is available."""
858
+ has_installed_pkg = self._has_installed_sam2_distribution()
859
+
860
+ # Only report SAM2 as installed when it exists as an actual Python package
861
+ try:
862
+ if has_installed_pkg:
863
+ # Initialize Hydra before importing SAM2
864
+ self._ensure_hydra_initialized()
865
+ from sam2.build_sam import build_sam2_video_predictor
866
+ self.sam2_available = True
867
+ self.setup_status_label.setText("SAM2 is installed")
868
+ self.setup_status_label.setStyleSheet("color: green;")
869
+ self._populate_checkpoints()
870
+ # Set default model selection
871
+ for i in range(self.combo_model.count()):
872
+ if self.combo_model.itemData(i) == "sam2.1_hiera_large.pt":
873
+ self.combo_model.setCurrentIndex(i)
874
+ break
875
+ self._check_model_availability()
876
+ return
877
+ except (ImportError, RuntimeError) as e:
878
+ # If RuntimeError about parent directory, SAM2 needs to be properly installed
879
+ if isinstance(e, RuntimeError) and ("parent directory" in str(e) or "shadowed" in str(e)):
880
+ # Check if sam2_backend directory exists - if so, it needs to be reinstalled
881
+ if os.path.exists(self.sam2_backend_dir):
882
+ sam2_package = os.path.join(self.sam2_backend_dir, "sam2")
883
+ if os.path.exists(sam2_package):
884
+ # SAM2 exists but not properly installed - needs pip install -e
885
+ self.sam2_available = False
886
+ self.setup_status_label.setText("SAM2 needs reinstallation")
887
+ self.setup_status_label.setStyleSheet("color: orange;")
888
+ return
889
+
890
+ # If the source tree exists but the package is not installed, report that clearly.
891
+ sam2_folder = self.sam2_backend_dir
892
+ if os.path.exists(sam2_folder):
893
+ sam2_package = os.path.join(sam2_folder, "sam2")
894
+ if os.path.exists(sam2_package) and os.path.exists(os.path.join(sam2_package, "__init__.py")):
895
+ self.sam2_available = False
896
+ self.setup_status_label.setText("SAM2 source found, but not installed")
897
+ self.setup_status_label.setStyleSheet("color: orange;")
898
+ return
899
+
900
+ self.sam2_available = False
901
+ self.setup_status_label.setText("SAM2 not installed")
902
+ self.setup_status_label.setStyleSheet("color: red;")
903
+
904
+ def _setup_ui(self):
905
+ """Setup UI components."""
906
+ layout = QVBoxLayout(self)
907
+
908
+ # Top row: SAM2 Setup and Model Selection side by side
909
+ top_row_layout = QHBoxLayout()
910
+
911
+ # SAM2 Setup Section (left side)
912
+ setup_group = QGroupBox("SAM2 Setup")
913
+ setup_layout = QVBoxLayout()
914
+
915
+ setup_info_layout = QHBoxLayout()
916
+ setup_info_layout.addWidget(QLabel("Status:"))
917
+ self.setup_status_label = QLabel("Checking...")
918
+ setup_info_layout.addWidget(self.setup_status_label)
919
+ setup_info_layout.addStretch()
920
+ setup_layout.addLayout(setup_info_layout)
921
+
922
+ setup_path_layout = QHBoxLayout()
923
+ setup_path_layout.addWidget(QLabel("Package:"))
924
+ self.setup_path_label = QLabel(self.sam2_backend_dir)
925
+ self.setup_path_label.setWordWrap(True)
926
+ self.setup_path_label.setStyleSheet("color: gray;")
927
+ setup_path_layout.addWidget(self.setup_path_label, stretch=1)
928
+ setup_layout.addLayout(setup_path_layout)
929
+
930
+ ckpt_path_layout = QHBoxLayout()
931
+ ckpt_path_layout.addWidget(QLabel("Checkpoints:"))
932
+ self.ckpt_path_label = QLabel(self.sam2_dir)
933
+ self.ckpt_path_label.setWordWrap(True)
934
+ self.ckpt_path_label.setStyleSheet("color: gray;")
935
+ ckpt_path_layout.addWidget(self.ckpt_path_label, stretch=1)
936
+ setup_layout.addLayout(ckpt_path_layout)
937
+
938
+ setup_group.setLayout(setup_layout)
939
+ top_row_layout.addWidget(setup_group)
940
+
941
+ # Model Selection (right side) - matching Video Settings width
942
+ model_group = QGroupBox("Model selection")
943
+ model_group.setFixedWidth(380) # Match Video Settings & Controls width
944
+ model_layout = QVBoxLayout()
945
+
946
+ model_select_layout = QHBoxLayout()
947
+ model_select_layout.addWidget(QLabel("Model:"))
948
+ self.combo_model = QComboBox()
949
+ # Add all available models with user-friendly names
950
+ self.model_names = {
951
+ "sam2.1_hiera_tiny.pt": "SAM2.1 Tiny (38.9M, Fastest)",
952
+ "sam2.1_hiera_small.pt": "SAM2.1 Small (46M, Fast)",
953
+ "sam2.1_hiera_base_plus.pt": "SAM2.1 Base+ (80.8M, Balanced)",
954
+ "sam2.1_hiera_large.pt": "SAM2.1 Large (224.4M, Best Quality)",
955
+ "sam2_hiera_tiny.pt": "SAM2.0 Tiny (38.9M, Legacy)",
956
+ "sam2_hiera_small.pt": "SAM2.0 Small (46M, Legacy)",
957
+ "sam2_hiera_base_plus.pt": "SAM2.0 Base+ (80.8M, Legacy)",
958
+ "sam2_hiera_large.pt": "SAM2.0 Large (224.4M, Legacy)",
959
+ }
960
+ for model_file, display_name in self.model_names.items():
961
+ self.combo_model.addItem(display_name, model_file)
962
+ self.combo_model.currentIndexChanged.connect(self._on_model_selected)
963
+ model_select_layout.addWidget(self.combo_model)
964
+ model_layout.addLayout(model_select_layout)
965
+
966
+ self.model_status_label = QLabel("Select a model to check/download")
967
+ self.model_status_label.setWordWrap(True)
968
+ self.model_status_label.setStyleSheet("color: gray;")
969
+ model_layout.addWidget(self.model_status_label)
970
+
971
+ self.download_progress = QProgressBar()
972
+ self.download_progress.setVisible(False)
973
+ model_layout.addWidget(self.download_progress)
974
+
975
+ model_group.setLayout(model_layout)
976
+ top_row_layout.addWidget(model_group)
977
+
978
+ layout.addLayout(top_row_layout)
979
+
980
+ # Legacy checkpoint combo (hidden, kept for compatibility)
981
+ self.combo_ckpt = QComboBox()
982
+ self.combo_ckpt.currentTextChanged.connect(self._on_checkpoint_changed)
983
+
984
+ # Video Range Controls
985
+ range_group = QGroupBox("Processing range")
986
+ range_layout = QHBoxLayout()
987
+
988
+ self.chk_limit_range = QCheckBox("Limit Range")
989
+ self.chk_limit_range.toggled.connect(self._toggle_range_inputs)
990
+ range_layout.addWidget(self.chk_limit_range)
991
+
992
+ range_layout.addWidget(QLabel("Start:"))
993
+ self.spin_start = QSpinBox()
994
+ self.spin_start.setRange(0, 999999)
995
+ self.spin_start.setEnabled(False)
996
+ range_layout.addWidget(self.spin_start)
997
+
998
+ self.btn_set_start = QPushButton("Set")
999
+ self.btn_set_start.clicked.connect(self._set_range_start)
1000
+ self.btn_set_start.setEnabled(False)
1001
+ range_layout.addWidget(self.btn_set_start)
1002
+
1003
+ range_layout.addWidget(QLabel("End:"))
1004
+ self.spin_end = QSpinBox()
1005
+ self.spin_end.setRange(0, 999999)
1006
+ self.spin_end.setEnabled(False)
1007
+ range_layout.addWidget(self.spin_end)
1008
+
1009
+ self.btn_set_end = QPushButton("Set")
1010
+ self.btn_set_end.clicked.connect(self._set_range_end)
1011
+ self.btn_set_end.setEnabled(False)
1012
+ range_layout.addWidget(self.btn_set_end)
1013
+
1014
+ range_group.setLayout(range_layout)
1015
+ layout.addWidget(range_group)
1016
+
1017
+ # Video Display and Controls side by side
1018
+ video_row_layout = QHBoxLayout()
1019
+
1020
+ # Video Display (left side)
1021
+ self.video_scroll = QScrollArea()
1022
+ self.video_scroll.setStyleSheet("background-color: black; border: none;")
1023
+ self.video_scroll.setWidgetResizable(False)
1024
+ self.video_scroll.setAlignment(Qt.AlignmentFlag.AlignCenter)
1025
+
1026
+ self.video_label = VideoLabel()
1027
+ self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
1028
+ self.video_label.setMinimumSize(1, 1)
1029
+ self.video_label.setStyleSheet("background-color: black;")
1030
+ self.video_label.click_signal.connect(self._handle_click)
1031
+ self.video_scroll.setWidget(self.video_label)
1032
+ self.video_scroll.viewport().installEventFilter(self)
1033
+
1034
+ self.btn_zoom_in = QPushButton("+", self.video_scroll.viewport())
1035
+ self.btn_zoom_out = QPushButton("-", self.video_scroll.viewport())
1036
+ self._style_zoom_button(self.btn_zoom_in)
1037
+ self._style_zoom_button(self.btn_zoom_out)
1038
+ self.btn_zoom_in.clicked.connect(self._zoom_in)
1039
+ self.btn_zoom_out.clicked.connect(self._zoom_out)
1040
+ self._position_zoom_buttons()
1041
+
1042
+ video_row_layout.addWidget(self.video_scroll, stretch=2)
1043
+
1044
+ # Controls Container (right side) - matching Model Selection width
1045
+ controls_group = QGroupBox("Video settings & controls")
1046
+ # Set width to match Model Selection container (approximately 350-400px)
1047
+ controls_group.setFixedWidth(380)
1048
+ controls_layout = QVBoxLayout()
1049
+
1050
+ # Load Video button
1051
+ self.btn_load = QPushButton("Load videos")
1052
+ self.btn_load.clicked.connect(self._load_video)
1053
+ controls_layout.addWidget(self.btn_load)
1054
+
1055
+ # Video list
1056
+ self.video_list_widget = QListWidget()
1057
+ self.video_list_widget.setSelectionMode(QListWidget.SelectionMode.SingleSelection)
1058
+ self.video_list_widget.currentRowChanged.connect(self._on_video_selected)
1059
+ controls_layout.addWidget(self.video_list_widget)
1060
+
1061
+ # Object Controls
1062
+ obj_layout = QHBoxLayout()
1063
+ obj_layout.addWidget(QLabel("Object:"))
1064
+
1065
+ self.combo_obj = QComboBox()
1066
+ self.combo_obj.addItem("Object 1", 1)
1067
+ self.combo_obj.currentIndexChanged.connect(self._on_object_changed)
1068
+ obj_layout.addWidget(self.combo_obj)
1069
+
1070
+ self.btn_add_obj = QPushButton("+")
1071
+ self.btn_add_obj.setFixedWidth(30)
1072
+ self.btn_add_obj.clicked.connect(self._add_object)
1073
+ obj_layout.addWidget(self.btn_add_obj)
1074
+
1075
+ controls_layout.addLayout(obj_layout)
1076
+
1077
+ # Point type
1078
+ point_type_layout = QHBoxLayout()
1079
+ self.radio_pos = QRadioButton("Positive (+)")
1080
+ self.radio_neg = QRadioButton("Negative (-)")
1081
+ self.radio_pos.setChecked(True)
1082
+ self.btn_group = QButtonGroup()
1083
+ self.btn_group.addButton(self.radio_pos)
1084
+ self.btn_group.addButton(self.radio_neg)
1085
+ point_type_layout.addWidget(self.radio_pos)
1086
+ point_type_layout.addWidget(self.radio_neg)
1087
+ controls_layout.addLayout(point_type_layout)
1088
+
1089
+ self.btn_clear_points = QPushButton("Clear points")
1090
+ self.btn_clear_points.clicked.connect(self._clear_points)
1091
+ controls_layout.addWidget(self.btn_clear_points)
1092
+
1093
+ controls_layout.addSpacing(10)
1094
+
1095
+ self.chk_auto_follow = QCheckBox("Auto-follow")
1096
+ self.chk_auto_follow.setChecked(True)
1097
+ self.chk_auto_follow.setToolTip("Automatically move slider to the frame being processed.")
1098
+ controls_layout.addWidget(self.chk_auto_follow)
1099
+
1100
+ self.btn_track = QPushButton("Run tracking (Current)")
1101
+ self.btn_track.clicked.connect(self._run_tracking)
1102
+ self.btn_track.setEnabled(False)
1103
+ controls_layout.addWidget(self.btn_track)
1104
+
1105
+ # Pause / Resume tracking controls
1106
+ pause_resume_layout = QHBoxLayout()
1107
+ self.btn_pause_tracking = QPushButton("Pause tracking")
1108
+ self.btn_pause_tracking.setEnabled(False)
1109
+ self.btn_pause_tracking.clicked.connect(self._pause_tracking)
1110
+ pause_resume_layout.addWidget(self.btn_pause_tracking)
1111
+
1112
+ self.btn_resume_tracking = QPushButton("Resume tracking from here")
1113
+ self.btn_resume_tracking.setEnabled(False)
1114
+ self.btn_resume_tracking.clicked.connect(self._resume_tracking)
1115
+ pause_resume_layout.addWidget(self.btn_resume_tracking)
1116
+
1117
+ controls_layout.addLayout(pause_resume_layout)
1118
+
1119
+ self.btn_track_all = QPushButton("Run tracking (All videos)")
1120
+ self.btn_track_all.clicked.connect(self._run_tracking_all)
1121
+ self.btn_track_all.setEnabled(False)
1122
+ controls_layout.addWidget(self.btn_track_all)
1123
+
1124
+ self.chk_save_overlay = QCheckBox("Save overlay video after tracking")
1125
+ self.chk_save_overlay.setChecked(self.save_overlay_video)
1126
+ self.chk_save_overlay.setToolTip(
1127
+ "Save an MP4 with colored mask overlays for later inspection.\n"
1128
+ "Also applies when tracking is paused."
1129
+ )
1130
+ self.chk_save_overlay.toggled.connect(lambda v: setattr(self, "save_overlay_video", bool(v)))
1131
+ controls_layout.addWidget(self.chk_save_overlay)
1132
+
1133
+ # SAM2 tracking resolution
1134
+ res_layout = QHBoxLayout()
1135
+ res_layout.addWidget(QLabel("Tracking resolution:"))
1136
+ self.tracking_res_combo = QComboBox()
1137
+ self.tracking_res_combo.addItem("256 (fastest, low quality)", 256)
1138
+ self.tracking_res_combo.addItem("384 (fast)", 384)
1139
+ self.tracking_res_combo.addItem("512 (balanced)", 512)
1140
+ self.tracking_res_combo.addItem("1024 (best quality, default)", 1024)
1141
+ self.tracking_res_combo.setCurrentIndex(3)
1142
+ self.tracking_res_combo.setToolTip(
1143
+ "Resolution at which SAM2 processes frames.\n"
1144
+ "Lower = faster tracking but less precise masks.\n"
1145
+ "512 is good for centroid/bbox extraction."
1146
+ )
1147
+ res_layout.addWidget(self.tracking_res_combo)
1148
+ controls_layout.addLayout(res_layout)
1149
+
1150
+ self.btn_preview = QPushButton("Preview frame")
1151
+ self.btn_preview.clicked.connect(self._preview_frame)
1152
+ self.btn_preview.setEnabled(False)
1153
+ controls_layout.addWidget(self.btn_preview)
1154
+
1155
+ self.btn_settings = QPushButton("Settings")
1156
+ self.btn_settings.clicked.connect(self._open_settings)
1157
+ controls_layout.addWidget(self.btn_settings)
1158
+
1159
+ controls_layout.addStretch()
1160
+
1161
+ controls_group.setLayout(controls_layout)
1162
+ video_row_layout.addWidget(controls_group)
1163
+
1164
+ layout.addLayout(video_row_layout)
1165
+
1166
+ # Slider
1167
+ self.slider = QSlider(Qt.Orientation.Horizontal)
1168
+ self.slider.sliderMoved.connect(self._set_frame)
1169
+ self.slider.setEnabled(False)
1170
+ layout.addWidget(self.slider)
1171
+
1172
+ # Frame navigation row
1173
+ nav_layout = QHBoxLayout()
1174
+
1175
+ self.btn_prev_frame = QPushButton("<")
1176
+ self.btn_prev_frame.setFixedWidth(40)
1177
+ self.btn_prev_frame.clicked.connect(self._prev_frame)
1178
+ self.btn_prev_frame.setEnabled(False)
1179
+ nav_layout.addWidget(self.btn_prev_frame)
1180
+
1181
+ self.lbl_frame = QLabel("Frame: 0 / 0")
1182
+ nav_layout.addWidget(self.lbl_frame, stretch=1)
1183
+
1184
+ self.btn_next_frame = QPushButton(">")
1185
+ self.btn_next_frame.setFixedWidth(40)
1186
+ self.btn_next_frame.clicked.connect(self._next_frame)
1187
+ self.btn_next_frame.setEnabled(False)
1188
+ nav_layout.addWidget(self.btn_next_frame)
1189
+
1190
+ layout.addLayout(nav_layout)
1191
+
1192
+ # Progress bar
1193
+ self.progress_bar = QProgressBar()
1194
+ self.progress_bar.setVisible(False)
1195
+ layout.addWidget(self.progress_bar)
1196
+
1197
+ # Log area
1198
+ self.log_text = QLabel("")
1199
+ self.log_text.setWordWrap(True)
1200
+ self.log_text.setMaximumHeight(80)
1201
+ layout.addWidget(self.log_text)
1202
+
1203
+ def _download_checkpoints(self):
1204
+ """Prompt user about checkpoint download (now handled automatically)."""
1205
+ QMessageBox.information(
1206
+ self,
1207
+ "Checkpoint Download",
1208
+ "Checkpoints are now downloaded automatically when you select a model.\n\n"
1209
+ "Simply select a model from the dropdown above, and it will be downloaded if not already present."
1210
+ )
1211
+
1212
+ def _check_model_availability(self):
1213
+ """Check if selected model checkpoint exists."""
1214
+ if not self.sam2_available:
1215
+ self.model_status_label.setText("SAM2 not installed. Run install.sh and reopen the app.")
1216
+ self.model_status_label.setStyleSheet("color: red;")
1217
+ return
1218
+
1219
+ idx = self.combo_model.currentIndex()
1220
+ if idx < 0:
1221
+ return
1222
+
1223
+ model_name = self.combo_model.itemData(idx)
1224
+ if not model_name:
1225
+ return
1226
+
1227
+ ckpt_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
1228
+
1229
+ if os.path.exists(ckpt_path):
1230
+ file_size = os.path.getsize(ckpt_path) / (1024**2)
1231
+ if file_size > 10: # Reasonable size check
1232
+ self.model_status_label.setText(f"{model_name} available ({file_size:.1f} MB)")
1233
+ self.model_status_label.setStyleSheet("color: green;")
1234
+ else:
1235
+ self.model_status_label.setText(f"{model_name} file seems corrupted ({file_size:.1f} MB). Will re-download.")
1236
+ self.model_status_label.setStyleSheet("color: orange;")
1237
+ else:
1238
+ self.model_status_label.setText(f"{model_name} not found. Will download automatically when selected.")
1239
+ self.model_status_label.setStyleSheet("color: orange;")
1240
+
1241
+ def _on_model_selected(self):
1242
+ """Handle model selection change."""
1243
+ idx = self.combo_model.currentIndex()
1244
+ if idx < 0:
1245
+ return
1246
+
1247
+ model_name = self.combo_model.itemData(idx)
1248
+ if not model_name:
1249
+ return
1250
+
1251
+ ckpt_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
1252
+
1253
+ # Check if checkpoint exists and is valid
1254
+ if os.path.exists(ckpt_path):
1255
+ file_size = os.path.getsize(ckpt_path) / (1024**2)
1256
+ if file_size > 10: # Reasonable size check (should be >100MB typically)
1257
+ self.model_status_label.setText(f"{model_name} ready ({file_size:.1f} MB)")
1258
+ self.model_status_label.setStyleSheet("color: green;")
1259
+ self.checkpoint_path = ckpt_path
1260
+ self._on_checkpoint_changed(model_name)
1261
+ return
1262
+
1263
+ # Checkpoint doesn't exist or is invalid, download it
1264
+ if model_name not in CheckpointDownloadWorker.MODEL_URLS:
1265
+ self.model_status_label.setText(f"Unknown model: {model_name}")
1266
+ self.model_status_label.setStyleSheet("color: red;")
1267
+ return
1268
+
1269
+ # Start download
1270
+ self._download_checkpoint(model_name, ckpt_path, CheckpointDownloadWorker.MODEL_URLS[model_name])
1271
+
1272
+ def _download_checkpoint(self, model_name, ckpt_path, model_url):
1273
+ """Download a checkpoint file."""
1274
+ if self.download_worker and self.download_worker.isRunning():
1275
+ QMessageBox.warning(self, "Download in progress", "A checkpoint download is already in progress.")
1276
+ return
1277
+
1278
+ # Ensure checkpoints directory exists
1279
+ ckpt_dir = os.path.dirname(ckpt_path)
1280
+ os.makedirs(ckpt_dir, exist_ok=True)
1281
+
1282
+ self.model_status_label.setText(f"Downloading {model_name}...")
1283
+ self.model_status_label.setStyleSheet("color: blue;")
1284
+ self.download_progress.setVisible(True)
1285
+ self.download_progress.setRange(0, 0) # Indeterminate
1286
+
1287
+ self.download_worker = CheckpointDownloadWorker(model_name, ckpt_path, model_url)
1288
+ self.download_worker.progress.connect(self._on_download_progress)
1289
+ self.download_worker.finished.connect(self._on_download_finished)
1290
+ self.download_worker.start()
1291
+
1292
+ def _on_download_progress(self, message):
1293
+ """Handle download progress updates."""
1294
+ self.model_status_label.setText(message)
1295
+
1296
+ def _on_download_finished(self, success, message):
1297
+ """Handle download completion."""
1298
+ self.download_progress.setVisible(False)
1299
+
1300
+ if success:
1301
+ self.model_status_label.setText(f"{message}")
1302
+ self.model_status_label.setStyleSheet("color: green;")
1303
+ model_name = self.combo_model.currentData()
1304
+ self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
1305
+ self._on_checkpoint_changed(model_name)
1306
+ else:
1307
+ self.model_status_label.setText(f"{message}")
1308
+ self.model_status_label.setStyleSheet("color: red;")
1309
+ QMessageBox.critical(self, "Download failed", message)
1310
+
1311
+ def _populate_checkpoints(self):
1312
+ """Populate checkpoint combo box (legacy, for compatibility)."""
1313
+ self.combo_ckpt.clear()
1314
+ ckpt_dir = os.path.join(self.sam2_dir, "checkpoints")
1315
+ if os.path.exists(ckpt_dir):
1316
+ checkpoints = [f for f in os.listdir(ckpt_dir) if f.endswith(".pt")]
1317
+ checkpoints.sort()
1318
+ self.combo_ckpt.addItems(checkpoints)
1319
+
1320
+ default = "sam2.1_hiera_large.pt"
1321
+ if default in checkpoints:
1322
+ self.combo_ckpt.setCurrentText(default)
1323
+ elif checkpoints:
1324
+ self.combo_ckpt.setCurrentIndex(0)
1325
+
1326
+ def _on_checkpoint_changed(self, ckpt_name):
1327
+ """Handle checkpoint selection change."""
1328
+ if not ckpt_name:
1329
+ return
1330
+
1331
+ # Update checkpoint path
1332
+ self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", ckpt_name)
1333
+ self.model_cfg = self._get_model_config(ckpt_name)
1334
+
1335
+ self.predictor = None
1336
+ self.inference_state = None
1337
+ self.masks = {}
1338
+ self.points = []
1339
+ # Clean up any incremental mask temp file
1340
+ if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
1341
+ try:
1342
+ self._incremental_mask_file.close()
1343
+ os.unlink(self._incremental_mask_file.name)
1344
+ except:
1345
+ pass
1346
+ self._incremental_mask_file = None
1347
+ self._update_frame()
1348
+
1349
+ def _get_model_config(self, ckpt_name):
1350
+ """Map checkpoint name to config."""
1351
+ if "sam2.1" in ckpt_name:
1352
+ prefix = "configs/sam2.1/sam2.1_hiera_"
1353
+ else:
1354
+ prefix = "configs/sam2/sam2_hiera_"
1355
+
1356
+ if "large" in ckpt_name.lower():
1357
+ return prefix + "l.yaml"
1358
+ elif "base_plus" in ckpt_name.lower() or "b+" in ckpt_name.lower():
1359
+ return prefix + "b+.yaml"
1360
+ elif "small" in ckpt_name.lower():
1361
+ return prefix + "s.yaml"
1362
+ elif "tiny" in ckpt_name.lower():
1363
+ return prefix + "t.yaml"
1364
+
1365
+ return prefix + "l.yaml"
1366
+
1367
+ def _toggle_range_inputs(self, checked):
1368
+ """Toggle range input widgets."""
1369
+ self.spin_start.setEnabled(checked)
1370
+ self.spin_end.setEnabled(checked)
1371
+ self.btn_set_start.setEnabled(checked)
1372
+ self.btn_set_end.setEnabled(checked)
1373
+
1374
+ def _set_range_start(self):
1375
+ """Apply the start frame chosen in the spin box (clamped to video length)."""
1376
+ if not self.cap:
1377
+ return
1378
+
1379
+ start_val = max(0, min(self.spin_start.value(), self.total_frames - 1))
1380
+
1381
+ # Clamp and update start value
1382
+ self.spin_start.blockSignals(True)
1383
+ self.spin_start.setValue(start_val)
1384
+ self.spin_start.blockSignals(False)
1385
+
1386
+ # Ensure start is not beyond end
1387
+ if start_val > self.spin_end.value():
1388
+ self.spin_end.setValue(start_val)
1389
+
1390
+ # Jump preview to start frame so the user sees what was set
1391
+ self.slider.setValue(start_val)
1392
+ self._set_frame(start_val)
1393
+
1394
+ def _set_range_end(self):
1395
+ """Apply the end frame chosen in the spin box (clamped to video length)."""
1396
+ if not self.cap:
1397
+ return
1398
+
1399
+ end_val = max(0, min(self.spin_end.value(), self.total_frames - 1))
1400
+
1401
+ # Clamp and update end value
1402
+ self.spin_end.blockSignals(True)
1403
+ self.spin_end.setValue(end_val)
1404
+ self.spin_end.blockSignals(False)
1405
+
1406
+ # Ensure end is not before start
1407
+ if end_val < self.spin_start.value():
1408
+ self.spin_start.setValue(end_val)
1409
+
1410
+ # Jump preview to end frame so the user sees what was set
1411
+ self.slider.setValue(end_val)
1412
+ self._set_frame(end_val)
1413
+
1414
+ def _add_object(self):
1415
+ """Add a new object ID."""
1416
+ new_id = max(self.obj_ids) + 1
1417
+ self.obj_ids.append(new_id)
1418
+ self.combo_obj.addItem(f"Object {new_id}", new_id)
1419
+ self.combo_obj.setCurrentIndex(self.combo_obj.count() - 1)
1420
+
1421
+ def _on_object_changed(self):
1422
+ """Handle object selection change."""
1423
+ self.current_obj_id = self.combo_obj.currentData()
1424
+
1425
+ def _create_video_state(self, path):
1426
+ """Create a per-video state dict."""
1427
+ cap = cv2.VideoCapture(path)
1428
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1429
+ cap.release()
1430
+ return {
1431
+ "path": path,
1432
+ "total_frames": total_frames,
1433
+ "points": [],
1434
+ "masks": {},
1435
+ "obj_ids": [1],
1436
+ "current_obj_id": 1,
1437
+ "current_frame_idx": 0,
1438
+ "spin_start": 0,
1439
+ "spin_end": max(total_frames - 1, 0),
1440
+ "inference_state": None,
1441
+ "state_start_frame": 0,
1442
+ }
1443
+
1444
+ def _save_current_video_state(self):
1445
+ """Persist current UI state into the active video entry."""
1446
+ if self.current_video_idx is None or self.current_video_idx >= len(self.videos):
1447
+ return
1448
+ try:
1449
+ v = self.videos[self.current_video_idx]
1450
+ v["points"] = list(self.points)
1451
+ v["masks"] = dict(self.masks)
1452
+ v["obj_ids"] = list(self.obj_ids)
1453
+ v["current_obj_id"] = self.current_obj_id
1454
+ v["current_frame_idx"] = self.current_frame_idx
1455
+ v["spin_start"] = self.spin_start.value()
1456
+ v["spin_end"] = self.spin_end.value()
1457
+ v["inference_state"] = self.inference_state
1458
+ v["state_start_frame"] = self.state_start_frame
1459
+ except Exception:
1460
+ pass
1461
+
1462
+ def _apply_video_state(self, idx: int):
1463
+ """Load a video's state into the UI and current attributes."""
1464
+ if idx < 0 or idx >= len(self.videos):
1465
+ return
1466
+ self.current_video_idx = idx
1467
+ v = self.videos[idx]
1468
+ self.video_path = v["path"]
1469
+ self.cap = cv2.VideoCapture(self.video_path)
1470
+ self.total_frames = v["total_frames"]
1471
+
1472
+ # Ranges
1473
+ self.slider.setRange(0, max(self.total_frames - 1, 0))
1474
+ self.spin_start.setRange(0, max(self.total_frames - 1, 0))
1475
+ self.spin_end.setRange(0, max(self.total_frames - 1, 0))
1476
+ self.spin_start.setValue(min(v["spin_start"], max(self.total_frames - 1, 0)))
1477
+ self.spin_end.setValue(min(v["spin_end"], max(self.total_frames - 1, 0)))
1478
+
1479
+ # Restore points/masks/objects
1480
+ self.points = list(v["points"])
1481
+ self.masks = dict(v["masks"])
1482
+ self.obj_ids = list(v["obj_ids"])
1483
+ self.current_obj_id = v["current_obj_id"]
1484
+ self.current_frame_idx = min(v["current_frame_idx"], max(self.total_frames - 1, 0))
1485
+ self.inference_state = v["inference_state"]
1486
+ self.state_start_frame = v.get("state_start_frame", 0)
1487
+
1488
+ # Rebuild object combo
1489
+ self.combo_obj.blockSignals(True)
1490
+ self.combo_obj.clear()
1491
+ for oid in self.obj_ids:
1492
+ self.combo_obj.addItem(f"Object {oid}", oid)
1493
+ idx_obj = self.combo_obj.findData(self.current_obj_id)
1494
+ if idx_obj >= 0:
1495
+ self.combo_obj.setCurrentIndex(idx_obj)
1496
+ self.combo_obj.blockSignals(False)
1497
+
1498
+ # Enable controls
1499
+ self.slider.setEnabled(True)
1500
+ self.btn_prev_frame.setEnabled(True)
1501
+ self.btn_next_frame.setEnabled(True)
1502
+ self.btn_track.setEnabled(self.sam2_available)
1503
+ self.btn_track_all.setEnabled(len(self.videos) > 0 and self.sam2_available)
1504
+ self.btn_preview.setEnabled(self.sam2_available)
1505
+
1506
+ # Move slider to current frame and refresh
1507
+ self.slider.blockSignals(True)
1508
+ self.slider.setValue(self.current_frame_idx)
1509
+ self.slider.blockSignals(False)
1510
+ self.zoom_factor = 1.0
1511
+ self._update_frame()
1512
+
1513
+ def _on_video_selected(self, row: int):
1514
+ """Handle selection change from the video list."""
1515
+ self._save_current_video_state()
1516
+ if row >= 0:
1517
+ self._apply_video_state(row)
1518
+
1519
+ def _load_video(self):
1520
+ """Load one or more video files."""
1521
+ video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
1522
+ paths, _ = QFileDialog.getOpenFileNames(
1523
+ self, "Open Videos", video_dir, "Video Files (*.mp4 *.avi *.mov *.mkv)"
1524
+ )
1525
+ if not paths:
1526
+ return
1527
+
1528
+ from .video_utils import ensure_video_in_experiment
1529
+
1530
+ added_any = False
1531
+ for path in paths:
1532
+ path = ensure_video_in_experiment(path, self.config, self)
1533
+ # Avoid duplicates
1534
+ if any(v["path"] == path for v in self.videos):
1535
+ continue
1536
+ state = self._create_video_state(path)
1537
+ self.videos.append(state)
1538
+ self.video_list_widget.addItem(os.path.basename(path))
1539
+ added_any = True
1540
+
1541
+ if not added_any:
1542
+ return
1543
+
1544
+ # Auto-select first added video if none active
1545
+ if self.current_video_idx is None and self.videos:
1546
+ self.video_list_widget.setCurrentRow(0)
1547
+ self._apply_video_state(0)
1548
+ else:
1549
+ # Keep current selection, just enable batch controls
1550
+ self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
1551
+
1552
+ def _ensure_predictor(self):
1553
+ """Ensure SAM2 model is loaded (rebuilds if resolution changed)."""
1554
+ tracking_res = self.tracking_res_combo.currentData() or 1024
1555
+ if self.predictor is not None:
1556
+ if getattr(self.predictor, "image_size", 1024) == tracking_res:
1557
+ return True
1558
+ # Resolution changed — need to rebuild
1559
+ del self.predictor
1560
+ self.predictor = None
1561
+ if torch.cuda.is_available():
1562
+ torch.cuda.empty_cache()
1563
+
1564
+ if not self.sam2_available:
1565
+ QMessageBox.warning(
1566
+ self,
1567
+ "SAM2 not available",
1568
+ "SAM2 is not installed in this environment.\n\nRun bash install.sh and reopen the app.",
1569
+ )
1570
+ return False
1571
+
1572
+ idx = self.combo_model.currentIndex()
1573
+ if idx < 0:
1574
+ QMessageBox.warning(self, "No model selected", "Please select a model.")
1575
+ return False
1576
+
1577
+ model_name = self.combo_model.itemData(idx)
1578
+ if not model_name:
1579
+ QMessageBox.warning(self, "No model selected", "Please select a model.")
1580
+ return False
1581
+
1582
+ self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
1583
+ if not os.path.exists(self.checkpoint_path):
1584
+ QMessageBox.warning(
1585
+ self,
1586
+ "Checkpoint Missing",
1587
+ f"Checkpoint not found:\n{self.checkpoint_path}\n\n"
1588
+ "The checkpoint should download automatically when you select a model.\n"
1589
+ "Please wait for the download to complete or select the model again."
1590
+ )
1591
+ return False
1592
+
1593
+ self.model_cfg = self._get_model_config(model_name)
1594
+
1595
+ try:
1596
+ # Ensure Hydra is initialized before importing SAM2
1597
+ if not self._ensure_hydra_initialized():
1598
+ QMessageBox.critical(
1599
+ self,
1600
+ "Hydra Initialization Failed",
1601
+ "Failed to initialize Hydra configuration system.\n\n"
1602
+ "Please ensure hydra-core is installed:\n"
1603
+ "pip install hydra-core>=1.3.2"
1604
+ )
1605
+ return False
1606
+
1607
+ # Import sam2 first to trigger its Hydra initialization
1608
+ import sam2
1609
+ # Then import the build function
1610
+ from sam2.build_sam import build_sam2_video_predictor
1611
+
1612
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1613
+ if device == "cpu":
1614
+ QMessageBox.warning(self, "CPU mode", "Running on CPU. This will be very slow.")
1615
+
1616
+ hydra_extra = [f"++model.image_size={tracking_res}"]
1617
+ self.predictor = build_sam2_video_predictor(
1618
+ self.model_cfg, self.checkpoint_path, device=device,
1619
+ hydra_overrides_extra=hydra_extra,
1620
+ )
1621
+ self.predictor.fill_hole_area = self.fill_hole_area
1622
+ return True
1623
+ except RuntimeError as e:
1624
+ if "parent directory" in str(e) or "shadowed" in str(e):
1625
+ QMessageBox.critical(
1626
+ self,
1627
+ "SAM2 Import Error",
1628
+ f"SAM2 import conflict:\n{e}\n\n"
1629
+ "Solution: Please install SAM2 to a location outside the behavior_labeling_app directory.\n"
1630
+ "Use the 'Change...' button to select a different installation location."
1631
+ )
1632
+ else:
1633
+ QMessageBox.critical(self, "Error", f"Failed to init SAM2 model:\n{e}")
1634
+ return False
1635
+ except Exception as e:
1636
+ QMessageBox.critical(self, "Error", f"Failed to init SAM2 model:\n{e}")
1637
+ return False
1638
+
1639
+ def _set_frame(self, frame_idx):
1640
+ """Set current frame index."""
1641
+ self.current_frame_idx = frame_idx
1642
+ self._update_frame()
1643
+
1644
+ def eventFilter(self, source, event):
1645
+ if source is self.video_scroll.viewport() and event.type() == QEvent.Type.Resize:
1646
+ self._position_zoom_buttons()
1647
+ return super().eventFilter(source, event)
1648
+
1649
+ def _style_zoom_button(self, btn):
1650
+ btn.setFixedSize(34, 34)
1651
+ btn.setCursor(Qt.CursorShape.PointingHandCursor)
1652
+ btn.setStyleSheet(
1653
+ "QPushButton {"
1654
+ "background-color: rgba(20, 20, 20, 190);"
1655
+ "color: white;"
1656
+ "border: 1px solid rgba(255, 255, 255, 120);"
1657
+ "border-radius: 17px;"
1658
+ "font-size: 18px;"
1659
+ "font-weight: bold;"
1660
+ "}"
1661
+ "QPushButton:hover {"
1662
+ "background-color: rgba(45, 45, 45, 220);"
1663
+ "}"
1664
+ )
1665
+
1666
+ def _position_zoom_buttons(self):
1667
+ if not hasattr(self, "video_scroll") or not hasattr(self, "btn_zoom_in"):
1668
+ return
1669
+ viewport = self.video_scroll.viewport()
1670
+ margin = 10
1671
+ spacing = 8
1672
+ x = viewport.width() - self.btn_zoom_in.width() - margin
1673
+ y = margin
1674
+ self.btn_zoom_in.move(x, y)
1675
+ self.btn_zoom_out.move(x, y + self.btn_zoom_in.height() + spacing)
1676
+ self.btn_zoom_in.raise_()
1677
+ self.btn_zoom_out.raise_()
1678
+
1679
+ def _zoom_in(self):
1680
+ self.zoom_factor = min(self.zoom_max, self.zoom_factor + self.zoom_step)
1681
+ self._apply_zoom()
1682
+
1683
+ def _zoom_out(self):
1684
+ self.zoom_factor = max(self.zoom_min, self.zoom_factor - self.zoom_step)
1685
+ self._apply_zoom()
1686
+
1687
+ def _apply_zoom(self):
1688
+ if self._base_display_pixmap is None:
1689
+ return
1690
+ w = max(1, int(self._base_display_pixmap.width() * self.zoom_factor))
1691
+ h = max(1, int(self._base_display_pixmap.height() * self.zoom_factor))
1692
+ scaled = self._base_display_pixmap.scaled(
1693
+ w,
1694
+ h,
1695
+ Qt.AspectRatioMode.KeepAspectRatio,
1696
+ Qt.TransformationMode.SmoothTransformation,
1697
+ )
1698
+ self.video_label.setPixmap(scaled)
1699
+ self.video_label.resize(scaled.size())
1700
+
1701
+ def _prev_frame(self):
1702
+ """Go to previous frame."""
1703
+ if self.current_frame_idx > 0:
1704
+ self.current_frame_idx -= 1
1705
+ self.slider.blockSignals(True)
1706
+ self.slider.setValue(self.current_frame_idx)
1707
+ self.slider.blockSignals(False)
1708
+ self._update_frame()
1709
+
1710
+ def _next_frame(self):
1711
+ """Go to next frame."""
1712
+ if self.current_frame_idx < self.total_frames - 1:
1713
+ self.current_frame_idx += 1
1714
+ self.slider.blockSignals(True)
1715
+ self.slider.setValue(self.current_frame_idx)
1716
+ self.slider.blockSignals(False)
1717
+ self._update_frame()
1718
+
1719
+ def _update_frame(self):
1720
+ """Update video display with current frame and overlays."""
1721
+ if not self.cap:
1722
+ return
1723
+
1724
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame_idx)
1725
+ ret, frame = self.cap.read()
1726
+ if ret:
1727
+ self.frame = frame
1728
+ # Keep RGB frame as instance var to prevent QImage memory issues
1729
+ self._frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1730
+ h, w, ch = self._frame_rgb.shape
1731
+
1732
+ bytes_per_line = ch * w
1733
+ q_img = QImage(self._frame_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
1734
+
1735
+ pixmap = QPixmap.fromImage(q_img)
1736
+ painter = QPainter(pixmap)
1737
+
1738
+ # Draw masks
1739
+ if self.current_frame_idx in self.masks:
1740
+ frame_masks = self.masks[self.current_frame_idx]
1741
+ for obj_id, mask in frame_masks.items():
1742
+ if mask is not None and mask.max() > 0:
1743
+ mask_h, mask_w = mask.shape[:2]
1744
+ if mask_h != h or mask_w != w:
1745
+ mask = cv2.resize(mask.astype(np.float32), (w, h), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
1746
+
1747
+ overlay = np.zeros((h, w, 4), dtype=np.uint8)
1748
+ color = get_obj_color(obj_id)
1749
+ overlay[mask > 0] = [color[0], color[1], color[2], 100]
1750
+
1751
+ overlay_img = QImage(overlay.data, w, h, w * 4, QImage.Format.Format_RGBA8888)
1752
+ painter.drawImage(0, 0, overlay_img)
1753
+
1754
+ # Draw points
1755
+ for p in self.points:
1756
+ if p[3] == self.current_frame_idx:
1757
+ x, y, label, obj_id = p[0], p[1], p[2], p[4]
1758
+ obj_color_rgb = get_obj_color(obj_id)
1759
+
1760
+ if label == 1:
1761
+ painter.setPen(QPen(QColor(*obj_color_rgb), 5))
1762
+ painter.drawPoint(x, y)
1763
+ else:
1764
+ painter.setPen(QPen(QColor(255, 0, 0), 5))
1765
+ painter.drawPoint(x, y)
1766
+
1767
+ painter.setPen(QPen(QColor(255, 255, 255), 1))
1768
+ painter.drawText(x + 5, y - 5, str(obj_id))
1769
+
1770
+ painter.end()
1771
+ self._base_display_pixmap = pixmap
1772
+ self._apply_zoom()
1773
+
1774
+ self.lbl_frame.setText(f"Frame: {self.current_frame_idx} / {self.total_frames}")
1775
+
1776
+ def _handle_click(self, x, y):
1777
+ """Handle click on video label."""
1778
+ if self.frame is None:
1779
+ return
1780
+
1781
+ pixmap = self.video_label.pixmap()
1782
+ if not pixmap:
1783
+ return
1784
+
1785
+ label_w = self.video_label.width()
1786
+ label_h = self.video_label.height()
1787
+ pix_w = pixmap.width()
1788
+ pix_h = pixmap.height()
1789
+
1790
+ x_offset = (label_w - pix_w) / 2
1791
+ y_offset = (label_h - pix_h) / 2
1792
+
1793
+ img_x = x - x_offset
1794
+ img_y = y - y_offset
1795
+
1796
+ if 0 <= img_x < pix_w and 0 <= img_y < pix_h:
1797
+ orig_h, orig_w = self.frame.shape[:2]
1798
+ scale_x = orig_w / pix_w
1799
+ scale_y = orig_h / pix_h
1800
+
1801
+ final_x = int(img_x * scale_x)
1802
+ final_y = int(img_y * scale_y)
1803
+
1804
+ label = 1 if self.radio_pos.isChecked() else 0
1805
+ self.points.append((final_x, final_y, label, self.current_frame_idx, self.current_obj_id))
1806
+ self._preview_frame()
1807
+
1808
+ def _clear_points(self):
1809
+ """Clear all points."""
1810
+ self.points = []
1811
+ self.masks = {}
1812
+ # Clean up any incremental mask temp file
1813
+ if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
1814
+ try:
1815
+ self._incremental_mask_file.close()
1816
+ os.unlink(self._incremental_mask_file.name)
1817
+ except:
1818
+ pass
1819
+ self._incremental_mask_file = None
1820
+ # Reset inference state to clear any cached predictions
1821
+ if self.inference_state and self.predictor:
1822
+ try:
1823
+ self.predictor.reset_state(self.inference_state)
1824
+ except:
1825
+ pass
1826
+ self._update_frame()
1827
+
1828
+ def _preview_frame(self):
1829
+ """Preview segmentation on current frame."""
1830
+ if not self.video_path:
1831
+ return
1832
+
1833
+ if not self._ensure_predictor():
1834
+ return
1835
+
1836
+ # Gather points for current frame and object
1837
+ current_points = []
1838
+ current_labels = []
1839
+ for x, y, label, frame_idx, obj_id in self.points:
1840
+ if frame_idx == self.current_frame_idx and obj_id == self.current_obj_id:
1841
+ current_points.append([x, y])
1842
+ current_labels.append(label)
1843
+
1844
+ if not current_points:
1845
+ return
1846
+
1847
+ # Force fresh state load for preview to avoid any stale cached data
1848
+ # This is slower but ensures accurate preview
1849
+ if not self._load_state(self.current_frame_idx, self.current_frame_idx + 1):
1850
+ return
1851
+
1852
+ # Reset any previous tracking state to get clean prediction
1853
+ try:
1854
+ self.predictor.reset_state(self.inference_state)
1855
+ except:
1856
+ pass
1857
+
1858
+ try:
1859
+ pts = np.array(current_points, dtype=np.float32)
1860
+ lbls = np.array(current_labels, dtype=np.int32)
1861
+
1862
+ local_frame_idx = 0 # We just loaded a single frame, so local index is 0
1863
+
1864
+ _, out_obj_ids, out_mask_logits = self._sam2_call(
1865
+ self.predictor.add_new_points_or_box,
1866
+ inference_state=self.inference_state,
1867
+ frame_idx=local_frame_idx,
1868
+ obj_id=self.current_obj_id,
1869
+ points=pts,
1870
+ labels=lbls,
1871
+ clear_old_points=True,
1872
+ normalize_coords=True,
1873
+ )
1874
+
1875
+ if self.current_obj_id in out_obj_ids:
1876
+ idx = out_obj_ids.index(self.current_obj_id)
1877
+ mask_logit = out_mask_logits[idx]
1878
+ if mask_logit.ndim == 3:
1879
+ mask_logit = mask_logit[0]
1880
+ mask = (mask_logit > self.mask_threshold).cpu().numpy().astype(np.uint8).squeeze()
1881
+
1882
+ if self.current_frame_idx not in self.masks:
1883
+ self.masks[self.current_frame_idx] = {}
1884
+ self.masks[self.current_frame_idx][self.current_obj_id] = mask
1885
+ self._update_frame()
1886
+ else:
1887
+ # If preview did not return this object, clear stale overlay for it.
1888
+ if self.current_frame_idx in self.masks and self.current_obj_id in self.masks[self.current_frame_idx]:
1889
+ del self.masks[self.current_frame_idx][self.current_obj_id]
1890
+ self._update_frame()
1891
+ except Exception as e:
1892
+ QMessageBox.critical(self, "Error", f"Preview failed:\n{e}")
1893
+
1894
+ def _ensure_state_for_frame(self, frame_idx):
1895
+ """Ensure state is loaded for specific frame."""
1896
+ if self.inference_state:
1897
+ local_idx = frame_idx - self.state_start_frame
1898
+ if 0 <= local_idx < self.inference_state["num_frames"]:
1899
+ return True
1900
+
1901
+ return self._load_state(frame_idx, frame_idx + 1)
1902
+
1903
+ def _load_state(self, start_frame, end_frame):
1904
+ """Load video frames into SAM2 state."""
1905
+ if not self._ensure_predictor():
1906
+ return False
1907
+
1908
+ if not self.video_path:
1909
+ return False
1910
+
1911
+ try:
1912
+ try:
1913
+ import decord
1914
+ except ImportError:
1915
+ QMessageBox.warning(
1916
+ self,
1917
+ "Missing Dependency",
1918
+ "decord not found. Please install it:\n\n"
1919
+ "pip install eva-decord\n\n"
1920
+ "Or: conda install -c conda-forge decord"
1921
+ )
1922
+ return False
1923
+
1924
+ from collections import OrderedDict
1925
+
1926
+ decord.bridge.set_bridge("torch")
1927
+ compute_device = self.predictor.device
1928
+ image_size = self.predictor.image_size
1929
+
1930
+ # Get original video dimensions (needed for coordinate normalization)
1931
+ vr_meta = decord.VideoReader(self.video_path)
1932
+ video_height, video_width, _ = vr_meta[0].shape
1933
+ total_frames = len(vr_meta)
1934
+ del vr_meta # Free memory immediately after getting dimensions
1935
+
1936
+ # Load frames at SAM2's internal image_size (square)
1937
+ # SAM2 uses original video_height/video_width for coordinate normalization
1938
+ vr = decord.VideoReader(self.video_path, width=image_size, height=image_size)
1939
+ target_dtype = getattr(self.predictor, "dtype", torch.float32)
1940
+ if self._use_cuda_bf16() and not self.offload_video:
1941
+ target_dtype = torch.bfloat16
1942
+
1943
+ if end_frame is None or end_frame > total_frames:
1944
+ end_frame = total_frames
1945
+ start_frame = max(0, start_frame)
1946
+
1947
+ if start_frame >= end_frame:
1948
+ return False
1949
+
1950
+ # Limit number of frames loaded at once to prevent RAM issues
1951
+ num_frames = end_frame - start_frame
1952
+ if num_frames > self.max_frames_per_load:
1953
+ # Only load the last max_frames_per_load frames to keep memory usage reasonable
1954
+ start_frame = max(start_frame, end_frame - self.max_frames_per_load)
1955
+ if start_frame < 0:
1956
+ start_frame = 0
1957
+ QMessageBox.warning(
1958
+ self,
1959
+ "Frame Range Limited",
1960
+ f"Requested {num_frames} frames, but limiting to {self.max_frames_per_load} frames "
1961
+ f"to prevent RAM issues.\n\nLoading frames {start_frame} to {end_frame}."
1962
+ )
1963
+
1964
+ indices = list(range(start_frame, end_frame))
1965
+ frames = vr.get_batch(indices)
1966
+ del vr # Free VideoReader memory after loading frames
1967
+ images = frames.permute(0, 3, 1, 2).float() / 255.0
1968
+ del frames # Free original frame tensor after processing
1969
+
1970
+ img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)[:, None, None]
1971
+ img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)[:, None, None]
1972
+
1973
+ if not self.offload_video:
1974
+ images = images.to(compute_device, dtype=target_dtype)
1975
+ img_mean = img_mean.to(compute_device, dtype=target_dtype)
1976
+ img_std = img_std.to(compute_device, dtype=target_dtype)
1977
+ else:
1978
+ images = images.to(dtype=target_dtype)
1979
+ img_mean = img_mean.to(dtype=target_dtype)
1980
+ img_std = img_std.to(dtype=target_dtype)
1981
+
1982
+ images -= img_mean
1983
+ images /= img_std
1984
+
1985
+ # Convert to list format expected by SAM2
1986
+ images_list = [images[i] for i in range(len(images))]
1987
+
1988
+ inference_state = {}
1989
+ inference_state["images"] = images_list
1990
+ inference_state["num_frames"] = len(images_list)
1991
+ inference_state["offload_video_to_cpu"] = self.offload_video
1992
+ inference_state["offload_state_to_cpu"] = self.offload_state
1993
+ inference_state["video_height"] = video_height
1994
+ inference_state["video_width"] = video_width
1995
+ inference_state["device"] = compute_device
1996
+ if self.offload_state:
1997
+ inference_state["storage_device"] = torch.device("cpu")
1998
+ else:
1999
+ inference_state["storage_device"] = compute_device
2000
+
2001
+ inference_state["point_inputs_per_obj"] = {}
2002
+ inference_state["mask_inputs_per_obj"] = {}
2003
+ inference_state["cached_features"] = {}
2004
+ inference_state["constants"] = {}
2005
+ inference_state["obj_id_to_idx"] = OrderedDict()
2006
+ inference_state["obj_idx_to_id"] = OrderedDict()
2007
+ inference_state["obj_ids"] = []
2008
+ inference_state["output_dict_per_obj"] = {}
2009
+ inference_state["temp_output_dict_per_obj"] = {}
2010
+ inference_state["frames_tracked_per_obj"] = {}
2011
+
2012
+ try:
2013
+ self._sam2_call(self.predictor._get_image_feature, inference_state, frame_idx=0, batch_size=1)
2014
+ except Exception as e:
2015
+ pass
2016
+
2017
+ self.inference_state = inference_state
2018
+ self.state_start_frame = start_frame
2019
+ return True
2020
+ except Exception as e:
2021
+ QMessageBox.critical(self, "Error", f"Failed to load video state:\n{e}")
2022
+ return False
2023
+
2024
+ def _run_tracking(self, from_batch: bool = False):
2025
+ """Run tracking on current video."""
2026
+ if not self.video_path:
2027
+ QMessageBox.warning(self, "No Video", "Please load a video first.")
2028
+ return
2029
+ if not self.points:
2030
+ if from_batch and self.batch_mode and self.batch_queue:
2031
+ next_idx = self.batch_queue.pop(0)
2032
+ self._apply_video_state(next_idx)
2033
+ self._run_tracking(from_batch=True)
2034
+ return
2035
+ QMessageBox.warning(self, "No Points", "Please add some points first.")
2036
+ return
2037
+
2038
+ if not self._ensure_predictor():
2039
+ return
2040
+
2041
+ # Group points by frame and object
2042
+ points_grouped = {}
2043
+ for x, y, label, frame_idx, obj_id in self.points:
2044
+ key = (frame_idx, obj_id)
2045
+ if key not in points_grouped:
2046
+ points_grouped[key] = {'points': [], 'labels': []}
2047
+ points_grouped[key]['points'].append([x, y])
2048
+ points_grouped[key]['labels'].append(label)
2049
+
2050
+ if not points_grouped:
2051
+ if from_batch and self.batch_mode and self.batch_queue:
2052
+ next_idx = self.batch_queue.pop(0)
2053
+ self._apply_video_state(next_idx)
2054
+ self._run_tracking(from_batch=True)
2055
+ return
2056
+ QMessageBox.warning(self, "Warning", "No points in the selected range.")
2057
+ return
2058
+
2059
+ # Determine processing range to include user points
2060
+ min_frame = min(k[0] for k in points_grouped.keys())
2061
+ max_frame = max(k[0] for k in points_grouped.keys())
2062
+
2063
+ if hasattr(self, 'chk_limit_range') and self.chk_limit_range.isChecked():
2064
+ start_f = max(self.spin_start.value(), 0)
2065
+ end_f = min(self.spin_end.value() + 1, self.total_frames)
2066
+ # Ensure the range covers the annotated points
2067
+ start_f = min(start_f, min_frame)
2068
+ end_f = max(end_f, max_frame + 1)
2069
+ else:
2070
+ start_f = 0
2071
+ end_f = self.total_frames
2072
+
2073
+ # If resuming, allow forcing the start a bit earlier than the drift point
2074
+ if self.resume_from_frame is not None:
2075
+ # Resume from the chosen frame (or later), do not jump back to frame 0
2076
+ start_f = max(self.resume_from_frame, start_f, 0)
2077
+ self.resume_from_frame = None
2078
+
2079
+ self.btn_track.setEnabled(False)
2080
+ self.btn_track_all.setEnabled(False)
2081
+ self.progress_bar.setVisible(True)
2082
+ self.progress_bar.setRange(0, end_f - start_f)
2083
+ self.progress_bar.setValue(0)
2084
+
2085
+ # Get initial masks for resume (if any)
2086
+ initial_masks = getattr(self, 'resume_initial_masks', {}) or {}
2087
+
2088
+ # Clear resume flags
2089
+ self.resume_from_frame = None
2090
+ self.resume_initial_masks = {}
2091
+ self.tracking_paused = False
2092
+
2093
+ self.tracking_worker = TrackingWorker(
2094
+ self.predictor,
2095
+ self.video_path,
2096
+ points_grouped,
2097
+ start_f,
2098
+ end_f,
2099
+ self.mask_threshold,
2100
+ self.offload_video,
2101
+ self.offload_state,
2102
+ enable_memory_management=self.enable_memory_management,
2103
+ reseed_between_chunks=getattr(self, "reseed_between_chunks", False),
2104
+ initial_masks=initial_masks,
2105
+ enable_motion_tracking=getattr(self, "enable_motion_tracking", False),
2106
+ motion_score_threshold=getattr(self, "motion_score_threshold", 0.3),
2107
+ motion_consecutive_low=getattr(self, "motion_consecutive_low", 3),
2108
+ motion_area_threshold=getattr(self, "motion_area_threshold", 0.5),
2109
+ enable_ocsort=getattr(self, "enable_ocsort", False),
2110
+ ocsort_inertia=getattr(self, "ocsort_inertia", 0.2),
2111
+ use_cuda_bf16_autocast=getattr(self, "use_cuda_bf16_autocast", True),
2112
+ )
2113
+ self.tracking_worker.progress_signal.connect(lambda x: self.progress_bar.setValue(x - start_f) if x >= start_f else None)
2114
+ self.tracking_worker.frame_result_signal.connect(self._on_frame_result)
2115
+ self.tracking_worker.finished_signal.connect(self._on_tracking_finished)
2116
+ self.tracking_worker.error_signal.connect(self._on_tracking_error)
2117
+ self.tracking_worker.log_message.connect(lambda msg: self.log_text.setText(msg))
2118
+ self.tracking_worker.start()
2119
+
2120
+ # Enable pause, disable resume while running
2121
+ self.btn_pause_tracking.setEnabled(True)
2122
+ self.btn_resume_tracking.setEnabled(False)
2123
+
2124
+ def _pause_tracking(self):
2125
+ """Pause the current tracking run to allow adding new prompts."""
2126
+ if hasattr(self, "tracking_worker") and self.tracking_worker and self.tracking_worker.isRunning():
2127
+ self.tracking_paused = True
2128
+ self.tracking_worker.stop()
2129
+ self.log_text.setText("Stopping tracking... you can add points and resume.")
2130
+ self.btn_pause_tracking.setEnabled(False)
2131
+ self.btn_resume_tracking.setEnabled(False)
2132
+
2133
+ def _resume_tracking(self):
2134
+ """Resume tracking from the current frame (or last processed) after adding prompts."""
2135
+ if self.tracking_worker and self.tracking_worker.isRunning():
2136
+ return # already running
2137
+
2138
+ # Choose a resume frame: prefer slider position, fallback to last processed
2139
+ resume_frame = self.current_frame_idx if hasattr(self, "current_frame_idx") else None
2140
+ if resume_frame is None and self.last_processed_frame is not None:
2141
+ resume_frame = self.last_processed_frame
2142
+ if resume_frame is None:
2143
+ resume_frame = 0
2144
+
2145
+ # Collect refined masks at the resume frame to use as conditioning
2146
+ # This allows the user to refine the mask before resuming
2147
+ initial_masks = {}
2148
+ if resume_frame in self.masks:
2149
+ for obj_id, mask in self.masks[resume_frame].items():
2150
+ if mask is not None and mask.max() > 0:
2151
+ initial_masks[(resume_frame, obj_id)] = mask
2152
+ self.log_text.setText(f"Will use refined mask for object {obj_id} at frame {resume_frame}")
2153
+
2154
+ # Store initial masks for the worker
2155
+ self.resume_initial_masks = initial_masks
2156
+
2157
+ # Start a bit before the drift point for stability (but not before the mask frame)
2158
+ self.resume_from_frame = resume_frame # Start exactly from where user refined
2159
+ self.tracking_paused = False
2160
+
2161
+ # Re-run tracking; it will honor resume_from_frame and initial_masks
2162
+ self._run_tracking()
2163
+
2164
+ def _run_tracking_all(self):
2165
+ """Run tracking sequentially for all loaded videos."""
2166
+ if not self.videos:
2167
+ QMessageBox.warning(self, "No Videos", "Please load videos first.")
2168
+ return
2169
+ if not self.sam2_available:
2170
+ QMessageBox.warning(
2171
+ self,
2172
+ "SAM2 not available",
2173
+ "SAM2 is not installed in this environment.\n\nRun bash install.sh and reopen the app.",
2174
+ )
2175
+ return
2176
+
2177
+ # Save current state
2178
+ self._save_current_video_state()
2179
+
2180
+ # Build queue of indices
2181
+ self.batch_queue = list(range(len(self.videos)))
2182
+ self.batch_mode = True
2183
+
2184
+ # Start with the first video in the queue
2185
+ next_idx = self.batch_queue.pop(0)
2186
+ self._apply_video_state(next_idx)
2187
+ self._run_tracking(from_batch=True)
2188
+
2189
+ def _on_frame_result(self, frame_idx, frame_masks):
2190
+ """Handle real-time mask updates."""
2191
+ if frame_idx not in self.masks:
2192
+ self.masks[frame_idx] = {}
2193
+
2194
+ for obj_id, mask in frame_masks.items():
2195
+ self.masks[frame_idx][obj_id] = mask
2196
+
2197
+ # Track last processed frame for potential resume
2198
+ self.last_processed_frame = frame_idx
2199
+
2200
+ # Incremental save: periodically flush masks to disk to free RAM
2201
+ # Trigger when we exceed threshold (not just at exact multiples)
2202
+ if len(self.masks) >= 500:
2203
+ self._incremental_save_masks()
2204
+
2205
+ if frame_idx == self.current_frame_idx:
2206
+ self._update_frame()
2207
+
2208
+ if self.chk_auto_follow.isChecked():
2209
+ self.slider.blockSignals(True)
2210
+ self.slider.setValue(frame_idx)
2211
+ self.slider.blockSignals(False)
2212
+ self.current_frame_idx = frame_idx
2213
+ self._update_frame()
2214
+
2215
+ def _on_tracking_finished(self, masks):
2216
+ """Handle tracking completion."""
2217
+ for frame_idx, frame_masks in masks.items():
2218
+ if frame_idx not in self.masks:
2219
+ self.masks[frame_idx] = {}
2220
+ for obj_id, mask in frame_masks.items():
2221
+ self.masks[frame_idx][obj_id] = mask
2222
+
2223
+ self.btn_track.setEnabled(True)
2224
+ self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
2225
+ self.btn_pause_tracking.setEnabled(False)
2226
+ self.btn_resume_tracking.setEnabled(self.tracking_paused)
2227
+ self.progress_bar.setVisible(False)
2228
+ self._update_frame()
2229
+
2230
+ # Persist state for current video
2231
+ self._save_current_video_state()
2232
+ overlay_path = None
2233
+ if self.save_overlay_video:
2234
+ overlay_path = self._save_overlay_video(paused=self.tracking_paused)
2235
+
2236
+ # If user paused manually, do not save or show completion popups yet
2237
+ if self.tracking_paused:
2238
+ if overlay_path:
2239
+ self.log_text.setText(
2240
+ "Tracking paused. Partial overlay video saved to:\n"
2241
+ f"{overlay_path}\n"
2242
+ "Add points on current frame, click 'Preview frame', then 'Resume tracking from here'."
2243
+ )
2244
+ else:
2245
+ self.log_text.setText(
2246
+ "Tracking paused. Add points on current frame, click 'Preview frame', then 'Resume tracking from here'."
2247
+ )
2248
+ self.btn_resume_tracking.setEnabled(True)
2249
+ return
2250
+
2251
+ # Save masks automatically
2252
+ mask_path = self._save_masks()
2253
+
2254
+ # In batch mode, skip popups to continue processing
2255
+ if self.batch_mode:
2256
+ pass
2257
+ else:
2258
+ if mask_path and self.video_path:
2259
+ # Show completion message with option to go to registration
2260
+ overlay_text = f"Overlay video saved to: {overlay_path}\n\n" if overlay_path else ""
2261
+ reply = QMessageBox.question(
2262
+ self,
2263
+ "Tracking Completed",
2264
+ f"Tracking completed successfully!\n\n"
2265
+ f"Masks saved to: {mask_path}\n\n"
2266
+ f"{overlay_text}"
2267
+ "Would you like to proceed to the Registration tab to process this video?",
2268
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
2269
+ QMessageBox.StandardButton.Yes
2270
+ )
2271
+
2272
+ if reply == QMessageBox.StandardButton.Yes:
2273
+ # Emit signal to switch to registration tab
2274
+ self.tracking_completed.emit(self.video_path, mask_path)
2275
+ else:
2276
+ QMessageBox.information(self, "Success", "Tracking completed!")
2277
+
2278
+ self.inference_state = None
2279
+ self.tracking_paused = False
2280
+ self.resume_from_frame = None
2281
+
2282
+ # Continue batch if pending
2283
+ if self.batch_mode and self.batch_queue:
2284
+ next_idx = self.batch_queue.pop(0)
2285
+ self._apply_video_state(next_idx)
2286
+ self._run_tracking(from_batch=True)
2287
+ return
2288
+ # End batch
2289
+ self.batch_mode = False
2290
+ self.batch_queue = []
2291
+
2292
+ def _incremental_save_masks(self):
2293
+ """Save old masks to temp file and clear from memory to prevent RAM exhaustion."""
2294
+ if not self.video_path or len(self.masks) < 500:
2295
+ return
2296
+
2297
+ import pickle
2298
+ import tempfile
2299
+
2300
+ # Initialize temp file on first call
2301
+ if not hasattr(self, '_incremental_mask_file') or self._incremental_mask_file is None:
2302
+ video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
2303
+ self._incremental_mask_file = tempfile.NamedTemporaryFile(
2304
+ mode='wb', suffix=f'_{video_basename}_masks.pkl', delete=False
2305
+ )
2306
+ self._incremental_frame_indices = []
2307
+
2308
+ # Get frames to save (oldest 400 frames, keep 100 for display)
2309
+ sorted_frames = sorted(self.masks.keys())
2310
+ frames_to_save = sorted_frames[:400]
2311
+
2312
+ # Save to temp file
2313
+ chunk_data = {idx: self.masks[idx] for idx in frames_to_save}
2314
+ pickle.dump(chunk_data, self._incremental_mask_file)
2315
+ self._incremental_frame_indices.extend(frames_to_save)
2316
+
2317
+ # Clear from memory
2318
+ for idx in frames_to_save:
2319
+ del self.masks[idx]
2320
+
2321
+ gc.collect()
2322
+
2323
+ def _get_masks_snapshot_for_export(self):
2324
+ """Get a merged mask snapshot including incremental chunks without consuming them."""
2325
+ snapshot = {}
2326
+ for frame_idx, frame_data in self.masks.items():
2327
+ snapshot[frame_idx] = dict(frame_data)
2328
+
2329
+ if hasattr(self, "_incremental_mask_file") and self._incremental_mask_file is not None:
2330
+ import pickle
2331
+ try:
2332
+ self._incremental_mask_file.flush()
2333
+ except Exception:
2334
+ pass
2335
+ try:
2336
+ with open(self._incremental_mask_file.name, "rb") as f:
2337
+ while True:
2338
+ try:
2339
+ chunk = pickle.load(f)
2340
+ for frame_idx, frame_data in chunk.items():
2341
+ if frame_idx not in snapshot:
2342
+ snapshot[frame_idx] = dict(frame_data)
2343
+ else:
2344
+ snapshot[frame_idx].update(frame_data)
2345
+ except EOFError:
2346
+ break
2347
+ except Exception as e:
2348
+ logger.warning("Could not read incremental masks for overlay export: %s", e)
2349
+
2350
+ return snapshot
2351
+
2352
+ def _save_overlay_video(self, paused=False):
2353
+ """Save overlay video (original frame + colored masks) for inspection."""
2354
+ if not self.video_path:
2355
+ return None
2356
+
2357
+ all_masks = self._get_masks_snapshot_for_export()
2358
+ if not all_masks:
2359
+ return None
2360
+
2361
+ cap = cv2.VideoCapture(self.video_path)
2362
+ if not cap.isOpened():
2363
+ return None
2364
+
2365
+ video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
2366
+ video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
2367
+ fps = cap.get(cv2.CAP_PROP_FPS)
2368
+ if fps <= 0:
2369
+ fps = 30.0
2370
+
2371
+ save_start = 0
2372
+ save_end = max(all_masks.keys()) if all_masks else 0
2373
+ if hasattr(self, "chk_limit_range") and self.chk_limit_range.isChecked():
2374
+ save_start = self.spin_start.value()
2375
+ save_end = min(self.spin_end.value(), save_end)
2376
+ save_start = max(save_start, 0)
2377
+ save_end = max(save_end, save_start)
2378
+
2379
+ experiment_path = self.config.get("experiment_path")
2380
+ if experiment_path and os.path.exists(experiment_path):
2381
+ out_dir = os.path.join(experiment_path, "overlays")
2382
+ else:
2383
+ from singlebehaviorlab._paths import USER_DATA_DIR
2384
+ out_dir = str(USER_DATA_DIR / "data" / "overlays")
2385
+ os.makedirs(out_dir, exist_ok=True)
2386
+
2387
+ video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
2388
+ if video_basename.endswith("_masks"):
2389
+ video_basename = video_basename[:-6]
2390
+ suffix = "_tracking_overlay_paused.mp4" if paused else "_tracking_overlay.mp4"
2391
+ output_path = os.path.join(out_dir, f"{video_basename}{suffix}")
2392
+
2393
+ writer = cv2.VideoWriter(
2394
+ output_path,
2395
+ cv2.VideoWriter_fourcc(*"mp4v"),
2396
+ fps,
2397
+ (video_width, video_height),
2398
+ )
2399
+ if not writer.isOpened():
2400
+ cap.release()
2401
+ return None
2402
+
2403
+ cap.set(cv2.CAP_PROP_POS_FRAMES, save_start)
2404
+ alpha = 0.35
2405
+ total_to_write = max(0, save_end - save_start + 1)
2406
+ progress = QProgressDialog(
2407
+ "Saving overlay video...",
2408
+ "",
2409
+ 0,
2410
+ total_to_write,
2411
+ self
2412
+ )
2413
+ progress.setWindowTitle("Exporting Overlay Video")
2414
+ progress.setWindowModality(Qt.WindowModality.ApplicationModal)
2415
+ progress.setAutoClose(True)
2416
+ progress.setAutoReset(True)
2417
+ progress.setMinimumDuration(0)
2418
+ progress.setCancelButton(None)
2419
+ progress.setValue(0)
2420
+ QApplication.processEvents()
2421
+
2422
+ try:
2423
+ written = 0
2424
+ for frame_idx in range(save_start, save_end + 1):
2425
+ ret, frame = cap.read()
2426
+ if not ret:
2427
+ break
2428
+
2429
+ frame_masks = all_masks.get(frame_idx, {})
2430
+ for obj_id, mask in frame_masks.items():
2431
+ if mask is None or mask.max() == 0:
2432
+ continue
2433
+
2434
+ if mask.shape[0] != video_height or mask.shape[1] != video_width:
2435
+ mask = cv2.resize(
2436
+ mask.astype(np.float32),
2437
+ (video_width, video_height),
2438
+ interpolation=cv2.INTER_NEAREST
2439
+ ).astype(np.uint8)
2440
+
2441
+ idx = mask > 0
2442
+ if not np.any(idx):
2443
+ continue
2444
+
2445
+ color_rgb = get_obj_color(obj_id)
2446
+ color_bgr = np.array([color_rgb[2], color_rgb[1], color_rgb[0]], dtype=np.float32)
2447
+ frame_float = frame.astype(np.float32)
2448
+ frame_float[idx] = frame_float[idx] * (1.0 - alpha) + color_bgr * alpha
2449
+ frame = frame_float.astype(np.uint8)
2450
+
2451
+ ys, xs = np.where(idx)
2452
+ if len(xs) > 0 and len(ys) > 0:
2453
+ cx = int(np.mean(xs))
2454
+ cy = int(np.mean(ys))
2455
+ cv2.putText(
2456
+ frame,
2457
+ f"id:{obj_id}",
2458
+ (cx, cy),
2459
+ cv2.FONT_HERSHEY_SIMPLEX,
2460
+ 0.45,
2461
+ (255, 255, 255),
2462
+ 1,
2463
+ cv2.LINE_AA,
2464
+ )
2465
+
2466
+ writer.write(frame)
2467
+ written += 1
2468
+ if written % 5 == 0 or written == total_to_write:
2469
+ progress.setValue(written)
2470
+ QApplication.processEvents()
2471
+ finally:
2472
+ progress.setValue(total_to_write)
2473
+ writer.release()
2474
+ cap.release()
2475
+
2476
+ return output_path
2477
+
2478
+ def _save_masks(self):
2479
+ """Save masks in format compatible with animal_registration app."""
2480
+ if not self.video_path:
2481
+ return None
2482
+
2483
+ import cv2
2484
+ import pickle
2485
+
2486
+ # Merge incremental saves back into self.masks
2487
+ if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
2488
+ self._incremental_mask_file.close()
2489
+ try:
2490
+ with open(self._incremental_mask_file.name, 'rb') as f:
2491
+ while True:
2492
+ try:
2493
+ chunk = pickle.load(f)
2494
+ for frame_idx, frame_data in chunk.items():
2495
+ if frame_idx not in self.masks:
2496
+ self.masks[frame_idx] = frame_data
2497
+ except EOFError:
2498
+ break
2499
+ # Clean up temp file
2500
+ os.unlink(self._incremental_mask_file.name)
2501
+ except Exception as e:
2502
+ logger.warning("Could not load incremental masks: %s", e)
2503
+ finally:
2504
+ self._incremental_mask_file = None
2505
+ self._incremental_frame_indices = []
2506
+
2507
+ if not self.masks:
2508
+ return None
2509
+
2510
+ # Get video dimensions
2511
+ cap = cv2.VideoCapture(self.video_path)
2512
+ video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
2513
+ video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
2514
+ fps = cap.get(cv2.CAP_PROP_FPS)
2515
+ cap.release()
2516
+
2517
+ # Determine range to save (respect limit range if set)
2518
+ save_start = 0
2519
+ save_end = max(self.masks.keys()) if self.masks else 0
2520
+ if hasattr(self, "chk_limit_range") and self.chk_limit_range.isChecked():
2521
+ save_start = self.spin_start.value()
2522
+ save_end = self.spin_end.value()
2523
+ save_start = max(save_start, 0)
2524
+ save_end = max(save_end, save_start)
2525
+
2526
+ frame_objects = []
2527
+ num_frames = (save_end - save_start + 1) if self.masks else 0
2528
+
2529
+ for frame_idx_global in range(save_start, save_end + 1):
2530
+ frame_objs = []
2531
+ if frame_idx_global in self.masks:
2532
+ for obj_id, mask in self.masks[frame_idx_global].items():
2533
+ if mask is not None and mask.max() > 0:
2534
+ # Resize mask to video dimensions if needed
2535
+ if mask.shape[0] != video_height or mask.shape[1] != video_width:
2536
+ mask_resized = cv2.resize(
2537
+ mask.astype(np.float32),
2538
+ (video_width, video_height),
2539
+ interpolation=cv2.INTER_NEAREST
2540
+ ).astype(np.uint8)
2541
+ else:
2542
+ mask_resized = mask
2543
+
2544
+ # Find bounding box
2545
+ rows, cols = np.where(mask_resized > 0)
2546
+ if len(rows) > 0 and len(cols) > 0:
2547
+ y_min, y_max = np.min(rows), np.max(rows)
2548
+ x_min, x_max = np.min(cols), np.max(cols)
2549
+
2550
+ # Extract mask within bbox
2551
+ bbox_mask = mask_resized[y_min:y_max+1, x_min:x_max+1]
2552
+
2553
+ obj = {
2554
+ 'bbox': (int(x_min), int(y_min), int(x_max), int(y_max)),
2555
+ 'mask': bbox_mask.astype(bool),
2556
+ 'obj_id': int(obj_id)
2557
+ }
2558
+ frame_objs.append(obj)
2559
+ frame_objects.append(frame_objs)
2560
+
2561
+ # Create mask data dictionary
2562
+ mask_data = {
2563
+ 'video_path': self.video_path,
2564
+ 'total_frames': num_frames,
2565
+ 'height': video_height,
2566
+ 'width': video_width,
2567
+ 'fps': fps,
2568
+ 'frame_objects': frame_objects,
2569
+ 'objects_per_frame': [len(objs) for objs in frame_objects],
2570
+ 'tracker': {},
2571
+ 'format': 'new',
2572
+ 'start_offset': save_start,
2573
+ 'original_total_frames': self.total_frames
2574
+ }
2575
+
2576
+ # Save to HDF5 file - use experiment folder if available
2577
+ experiment_path = self.config.get("experiment_path")
2578
+ if experiment_path and os.path.exists(experiment_path):
2579
+ masks_dir = os.path.join(experiment_path, "masks")
2580
+ else:
2581
+ from singlebehaviorlab._paths import USER_DATA_DIR
2582
+ masks_dir = str(USER_DATA_DIR / "data" / "masks")
2583
+ os.makedirs(masks_dir, exist_ok=True)
2584
+
2585
+ video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
2586
+ # Remove "_masks" suffix if present to avoid duplication
2587
+ if video_basename.endswith("_masks"):
2588
+ video_basename = video_basename[:-6]
2589
+ mask_path = os.path.join(masks_dir, f"{video_basename}.h5")
2590
+
2591
+ from singlebehaviorlab.backend.video_processor import save_segmentation_data
2592
+ save_segmentation_data(mask_path, mask_data)
2593
+ return mask_path
2594
+
2595
+ def _on_tracking_error(self, err):
2596
+ """Handle tracking error."""
2597
+ self.btn_track.setEnabled(True)
2598
+ self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
2599
+ self.btn_pause_tracking.setEnabled(False)
2600
+ self.btn_resume_tracking.setEnabled(False)
2601
+ self.progress_bar.setVisible(False)
2602
+ QMessageBox.critical(self, "Error", f"Tracking failed:\n{err}")
2603
+ self.inference_state = None
2604
+ # Stop batch mode on error
2605
+ self.batch_mode = False
2606
+ self.batch_queue = []
2607
+
2608
+ def _open_settings(self):
2609
+ """Open settings dialog."""
2610
+ from PyQt6.QtWidgets import QDialog, QFormLayout
2611
+
2612
+ dialog = QDialog(self)
2613
+ dialog.setWindowTitle("SAM2 Settings")
2614
+ dialog.resize(450, 550)
2615
+ layout = QFormLayout(dialog)
2616
+
2617
+ spin_threshold = QDoubleSpinBox()
2618
+ spin_threshold.setRange(-10.0, 10.0)
2619
+ spin_threshold.setSingleStep(0.1)
2620
+ spin_threshold.setValue(self.mask_threshold)
2621
+ layout.addRow("Mask Threshold:", spin_threshold)
2622
+
2623
+ spin_fill_hole = QSpinBox()
2624
+ spin_fill_hole.setRange(0, 10000)
2625
+ spin_fill_hole.setValue(self.fill_hole_area)
2626
+ layout.addRow("Fill Hole Area:", spin_fill_hole)
2627
+
2628
+ chk_offload_video = QCheckBox()
2629
+ chk_offload_video.setChecked(self.offload_video)
2630
+ layout.addRow("Offload Video to CPU:", chk_offload_video)
2631
+
2632
+ chk_offload_state = QCheckBox()
2633
+ chk_offload_state.setChecked(self.offload_state)
2634
+ layout.addRow("Offload State to CPU:", chk_offload_state)
2635
+
2636
+ chk_bf16_autocast = QCheckBox()
2637
+ chk_bf16_autocast.setChecked(self.use_cuda_bf16_autocast)
2638
+ chk_bf16_autocast.setToolTip(
2639
+ "Use CUDA bfloat16 autocast for SAM2 inference.\n"
2640
+ "Usually speeds up segmentation on newer NVIDIA GPUs.\n"
2641
+ "Ignored on CPU."
2642
+ )
2643
+ layout.addRow("Use CUDA bf16 autocast:", chk_bf16_autocast)
2644
+
2645
+ chk_memory_management = QCheckBox()
2646
+ chk_memory_management.setChecked(self.enable_memory_management)
2647
+ layout.addRow("Enable Memory Management:", chk_memory_management)
2648
+
2649
+ chk_reseed_chunks = QCheckBox()
2650
+ chk_reseed_chunks.setToolTip("When processing in chunks, re-seed the next chunk with the last frame mask as a mask prompt.")
2651
+ chk_reseed_chunks.setChecked(getattr(self, "reseed_between_chunks", False))
2652
+ layout.addRow("Re-seed each chunk with last mask:", chk_reseed_chunks)
2653
+
2654
+ layout.addRow(QLabel("<b>Motion-Aware Tracking</b>"))
2655
+
2656
+ chk_motion_tracking = QCheckBox()
2657
+ chk_motion_tracking.setToolTip(
2658
+ "Enable motion-aware tracking:\n"
2659
+ "- Uses Kalman filter to predict object motion\n"
2660
+ "- Scores each frame by mask quality and motion consistency\n"
2661
+ "- Filters low-quality frames from memory to prevent drift\n"
2662
+ "Requires: pip install filterpy"
2663
+ )
2664
+ chk_motion_tracking.setChecked(getattr(self, "enable_motion_tracking", False))
2665
+ layout.addRow("Enable motion-aware tracking:", chk_motion_tracking)
2666
+
2667
+ spin_motion_threshold = QDoubleSpinBox()
2668
+ spin_motion_threshold.setRange(0.0, 1.0)
2669
+ spin_motion_threshold.setSingleStep(0.05)
2670
+ spin_motion_threshold.setValue(getattr(self, "motion_score_threshold", 0.3))
2671
+ spin_motion_threshold.setToolTip(
2672
+ "Minimum score for a frame to be used in memory.\n"
2673
+ "Lower = more permissive, Higher = stricter filtering.\n"
2674
+ "Score combines mask confidence and motion IoU."
2675
+ )
2676
+ layout.addRow("Motion score threshold:", spin_motion_threshold)
2677
+
2678
+ spin_consecutive_low = QSpinBox()
2679
+ spin_consecutive_low.setRange(1, 20)
2680
+ spin_consecutive_low.setValue(getattr(self, "motion_consecutive_low", 3))
2681
+ spin_consecutive_low.setToolTip(
2682
+ "Number of consecutive low-score frames before auto-correction.\n"
2683
+ "Lower = faster correction but more sensitive.\n"
2684
+ "Higher = more tolerant but slower to correct drift."
2685
+ )
2686
+ layout.addRow("Frames before auto-correct:", spin_consecutive_low)
2687
+
2688
+ spin_area_threshold = QDoubleSpinBox()
2689
+ spin_area_threshold.setRange(0.1, 2.0)
2690
+ spin_area_threshold.setSingleStep(0.1)
2691
+ spin_area_threshold.setValue(getattr(self, "motion_area_threshold", 0.5))
2692
+ spin_area_threshold.setToolTip(
2693
+ "Max allowed mask area change ratio.\n"
2694
+ "0.5 = mask can shrink/grow by 50% max.\n"
2695
+ "Lower = stricter, Higher = more permissive."
2696
+ )
2697
+ layout.addRow("Area change tolerance:", spin_area_threshold)
2698
+
2699
+ layout.addRow(QLabel("<b>OC-SORT Drift Correction</b>"))
2700
+
2701
+ chk_ocsort = QCheckBox()
2702
+ chk_ocsort.setToolTip(
2703
+ "Enable OC-SORT enhancements for drift correction:\n\n"
2704
+ "-Virtual Trajectory: During occlusions, maintains tracking\n"
2705
+ " using predicted motion (prevents state collapse)\n\n"
2706
+ "-ORU (Observation-Centric Re-Update): When object reappears,\n"
2707
+ " corrects accumulated drift by re-estimating past states\n\n"
2708
+ "Based on: 'Observation-Centric SORT' (arXiv:2203.14360)"
2709
+ )
2710
+ chk_ocsort.setChecked(getattr(self, "enable_ocsort", False))
2711
+ layout.addRow("Enable OC-SORT drift correction:", chk_ocsort)
2712
+
2713
+ spin_ocsort_inertia = QDoubleSpinBox()
2714
+ spin_ocsort_inertia.setRange(0.0, 1.0)
2715
+ spin_ocsort_inertia.setSingleStep(0.05)
2716
+ spin_ocsort_inertia.setValue(getattr(self, "ocsort_inertia", 0.2))
2717
+ spin_ocsort_inertia.setToolTip(
2718
+ "Velocity smoothing factor for ORU (paper default: 0.2).\n\n"
2719
+ "When object reappears after occlusion, this blends\n"
2720
+ "old velocity with newly computed velocity:\n"
2721
+ " smoothed = inertia * old_vel + (1-inertia) * new_vel\n\n"
2722
+ "Higher = more momentum, smoother but slower correction.\n"
2723
+ "Lower = faster correction but may be jerky."
2724
+ )
2725
+ layout.addRow("ORU inertia (velocity smoothing):", spin_ocsort_inertia)
2726
+
2727
+ btn_ok = QPushButton("OK")
2728
+ btn_ok.clicked.connect(dialog.accept)
2729
+ layout.addRow(btn_ok)
2730
+
2731
+ if dialog.exec():
2732
+ self.mask_threshold = spin_threshold.value()
2733
+ self.fill_hole_area = spin_fill_hole.value()
2734
+ self.offload_video = chk_offload_video.isChecked()
2735
+ self.offload_state = chk_offload_state.isChecked()
2736
+ self.use_cuda_bf16_autocast = chk_bf16_autocast.isChecked()
2737
+ self.enable_memory_management = chk_memory_management.isChecked()
2738
+ self.reseed_between_chunks = chk_reseed_chunks.isChecked()
2739
+ self.enable_motion_tracking = chk_motion_tracking.isChecked()
2740
+ self.motion_score_threshold = spin_motion_threshold.value()
2741
+ self.motion_consecutive_low = spin_consecutive_low.value()
2742
+ self.motion_area_threshold = spin_area_threshold.value()
2743
+ self.enable_ocsort = chk_ocsort.isChecked()
2744
+ self.ocsort_inertia = spin_ocsort_inertia.value()
2745
+
2746
+ if self.predictor:
2747
+ self.predictor.fill_hole_area = self.fill_hole_area
2748
+
2749
+ def update_config(self, config: dict):
2750
+ """Update configuration."""
2751
+ self.config = config
2752
+