singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,651 @@
1
+ """Inference worker thread: runs model on video(s) and emits results."""
2
+ from PyQt6.QtCore import QThread, pyqtSignal
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+
8
+
9
+ def _sanitize_bbox_coords(x1, y1, x2, y2, crop_padding=0.35, crop_min_size=0.04):
10
+ """Pad bbox proportionally to its own size (matches training _sanitize_bboxes)."""
11
+ x1, y1 = max(0.0, min(1.0, x1)), max(0.0, min(1.0, y1))
12
+ x2, y2 = max(0.0, min(1.0, x2)), max(0.0, min(1.0, y2))
13
+ cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
14
+ w = min(max(abs(x2 - x1), crop_min_size) * (1.0 + 2.0 * crop_padding), 1.0)
15
+ h = min(max(abs(y2 - y1), crop_min_size) * (1.0 + 2.0 * crop_padding), 1.0)
16
+ x1, y1 = max(0.0, cx - 0.5 * w), max(0.0, cy - 0.5 * h)
17
+ x2, y2 = min(1.0, cx + 0.5 * w), min(1.0, cy + 0.5 * h)
18
+ x2, y2 = max(x2, x1 + crop_min_size), max(y2, y1 + crop_min_size)
19
+ return (max(0.0, min(1.0, x1)), max(0.0, min(1.0, y1)),
20
+ max(0.0, min(1.0, x2)), max(0.0, min(1.0, y2)))
21
+
22
+
23
+ class InferenceWorker(QThread):
24
+ """Worker thread for running inference."""
25
+ progress = pyqtSignal(int, int)
26
+ finished = pyqtSignal(dict)
27
+ error = pyqtSignal(str)
28
+ log_message = pyqtSignal(str)
29
+ video_done = pyqtSignal(str, dict)
30
+
31
+ def __init__(
32
+ self,
33
+ model,
34
+ video_paths,
35
+ target_fps,
36
+ clip_length,
37
+ step_frames,
38
+ resolution=288,
39
+ classes=None,
40
+ use_localization_pipeline=False,
41
+ crop_padding=0.35,
42
+ crop_min_size=0.04,
43
+ use_ovr=False,
44
+ ovr_temperatures=None,
45
+ collect_attention=False,
46
+ sample_ranges_by_video=None,
47
+ ):
48
+ super().__init__()
49
+ self.model = model
50
+ self.video_paths = video_paths if isinstance(video_paths, list) else [video_paths]
51
+ self.target_fps = target_fps
52
+ self.clip_length = clip_length
53
+ self.step_frames = step_frames
54
+ self.resolution = int(resolution)
55
+ self.classes = classes or []
56
+ self.use_localization_pipeline = bool(use_localization_pipeline)
57
+ self.crop_padding = float(crop_padding)
58
+ self.crop_min_size = float(crop_min_size)
59
+ self.use_ovr = bool(use_ovr)
60
+ self._ovr_temperatures = dict(ovr_temperatures or {})
61
+ self.collect_attention = bool(collect_attention)
62
+ self.sample_ranges_by_video = dict(sample_ranges_by_video or {})
63
+ self.should_stop = False
64
+
65
+ def _apply_ovr_temperature(self, logits: torch.Tensor) -> torch.Tensor:
66
+ """Scale logits by per-class calibrated temperature before sigmoid.
67
+
68
+ logits: [..., C] tensor. Divides each class dimension by its
69
+ temperature (default 1.0 = no change).
70
+ """
71
+ if not self._ovr_temperatures or not self.classes:
72
+ return logits
73
+ C = logits.shape[-1]
74
+ temps = torch.ones(C, device=logits.device, dtype=logits.dtype)
75
+ for ci, cls_name in enumerate(self.classes):
76
+ if ci < C and cls_name in self._ovr_temperatures:
77
+ t = float(self._ovr_temperatures[cls_name])
78
+ temps[ci] = max(t, 0.01)
79
+ return logits / temps
80
+
81
+ def stop(self):
82
+ self.should_stop = True
83
+
84
+ def _build_center_merge_weights(self, length: int) -> np.ndarray:
85
+ """Center-heavy temporal weights for overlap aggregation.
86
+
87
+ Clip-edge predictions tend to be less reliable because they have less
88
+ temporal context. Use a Hann-shaped profile with a small floor so edge
89
+ frames still contribute when only one clip covers them.
90
+ """
91
+ if length <= 1:
92
+ return np.ones((max(1, length),), dtype=np.float32)
93
+ w = np.hanning(length).astype(np.float32)
94
+ if not np.any(w > 0):
95
+ return np.ones((length,), dtype=np.float32)
96
+ return np.clip(0.1 + 0.9 * w, 1e-3, None).astype(np.float32)
97
+
98
+ def _iter_clips_prefetched(self, video_path, target_fps, clip_length, step_frames, chunk_size=64):
99
+ from threading import Thread
100
+ from queue import Queue
101
+ q = Queue(maxsize=2)
102
+
103
+ def _producer():
104
+ try:
105
+ for chunk in self._iter_clips_in_chunks(video_path, target_fps, clip_length, step_frames, chunk_size):
106
+ q.put(chunk)
107
+ if self.should_stop:
108
+ break
109
+ finally:
110
+ q.put(None)
111
+
112
+ t = Thread(target=_producer, daemon=True)
113
+ t.start()
114
+ while True:
115
+ item = q.get()
116
+ if item is None:
117
+ break
118
+ yield item
119
+ t.join(timeout=2.0)
120
+
121
+ def _normalize_sample_ranges(self, sample_ranges, total_frames):
122
+ """Clamp and sort sample ranges as [(start, end), ...] with end-exclusive bounds."""
123
+ if not sample_ranges:
124
+ return []
125
+ normalized = []
126
+ for rng in sample_ranges:
127
+ if not isinstance(rng, (list, tuple)) or len(rng) != 2:
128
+ continue
129
+ start, end = int(rng[0]), int(rng[1])
130
+ start = max(0, min(total_frames, start))
131
+ end = max(start, min(total_frames, end))
132
+ if end > start:
133
+ normalized.append((start, end))
134
+ normalized.sort()
135
+ merged = []
136
+ for start, end in normalized:
137
+ if not merged or start > merged[-1][1]:
138
+ merged.append([start, end])
139
+ else:
140
+ merged[-1][1] = max(merged[-1][1], end)
141
+ return [(start, end) for start, end in merged]
142
+
143
+ def _estimate_subsampled_frames_in_ranges(self, sample_ranges, frame_interval):
144
+ total = 0
145
+ for start, end in sample_ranges:
146
+ first = start if start % frame_interval == 0 else start + (frame_interval - (start % frame_interval))
147
+ if first >= end:
148
+ continue
149
+ total += 1 + ((end - 1 - first) // frame_interval)
150
+ return total
151
+
152
+ def _iter_clips_in_chunks(self, video_path, target_fps, clip_length, step_frames, chunk_size=64):
153
+ cap = cv2.VideoCapture(video_path)
154
+ if not cap.isOpened():
155
+ raise ValueError(f"Could not open video: {video_path}")
156
+
157
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
158
+ if orig_fps <= 0:
159
+ orig_fps = 30.0
160
+
161
+ frame_interval = max(1, int(round(orig_fps / target_fps)))
162
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
163
+ sample_ranges = self._normalize_sample_ranges(
164
+ self.sample_ranges_by_video.get(video_path),
165
+ total_frames,
166
+ )
167
+ if sample_ranges:
168
+ total_subsampled = self._estimate_subsampled_frames_in_ranges(sample_ranges, frame_interval)
169
+ else:
170
+ total_subsampled = total_frames // frame_interval
171
+ if total_subsampled < clip_length:
172
+ total_clips_estimate = 0
173
+ elif step_frames <= 0:
174
+ total_clips_estimate = max(1, total_subsampled // max(1, clip_length))
175
+ else:
176
+ total_clips_estimate = 1 + max(0, (total_subsampled - clip_length) // step_frames)
177
+ total_clips_estimate = max(1, total_clips_estimate)
178
+
179
+ clip_starts = []
180
+ frames_buffer = []
181
+ frames_buffer_fullres = []
182
+ keep_fullres_cache = self.use_localization_pipeline and getattr(self.model, "use_localization", False)
183
+ chunk_clips = []
184
+ chunk_starts = []
185
+ chunk_fullres = []
186
+ clip_idx = 0
187
+
188
+ def _iter_ranges():
189
+ if sample_ranges:
190
+ for start, end in sample_ranges:
191
+ yield start, end
192
+ else:
193
+ yield 0, total_frames
194
+
195
+ for range_start, range_end in _iter_ranges():
196
+ if self.should_stop:
197
+ break
198
+ cap.set(cv2.CAP_PROP_POS_FRAMES, range_start)
199
+ frames_buffer = []
200
+ frames_buffer_fullres = []
201
+ skip_remaining = 0
202
+ frame_idx = range_start
203
+
204
+ while frame_idx < range_end:
205
+ if self.should_stop:
206
+ break
207
+ ret, frame = cap.read()
208
+ if not ret:
209
+ break
210
+ if frame_idx % frame_interval == 0:
211
+ if skip_remaining > 0:
212
+ skip_remaining -= 1
213
+ frame_idx += 1
214
+ continue
215
+ if keep_fullres_cache:
216
+ frames_buffer_fullres.append(frame.copy())
217
+ h_src, w_src = frame.shape[:2]
218
+ is_upscale = (w_src < self.resolution) or (h_src < self.resolution)
219
+ interp = cv2.INTER_LANCZOS4 if is_upscale else cv2.INTER_AREA
220
+ frame_resized = cv2.resize(frame, (self.resolution, self.resolution), interpolation=interp)
221
+ if is_upscale:
222
+ blurred = cv2.GaussianBlur(frame_resized, (0, 0), sigmaX=1.0)
223
+ frame_resized = cv2.addWeighted(frame_resized, 1.5, blurred, -0.5, 0)
224
+ frame_resized = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
225
+ frames_buffer.append(frame_resized)
226
+ if len(frames_buffer) == clip_length:
227
+ clip_array = np.stack(frames_buffer).astype(np.float32) / 255.0
228
+ clip_start = frame_idx - (clip_length - 1) * frame_interval
229
+ clip_starts.append(clip_start)
230
+ chunk_clips.append(clip_array)
231
+ chunk_starts.append(clip_start)
232
+ if keep_fullres_cache:
233
+ chunk_fullres.append(np.stack(frames_buffer_fullres))
234
+ clip_idx += 1
235
+ if self.progress:
236
+ self.progress.emit(clip_idx, total_clips_estimate)
237
+ if step_frames < clip_length:
238
+ frames_buffer = frames_buffer[step_frames:]
239
+ if keep_fullres_cache:
240
+ frames_buffer_fullres = frames_buffer_fullres[step_frames:]
241
+ else:
242
+ frames_buffer = []
243
+ frames_buffer_fullres = []
244
+ skip_remaining = max(0, step_frames - clip_length)
245
+ if len(chunk_clips) >= max(1, int(chunk_size)):
246
+ yield chunk_clips, chunk_starts, chunk_fullres
247
+ chunk_clips = []
248
+ chunk_starts = []
249
+ chunk_fullres = []
250
+ frame_idx += 1
251
+ cap.release()
252
+ if chunk_clips:
253
+ yield chunk_clips, chunk_starts, chunk_fullres
254
+
255
+ def _predict_localization(self, clips, device, batch_size=4):
256
+ loc_bboxes = []
257
+ for i in range(0, len(clips), batch_size):
258
+ if self.should_stop:
259
+ break
260
+ batch_clips = clips[i:i+batch_size]
261
+ batch_tensors = []
262
+ for clip in batch_clips:
263
+ clip_tensor = torch.from_numpy(clip).permute(0, 3, 1, 2)
264
+ batch_tensors.append(clip_tensor)
265
+ batch = torch.stack(batch_tensors).to(device)
266
+ with torch.no_grad():
267
+ out = self.model(batch, return_localization=True)
268
+ del batch
269
+ loc_part = None
270
+ if isinstance(out, tuple) and len(out) >= 2 and torch.is_tensor(out[-1]) and out[-1].shape[-1] == 4:
271
+ loc_part = out[-1]
272
+ if loc_part is None:
273
+ loc_part = torch.tensor([[0.0, 0.0, 1.0, 1.0]] * len(batch_clips), device=device)
274
+ loc_np = loc_part.detach().cpu().numpy()
275
+ if loc_part.dim() == 3:
276
+ loc_bboxes.extend(loc_np.tolist())
277
+ else:
278
+ loc_bboxes.extend(loc_np.tolist())
279
+ return loc_bboxes
280
+
281
+ def _build_refined_clips(self, fullres_clip_frames, loc_bboxes):
282
+ refined = []
283
+ for idx, frames_bgr in enumerate(fullres_clip_frames):
284
+ bbox = loc_bboxes[idx] if idx < len(loc_bboxes) else [0.0, 0.0, 1.0, 1.0]
285
+ has_temporal_boxes = (
286
+ isinstance(bbox, (list, tuple))
287
+ and len(bbox) > 0
288
+ and isinstance(bbox[0], (list, tuple))
289
+ and len(bbox[0]) == 4
290
+ )
291
+ if has_temporal_boxes:
292
+ b0 = [float(v) for v in bbox[0]]
293
+ x1_fixed, y1_fixed, x2_fixed, y2_fixed = _sanitize_bbox_coords(*b0, self.crop_padding, self.crop_min_size)
294
+ clip_frames = []
295
+ for fi, frame_bgr in enumerate(frames_bgr):
296
+ if has_temporal_boxes:
297
+ x1, y1, x2, y2 = x1_fixed, y1_fixed, x2_fixed, y2_fixed
298
+ else:
299
+ x1, y1, x2, y2 = _sanitize_bbox_coords(*[float(v) for v in bbox], self.crop_padding, self.crop_min_size)
300
+ h, w = frame_bgr.shape[:2]
301
+ fx1 = int(round(x1 * w))
302
+ fy1 = int(round(y1 * h))
303
+ fx2 = int(round(x2 * w))
304
+ fy2 = int(round(y2 * h))
305
+ fx1 = max(0, min(fx1, w - 1))
306
+ fy1 = max(0, min(fy1, h - 1))
307
+ fx2 = max(fx1 + 1, min(fx2, w))
308
+ fy2 = max(fy1 + 1, min(fy2, h))
309
+ crop = frame_bgr[fy1:fy2, fx1:fx2]
310
+ if crop.size == 0:
311
+ crop = frame_bgr
312
+ # Match training: BGR→RGB, [0,1] float, then PyTorch bilinear resize
313
+ crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
314
+ crop_f = torch.from_numpy(crop_rgb.astype(np.float32) / 255.0)
315
+ crop_f = crop_f.permute(2, 0, 1).unsqueeze(0) # [1,C,H,W]
316
+ crop_f = torch.nn.functional.interpolate(
317
+ crop_f, size=(self.resolution, self.resolution),
318
+ mode='bilinear', align_corners=False,
319
+ )
320
+ clip_frames.append(crop_f.squeeze(0)) # [C,H,W]
321
+ if not clip_frames:
322
+ clip_frames = [torch.zeros(3, self.resolution, self.resolution)] * self.clip_length
323
+ while len(clip_frames) < self.clip_length:
324
+ clip_frames.append(clip_frames[-1].clone())
325
+ clip_frames = clip_frames[:self.clip_length]
326
+ # Stack to [T,C,H,W] then back to [T,H,W,C] numpy for the existing pipeline
327
+ clip_tensor = torch.stack(clip_frames) # [T,C,H,W]
328
+ clip_np = clip_tensor.permute(0, 2, 3, 1).numpy() # [T,H,W,C]
329
+ refined.append(clip_np)
330
+ return refined
331
+
332
+ def _is_cuda_oom(self, err: Exception) -> bool:
333
+ if not isinstance(err, RuntimeError):
334
+ return False
335
+ msg = str(err).lower()
336
+ return ("out of memory" in msg) and ("cuda" in msg or "cublas" in msg or "cudnn" in msg)
337
+
338
+ def _cleanup_memory(self, device=None):
339
+ import gc
340
+ gc.collect()
341
+ if device is not None and getattr(device, "type", None) == "cuda" and torch.cuda.is_available():
342
+ try:
343
+ torch.cuda.empty_cache()
344
+ except Exception:
345
+ pass
346
+ try:
347
+ jax.clear_caches()
348
+ except Exception:
349
+ pass
350
+
351
+ def run(self):
352
+ import traceback
353
+ try:
354
+ def log_fn(msg):
355
+ self.log_message.emit(msg)
356
+ results = {}
357
+ total_videos = len(self.video_paths)
358
+ for v_idx, video_path in enumerate(self.video_paths):
359
+ if self.should_stop:
360
+ break
361
+ log_fn(f"Processing video {v_idx+1}/{total_videos}: {os.path.basename(video_path)}")
362
+ log_fn("Extracting clips and running inference in memory-safe chunks...")
363
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
364
+ clip_frame_embeddings = None
365
+ aggregated_frame_embs = None
366
+ try:
367
+ self.model.to(device)
368
+ self.model.eval()
369
+ predictions = []
370
+ confidences = []
371
+ clip_probabilities = []
372
+ clip_frame_probabilities = []
373
+ clip_frame_logits = []
374
+ clip_frame_embeddings = []
375
+ clip_attention_maps = []
376
+ clip_starts = []
377
+ localization_bboxes = []
378
+ has_frame_head = getattr(self.model, "use_frame_head", True)
379
+ if self.use_localization_pipeline and getattr(self.model, "use_localization", False):
380
+ log_fn("Localization pipeline enabled")
381
+ resolution = getattr(self.model.backbone, 'resolution', 288)
382
+ if device.type == "cuda":
383
+ try:
384
+ free_mb = torch.cuda.mem_get_info(device)[0] / (1024 * 1024)
385
+ except Exception:
386
+ free_mb = 4000
387
+ pixels = resolution * resolution
388
+ # Scale relative to 288x288 baseline (~250 MB/clip including activations)
389
+ scale = pixels / (288 * 288)
390
+ est_per_clip_mb = 250 * scale
391
+ batch_size = max(1, int(free_mb * 0.5 / est_per_clip_mb))
392
+ batch_size = min(batch_size, 32)
393
+ if resolution > 300:
394
+ batch_size = min(batch_size, 16)
395
+ if resolution > 400:
396
+ batch_size = min(batch_size, 4)
397
+ else:
398
+ batch_size = 2
399
+ log_fn(f"Inference batch size: {batch_size} (resolution={resolution}, device={device})")
400
+ cap = cv2.VideoCapture(video_path)
401
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
402
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
403
+ cap.release()
404
+ if orig_fps <= 0:
405
+ orig_fps = 30.0
406
+ frame_interval = max(1, int(round(orig_fps / self.target_fps)))
407
+ sample_ranges = self._normalize_sample_ranges(
408
+ self.sample_ranges_by_video.get(video_path),
409
+ frame_count,
410
+ )
411
+ if sample_ranges:
412
+ sample_desc = ", ".join(
413
+ f"{start/orig_fps:.1f}s-{end/orig_fps:.1f}s" for start, end in sample_ranges[:6]
414
+ )
415
+ if len(sample_ranges) > 6:
416
+ sample_desc += ", ..."
417
+ log_fn(
418
+ f"Quick-check inference on {len(sample_ranges)} sample range(s): {sample_desc}"
419
+ )
420
+ chunk_idx = 0
421
+ for chunk_clips, chunk_starts, chunk_fullres in self._iter_clips_prefetched(
422
+ video_path, self.target_fps, self.clip_length, self.step_frames, chunk_size=64,
423
+ ):
424
+ if self.should_stop:
425
+ break
426
+ chunk_idx += 1
427
+ clip_starts.extend(chunk_starts)
428
+ work_clips = chunk_clips
429
+ used_refined_roi = False
430
+ if self.use_localization_pipeline and getattr(self.model, "use_localization", False):
431
+ loc_input = chunk_clips
432
+ loc_bboxes = self._predict_localization(loc_input, device, batch_size=batch_size)
433
+ if len(loc_bboxes) == len(chunk_clips):
434
+ localization_bboxes.extend(loc_bboxes)
435
+ else:
436
+ localization_bboxes.extend([[0.0, 0.0, 1.0, 1.0]] * len(chunk_clips))
437
+ if len(loc_bboxes) == len(chunk_clips) and chunk_fullres:
438
+ work_clips = self._build_refined_clips(chunk_fullres, loc_bboxes)
439
+ used_refined_roi = True
440
+ for i in range(0, len(work_clips), batch_size):
441
+ if self.should_stop:
442
+ break
443
+ batch_clips = work_clips[i:i+batch_size]
444
+ batch_tensors = []
445
+ for clip in batch_clips:
446
+ clip_tensor = torch.from_numpy(clip)
447
+ clip_tensor = clip_tensor.permute(0, 3, 1, 2)
448
+ batch_tensors.append(clip_tensor)
449
+ batch = torch.stack(batch_tensors).to(device)
450
+ with torch.no_grad():
451
+ logits = self.model(batch, return_frame_logits=has_frame_head,
452
+ return_attn_weights=self.collect_attention)
453
+ batch_frame_probs = None
454
+ batch_frame_logits = None
455
+ batch_frame_embs = None
456
+ batch_attn_maps = None
457
+ _fo = getattr(self.model, '_frame_output', None)
458
+ if _fo is not None:
459
+ f_logits = _fo[0]
460
+ batch_frame_logits = f_logits.detach().cpu().numpy()
461
+ if self.use_ovr:
462
+ batch_frame_probs = torch.sigmoid(self._apply_ovr_temperature(f_logits)).detach().cpu().numpy()
463
+ else:
464
+ batch_frame_probs = torch.softmax(f_logits, dim=-1).detach().cpu().numpy()
465
+ _emb_src = (_fo[7] if len(_fo) > 7 and _fo[7] is not None else
466
+ (_fo[6] if len(_fo) > 6 and _fo[6] is not None else None))
467
+ if _emb_src is not None:
468
+ batch_frame_embs = _emb_src.detach().cpu().numpy()
469
+ if self.collect_attention and len(_fo) > 8 and _fo[8] is not None:
470
+ batch_attn_maps = _fo[8].detach().cpu().numpy()
471
+ if batch_frame_probs is not None:
472
+ for b_i in range(batch_frame_probs.shape[0]):
473
+ clip_frame_probabilities.append(batch_frame_probs[b_i].tolist())
474
+ if batch_frame_logits is not None:
475
+ clip_frame_logits.append(batch_frame_logits[b_i])
476
+ else:
477
+ clip_frame_logits.append(None)
478
+ if batch_frame_embs is not None:
479
+ clip_frame_embeddings.append(batch_frame_embs[b_i])
480
+ else:
481
+ clip_frame_embeddings.append(None)
482
+ if batch_attn_maps is not None:
483
+ clip_attention_maps.append(batch_attn_maps[b_i])
484
+ elif self.collect_attention:
485
+ clip_attention_maps.append(None)
486
+ else:
487
+ for b_i in range(len(batch_clips)):
488
+ clip_frame_probabilities.append(None)
489
+ clip_frame_logits.append(None)
490
+ clip_frame_embeddings.append(None)
491
+ if self.collect_attention:
492
+ clip_attention_maps.append(None)
493
+ del batch
494
+ if isinstance(logits, tuple):
495
+ logits = logits[0]
496
+ if self.use_ovr:
497
+ probs = torch.sigmoid(self._apply_ovr_temperature(logits))
498
+ else:
499
+ probs = torch.softmax(logits, dim=1)
500
+ preds = torch.argmax(probs, dim=1)
501
+ confs = torch.max(probs, dim=1)[0]
502
+ predictions.extend(preds.cpu().numpy().tolist())
503
+ confidences.extend(confs.cpu().numpy().tolist())
504
+ clip_probabilities.extend(probs.detach().cpu().numpy().tolist())
505
+ del logits, probs, preds, confs
506
+ self._cleanup_memory(device)
507
+ if chunk_idx % 5 == 0:
508
+ log_fn(f"Processed {len(predictions)} clips so far...")
509
+ if not predictions:
510
+ log_fn(f"Warning: No clips extracted from {os.path.basename(video_path)}")
511
+ continue
512
+ if clip_frame_probabilities:
513
+ sample = clip_frame_probabilities[:min(3, len(clip_frame_probabilities))]
514
+ for ci, fp in enumerate(sample):
515
+ arr = np.asarray(fp, dtype=np.float32)
516
+ if arr.ndim == 2:
517
+ means = arr.mean(axis=0)
518
+ maxs = arr.max(axis=0)
519
+ log_fn(f" Clip {ci} frame probs — mean: {np.array2string(means, precision=3)}, max: {np.array2string(maxs, precision=3)}")
520
+ cap_info = cv2.VideoCapture(video_path)
521
+ video_total_frames = int(cap_info.get(cv2.CAP_PROP_FRAME_COUNT))
522
+ video_orig_fps = cap_info.get(cv2.CAP_PROP_FPS)
523
+ if video_orig_fps <= 0:
524
+ video_orig_fps = 30.0
525
+ cap_info.release()
526
+ video_frame_interval = max(1, int(round(video_orig_fps / max(1e-6, float(self.target_fps)))))
527
+ aggregated_frame_probs = None
528
+ aggregated_frame_logits = None
529
+ aggregated_frame_embs = None
530
+ has_embeddings = any(e is not None for e in clip_frame_embeddings)
531
+ has_frame_logits = any(fl is not None for fl in clip_frame_logits)
532
+ embed_dim = clip_frame_embeddings[0].shape[-1] if has_embeddings and clip_frame_embeddings[0] is not None else 0
533
+ if clip_frame_probabilities and any(p is not None for p in clip_frame_probabilities):
534
+ log_fn("Aggregating per-frame outputs (center/confidence-weighted overlap merge)...")
535
+ num_classes = len(self.classes)
536
+ agg_probs = np.zeros((frame_count, num_classes), dtype=np.float32)
537
+ agg_logits = np.zeros((frame_count, num_classes), dtype=np.float32) if has_frame_logits else None
538
+ agg_counts = np.zeros((frame_count, 1), dtype=np.float32)
539
+ if has_embeddings and embed_dim > 0:
540
+ agg_embs = np.zeros((frame_count, embed_dim), dtype=np.float32)
541
+ else:
542
+ agg_embs = None
543
+ for i, probs in enumerate(clip_frame_probabilities):
544
+ if probs is None:
545
+ continue
546
+ if i >= len(clip_starts):
547
+ break
548
+ start_f = clip_starts[i]
549
+ probs_arr = np.clip(np.array(probs, dtype=np.float32), 0.0, None)
550
+ if probs_arr.ndim != 2 or probs_arr.shape[1] != num_classes:
551
+ continue
552
+ T = int(probs_arr.shape[0])
553
+ merge_w = self._build_center_merge_weights(T)
554
+ # Confidence-weight each frame's contribution, but keep a
555
+ # floor so uncertain clips still add a little evidence.
556
+ frame_conf = np.clip(np.max(probs_arr, axis=1), 0.0, 1.0)
557
+ conf_w = np.clip(0.1 + 0.9 * frame_conf, 0.1, 1.0).astype(np.float32)
558
+ logits_arr = None
559
+ emb_arr = None
560
+ if agg_logits is not None and i < len(clip_frame_logits) and clip_frame_logits[i] is not None:
561
+ logits_arr = np.array(clip_frame_logits[i], dtype=np.float32)
562
+ if agg_embs is not None and clip_frame_embeddings[i] is not None:
563
+ emb_arr = np.array(clip_frame_embeddings[i], dtype=np.float32)
564
+ for t in range(T):
565
+ f_start = start_f + t * frame_interval
566
+ f_end = min(f_start + frame_interval, frame_count)
567
+ if f_start >= frame_count:
568
+ break
569
+ if f_end <= f_start:
570
+ continue
571
+ w = float(merge_w[t] * conf_w[t])
572
+ agg_probs[f_start:f_end] += probs_arr[t][np.newaxis, :] * w
573
+ agg_counts[f_start:f_end] += w
574
+ if logits_arr is not None and t < logits_arr.shape[0]:
575
+ agg_logits[f_start:f_end] += logits_arr[t][np.newaxis, :] * w
576
+ if emb_arr is not None and t < emb_arr.shape[0]:
577
+ agg_embs[f_start:f_end] += emb_arr[t][np.newaxis, :] * w
578
+ agg_probs = agg_probs / np.maximum(agg_counts, 1.0)
579
+ if agg_logits is not None:
580
+ agg_logits = agg_logits / np.maximum(agg_counts, 1.0)
581
+ aggregated_frame_logits = agg_logits
582
+ if agg_embs is not None:
583
+ agg_embs = agg_embs / np.maximum(agg_counts, 1.0)
584
+ aggregated_frame_embs = agg_embs
585
+ if not self.use_ovr:
586
+ covered_mask = agg_counts.squeeze(-1) > 0
587
+ row_sums = agg_probs[covered_mask].sum(axis=1, keepdims=True)
588
+ safe_sums = np.maximum(row_sums, 1e-8)
589
+ agg_probs[covered_mask] = agg_probs[covered_mask] / safe_sums
590
+ aggregated_frame_probs = agg_probs
591
+ clip_frame_embeddings = None
592
+ res_entry = {
593
+ "predictions": predictions,
594
+ "confidences": confidences,
595
+ "clip_probabilities": clip_probabilities,
596
+ "clip_starts": clip_starts,
597
+ "total_frames": video_total_frames,
598
+ "orig_fps": video_orig_fps,
599
+ "frame_interval": video_frame_interval,
600
+ "aggregated_frame_probs": aggregated_frame_probs,
601
+ }
602
+ if sample_ranges:
603
+ res_entry["sample_ranges"] = sample_ranges
604
+ if clip_frame_probabilities:
605
+ res_entry["clip_frame_probabilities"] = clip_frame_probabilities
606
+ if localization_bboxes:
607
+ res_entry["localization_bboxes"] = localization_bboxes
608
+ if self.collect_attention and clip_attention_maps:
609
+ res_entry["clip_attention_maps"] = clip_attention_maps
610
+ results[video_path] = res_entry
611
+ self.video_done.emit(video_path, res_entry)
612
+ except RuntimeError as video_err:
613
+ if self._is_cuda_oom(video_err):
614
+ log_fn(
615
+ f"CUDA OOM while processing {os.path.basename(video_path)}. "
616
+ f"Skipping this video and continuing."
617
+ )
618
+ else:
619
+ log_fn(
620
+ f"Error processing {os.path.basename(video_path)}: {video_err}\n"
621
+ f"{traceback.format_exc()}"
622
+ )
623
+ continue
624
+ except Exception as video_err:
625
+ log_fn(
626
+ f"Error processing {os.path.basename(video_path)}: {video_err}\n"
627
+ f"{traceback.format_exc()}"
628
+ )
629
+ continue
630
+ finally:
631
+ clip_frame_embeddings = None
632
+ aggregated_frame_embs = None
633
+ self._cleanup_memory(device)
634
+ if self.should_stop:
635
+ if results:
636
+ log_fn("Inference stopped by user. Keeping results for completed videos.")
637
+ else:
638
+ log_fn("Inference stopped by user. No videos were completed.")
639
+ if not results:
640
+ if not self.should_stop:
641
+ self.error.emit("No results generated.")
642
+ else:
643
+ self.finished.emit({})
644
+ return
645
+ if not self.should_stop:
646
+ log_fn("Inference complete!")
647
+ self.finished.emit(results)
648
+ except Exception as e:
649
+ error_msg = f"Inference failed: {str(e)}\n{traceback.format_exc()}"
650
+ self.log_message.emit(f"ERROR: {error_msg}")
651
+ self.error.emit(error_msg)