singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
"""Inference worker thread: runs model on video(s) and emits results."""
|
|
2
|
+
from PyQt6.QtCore import QThread, pyqtSignal
|
|
3
|
+
import cv2
|
|
4
|
+
import os
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _sanitize_bbox_coords(x1, y1, x2, y2, crop_padding=0.35, crop_min_size=0.04):
|
|
10
|
+
"""Pad bbox proportionally to its own size (matches training _sanitize_bboxes)."""
|
|
11
|
+
x1, y1 = max(0.0, min(1.0, x1)), max(0.0, min(1.0, y1))
|
|
12
|
+
x2, y2 = max(0.0, min(1.0, x2)), max(0.0, min(1.0, y2))
|
|
13
|
+
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
|
|
14
|
+
w = min(max(abs(x2 - x1), crop_min_size) * (1.0 + 2.0 * crop_padding), 1.0)
|
|
15
|
+
h = min(max(abs(y2 - y1), crop_min_size) * (1.0 + 2.0 * crop_padding), 1.0)
|
|
16
|
+
x1, y1 = max(0.0, cx - 0.5 * w), max(0.0, cy - 0.5 * h)
|
|
17
|
+
x2, y2 = min(1.0, cx + 0.5 * w), min(1.0, cy + 0.5 * h)
|
|
18
|
+
x2, y2 = max(x2, x1 + crop_min_size), max(y2, y1 + crop_min_size)
|
|
19
|
+
return (max(0.0, min(1.0, x1)), max(0.0, min(1.0, y1)),
|
|
20
|
+
max(0.0, min(1.0, x2)), max(0.0, min(1.0, y2)))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InferenceWorker(QThread):
|
|
24
|
+
"""Worker thread for running inference."""
|
|
25
|
+
progress = pyqtSignal(int, int)
|
|
26
|
+
finished = pyqtSignal(dict)
|
|
27
|
+
error = pyqtSignal(str)
|
|
28
|
+
log_message = pyqtSignal(str)
|
|
29
|
+
video_done = pyqtSignal(str, dict)
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model,
|
|
34
|
+
video_paths,
|
|
35
|
+
target_fps,
|
|
36
|
+
clip_length,
|
|
37
|
+
step_frames,
|
|
38
|
+
resolution=288,
|
|
39
|
+
classes=None,
|
|
40
|
+
use_localization_pipeline=False,
|
|
41
|
+
crop_padding=0.35,
|
|
42
|
+
crop_min_size=0.04,
|
|
43
|
+
use_ovr=False,
|
|
44
|
+
ovr_temperatures=None,
|
|
45
|
+
collect_attention=False,
|
|
46
|
+
sample_ranges_by_video=None,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.model = model
|
|
50
|
+
self.video_paths = video_paths if isinstance(video_paths, list) else [video_paths]
|
|
51
|
+
self.target_fps = target_fps
|
|
52
|
+
self.clip_length = clip_length
|
|
53
|
+
self.step_frames = step_frames
|
|
54
|
+
self.resolution = int(resolution)
|
|
55
|
+
self.classes = classes or []
|
|
56
|
+
self.use_localization_pipeline = bool(use_localization_pipeline)
|
|
57
|
+
self.crop_padding = float(crop_padding)
|
|
58
|
+
self.crop_min_size = float(crop_min_size)
|
|
59
|
+
self.use_ovr = bool(use_ovr)
|
|
60
|
+
self._ovr_temperatures = dict(ovr_temperatures or {})
|
|
61
|
+
self.collect_attention = bool(collect_attention)
|
|
62
|
+
self.sample_ranges_by_video = dict(sample_ranges_by_video or {})
|
|
63
|
+
self.should_stop = False
|
|
64
|
+
|
|
65
|
+
def _apply_ovr_temperature(self, logits: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
"""Scale logits by per-class calibrated temperature before sigmoid.
|
|
67
|
+
|
|
68
|
+
logits: [..., C] tensor. Divides each class dimension by its
|
|
69
|
+
temperature (default 1.0 = no change).
|
|
70
|
+
"""
|
|
71
|
+
if not self._ovr_temperatures or not self.classes:
|
|
72
|
+
return logits
|
|
73
|
+
C = logits.shape[-1]
|
|
74
|
+
temps = torch.ones(C, device=logits.device, dtype=logits.dtype)
|
|
75
|
+
for ci, cls_name in enumerate(self.classes):
|
|
76
|
+
if ci < C and cls_name in self._ovr_temperatures:
|
|
77
|
+
t = float(self._ovr_temperatures[cls_name])
|
|
78
|
+
temps[ci] = max(t, 0.01)
|
|
79
|
+
return logits / temps
|
|
80
|
+
|
|
81
|
+
def stop(self):
|
|
82
|
+
self.should_stop = True
|
|
83
|
+
|
|
84
|
+
def _build_center_merge_weights(self, length: int) -> np.ndarray:
|
|
85
|
+
"""Center-heavy temporal weights for overlap aggregation.
|
|
86
|
+
|
|
87
|
+
Clip-edge predictions tend to be less reliable because they have less
|
|
88
|
+
temporal context. Use a Hann-shaped profile with a small floor so edge
|
|
89
|
+
frames still contribute when only one clip covers them.
|
|
90
|
+
"""
|
|
91
|
+
if length <= 1:
|
|
92
|
+
return np.ones((max(1, length),), dtype=np.float32)
|
|
93
|
+
w = np.hanning(length).astype(np.float32)
|
|
94
|
+
if not np.any(w > 0):
|
|
95
|
+
return np.ones((length,), dtype=np.float32)
|
|
96
|
+
return np.clip(0.1 + 0.9 * w, 1e-3, None).astype(np.float32)
|
|
97
|
+
|
|
98
|
+
def _iter_clips_prefetched(self, video_path, target_fps, clip_length, step_frames, chunk_size=64):
|
|
99
|
+
from threading import Thread
|
|
100
|
+
from queue import Queue
|
|
101
|
+
q = Queue(maxsize=2)
|
|
102
|
+
|
|
103
|
+
def _producer():
|
|
104
|
+
try:
|
|
105
|
+
for chunk in self._iter_clips_in_chunks(video_path, target_fps, clip_length, step_frames, chunk_size):
|
|
106
|
+
q.put(chunk)
|
|
107
|
+
if self.should_stop:
|
|
108
|
+
break
|
|
109
|
+
finally:
|
|
110
|
+
q.put(None)
|
|
111
|
+
|
|
112
|
+
t = Thread(target=_producer, daemon=True)
|
|
113
|
+
t.start()
|
|
114
|
+
while True:
|
|
115
|
+
item = q.get()
|
|
116
|
+
if item is None:
|
|
117
|
+
break
|
|
118
|
+
yield item
|
|
119
|
+
t.join(timeout=2.0)
|
|
120
|
+
|
|
121
|
+
def _normalize_sample_ranges(self, sample_ranges, total_frames):
|
|
122
|
+
"""Clamp and sort sample ranges as [(start, end), ...] with end-exclusive bounds."""
|
|
123
|
+
if not sample_ranges:
|
|
124
|
+
return []
|
|
125
|
+
normalized = []
|
|
126
|
+
for rng in sample_ranges:
|
|
127
|
+
if not isinstance(rng, (list, tuple)) or len(rng) != 2:
|
|
128
|
+
continue
|
|
129
|
+
start, end = int(rng[0]), int(rng[1])
|
|
130
|
+
start = max(0, min(total_frames, start))
|
|
131
|
+
end = max(start, min(total_frames, end))
|
|
132
|
+
if end > start:
|
|
133
|
+
normalized.append((start, end))
|
|
134
|
+
normalized.sort()
|
|
135
|
+
merged = []
|
|
136
|
+
for start, end in normalized:
|
|
137
|
+
if not merged or start > merged[-1][1]:
|
|
138
|
+
merged.append([start, end])
|
|
139
|
+
else:
|
|
140
|
+
merged[-1][1] = max(merged[-1][1], end)
|
|
141
|
+
return [(start, end) for start, end in merged]
|
|
142
|
+
|
|
143
|
+
def _estimate_subsampled_frames_in_ranges(self, sample_ranges, frame_interval):
|
|
144
|
+
total = 0
|
|
145
|
+
for start, end in sample_ranges:
|
|
146
|
+
first = start if start % frame_interval == 0 else start + (frame_interval - (start % frame_interval))
|
|
147
|
+
if first >= end:
|
|
148
|
+
continue
|
|
149
|
+
total += 1 + ((end - 1 - first) // frame_interval)
|
|
150
|
+
return total
|
|
151
|
+
|
|
152
|
+
def _iter_clips_in_chunks(self, video_path, target_fps, clip_length, step_frames, chunk_size=64):
|
|
153
|
+
cap = cv2.VideoCapture(video_path)
|
|
154
|
+
if not cap.isOpened():
|
|
155
|
+
raise ValueError(f"Could not open video: {video_path}")
|
|
156
|
+
|
|
157
|
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
158
|
+
if orig_fps <= 0:
|
|
159
|
+
orig_fps = 30.0
|
|
160
|
+
|
|
161
|
+
frame_interval = max(1, int(round(orig_fps / target_fps)))
|
|
162
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
163
|
+
sample_ranges = self._normalize_sample_ranges(
|
|
164
|
+
self.sample_ranges_by_video.get(video_path),
|
|
165
|
+
total_frames,
|
|
166
|
+
)
|
|
167
|
+
if sample_ranges:
|
|
168
|
+
total_subsampled = self._estimate_subsampled_frames_in_ranges(sample_ranges, frame_interval)
|
|
169
|
+
else:
|
|
170
|
+
total_subsampled = total_frames // frame_interval
|
|
171
|
+
if total_subsampled < clip_length:
|
|
172
|
+
total_clips_estimate = 0
|
|
173
|
+
elif step_frames <= 0:
|
|
174
|
+
total_clips_estimate = max(1, total_subsampled // max(1, clip_length))
|
|
175
|
+
else:
|
|
176
|
+
total_clips_estimate = 1 + max(0, (total_subsampled - clip_length) // step_frames)
|
|
177
|
+
total_clips_estimate = max(1, total_clips_estimate)
|
|
178
|
+
|
|
179
|
+
clip_starts = []
|
|
180
|
+
frames_buffer = []
|
|
181
|
+
frames_buffer_fullres = []
|
|
182
|
+
keep_fullres_cache = self.use_localization_pipeline and getattr(self.model, "use_localization", False)
|
|
183
|
+
chunk_clips = []
|
|
184
|
+
chunk_starts = []
|
|
185
|
+
chunk_fullres = []
|
|
186
|
+
clip_idx = 0
|
|
187
|
+
|
|
188
|
+
def _iter_ranges():
|
|
189
|
+
if sample_ranges:
|
|
190
|
+
for start, end in sample_ranges:
|
|
191
|
+
yield start, end
|
|
192
|
+
else:
|
|
193
|
+
yield 0, total_frames
|
|
194
|
+
|
|
195
|
+
for range_start, range_end in _iter_ranges():
|
|
196
|
+
if self.should_stop:
|
|
197
|
+
break
|
|
198
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, range_start)
|
|
199
|
+
frames_buffer = []
|
|
200
|
+
frames_buffer_fullres = []
|
|
201
|
+
skip_remaining = 0
|
|
202
|
+
frame_idx = range_start
|
|
203
|
+
|
|
204
|
+
while frame_idx < range_end:
|
|
205
|
+
if self.should_stop:
|
|
206
|
+
break
|
|
207
|
+
ret, frame = cap.read()
|
|
208
|
+
if not ret:
|
|
209
|
+
break
|
|
210
|
+
if frame_idx % frame_interval == 0:
|
|
211
|
+
if skip_remaining > 0:
|
|
212
|
+
skip_remaining -= 1
|
|
213
|
+
frame_idx += 1
|
|
214
|
+
continue
|
|
215
|
+
if keep_fullres_cache:
|
|
216
|
+
frames_buffer_fullres.append(frame.copy())
|
|
217
|
+
h_src, w_src = frame.shape[:2]
|
|
218
|
+
is_upscale = (w_src < self.resolution) or (h_src < self.resolution)
|
|
219
|
+
interp = cv2.INTER_LANCZOS4 if is_upscale else cv2.INTER_AREA
|
|
220
|
+
frame_resized = cv2.resize(frame, (self.resolution, self.resolution), interpolation=interp)
|
|
221
|
+
if is_upscale:
|
|
222
|
+
blurred = cv2.GaussianBlur(frame_resized, (0, 0), sigmaX=1.0)
|
|
223
|
+
frame_resized = cv2.addWeighted(frame_resized, 1.5, blurred, -0.5, 0)
|
|
224
|
+
frame_resized = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
|
|
225
|
+
frames_buffer.append(frame_resized)
|
|
226
|
+
if len(frames_buffer) == clip_length:
|
|
227
|
+
clip_array = np.stack(frames_buffer).astype(np.float32) / 255.0
|
|
228
|
+
clip_start = frame_idx - (clip_length - 1) * frame_interval
|
|
229
|
+
clip_starts.append(clip_start)
|
|
230
|
+
chunk_clips.append(clip_array)
|
|
231
|
+
chunk_starts.append(clip_start)
|
|
232
|
+
if keep_fullres_cache:
|
|
233
|
+
chunk_fullres.append(np.stack(frames_buffer_fullres))
|
|
234
|
+
clip_idx += 1
|
|
235
|
+
if self.progress:
|
|
236
|
+
self.progress.emit(clip_idx, total_clips_estimate)
|
|
237
|
+
if step_frames < clip_length:
|
|
238
|
+
frames_buffer = frames_buffer[step_frames:]
|
|
239
|
+
if keep_fullres_cache:
|
|
240
|
+
frames_buffer_fullres = frames_buffer_fullres[step_frames:]
|
|
241
|
+
else:
|
|
242
|
+
frames_buffer = []
|
|
243
|
+
frames_buffer_fullres = []
|
|
244
|
+
skip_remaining = max(0, step_frames - clip_length)
|
|
245
|
+
if len(chunk_clips) >= max(1, int(chunk_size)):
|
|
246
|
+
yield chunk_clips, chunk_starts, chunk_fullres
|
|
247
|
+
chunk_clips = []
|
|
248
|
+
chunk_starts = []
|
|
249
|
+
chunk_fullres = []
|
|
250
|
+
frame_idx += 1
|
|
251
|
+
cap.release()
|
|
252
|
+
if chunk_clips:
|
|
253
|
+
yield chunk_clips, chunk_starts, chunk_fullres
|
|
254
|
+
|
|
255
|
+
def _predict_localization(self, clips, device, batch_size=4):
|
|
256
|
+
loc_bboxes = []
|
|
257
|
+
for i in range(0, len(clips), batch_size):
|
|
258
|
+
if self.should_stop:
|
|
259
|
+
break
|
|
260
|
+
batch_clips = clips[i:i+batch_size]
|
|
261
|
+
batch_tensors = []
|
|
262
|
+
for clip in batch_clips:
|
|
263
|
+
clip_tensor = torch.from_numpy(clip).permute(0, 3, 1, 2)
|
|
264
|
+
batch_tensors.append(clip_tensor)
|
|
265
|
+
batch = torch.stack(batch_tensors).to(device)
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
out = self.model(batch, return_localization=True)
|
|
268
|
+
del batch
|
|
269
|
+
loc_part = None
|
|
270
|
+
if isinstance(out, tuple) and len(out) >= 2 and torch.is_tensor(out[-1]) and out[-1].shape[-1] == 4:
|
|
271
|
+
loc_part = out[-1]
|
|
272
|
+
if loc_part is None:
|
|
273
|
+
loc_part = torch.tensor([[0.0, 0.0, 1.0, 1.0]] * len(batch_clips), device=device)
|
|
274
|
+
loc_np = loc_part.detach().cpu().numpy()
|
|
275
|
+
if loc_part.dim() == 3:
|
|
276
|
+
loc_bboxes.extend(loc_np.tolist())
|
|
277
|
+
else:
|
|
278
|
+
loc_bboxes.extend(loc_np.tolist())
|
|
279
|
+
return loc_bboxes
|
|
280
|
+
|
|
281
|
+
def _build_refined_clips(self, fullres_clip_frames, loc_bboxes):
|
|
282
|
+
refined = []
|
|
283
|
+
for idx, frames_bgr in enumerate(fullres_clip_frames):
|
|
284
|
+
bbox = loc_bboxes[idx] if idx < len(loc_bboxes) else [0.0, 0.0, 1.0, 1.0]
|
|
285
|
+
has_temporal_boxes = (
|
|
286
|
+
isinstance(bbox, (list, tuple))
|
|
287
|
+
and len(bbox) > 0
|
|
288
|
+
and isinstance(bbox[0], (list, tuple))
|
|
289
|
+
and len(bbox[0]) == 4
|
|
290
|
+
)
|
|
291
|
+
if has_temporal_boxes:
|
|
292
|
+
b0 = [float(v) for v in bbox[0]]
|
|
293
|
+
x1_fixed, y1_fixed, x2_fixed, y2_fixed = _sanitize_bbox_coords(*b0, self.crop_padding, self.crop_min_size)
|
|
294
|
+
clip_frames = []
|
|
295
|
+
for fi, frame_bgr in enumerate(frames_bgr):
|
|
296
|
+
if has_temporal_boxes:
|
|
297
|
+
x1, y1, x2, y2 = x1_fixed, y1_fixed, x2_fixed, y2_fixed
|
|
298
|
+
else:
|
|
299
|
+
x1, y1, x2, y2 = _sanitize_bbox_coords(*[float(v) for v in bbox], self.crop_padding, self.crop_min_size)
|
|
300
|
+
h, w = frame_bgr.shape[:2]
|
|
301
|
+
fx1 = int(round(x1 * w))
|
|
302
|
+
fy1 = int(round(y1 * h))
|
|
303
|
+
fx2 = int(round(x2 * w))
|
|
304
|
+
fy2 = int(round(y2 * h))
|
|
305
|
+
fx1 = max(0, min(fx1, w - 1))
|
|
306
|
+
fy1 = max(0, min(fy1, h - 1))
|
|
307
|
+
fx2 = max(fx1 + 1, min(fx2, w))
|
|
308
|
+
fy2 = max(fy1 + 1, min(fy2, h))
|
|
309
|
+
crop = frame_bgr[fy1:fy2, fx1:fx2]
|
|
310
|
+
if crop.size == 0:
|
|
311
|
+
crop = frame_bgr
|
|
312
|
+
# Match training: BGR→RGB, [0,1] float, then PyTorch bilinear resize
|
|
313
|
+
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
|
314
|
+
crop_f = torch.from_numpy(crop_rgb.astype(np.float32) / 255.0)
|
|
315
|
+
crop_f = crop_f.permute(2, 0, 1).unsqueeze(0) # [1,C,H,W]
|
|
316
|
+
crop_f = torch.nn.functional.interpolate(
|
|
317
|
+
crop_f, size=(self.resolution, self.resolution),
|
|
318
|
+
mode='bilinear', align_corners=False,
|
|
319
|
+
)
|
|
320
|
+
clip_frames.append(crop_f.squeeze(0)) # [C,H,W]
|
|
321
|
+
if not clip_frames:
|
|
322
|
+
clip_frames = [torch.zeros(3, self.resolution, self.resolution)] * self.clip_length
|
|
323
|
+
while len(clip_frames) < self.clip_length:
|
|
324
|
+
clip_frames.append(clip_frames[-1].clone())
|
|
325
|
+
clip_frames = clip_frames[:self.clip_length]
|
|
326
|
+
# Stack to [T,C,H,W] then back to [T,H,W,C] numpy for the existing pipeline
|
|
327
|
+
clip_tensor = torch.stack(clip_frames) # [T,C,H,W]
|
|
328
|
+
clip_np = clip_tensor.permute(0, 2, 3, 1).numpy() # [T,H,W,C]
|
|
329
|
+
refined.append(clip_np)
|
|
330
|
+
return refined
|
|
331
|
+
|
|
332
|
+
def _is_cuda_oom(self, err: Exception) -> bool:
|
|
333
|
+
if not isinstance(err, RuntimeError):
|
|
334
|
+
return False
|
|
335
|
+
msg = str(err).lower()
|
|
336
|
+
return ("out of memory" in msg) and ("cuda" in msg or "cublas" in msg or "cudnn" in msg)
|
|
337
|
+
|
|
338
|
+
def _cleanup_memory(self, device=None):
|
|
339
|
+
import gc
|
|
340
|
+
gc.collect()
|
|
341
|
+
if device is not None and getattr(device, "type", None) == "cuda" and torch.cuda.is_available():
|
|
342
|
+
try:
|
|
343
|
+
torch.cuda.empty_cache()
|
|
344
|
+
except Exception:
|
|
345
|
+
pass
|
|
346
|
+
try:
|
|
347
|
+
jax.clear_caches()
|
|
348
|
+
except Exception:
|
|
349
|
+
pass
|
|
350
|
+
|
|
351
|
+
def run(self):
|
|
352
|
+
import traceback
|
|
353
|
+
try:
|
|
354
|
+
def log_fn(msg):
|
|
355
|
+
self.log_message.emit(msg)
|
|
356
|
+
results = {}
|
|
357
|
+
total_videos = len(self.video_paths)
|
|
358
|
+
for v_idx, video_path in enumerate(self.video_paths):
|
|
359
|
+
if self.should_stop:
|
|
360
|
+
break
|
|
361
|
+
log_fn(f"Processing video {v_idx+1}/{total_videos}: {os.path.basename(video_path)}")
|
|
362
|
+
log_fn("Extracting clips and running inference in memory-safe chunks...")
|
|
363
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
364
|
+
clip_frame_embeddings = None
|
|
365
|
+
aggregated_frame_embs = None
|
|
366
|
+
try:
|
|
367
|
+
self.model.to(device)
|
|
368
|
+
self.model.eval()
|
|
369
|
+
predictions = []
|
|
370
|
+
confidences = []
|
|
371
|
+
clip_probabilities = []
|
|
372
|
+
clip_frame_probabilities = []
|
|
373
|
+
clip_frame_logits = []
|
|
374
|
+
clip_frame_embeddings = []
|
|
375
|
+
clip_attention_maps = []
|
|
376
|
+
clip_starts = []
|
|
377
|
+
localization_bboxes = []
|
|
378
|
+
has_frame_head = getattr(self.model, "use_frame_head", True)
|
|
379
|
+
if self.use_localization_pipeline and getattr(self.model, "use_localization", False):
|
|
380
|
+
log_fn("Localization pipeline enabled")
|
|
381
|
+
resolution = getattr(self.model.backbone, 'resolution', 288)
|
|
382
|
+
if device.type == "cuda":
|
|
383
|
+
try:
|
|
384
|
+
free_mb = torch.cuda.mem_get_info(device)[0] / (1024 * 1024)
|
|
385
|
+
except Exception:
|
|
386
|
+
free_mb = 4000
|
|
387
|
+
pixels = resolution * resolution
|
|
388
|
+
# Scale relative to 288x288 baseline (~250 MB/clip including activations)
|
|
389
|
+
scale = pixels / (288 * 288)
|
|
390
|
+
est_per_clip_mb = 250 * scale
|
|
391
|
+
batch_size = max(1, int(free_mb * 0.5 / est_per_clip_mb))
|
|
392
|
+
batch_size = min(batch_size, 32)
|
|
393
|
+
if resolution > 300:
|
|
394
|
+
batch_size = min(batch_size, 16)
|
|
395
|
+
if resolution > 400:
|
|
396
|
+
batch_size = min(batch_size, 4)
|
|
397
|
+
else:
|
|
398
|
+
batch_size = 2
|
|
399
|
+
log_fn(f"Inference batch size: {batch_size} (resolution={resolution}, device={device})")
|
|
400
|
+
cap = cv2.VideoCapture(video_path)
|
|
401
|
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
402
|
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
403
|
+
cap.release()
|
|
404
|
+
if orig_fps <= 0:
|
|
405
|
+
orig_fps = 30.0
|
|
406
|
+
frame_interval = max(1, int(round(orig_fps / self.target_fps)))
|
|
407
|
+
sample_ranges = self._normalize_sample_ranges(
|
|
408
|
+
self.sample_ranges_by_video.get(video_path),
|
|
409
|
+
frame_count,
|
|
410
|
+
)
|
|
411
|
+
if sample_ranges:
|
|
412
|
+
sample_desc = ", ".join(
|
|
413
|
+
f"{start/orig_fps:.1f}s-{end/orig_fps:.1f}s" for start, end in sample_ranges[:6]
|
|
414
|
+
)
|
|
415
|
+
if len(sample_ranges) > 6:
|
|
416
|
+
sample_desc += ", ..."
|
|
417
|
+
log_fn(
|
|
418
|
+
f"Quick-check inference on {len(sample_ranges)} sample range(s): {sample_desc}"
|
|
419
|
+
)
|
|
420
|
+
chunk_idx = 0
|
|
421
|
+
for chunk_clips, chunk_starts, chunk_fullres in self._iter_clips_prefetched(
|
|
422
|
+
video_path, self.target_fps, self.clip_length, self.step_frames, chunk_size=64,
|
|
423
|
+
):
|
|
424
|
+
if self.should_stop:
|
|
425
|
+
break
|
|
426
|
+
chunk_idx += 1
|
|
427
|
+
clip_starts.extend(chunk_starts)
|
|
428
|
+
work_clips = chunk_clips
|
|
429
|
+
used_refined_roi = False
|
|
430
|
+
if self.use_localization_pipeline and getattr(self.model, "use_localization", False):
|
|
431
|
+
loc_input = chunk_clips
|
|
432
|
+
loc_bboxes = self._predict_localization(loc_input, device, batch_size=batch_size)
|
|
433
|
+
if len(loc_bboxes) == len(chunk_clips):
|
|
434
|
+
localization_bboxes.extend(loc_bboxes)
|
|
435
|
+
else:
|
|
436
|
+
localization_bboxes.extend([[0.0, 0.0, 1.0, 1.0]] * len(chunk_clips))
|
|
437
|
+
if len(loc_bboxes) == len(chunk_clips) and chunk_fullres:
|
|
438
|
+
work_clips = self._build_refined_clips(chunk_fullres, loc_bboxes)
|
|
439
|
+
used_refined_roi = True
|
|
440
|
+
for i in range(0, len(work_clips), batch_size):
|
|
441
|
+
if self.should_stop:
|
|
442
|
+
break
|
|
443
|
+
batch_clips = work_clips[i:i+batch_size]
|
|
444
|
+
batch_tensors = []
|
|
445
|
+
for clip in batch_clips:
|
|
446
|
+
clip_tensor = torch.from_numpy(clip)
|
|
447
|
+
clip_tensor = clip_tensor.permute(0, 3, 1, 2)
|
|
448
|
+
batch_tensors.append(clip_tensor)
|
|
449
|
+
batch = torch.stack(batch_tensors).to(device)
|
|
450
|
+
with torch.no_grad():
|
|
451
|
+
logits = self.model(batch, return_frame_logits=has_frame_head,
|
|
452
|
+
return_attn_weights=self.collect_attention)
|
|
453
|
+
batch_frame_probs = None
|
|
454
|
+
batch_frame_logits = None
|
|
455
|
+
batch_frame_embs = None
|
|
456
|
+
batch_attn_maps = None
|
|
457
|
+
_fo = getattr(self.model, '_frame_output', None)
|
|
458
|
+
if _fo is not None:
|
|
459
|
+
f_logits = _fo[0]
|
|
460
|
+
batch_frame_logits = f_logits.detach().cpu().numpy()
|
|
461
|
+
if self.use_ovr:
|
|
462
|
+
batch_frame_probs = torch.sigmoid(self._apply_ovr_temperature(f_logits)).detach().cpu().numpy()
|
|
463
|
+
else:
|
|
464
|
+
batch_frame_probs = torch.softmax(f_logits, dim=-1).detach().cpu().numpy()
|
|
465
|
+
_emb_src = (_fo[7] if len(_fo) > 7 and _fo[7] is not None else
|
|
466
|
+
(_fo[6] if len(_fo) > 6 and _fo[6] is not None else None))
|
|
467
|
+
if _emb_src is not None:
|
|
468
|
+
batch_frame_embs = _emb_src.detach().cpu().numpy()
|
|
469
|
+
if self.collect_attention and len(_fo) > 8 and _fo[8] is not None:
|
|
470
|
+
batch_attn_maps = _fo[8].detach().cpu().numpy()
|
|
471
|
+
if batch_frame_probs is not None:
|
|
472
|
+
for b_i in range(batch_frame_probs.shape[0]):
|
|
473
|
+
clip_frame_probabilities.append(batch_frame_probs[b_i].tolist())
|
|
474
|
+
if batch_frame_logits is not None:
|
|
475
|
+
clip_frame_logits.append(batch_frame_logits[b_i])
|
|
476
|
+
else:
|
|
477
|
+
clip_frame_logits.append(None)
|
|
478
|
+
if batch_frame_embs is not None:
|
|
479
|
+
clip_frame_embeddings.append(batch_frame_embs[b_i])
|
|
480
|
+
else:
|
|
481
|
+
clip_frame_embeddings.append(None)
|
|
482
|
+
if batch_attn_maps is not None:
|
|
483
|
+
clip_attention_maps.append(batch_attn_maps[b_i])
|
|
484
|
+
elif self.collect_attention:
|
|
485
|
+
clip_attention_maps.append(None)
|
|
486
|
+
else:
|
|
487
|
+
for b_i in range(len(batch_clips)):
|
|
488
|
+
clip_frame_probabilities.append(None)
|
|
489
|
+
clip_frame_logits.append(None)
|
|
490
|
+
clip_frame_embeddings.append(None)
|
|
491
|
+
if self.collect_attention:
|
|
492
|
+
clip_attention_maps.append(None)
|
|
493
|
+
del batch
|
|
494
|
+
if isinstance(logits, tuple):
|
|
495
|
+
logits = logits[0]
|
|
496
|
+
if self.use_ovr:
|
|
497
|
+
probs = torch.sigmoid(self._apply_ovr_temperature(logits))
|
|
498
|
+
else:
|
|
499
|
+
probs = torch.softmax(logits, dim=1)
|
|
500
|
+
preds = torch.argmax(probs, dim=1)
|
|
501
|
+
confs = torch.max(probs, dim=1)[0]
|
|
502
|
+
predictions.extend(preds.cpu().numpy().tolist())
|
|
503
|
+
confidences.extend(confs.cpu().numpy().tolist())
|
|
504
|
+
clip_probabilities.extend(probs.detach().cpu().numpy().tolist())
|
|
505
|
+
del logits, probs, preds, confs
|
|
506
|
+
self._cleanup_memory(device)
|
|
507
|
+
if chunk_idx % 5 == 0:
|
|
508
|
+
log_fn(f"Processed {len(predictions)} clips so far...")
|
|
509
|
+
if not predictions:
|
|
510
|
+
log_fn(f"Warning: No clips extracted from {os.path.basename(video_path)}")
|
|
511
|
+
continue
|
|
512
|
+
if clip_frame_probabilities:
|
|
513
|
+
sample = clip_frame_probabilities[:min(3, len(clip_frame_probabilities))]
|
|
514
|
+
for ci, fp in enumerate(sample):
|
|
515
|
+
arr = np.asarray(fp, dtype=np.float32)
|
|
516
|
+
if arr.ndim == 2:
|
|
517
|
+
means = arr.mean(axis=0)
|
|
518
|
+
maxs = arr.max(axis=0)
|
|
519
|
+
log_fn(f" Clip {ci} frame probs — mean: {np.array2string(means, precision=3)}, max: {np.array2string(maxs, precision=3)}")
|
|
520
|
+
cap_info = cv2.VideoCapture(video_path)
|
|
521
|
+
video_total_frames = int(cap_info.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
522
|
+
video_orig_fps = cap_info.get(cv2.CAP_PROP_FPS)
|
|
523
|
+
if video_orig_fps <= 0:
|
|
524
|
+
video_orig_fps = 30.0
|
|
525
|
+
cap_info.release()
|
|
526
|
+
video_frame_interval = max(1, int(round(video_orig_fps / max(1e-6, float(self.target_fps)))))
|
|
527
|
+
aggregated_frame_probs = None
|
|
528
|
+
aggregated_frame_logits = None
|
|
529
|
+
aggregated_frame_embs = None
|
|
530
|
+
has_embeddings = any(e is not None for e in clip_frame_embeddings)
|
|
531
|
+
has_frame_logits = any(fl is not None for fl in clip_frame_logits)
|
|
532
|
+
embed_dim = clip_frame_embeddings[0].shape[-1] if has_embeddings and clip_frame_embeddings[0] is not None else 0
|
|
533
|
+
if clip_frame_probabilities and any(p is not None for p in clip_frame_probabilities):
|
|
534
|
+
log_fn("Aggregating per-frame outputs (center/confidence-weighted overlap merge)...")
|
|
535
|
+
num_classes = len(self.classes)
|
|
536
|
+
agg_probs = np.zeros((frame_count, num_classes), dtype=np.float32)
|
|
537
|
+
agg_logits = np.zeros((frame_count, num_classes), dtype=np.float32) if has_frame_logits else None
|
|
538
|
+
agg_counts = np.zeros((frame_count, 1), dtype=np.float32)
|
|
539
|
+
if has_embeddings and embed_dim > 0:
|
|
540
|
+
agg_embs = np.zeros((frame_count, embed_dim), dtype=np.float32)
|
|
541
|
+
else:
|
|
542
|
+
agg_embs = None
|
|
543
|
+
for i, probs in enumerate(clip_frame_probabilities):
|
|
544
|
+
if probs is None:
|
|
545
|
+
continue
|
|
546
|
+
if i >= len(clip_starts):
|
|
547
|
+
break
|
|
548
|
+
start_f = clip_starts[i]
|
|
549
|
+
probs_arr = np.clip(np.array(probs, dtype=np.float32), 0.0, None)
|
|
550
|
+
if probs_arr.ndim != 2 or probs_arr.shape[1] != num_classes:
|
|
551
|
+
continue
|
|
552
|
+
T = int(probs_arr.shape[0])
|
|
553
|
+
merge_w = self._build_center_merge_weights(T)
|
|
554
|
+
# Confidence-weight each frame's contribution, but keep a
|
|
555
|
+
# floor so uncertain clips still add a little evidence.
|
|
556
|
+
frame_conf = np.clip(np.max(probs_arr, axis=1), 0.0, 1.0)
|
|
557
|
+
conf_w = np.clip(0.1 + 0.9 * frame_conf, 0.1, 1.0).astype(np.float32)
|
|
558
|
+
logits_arr = None
|
|
559
|
+
emb_arr = None
|
|
560
|
+
if agg_logits is not None and i < len(clip_frame_logits) and clip_frame_logits[i] is not None:
|
|
561
|
+
logits_arr = np.array(clip_frame_logits[i], dtype=np.float32)
|
|
562
|
+
if agg_embs is not None and clip_frame_embeddings[i] is not None:
|
|
563
|
+
emb_arr = np.array(clip_frame_embeddings[i], dtype=np.float32)
|
|
564
|
+
for t in range(T):
|
|
565
|
+
f_start = start_f + t * frame_interval
|
|
566
|
+
f_end = min(f_start + frame_interval, frame_count)
|
|
567
|
+
if f_start >= frame_count:
|
|
568
|
+
break
|
|
569
|
+
if f_end <= f_start:
|
|
570
|
+
continue
|
|
571
|
+
w = float(merge_w[t] * conf_w[t])
|
|
572
|
+
agg_probs[f_start:f_end] += probs_arr[t][np.newaxis, :] * w
|
|
573
|
+
agg_counts[f_start:f_end] += w
|
|
574
|
+
if logits_arr is not None and t < logits_arr.shape[0]:
|
|
575
|
+
agg_logits[f_start:f_end] += logits_arr[t][np.newaxis, :] * w
|
|
576
|
+
if emb_arr is not None and t < emb_arr.shape[0]:
|
|
577
|
+
agg_embs[f_start:f_end] += emb_arr[t][np.newaxis, :] * w
|
|
578
|
+
agg_probs = agg_probs / np.maximum(agg_counts, 1.0)
|
|
579
|
+
if agg_logits is not None:
|
|
580
|
+
agg_logits = agg_logits / np.maximum(agg_counts, 1.0)
|
|
581
|
+
aggregated_frame_logits = agg_logits
|
|
582
|
+
if agg_embs is not None:
|
|
583
|
+
agg_embs = agg_embs / np.maximum(agg_counts, 1.0)
|
|
584
|
+
aggregated_frame_embs = agg_embs
|
|
585
|
+
if not self.use_ovr:
|
|
586
|
+
covered_mask = agg_counts.squeeze(-1) > 0
|
|
587
|
+
row_sums = agg_probs[covered_mask].sum(axis=1, keepdims=True)
|
|
588
|
+
safe_sums = np.maximum(row_sums, 1e-8)
|
|
589
|
+
agg_probs[covered_mask] = agg_probs[covered_mask] / safe_sums
|
|
590
|
+
aggregated_frame_probs = agg_probs
|
|
591
|
+
clip_frame_embeddings = None
|
|
592
|
+
res_entry = {
|
|
593
|
+
"predictions": predictions,
|
|
594
|
+
"confidences": confidences,
|
|
595
|
+
"clip_probabilities": clip_probabilities,
|
|
596
|
+
"clip_starts": clip_starts,
|
|
597
|
+
"total_frames": video_total_frames,
|
|
598
|
+
"orig_fps": video_orig_fps,
|
|
599
|
+
"frame_interval": video_frame_interval,
|
|
600
|
+
"aggregated_frame_probs": aggregated_frame_probs,
|
|
601
|
+
}
|
|
602
|
+
if sample_ranges:
|
|
603
|
+
res_entry["sample_ranges"] = sample_ranges
|
|
604
|
+
if clip_frame_probabilities:
|
|
605
|
+
res_entry["clip_frame_probabilities"] = clip_frame_probabilities
|
|
606
|
+
if localization_bboxes:
|
|
607
|
+
res_entry["localization_bboxes"] = localization_bboxes
|
|
608
|
+
if self.collect_attention and clip_attention_maps:
|
|
609
|
+
res_entry["clip_attention_maps"] = clip_attention_maps
|
|
610
|
+
results[video_path] = res_entry
|
|
611
|
+
self.video_done.emit(video_path, res_entry)
|
|
612
|
+
except RuntimeError as video_err:
|
|
613
|
+
if self._is_cuda_oom(video_err):
|
|
614
|
+
log_fn(
|
|
615
|
+
f"CUDA OOM while processing {os.path.basename(video_path)}. "
|
|
616
|
+
f"Skipping this video and continuing."
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
log_fn(
|
|
620
|
+
f"Error processing {os.path.basename(video_path)}: {video_err}\n"
|
|
621
|
+
f"{traceback.format_exc()}"
|
|
622
|
+
)
|
|
623
|
+
continue
|
|
624
|
+
except Exception as video_err:
|
|
625
|
+
log_fn(
|
|
626
|
+
f"Error processing {os.path.basename(video_path)}: {video_err}\n"
|
|
627
|
+
f"{traceback.format_exc()}"
|
|
628
|
+
)
|
|
629
|
+
continue
|
|
630
|
+
finally:
|
|
631
|
+
clip_frame_embeddings = None
|
|
632
|
+
aggregated_frame_embs = None
|
|
633
|
+
self._cleanup_memory(device)
|
|
634
|
+
if self.should_stop:
|
|
635
|
+
if results:
|
|
636
|
+
log_fn("Inference stopped by user. Keeping results for completed videos.")
|
|
637
|
+
else:
|
|
638
|
+
log_fn("Inference stopped by user. No videos were completed.")
|
|
639
|
+
if not results:
|
|
640
|
+
if not self.should_stop:
|
|
641
|
+
self.error.emit("No results generated.")
|
|
642
|
+
else:
|
|
643
|
+
self.finished.emit({})
|
|
644
|
+
return
|
|
645
|
+
if not self.should_stop:
|
|
646
|
+
log_fn("Inference complete!")
|
|
647
|
+
self.finished.emit(results)
|
|
648
|
+
except Exception as e:
|
|
649
|
+
error_msg = f"Inference failed: {str(e)}\n{traceback.format_exc()}"
|
|
650
|
+
self.log_message.emit(f"ERROR: {error_msg}")
|
|
651
|
+
self.error.emit(error_msg)
|