singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,2752 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
import gc
|
|
4
|
+
import logging
|
|
5
|
+
import cv2
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
from PyQt6.QtWidgets import (
|
|
11
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog,
|
|
12
|
+
QRadioButton, QSlider, QButtonGroup, QMessageBox, QProgressBar,
|
|
13
|
+
QComboBox, QDoubleSpinBox, QSpinBox, QFormLayout, QCheckBox, QGroupBox,
|
|
14
|
+
QSizePolicy, QListWidget, QScrollArea, QProgressDialog, QApplication
|
|
15
|
+
)
|
|
16
|
+
from PyQt6.QtCore import Qt, QTimer, QThread, pyqtSignal, QPointF, QEvent
|
|
17
|
+
from PyQt6.QtGui import QImage, QPixmap, QPainter, QColor, QPen, QBrush
|
|
18
|
+
import shutil
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from contextlib import nullcontext
|
|
21
|
+
from importlib import metadata as importlib_metadata
|
|
22
|
+
|
|
23
|
+
# Motion tracking (Kalman filter, OC-SORT) in separate module
|
|
24
|
+
from .motion_tracking import MultiObjectMotionTracker
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Colors for different objects (R, G, B)
|
|
28
|
+
OBJ_COLORS = [
|
|
29
|
+
(0, 255, 0), # 1: Green
|
|
30
|
+
(255, 0, 0), # 2: Red
|
|
31
|
+
(0, 0, 255), # 3: Blue
|
|
32
|
+
(255, 255, 0), # 4: Yellow
|
|
33
|
+
(0, 255, 255), # 5: Cyan
|
|
34
|
+
(255, 0, 255), # 6: Magenta
|
|
35
|
+
(255, 128, 0), # 7: Orange
|
|
36
|
+
(128, 0, 255), # 8: Purple
|
|
37
|
+
(128, 128, 0), # 9: Olive
|
|
38
|
+
(0, 128, 128), # 10: Teal
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
def get_obj_color(obj_id):
|
|
42
|
+
idx = (obj_id - 1) % len(OBJ_COLORS)
|
|
43
|
+
return OBJ_COLORS[idx]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CheckpointDownloadWorker(QThread):
|
|
47
|
+
"""Worker thread for downloading SAM2 checkpoints."""
|
|
48
|
+
progress = pyqtSignal(str)
|
|
49
|
+
finished = pyqtSignal(bool, str)
|
|
50
|
+
|
|
51
|
+
# Model URLs (SAM 2.1)
|
|
52
|
+
MODEL_URLS = {
|
|
53
|
+
"sam2.1_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
|
54
|
+
"sam2.1_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
|
55
|
+
"sam2.1_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
|
56
|
+
"sam2.1_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
|
57
|
+
# SAM 2.0 (older versions)
|
|
58
|
+
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
|
|
59
|
+
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
|
|
60
|
+
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
|
|
61
|
+
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def __init__(self, checkpoint_name, checkpoint_path, checkpoint_url):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.checkpoint_name = checkpoint_name
|
|
67
|
+
self.checkpoint_path = checkpoint_path
|
|
68
|
+
self.checkpoint_url = checkpoint_url
|
|
69
|
+
|
|
70
|
+
def run(self):
|
|
71
|
+
try:
|
|
72
|
+
# Check if already downloaded
|
|
73
|
+
if os.path.exists(self.checkpoint_path):
|
|
74
|
+
file_size = os.path.getsize(self.checkpoint_path) / (1024**2) # MB
|
|
75
|
+
if file_size > 10: # Reasonable size check (should be >100MB)
|
|
76
|
+
self.finished.emit(True, f"Checkpoint already exists ({file_size:.1f} MB)")
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
self.progress.emit(f"Downloading {self.checkpoint_name}...")
|
|
80
|
+
self.progress.emit(f"URL: {self.checkpoint_url}")
|
|
81
|
+
|
|
82
|
+
# Try wget first, then curl
|
|
83
|
+
import urllib.request
|
|
84
|
+
|
|
85
|
+
def show_progress(block_num, block_size, total_size):
|
|
86
|
+
if total_size > 0:
|
|
87
|
+
percent = min(100, (block_num * block_size * 100) / total_size)
|
|
88
|
+
self.progress.emit(f"Downloading {self.checkpoint_name}: {percent:.1f}%")
|
|
89
|
+
|
|
90
|
+
# Download with progress
|
|
91
|
+
urllib.request.urlretrieve(
|
|
92
|
+
self.checkpoint_url,
|
|
93
|
+
self.checkpoint_path,
|
|
94
|
+
reporthook=show_progress
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Verify download
|
|
98
|
+
if os.path.exists(self.checkpoint_path):
|
|
99
|
+
file_size = os.path.getsize(self.checkpoint_path) / (1024**2)
|
|
100
|
+
if file_size < 10: # Suspiciously small
|
|
101
|
+
os.remove(self.checkpoint_path)
|
|
102
|
+
raise Exception(f"Downloaded file seems too small ({file_size:.1f} MB). Download may have failed.")
|
|
103
|
+
self.finished.emit(True, f"Downloaded successfully ({file_size:.1f} MB)")
|
|
104
|
+
else:
|
|
105
|
+
raise Exception("Download completed but file not found")
|
|
106
|
+
|
|
107
|
+
except Exception as e:
|
|
108
|
+
if os.path.exists(self.checkpoint_path):
|
|
109
|
+
try:
|
|
110
|
+
os.remove(self.checkpoint_path)
|
|
111
|
+
except:
|
|
112
|
+
pass
|
|
113
|
+
self.finished.emit(False, f"Download failed: {str(e)}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TrackingWorker(QThread):
|
|
117
|
+
"""Worker thread for running tracking."""
|
|
118
|
+
progress_signal = pyqtSignal(int)
|
|
119
|
+
frame_result_signal = pyqtSignal(int, dict) # frame_idx, {obj_id: mask}
|
|
120
|
+
finished_signal = pyqtSignal(dict)
|
|
121
|
+
error_signal = pyqtSignal(str)
|
|
122
|
+
log_message = pyqtSignal(str)
|
|
123
|
+
|
|
124
|
+
def __init__(self, predictor, video_path, user_points, start_frame, end_frame,
|
|
125
|
+
mask_threshold=0.0, offload_video=True, offload_state=True,
|
|
126
|
+
enable_memory_management=True, reseed_between_chunks=False,
|
|
127
|
+
initial_masks=None, enable_motion_tracking=False,
|
|
128
|
+
motion_score_threshold=0.3, motion_consecutive_low=3,
|
|
129
|
+
motion_area_threshold=0.5, enable_ocsort=False,
|
|
130
|
+
ocsort_inertia=0.2, use_cuda_bf16_autocast=True):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.predictor = predictor
|
|
133
|
+
self.video_path = video_path
|
|
134
|
+
self.user_points = user_points # (frame_idx, obj_id) -> {'points': [], 'labels': []}
|
|
135
|
+
self.start_frame = start_frame
|
|
136
|
+
self.end_frame = end_frame
|
|
137
|
+
self.mask_threshold = mask_threshold
|
|
138
|
+
self.offload_video = offload_video
|
|
139
|
+
self.offload_state = offload_state
|
|
140
|
+
self.enable_memory_management = enable_memory_management
|
|
141
|
+
self.reseed_between_chunks = reseed_between_chunks
|
|
142
|
+
self.initial_masks = initial_masks or {} # {(frame_idx, obj_id): mask_array} for resume conditioning
|
|
143
|
+
self.chunk_size = 200
|
|
144
|
+
self.should_stop = False
|
|
145
|
+
self.use_cuda_bf16_autocast = bool(use_cuda_bf16_autocast)
|
|
146
|
+
|
|
147
|
+
# Motion-aware tracking
|
|
148
|
+
self.enable_motion_tracking = enable_motion_tracking
|
|
149
|
+
self.motion_score_threshold = motion_score_threshold
|
|
150
|
+
self.motion_tracker = None
|
|
151
|
+
if enable_motion_tracking:
|
|
152
|
+
self.motion_tracker = MultiObjectMotionTracker(
|
|
153
|
+
motion_score_threshold=motion_score_threshold,
|
|
154
|
+
use_kalman=True,
|
|
155
|
+
consecutive_low_threshold=motion_consecutive_low,
|
|
156
|
+
area_change_threshold=motion_area_threshold,
|
|
157
|
+
use_ocsort=enable_ocsort,
|
|
158
|
+
ocsort_inertia=ocsort_inertia
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _use_cuda_bf16(self):
|
|
162
|
+
"""Use bf16 autocast only when SAM2 runs on CUDA."""
|
|
163
|
+
dev = getattr(self.predictor, "device", None)
|
|
164
|
+
dev_type = getattr(dev, "type", str(dev))
|
|
165
|
+
return bool(
|
|
166
|
+
self.use_cuda_bf16_autocast
|
|
167
|
+
and torch.cuda.is_available()
|
|
168
|
+
and dev_type == "cuda"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def _sam2_autocast(self):
|
|
172
|
+
if self._use_cuda_bf16():
|
|
173
|
+
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
174
|
+
return nullcontext()
|
|
175
|
+
|
|
176
|
+
def _sam2_call(self, fn, *args, **kwargs):
|
|
177
|
+
with self._sam2_autocast():
|
|
178
|
+
return fn(*args, **kwargs)
|
|
179
|
+
|
|
180
|
+
def stop(self):
|
|
181
|
+
"""Request tracking stop."""
|
|
182
|
+
self.should_stop = True
|
|
183
|
+
|
|
184
|
+
def run(self):
|
|
185
|
+
"""Run tracking with incremental processing."""
|
|
186
|
+
try:
|
|
187
|
+
all_video_segments = {} # global_frame_idx -> {obj_id: mask}
|
|
188
|
+
MAX_MASKS_IN_MEMORY = 500 # Keep only recent masks, older ones are already emitted via signal
|
|
189
|
+
last_masks_for_reseed = None # Store last frame masks of previous chunk for optional reseed
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
import decord
|
|
193
|
+
except ImportError:
|
|
194
|
+
raise ImportError("decord not found. Please install it: pip install eva-decord")
|
|
195
|
+
|
|
196
|
+
from collections import OrderedDict
|
|
197
|
+
|
|
198
|
+
def load_frames(start, end):
|
|
199
|
+
decord.bridge.set_bridge("torch")
|
|
200
|
+
image_size = self.predictor.image_size
|
|
201
|
+
vr = decord.VideoReader(self.video_path, width=image_size, height=image_size)
|
|
202
|
+
target_dtype = getattr(self.predictor, "dtype", torch.float32)
|
|
203
|
+
if self._use_cuda_bf16() and not self.offload_video:
|
|
204
|
+
target_dtype = torch.bfloat16
|
|
205
|
+
|
|
206
|
+
if end > len(vr):
|
|
207
|
+
end = len(vr)
|
|
208
|
+
indices = list(range(start, end))
|
|
209
|
+
frames = vr.get_batch(indices)
|
|
210
|
+
del vr # Free VideoReader memory after loading frames
|
|
211
|
+
images = frames.permute(0, 3, 1, 2).float() / 255.0
|
|
212
|
+
del frames # Free original frame tensor after processing
|
|
213
|
+
|
|
214
|
+
img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)[:, None, None]
|
|
215
|
+
img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)[:, None, None]
|
|
216
|
+
|
|
217
|
+
if not self.offload_video:
|
|
218
|
+
images = images.to(self.predictor.device, dtype=target_dtype)
|
|
219
|
+
img_mean = img_mean.to(self.predictor.device, dtype=target_dtype)
|
|
220
|
+
img_std = img_std.to(self.predictor.device, dtype=target_dtype)
|
|
221
|
+
else:
|
|
222
|
+
# Keep on CPU but ensure dtype matches model expectations
|
|
223
|
+
images = images.to(dtype=target_dtype)
|
|
224
|
+
img_mean = img_mean.to(dtype=target_dtype)
|
|
225
|
+
img_std = img_std.to(dtype=target_dtype)
|
|
226
|
+
|
|
227
|
+
images -= img_mean
|
|
228
|
+
images /= img_std
|
|
229
|
+
return images
|
|
230
|
+
|
|
231
|
+
def is_cuda_alloc_error(exc):
|
|
232
|
+
msg = str(exc).lower()
|
|
233
|
+
return (
|
|
234
|
+
"cuda out of memory" in msg
|
|
235
|
+
or "cublas_status_alloc_failed" in msg
|
|
236
|
+
or ("cuda error" in msg and "alloc" in msg)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def run_with_cuda_retry(op_name, fn):
|
|
240
|
+
try:
|
|
241
|
+
return fn()
|
|
242
|
+
except RuntimeError as e:
|
|
243
|
+
if not is_cuda_alloc_error(e):
|
|
244
|
+
raise
|
|
245
|
+
self.log_message.emit(
|
|
246
|
+
f"[GPU] Memory error during {op_name}. Clearing cache and retrying once..."
|
|
247
|
+
)
|
|
248
|
+
if torch.cuda.is_available():
|
|
249
|
+
torch.cuda.empty_cache()
|
|
250
|
+
gc.collect()
|
|
251
|
+
try:
|
|
252
|
+
return fn()
|
|
253
|
+
except RuntimeError as e2:
|
|
254
|
+
if is_cuda_alloc_error(e2):
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
"GPU memory allocation failed while running SAM2. "
|
|
257
|
+
"Try one or more: enable Offload Video to CPU, enable Offload State to CPU, "
|
|
258
|
+
"use a smaller SAM2 model, or track a shorter range."
|
|
259
|
+
) from e2
|
|
260
|
+
raise
|
|
261
|
+
|
|
262
|
+
# Initialize with first chunk
|
|
263
|
+
current_end = min(self.start_frame + self.chunk_size, self.end_frame)
|
|
264
|
+
self.log_message.emit(f"Initializing with chunk: {self.start_frame} to {current_end}")
|
|
265
|
+
|
|
266
|
+
images = load_frames(self.start_frame, current_end)
|
|
267
|
+
|
|
268
|
+
# Get original dimensions
|
|
269
|
+
vr_meta = decord.VideoReader(self.video_path)
|
|
270
|
+
vh, vw, _ = vr_meta[0].shape
|
|
271
|
+
del vr_meta # Free memory immediately after getting dimensions
|
|
272
|
+
|
|
273
|
+
images_list = [images[i] for i in range(len(images))]
|
|
274
|
+
|
|
275
|
+
inference_state = {}
|
|
276
|
+
inference_state["images"] = images_list
|
|
277
|
+
inference_state["num_frames"] = len(images_list)
|
|
278
|
+
inference_state["offload_video_to_cpu"] = self.offload_video
|
|
279
|
+
inference_state["offload_state_to_cpu"] = self.offload_state
|
|
280
|
+
inference_state["video_height"] = vh
|
|
281
|
+
inference_state["video_width"] = vw
|
|
282
|
+
inference_state["device"] = self.predictor.device
|
|
283
|
+
inference_state["storage_device"] = torch.device("cpu") if self.offload_state else self.predictor.device
|
|
284
|
+
inference_state["point_inputs_per_obj"] = {}
|
|
285
|
+
inference_state["mask_inputs_per_obj"] = {}
|
|
286
|
+
inference_state["cached_features"] = {}
|
|
287
|
+
inference_state["constants"] = {}
|
|
288
|
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
|
289
|
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
|
290
|
+
inference_state["obj_ids"] = []
|
|
291
|
+
inference_state["output_dict_per_obj"] = {}
|
|
292
|
+
inference_state["temp_output_dict_per_obj"] = {}
|
|
293
|
+
inference_state["frames_tracked_per_obj"] = {}
|
|
294
|
+
|
|
295
|
+
# Warm up
|
|
296
|
+
try:
|
|
297
|
+
self._sam2_call(self.predictor._get_image_feature, inference_state, frame_idx=0, batch_size=1)
|
|
298
|
+
except:
|
|
299
|
+
pass
|
|
300
|
+
|
|
301
|
+
self.predictor.reset_state(inference_state)
|
|
302
|
+
|
|
303
|
+
# Loop through chunks with sliding window memory management
|
|
304
|
+
# Track the global offset (how many frames we've trimmed from the start)
|
|
305
|
+
# Only used when memory management is enabled
|
|
306
|
+
global_offset = 0 # Tracks how many frames we've dropped from the front
|
|
307
|
+
MAX_FRAMES_IN_MEMORY = 800 # Keep ~800 frames in memory
|
|
308
|
+
|
|
309
|
+
processed_up_to = self.start_frame
|
|
310
|
+
|
|
311
|
+
while processed_up_to < self.end_frame:
|
|
312
|
+
if self.should_stop:
|
|
313
|
+
break
|
|
314
|
+
|
|
315
|
+
# inference_state["images"] grows with each processed frame;
|
|
316
|
+
# frames correspond to self.start_frame + index. When memory
|
|
317
|
+
# management is disabled, global_offset stays at 0.
|
|
318
|
+
buffer_start = self.start_frame + (global_offset if self.enable_memory_management else 0)
|
|
319
|
+
|
|
320
|
+
chunk_start = processed_up_to
|
|
321
|
+
chunk_end = buffer_start + inference_state["num_frames"] # End of current buffer
|
|
322
|
+
|
|
323
|
+
self.log_message.emit(f"Processing range: {chunk_start} to {chunk_end} (buffer: {buffer_start} to {chunk_end})")
|
|
324
|
+
|
|
325
|
+
for (frame_idx, obj_id), data in self.user_points.items():
|
|
326
|
+
if chunk_start <= frame_idx < chunk_end:
|
|
327
|
+
# Local index relative to current buffer (after trimming)
|
|
328
|
+
local_idx = frame_idx - buffer_start
|
|
329
|
+
if local_idx < 0 or local_idx >= inference_state["num_frames"]:
|
|
330
|
+
# Frame was trimmed, skip (shouldn't happen if logic is correct)
|
|
331
|
+
continue
|
|
332
|
+
pts = np.array(data['points'], dtype=np.float32)
|
|
333
|
+
lbls = np.array(data['labels'], dtype=np.int32)
|
|
334
|
+
|
|
335
|
+
run_with_cuda_retry(
|
|
336
|
+
"add_new_points_or_box",
|
|
337
|
+
lambda: self._sam2_call(self.predictor.add_new_points_or_box,
|
|
338
|
+
inference_state=inference_state,
|
|
339
|
+
frame_idx=local_idx,
|
|
340
|
+
obj_id=obj_id,
|
|
341
|
+
points=pts,
|
|
342
|
+
labels=lbls,
|
|
343
|
+
normalize_coords=True,
|
|
344
|
+
),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Inject initial masks (e.g., from pause/resume refinement)
|
|
348
|
+
for (frame_idx, obj_id), mask in self.initial_masks.items():
|
|
349
|
+
if chunk_start <= frame_idx < chunk_end:
|
|
350
|
+
local_idx = frame_idx - buffer_start
|
|
351
|
+
if local_idx < 0 or local_idx >= inference_state["num_frames"]:
|
|
352
|
+
continue
|
|
353
|
+
try:
|
|
354
|
+
# Resize mask to video dimensions if needed
|
|
355
|
+
vh = inference_state["video_height"]
|
|
356
|
+
vw = inference_state["video_width"]
|
|
357
|
+
if mask.shape[0] != vh or mask.shape[1] != vw:
|
|
358
|
+
import cv2
|
|
359
|
+
mask_resized = cv2.resize(mask.astype(np.float32), (vw, vh), interpolation=cv2.INTER_NEAREST)
|
|
360
|
+
mask = (mask_resized > 0.5).astype(np.uint8)
|
|
361
|
+
|
|
362
|
+
run_with_cuda_retry(
|
|
363
|
+
"add_new_mask",
|
|
364
|
+
lambda: self._sam2_call(self.predictor.add_new_mask,
|
|
365
|
+
inference_state=inference_state,
|
|
366
|
+
frame_idx=local_idx,
|
|
367
|
+
obj_id=obj_id,
|
|
368
|
+
mask=mask.astype(bool),
|
|
369
|
+
),
|
|
370
|
+
)
|
|
371
|
+
self.log_message.emit(f"Injected refined mask for obj {obj_id} at frame {frame_idx}")
|
|
372
|
+
except Exception as e:
|
|
373
|
+
self.log_message.emit(f"Warning: Could not inject mask: {e}")
|
|
374
|
+
|
|
375
|
+
# Propagate from chunk_start
|
|
376
|
+
# We need local index for propagation start (relative to current buffer)
|
|
377
|
+
prop_start_local = chunk_start - buffer_start
|
|
378
|
+
if prop_start_local < 0:
|
|
379
|
+
prop_start_local = 0 # Can't propagate from before buffer start
|
|
380
|
+
|
|
381
|
+
# Memory trimming may drop the initial conditioning frame (the
|
|
382
|
+
# user's first click). The bundled SAM2 fork modifies
|
|
383
|
+
# propagate_in_video_preflight to allow propagation when only
|
|
384
|
+
# tracking history (non_cond_frame_outputs) is present, so
|
|
385
|
+
# explicit mask re-injection is not required here.
|
|
386
|
+
|
|
387
|
+
with self._sam2_autocast():
|
|
388
|
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
|
389
|
+
inference_state,
|
|
390
|
+
start_frame_idx=prop_start_local
|
|
391
|
+
):
|
|
392
|
+
if self.should_stop:
|
|
393
|
+
break
|
|
394
|
+
|
|
395
|
+
# Convert local buffer index to global frame index
|
|
396
|
+
global_idx = out_frame_idx + buffer_start
|
|
397
|
+
|
|
398
|
+
if global_idx not in all_video_segments:
|
|
399
|
+
all_video_segments[global_idx] = {}
|
|
400
|
+
|
|
401
|
+
frame_scores = {} # Track scores for this frame per object
|
|
402
|
+
low_quality_objects = [] # Objects with sustained low scores this frame
|
|
403
|
+
|
|
404
|
+
for i, o_id in enumerate(out_obj_ids):
|
|
405
|
+
mask_logit = out_mask_logits[i]
|
|
406
|
+
if mask_logit.ndim == 3:
|
|
407
|
+
mask_logit = mask_logit[0]
|
|
408
|
+
mask = (mask_logit > self.mask_threshold).cpu().numpy().astype(np.uint8).squeeze()
|
|
409
|
+
all_video_segments[global_idx][o_id] = mask
|
|
410
|
+
|
|
411
|
+
# Motion-aware scoring
|
|
412
|
+
if self.motion_tracker is not None:
|
|
413
|
+
score, should_use = self.motion_tracker.update(
|
|
414
|
+
o_id, mask, mask_logit, global_idx
|
|
415
|
+
)
|
|
416
|
+
frame_scores[o_id] = score
|
|
417
|
+
|
|
418
|
+
# Only filter SAM2 memory after sustained low quality.
|
|
419
|
+
# This avoids dropping useful memory on one-frame glitches.
|
|
420
|
+
if (not should_use) and self.motion_tracker.check_needs_correction(o_id):
|
|
421
|
+
low_quality_objects.append(o_id)
|
|
422
|
+
|
|
423
|
+
# Log motion tracking info periodically
|
|
424
|
+
if self.motion_tracker is not None and global_idx % 50 == 0:
|
|
425
|
+
score_str = ", ".join([f"obj{k}:{v:.2f}" for k, v in frame_scores.items()])
|
|
426
|
+
self.log_message.emit(f"Frame {global_idx} scores: {score_str}")
|
|
427
|
+
|
|
428
|
+
# Memory filtering: remove low-quality frames from memory
|
|
429
|
+
if self.motion_tracker is not None and low_quality_objects:
|
|
430
|
+
for obj_id in low_quality_objects:
|
|
431
|
+
obj_idx = inference_state.get("obj_id_to_idx", {}).get(obj_id)
|
|
432
|
+
if obj_idx is not None:
|
|
433
|
+
# Recency-weighted memory filtering:
|
|
434
|
+
# keep recent memory frames and remove older low-score frames first.
|
|
435
|
+
obj_output = inference_state.get("output_dict_per_obj", {}).get(obj_idx, {})
|
|
436
|
+
non_cond = obj_output.get("non_cond_frame_outputs", {})
|
|
437
|
+
if not non_cond:
|
|
438
|
+
continue
|
|
439
|
+
threshold = self.motion_tracker.get_effective_threshold(obj_id)
|
|
440
|
+
recent_keep = 6
|
|
441
|
+
max_remove = 2
|
|
442
|
+
removed = 0
|
|
443
|
+
old_keys = sorted(
|
|
444
|
+
k for k in non_cond.keys() if k < (out_frame_idx - recent_keep)
|
|
445
|
+
)
|
|
446
|
+
for mem_local_idx in old_keys:
|
|
447
|
+
if removed >= max_remove:
|
|
448
|
+
break
|
|
449
|
+
mem_global_idx = mem_local_idx + buffer_start
|
|
450
|
+
mem_score = self.motion_tracker.get_frame_score(obj_id, mem_global_idx)
|
|
451
|
+
if mem_score is not None and mem_score < threshold:
|
|
452
|
+
del non_cond[mem_local_idx]
|
|
453
|
+
removed += 1
|
|
454
|
+
|
|
455
|
+
# Fallback if no older candidate was removable.
|
|
456
|
+
if removed == 0 and out_frame_idx in non_cond:
|
|
457
|
+
del non_cond[out_frame_idx]
|
|
458
|
+
|
|
459
|
+
# Appearance memory re-seed: when object recovers from long occlusion,
|
|
460
|
+
# inject golden mask so SAM2 remembers what the object looked like.
|
|
461
|
+
if self.motion_tracker is not None and self.motion_tracker.appearance_memory is not None:
|
|
462
|
+
for o_id in out_obj_ids:
|
|
463
|
+
amem = self.motion_tracker.appearance_memory
|
|
464
|
+
if amem.is_recovery_pending(o_id):
|
|
465
|
+
golden_mask = amem.pop_reseed_mask(o_id)
|
|
466
|
+
if golden_mask is not None:
|
|
467
|
+
try:
|
|
468
|
+
# Inject at current frame so SAM2 uses the golden
|
|
469
|
+
# mask immediately (not delayed by one frame).
|
|
470
|
+
reseed_local = out_frame_idx
|
|
471
|
+
i_vh = inference_state["video_height"]
|
|
472
|
+
i_vw = inference_state["video_width"]
|
|
473
|
+
gm = golden_mask
|
|
474
|
+
if gm.shape[0] != i_vh or gm.shape[1] != i_vw:
|
|
475
|
+
import cv2 as _cv2
|
|
476
|
+
gm = _cv2.resize(
|
|
477
|
+
gm.astype(np.float32), (i_vw, i_vh),
|
|
478
|
+
interpolation=_cv2.INTER_NEAREST
|
|
479
|
+
)
|
|
480
|
+
gm = (gm > 0.5).astype(np.uint8)
|
|
481
|
+
run_with_cuda_retry(
|
|
482
|
+
"appearance_reseed_add_new_mask",
|
|
483
|
+
lambda: self._sam2_call(
|
|
484
|
+
self.predictor.add_new_mask,
|
|
485
|
+
inference_state=inference_state,
|
|
486
|
+
frame_idx=reseed_local,
|
|
487
|
+
obj_id=o_id,
|
|
488
|
+
mask=gm.astype(bool),
|
|
489
|
+
),
|
|
490
|
+
)
|
|
491
|
+
self.log_message.emit(
|
|
492
|
+
f"[AppearanceMemory] Re-seeded obj {o_id} at frame {global_idx} with golden mask"
|
|
493
|
+
)
|
|
494
|
+
except Exception as e:
|
|
495
|
+
self.log_message.emit(f"[AppearanceMemory] Re-seed failed for obj {o_id}: {e}")
|
|
496
|
+
|
|
497
|
+
# Automatic prompt injection: when drift detected, inject predicted bbox
|
|
498
|
+
if self.motion_tracker is not None:
|
|
499
|
+
for o_id in out_obj_ids:
|
|
500
|
+
if self.motion_tracker.check_needs_correction(o_id):
|
|
501
|
+
# Get Kalman-predicted bbox
|
|
502
|
+
pred_bbox = self.motion_tracker.get_predicted_bbox_for_correction(o_id)
|
|
503
|
+
if pred_bbox is not None:
|
|
504
|
+
if not self.motion_tracker.is_correction_bbox_sane(o_id, pred_bbox):
|
|
505
|
+
self.log_message.emit(
|
|
506
|
+
f"[Motion] Skipped correction for obj {o_id}: jump/scale too large"
|
|
507
|
+
)
|
|
508
|
+
self.motion_tracker.reset_correction_flag(o_id)
|
|
509
|
+
continue
|
|
510
|
+
try:
|
|
511
|
+
# Inject predicted bbox as new prompt for next frame
|
|
512
|
+
next_local_idx = out_frame_idx + 1
|
|
513
|
+
if next_local_idx < inference_state["num_frames"]:
|
|
514
|
+
run_with_cuda_retry(
|
|
515
|
+
"motion_correction_add_new_points_or_box",
|
|
516
|
+
lambda: self._sam2_call(
|
|
517
|
+
self.predictor.add_new_points_or_box,
|
|
518
|
+
inference_state,
|
|
519
|
+
frame_idx=next_local_idx,
|
|
520
|
+
obj_id=o_id,
|
|
521
|
+
box=pred_bbox,
|
|
522
|
+
clear_old_points=True,
|
|
523
|
+
normalize_coords=False,
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
self.log_message.emit(
|
|
527
|
+
f"[Motion] Injected correction for obj {o_id} at frame {global_idx+1}"
|
|
528
|
+
)
|
|
529
|
+
self.motion_tracker.reset_correction_flag(o_id)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
self.log_message.emit(f"Correction failed: {e}")
|
|
532
|
+
|
|
533
|
+
self.progress_signal.emit(global_idx)
|
|
534
|
+
# Emit real-time result for this frame
|
|
535
|
+
self.frame_result_signal.emit(global_idx, all_video_segments[global_idx])
|
|
536
|
+
|
|
537
|
+
# Clear old masks from memory (they're already emitted to main thread)
|
|
538
|
+
# Keep only recent MAX_MASKS_IN_MEMORY masks for the final emit
|
|
539
|
+
if len(all_video_segments) > MAX_MASKS_IN_MEMORY:
|
|
540
|
+
oldest_frame = min(all_video_segments.keys())
|
|
541
|
+
del all_video_segments[oldest_frame]
|
|
542
|
+
|
|
543
|
+
# Periodically clear CUDA cache (every 100 frames) to prevent accumulation
|
|
544
|
+
if global_idx % 100 == 0 and torch.cuda.is_available():
|
|
545
|
+
torch.cuda.empty_cache()
|
|
546
|
+
|
|
547
|
+
processed_up_to = chunk_end
|
|
548
|
+
|
|
549
|
+
# Clear CUDA cache and run garbage collection after each chunk
|
|
550
|
+
if torch.cuda.is_available():
|
|
551
|
+
torch.cuda.empty_cache()
|
|
552
|
+
gc.collect()
|
|
553
|
+
|
|
554
|
+
# MEMORY MANAGEMENT: Trim old frames if we exceed MAX_FRAMES_IN_MEMORY
|
|
555
|
+
# Only apply if memory management is enabled
|
|
556
|
+
if self.enable_memory_management and inference_state["num_frames"] > MAX_FRAMES_IN_MEMORY:
|
|
557
|
+
frames_to_trim = inference_state["num_frames"] - MAX_FRAMES_IN_MEMORY
|
|
558
|
+
self.log_message.emit(f"Trimming {frames_to_trim} old frames from memory (keeping last {MAX_FRAMES_IN_MEMORY} frames)...")
|
|
559
|
+
|
|
560
|
+
# 1. Trim images list (keep last MAX_FRAMES_IN_MEMORY frames)
|
|
561
|
+
# Since images is a list, this is O(1) pointer manipulation, not O(N) memory copy
|
|
562
|
+
inference_state["images"] = inference_state["images"][-MAX_FRAMES_IN_MEMORY:]
|
|
563
|
+
inference_state["num_frames"] = len(inference_state["images"])
|
|
564
|
+
|
|
565
|
+
# 2. Update global offset
|
|
566
|
+
global_offset += frames_to_trim
|
|
567
|
+
|
|
568
|
+
# 3. Shift all indices in inference_state dictionaries
|
|
569
|
+
def shift_dict_keys(d, shift):
|
|
570
|
+
"""Shift dictionary keys by subtracting shift, removing negative keys"""
|
|
571
|
+
new_d = {}
|
|
572
|
+
for k, v in d.items():
|
|
573
|
+
new_k = k - shift
|
|
574
|
+
if new_k >= 0: # Only keep non-negative keys (frames still in buffer)
|
|
575
|
+
new_d[new_k] = v
|
|
576
|
+
return new_d
|
|
577
|
+
|
|
578
|
+
# Shift cached features
|
|
579
|
+
inference_state["cached_features"] = shift_dict_keys(inference_state["cached_features"], frames_to_trim)
|
|
580
|
+
|
|
581
|
+
# Shift per-object dictionaries
|
|
582
|
+
for obj_idx in list(inference_state["point_inputs_per_obj"].keys()):
|
|
583
|
+
inference_state["point_inputs_per_obj"][obj_idx] = shift_dict_keys(
|
|
584
|
+
inference_state["point_inputs_per_obj"][obj_idx], frames_to_trim
|
|
585
|
+
)
|
|
586
|
+
inference_state["mask_inputs_per_obj"][obj_idx] = shift_dict_keys(
|
|
587
|
+
inference_state["mask_inputs_per_obj"][obj_idx], frames_to_trim
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Shift output dicts (keep conditioning frames if they're still in range)
|
|
591
|
+
obj_output = inference_state["output_dict_per_obj"][obj_idx]
|
|
592
|
+
obj_output["cond_frame_outputs"] = shift_dict_keys(
|
|
593
|
+
obj_output["cond_frame_outputs"], frames_to_trim
|
|
594
|
+
)
|
|
595
|
+
# SAM2's memory bank only requires the last num_maskmem
|
|
596
|
+
# non_cond frames, but all non_cond frames still inside
|
|
597
|
+
# the (already trimmed) buffer are retained.
|
|
598
|
+
obj_output["non_cond_frame_outputs"] = shift_dict_keys(
|
|
599
|
+
obj_output["non_cond_frame_outputs"], frames_to_trim
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
obj_temp = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
603
|
+
obj_temp["cond_frame_outputs"] = shift_dict_keys(
|
|
604
|
+
obj_temp["cond_frame_outputs"], frames_to_trim
|
|
605
|
+
)
|
|
606
|
+
obj_temp["non_cond_frame_outputs"] = shift_dict_keys(
|
|
607
|
+
obj_temp["non_cond_frame_outputs"], frames_to_trim
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Shift frames_tracked metadata
|
|
611
|
+
inference_state["frames_tracked_per_obj"][obj_idx] = shift_dict_keys(
|
|
612
|
+
inference_state["frames_tracked_per_obj"][obj_idx], frames_to_trim
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
self.log_message.emit(f"Memory trimmed. Global offset now: {global_offset}")
|
|
616
|
+
|
|
617
|
+
# Clear CUDA cache after memory trimming
|
|
618
|
+
if torch.cuda.is_available():
|
|
619
|
+
torch.cuda.empty_cache()
|
|
620
|
+
|
|
621
|
+
# Load NEXT chunk if needed
|
|
622
|
+
if processed_up_to < self.end_frame:
|
|
623
|
+
next_end = min(processed_up_to + self.chunk_size, self.end_frame)
|
|
624
|
+
self.log_message.emit(f"Loading next chunk: {processed_up_to} to {next_end}")
|
|
625
|
+
|
|
626
|
+
# Capture last frame masks of current chunk for optional reseed
|
|
627
|
+
if self.reseed_between_chunks:
|
|
628
|
+
last_frame_idx = processed_up_to - 1
|
|
629
|
+
last_masks_for_reseed = all_video_segments.get(last_frame_idx, None)
|
|
630
|
+
|
|
631
|
+
new_images = load_frames(processed_up_to, next_end)
|
|
632
|
+
|
|
633
|
+
# Append to inference_state
|
|
634
|
+
# OPTIMIZATION: Since images is a list, we can extend it directly (O(1) per frame)
|
|
635
|
+
# instead of torch.cat which would copy all existing frames (O(N))
|
|
636
|
+
new_images_list = [new_images[i] for i in range(len(new_images))]
|
|
637
|
+
inference_state["images"].extend(new_images_list)
|
|
638
|
+
inference_state["num_frames"] = len(inference_state["images"])
|
|
639
|
+
|
|
640
|
+
# Optional reseed: add mask from last frame of previous chunk
|
|
641
|
+
if self.reseed_between_chunks and last_masks_for_reseed:
|
|
642
|
+
try:
|
|
643
|
+
seed_frame_local = processed_up_to - buffer_start # first frame of new chunk in buffer coords
|
|
644
|
+
vh = inference_state["video_height"]
|
|
645
|
+
vw = inference_state["video_width"]
|
|
646
|
+
|
|
647
|
+
for obj_id, mask in last_masks_for_reseed.items():
|
|
648
|
+
if mask is None or mask.max() == 0:
|
|
649
|
+
continue
|
|
650
|
+
|
|
651
|
+
# Resize mask to video dimensions if needed
|
|
652
|
+
if mask.shape[0] != vh or mask.shape[1] != vw:
|
|
653
|
+
import cv2
|
|
654
|
+
mask_resized = cv2.resize(mask.astype(np.float32), (vw, vh), interpolation=cv2.INTER_NEAREST)
|
|
655
|
+
mask = (mask_resized > 0.5).astype(np.uint8)
|
|
656
|
+
|
|
657
|
+
run_with_cuda_retry(
|
|
658
|
+
"chunk_reseed_add_new_mask",
|
|
659
|
+
lambda: self._sam2_call(self.predictor.add_new_mask,
|
|
660
|
+
inference_state=inference_state,
|
|
661
|
+
frame_idx=seed_frame_local,
|
|
662
|
+
obj_id=obj_id,
|
|
663
|
+
mask=mask.astype(bool),
|
|
664
|
+
),
|
|
665
|
+
)
|
|
666
|
+
self.log_message.emit(f"Reseeded obj {obj_id} with mask at frame {processed_up_to}")
|
|
667
|
+
except Exception as e:
|
|
668
|
+
self.log_message.emit(f"Warning: reseed between chunks failed: {e}")
|
|
669
|
+
|
|
670
|
+
self.finished_signal.emit(all_video_segments)
|
|
671
|
+
|
|
672
|
+
except Exception as e:
|
|
673
|
+
import traceback
|
|
674
|
+
self.log_message.emit(f"ERROR: {str(e)}\n{traceback.format_exc()}")
|
|
675
|
+
self.error_signal.emit(str(e))
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class VideoLabel(QLabel):
|
|
679
|
+
"""Custom label for video display with click handling."""
|
|
680
|
+
click_signal = pyqtSignal(int, int)
|
|
681
|
+
|
|
682
|
+
def mousePressEvent(self, event):
|
|
683
|
+
if event.button() == Qt.MouseButton.LeftButton:
|
|
684
|
+
self.click_signal.emit(event.pos().x(), event.pos().y())
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
class SegmentationTrackingWidget(QWidget):
|
|
688
|
+
"""Widget for segmentation and multi-object tracking using SAM2."""
|
|
689
|
+
|
|
690
|
+
# Signal emitted when tracking completes with video and mask paths
|
|
691
|
+
tracking_completed = pyqtSignal(str, str) # video_path, mask_path
|
|
692
|
+
|
|
693
|
+
def __init__(self, config: dict):
|
|
694
|
+
super().__init__()
|
|
695
|
+
self.config = config
|
|
696
|
+
self.video_path = None
|
|
697
|
+
self.cap = None
|
|
698
|
+
self.total_frames = 0
|
|
699
|
+
self.current_frame_idx = 0
|
|
700
|
+
self.frame = None
|
|
701
|
+
self._frame_rgb = None # Keep RGB frame reference for QImage memory safety
|
|
702
|
+
self.points = [] # List of (x, y, label, frame_idx, obj_id)
|
|
703
|
+
self.masks = {} # frame_idx -> {obj_id: mask_array}
|
|
704
|
+
self.last_processed_frame = None
|
|
705
|
+
self.tracking_paused = False
|
|
706
|
+
self.resume_from_frame = None
|
|
707
|
+
self.resume_initial_masks = {}
|
|
708
|
+
self._base_display_pixmap = None
|
|
709
|
+
self.zoom_factor = 1.0
|
|
710
|
+
self.zoom_min = 0.5
|
|
711
|
+
self.zoom_max = 4.0
|
|
712
|
+
self.zoom_step = 0.2
|
|
713
|
+
|
|
714
|
+
self.obj_ids = [1]
|
|
715
|
+
self.current_obj_id = 1
|
|
716
|
+
|
|
717
|
+
self.predictor = None
|
|
718
|
+
self.inference_state = None
|
|
719
|
+
self.state_start_frame = 0
|
|
720
|
+
|
|
721
|
+
# Multi-video support
|
|
722
|
+
self.videos = [] # list of per-video state dicts
|
|
723
|
+
self.current_video_idx = None
|
|
724
|
+
self.batch_queue = []
|
|
725
|
+
self.batch_mode = False
|
|
726
|
+
|
|
727
|
+
# Settings
|
|
728
|
+
self.mask_threshold = 0.0
|
|
729
|
+
self.fill_hole_area = 0
|
|
730
|
+
self.offload_video = True
|
|
731
|
+
self.offload_state = True
|
|
732
|
+
self.use_cuda_bf16_autocast = True
|
|
733
|
+
self.enable_memory_management = True
|
|
734
|
+
self.max_frames_per_load = 200 # Limit frames loaded at once to prevent RAM issues
|
|
735
|
+
self.save_overlay_video = True
|
|
736
|
+
|
|
737
|
+
# Motion-aware tracking settings
|
|
738
|
+
self.enable_motion_tracking = False # Off by default
|
|
739
|
+
self.motion_score_threshold = 0.3
|
|
740
|
+
self.motion_consecutive_low = 3 # Frames before auto-correction
|
|
741
|
+
self.motion_area_threshold = 0.5 # Max allowed area change ratio
|
|
742
|
+
|
|
743
|
+
# OC-SORT drift correction settings
|
|
744
|
+
self.enable_ocsort = False # Off by default
|
|
745
|
+
self.ocsort_inertia = 0.2 # Paper default: 0.2
|
|
746
|
+
|
|
747
|
+
# SAM2 paths - resolved via _paths so this works both from source and pip install
|
|
748
|
+
from singlebehaviorlab._paths import get_sam2_backend_dir, get_sam2_checkpoints_dir
|
|
749
|
+
self.sam2_dir = str(get_sam2_checkpoints_dir())
|
|
750
|
+
self.sam2_backend_dir = str(get_sam2_backend_dir())
|
|
751
|
+
self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", "sam2.1_hiera_large.pt")
|
|
752
|
+
self.model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
753
|
+
|
|
754
|
+
self.tracking_worker = None
|
|
755
|
+
self.download_worker = None
|
|
756
|
+
|
|
757
|
+
self._setup_ui()
|
|
758
|
+
self._check_sam2_availability()
|
|
759
|
+
|
|
760
|
+
def _use_cuda_bf16(self):
|
|
761
|
+
"""Use bf16 autocast only for CUDA SAM2 inference."""
|
|
762
|
+
dev = getattr(self.predictor, "device", None)
|
|
763
|
+
dev_type = getattr(dev, "type", str(dev))
|
|
764
|
+
return bool(
|
|
765
|
+
self.use_cuda_bf16_autocast
|
|
766
|
+
and torch.cuda.is_available()
|
|
767
|
+
and dev_type == "cuda"
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
def _sam2_autocast(self):
|
|
771
|
+
if self._use_cuda_bf16():
|
|
772
|
+
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
773
|
+
return nullcontext()
|
|
774
|
+
|
|
775
|
+
def _sam2_call(self, fn, *args, **kwargs):
|
|
776
|
+
with self._sam2_autocast():
|
|
777
|
+
return fn(*args, **kwargs)
|
|
778
|
+
self._check_model_availability()
|
|
779
|
+
|
|
780
|
+
def _ensure_hydra_initialized(self):
|
|
781
|
+
"""Ensure Hydra is initialized before using SAM2."""
|
|
782
|
+
try:
|
|
783
|
+
from hydra.core.global_hydra import GlobalHydra
|
|
784
|
+
from hydra import initialize_config_dir
|
|
785
|
+
|
|
786
|
+
# Check if Hydra is already initialized
|
|
787
|
+
if GlobalHydra.instance().is_initialized():
|
|
788
|
+
return True
|
|
789
|
+
|
|
790
|
+
# Find sam2 configs directory
|
|
791
|
+
sam2_configs_dir = None
|
|
792
|
+
|
|
793
|
+
# Try to find from installed package first (most reliable)
|
|
794
|
+
try:
|
|
795
|
+
import sam2
|
|
796
|
+
if hasattr(sam2, '__file__') and sam2.__file__:
|
|
797
|
+
sam2_path = os.path.dirname(sam2.__file__)
|
|
798
|
+
configs_path = os.path.join(sam2_path, "configs")
|
|
799
|
+
if os.path.exists(configs_path):
|
|
800
|
+
sam2_configs_dir = configs_path
|
|
801
|
+
elif hasattr(sam2, '__path__'):
|
|
802
|
+
# Handle namespace packages
|
|
803
|
+
for path in sam2.__path__:
|
|
804
|
+
configs_path = os.path.join(path, "configs")
|
|
805
|
+
if os.path.exists(configs_path):
|
|
806
|
+
sam2_configs_dir = configs_path
|
|
807
|
+
break
|
|
808
|
+
except ImportError:
|
|
809
|
+
pass
|
|
810
|
+
|
|
811
|
+
# Fall back to the bundled sam2_backend configs directory.
|
|
812
|
+
if not sam2_configs_dir:
|
|
813
|
+
sam2_backend_configs = os.path.join(self.sam2_backend_dir, "sam2", "configs")
|
|
814
|
+
if os.path.exists(sam2_backend_configs):
|
|
815
|
+
sam2_configs_dir = sam2_backend_configs
|
|
816
|
+
|
|
817
|
+
if sam2_configs_dir:
|
|
818
|
+
# Initialize Hydra with the config directory
|
|
819
|
+
initialize_config_dir(config_dir=sam2_configs_dir, version_base="1.2")
|
|
820
|
+
return True
|
|
821
|
+
|
|
822
|
+
# Fallback: try initialize_config_module (may work if SAM2 is properly installed)
|
|
823
|
+
try:
|
|
824
|
+
from hydra import initialize_config_module
|
|
825
|
+
initialize_config_module("sam2", version_base="1.2")
|
|
826
|
+
return True
|
|
827
|
+
except Exception:
|
|
828
|
+
pass
|
|
829
|
+
|
|
830
|
+
return False
|
|
831
|
+
|
|
832
|
+
except ImportError as e:
|
|
833
|
+
# Hydra not installed
|
|
834
|
+
return False
|
|
835
|
+
except Exception as e:
|
|
836
|
+
# Hydra initialization failed
|
|
837
|
+
return False
|
|
838
|
+
|
|
839
|
+
def _has_installed_sam2_distribution(self):
|
|
840
|
+
"""Return True when SAM2 is importable as a Python package."""
|
|
841
|
+
for dist_name in ("SAM-2", "sam2"):
|
|
842
|
+
try:
|
|
843
|
+
importlib_metadata.distribution(dist_name)
|
|
844
|
+
return True
|
|
845
|
+
except importlib_metadata.PackageNotFoundError:
|
|
846
|
+
continue
|
|
847
|
+
except Exception:
|
|
848
|
+
continue
|
|
849
|
+
try:
|
|
850
|
+
import importlib.util
|
|
851
|
+
return importlib.util.find_spec("sam2") is not None
|
|
852
|
+
except Exception:
|
|
853
|
+
return False
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
def _check_sam2_availability(self):
|
|
857
|
+
"""Check if SAM2 is available."""
|
|
858
|
+
has_installed_pkg = self._has_installed_sam2_distribution()
|
|
859
|
+
|
|
860
|
+
# Only report SAM2 as installed when it exists as an actual Python package
|
|
861
|
+
try:
|
|
862
|
+
if has_installed_pkg:
|
|
863
|
+
# Initialize Hydra before importing SAM2
|
|
864
|
+
self._ensure_hydra_initialized()
|
|
865
|
+
from sam2.build_sam import build_sam2_video_predictor
|
|
866
|
+
self.sam2_available = True
|
|
867
|
+
self.setup_status_label.setText("SAM2 is installed")
|
|
868
|
+
self.setup_status_label.setStyleSheet("color: green;")
|
|
869
|
+
self._populate_checkpoints()
|
|
870
|
+
# Set default model selection
|
|
871
|
+
for i in range(self.combo_model.count()):
|
|
872
|
+
if self.combo_model.itemData(i) == "sam2.1_hiera_large.pt":
|
|
873
|
+
self.combo_model.setCurrentIndex(i)
|
|
874
|
+
break
|
|
875
|
+
self._check_model_availability()
|
|
876
|
+
return
|
|
877
|
+
except (ImportError, RuntimeError) as e:
|
|
878
|
+
# If RuntimeError about parent directory, SAM2 needs to be properly installed
|
|
879
|
+
if isinstance(e, RuntimeError) and ("parent directory" in str(e) or "shadowed" in str(e)):
|
|
880
|
+
# Check if sam2_backend directory exists - if so, it needs to be reinstalled
|
|
881
|
+
if os.path.exists(self.sam2_backend_dir):
|
|
882
|
+
sam2_package = os.path.join(self.sam2_backend_dir, "sam2")
|
|
883
|
+
if os.path.exists(sam2_package):
|
|
884
|
+
# SAM2 exists but not properly installed - needs pip install -e
|
|
885
|
+
self.sam2_available = False
|
|
886
|
+
self.setup_status_label.setText("SAM2 needs reinstallation")
|
|
887
|
+
self.setup_status_label.setStyleSheet("color: orange;")
|
|
888
|
+
return
|
|
889
|
+
|
|
890
|
+
# If the source tree exists but the package is not installed, report that clearly.
|
|
891
|
+
sam2_folder = self.sam2_backend_dir
|
|
892
|
+
if os.path.exists(sam2_folder):
|
|
893
|
+
sam2_package = os.path.join(sam2_folder, "sam2")
|
|
894
|
+
if os.path.exists(sam2_package) and os.path.exists(os.path.join(sam2_package, "__init__.py")):
|
|
895
|
+
self.sam2_available = False
|
|
896
|
+
self.setup_status_label.setText("SAM2 source found, but not installed")
|
|
897
|
+
self.setup_status_label.setStyleSheet("color: orange;")
|
|
898
|
+
return
|
|
899
|
+
|
|
900
|
+
self.sam2_available = False
|
|
901
|
+
self.setup_status_label.setText("SAM2 not installed")
|
|
902
|
+
self.setup_status_label.setStyleSheet("color: red;")
|
|
903
|
+
|
|
904
|
+
def _setup_ui(self):
|
|
905
|
+
"""Setup UI components."""
|
|
906
|
+
layout = QVBoxLayout(self)
|
|
907
|
+
|
|
908
|
+
# Top row: SAM2 Setup and Model Selection side by side
|
|
909
|
+
top_row_layout = QHBoxLayout()
|
|
910
|
+
|
|
911
|
+
# SAM2 Setup Section (left side)
|
|
912
|
+
setup_group = QGroupBox("SAM2 Setup")
|
|
913
|
+
setup_layout = QVBoxLayout()
|
|
914
|
+
|
|
915
|
+
setup_info_layout = QHBoxLayout()
|
|
916
|
+
setup_info_layout.addWidget(QLabel("Status:"))
|
|
917
|
+
self.setup_status_label = QLabel("Checking...")
|
|
918
|
+
setup_info_layout.addWidget(self.setup_status_label)
|
|
919
|
+
setup_info_layout.addStretch()
|
|
920
|
+
setup_layout.addLayout(setup_info_layout)
|
|
921
|
+
|
|
922
|
+
setup_path_layout = QHBoxLayout()
|
|
923
|
+
setup_path_layout.addWidget(QLabel("Package:"))
|
|
924
|
+
self.setup_path_label = QLabel(self.sam2_backend_dir)
|
|
925
|
+
self.setup_path_label.setWordWrap(True)
|
|
926
|
+
self.setup_path_label.setStyleSheet("color: gray;")
|
|
927
|
+
setup_path_layout.addWidget(self.setup_path_label, stretch=1)
|
|
928
|
+
setup_layout.addLayout(setup_path_layout)
|
|
929
|
+
|
|
930
|
+
ckpt_path_layout = QHBoxLayout()
|
|
931
|
+
ckpt_path_layout.addWidget(QLabel("Checkpoints:"))
|
|
932
|
+
self.ckpt_path_label = QLabel(self.sam2_dir)
|
|
933
|
+
self.ckpt_path_label.setWordWrap(True)
|
|
934
|
+
self.ckpt_path_label.setStyleSheet("color: gray;")
|
|
935
|
+
ckpt_path_layout.addWidget(self.ckpt_path_label, stretch=1)
|
|
936
|
+
setup_layout.addLayout(ckpt_path_layout)
|
|
937
|
+
|
|
938
|
+
setup_group.setLayout(setup_layout)
|
|
939
|
+
top_row_layout.addWidget(setup_group)
|
|
940
|
+
|
|
941
|
+
# Model Selection (right side) - matching Video Settings width
|
|
942
|
+
model_group = QGroupBox("Model selection")
|
|
943
|
+
model_group.setFixedWidth(380) # Match Video Settings & Controls width
|
|
944
|
+
model_layout = QVBoxLayout()
|
|
945
|
+
|
|
946
|
+
model_select_layout = QHBoxLayout()
|
|
947
|
+
model_select_layout.addWidget(QLabel("Model:"))
|
|
948
|
+
self.combo_model = QComboBox()
|
|
949
|
+
# Add all available models with user-friendly names
|
|
950
|
+
self.model_names = {
|
|
951
|
+
"sam2.1_hiera_tiny.pt": "SAM2.1 Tiny (38.9M, Fastest)",
|
|
952
|
+
"sam2.1_hiera_small.pt": "SAM2.1 Small (46M, Fast)",
|
|
953
|
+
"sam2.1_hiera_base_plus.pt": "SAM2.1 Base+ (80.8M, Balanced)",
|
|
954
|
+
"sam2.1_hiera_large.pt": "SAM2.1 Large (224.4M, Best Quality)",
|
|
955
|
+
"sam2_hiera_tiny.pt": "SAM2.0 Tiny (38.9M, Legacy)",
|
|
956
|
+
"sam2_hiera_small.pt": "SAM2.0 Small (46M, Legacy)",
|
|
957
|
+
"sam2_hiera_base_plus.pt": "SAM2.0 Base+ (80.8M, Legacy)",
|
|
958
|
+
"sam2_hiera_large.pt": "SAM2.0 Large (224.4M, Legacy)",
|
|
959
|
+
}
|
|
960
|
+
for model_file, display_name in self.model_names.items():
|
|
961
|
+
self.combo_model.addItem(display_name, model_file)
|
|
962
|
+
self.combo_model.currentIndexChanged.connect(self._on_model_selected)
|
|
963
|
+
model_select_layout.addWidget(self.combo_model)
|
|
964
|
+
model_layout.addLayout(model_select_layout)
|
|
965
|
+
|
|
966
|
+
self.model_status_label = QLabel("Select a model to check/download")
|
|
967
|
+
self.model_status_label.setWordWrap(True)
|
|
968
|
+
self.model_status_label.setStyleSheet("color: gray;")
|
|
969
|
+
model_layout.addWidget(self.model_status_label)
|
|
970
|
+
|
|
971
|
+
self.download_progress = QProgressBar()
|
|
972
|
+
self.download_progress.setVisible(False)
|
|
973
|
+
model_layout.addWidget(self.download_progress)
|
|
974
|
+
|
|
975
|
+
model_group.setLayout(model_layout)
|
|
976
|
+
top_row_layout.addWidget(model_group)
|
|
977
|
+
|
|
978
|
+
layout.addLayout(top_row_layout)
|
|
979
|
+
|
|
980
|
+
# Legacy checkpoint combo (hidden, kept for compatibility)
|
|
981
|
+
self.combo_ckpt = QComboBox()
|
|
982
|
+
self.combo_ckpt.currentTextChanged.connect(self._on_checkpoint_changed)
|
|
983
|
+
|
|
984
|
+
# Video Range Controls
|
|
985
|
+
range_group = QGroupBox("Processing range")
|
|
986
|
+
range_layout = QHBoxLayout()
|
|
987
|
+
|
|
988
|
+
self.chk_limit_range = QCheckBox("Limit Range")
|
|
989
|
+
self.chk_limit_range.toggled.connect(self._toggle_range_inputs)
|
|
990
|
+
range_layout.addWidget(self.chk_limit_range)
|
|
991
|
+
|
|
992
|
+
range_layout.addWidget(QLabel("Start:"))
|
|
993
|
+
self.spin_start = QSpinBox()
|
|
994
|
+
self.spin_start.setRange(0, 999999)
|
|
995
|
+
self.spin_start.setEnabled(False)
|
|
996
|
+
range_layout.addWidget(self.spin_start)
|
|
997
|
+
|
|
998
|
+
self.btn_set_start = QPushButton("Set")
|
|
999
|
+
self.btn_set_start.clicked.connect(self._set_range_start)
|
|
1000
|
+
self.btn_set_start.setEnabled(False)
|
|
1001
|
+
range_layout.addWidget(self.btn_set_start)
|
|
1002
|
+
|
|
1003
|
+
range_layout.addWidget(QLabel("End:"))
|
|
1004
|
+
self.spin_end = QSpinBox()
|
|
1005
|
+
self.spin_end.setRange(0, 999999)
|
|
1006
|
+
self.spin_end.setEnabled(False)
|
|
1007
|
+
range_layout.addWidget(self.spin_end)
|
|
1008
|
+
|
|
1009
|
+
self.btn_set_end = QPushButton("Set")
|
|
1010
|
+
self.btn_set_end.clicked.connect(self._set_range_end)
|
|
1011
|
+
self.btn_set_end.setEnabled(False)
|
|
1012
|
+
range_layout.addWidget(self.btn_set_end)
|
|
1013
|
+
|
|
1014
|
+
range_group.setLayout(range_layout)
|
|
1015
|
+
layout.addWidget(range_group)
|
|
1016
|
+
|
|
1017
|
+
# Video Display and Controls side by side
|
|
1018
|
+
video_row_layout = QHBoxLayout()
|
|
1019
|
+
|
|
1020
|
+
# Video Display (left side)
|
|
1021
|
+
self.video_scroll = QScrollArea()
|
|
1022
|
+
self.video_scroll.setStyleSheet("background-color: black; border: none;")
|
|
1023
|
+
self.video_scroll.setWidgetResizable(False)
|
|
1024
|
+
self.video_scroll.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1025
|
+
|
|
1026
|
+
self.video_label = VideoLabel()
|
|
1027
|
+
self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1028
|
+
self.video_label.setMinimumSize(1, 1)
|
|
1029
|
+
self.video_label.setStyleSheet("background-color: black;")
|
|
1030
|
+
self.video_label.click_signal.connect(self._handle_click)
|
|
1031
|
+
self.video_scroll.setWidget(self.video_label)
|
|
1032
|
+
self.video_scroll.viewport().installEventFilter(self)
|
|
1033
|
+
|
|
1034
|
+
self.btn_zoom_in = QPushButton("+", self.video_scroll.viewport())
|
|
1035
|
+
self.btn_zoom_out = QPushButton("-", self.video_scroll.viewport())
|
|
1036
|
+
self._style_zoom_button(self.btn_zoom_in)
|
|
1037
|
+
self._style_zoom_button(self.btn_zoom_out)
|
|
1038
|
+
self.btn_zoom_in.clicked.connect(self._zoom_in)
|
|
1039
|
+
self.btn_zoom_out.clicked.connect(self._zoom_out)
|
|
1040
|
+
self._position_zoom_buttons()
|
|
1041
|
+
|
|
1042
|
+
video_row_layout.addWidget(self.video_scroll, stretch=2)
|
|
1043
|
+
|
|
1044
|
+
# Controls Container (right side) - matching Model Selection width
|
|
1045
|
+
controls_group = QGroupBox("Video settings & controls")
|
|
1046
|
+
# Set width to match Model Selection container (approximately 350-400px)
|
|
1047
|
+
controls_group.setFixedWidth(380)
|
|
1048
|
+
controls_layout = QVBoxLayout()
|
|
1049
|
+
|
|
1050
|
+
# Load Video button
|
|
1051
|
+
self.btn_load = QPushButton("Load videos")
|
|
1052
|
+
self.btn_load.clicked.connect(self._load_video)
|
|
1053
|
+
controls_layout.addWidget(self.btn_load)
|
|
1054
|
+
|
|
1055
|
+
# Video list
|
|
1056
|
+
self.video_list_widget = QListWidget()
|
|
1057
|
+
self.video_list_widget.setSelectionMode(QListWidget.SelectionMode.SingleSelection)
|
|
1058
|
+
self.video_list_widget.currentRowChanged.connect(self._on_video_selected)
|
|
1059
|
+
controls_layout.addWidget(self.video_list_widget)
|
|
1060
|
+
|
|
1061
|
+
# Object Controls
|
|
1062
|
+
obj_layout = QHBoxLayout()
|
|
1063
|
+
obj_layout.addWidget(QLabel("Object:"))
|
|
1064
|
+
|
|
1065
|
+
self.combo_obj = QComboBox()
|
|
1066
|
+
self.combo_obj.addItem("Object 1", 1)
|
|
1067
|
+
self.combo_obj.currentIndexChanged.connect(self._on_object_changed)
|
|
1068
|
+
obj_layout.addWidget(self.combo_obj)
|
|
1069
|
+
|
|
1070
|
+
self.btn_add_obj = QPushButton("+")
|
|
1071
|
+
self.btn_add_obj.setFixedWidth(30)
|
|
1072
|
+
self.btn_add_obj.clicked.connect(self._add_object)
|
|
1073
|
+
obj_layout.addWidget(self.btn_add_obj)
|
|
1074
|
+
|
|
1075
|
+
controls_layout.addLayout(obj_layout)
|
|
1076
|
+
|
|
1077
|
+
# Point type
|
|
1078
|
+
point_type_layout = QHBoxLayout()
|
|
1079
|
+
self.radio_pos = QRadioButton("Positive (+)")
|
|
1080
|
+
self.radio_neg = QRadioButton("Negative (-)")
|
|
1081
|
+
self.radio_pos.setChecked(True)
|
|
1082
|
+
self.btn_group = QButtonGroup()
|
|
1083
|
+
self.btn_group.addButton(self.radio_pos)
|
|
1084
|
+
self.btn_group.addButton(self.radio_neg)
|
|
1085
|
+
point_type_layout.addWidget(self.radio_pos)
|
|
1086
|
+
point_type_layout.addWidget(self.radio_neg)
|
|
1087
|
+
controls_layout.addLayout(point_type_layout)
|
|
1088
|
+
|
|
1089
|
+
self.btn_clear_points = QPushButton("Clear points")
|
|
1090
|
+
self.btn_clear_points.clicked.connect(self._clear_points)
|
|
1091
|
+
controls_layout.addWidget(self.btn_clear_points)
|
|
1092
|
+
|
|
1093
|
+
controls_layout.addSpacing(10)
|
|
1094
|
+
|
|
1095
|
+
self.chk_auto_follow = QCheckBox("Auto-follow")
|
|
1096
|
+
self.chk_auto_follow.setChecked(True)
|
|
1097
|
+
self.chk_auto_follow.setToolTip("Automatically move slider to the frame being processed.")
|
|
1098
|
+
controls_layout.addWidget(self.chk_auto_follow)
|
|
1099
|
+
|
|
1100
|
+
self.btn_track = QPushButton("Run tracking (Current)")
|
|
1101
|
+
self.btn_track.clicked.connect(self._run_tracking)
|
|
1102
|
+
self.btn_track.setEnabled(False)
|
|
1103
|
+
controls_layout.addWidget(self.btn_track)
|
|
1104
|
+
|
|
1105
|
+
# Pause / Resume tracking controls
|
|
1106
|
+
pause_resume_layout = QHBoxLayout()
|
|
1107
|
+
self.btn_pause_tracking = QPushButton("Pause tracking")
|
|
1108
|
+
self.btn_pause_tracking.setEnabled(False)
|
|
1109
|
+
self.btn_pause_tracking.clicked.connect(self._pause_tracking)
|
|
1110
|
+
pause_resume_layout.addWidget(self.btn_pause_tracking)
|
|
1111
|
+
|
|
1112
|
+
self.btn_resume_tracking = QPushButton("Resume tracking from here")
|
|
1113
|
+
self.btn_resume_tracking.setEnabled(False)
|
|
1114
|
+
self.btn_resume_tracking.clicked.connect(self._resume_tracking)
|
|
1115
|
+
pause_resume_layout.addWidget(self.btn_resume_tracking)
|
|
1116
|
+
|
|
1117
|
+
controls_layout.addLayout(pause_resume_layout)
|
|
1118
|
+
|
|
1119
|
+
self.btn_track_all = QPushButton("Run tracking (All videos)")
|
|
1120
|
+
self.btn_track_all.clicked.connect(self._run_tracking_all)
|
|
1121
|
+
self.btn_track_all.setEnabled(False)
|
|
1122
|
+
controls_layout.addWidget(self.btn_track_all)
|
|
1123
|
+
|
|
1124
|
+
self.chk_save_overlay = QCheckBox("Save overlay video after tracking")
|
|
1125
|
+
self.chk_save_overlay.setChecked(self.save_overlay_video)
|
|
1126
|
+
self.chk_save_overlay.setToolTip(
|
|
1127
|
+
"Save an MP4 with colored mask overlays for later inspection.\n"
|
|
1128
|
+
"Also applies when tracking is paused."
|
|
1129
|
+
)
|
|
1130
|
+
self.chk_save_overlay.toggled.connect(lambda v: setattr(self, "save_overlay_video", bool(v)))
|
|
1131
|
+
controls_layout.addWidget(self.chk_save_overlay)
|
|
1132
|
+
|
|
1133
|
+
# SAM2 tracking resolution
|
|
1134
|
+
res_layout = QHBoxLayout()
|
|
1135
|
+
res_layout.addWidget(QLabel("Tracking resolution:"))
|
|
1136
|
+
self.tracking_res_combo = QComboBox()
|
|
1137
|
+
self.tracking_res_combo.addItem("256 (fastest, low quality)", 256)
|
|
1138
|
+
self.tracking_res_combo.addItem("384 (fast)", 384)
|
|
1139
|
+
self.tracking_res_combo.addItem("512 (balanced)", 512)
|
|
1140
|
+
self.tracking_res_combo.addItem("1024 (best quality, default)", 1024)
|
|
1141
|
+
self.tracking_res_combo.setCurrentIndex(3)
|
|
1142
|
+
self.tracking_res_combo.setToolTip(
|
|
1143
|
+
"Resolution at which SAM2 processes frames.\n"
|
|
1144
|
+
"Lower = faster tracking but less precise masks.\n"
|
|
1145
|
+
"512 is good for centroid/bbox extraction."
|
|
1146
|
+
)
|
|
1147
|
+
res_layout.addWidget(self.tracking_res_combo)
|
|
1148
|
+
controls_layout.addLayout(res_layout)
|
|
1149
|
+
|
|
1150
|
+
self.btn_preview = QPushButton("Preview frame")
|
|
1151
|
+
self.btn_preview.clicked.connect(self._preview_frame)
|
|
1152
|
+
self.btn_preview.setEnabled(False)
|
|
1153
|
+
controls_layout.addWidget(self.btn_preview)
|
|
1154
|
+
|
|
1155
|
+
self.btn_settings = QPushButton("Settings")
|
|
1156
|
+
self.btn_settings.clicked.connect(self._open_settings)
|
|
1157
|
+
controls_layout.addWidget(self.btn_settings)
|
|
1158
|
+
|
|
1159
|
+
controls_layout.addStretch()
|
|
1160
|
+
|
|
1161
|
+
controls_group.setLayout(controls_layout)
|
|
1162
|
+
video_row_layout.addWidget(controls_group)
|
|
1163
|
+
|
|
1164
|
+
layout.addLayout(video_row_layout)
|
|
1165
|
+
|
|
1166
|
+
# Slider
|
|
1167
|
+
self.slider = QSlider(Qt.Orientation.Horizontal)
|
|
1168
|
+
self.slider.sliderMoved.connect(self._set_frame)
|
|
1169
|
+
self.slider.setEnabled(False)
|
|
1170
|
+
layout.addWidget(self.slider)
|
|
1171
|
+
|
|
1172
|
+
# Frame navigation row
|
|
1173
|
+
nav_layout = QHBoxLayout()
|
|
1174
|
+
|
|
1175
|
+
self.btn_prev_frame = QPushButton("<")
|
|
1176
|
+
self.btn_prev_frame.setFixedWidth(40)
|
|
1177
|
+
self.btn_prev_frame.clicked.connect(self._prev_frame)
|
|
1178
|
+
self.btn_prev_frame.setEnabled(False)
|
|
1179
|
+
nav_layout.addWidget(self.btn_prev_frame)
|
|
1180
|
+
|
|
1181
|
+
self.lbl_frame = QLabel("Frame: 0 / 0")
|
|
1182
|
+
nav_layout.addWidget(self.lbl_frame, stretch=1)
|
|
1183
|
+
|
|
1184
|
+
self.btn_next_frame = QPushButton(">")
|
|
1185
|
+
self.btn_next_frame.setFixedWidth(40)
|
|
1186
|
+
self.btn_next_frame.clicked.connect(self._next_frame)
|
|
1187
|
+
self.btn_next_frame.setEnabled(False)
|
|
1188
|
+
nav_layout.addWidget(self.btn_next_frame)
|
|
1189
|
+
|
|
1190
|
+
layout.addLayout(nav_layout)
|
|
1191
|
+
|
|
1192
|
+
# Progress bar
|
|
1193
|
+
self.progress_bar = QProgressBar()
|
|
1194
|
+
self.progress_bar.setVisible(False)
|
|
1195
|
+
layout.addWidget(self.progress_bar)
|
|
1196
|
+
|
|
1197
|
+
# Log area
|
|
1198
|
+
self.log_text = QLabel("")
|
|
1199
|
+
self.log_text.setWordWrap(True)
|
|
1200
|
+
self.log_text.setMaximumHeight(80)
|
|
1201
|
+
layout.addWidget(self.log_text)
|
|
1202
|
+
|
|
1203
|
+
def _download_checkpoints(self):
|
|
1204
|
+
"""Prompt user about checkpoint download (now handled automatically)."""
|
|
1205
|
+
QMessageBox.information(
|
|
1206
|
+
self,
|
|
1207
|
+
"Checkpoint Download",
|
|
1208
|
+
"Checkpoints are now downloaded automatically when you select a model.\n\n"
|
|
1209
|
+
"Simply select a model from the dropdown above, and it will be downloaded if not already present."
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
def _check_model_availability(self):
|
|
1213
|
+
"""Check if selected model checkpoint exists."""
|
|
1214
|
+
if not self.sam2_available:
|
|
1215
|
+
self.model_status_label.setText("SAM2 not installed. Run install.sh and reopen the app.")
|
|
1216
|
+
self.model_status_label.setStyleSheet("color: red;")
|
|
1217
|
+
return
|
|
1218
|
+
|
|
1219
|
+
idx = self.combo_model.currentIndex()
|
|
1220
|
+
if idx < 0:
|
|
1221
|
+
return
|
|
1222
|
+
|
|
1223
|
+
model_name = self.combo_model.itemData(idx)
|
|
1224
|
+
if not model_name:
|
|
1225
|
+
return
|
|
1226
|
+
|
|
1227
|
+
ckpt_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
|
|
1228
|
+
|
|
1229
|
+
if os.path.exists(ckpt_path):
|
|
1230
|
+
file_size = os.path.getsize(ckpt_path) / (1024**2)
|
|
1231
|
+
if file_size > 10: # Reasonable size check
|
|
1232
|
+
self.model_status_label.setText(f"{model_name} available ({file_size:.1f} MB)")
|
|
1233
|
+
self.model_status_label.setStyleSheet("color: green;")
|
|
1234
|
+
else:
|
|
1235
|
+
self.model_status_label.setText(f"{model_name} file seems corrupted ({file_size:.1f} MB). Will re-download.")
|
|
1236
|
+
self.model_status_label.setStyleSheet("color: orange;")
|
|
1237
|
+
else:
|
|
1238
|
+
self.model_status_label.setText(f"{model_name} not found. Will download automatically when selected.")
|
|
1239
|
+
self.model_status_label.setStyleSheet("color: orange;")
|
|
1240
|
+
|
|
1241
|
+
def _on_model_selected(self):
|
|
1242
|
+
"""Handle model selection change."""
|
|
1243
|
+
idx = self.combo_model.currentIndex()
|
|
1244
|
+
if idx < 0:
|
|
1245
|
+
return
|
|
1246
|
+
|
|
1247
|
+
model_name = self.combo_model.itemData(idx)
|
|
1248
|
+
if not model_name:
|
|
1249
|
+
return
|
|
1250
|
+
|
|
1251
|
+
ckpt_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
|
|
1252
|
+
|
|
1253
|
+
# Check if checkpoint exists and is valid
|
|
1254
|
+
if os.path.exists(ckpt_path):
|
|
1255
|
+
file_size = os.path.getsize(ckpt_path) / (1024**2)
|
|
1256
|
+
if file_size > 10: # Reasonable size check (should be >100MB typically)
|
|
1257
|
+
self.model_status_label.setText(f"{model_name} ready ({file_size:.1f} MB)")
|
|
1258
|
+
self.model_status_label.setStyleSheet("color: green;")
|
|
1259
|
+
self.checkpoint_path = ckpt_path
|
|
1260
|
+
self._on_checkpoint_changed(model_name)
|
|
1261
|
+
return
|
|
1262
|
+
|
|
1263
|
+
# Checkpoint doesn't exist or is invalid, download it
|
|
1264
|
+
if model_name not in CheckpointDownloadWorker.MODEL_URLS:
|
|
1265
|
+
self.model_status_label.setText(f"Unknown model: {model_name}")
|
|
1266
|
+
self.model_status_label.setStyleSheet("color: red;")
|
|
1267
|
+
return
|
|
1268
|
+
|
|
1269
|
+
# Start download
|
|
1270
|
+
self._download_checkpoint(model_name, ckpt_path, CheckpointDownloadWorker.MODEL_URLS[model_name])
|
|
1271
|
+
|
|
1272
|
+
def _download_checkpoint(self, model_name, ckpt_path, model_url):
|
|
1273
|
+
"""Download a checkpoint file."""
|
|
1274
|
+
if self.download_worker and self.download_worker.isRunning():
|
|
1275
|
+
QMessageBox.warning(self, "Download in progress", "A checkpoint download is already in progress.")
|
|
1276
|
+
return
|
|
1277
|
+
|
|
1278
|
+
# Ensure checkpoints directory exists
|
|
1279
|
+
ckpt_dir = os.path.dirname(ckpt_path)
|
|
1280
|
+
os.makedirs(ckpt_dir, exist_ok=True)
|
|
1281
|
+
|
|
1282
|
+
self.model_status_label.setText(f"Downloading {model_name}...")
|
|
1283
|
+
self.model_status_label.setStyleSheet("color: blue;")
|
|
1284
|
+
self.download_progress.setVisible(True)
|
|
1285
|
+
self.download_progress.setRange(0, 0) # Indeterminate
|
|
1286
|
+
|
|
1287
|
+
self.download_worker = CheckpointDownloadWorker(model_name, ckpt_path, model_url)
|
|
1288
|
+
self.download_worker.progress.connect(self._on_download_progress)
|
|
1289
|
+
self.download_worker.finished.connect(self._on_download_finished)
|
|
1290
|
+
self.download_worker.start()
|
|
1291
|
+
|
|
1292
|
+
def _on_download_progress(self, message):
|
|
1293
|
+
"""Handle download progress updates."""
|
|
1294
|
+
self.model_status_label.setText(message)
|
|
1295
|
+
|
|
1296
|
+
def _on_download_finished(self, success, message):
|
|
1297
|
+
"""Handle download completion."""
|
|
1298
|
+
self.download_progress.setVisible(False)
|
|
1299
|
+
|
|
1300
|
+
if success:
|
|
1301
|
+
self.model_status_label.setText(f"{message}")
|
|
1302
|
+
self.model_status_label.setStyleSheet("color: green;")
|
|
1303
|
+
model_name = self.combo_model.currentData()
|
|
1304
|
+
self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
|
|
1305
|
+
self._on_checkpoint_changed(model_name)
|
|
1306
|
+
else:
|
|
1307
|
+
self.model_status_label.setText(f"{message}")
|
|
1308
|
+
self.model_status_label.setStyleSheet("color: red;")
|
|
1309
|
+
QMessageBox.critical(self, "Download failed", message)
|
|
1310
|
+
|
|
1311
|
+
def _populate_checkpoints(self):
|
|
1312
|
+
"""Populate checkpoint combo box (legacy, for compatibility)."""
|
|
1313
|
+
self.combo_ckpt.clear()
|
|
1314
|
+
ckpt_dir = os.path.join(self.sam2_dir, "checkpoints")
|
|
1315
|
+
if os.path.exists(ckpt_dir):
|
|
1316
|
+
checkpoints = [f for f in os.listdir(ckpt_dir) if f.endswith(".pt")]
|
|
1317
|
+
checkpoints.sort()
|
|
1318
|
+
self.combo_ckpt.addItems(checkpoints)
|
|
1319
|
+
|
|
1320
|
+
default = "sam2.1_hiera_large.pt"
|
|
1321
|
+
if default in checkpoints:
|
|
1322
|
+
self.combo_ckpt.setCurrentText(default)
|
|
1323
|
+
elif checkpoints:
|
|
1324
|
+
self.combo_ckpt.setCurrentIndex(0)
|
|
1325
|
+
|
|
1326
|
+
def _on_checkpoint_changed(self, ckpt_name):
|
|
1327
|
+
"""Handle checkpoint selection change."""
|
|
1328
|
+
if not ckpt_name:
|
|
1329
|
+
return
|
|
1330
|
+
|
|
1331
|
+
# Update checkpoint path
|
|
1332
|
+
self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", ckpt_name)
|
|
1333
|
+
self.model_cfg = self._get_model_config(ckpt_name)
|
|
1334
|
+
|
|
1335
|
+
self.predictor = None
|
|
1336
|
+
self.inference_state = None
|
|
1337
|
+
self.masks = {}
|
|
1338
|
+
self.points = []
|
|
1339
|
+
# Clean up any incremental mask temp file
|
|
1340
|
+
if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
|
|
1341
|
+
try:
|
|
1342
|
+
self._incremental_mask_file.close()
|
|
1343
|
+
os.unlink(self._incremental_mask_file.name)
|
|
1344
|
+
except:
|
|
1345
|
+
pass
|
|
1346
|
+
self._incremental_mask_file = None
|
|
1347
|
+
self._update_frame()
|
|
1348
|
+
|
|
1349
|
+
def _get_model_config(self, ckpt_name):
|
|
1350
|
+
"""Map checkpoint name to config."""
|
|
1351
|
+
if "sam2.1" in ckpt_name:
|
|
1352
|
+
prefix = "configs/sam2.1/sam2.1_hiera_"
|
|
1353
|
+
else:
|
|
1354
|
+
prefix = "configs/sam2/sam2_hiera_"
|
|
1355
|
+
|
|
1356
|
+
if "large" in ckpt_name.lower():
|
|
1357
|
+
return prefix + "l.yaml"
|
|
1358
|
+
elif "base_plus" in ckpt_name.lower() or "b+" in ckpt_name.lower():
|
|
1359
|
+
return prefix + "b+.yaml"
|
|
1360
|
+
elif "small" in ckpt_name.lower():
|
|
1361
|
+
return prefix + "s.yaml"
|
|
1362
|
+
elif "tiny" in ckpt_name.lower():
|
|
1363
|
+
return prefix + "t.yaml"
|
|
1364
|
+
|
|
1365
|
+
return prefix + "l.yaml"
|
|
1366
|
+
|
|
1367
|
+
def _toggle_range_inputs(self, checked):
|
|
1368
|
+
"""Toggle range input widgets."""
|
|
1369
|
+
self.spin_start.setEnabled(checked)
|
|
1370
|
+
self.spin_end.setEnabled(checked)
|
|
1371
|
+
self.btn_set_start.setEnabled(checked)
|
|
1372
|
+
self.btn_set_end.setEnabled(checked)
|
|
1373
|
+
|
|
1374
|
+
def _set_range_start(self):
|
|
1375
|
+
"""Apply the start frame chosen in the spin box (clamped to video length)."""
|
|
1376
|
+
if not self.cap:
|
|
1377
|
+
return
|
|
1378
|
+
|
|
1379
|
+
start_val = max(0, min(self.spin_start.value(), self.total_frames - 1))
|
|
1380
|
+
|
|
1381
|
+
# Clamp and update start value
|
|
1382
|
+
self.spin_start.blockSignals(True)
|
|
1383
|
+
self.spin_start.setValue(start_val)
|
|
1384
|
+
self.spin_start.blockSignals(False)
|
|
1385
|
+
|
|
1386
|
+
# Ensure start is not beyond end
|
|
1387
|
+
if start_val > self.spin_end.value():
|
|
1388
|
+
self.spin_end.setValue(start_val)
|
|
1389
|
+
|
|
1390
|
+
# Jump preview to start frame so the user sees what was set
|
|
1391
|
+
self.slider.setValue(start_val)
|
|
1392
|
+
self._set_frame(start_val)
|
|
1393
|
+
|
|
1394
|
+
def _set_range_end(self):
|
|
1395
|
+
"""Apply the end frame chosen in the spin box (clamped to video length)."""
|
|
1396
|
+
if not self.cap:
|
|
1397
|
+
return
|
|
1398
|
+
|
|
1399
|
+
end_val = max(0, min(self.spin_end.value(), self.total_frames - 1))
|
|
1400
|
+
|
|
1401
|
+
# Clamp and update end value
|
|
1402
|
+
self.spin_end.blockSignals(True)
|
|
1403
|
+
self.spin_end.setValue(end_val)
|
|
1404
|
+
self.spin_end.blockSignals(False)
|
|
1405
|
+
|
|
1406
|
+
# Ensure end is not before start
|
|
1407
|
+
if end_val < self.spin_start.value():
|
|
1408
|
+
self.spin_start.setValue(end_val)
|
|
1409
|
+
|
|
1410
|
+
# Jump preview to end frame so the user sees what was set
|
|
1411
|
+
self.slider.setValue(end_val)
|
|
1412
|
+
self._set_frame(end_val)
|
|
1413
|
+
|
|
1414
|
+
def _add_object(self):
|
|
1415
|
+
"""Add a new object ID."""
|
|
1416
|
+
new_id = max(self.obj_ids) + 1
|
|
1417
|
+
self.obj_ids.append(new_id)
|
|
1418
|
+
self.combo_obj.addItem(f"Object {new_id}", new_id)
|
|
1419
|
+
self.combo_obj.setCurrentIndex(self.combo_obj.count() - 1)
|
|
1420
|
+
|
|
1421
|
+
def _on_object_changed(self):
|
|
1422
|
+
"""Handle object selection change."""
|
|
1423
|
+
self.current_obj_id = self.combo_obj.currentData()
|
|
1424
|
+
|
|
1425
|
+
def _create_video_state(self, path):
|
|
1426
|
+
"""Create a per-video state dict."""
|
|
1427
|
+
cap = cv2.VideoCapture(path)
|
|
1428
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
1429
|
+
cap.release()
|
|
1430
|
+
return {
|
|
1431
|
+
"path": path,
|
|
1432
|
+
"total_frames": total_frames,
|
|
1433
|
+
"points": [],
|
|
1434
|
+
"masks": {},
|
|
1435
|
+
"obj_ids": [1],
|
|
1436
|
+
"current_obj_id": 1,
|
|
1437
|
+
"current_frame_idx": 0,
|
|
1438
|
+
"spin_start": 0,
|
|
1439
|
+
"spin_end": max(total_frames - 1, 0),
|
|
1440
|
+
"inference_state": None,
|
|
1441
|
+
"state_start_frame": 0,
|
|
1442
|
+
}
|
|
1443
|
+
|
|
1444
|
+
def _save_current_video_state(self):
|
|
1445
|
+
"""Persist current UI state into the active video entry."""
|
|
1446
|
+
if self.current_video_idx is None or self.current_video_idx >= len(self.videos):
|
|
1447
|
+
return
|
|
1448
|
+
try:
|
|
1449
|
+
v = self.videos[self.current_video_idx]
|
|
1450
|
+
v["points"] = list(self.points)
|
|
1451
|
+
v["masks"] = dict(self.masks)
|
|
1452
|
+
v["obj_ids"] = list(self.obj_ids)
|
|
1453
|
+
v["current_obj_id"] = self.current_obj_id
|
|
1454
|
+
v["current_frame_idx"] = self.current_frame_idx
|
|
1455
|
+
v["spin_start"] = self.spin_start.value()
|
|
1456
|
+
v["spin_end"] = self.spin_end.value()
|
|
1457
|
+
v["inference_state"] = self.inference_state
|
|
1458
|
+
v["state_start_frame"] = self.state_start_frame
|
|
1459
|
+
except Exception:
|
|
1460
|
+
pass
|
|
1461
|
+
|
|
1462
|
+
def _apply_video_state(self, idx: int):
|
|
1463
|
+
"""Load a video's state into the UI and current attributes."""
|
|
1464
|
+
if idx < 0 or idx >= len(self.videos):
|
|
1465
|
+
return
|
|
1466
|
+
self.current_video_idx = idx
|
|
1467
|
+
v = self.videos[idx]
|
|
1468
|
+
self.video_path = v["path"]
|
|
1469
|
+
self.cap = cv2.VideoCapture(self.video_path)
|
|
1470
|
+
self.total_frames = v["total_frames"]
|
|
1471
|
+
|
|
1472
|
+
# Ranges
|
|
1473
|
+
self.slider.setRange(0, max(self.total_frames - 1, 0))
|
|
1474
|
+
self.spin_start.setRange(0, max(self.total_frames - 1, 0))
|
|
1475
|
+
self.spin_end.setRange(0, max(self.total_frames - 1, 0))
|
|
1476
|
+
self.spin_start.setValue(min(v["spin_start"], max(self.total_frames - 1, 0)))
|
|
1477
|
+
self.spin_end.setValue(min(v["spin_end"], max(self.total_frames - 1, 0)))
|
|
1478
|
+
|
|
1479
|
+
# Restore points/masks/objects
|
|
1480
|
+
self.points = list(v["points"])
|
|
1481
|
+
self.masks = dict(v["masks"])
|
|
1482
|
+
self.obj_ids = list(v["obj_ids"])
|
|
1483
|
+
self.current_obj_id = v["current_obj_id"]
|
|
1484
|
+
self.current_frame_idx = min(v["current_frame_idx"], max(self.total_frames - 1, 0))
|
|
1485
|
+
self.inference_state = v["inference_state"]
|
|
1486
|
+
self.state_start_frame = v.get("state_start_frame", 0)
|
|
1487
|
+
|
|
1488
|
+
# Rebuild object combo
|
|
1489
|
+
self.combo_obj.blockSignals(True)
|
|
1490
|
+
self.combo_obj.clear()
|
|
1491
|
+
for oid in self.obj_ids:
|
|
1492
|
+
self.combo_obj.addItem(f"Object {oid}", oid)
|
|
1493
|
+
idx_obj = self.combo_obj.findData(self.current_obj_id)
|
|
1494
|
+
if idx_obj >= 0:
|
|
1495
|
+
self.combo_obj.setCurrentIndex(idx_obj)
|
|
1496
|
+
self.combo_obj.blockSignals(False)
|
|
1497
|
+
|
|
1498
|
+
# Enable controls
|
|
1499
|
+
self.slider.setEnabled(True)
|
|
1500
|
+
self.btn_prev_frame.setEnabled(True)
|
|
1501
|
+
self.btn_next_frame.setEnabled(True)
|
|
1502
|
+
self.btn_track.setEnabled(self.sam2_available)
|
|
1503
|
+
self.btn_track_all.setEnabled(len(self.videos) > 0 and self.sam2_available)
|
|
1504
|
+
self.btn_preview.setEnabled(self.sam2_available)
|
|
1505
|
+
|
|
1506
|
+
# Move slider to current frame and refresh
|
|
1507
|
+
self.slider.blockSignals(True)
|
|
1508
|
+
self.slider.setValue(self.current_frame_idx)
|
|
1509
|
+
self.slider.blockSignals(False)
|
|
1510
|
+
self.zoom_factor = 1.0
|
|
1511
|
+
self._update_frame()
|
|
1512
|
+
|
|
1513
|
+
def _on_video_selected(self, row: int):
|
|
1514
|
+
"""Handle selection change from the video list."""
|
|
1515
|
+
self._save_current_video_state()
|
|
1516
|
+
if row >= 0:
|
|
1517
|
+
self._apply_video_state(row)
|
|
1518
|
+
|
|
1519
|
+
def _load_video(self):
|
|
1520
|
+
"""Load one or more video files."""
|
|
1521
|
+
video_dir = self.config.get("raw_videos_dir", self.config.get("data_dir", "data/raw_videos"))
|
|
1522
|
+
paths, _ = QFileDialog.getOpenFileNames(
|
|
1523
|
+
self, "Open Videos", video_dir, "Video Files (*.mp4 *.avi *.mov *.mkv)"
|
|
1524
|
+
)
|
|
1525
|
+
if not paths:
|
|
1526
|
+
return
|
|
1527
|
+
|
|
1528
|
+
from .video_utils import ensure_video_in_experiment
|
|
1529
|
+
|
|
1530
|
+
added_any = False
|
|
1531
|
+
for path in paths:
|
|
1532
|
+
path = ensure_video_in_experiment(path, self.config, self)
|
|
1533
|
+
# Avoid duplicates
|
|
1534
|
+
if any(v["path"] == path for v in self.videos):
|
|
1535
|
+
continue
|
|
1536
|
+
state = self._create_video_state(path)
|
|
1537
|
+
self.videos.append(state)
|
|
1538
|
+
self.video_list_widget.addItem(os.path.basename(path))
|
|
1539
|
+
added_any = True
|
|
1540
|
+
|
|
1541
|
+
if not added_any:
|
|
1542
|
+
return
|
|
1543
|
+
|
|
1544
|
+
# Auto-select first added video if none active
|
|
1545
|
+
if self.current_video_idx is None and self.videos:
|
|
1546
|
+
self.video_list_widget.setCurrentRow(0)
|
|
1547
|
+
self._apply_video_state(0)
|
|
1548
|
+
else:
|
|
1549
|
+
# Keep current selection, just enable batch controls
|
|
1550
|
+
self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
|
|
1551
|
+
|
|
1552
|
+
def _ensure_predictor(self):
|
|
1553
|
+
"""Ensure SAM2 model is loaded (rebuilds if resolution changed)."""
|
|
1554
|
+
tracking_res = self.tracking_res_combo.currentData() or 1024
|
|
1555
|
+
if self.predictor is not None:
|
|
1556
|
+
if getattr(self.predictor, "image_size", 1024) == tracking_res:
|
|
1557
|
+
return True
|
|
1558
|
+
# Resolution changed — need to rebuild
|
|
1559
|
+
del self.predictor
|
|
1560
|
+
self.predictor = None
|
|
1561
|
+
if torch.cuda.is_available():
|
|
1562
|
+
torch.cuda.empty_cache()
|
|
1563
|
+
|
|
1564
|
+
if not self.sam2_available:
|
|
1565
|
+
QMessageBox.warning(
|
|
1566
|
+
self,
|
|
1567
|
+
"SAM2 not available",
|
|
1568
|
+
"SAM2 is not installed in this environment.\n\nRun bash install.sh and reopen the app.",
|
|
1569
|
+
)
|
|
1570
|
+
return False
|
|
1571
|
+
|
|
1572
|
+
idx = self.combo_model.currentIndex()
|
|
1573
|
+
if idx < 0:
|
|
1574
|
+
QMessageBox.warning(self, "No model selected", "Please select a model.")
|
|
1575
|
+
return False
|
|
1576
|
+
|
|
1577
|
+
model_name = self.combo_model.itemData(idx)
|
|
1578
|
+
if not model_name:
|
|
1579
|
+
QMessageBox.warning(self, "No model selected", "Please select a model.")
|
|
1580
|
+
return False
|
|
1581
|
+
|
|
1582
|
+
self.checkpoint_path = os.path.join(self.sam2_dir, "checkpoints", model_name)
|
|
1583
|
+
if not os.path.exists(self.checkpoint_path):
|
|
1584
|
+
QMessageBox.warning(
|
|
1585
|
+
self,
|
|
1586
|
+
"Checkpoint Missing",
|
|
1587
|
+
f"Checkpoint not found:\n{self.checkpoint_path}\n\n"
|
|
1588
|
+
"The checkpoint should download automatically when you select a model.\n"
|
|
1589
|
+
"Please wait for the download to complete or select the model again."
|
|
1590
|
+
)
|
|
1591
|
+
return False
|
|
1592
|
+
|
|
1593
|
+
self.model_cfg = self._get_model_config(model_name)
|
|
1594
|
+
|
|
1595
|
+
try:
|
|
1596
|
+
# Ensure Hydra is initialized before importing SAM2
|
|
1597
|
+
if not self._ensure_hydra_initialized():
|
|
1598
|
+
QMessageBox.critical(
|
|
1599
|
+
self,
|
|
1600
|
+
"Hydra Initialization Failed",
|
|
1601
|
+
"Failed to initialize Hydra configuration system.\n\n"
|
|
1602
|
+
"Please ensure hydra-core is installed:\n"
|
|
1603
|
+
"pip install hydra-core>=1.3.2"
|
|
1604
|
+
)
|
|
1605
|
+
return False
|
|
1606
|
+
|
|
1607
|
+
# Import sam2 first to trigger its Hydra initialization
|
|
1608
|
+
import sam2
|
|
1609
|
+
# Then import the build function
|
|
1610
|
+
from sam2.build_sam import build_sam2_video_predictor
|
|
1611
|
+
|
|
1612
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
1613
|
+
if device == "cpu":
|
|
1614
|
+
QMessageBox.warning(self, "CPU mode", "Running on CPU. This will be very slow.")
|
|
1615
|
+
|
|
1616
|
+
hydra_extra = [f"++model.image_size={tracking_res}"]
|
|
1617
|
+
self.predictor = build_sam2_video_predictor(
|
|
1618
|
+
self.model_cfg, self.checkpoint_path, device=device,
|
|
1619
|
+
hydra_overrides_extra=hydra_extra,
|
|
1620
|
+
)
|
|
1621
|
+
self.predictor.fill_hole_area = self.fill_hole_area
|
|
1622
|
+
return True
|
|
1623
|
+
except RuntimeError as e:
|
|
1624
|
+
if "parent directory" in str(e) or "shadowed" in str(e):
|
|
1625
|
+
QMessageBox.critical(
|
|
1626
|
+
self,
|
|
1627
|
+
"SAM2 Import Error",
|
|
1628
|
+
f"SAM2 import conflict:\n{e}\n\n"
|
|
1629
|
+
"Solution: Please install SAM2 to a location outside the behavior_labeling_app directory.\n"
|
|
1630
|
+
"Use the 'Change...' button to select a different installation location."
|
|
1631
|
+
)
|
|
1632
|
+
else:
|
|
1633
|
+
QMessageBox.critical(self, "Error", f"Failed to init SAM2 model:\n{e}")
|
|
1634
|
+
return False
|
|
1635
|
+
except Exception as e:
|
|
1636
|
+
QMessageBox.critical(self, "Error", f"Failed to init SAM2 model:\n{e}")
|
|
1637
|
+
return False
|
|
1638
|
+
|
|
1639
|
+
def _set_frame(self, frame_idx):
|
|
1640
|
+
"""Set current frame index."""
|
|
1641
|
+
self.current_frame_idx = frame_idx
|
|
1642
|
+
self._update_frame()
|
|
1643
|
+
|
|
1644
|
+
def eventFilter(self, source, event):
|
|
1645
|
+
if source is self.video_scroll.viewport() and event.type() == QEvent.Type.Resize:
|
|
1646
|
+
self._position_zoom_buttons()
|
|
1647
|
+
return super().eventFilter(source, event)
|
|
1648
|
+
|
|
1649
|
+
def _style_zoom_button(self, btn):
|
|
1650
|
+
btn.setFixedSize(34, 34)
|
|
1651
|
+
btn.setCursor(Qt.CursorShape.PointingHandCursor)
|
|
1652
|
+
btn.setStyleSheet(
|
|
1653
|
+
"QPushButton {"
|
|
1654
|
+
"background-color: rgba(20, 20, 20, 190);"
|
|
1655
|
+
"color: white;"
|
|
1656
|
+
"border: 1px solid rgba(255, 255, 255, 120);"
|
|
1657
|
+
"border-radius: 17px;"
|
|
1658
|
+
"font-size: 18px;"
|
|
1659
|
+
"font-weight: bold;"
|
|
1660
|
+
"}"
|
|
1661
|
+
"QPushButton:hover {"
|
|
1662
|
+
"background-color: rgba(45, 45, 45, 220);"
|
|
1663
|
+
"}"
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
def _position_zoom_buttons(self):
|
|
1667
|
+
if not hasattr(self, "video_scroll") or not hasattr(self, "btn_zoom_in"):
|
|
1668
|
+
return
|
|
1669
|
+
viewport = self.video_scroll.viewport()
|
|
1670
|
+
margin = 10
|
|
1671
|
+
spacing = 8
|
|
1672
|
+
x = viewport.width() - self.btn_zoom_in.width() - margin
|
|
1673
|
+
y = margin
|
|
1674
|
+
self.btn_zoom_in.move(x, y)
|
|
1675
|
+
self.btn_zoom_out.move(x, y + self.btn_zoom_in.height() + spacing)
|
|
1676
|
+
self.btn_zoom_in.raise_()
|
|
1677
|
+
self.btn_zoom_out.raise_()
|
|
1678
|
+
|
|
1679
|
+
def _zoom_in(self):
|
|
1680
|
+
self.zoom_factor = min(self.zoom_max, self.zoom_factor + self.zoom_step)
|
|
1681
|
+
self._apply_zoom()
|
|
1682
|
+
|
|
1683
|
+
def _zoom_out(self):
|
|
1684
|
+
self.zoom_factor = max(self.zoom_min, self.zoom_factor - self.zoom_step)
|
|
1685
|
+
self._apply_zoom()
|
|
1686
|
+
|
|
1687
|
+
def _apply_zoom(self):
|
|
1688
|
+
if self._base_display_pixmap is None:
|
|
1689
|
+
return
|
|
1690
|
+
w = max(1, int(self._base_display_pixmap.width() * self.zoom_factor))
|
|
1691
|
+
h = max(1, int(self._base_display_pixmap.height() * self.zoom_factor))
|
|
1692
|
+
scaled = self._base_display_pixmap.scaled(
|
|
1693
|
+
w,
|
|
1694
|
+
h,
|
|
1695
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
|
1696
|
+
Qt.TransformationMode.SmoothTransformation,
|
|
1697
|
+
)
|
|
1698
|
+
self.video_label.setPixmap(scaled)
|
|
1699
|
+
self.video_label.resize(scaled.size())
|
|
1700
|
+
|
|
1701
|
+
def _prev_frame(self):
|
|
1702
|
+
"""Go to previous frame."""
|
|
1703
|
+
if self.current_frame_idx > 0:
|
|
1704
|
+
self.current_frame_idx -= 1
|
|
1705
|
+
self.slider.blockSignals(True)
|
|
1706
|
+
self.slider.setValue(self.current_frame_idx)
|
|
1707
|
+
self.slider.blockSignals(False)
|
|
1708
|
+
self._update_frame()
|
|
1709
|
+
|
|
1710
|
+
def _next_frame(self):
|
|
1711
|
+
"""Go to next frame."""
|
|
1712
|
+
if self.current_frame_idx < self.total_frames - 1:
|
|
1713
|
+
self.current_frame_idx += 1
|
|
1714
|
+
self.slider.blockSignals(True)
|
|
1715
|
+
self.slider.setValue(self.current_frame_idx)
|
|
1716
|
+
self.slider.blockSignals(False)
|
|
1717
|
+
self._update_frame()
|
|
1718
|
+
|
|
1719
|
+
def _update_frame(self):
|
|
1720
|
+
"""Update video display with current frame and overlays."""
|
|
1721
|
+
if not self.cap:
|
|
1722
|
+
return
|
|
1723
|
+
|
|
1724
|
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame_idx)
|
|
1725
|
+
ret, frame = self.cap.read()
|
|
1726
|
+
if ret:
|
|
1727
|
+
self.frame = frame
|
|
1728
|
+
# Keep RGB frame as instance var to prevent QImage memory issues
|
|
1729
|
+
self._frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
1730
|
+
h, w, ch = self._frame_rgb.shape
|
|
1731
|
+
|
|
1732
|
+
bytes_per_line = ch * w
|
|
1733
|
+
q_img = QImage(self._frame_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
|
|
1734
|
+
|
|
1735
|
+
pixmap = QPixmap.fromImage(q_img)
|
|
1736
|
+
painter = QPainter(pixmap)
|
|
1737
|
+
|
|
1738
|
+
# Draw masks
|
|
1739
|
+
if self.current_frame_idx in self.masks:
|
|
1740
|
+
frame_masks = self.masks[self.current_frame_idx]
|
|
1741
|
+
for obj_id, mask in frame_masks.items():
|
|
1742
|
+
if mask is not None and mask.max() > 0:
|
|
1743
|
+
mask_h, mask_w = mask.shape[:2]
|
|
1744
|
+
if mask_h != h or mask_w != w:
|
|
1745
|
+
mask = cv2.resize(mask.astype(np.float32), (w, h), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
|
|
1746
|
+
|
|
1747
|
+
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
|
1748
|
+
color = get_obj_color(obj_id)
|
|
1749
|
+
overlay[mask > 0] = [color[0], color[1], color[2], 100]
|
|
1750
|
+
|
|
1751
|
+
overlay_img = QImage(overlay.data, w, h, w * 4, QImage.Format.Format_RGBA8888)
|
|
1752
|
+
painter.drawImage(0, 0, overlay_img)
|
|
1753
|
+
|
|
1754
|
+
# Draw points
|
|
1755
|
+
for p in self.points:
|
|
1756
|
+
if p[3] == self.current_frame_idx:
|
|
1757
|
+
x, y, label, obj_id = p[0], p[1], p[2], p[4]
|
|
1758
|
+
obj_color_rgb = get_obj_color(obj_id)
|
|
1759
|
+
|
|
1760
|
+
if label == 1:
|
|
1761
|
+
painter.setPen(QPen(QColor(*obj_color_rgb), 5))
|
|
1762
|
+
painter.drawPoint(x, y)
|
|
1763
|
+
else:
|
|
1764
|
+
painter.setPen(QPen(QColor(255, 0, 0), 5))
|
|
1765
|
+
painter.drawPoint(x, y)
|
|
1766
|
+
|
|
1767
|
+
painter.setPen(QPen(QColor(255, 255, 255), 1))
|
|
1768
|
+
painter.drawText(x + 5, y - 5, str(obj_id))
|
|
1769
|
+
|
|
1770
|
+
painter.end()
|
|
1771
|
+
self._base_display_pixmap = pixmap
|
|
1772
|
+
self._apply_zoom()
|
|
1773
|
+
|
|
1774
|
+
self.lbl_frame.setText(f"Frame: {self.current_frame_idx} / {self.total_frames}")
|
|
1775
|
+
|
|
1776
|
+
def _handle_click(self, x, y):
|
|
1777
|
+
"""Handle click on video label."""
|
|
1778
|
+
if self.frame is None:
|
|
1779
|
+
return
|
|
1780
|
+
|
|
1781
|
+
pixmap = self.video_label.pixmap()
|
|
1782
|
+
if not pixmap:
|
|
1783
|
+
return
|
|
1784
|
+
|
|
1785
|
+
label_w = self.video_label.width()
|
|
1786
|
+
label_h = self.video_label.height()
|
|
1787
|
+
pix_w = pixmap.width()
|
|
1788
|
+
pix_h = pixmap.height()
|
|
1789
|
+
|
|
1790
|
+
x_offset = (label_w - pix_w) / 2
|
|
1791
|
+
y_offset = (label_h - pix_h) / 2
|
|
1792
|
+
|
|
1793
|
+
img_x = x - x_offset
|
|
1794
|
+
img_y = y - y_offset
|
|
1795
|
+
|
|
1796
|
+
if 0 <= img_x < pix_w and 0 <= img_y < pix_h:
|
|
1797
|
+
orig_h, orig_w = self.frame.shape[:2]
|
|
1798
|
+
scale_x = orig_w / pix_w
|
|
1799
|
+
scale_y = orig_h / pix_h
|
|
1800
|
+
|
|
1801
|
+
final_x = int(img_x * scale_x)
|
|
1802
|
+
final_y = int(img_y * scale_y)
|
|
1803
|
+
|
|
1804
|
+
label = 1 if self.radio_pos.isChecked() else 0
|
|
1805
|
+
self.points.append((final_x, final_y, label, self.current_frame_idx, self.current_obj_id))
|
|
1806
|
+
self._preview_frame()
|
|
1807
|
+
|
|
1808
|
+
def _clear_points(self):
|
|
1809
|
+
"""Clear all points."""
|
|
1810
|
+
self.points = []
|
|
1811
|
+
self.masks = {}
|
|
1812
|
+
# Clean up any incremental mask temp file
|
|
1813
|
+
if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
|
|
1814
|
+
try:
|
|
1815
|
+
self._incremental_mask_file.close()
|
|
1816
|
+
os.unlink(self._incremental_mask_file.name)
|
|
1817
|
+
except:
|
|
1818
|
+
pass
|
|
1819
|
+
self._incremental_mask_file = None
|
|
1820
|
+
# Reset inference state to clear any cached predictions
|
|
1821
|
+
if self.inference_state and self.predictor:
|
|
1822
|
+
try:
|
|
1823
|
+
self.predictor.reset_state(self.inference_state)
|
|
1824
|
+
except:
|
|
1825
|
+
pass
|
|
1826
|
+
self._update_frame()
|
|
1827
|
+
|
|
1828
|
+
def _preview_frame(self):
|
|
1829
|
+
"""Preview segmentation on current frame."""
|
|
1830
|
+
if not self.video_path:
|
|
1831
|
+
return
|
|
1832
|
+
|
|
1833
|
+
if not self._ensure_predictor():
|
|
1834
|
+
return
|
|
1835
|
+
|
|
1836
|
+
# Gather points for current frame and object
|
|
1837
|
+
current_points = []
|
|
1838
|
+
current_labels = []
|
|
1839
|
+
for x, y, label, frame_idx, obj_id in self.points:
|
|
1840
|
+
if frame_idx == self.current_frame_idx and obj_id == self.current_obj_id:
|
|
1841
|
+
current_points.append([x, y])
|
|
1842
|
+
current_labels.append(label)
|
|
1843
|
+
|
|
1844
|
+
if not current_points:
|
|
1845
|
+
return
|
|
1846
|
+
|
|
1847
|
+
# Force fresh state load for preview to avoid any stale cached data
|
|
1848
|
+
# This is slower but ensures accurate preview
|
|
1849
|
+
if not self._load_state(self.current_frame_idx, self.current_frame_idx + 1):
|
|
1850
|
+
return
|
|
1851
|
+
|
|
1852
|
+
# Reset any previous tracking state to get clean prediction
|
|
1853
|
+
try:
|
|
1854
|
+
self.predictor.reset_state(self.inference_state)
|
|
1855
|
+
except:
|
|
1856
|
+
pass
|
|
1857
|
+
|
|
1858
|
+
try:
|
|
1859
|
+
pts = np.array(current_points, dtype=np.float32)
|
|
1860
|
+
lbls = np.array(current_labels, dtype=np.int32)
|
|
1861
|
+
|
|
1862
|
+
local_frame_idx = 0 # We just loaded a single frame, so local index is 0
|
|
1863
|
+
|
|
1864
|
+
_, out_obj_ids, out_mask_logits = self._sam2_call(
|
|
1865
|
+
self.predictor.add_new_points_or_box,
|
|
1866
|
+
inference_state=self.inference_state,
|
|
1867
|
+
frame_idx=local_frame_idx,
|
|
1868
|
+
obj_id=self.current_obj_id,
|
|
1869
|
+
points=pts,
|
|
1870
|
+
labels=lbls,
|
|
1871
|
+
clear_old_points=True,
|
|
1872
|
+
normalize_coords=True,
|
|
1873
|
+
)
|
|
1874
|
+
|
|
1875
|
+
if self.current_obj_id in out_obj_ids:
|
|
1876
|
+
idx = out_obj_ids.index(self.current_obj_id)
|
|
1877
|
+
mask_logit = out_mask_logits[idx]
|
|
1878
|
+
if mask_logit.ndim == 3:
|
|
1879
|
+
mask_logit = mask_logit[0]
|
|
1880
|
+
mask = (mask_logit > self.mask_threshold).cpu().numpy().astype(np.uint8).squeeze()
|
|
1881
|
+
|
|
1882
|
+
if self.current_frame_idx not in self.masks:
|
|
1883
|
+
self.masks[self.current_frame_idx] = {}
|
|
1884
|
+
self.masks[self.current_frame_idx][self.current_obj_id] = mask
|
|
1885
|
+
self._update_frame()
|
|
1886
|
+
else:
|
|
1887
|
+
# If preview did not return this object, clear stale overlay for it.
|
|
1888
|
+
if self.current_frame_idx in self.masks and self.current_obj_id in self.masks[self.current_frame_idx]:
|
|
1889
|
+
del self.masks[self.current_frame_idx][self.current_obj_id]
|
|
1890
|
+
self._update_frame()
|
|
1891
|
+
except Exception as e:
|
|
1892
|
+
QMessageBox.critical(self, "Error", f"Preview failed:\n{e}")
|
|
1893
|
+
|
|
1894
|
+
def _ensure_state_for_frame(self, frame_idx):
|
|
1895
|
+
"""Ensure state is loaded for specific frame."""
|
|
1896
|
+
if self.inference_state:
|
|
1897
|
+
local_idx = frame_idx - self.state_start_frame
|
|
1898
|
+
if 0 <= local_idx < self.inference_state["num_frames"]:
|
|
1899
|
+
return True
|
|
1900
|
+
|
|
1901
|
+
return self._load_state(frame_idx, frame_idx + 1)
|
|
1902
|
+
|
|
1903
|
+
def _load_state(self, start_frame, end_frame):
|
|
1904
|
+
"""Load video frames into SAM2 state."""
|
|
1905
|
+
if not self._ensure_predictor():
|
|
1906
|
+
return False
|
|
1907
|
+
|
|
1908
|
+
if not self.video_path:
|
|
1909
|
+
return False
|
|
1910
|
+
|
|
1911
|
+
try:
|
|
1912
|
+
try:
|
|
1913
|
+
import decord
|
|
1914
|
+
except ImportError:
|
|
1915
|
+
QMessageBox.warning(
|
|
1916
|
+
self,
|
|
1917
|
+
"Missing Dependency",
|
|
1918
|
+
"decord not found. Please install it:\n\n"
|
|
1919
|
+
"pip install eva-decord\n\n"
|
|
1920
|
+
"Or: conda install -c conda-forge decord"
|
|
1921
|
+
)
|
|
1922
|
+
return False
|
|
1923
|
+
|
|
1924
|
+
from collections import OrderedDict
|
|
1925
|
+
|
|
1926
|
+
decord.bridge.set_bridge("torch")
|
|
1927
|
+
compute_device = self.predictor.device
|
|
1928
|
+
image_size = self.predictor.image_size
|
|
1929
|
+
|
|
1930
|
+
# Get original video dimensions (needed for coordinate normalization)
|
|
1931
|
+
vr_meta = decord.VideoReader(self.video_path)
|
|
1932
|
+
video_height, video_width, _ = vr_meta[0].shape
|
|
1933
|
+
total_frames = len(vr_meta)
|
|
1934
|
+
del vr_meta # Free memory immediately after getting dimensions
|
|
1935
|
+
|
|
1936
|
+
# Load frames at SAM2's internal image_size (square)
|
|
1937
|
+
# SAM2 uses original video_height/video_width for coordinate normalization
|
|
1938
|
+
vr = decord.VideoReader(self.video_path, width=image_size, height=image_size)
|
|
1939
|
+
target_dtype = getattr(self.predictor, "dtype", torch.float32)
|
|
1940
|
+
if self._use_cuda_bf16() and not self.offload_video:
|
|
1941
|
+
target_dtype = torch.bfloat16
|
|
1942
|
+
|
|
1943
|
+
if end_frame is None or end_frame > total_frames:
|
|
1944
|
+
end_frame = total_frames
|
|
1945
|
+
start_frame = max(0, start_frame)
|
|
1946
|
+
|
|
1947
|
+
if start_frame >= end_frame:
|
|
1948
|
+
return False
|
|
1949
|
+
|
|
1950
|
+
# Limit number of frames loaded at once to prevent RAM issues
|
|
1951
|
+
num_frames = end_frame - start_frame
|
|
1952
|
+
if num_frames > self.max_frames_per_load:
|
|
1953
|
+
# Only load the last max_frames_per_load frames to keep memory usage reasonable
|
|
1954
|
+
start_frame = max(start_frame, end_frame - self.max_frames_per_load)
|
|
1955
|
+
if start_frame < 0:
|
|
1956
|
+
start_frame = 0
|
|
1957
|
+
QMessageBox.warning(
|
|
1958
|
+
self,
|
|
1959
|
+
"Frame Range Limited",
|
|
1960
|
+
f"Requested {num_frames} frames, but limiting to {self.max_frames_per_load} frames "
|
|
1961
|
+
f"to prevent RAM issues.\n\nLoading frames {start_frame} to {end_frame}."
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
indices = list(range(start_frame, end_frame))
|
|
1965
|
+
frames = vr.get_batch(indices)
|
|
1966
|
+
del vr # Free VideoReader memory after loading frames
|
|
1967
|
+
images = frames.permute(0, 3, 1, 2).float() / 255.0
|
|
1968
|
+
del frames # Free original frame tensor after processing
|
|
1969
|
+
|
|
1970
|
+
img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)[:, None, None]
|
|
1971
|
+
img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)[:, None, None]
|
|
1972
|
+
|
|
1973
|
+
if not self.offload_video:
|
|
1974
|
+
images = images.to(compute_device, dtype=target_dtype)
|
|
1975
|
+
img_mean = img_mean.to(compute_device, dtype=target_dtype)
|
|
1976
|
+
img_std = img_std.to(compute_device, dtype=target_dtype)
|
|
1977
|
+
else:
|
|
1978
|
+
images = images.to(dtype=target_dtype)
|
|
1979
|
+
img_mean = img_mean.to(dtype=target_dtype)
|
|
1980
|
+
img_std = img_std.to(dtype=target_dtype)
|
|
1981
|
+
|
|
1982
|
+
images -= img_mean
|
|
1983
|
+
images /= img_std
|
|
1984
|
+
|
|
1985
|
+
# Convert to list format expected by SAM2
|
|
1986
|
+
images_list = [images[i] for i in range(len(images))]
|
|
1987
|
+
|
|
1988
|
+
inference_state = {}
|
|
1989
|
+
inference_state["images"] = images_list
|
|
1990
|
+
inference_state["num_frames"] = len(images_list)
|
|
1991
|
+
inference_state["offload_video_to_cpu"] = self.offload_video
|
|
1992
|
+
inference_state["offload_state_to_cpu"] = self.offload_state
|
|
1993
|
+
inference_state["video_height"] = video_height
|
|
1994
|
+
inference_state["video_width"] = video_width
|
|
1995
|
+
inference_state["device"] = compute_device
|
|
1996
|
+
if self.offload_state:
|
|
1997
|
+
inference_state["storage_device"] = torch.device("cpu")
|
|
1998
|
+
else:
|
|
1999
|
+
inference_state["storage_device"] = compute_device
|
|
2000
|
+
|
|
2001
|
+
inference_state["point_inputs_per_obj"] = {}
|
|
2002
|
+
inference_state["mask_inputs_per_obj"] = {}
|
|
2003
|
+
inference_state["cached_features"] = {}
|
|
2004
|
+
inference_state["constants"] = {}
|
|
2005
|
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
|
2006
|
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
|
2007
|
+
inference_state["obj_ids"] = []
|
|
2008
|
+
inference_state["output_dict_per_obj"] = {}
|
|
2009
|
+
inference_state["temp_output_dict_per_obj"] = {}
|
|
2010
|
+
inference_state["frames_tracked_per_obj"] = {}
|
|
2011
|
+
|
|
2012
|
+
try:
|
|
2013
|
+
self._sam2_call(self.predictor._get_image_feature, inference_state, frame_idx=0, batch_size=1)
|
|
2014
|
+
except Exception as e:
|
|
2015
|
+
pass
|
|
2016
|
+
|
|
2017
|
+
self.inference_state = inference_state
|
|
2018
|
+
self.state_start_frame = start_frame
|
|
2019
|
+
return True
|
|
2020
|
+
except Exception as e:
|
|
2021
|
+
QMessageBox.critical(self, "Error", f"Failed to load video state:\n{e}")
|
|
2022
|
+
return False
|
|
2023
|
+
|
|
2024
|
+
def _run_tracking(self, from_batch: bool = False):
|
|
2025
|
+
"""Run tracking on current video."""
|
|
2026
|
+
if not self.video_path:
|
|
2027
|
+
QMessageBox.warning(self, "No Video", "Please load a video first.")
|
|
2028
|
+
return
|
|
2029
|
+
if not self.points:
|
|
2030
|
+
if from_batch and self.batch_mode and self.batch_queue:
|
|
2031
|
+
next_idx = self.batch_queue.pop(0)
|
|
2032
|
+
self._apply_video_state(next_idx)
|
|
2033
|
+
self._run_tracking(from_batch=True)
|
|
2034
|
+
return
|
|
2035
|
+
QMessageBox.warning(self, "No Points", "Please add some points first.")
|
|
2036
|
+
return
|
|
2037
|
+
|
|
2038
|
+
if not self._ensure_predictor():
|
|
2039
|
+
return
|
|
2040
|
+
|
|
2041
|
+
# Group points by frame and object
|
|
2042
|
+
points_grouped = {}
|
|
2043
|
+
for x, y, label, frame_idx, obj_id in self.points:
|
|
2044
|
+
key = (frame_idx, obj_id)
|
|
2045
|
+
if key not in points_grouped:
|
|
2046
|
+
points_grouped[key] = {'points': [], 'labels': []}
|
|
2047
|
+
points_grouped[key]['points'].append([x, y])
|
|
2048
|
+
points_grouped[key]['labels'].append(label)
|
|
2049
|
+
|
|
2050
|
+
if not points_grouped:
|
|
2051
|
+
if from_batch and self.batch_mode and self.batch_queue:
|
|
2052
|
+
next_idx = self.batch_queue.pop(0)
|
|
2053
|
+
self._apply_video_state(next_idx)
|
|
2054
|
+
self._run_tracking(from_batch=True)
|
|
2055
|
+
return
|
|
2056
|
+
QMessageBox.warning(self, "Warning", "No points in the selected range.")
|
|
2057
|
+
return
|
|
2058
|
+
|
|
2059
|
+
# Determine processing range to include user points
|
|
2060
|
+
min_frame = min(k[0] for k in points_grouped.keys())
|
|
2061
|
+
max_frame = max(k[0] for k in points_grouped.keys())
|
|
2062
|
+
|
|
2063
|
+
if hasattr(self, 'chk_limit_range') and self.chk_limit_range.isChecked():
|
|
2064
|
+
start_f = max(self.spin_start.value(), 0)
|
|
2065
|
+
end_f = min(self.spin_end.value() + 1, self.total_frames)
|
|
2066
|
+
# Ensure the range covers the annotated points
|
|
2067
|
+
start_f = min(start_f, min_frame)
|
|
2068
|
+
end_f = max(end_f, max_frame + 1)
|
|
2069
|
+
else:
|
|
2070
|
+
start_f = 0
|
|
2071
|
+
end_f = self.total_frames
|
|
2072
|
+
|
|
2073
|
+
# If resuming, allow forcing the start a bit earlier than the drift point
|
|
2074
|
+
if self.resume_from_frame is not None:
|
|
2075
|
+
# Resume from the chosen frame (or later), do not jump back to frame 0
|
|
2076
|
+
start_f = max(self.resume_from_frame, start_f, 0)
|
|
2077
|
+
self.resume_from_frame = None
|
|
2078
|
+
|
|
2079
|
+
self.btn_track.setEnabled(False)
|
|
2080
|
+
self.btn_track_all.setEnabled(False)
|
|
2081
|
+
self.progress_bar.setVisible(True)
|
|
2082
|
+
self.progress_bar.setRange(0, end_f - start_f)
|
|
2083
|
+
self.progress_bar.setValue(0)
|
|
2084
|
+
|
|
2085
|
+
# Get initial masks for resume (if any)
|
|
2086
|
+
initial_masks = getattr(self, 'resume_initial_masks', {}) or {}
|
|
2087
|
+
|
|
2088
|
+
# Clear resume flags
|
|
2089
|
+
self.resume_from_frame = None
|
|
2090
|
+
self.resume_initial_masks = {}
|
|
2091
|
+
self.tracking_paused = False
|
|
2092
|
+
|
|
2093
|
+
self.tracking_worker = TrackingWorker(
|
|
2094
|
+
self.predictor,
|
|
2095
|
+
self.video_path,
|
|
2096
|
+
points_grouped,
|
|
2097
|
+
start_f,
|
|
2098
|
+
end_f,
|
|
2099
|
+
self.mask_threshold,
|
|
2100
|
+
self.offload_video,
|
|
2101
|
+
self.offload_state,
|
|
2102
|
+
enable_memory_management=self.enable_memory_management,
|
|
2103
|
+
reseed_between_chunks=getattr(self, "reseed_between_chunks", False),
|
|
2104
|
+
initial_masks=initial_masks,
|
|
2105
|
+
enable_motion_tracking=getattr(self, "enable_motion_tracking", False),
|
|
2106
|
+
motion_score_threshold=getattr(self, "motion_score_threshold", 0.3),
|
|
2107
|
+
motion_consecutive_low=getattr(self, "motion_consecutive_low", 3),
|
|
2108
|
+
motion_area_threshold=getattr(self, "motion_area_threshold", 0.5),
|
|
2109
|
+
enable_ocsort=getattr(self, "enable_ocsort", False),
|
|
2110
|
+
ocsort_inertia=getattr(self, "ocsort_inertia", 0.2),
|
|
2111
|
+
use_cuda_bf16_autocast=getattr(self, "use_cuda_bf16_autocast", True),
|
|
2112
|
+
)
|
|
2113
|
+
self.tracking_worker.progress_signal.connect(lambda x: self.progress_bar.setValue(x - start_f) if x >= start_f else None)
|
|
2114
|
+
self.tracking_worker.frame_result_signal.connect(self._on_frame_result)
|
|
2115
|
+
self.tracking_worker.finished_signal.connect(self._on_tracking_finished)
|
|
2116
|
+
self.tracking_worker.error_signal.connect(self._on_tracking_error)
|
|
2117
|
+
self.tracking_worker.log_message.connect(lambda msg: self.log_text.setText(msg))
|
|
2118
|
+
self.tracking_worker.start()
|
|
2119
|
+
|
|
2120
|
+
# Enable pause, disable resume while running
|
|
2121
|
+
self.btn_pause_tracking.setEnabled(True)
|
|
2122
|
+
self.btn_resume_tracking.setEnabled(False)
|
|
2123
|
+
|
|
2124
|
+
def _pause_tracking(self):
|
|
2125
|
+
"""Pause the current tracking run to allow adding new prompts."""
|
|
2126
|
+
if hasattr(self, "tracking_worker") and self.tracking_worker and self.tracking_worker.isRunning():
|
|
2127
|
+
self.tracking_paused = True
|
|
2128
|
+
self.tracking_worker.stop()
|
|
2129
|
+
self.log_text.setText("Stopping tracking... you can add points and resume.")
|
|
2130
|
+
self.btn_pause_tracking.setEnabled(False)
|
|
2131
|
+
self.btn_resume_tracking.setEnabled(False)
|
|
2132
|
+
|
|
2133
|
+
def _resume_tracking(self):
|
|
2134
|
+
"""Resume tracking from the current frame (or last processed) after adding prompts."""
|
|
2135
|
+
if self.tracking_worker and self.tracking_worker.isRunning():
|
|
2136
|
+
return # already running
|
|
2137
|
+
|
|
2138
|
+
# Choose a resume frame: prefer slider position, fallback to last processed
|
|
2139
|
+
resume_frame = self.current_frame_idx if hasattr(self, "current_frame_idx") else None
|
|
2140
|
+
if resume_frame is None and self.last_processed_frame is not None:
|
|
2141
|
+
resume_frame = self.last_processed_frame
|
|
2142
|
+
if resume_frame is None:
|
|
2143
|
+
resume_frame = 0
|
|
2144
|
+
|
|
2145
|
+
# Collect refined masks at the resume frame to use as conditioning
|
|
2146
|
+
# This allows the user to refine the mask before resuming
|
|
2147
|
+
initial_masks = {}
|
|
2148
|
+
if resume_frame in self.masks:
|
|
2149
|
+
for obj_id, mask in self.masks[resume_frame].items():
|
|
2150
|
+
if mask is not None and mask.max() > 0:
|
|
2151
|
+
initial_masks[(resume_frame, obj_id)] = mask
|
|
2152
|
+
self.log_text.setText(f"Will use refined mask for object {obj_id} at frame {resume_frame}")
|
|
2153
|
+
|
|
2154
|
+
# Store initial masks for the worker
|
|
2155
|
+
self.resume_initial_masks = initial_masks
|
|
2156
|
+
|
|
2157
|
+
# Start a bit before the drift point for stability (but not before the mask frame)
|
|
2158
|
+
self.resume_from_frame = resume_frame # Start exactly from where user refined
|
|
2159
|
+
self.tracking_paused = False
|
|
2160
|
+
|
|
2161
|
+
# Re-run tracking; it will honor resume_from_frame and initial_masks
|
|
2162
|
+
self._run_tracking()
|
|
2163
|
+
|
|
2164
|
+
def _run_tracking_all(self):
|
|
2165
|
+
"""Run tracking sequentially for all loaded videos."""
|
|
2166
|
+
if not self.videos:
|
|
2167
|
+
QMessageBox.warning(self, "No Videos", "Please load videos first.")
|
|
2168
|
+
return
|
|
2169
|
+
if not self.sam2_available:
|
|
2170
|
+
QMessageBox.warning(
|
|
2171
|
+
self,
|
|
2172
|
+
"SAM2 not available",
|
|
2173
|
+
"SAM2 is not installed in this environment.\n\nRun bash install.sh and reopen the app.",
|
|
2174
|
+
)
|
|
2175
|
+
return
|
|
2176
|
+
|
|
2177
|
+
# Save current state
|
|
2178
|
+
self._save_current_video_state()
|
|
2179
|
+
|
|
2180
|
+
# Build queue of indices
|
|
2181
|
+
self.batch_queue = list(range(len(self.videos)))
|
|
2182
|
+
self.batch_mode = True
|
|
2183
|
+
|
|
2184
|
+
# Start with the first video in the queue
|
|
2185
|
+
next_idx = self.batch_queue.pop(0)
|
|
2186
|
+
self._apply_video_state(next_idx)
|
|
2187
|
+
self._run_tracking(from_batch=True)
|
|
2188
|
+
|
|
2189
|
+
def _on_frame_result(self, frame_idx, frame_masks):
|
|
2190
|
+
"""Handle real-time mask updates."""
|
|
2191
|
+
if frame_idx not in self.masks:
|
|
2192
|
+
self.masks[frame_idx] = {}
|
|
2193
|
+
|
|
2194
|
+
for obj_id, mask in frame_masks.items():
|
|
2195
|
+
self.masks[frame_idx][obj_id] = mask
|
|
2196
|
+
|
|
2197
|
+
# Track last processed frame for potential resume
|
|
2198
|
+
self.last_processed_frame = frame_idx
|
|
2199
|
+
|
|
2200
|
+
# Incremental save: periodically flush masks to disk to free RAM
|
|
2201
|
+
# Trigger when we exceed threshold (not just at exact multiples)
|
|
2202
|
+
if len(self.masks) >= 500:
|
|
2203
|
+
self._incremental_save_masks()
|
|
2204
|
+
|
|
2205
|
+
if frame_idx == self.current_frame_idx:
|
|
2206
|
+
self._update_frame()
|
|
2207
|
+
|
|
2208
|
+
if self.chk_auto_follow.isChecked():
|
|
2209
|
+
self.slider.blockSignals(True)
|
|
2210
|
+
self.slider.setValue(frame_idx)
|
|
2211
|
+
self.slider.blockSignals(False)
|
|
2212
|
+
self.current_frame_idx = frame_idx
|
|
2213
|
+
self._update_frame()
|
|
2214
|
+
|
|
2215
|
+
def _on_tracking_finished(self, masks):
|
|
2216
|
+
"""Handle tracking completion."""
|
|
2217
|
+
for frame_idx, frame_masks in masks.items():
|
|
2218
|
+
if frame_idx not in self.masks:
|
|
2219
|
+
self.masks[frame_idx] = {}
|
|
2220
|
+
for obj_id, mask in frame_masks.items():
|
|
2221
|
+
self.masks[frame_idx][obj_id] = mask
|
|
2222
|
+
|
|
2223
|
+
self.btn_track.setEnabled(True)
|
|
2224
|
+
self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
|
|
2225
|
+
self.btn_pause_tracking.setEnabled(False)
|
|
2226
|
+
self.btn_resume_tracking.setEnabled(self.tracking_paused)
|
|
2227
|
+
self.progress_bar.setVisible(False)
|
|
2228
|
+
self._update_frame()
|
|
2229
|
+
|
|
2230
|
+
# Persist state for current video
|
|
2231
|
+
self._save_current_video_state()
|
|
2232
|
+
overlay_path = None
|
|
2233
|
+
if self.save_overlay_video:
|
|
2234
|
+
overlay_path = self._save_overlay_video(paused=self.tracking_paused)
|
|
2235
|
+
|
|
2236
|
+
# If user paused manually, do not save or show completion popups yet
|
|
2237
|
+
if self.tracking_paused:
|
|
2238
|
+
if overlay_path:
|
|
2239
|
+
self.log_text.setText(
|
|
2240
|
+
"Tracking paused. Partial overlay video saved to:\n"
|
|
2241
|
+
f"{overlay_path}\n"
|
|
2242
|
+
"Add points on current frame, click 'Preview frame', then 'Resume tracking from here'."
|
|
2243
|
+
)
|
|
2244
|
+
else:
|
|
2245
|
+
self.log_text.setText(
|
|
2246
|
+
"Tracking paused. Add points on current frame, click 'Preview frame', then 'Resume tracking from here'."
|
|
2247
|
+
)
|
|
2248
|
+
self.btn_resume_tracking.setEnabled(True)
|
|
2249
|
+
return
|
|
2250
|
+
|
|
2251
|
+
# Save masks automatically
|
|
2252
|
+
mask_path = self._save_masks()
|
|
2253
|
+
|
|
2254
|
+
# In batch mode, skip popups to continue processing
|
|
2255
|
+
if self.batch_mode:
|
|
2256
|
+
pass
|
|
2257
|
+
else:
|
|
2258
|
+
if mask_path and self.video_path:
|
|
2259
|
+
# Show completion message with option to go to registration
|
|
2260
|
+
overlay_text = f"Overlay video saved to: {overlay_path}\n\n" if overlay_path else ""
|
|
2261
|
+
reply = QMessageBox.question(
|
|
2262
|
+
self,
|
|
2263
|
+
"Tracking Completed",
|
|
2264
|
+
f"Tracking completed successfully!\n\n"
|
|
2265
|
+
f"Masks saved to: {mask_path}\n\n"
|
|
2266
|
+
f"{overlay_text}"
|
|
2267
|
+
"Would you like to proceed to the Registration tab to process this video?",
|
|
2268
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
2269
|
+
QMessageBox.StandardButton.Yes
|
|
2270
|
+
)
|
|
2271
|
+
|
|
2272
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
2273
|
+
# Emit signal to switch to registration tab
|
|
2274
|
+
self.tracking_completed.emit(self.video_path, mask_path)
|
|
2275
|
+
else:
|
|
2276
|
+
QMessageBox.information(self, "Success", "Tracking completed!")
|
|
2277
|
+
|
|
2278
|
+
self.inference_state = None
|
|
2279
|
+
self.tracking_paused = False
|
|
2280
|
+
self.resume_from_frame = None
|
|
2281
|
+
|
|
2282
|
+
# Continue batch if pending
|
|
2283
|
+
if self.batch_mode and self.batch_queue:
|
|
2284
|
+
next_idx = self.batch_queue.pop(0)
|
|
2285
|
+
self._apply_video_state(next_idx)
|
|
2286
|
+
self._run_tracking(from_batch=True)
|
|
2287
|
+
return
|
|
2288
|
+
# End batch
|
|
2289
|
+
self.batch_mode = False
|
|
2290
|
+
self.batch_queue = []
|
|
2291
|
+
|
|
2292
|
+
def _incremental_save_masks(self):
|
|
2293
|
+
"""Save old masks to temp file and clear from memory to prevent RAM exhaustion."""
|
|
2294
|
+
if not self.video_path or len(self.masks) < 500:
|
|
2295
|
+
return
|
|
2296
|
+
|
|
2297
|
+
import pickle
|
|
2298
|
+
import tempfile
|
|
2299
|
+
|
|
2300
|
+
# Initialize temp file on first call
|
|
2301
|
+
if not hasattr(self, '_incremental_mask_file') or self._incremental_mask_file is None:
|
|
2302
|
+
video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
|
|
2303
|
+
self._incremental_mask_file = tempfile.NamedTemporaryFile(
|
|
2304
|
+
mode='wb', suffix=f'_{video_basename}_masks.pkl', delete=False
|
|
2305
|
+
)
|
|
2306
|
+
self._incremental_frame_indices = []
|
|
2307
|
+
|
|
2308
|
+
# Get frames to save (oldest 400 frames, keep 100 for display)
|
|
2309
|
+
sorted_frames = sorted(self.masks.keys())
|
|
2310
|
+
frames_to_save = sorted_frames[:400]
|
|
2311
|
+
|
|
2312
|
+
# Save to temp file
|
|
2313
|
+
chunk_data = {idx: self.masks[idx] for idx in frames_to_save}
|
|
2314
|
+
pickle.dump(chunk_data, self._incremental_mask_file)
|
|
2315
|
+
self._incremental_frame_indices.extend(frames_to_save)
|
|
2316
|
+
|
|
2317
|
+
# Clear from memory
|
|
2318
|
+
for idx in frames_to_save:
|
|
2319
|
+
del self.masks[idx]
|
|
2320
|
+
|
|
2321
|
+
gc.collect()
|
|
2322
|
+
|
|
2323
|
+
def _get_masks_snapshot_for_export(self):
|
|
2324
|
+
"""Get a merged mask snapshot including incremental chunks without consuming them."""
|
|
2325
|
+
snapshot = {}
|
|
2326
|
+
for frame_idx, frame_data in self.masks.items():
|
|
2327
|
+
snapshot[frame_idx] = dict(frame_data)
|
|
2328
|
+
|
|
2329
|
+
if hasattr(self, "_incremental_mask_file") and self._incremental_mask_file is not None:
|
|
2330
|
+
import pickle
|
|
2331
|
+
try:
|
|
2332
|
+
self._incremental_mask_file.flush()
|
|
2333
|
+
except Exception:
|
|
2334
|
+
pass
|
|
2335
|
+
try:
|
|
2336
|
+
with open(self._incremental_mask_file.name, "rb") as f:
|
|
2337
|
+
while True:
|
|
2338
|
+
try:
|
|
2339
|
+
chunk = pickle.load(f)
|
|
2340
|
+
for frame_idx, frame_data in chunk.items():
|
|
2341
|
+
if frame_idx not in snapshot:
|
|
2342
|
+
snapshot[frame_idx] = dict(frame_data)
|
|
2343
|
+
else:
|
|
2344
|
+
snapshot[frame_idx].update(frame_data)
|
|
2345
|
+
except EOFError:
|
|
2346
|
+
break
|
|
2347
|
+
except Exception as e:
|
|
2348
|
+
logger.warning("Could not read incremental masks for overlay export: %s", e)
|
|
2349
|
+
|
|
2350
|
+
return snapshot
|
|
2351
|
+
|
|
2352
|
+
def _save_overlay_video(self, paused=False):
|
|
2353
|
+
"""Save overlay video (original frame + colored masks) for inspection."""
|
|
2354
|
+
if not self.video_path:
|
|
2355
|
+
return None
|
|
2356
|
+
|
|
2357
|
+
all_masks = self._get_masks_snapshot_for_export()
|
|
2358
|
+
if not all_masks:
|
|
2359
|
+
return None
|
|
2360
|
+
|
|
2361
|
+
cap = cv2.VideoCapture(self.video_path)
|
|
2362
|
+
if not cap.isOpened():
|
|
2363
|
+
return None
|
|
2364
|
+
|
|
2365
|
+
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
2366
|
+
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
2367
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
2368
|
+
if fps <= 0:
|
|
2369
|
+
fps = 30.0
|
|
2370
|
+
|
|
2371
|
+
save_start = 0
|
|
2372
|
+
save_end = max(all_masks.keys()) if all_masks else 0
|
|
2373
|
+
if hasattr(self, "chk_limit_range") and self.chk_limit_range.isChecked():
|
|
2374
|
+
save_start = self.spin_start.value()
|
|
2375
|
+
save_end = min(self.spin_end.value(), save_end)
|
|
2376
|
+
save_start = max(save_start, 0)
|
|
2377
|
+
save_end = max(save_end, save_start)
|
|
2378
|
+
|
|
2379
|
+
experiment_path = self.config.get("experiment_path")
|
|
2380
|
+
if experiment_path and os.path.exists(experiment_path):
|
|
2381
|
+
out_dir = os.path.join(experiment_path, "overlays")
|
|
2382
|
+
else:
|
|
2383
|
+
from singlebehaviorlab._paths import USER_DATA_DIR
|
|
2384
|
+
out_dir = str(USER_DATA_DIR / "data" / "overlays")
|
|
2385
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
2386
|
+
|
|
2387
|
+
video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
|
|
2388
|
+
if video_basename.endswith("_masks"):
|
|
2389
|
+
video_basename = video_basename[:-6]
|
|
2390
|
+
suffix = "_tracking_overlay_paused.mp4" if paused else "_tracking_overlay.mp4"
|
|
2391
|
+
output_path = os.path.join(out_dir, f"{video_basename}{suffix}")
|
|
2392
|
+
|
|
2393
|
+
writer = cv2.VideoWriter(
|
|
2394
|
+
output_path,
|
|
2395
|
+
cv2.VideoWriter_fourcc(*"mp4v"),
|
|
2396
|
+
fps,
|
|
2397
|
+
(video_width, video_height),
|
|
2398
|
+
)
|
|
2399
|
+
if not writer.isOpened():
|
|
2400
|
+
cap.release()
|
|
2401
|
+
return None
|
|
2402
|
+
|
|
2403
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, save_start)
|
|
2404
|
+
alpha = 0.35
|
|
2405
|
+
total_to_write = max(0, save_end - save_start + 1)
|
|
2406
|
+
progress = QProgressDialog(
|
|
2407
|
+
"Saving overlay video...",
|
|
2408
|
+
"",
|
|
2409
|
+
0,
|
|
2410
|
+
total_to_write,
|
|
2411
|
+
self
|
|
2412
|
+
)
|
|
2413
|
+
progress.setWindowTitle("Exporting Overlay Video")
|
|
2414
|
+
progress.setWindowModality(Qt.WindowModality.ApplicationModal)
|
|
2415
|
+
progress.setAutoClose(True)
|
|
2416
|
+
progress.setAutoReset(True)
|
|
2417
|
+
progress.setMinimumDuration(0)
|
|
2418
|
+
progress.setCancelButton(None)
|
|
2419
|
+
progress.setValue(0)
|
|
2420
|
+
QApplication.processEvents()
|
|
2421
|
+
|
|
2422
|
+
try:
|
|
2423
|
+
written = 0
|
|
2424
|
+
for frame_idx in range(save_start, save_end + 1):
|
|
2425
|
+
ret, frame = cap.read()
|
|
2426
|
+
if not ret:
|
|
2427
|
+
break
|
|
2428
|
+
|
|
2429
|
+
frame_masks = all_masks.get(frame_idx, {})
|
|
2430
|
+
for obj_id, mask in frame_masks.items():
|
|
2431
|
+
if mask is None or mask.max() == 0:
|
|
2432
|
+
continue
|
|
2433
|
+
|
|
2434
|
+
if mask.shape[0] != video_height or mask.shape[1] != video_width:
|
|
2435
|
+
mask = cv2.resize(
|
|
2436
|
+
mask.astype(np.float32),
|
|
2437
|
+
(video_width, video_height),
|
|
2438
|
+
interpolation=cv2.INTER_NEAREST
|
|
2439
|
+
).astype(np.uint8)
|
|
2440
|
+
|
|
2441
|
+
idx = mask > 0
|
|
2442
|
+
if not np.any(idx):
|
|
2443
|
+
continue
|
|
2444
|
+
|
|
2445
|
+
color_rgb = get_obj_color(obj_id)
|
|
2446
|
+
color_bgr = np.array([color_rgb[2], color_rgb[1], color_rgb[0]], dtype=np.float32)
|
|
2447
|
+
frame_float = frame.astype(np.float32)
|
|
2448
|
+
frame_float[idx] = frame_float[idx] * (1.0 - alpha) + color_bgr * alpha
|
|
2449
|
+
frame = frame_float.astype(np.uint8)
|
|
2450
|
+
|
|
2451
|
+
ys, xs = np.where(idx)
|
|
2452
|
+
if len(xs) > 0 and len(ys) > 0:
|
|
2453
|
+
cx = int(np.mean(xs))
|
|
2454
|
+
cy = int(np.mean(ys))
|
|
2455
|
+
cv2.putText(
|
|
2456
|
+
frame,
|
|
2457
|
+
f"id:{obj_id}",
|
|
2458
|
+
(cx, cy),
|
|
2459
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
2460
|
+
0.45,
|
|
2461
|
+
(255, 255, 255),
|
|
2462
|
+
1,
|
|
2463
|
+
cv2.LINE_AA,
|
|
2464
|
+
)
|
|
2465
|
+
|
|
2466
|
+
writer.write(frame)
|
|
2467
|
+
written += 1
|
|
2468
|
+
if written % 5 == 0 or written == total_to_write:
|
|
2469
|
+
progress.setValue(written)
|
|
2470
|
+
QApplication.processEvents()
|
|
2471
|
+
finally:
|
|
2472
|
+
progress.setValue(total_to_write)
|
|
2473
|
+
writer.release()
|
|
2474
|
+
cap.release()
|
|
2475
|
+
|
|
2476
|
+
return output_path
|
|
2477
|
+
|
|
2478
|
+
def _save_masks(self):
|
|
2479
|
+
"""Save masks in format compatible with animal_registration app."""
|
|
2480
|
+
if not self.video_path:
|
|
2481
|
+
return None
|
|
2482
|
+
|
|
2483
|
+
import cv2
|
|
2484
|
+
import pickle
|
|
2485
|
+
|
|
2486
|
+
# Merge incremental saves back into self.masks
|
|
2487
|
+
if hasattr(self, '_incremental_mask_file') and self._incremental_mask_file is not None:
|
|
2488
|
+
self._incremental_mask_file.close()
|
|
2489
|
+
try:
|
|
2490
|
+
with open(self._incremental_mask_file.name, 'rb') as f:
|
|
2491
|
+
while True:
|
|
2492
|
+
try:
|
|
2493
|
+
chunk = pickle.load(f)
|
|
2494
|
+
for frame_idx, frame_data in chunk.items():
|
|
2495
|
+
if frame_idx not in self.masks:
|
|
2496
|
+
self.masks[frame_idx] = frame_data
|
|
2497
|
+
except EOFError:
|
|
2498
|
+
break
|
|
2499
|
+
# Clean up temp file
|
|
2500
|
+
os.unlink(self._incremental_mask_file.name)
|
|
2501
|
+
except Exception as e:
|
|
2502
|
+
logger.warning("Could not load incremental masks: %s", e)
|
|
2503
|
+
finally:
|
|
2504
|
+
self._incremental_mask_file = None
|
|
2505
|
+
self._incremental_frame_indices = []
|
|
2506
|
+
|
|
2507
|
+
if not self.masks:
|
|
2508
|
+
return None
|
|
2509
|
+
|
|
2510
|
+
# Get video dimensions
|
|
2511
|
+
cap = cv2.VideoCapture(self.video_path)
|
|
2512
|
+
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
2513
|
+
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
2514
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
2515
|
+
cap.release()
|
|
2516
|
+
|
|
2517
|
+
# Determine range to save (respect limit range if set)
|
|
2518
|
+
save_start = 0
|
|
2519
|
+
save_end = max(self.masks.keys()) if self.masks else 0
|
|
2520
|
+
if hasattr(self, "chk_limit_range") and self.chk_limit_range.isChecked():
|
|
2521
|
+
save_start = self.spin_start.value()
|
|
2522
|
+
save_end = self.spin_end.value()
|
|
2523
|
+
save_start = max(save_start, 0)
|
|
2524
|
+
save_end = max(save_end, save_start)
|
|
2525
|
+
|
|
2526
|
+
frame_objects = []
|
|
2527
|
+
num_frames = (save_end - save_start + 1) if self.masks else 0
|
|
2528
|
+
|
|
2529
|
+
for frame_idx_global in range(save_start, save_end + 1):
|
|
2530
|
+
frame_objs = []
|
|
2531
|
+
if frame_idx_global in self.masks:
|
|
2532
|
+
for obj_id, mask in self.masks[frame_idx_global].items():
|
|
2533
|
+
if mask is not None and mask.max() > 0:
|
|
2534
|
+
# Resize mask to video dimensions if needed
|
|
2535
|
+
if mask.shape[0] != video_height or mask.shape[1] != video_width:
|
|
2536
|
+
mask_resized = cv2.resize(
|
|
2537
|
+
mask.astype(np.float32),
|
|
2538
|
+
(video_width, video_height),
|
|
2539
|
+
interpolation=cv2.INTER_NEAREST
|
|
2540
|
+
).astype(np.uint8)
|
|
2541
|
+
else:
|
|
2542
|
+
mask_resized = mask
|
|
2543
|
+
|
|
2544
|
+
# Find bounding box
|
|
2545
|
+
rows, cols = np.where(mask_resized > 0)
|
|
2546
|
+
if len(rows) > 0 and len(cols) > 0:
|
|
2547
|
+
y_min, y_max = np.min(rows), np.max(rows)
|
|
2548
|
+
x_min, x_max = np.min(cols), np.max(cols)
|
|
2549
|
+
|
|
2550
|
+
# Extract mask within bbox
|
|
2551
|
+
bbox_mask = mask_resized[y_min:y_max+1, x_min:x_max+1]
|
|
2552
|
+
|
|
2553
|
+
obj = {
|
|
2554
|
+
'bbox': (int(x_min), int(y_min), int(x_max), int(y_max)),
|
|
2555
|
+
'mask': bbox_mask.astype(bool),
|
|
2556
|
+
'obj_id': int(obj_id)
|
|
2557
|
+
}
|
|
2558
|
+
frame_objs.append(obj)
|
|
2559
|
+
frame_objects.append(frame_objs)
|
|
2560
|
+
|
|
2561
|
+
# Create mask data dictionary
|
|
2562
|
+
mask_data = {
|
|
2563
|
+
'video_path': self.video_path,
|
|
2564
|
+
'total_frames': num_frames,
|
|
2565
|
+
'height': video_height,
|
|
2566
|
+
'width': video_width,
|
|
2567
|
+
'fps': fps,
|
|
2568
|
+
'frame_objects': frame_objects,
|
|
2569
|
+
'objects_per_frame': [len(objs) for objs in frame_objects],
|
|
2570
|
+
'tracker': {},
|
|
2571
|
+
'format': 'new',
|
|
2572
|
+
'start_offset': save_start,
|
|
2573
|
+
'original_total_frames': self.total_frames
|
|
2574
|
+
}
|
|
2575
|
+
|
|
2576
|
+
# Save to HDF5 file - use experiment folder if available
|
|
2577
|
+
experiment_path = self.config.get("experiment_path")
|
|
2578
|
+
if experiment_path and os.path.exists(experiment_path):
|
|
2579
|
+
masks_dir = os.path.join(experiment_path, "masks")
|
|
2580
|
+
else:
|
|
2581
|
+
from singlebehaviorlab._paths import USER_DATA_DIR
|
|
2582
|
+
masks_dir = str(USER_DATA_DIR / "data" / "masks")
|
|
2583
|
+
os.makedirs(masks_dir, exist_ok=True)
|
|
2584
|
+
|
|
2585
|
+
video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
|
|
2586
|
+
# Remove "_masks" suffix if present to avoid duplication
|
|
2587
|
+
if video_basename.endswith("_masks"):
|
|
2588
|
+
video_basename = video_basename[:-6]
|
|
2589
|
+
mask_path = os.path.join(masks_dir, f"{video_basename}.h5")
|
|
2590
|
+
|
|
2591
|
+
from singlebehaviorlab.backend.video_processor import save_segmentation_data
|
|
2592
|
+
save_segmentation_data(mask_path, mask_data)
|
|
2593
|
+
return mask_path
|
|
2594
|
+
|
|
2595
|
+
def _on_tracking_error(self, err):
|
|
2596
|
+
"""Handle tracking error."""
|
|
2597
|
+
self.btn_track.setEnabled(True)
|
|
2598
|
+
self.btn_track_all.setEnabled(self.sam2_available and len(self.videos) > 0)
|
|
2599
|
+
self.btn_pause_tracking.setEnabled(False)
|
|
2600
|
+
self.btn_resume_tracking.setEnabled(False)
|
|
2601
|
+
self.progress_bar.setVisible(False)
|
|
2602
|
+
QMessageBox.critical(self, "Error", f"Tracking failed:\n{err}")
|
|
2603
|
+
self.inference_state = None
|
|
2604
|
+
# Stop batch mode on error
|
|
2605
|
+
self.batch_mode = False
|
|
2606
|
+
self.batch_queue = []
|
|
2607
|
+
|
|
2608
|
+
def _open_settings(self):
|
|
2609
|
+
"""Open settings dialog."""
|
|
2610
|
+
from PyQt6.QtWidgets import QDialog, QFormLayout
|
|
2611
|
+
|
|
2612
|
+
dialog = QDialog(self)
|
|
2613
|
+
dialog.setWindowTitle("SAM2 Settings")
|
|
2614
|
+
dialog.resize(450, 550)
|
|
2615
|
+
layout = QFormLayout(dialog)
|
|
2616
|
+
|
|
2617
|
+
spin_threshold = QDoubleSpinBox()
|
|
2618
|
+
spin_threshold.setRange(-10.0, 10.0)
|
|
2619
|
+
spin_threshold.setSingleStep(0.1)
|
|
2620
|
+
spin_threshold.setValue(self.mask_threshold)
|
|
2621
|
+
layout.addRow("Mask Threshold:", spin_threshold)
|
|
2622
|
+
|
|
2623
|
+
spin_fill_hole = QSpinBox()
|
|
2624
|
+
spin_fill_hole.setRange(0, 10000)
|
|
2625
|
+
spin_fill_hole.setValue(self.fill_hole_area)
|
|
2626
|
+
layout.addRow("Fill Hole Area:", spin_fill_hole)
|
|
2627
|
+
|
|
2628
|
+
chk_offload_video = QCheckBox()
|
|
2629
|
+
chk_offload_video.setChecked(self.offload_video)
|
|
2630
|
+
layout.addRow("Offload Video to CPU:", chk_offload_video)
|
|
2631
|
+
|
|
2632
|
+
chk_offload_state = QCheckBox()
|
|
2633
|
+
chk_offload_state.setChecked(self.offload_state)
|
|
2634
|
+
layout.addRow("Offload State to CPU:", chk_offload_state)
|
|
2635
|
+
|
|
2636
|
+
chk_bf16_autocast = QCheckBox()
|
|
2637
|
+
chk_bf16_autocast.setChecked(self.use_cuda_bf16_autocast)
|
|
2638
|
+
chk_bf16_autocast.setToolTip(
|
|
2639
|
+
"Use CUDA bfloat16 autocast for SAM2 inference.\n"
|
|
2640
|
+
"Usually speeds up segmentation on newer NVIDIA GPUs.\n"
|
|
2641
|
+
"Ignored on CPU."
|
|
2642
|
+
)
|
|
2643
|
+
layout.addRow("Use CUDA bf16 autocast:", chk_bf16_autocast)
|
|
2644
|
+
|
|
2645
|
+
chk_memory_management = QCheckBox()
|
|
2646
|
+
chk_memory_management.setChecked(self.enable_memory_management)
|
|
2647
|
+
layout.addRow("Enable Memory Management:", chk_memory_management)
|
|
2648
|
+
|
|
2649
|
+
chk_reseed_chunks = QCheckBox()
|
|
2650
|
+
chk_reseed_chunks.setToolTip("When processing in chunks, re-seed the next chunk with the last frame mask as a mask prompt.")
|
|
2651
|
+
chk_reseed_chunks.setChecked(getattr(self, "reseed_between_chunks", False))
|
|
2652
|
+
layout.addRow("Re-seed each chunk with last mask:", chk_reseed_chunks)
|
|
2653
|
+
|
|
2654
|
+
layout.addRow(QLabel("<b>Motion-Aware Tracking</b>"))
|
|
2655
|
+
|
|
2656
|
+
chk_motion_tracking = QCheckBox()
|
|
2657
|
+
chk_motion_tracking.setToolTip(
|
|
2658
|
+
"Enable motion-aware tracking:\n"
|
|
2659
|
+
"- Uses Kalman filter to predict object motion\n"
|
|
2660
|
+
"- Scores each frame by mask quality and motion consistency\n"
|
|
2661
|
+
"- Filters low-quality frames from memory to prevent drift\n"
|
|
2662
|
+
"Requires: pip install filterpy"
|
|
2663
|
+
)
|
|
2664
|
+
chk_motion_tracking.setChecked(getattr(self, "enable_motion_tracking", False))
|
|
2665
|
+
layout.addRow("Enable motion-aware tracking:", chk_motion_tracking)
|
|
2666
|
+
|
|
2667
|
+
spin_motion_threshold = QDoubleSpinBox()
|
|
2668
|
+
spin_motion_threshold.setRange(0.0, 1.0)
|
|
2669
|
+
spin_motion_threshold.setSingleStep(0.05)
|
|
2670
|
+
spin_motion_threshold.setValue(getattr(self, "motion_score_threshold", 0.3))
|
|
2671
|
+
spin_motion_threshold.setToolTip(
|
|
2672
|
+
"Minimum score for a frame to be used in memory.\n"
|
|
2673
|
+
"Lower = more permissive, Higher = stricter filtering.\n"
|
|
2674
|
+
"Score combines mask confidence and motion IoU."
|
|
2675
|
+
)
|
|
2676
|
+
layout.addRow("Motion score threshold:", spin_motion_threshold)
|
|
2677
|
+
|
|
2678
|
+
spin_consecutive_low = QSpinBox()
|
|
2679
|
+
spin_consecutive_low.setRange(1, 20)
|
|
2680
|
+
spin_consecutive_low.setValue(getattr(self, "motion_consecutive_low", 3))
|
|
2681
|
+
spin_consecutive_low.setToolTip(
|
|
2682
|
+
"Number of consecutive low-score frames before auto-correction.\n"
|
|
2683
|
+
"Lower = faster correction but more sensitive.\n"
|
|
2684
|
+
"Higher = more tolerant but slower to correct drift."
|
|
2685
|
+
)
|
|
2686
|
+
layout.addRow("Frames before auto-correct:", spin_consecutive_low)
|
|
2687
|
+
|
|
2688
|
+
spin_area_threshold = QDoubleSpinBox()
|
|
2689
|
+
spin_area_threshold.setRange(0.1, 2.0)
|
|
2690
|
+
spin_area_threshold.setSingleStep(0.1)
|
|
2691
|
+
spin_area_threshold.setValue(getattr(self, "motion_area_threshold", 0.5))
|
|
2692
|
+
spin_area_threshold.setToolTip(
|
|
2693
|
+
"Max allowed mask area change ratio.\n"
|
|
2694
|
+
"0.5 = mask can shrink/grow by 50% max.\n"
|
|
2695
|
+
"Lower = stricter, Higher = more permissive."
|
|
2696
|
+
)
|
|
2697
|
+
layout.addRow("Area change tolerance:", spin_area_threshold)
|
|
2698
|
+
|
|
2699
|
+
layout.addRow(QLabel("<b>OC-SORT Drift Correction</b>"))
|
|
2700
|
+
|
|
2701
|
+
chk_ocsort = QCheckBox()
|
|
2702
|
+
chk_ocsort.setToolTip(
|
|
2703
|
+
"Enable OC-SORT enhancements for drift correction:\n\n"
|
|
2704
|
+
"-Virtual Trajectory: During occlusions, maintains tracking\n"
|
|
2705
|
+
" using predicted motion (prevents state collapse)\n\n"
|
|
2706
|
+
"-ORU (Observation-Centric Re-Update): When object reappears,\n"
|
|
2707
|
+
" corrects accumulated drift by re-estimating past states\n\n"
|
|
2708
|
+
"Based on: 'Observation-Centric SORT' (arXiv:2203.14360)"
|
|
2709
|
+
)
|
|
2710
|
+
chk_ocsort.setChecked(getattr(self, "enable_ocsort", False))
|
|
2711
|
+
layout.addRow("Enable OC-SORT drift correction:", chk_ocsort)
|
|
2712
|
+
|
|
2713
|
+
spin_ocsort_inertia = QDoubleSpinBox()
|
|
2714
|
+
spin_ocsort_inertia.setRange(0.0, 1.0)
|
|
2715
|
+
spin_ocsort_inertia.setSingleStep(0.05)
|
|
2716
|
+
spin_ocsort_inertia.setValue(getattr(self, "ocsort_inertia", 0.2))
|
|
2717
|
+
spin_ocsort_inertia.setToolTip(
|
|
2718
|
+
"Velocity smoothing factor for ORU (paper default: 0.2).\n\n"
|
|
2719
|
+
"When object reappears after occlusion, this blends\n"
|
|
2720
|
+
"old velocity with newly computed velocity:\n"
|
|
2721
|
+
" smoothed = inertia * old_vel + (1-inertia) * new_vel\n\n"
|
|
2722
|
+
"Higher = more momentum, smoother but slower correction.\n"
|
|
2723
|
+
"Lower = faster correction but may be jerky."
|
|
2724
|
+
)
|
|
2725
|
+
layout.addRow("ORU inertia (velocity smoothing):", spin_ocsort_inertia)
|
|
2726
|
+
|
|
2727
|
+
btn_ok = QPushButton("OK")
|
|
2728
|
+
btn_ok.clicked.connect(dialog.accept)
|
|
2729
|
+
layout.addRow(btn_ok)
|
|
2730
|
+
|
|
2731
|
+
if dialog.exec():
|
|
2732
|
+
self.mask_threshold = spin_threshold.value()
|
|
2733
|
+
self.fill_hole_area = spin_fill_hole.value()
|
|
2734
|
+
self.offload_video = chk_offload_video.isChecked()
|
|
2735
|
+
self.offload_state = chk_offload_state.isChecked()
|
|
2736
|
+
self.use_cuda_bf16_autocast = chk_bf16_autocast.isChecked()
|
|
2737
|
+
self.enable_memory_management = chk_memory_management.isChecked()
|
|
2738
|
+
self.reseed_between_chunks = chk_reseed_chunks.isChecked()
|
|
2739
|
+
self.enable_motion_tracking = chk_motion_tracking.isChecked()
|
|
2740
|
+
self.motion_score_threshold = spin_motion_threshold.value()
|
|
2741
|
+
self.motion_consecutive_low = spin_consecutive_low.value()
|
|
2742
|
+
self.motion_area_threshold = spin_area_threshold.value()
|
|
2743
|
+
self.enable_ocsort = chk_ocsort.isChecked()
|
|
2744
|
+
self.ocsort_inertia = spin_ocsort_inertia.value()
|
|
2745
|
+
|
|
2746
|
+
if self.predictor:
|
|
2747
|
+
self.predictor.fill_hole_area = self.fill_hole_area
|
|
2748
|
+
|
|
2749
|
+
def update_config(self, config: dict):
|
|
2750
|
+
"""Update configuration."""
|
|
2751
|
+
self.config = config
|
|
2752
|
+
|