singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,1290 @@
1
+ import logging
2
+ import os
3
+ import gc
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # JAX memory config MUST be set before importing jax.
8
+ # Without this, JAX grabs 75-90% of GPU memory upfront, leaving PyTorch
9
+ # starved and eventually causing CUDA_ERROR_ILLEGAL_ADDRESS under sustained
10
+ # inference workloads where both frameworks share the same GPU.
11
+ os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
12
+ os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from typing import Optional, Callable
18
+ import numpy as np
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from videoprism import models as vp
22
+
23
+
24
+ def interpolate_pos_embed_2d(pos_embed: np.ndarray, orig_grid: int, new_grid: int) -> np.ndarray:
25
+ """
26
+ Bicubic interpolation of 2D spatial position embeddings.
27
+
28
+ Args:
29
+ pos_embed: [N, D] where N = orig_grid * orig_grid
30
+ orig_grid: original spatial grid size (e.g., 16 for 288px)
31
+ new_grid: target spatial grid size (e.g., 19 for 342px)
32
+
33
+ Returns:
34
+ Interpolated pos_embed: [new_grid*new_grid, D]
35
+ """
36
+ if orig_grid == new_grid:
37
+ return pos_embed
38
+
39
+ D = pos_embed.shape[-1]
40
+ pos_2d = pos_embed.reshape(orig_grid, orig_grid, D)
41
+ pos_torch = torch.from_numpy(pos_2d).permute(2, 0, 1).unsqueeze(0).float()
42
+
43
+ pos_interp = F.interpolate(
44
+ pos_torch,
45
+ size=(new_grid, new_grid),
46
+ mode='bicubic',
47
+ align_corners=False
48
+ )
49
+
50
+ pos_interp = pos_interp.squeeze(0).permute(1, 2, 0).numpy()
51
+ return pos_interp.reshape(new_grid * new_grid, D)
52
+
53
+
54
+ class VideoPrismBackbone(nn.Module):
55
+ """VideoPrism backbone wrapper for PyTorch compatibility."""
56
+
57
+ DEFAULT_RESOLUTION = 288
58
+ PATCH_SIZE = 18
59
+
60
+ def __init__(
61
+ self,
62
+ model_name: str = 'videoprism_public_v1_base',
63
+ resolution: int = 288,
64
+ log_fn: Optional[Callable[[str], None]] = None
65
+ ):
66
+ super().__init__()
67
+ self.model_name = model_name
68
+ self.resolution = resolution
69
+ self.flax_model = None
70
+ self.params = None
71
+ self.original_params = None
72
+ self.jax_device = None
73
+ self._forward_fn = None
74
+ self.log_fn = log_fn or print
75
+ self._current_grid_size = None
76
+ self.enable_dlpack = os.environ.get("BEHAVIOR_APP_USE_DLPACK", "0").strip().lower() in ("1", "true", "yes", "on")
77
+ self._load_model()
78
+
79
+ def _is_jax_gpu(self) -> bool:
80
+ d = self.jax_device
81
+ return (
82
+ d.device_kind == 'gpu' or
83
+ 'gpu' in d.device_kind.lower() or
84
+ 'cuda' in str(d).lower() or
85
+ d.platform in ['gpu', 'cuda']
86
+ )
87
+
88
+ def _load_model(self):
89
+ """Load VideoPrism model (JAX/Flax) and configure for GPU if available."""
90
+ try:
91
+ import copy
92
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
93
+ os.environ['XLA_FLAGS'] = '--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found'
94
+
95
+ gc.collect()
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+ torch.cuda.synchronize()
99
+ self.log_fn("Cleared PyTorch GPU cache before loading VideoPrism")
100
+
101
+ devices = jax.devices()
102
+ self.log_fn(f"JAX devices found: {len(devices)}")
103
+ for i, d in enumerate(devices):
104
+ self.log_fn(f" Device {i}: {d} (kind: {d.device_kind}, platform: {d.platform})")
105
+
106
+ gpu_devices = [
107
+ d for d in devices
108
+ if (d.device_kind == 'gpu' or
109
+ 'gpu' in d.device_kind.lower() or
110
+ 'cuda' in str(d).lower() or
111
+ d.platform in ['gpu', 'cuda'])
112
+ ]
113
+
114
+ if gpu_devices:
115
+ self.jax_device = gpu_devices[0]
116
+ self.log_fn(f"JAX GPU device selected: {self.jax_device}")
117
+ else:
118
+ self.jax_device = jax.devices('cpu')[0]
119
+ self.log_fn(f"JAX using CPU device: {self.jax_device} (GPU not available)")
120
+
121
+ self.log_fn("Loading VideoPrism model...")
122
+ self.flax_model = vp.get_model(self.model_name)
123
+ self.params = vp.load_pretrained_weights(self.model_name)
124
+ self.original_params = copy.deepcopy(self.params)
125
+
126
+ orig_grid = self.DEFAULT_RESOLUTION // self.PATCH_SIZE
127
+ new_grid = self.resolution // self.PATCH_SIZE
128
+ self._current_grid_size = new_grid
129
+
130
+ if new_grid != orig_grid:
131
+ self.log_fn(f"Resolution {self.resolution}x{self.resolution} -> {new_grid}x{new_grid} spatial grid")
132
+ self.params = self._interpolate_spatial_pos_embed(self.params, orig_grid, new_grid)
133
+ else:
134
+ self.log_fn(f"Using default resolution {self.resolution}x{self.resolution} ({new_grid}x{new_grid} grid)")
135
+
136
+ if self._is_jax_gpu():
137
+ self.log_fn("Moving VideoPrism parameters to GPU...")
138
+ self.params = jax.device_put(self.params, self.jax_device)
139
+
140
+ @jax.jit
141
+ def forward_fn(params, videos_bthwc: jnp.ndarray) -> jnp.ndarray:
142
+ embeddings, _ = self.flax_model.apply(
143
+ params, videos_bthwc, train=False, return_intermediate=False
144
+ )
145
+ return embeddings
146
+
147
+ self._forward_fn = forward_fn
148
+ self.log_fn(f"Loaded VideoPrism model: {self.model_name} (device: {self.jax_device.device_kind})")
149
+
150
+ self.log_fn(f"Warming up JIT compilation at {self.resolution}x{self.resolution}...")
151
+ dummy_batch = jnp.zeros((1, 16, self.resolution, self.resolution, 3), dtype=jnp.float32)
152
+ dummy_batch = jax.device_put(dummy_batch, self.jax_device)
153
+ _ = self._forward_fn(self.params, dummy_batch)
154
+ self.log_fn("JIT compilation complete")
155
+
156
+ except Exception as e:
157
+ error_msg = f"Error loading VideoPrism model: {e}\n"
158
+ error_msg += "This may be due to CuDNN version mismatch or missing dependencies."
159
+ if "DNN library initialization failed" in str(e) or "FAILED_PRECONDITION" in str(e):
160
+ error_msg += (
161
+ "\n\nTry:\n"
162
+ "conda activate singlebehaviorlab\n"
163
+ "pip install --upgrade torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124\n"
164
+ "pip install --upgrade \"jax[cuda12]==0.6.2\" flax==0.10.7\n"
165
+ "pip install --upgrade \"nvidia-cudnn-cu12==9.20.0.48\""
166
+ )
167
+ self.log_fn(error_msg)
168
+ import traceback
169
+ self.log_fn(traceback.format_exc())
170
+ raise RuntimeError(error_msg) from e
171
+
172
+ def _interpolate_spatial_pos_embed(self, params: dict, orig_grid: int, new_grid: int) -> dict:
173
+ import copy
174
+ params = copy.deepcopy(params)
175
+
176
+ def find_and_interpolate(d, path=""):
177
+ if isinstance(d, dict):
178
+ for k, v in d.items():
179
+ new_path = f"{path}.{k}" if path else k
180
+ d[k] = find_and_interpolate(v, new_path)
181
+ elif isinstance(d, (np.ndarray, jnp.ndarray)):
182
+ arr = np.asarray(d)
183
+ expected_spatial = orig_grid * orig_grid
184
+ is_pos_embed = any(name in path.lower() for name in ['pos_embed', 'posembed', 'position'])
185
+
186
+ if is_pos_embed and arr.ndim >= 2:
187
+ if arr.shape[-2] == expected_spatial:
188
+ D = arr.shape[-1]
189
+ self.log_fn(f" Interpolating {path}: [{expected_spatial}, {D}] -> [{new_grid*new_grid}, {D}]")
190
+ arr_interp = interpolate_pos_embed_2d(arr, orig_grid, new_grid)
191
+ return jnp.asarray(arr_interp)
192
+ elif arr.ndim == 3 and arr.shape[1] == expected_spatial:
193
+ D = arr.shape[-1]
194
+ self.log_fn(f" Interpolating {path}: [1, {expected_spatial}, {D}] -> [1, {new_grid*new_grid}, {D}]")
195
+ arr_interp = interpolate_pos_embed_2d(arr[0], orig_grid, new_grid)
196
+ return jnp.asarray(arr_interp[np.newaxis, :, :])
197
+ elif arr.ndim == 3 and arr.shape[1] == expected_spatial + 1:
198
+ D = arr.shape[-1]
199
+ self.log_fn(f" Interpolating {path}: [1, 1+{expected_spatial}, {D}] -> [1, 1+{new_grid*new_grid}, {D}]")
200
+ cls_token = arr[:, :1, :]
201
+ spatial = arr[0, 1:, :]
202
+ spatial_interp = interpolate_pos_embed_2d(spatial, orig_grid, new_grid)
203
+ return jnp.concatenate([cls_token, spatial_interp[np.newaxis, :, :]], axis=1)
204
+ return d
205
+ return d
206
+
207
+ params = find_and_interpolate(params)
208
+ return params
209
+
210
+ def set_resolution(self, resolution: int):
211
+ if resolution == self.resolution:
212
+ return
213
+ self.resolution = resolution
214
+ orig_grid = self.DEFAULT_RESOLUTION // self.PATCH_SIZE
215
+ new_grid = resolution // self.PATCH_SIZE
216
+ self._current_grid_size = new_grid
217
+
218
+ if self.original_params is not None:
219
+ self.log_fn(f"Re-interpolating spatial pos embeddings for resolution {resolution}...")
220
+ import copy
221
+ self.params = copy.deepcopy(self.original_params)
222
+ if new_grid != orig_grid:
223
+ self.params = self._interpolate_spatial_pos_embed(self.params, orig_grid, new_grid)
224
+ if self._is_jax_gpu():
225
+ self.params = jax.device_put(self.params, self.jax_device)
226
+ self.log_fn(f"Resolution updated to {resolution}x{resolution}")
227
+
228
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
229
+ """
230
+ Args:
231
+ x: Input tensor [B, T, C, H, W] in [0, 1] float32
232
+ Returns:
233
+ Token embeddings [B, N, D] where N = T * S spatial tokens
234
+ """
235
+ if self.flax_model is None or self.params is None or self._forward_fn is None:
236
+ raise RuntimeError("Model not loaded")
237
+
238
+ device = x.device
239
+ x_bthwc = x.permute(0, 1, 3, 4, 2).contiguous()
240
+
241
+ use_dlpack = (
242
+ self.enable_dlpack and
243
+ self.resolution <= 288 and
244
+ x_bthwc.is_cuda and
245
+ self.jax_device is not None and
246
+ self.jax_device.platform in ["gpu", "cuda"]
247
+ )
248
+
249
+ def _to_jax_tensor(prefer_dlpack: bool):
250
+ if prefer_dlpack:
251
+ try:
252
+ x_jax_local = jax.dlpack.from_dlpack(x_bthwc)
253
+ except Exception:
254
+ import torch.utils.dlpack as torch_dlpack
255
+ x_dlpack = torch_dlpack.to_dlpack(x_bthwc)
256
+ x_jax_local = jax.dlpack.from_dlpack(x_dlpack)
257
+ if x_jax_local.device != self.jax_device:
258
+ x_jax_local = jax.device_put(x_jax_local, self.jax_device)
259
+ return x_jax_local, True
260
+ x_np = x_bthwc.detach().cpu().contiguous().numpy()
261
+ x_jax_local = jnp.asarray(x_np)
262
+ x_jax_local = jax.device_put(x_jax_local, self.jax_device)
263
+ return x_jax_local, False
264
+
265
+ try:
266
+ x_jax, used_dlpack = _to_jax_tensor(use_dlpack)
267
+ embeddings_jax = self._forward_fn(self.params, x_jax)
268
+ embeddings_jax.block_until_ready()
269
+ del x_jax
270
+ except Exception as e:
271
+ if use_dlpack:
272
+ self.log_fn(f"Warning: DLPack path failed ({e}). Retrying with safe host-copy transfer.")
273
+ x_jax, used_dlpack = _to_jax_tensor(False)
274
+ embeddings_jax = self._forward_fn(self.params, x_jax)
275
+ embeddings_jax.block_until_ready()
276
+ del x_jax
277
+ else:
278
+ raise
279
+
280
+ embeddings_np = np.asarray(embeddings_jax)
281
+ del embeddings_jax
282
+ embeddings_torch = torch.from_numpy(embeddings_np.copy()).to(device)
283
+ del embeddings_np
284
+ return embeddings_torch
285
+
286
+ def get_embed_dim(self) -> int:
287
+ dummy_input = torch.zeros(1, 16, 3, self.resolution, self.resolution)
288
+ with torch.no_grad():
289
+ tokens = self.forward(dummy_input)
290
+ return tokens.shape[-1]
291
+
292
+ def get_num_tokens(self) -> int:
293
+ grid_size = self.resolution // self.PATCH_SIZE
294
+ return grid_size * grid_size
295
+
296
+
297
+ # Stage B: frame decoder and localization heads.
298
+
299
+ class SpatialAttentionPool(nn.Module):
300
+ """Trainable per-frame spatial attention pooling.
301
+
302
+ For each frame, a learned query attends over S spatial tokens
303
+ and an MLP bottleneck projects the result to proj_dim.
304
+ """
305
+
306
+ def __init__(self, embed_dim: int, proj_dim: int = 256,
307
+ num_heads: int = 4, dropout: float = 0.1):
308
+ super().__init__()
309
+ self.embed_dim = embed_dim
310
+ self.proj_dim = proj_dim
311
+ self.ln = nn.LayerNorm(embed_dim)
312
+ self.query = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
313
+ self.attn = nn.MultiheadAttention(
314
+ embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True,
315
+ )
316
+ self.mlp = nn.Sequential(
317
+ nn.Linear(embed_dim, proj_dim),
318
+ nn.GELU(),
319
+ nn.Dropout(dropout),
320
+ nn.Linear(proj_dim, proj_dim),
321
+ nn.Dropout(dropout),
322
+ )
323
+
324
+ def forward(self, tokens: torch.Tensor, num_frames: int,
325
+ return_attn_weights: bool = False):
326
+ """
327
+ Args:
328
+ tokens: [B, T*S, D]
329
+ num_frames: T
330
+ return_attn_weights: if True, also return spatial attention maps
331
+ Returns:
332
+ [B, T, proj_dim] or ([B, T, proj_dim], [B, T, num_heads, S])
333
+ """
334
+ B, N, D = tokens.shape
335
+ S = N // num_frames
336
+ T = num_frames
337
+
338
+ flat = tokens.view(B, T, S, D).reshape(B * T, S, D)
339
+ flat = self.ln(flat)
340
+
341
+ q = self.query.expand(B * T, -1, -1)
342
+ pooled, attn_w = self.attn(q, flat, flat, need_weights=return_attn_weights,
343
+ average_attn_weights=False)
344
+ pooled = pooled.squeeze(1) # [B*T, D]
345
+
346
+ out = self.mlp(pooled) # [B*T, proj_dim]
347
+
348
+ if return_attn_weights and attn_w is not None:
349
+ # attn_w: [B*T, num_heads, 1, S] -> [B, T, num_heads, S]
350
+ attn_w = attn_w.squeeze(2).view(B, T, -1, S)
351
+ return out.view(B, T, self.proj_dim), attn_w
352
+
353
+ return out.view(B, T, self.proj_dim)
354
+
355
+
356
+ class _DilatedResidualBlock(nn.Module):
357
+ def __init__(self, channels: int, dilation: int, k: int, p: float):
358
+ super().__init__()
359
+ self.block = nn.Sequential(
360
+ nn.Conv1d(channels, channels, kernel_size=k,
361
+ padding=dilation, dilation=dilation, bias=False),
362
+ nn.GELU(),
363
+ nn.Dropout(p),
364
+ nn.Conv1d(channels, channels, kernel_size=1, bias=False),
365
+ nn.Dropout(p),
366
+ )
367
+
368
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
369
+ return x + self.block(x)
370
+
371
+
372
+ class _MSRefineStage(nn.Module):
373
+ def __init__(self, in_channels: int, hidden_channels: int,
374
+ out_channels: int, layers: int, k: int, p: float):
375
+ super().__init__()
376
+ self.in_proj = nn.Conv1d(in_channels, hidden_channels,
377
+ kernel_size=1, bias=False)
378
+ self.blocks = nn.ModuleList([
379
+ _DilatedResidualBlock(hidden_channels, dilation=2 ** i, k=k, p=p)
380
+ for i in range(max(1, int(layers)))
381
+ ])
382
+ self.out_proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=1)
383
+
384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
385
+ y = self.in_proj(x)
386
+ for blk in self.blocks:
387
+ y = blk(y)
388
+ return self.out_proj(y)
389
+
390
+
391
+ class DilatedTemporalHead(nn.Module):
392
+ """Frame decoder head (Stage B).
393
+
394
+ Trainable spatial attention pooling -> local TCN -> state + boundary heads.
395
+ Produces per-frame state logits, boundary logits, and frame embeddings.
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ embed_dim: int,
401
+ num_classes: int,
402
+ num_layers: int = 4,
403
+ kernel_size: int = 3,
404
+ dropout: float = 0.2,
405
+ num_stages: int = 3,
406
+ hidden_dim: Optional[int] = None,
407
+ temporal_pool: int = 1,
408
+ proj_dim: int = 256,
409
+ spatial_pool_heads: int = 4,
410
+ multi_scale: bool = False,
411
+ use_temporal_decoder: bool = True,
412
+ ):
413
+ super().__init__()
414
+ self.embed_dim = int(embed_dim)
415
+ self.num_classes = int(num_classes)
416
+ self.num_layers = max(1, int(num_layers))
417
+ self.num_stages = max(1, int(num_stages))
418
+ self.dropout = float(dropout)
419
+ self.temporal_pool = max(1, int(temporal_pool))
420
+ self.use_ovr = False
421
+ self.proj_dim = int(proj_dim)
422
+ self.multi_scale = bool(multi_scale)
423
+ self.use_temporal_decoder = bool(use_temporal_decoder)
424
+ auto_hidden = max(128, self.proj_dim)
425
+ self.hidden_dim = int(hidden_dim) if hidden_dim is not None else auto_hidden
426
+
427
+ self.spatial_pool = SpatialAttentionPool(
428
+ embed_dim=self.embed_dim,
429
+ proj_dim=self.proj_dim,
430
+ num_heads=int(spatial_pool_heads),
431
+ dropout=self.dropout,
432
+ )
433
+
434
+ # Raw per-frame projection: simple mean-pool of spatial tokens → proj_dim.
435
+ # Not filtered by learned attention, so it retains broader spatial context
436
+ # that SpatialAttentionPool may suppress in favour of discriminative patches.
437
+ self.raw_proj = nn.Linear(self.embed_dim, self.proj_dim)
438
+
439
+ # When multi_scale=True, long-scale and short-scale features are concatenated,
440
+ # so the TCN input is 2*proj_dim instead of proj_dim.
441
+ tcn_in = self.proj_dim * 2 if self.multi_scale else self.proj_dim
442
+ if self.use_temporal_decoder:
443
+ self.stage1 = _MSRefineStage(
444
+ in_channels=tcn_in,
445
+ hidden_channels=self.hidden_dim,
446
+ out_channels=self.num_classes,
447
+ layers=self.num_layers,
448
+ k=kernel_size,
449
+ p=self.dropout,
450
+ )
451
+ self.refine_stages = nn.ModuleList()
452
+ for _ in range(self.num_stages - 1):
453
+ self.refine_stages.append(
454
+ _MSRefineStage(
455
+ in_channels=self.num_classes,
456
+ hidden_channels=self.hidden_dim,
457
+ out_channels=self.num_classes,
458
+ layers=self.num_layers,
459
+ k=kernel_size,
460
+ p=self.dropout,
461
+ )
462
+ )
463
+
464
+ self.boundary_tcn = _MSRefineStage(
465
+ in_channels=tcn_in,
466
+ hidden_channels=self.hidden_dim,
467
+ out_channels=1,
468
+ layers=max(1, self.num_layers // 2),
469
+ k=kernel_size,
470
+ p=self.dropout,
471
+ )
472
+ self.frame_classifier = None
473
+ else:
474
+ self.stage1 = None
475
+ self.refine_stages = nn.ModuleList()
476
+ self.boundary_tcn = None
477
+ self.frame_classifier = nn.Sequential(
478
+ nn.LayerNorm(tcn_in),
479
+ nn.Linear(tcn_in, self.num_classes),
480
+ )
481
+
482
+ def forward(self, tokens: torch.Tensor, num_frames: int,
483
+ tokens_short: Optional[torch.Tensor] = None,
484
+ num_frames_short: Optional[int] = None,
485
+ return_attn_weights: bool = False):
486
+ """
487
+ Args:
488
+ tokens: [B, T*S, D] long-scale backbone tokens
489
+ num_frames: T (long scale)
490
+ tokens_short: [B, T_s*S, D] short-scale tokens (half fps, T_s = T//2).
491
+ Required when multi_scale=True.
492
+ num_frames_short: T_s
493
+ return_attn_weights: if True, include spatial attention maps in output
494
+ Returns:
495
+ frame_logits: [B, T, C]
496
+ clip_logits: [B, C]
497
+ temporal_weights: [B, T]
498
+ frame_logits_pooled: [B, T_pooled, C]
499
+ boundary_logits: [B, T, 1]
500
+ frame_embeddings: [B, T, proj_dim] (long-scale attention pool)
501
+ frame_embeddings_combined: [B, T, 2*proj_dim] (attn || raw_mean, long scale)
502
+ attn_weights: [B, T, num_heads, S] or None
503
+ """
504
+ B, N, D = tokens.shape
505
+ if num_frames <= 0 or (N % num_frames) != 0:
506
+ raise ValueError(
507
+ f"DilatedTemporalHead expects T*S tokens. "
508
+ f"Got N={N}, num_frames={num_frames}."
509
+ )
510
+ T = num_frames
511
+
512
+ # 1. Spatial attention pooling (long scale): [B, T*S, D] -> [B, T, proj_dim]
513
+ attn_weights = None
514
+ pool_out = self.spatial_pool(tokens, T, return_attn_weights=return_attn_weights)
515
+ if return_attn_weights and isinstance(pool_out, tuple):
516
+ x_long, attn_weights = pool_out
517
+ else:
518
+ x_long = pool_out
519
+ # frame_embeddings is the long-scale per-frame feature
520
+ frame_embeddings = x_long # [B, T, proj_dim]
521
+
522
+ # Multi-scale: pool short-scale tokens and upsample to align with T
523
+ if self.multi_scale:
524
+ if tokens_short is None or num_frames_short is None:
525
+ raise ValueError(
526
+ "DilatedTemporalHead has multi_scale=True but tokens_short / "
527
+ "num_frames_short were not provided. Ensure the dataset has "
528
+ "_emb_multi_scale=True and short-scale embeddings are cached."
529
+ )
530
+ T_short = num_frames_short
531
+ x_short = self.spatial_pool(tokens_short, T_short) # [B, T_short, proj_dim]
532
+ scale_factor = T // T_short
533
+ # Nearest-neighbour upsample: each short frame covers scale_factor long frames
534
+ x_short_up = x_short.repeat_interleave(scale_factor, dim=1)[:, :T, :]
535
+ x = torch.cat([x_long, x_short_up], dim=-1) # [B, T, 2*proj_dim]
536
+ else:
537
+ x = x_long # [B, T, proj_dim]
538
+
539
+ if self.use_temporal_decoder:
540
+ # 2. Boundary prediction at full T (before temporal pooling reduces resolution)
541
+ emb_conv = x.transpose(1, 2) # [B, tcn_in, T]
542
+ boundary_logits = self.boundary_tcn(emb_conv).transpose(1, 2) # [B, T, 1]
543
+
544
+ T_pooled = T
545
+ if self.temporal_pool > 1:
546
+ p = self.temporal_pool
547
+ pad = (p - T % p) % p
548
+ if pad > 0:
549
+ x = torch.cat([x, x[:, -1:, :].expand(-1, pad, -1)], dim=1)
550
+ T_padded = T + pad
551
+ feat_dim = x.shape[-1]
552
+ x = x.view(B, T_padded // p, p, feat_dim).mean(dim=2)
553
+ T_pooled = T_padded // p
554
+
555
+ x_conv = x.transpose(1, 2) # [B, tcn_in, T_pooled]
556
+ stage_logits = self.stage1(x_conv)
557
+ for refine in self.refine_stages:
558
+ if self.use_ovr:
559
+ refine_in = torch.sigmoid(stage_logits)
560
+ else:
561
+ refine_in = torch.softmax(stage_logits, dim=1)
562
+ stage_logits = stage_logits + refine(refine_in)
563
+
564
+ stage_logits_pooled = stage_logits
565
+ clip_logits = stage_logits_pooled.mean(dim=2)
566
+
567
+ if self.temporal_pool > 1:
568
+ stage_logits = stage_logits.repeat_interleave(
569
+ self.temporal_pool, dim=2
570
+ )[:, :, :T]
571
+
572
+ frame_logits = stage_logits.transpose(1, 2)
573
+ frame_logits_pooled = stage_logits_pooled.transpose(1, 2)
574
+ else:
575
+ boundary_logits = None
576
+ frame_logits = self.frame_classifier(x)
577
+ frame_logits_pooled = frame_logits
578
+ clip_logits = frame_logits.mean(dim=1)
579
+ temporal_weights = torch.ones(B, T, device=tokens.device) / T
580
+
581
+ # Raw mean-pool per frame from long-scale tokens: bypasses spatial attention
582
+ # so downstream analysis can access a less task-biased view as well.
583
+ S = N // T
584
+ raw_mean = tokens.view(B, T, S, self.embed_dim).mean(dim=2)
585
+ raw_emb = self.raw_proj(raw_mean) # [B, T, proj_dim]
586
+ # Combined embedding: [attention_emb || raw_emb], shape [B, T, 2*proj_dim]
587
+ frame_embeddings_combined = torch.cat([frame_embeddings, raw_emb], dim=-1)
588
+
589
+ return (frame_logits, clip_logits, temporal_weights,
590
+ frame_logits_pooled, boundary_logits, frame_embeddings,
591
+ frame_embeddings_combined, attn_weights)
592
+
593
+
594
+ # Spatial localization head.
595
+
596
+ class SpatialLocalizationHead(nn.Module):
597
+ """Frame-level bbox regressor over spatial tokens.
598
+
599
+ Design:
600
+ 1) Dense objectness map over spatial patches per frame.
601
+ 2) Upsampled objectness for sub-patch precision.
602
+ 3) Temperature-scaled soft-argmax for coarse center.
603
+ 4) Local token gathering for context-aware refinement.
604
+ 5) Temporal residual refiner over center trajectories.
605
+ 6) Fixed box size from dataset stats for stable crops.
606
+ """
607
+
608
+ UPSAMPLE_FACTOR = 4
609
+
610
+ def __init__(self, embed_dim: int, hidden_dim: int = 256, dropout: float = 0.1):
611
+ super().__init__()
612
+ self.norm = nn.LayerNorm(embed_dim)
613
+ self.query = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
614
+ self.attn = nn.MultiheadAttention(
615
+ embed_dim, num_heads=4, dropout=dropout, batch_first=True,
616
+ )
617
+ self.attn_norm = nn.LayerNorm(embed_dim)
618
+ self.objectness = nn.Linear(embed_dim, 1)
619
+ self.temperature = nn.Parameter(torch.tensor(8.0))
620
+
621
+ self.refine = nn.Sequential(
622
+ nn.Linear(embed_dim * 2, hidden_dim),
623
+ nn.GELU(),
624
+ nn.Dropout(dropout),
625
+ nn.Linear(hidden_dim, hidden_dim // 2),
626
+ nn.GELU(),
627
+ nn.Dropout(dropout),
628
+ nn.Linear(hidden_dim // 2, 2),
629
+ )
630
+
631
+ temporal_hidden = max(16, hidden_dim // 8)
632
+ self.temporal_refine = nn.Sequential(
633
+ nn.Conv1d(2, temporal_hidden, kernel_size=5, padding=2),
634
+ nn.GELU(),
635
+ nn.Dropout(dropout),
636
+ nn.Conv1d(temporal_hidden, 2, kernel_size=5, padding=2),
637
+ )
638
+ final_temporal = self.temporal_refine[-1]
639
+ if isinstance(final_temporal, nn.Conv1d):
640
+ nn.init.zeros_(final_temporal.weight)
641
+ if final_temporal.bias is not None:
642
+ nn.init.zeros_(final_temporal.bias)
643
+
644
+ proj_dim = hidden_dim // 2
645
+ self.contrastive_proj = nn.Sequential(
646
+ nn.Linear(embed_dim, hidden_dim),
647
+ nn.GELU(),
648
+ nn.Linear(hidden_dim, proj_dim),
649
+ )
650
+
651
+ self.register_buffer(
652
+ "fixed_box_wh", torch.tensor([0.2, 0.2], dtype=torch.float32)
653
+ )
654
+
655
+ def set_fixed_box_size(self, width: float, height: float):
656
+ w = float(max(1e-4, min(1.0, width)))
657
+ h = float(max(1e-4, min(1.0, height)))
658
+ self.fixed_box_wh.data = torch.tensor(
659
+ [w, h], dtype=self.fixed_box_wh.dtype, device=self.fixed_box_wh.device,
660
+ )
661
+
662
+ def get_contrastive_tokens(
663
+ self, tokens: torch.Tensor, num_frames: Optional[int] = None
664
+ ) -> torch.Tensor:
665
+ B, N, D = tokens.shape
666
+ normed = self.norm(tokens)
667
+ S = N
668
+ if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
669
+ S = N // int(num_frames)
670
+ first_frame = normed[:, :S, :]
671
+ return self.contrastive_proj(first_frame)
672
+
673
+ def get_objectness_logits(
674
+ self, tokens: torch.Tensor, num_frames: Optional[int] = None,
675
+ all_frames: bool = False,
676
+ ) -> torch.Tensor:
677
+ B, N, D = tokens.shape
678
+ normed = self.norm(tokens)
679
+ S = N
680
+ T = 1
681
+ if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
682
+ T = int(num_frames)
683
+ S = N // T
684
+ if all_frames and T > 1:
685
+ frames = normed.view(B, T, S, D).reshape(B * T, S, D)
686
+ return self.objectness(frames).squeeze(-1)
687
+ else:
688
+ first_frame = normed[:, :S, :]
689
+ return self.objectness(first_frame).squeeze(-1)
690
+
691
+ def forward(
692
+ self,
693
+ tokens: torch.Tensor,
694
+ num_frames: Optional[int] = None,
695
+ fixed_box_wh: Optional[torch.Tensor] = None,
696
+ ) -> torch.Tensor:
697
+ B, N, D = tokens.shape
698
+ normed = self.norm(tokens)
699
+
700
+ T = 1
701
+ S = N
702
+ if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
703
+ T = int(num_frames)
704
+ S = N // T
705
+ frames_tokens = normed.view(B, T, S, D)
706
+ else:
707
+ frames_tokens = normed.view(B, 1, N, D)
708
+
709
+ flat_tokens = frames_tokens.reshape(B * T, S, D)
710
+ q = self.query.expand(B * T, -1, -1)
711
+ pooled, _ = self.attn(q, flat_tokens, flat_tokens)
712
+ pooled = self.attn_norm(pooled.squeeze(1))
713
+
714
+ obj_logits = self.objectness(flat_tokens).squeeze(-1)
715
+
716
+ g = int(round(S ** 0.5))
717
+ is_square = g * g == S
718
+ min_size = 1.0 / float(g) if is_square else 1.0 / float(max(1, S))
719
+
720
+ if is_square:
721
+ g_up = g * self.UPSAMPLE_FACTOR
722
+ obj_2d = obj_logits.view(B * T, 1, g, g)
723
+ obj_up = F.interpolate(obj_2d, size=(g_up, g_up), mode="bilinear", align_corners=False)
724
+ obj_up = obj_up.view(B * T, g_up * g_up)
725
+
726
+ temp = self.temperature.clamp(min=1.0)
727
+ obj_probs = torch.softmax(obj_up * temp, dim=1)
728
+
729
+ ys = (torch.arange(g_up, device=normed.device, dtype=normed.dtype) + 0.5) / float(g_up)
730
+ xs = (torch.arange(g_up, device=normed.device, dtype=normed.dtype) + 0.5) / float(g_up)
731
+ yy, xx = torch.meshgrid(ys, xs, indexing="ij")
732
+ x_coords = xx.reshape(-1).unsqueeze(0)
733
+ y_coords = yy.reshape(-1).unsqueeze(0)
734
+ else:
735
+ temp = self.temperature.clamp(min=1.0)
736
+ obj_probs = torch.softmax(obj_logits * temp, dim=1)
737
+ x_coords = ((torch.arange(S, device=normed.device, dtype=normed.dtype) + 0.5) / float(S)).unsqueeze(0)
738
+ y_coords = torch.full_like(x_coords, 0.5)
739
+
740
+ cx0 = (obj_probs * x_coords).sum(dim=1)
741
+ cy0 = (obj_probs * y_coords).sum(dim=1)
742
+
743
+ if is_square:
744
+ cx0_grid = (cx0 * g).clamp(0, g - 1).long()
745
+ cy0_grid = (cy0 * g).clamp(0, g - 1).long()
746
+ tokens_2d = flat_tokens.view(B * T, g, g, D)
747
+ local_feats = []
748
+ for dy in range(-1, 2):
749
+ for dx in range(-1, 2):
750
+ gy = (cy0_grid + dy).clamp(0, g - 1)
751
+ gx = (cx0_grid + dx).clamp(0, g - 1)
752
+ idx_bt = torch.arange(B * T, device=normed.device)
753
+ local_feats.append(tokens_2d[idx_bt, gy, gx, :])
754
+ local_pool = torch.stack(local_feats, dim=1).mean(dim=1)
755
+ else:
756
+ local_pool = pooled
757
+
758
+ refine_input = torch.cat([pooled, local_pool], dim=1)
759
+ delta = self.refine(refine_input)
760
+ dx = 0.30 * torch.tanh(delta[:, 0])
761
+ dy = 0.30 * torch.tanh(delta[:, 1])
762
+
763
+ cx = (cx0 + dx).clamp(0.0, 1.0)
764
+ cy = (cy0 + dy).clamp(0.0, 1.0)
765
+
766
+ if T > 1:
767
+ center_seq = torch.stack([cx, cy], dim=1).view(B, T, 2).transpose(1, 2)
768
+ temporal_delta = self.temporal_refine(center_seq).transpose(1, 2).reshape(B * T, 2)
769
+ cx = (cx + 0.12 * torch.tanh(temporal_delta[:, 0])).clamp(0.0, 1.0)
770
+ cy = (cy + 0.12 * torch.tanh(temporal_delta[:, 1])).clamp(0.0, 1.0)
771
+
772
+ if fixed_box_wh is not None:
773
+ wh = fixed_box_wh.to(device=normed.device, dtype=normed.dtype)
774
+ if wh.dim() == 1:
775
+ wh = wh.view(1, 2).expand(B, -1)
776
+ if wh.size(0) != B:
777
+ wh = wh[:B]
778
+ else:
779
+ wh = self.fixed_box_wh.to(device=normed.device, dtype=normed.dtype).view(1, 2).expand(B, -1)
780
+ wh = wh.clamp(min=min_size, max=1.0)
781
+ wh_bt = wh.unsqueeze(1).expand(B, T, 2).reshape(B * T, 2) if T > 1 else wh
782
+ w = wh_bt[:, 0]
783
+ h = wh_bt[:, 1]
784
+
785
+ x1 = (cx - 0.5 * w).clamp(0.0, 1.0)
786
+ y1 = (cy - 0.5 * h).clamp(0.0, 1.0)
787
+ x2 = (cx + 0.5 * w).clamp(0.0, 1.0)
788
+ y2 = (cy + 0.5 * h).clamp(0.0, 1.0)
789
+ x2 = torch.maximum(x2, x1 + min_size).clamp(0.0, 1.0)
790
+ y2 = torch.maximum(y2, y1 + min_size).clamp(0.0, 1.0)
791
+
792
+ boxes = torch.stack([x1, y1, x2, y2], dim=1).view(B, T, 4)
793
+ if T == 1:
794
+ return boxes[:, 0, :]
795
+ return boxes
796
+
797
+
798
+ # BehaviorClassifier.
799
+
800
+ class BehaviorClassifier(nn.Module):
801
+ """Complete model: VideoPrism backbone + frame decoder head (Stage B)."""
802
+
803
+ def __init__(
804
+ self,
805
+ backbone: VideoPrismBackbone,
806
+ num_classes: int,
807
+ class_names: list = None,
808
+ dropout: float = 0.1,
809
+ freeze_backbone: bool = True,
810
+ head_kwargs: Optional[dict] = None,
811
+ use_localization: bool = False,
812
+ localization_hidden_dim: int = 256,
813
+ localization_dropout: float = 0.0,
814
+ frame_head_temporal_layers: int = 1,
815
+ temporal_pool_frames: int = 1,
816
+ proj_dim: int = 256,
817
+ num_stages: int = 3,
818
+ multi_scale: bool = False,
819
+ use_temporal_decoder: bool = True,
820
+ use_frame_head: bool = True,
821
+ **kwargs,
822
+ ):
823
+ super().__init__()
824
+ self.backbone = backbone
825
+ self.multi_scale = bool(multi_scale)
826
+ self.use_temporal_decoder = bool(use_temporal_decoder)
827
+
828
+ if freeze_backbone:
829
+ for p in self.backbone.parameters():
830
+ p.requires_grad = False
831
+
832
+ embed_dim = backbone.get_embed_dim()
833
+ self.temporal_pool_frames = max(1, int(temporal_pool_frames))
834
+ self.use_frame_head = True # always on
835
+
836
+ head_kwargs = head_kwargs or {}
837
+ spatial_pool_heads = head_kwargs.get("num_heads", 4)
838
+
839
+ self.frame_head = DilatedTemporalHead(
840
+ embed_dim=embed_dim,
841
+ num_classes=num_classes,
842
+ dropout=dropout,
843
+ num_layers=max(1, int(frame_head_temporal_layers)),
844
+ num_stages=max(1, int(num_stages)),
845
+ temporal_pool=self.temporal_pool_frames,
846
+ proj_dim=int(proj_dim),
847
+ spatial_pool_heads=int(spatial_pool_heads),
848
+ multi_scale=self.multi_scale,
849
+ use_temporal_decoder=self.use_temporal_decoder,
850
+ )
851
+
852
+ self.use_localization = bool(use_localization)
853
+ self.localization_head = None
854
+ if self.use_localization:
855
+ self.localization_head = SpatialLocalizationHead(
856
+ embed_dim=embed_dim,
857
+ hidden_dim=int(localization_hidden_dim),
858
+ dropout=float(localization_dropout),
859
+ )
860
+
861
+ def forward(
862
+ self,
863
+ video: Optional[torch.Tensor],
864
+ return_localization: bool = False,
865
+ return_frame_logits: bool = False,
866
+ cache_backbone_tokens: bool = False,
867
+ localization_box_wh: Optional[torch.Tensor] = None,
868
+ return_attn_weights: bool = False,
869
+ backbone_tokens: Optional[torch.Tensor] = None,
870
+ num_frames: Optional[int] = None,
871
+ backbone_tokens_short: Optional[torch.Tensor] = None,
872
+ num_frames_short: Optional[int] = None,
873
+ return_features: bool = False,
874
+ ):
875
+ """
876
+ Returns:
877
+ clip_logits [B, C] by default.
878
+ Stores full frame outputs in self._frame_output when
879
+ return_frame_logits=True.
880
+
881
+ backbone_tokens: pre-computed tokens [B, T*S, D]. If provided,
882
+ the backbone call is skipped. num_frames must also be provided.
883
+ backbone_tokens_short: pre-computed short-scale tokens [B, T_s*S, D]
884
+ (T_s = T//2, half-fps clip). Used when multi_scale=True.
885
+ """
886
+ if backbone_tokens is not None:
887
+ # Embedding-space stitch path: bypass backbone entirely
888
+ tokens_full = backbone_tokens
889
+ if num_frames is None:
890
+ raise ValueError("num_frames must be provided when backbone_tokens is given")
891
+ T = num_frames
892
+ tokens_short = backbone_tokens_short
893
+ T_short = num_frames_short
894
+ else:
895
+ tokens_full = self.backbone(video)
896
+ T = int(video.shape[1])
897
+ # Multi-scale at inference: subsample video by 2 and run backbone again
898
+ if self.multi_scale:
899
+ video_short = video[:, ::2, :, :, :] # [B, T//2, H, W, C]
900
+ T_short = int(video_short.shape[1])
901
+ tokens_short = self.backbone(video_short) # [B, T_s*S, D]
902
+ else:
903
+ tokens_short = None
904
+ T_short = None
905
+
906
+ if cache_backbone_tokens:
907
+ self._backbone_tokens = tokens_full
908
+
909
+ self._frame_output = None
910
+ frame_device = next(self.frame_head.parameters()).device
911
+ tokens_in = tokens_full.to(frame_device) if tokens_full.device != frame_device else tokens_full
912
+ tokens_short_in = (
913
+ tokens_short.to(frame_device)
914
+ if tokens_short is not None and tokens_short.device != frame_device
915
+ else tokens_short
916
+ )
917
+ frame_out = self.frame_head(
918
+ tokens_in, num_frames=T,
919
+ tokens_short=tokens_short_in,
920
+ num_frames_short=T_short,
921
+ return_attn_weights=return_attn_weights,
922
+ )
923
+
924
+ frame_logits = frame_out[0]
925
+ frame_clip_logits = frame_out[1]
926
+ temporal_weights = frame_out[2]
927
+ frame_logits_pooled = frame_out[3]
928
+ boundary_logits = frame_out[4]
929
+ frame_embeddings = frame_out[5]
930
+ frame_embeddings_combined = frame_out[6] if len(frame_out) > 6 else None
931
+ attn_weights = frame_out[7] if len(frame_out) > 7 else None
932
+
933
+ if return_frame_logits:
934
+ self._frame_output = (
935
+ frame_logits, # 0 [B, T, C]
936
+ frame_clip_logits, # 1 [B, C]
937
+ temporal_weights, # 2 [B, T]
938
+ frame_logits_pooled, # 3 [B, T_pooled, C]
939
+ int(self.temporal_pool_frames), # 4
940
+ boundary_logits, # 5 [B, T, 1]
941
+ frame_embeddings, # 6 [B, T, proj_dim] (attention-pooled)
942
+ frame_embeddings_combined, # 7 [B, T, 2*proj_dim] (attn || raw_mean)
943
+ attn_weights, # 8 [B, T, num_heads, S] or None
944
+ )
945
+
946
+ head_output = frame_clip_logits
947
+
948
+ if return_localization and self.use_localization and self.localization_head is not None:
949
+ loc_device = next(self.localization_head.parameters()).device
950
+ loc_out = self.localization_head(
951
+ tokens_full.to(loc_device) if tokens_full.device != loc_device else tokens_full,
952
+ num_frames=T,
953
+ fixed_box_wh=localization_box_wh,
954
+ )
955
+ return head_output, loc_out
956
+
957
+ return head_output
958
+
959
+ def save_head(self, path: str, metadata: Optional[dict] = None):
960
+ """Save frame head (and localization head) parameters."""
961
+ payload = {
962
+ "frame_head_state_dict": self.frame_head.state_dict(),
963
+ "use_localization": self.use_localization,
964
+ "use_frame_head": True,
965
+ "multi_scale": self.multi_scale,
966
+ "use_temporal_decoder": self.use_temporal_decoder,
967
+ }
968
+ if self.use_localization and self.localization_head is not None:
969
+ payload["localization_state_dict"] = self.localization_head.state_dict()
970
+ torch.save(payload, path)
971
+
972
+ if metadata:
973
+ import json
974
+ meta_path = path + ".meta.json"
975
+ try:
976
+ with open(meta_path, "w", encoding="utf-8") as f:
977
+ json.dump(metadata, f, indent=2)
978
+ except Exception as exc:
979
+ logger.warning("Failed to write metadata to %s: %s", meta_path, exc)
980
+
981
+ def load_head(self, path: str):
982
+ """Load head parameters."""
983
+ state = torch.load(path, map_location='cpu')
984
+ if not isinstance(state, dict):
985
+ return
986
+
987
+ if "frame_head_state_dict" in state:
988
+ self.frame_head.load_state_dict(
989
+ state["frame_head_state_dict"], strict=False,
990
+ )
991
+ if (
992
+ self.use_localization
993
+ and self.localization_head is not None
994
+ and isinstance(state.get("localization_state_dict"), dict)
995
+ ):
996
+ self.localization_head.load_state_dict(
997
+ state["localization_state_dict"], strict=False,
998
+ )
999
+
1000
+
1001
+ # Loss functions.
1002
+
1003
+ def _giou(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1004
+ """Generalized IoU between xyxy boxes. Returns [N] in [-1, 1]."""
1005
+ eps = 1e-7
1006
+ ix1 = torch.maximum(pred[:, 0], target[:, 0])
1007
+ iy1 = torch.maximum(pred[:, 1], target[:, 1])
1008
+ ix2 = torch.minimum(pred[:, 2], target[:, 2])
1009
+ iy2 = torch.minimum(pred[:, 3], target[:, 3])
1010
+ inter = (ix2 - ix1).clamp(min=0) * (iy2 - iy1).clamp(min=0)
1011
+
1012
+ area_p = (pred[:, 2] - pred[:, 0]).clamp(min=0) * (pred[:, 3] - pred[:, 1]).clamp(min=0)
1013
+ area_t = (target[:, 2] - target[:, 0]).clamp(min=0) * (target[:, 3] - target[:, 1]).clamp(min=0)
1014
+ union = area_p + area_t - inter
1015
+ iou = inter / (union + eps)
1016
+
1017
+ ex1 = torch.minimum(pred[:, 0], target[:, 0])
1018
+ ey1 = torch.minimum(pred[:, 1], target[:, 1])
1019
+ ex2 = torch.maximum(pred[:, 2], target[:, 2])
1020
+ ey2 = torch.maximum(pred[:, 3], target[:, 3])
1021
+ area_enclosing = (ex2 - ex1).clamp(min=0) * (ey2 - ey1).clamp(min=0)
1022
+
1023
+ return iou - (area_enclosing - union) / (area_enclosing + eps)
1024
+
1025
+
1026
+ def localization_bbox_loss(
1027
+ pred_bboxes: torch.Tensor,
1028
+ target_bboxes: torch.Tensor,
1029
+ valid_mask: torch.Tensor,
1030
+ smooth_l1_weight: float = 0.5,
1031
+ giou_weight: float = 0.5,
1032
+ temporal_smoothness_weight: float = 0.0,
1033
+ ) -> torch.Tensor:
1034
+ """Combined Smooth-L1 + GIoU bbox loss over valid samples."""
1035
+ if pred_bboxes is None or target_bboxes is None or valid_mask is None:
1036
+ return torch.tensor(0.0, device=pred_bboxes.device if pred_bboxes is not None else "cpu", requires_grad=True)
1037
+
1038
+ valid = valid_mask > 0.5
1039
+ if not valid.any():
1040
+ return torch.tensor(0.0, device=pred_bboxes.device, requires_grad=True)
1041
+
1042
+ pred_first = pred_bboxes[:, 0, :] if pred_bboxes.dim() == 3 else pred_bboxes
1043
+ pred = pred_first[valid]
1044
+ tgt = target_bboxes[valid]
1045
+
1046
+ loss_l1 = F.smooth_l1_loss(pred, tgt, reduction="mean")
1047
+ loss_giou = (1.0 - _giou(pred, tgt)).mean()
1048
+ loss = smooth_l1_weight * loss_l1 + giou_weight * loss_giou
1049
+
1050
+ if pred_bboxes.dim() == 3 and pred_bboxes.size(1) > 1:
1051
+ pred_valid_seq = pred_bboxes[valid]
1052
+ if pred_valid_seq.numel() > 0:
1053
+ smooth = F.smooth_l1_loss(
1054
+ pred_valid_seq[:, 1:, :], pred_valid_seq[:, :-1, :],
1055
+ reduction="mean",
1056
+ )
1057
+ loss = loss + temporal_smoothness_weight * smooth
1058
+
1059
+ return loss
1060
+
1061
+
1062
+ def objectness_spatial_contrastive_loss(
1063
+ projected_tokens: torch.Tensor,
1064
+ spatial_masks: torch.Tensor,
1065
+ temperature: float = 0.1,
1066
+ ) -> torch.Tensor:
1067
+ """Per-sample spatial contrastive loss for localization training."""
1068
+ device = projected_tokens.device
1069
+ B, S, D = projected_tokens.shape
1070
+
1071
+ if spatial_masks.size(1) != S:
1072
+ return torch.tensor(0.0, device=device, requires_grad=True)
1073
+
1074
+ n_inside = spatial_masks.sum(dim=1)
1075
+ n_outside = S - n_inside
1076
+ valid = (n_inside >= 2) & (n_outside >= 1)
1077
+ if not valid.any():
1078
+ return torch.tensor(0.0, device=device, requires_grad=True)
1079
+
1080
+ tokens_sub = F.normalize(projected_tokens[valid], p=2, dim=-1)
1081
+ masks_sub = spatial_masks[valid]
1082
+
1083
+ total_loss = torch.tensor(0.0, device=device)
1084
+ count = 0
1085
+
1086
+ for i in range(tokens_sub.size(0)):
1087
+ tok = tokens_sub[i]
1088
+ mask = masks_sub[i]
1089
+ inside_idx = (mask > 0.5).nonzero(as_tuple=True)[0]
1090
+ outside_idx = (mask <= 0.5).nonzero(as_tuple=True)[0]
1091
+ K = inside_idx.size(0)
1092
+ if K < 2 or outside_idx.size(0) < 1:
1093
+ continue
1094
+
1095
+ inside_tok = tok[inside_idx]
1096
+ outside_tok = tok[outside_idx]
1097
+ sim_pos = torch.matmul(inside_tok, inside_tok.T) / temperature
1098
+ sim_neg = torch.matmul(inside_tok, outside_tok.T) / temperature
1099
+ sim_pos = sim_pos.masked_fill(
1100
+ torch.eye(K, device=device, dtype=torch.bool), float("-inf"),
1101
+ )
1102
+ log_numer = torch.logsumexp(sim_pos, dim=1)
1103
+ all_sim = torch.cat([sim_pos, sim_neg], dim=1)
1104
+ log_denom = torch.logsumexp(all_sim, dim=1)
1105
+ total_loss = total_loss + (-(log_numer - log_denom)).mean()
1106
+ count += 1
1107
+
1108
+ if count == 0:
1109
+ return torch.tensor(0.0, device=device, requires_grad=True)
1110
+ return total_loss / count
1111
+
1112
+
1113
+ def objectness_mask_loss(
1114
+ objectness_logits: torch.Tensor,
1115
+ spatial_masks: torch.Tensor,
1116
+ min_pos_tokens: int = 2,
1117
+ ) -> torch.Tensor:
1118
+ """BCE logits loss for localization objectness supervision."""
1119
+ device = objectness_logits.device
1120
+ if spatial_masks.shape != objectness_logits.shape:
1121
+ return torch.tensor(0.0, device=device, requires_grad=True)
1122
+
1123
+ pos_count = spatial_masks.sum(dim=1)
1124
+ valid = (pos_count >= float(min_pos_tokens)) & (pos_count < float(spatial_masks.size(1)))
1125
+ if not valid.any():
1126
+ return torch.tensor(0.0, device=device, requires_grad=True)
1127
+
1128
+ logits = objectness_logits[valid]
1129
+ targets = spatial_masks[valid].float()
1130
+ pos = targets.sum()
1131
+ neg = torch.tensor(float(targets.numel()), device=device) - pos
1132
+ pos_weight = (neg / (pos + 1e-6)).clamp(min=1.0, max=20.0)
1133
+ return F.binary_cross_entropy_with_logits(logits, targets, pos_weight=pos_weight)
1134
+
1135
+
1136
+ def gaussian_focal_loss(
1137
+ pred_logits: torch.Tensor,
1138
+ target_heatmap: torch.Tensor,
1139
+ alpha: float = 2.0,
1140
+ beta: float = 4.0,
1141
+ ) -> torch.Tensor:
1142
+ """CenterNet-style Gaussian focal loss for heatmap supervision."""
1143
+ pred = torch.sigmoid(pred_logits)
1144
+ pos_mask = target_heatmap.eq(1.0)
1145
+ neg_mask = ~pos_mask
1146
+
1147
+ pos_loss = -((1 - pred).pow(alpha) * torch.log(pred.clamp(min=1e-6))) * pos_mask
1148
+ neg_loss = -(
1149
+ (1 - target_heatmap).pow(beta) * pred.pow(alpha)
1150
+ * torch.log((1 - pred).clamp(min=1e-6))
1151
+ ) * neg_mask
1152
+
1153
+ num_pos = pos_mask.float().sum().clamp(min=1.0)
1154
+ return (pos_loss.sum() + neg_loss.sum()) / num_pos
1155
+
1156
+
1157
+ def center_heatmap_loss(
1158
+ objectness_logits: torch.Tensor,
1159
+ target_bboxes: torch.Tensor,
1160
+ valid_mask: torch.Tensor,
1161
+ sigma_in_patches: float = 1.5,
1162
+ ) -> torch.Tensor:
1163
+ """Gaussian center heatmap supervision for localization objectness."""
1164
+ device = objectness_logits.device
1165
+ B, S = objectness_logits.shape
1166
+ if target_bboxes is None or valid_mask is None:
1167
+ return torch.tensor(0.0, device=device, requires_grad=True)
1168
+ if target_bboxes.dim() != 2 or target_bboxes.size(0) != B or target_bboxes.size(1) != 4:
1169
+ return torch.tensor(0.0, device=device, requires_grad=True)
1170
+
1171
+ valid = valid_mask > 0.5
1172
+ if not valid.any():
1173
+ return torch.tensor(0.0, device=device, requires_grad=True)
1174
+
1175
+ g = int(round(S ** 0.5))
1176
+ if g * g != S:
1177
+ return torch.tensor(0.0, device=device, requires_grad=True)
1178
+
1179
+ ys = (torch.arange(g, device=device, dtype=objectness_logits.dtype) + 0.5) / float(g)
1180
+ xs = (torch.arange(g, device=device, dtype=objectness_logits.dtype) + 0.5) / float(g)
1181
+ yy, xx = torch.meshgrid(ys, xs, indexing="ij")
1182
+
1183
+ tgt = target_bboxes[valid]
1184
+ cx = (0.5 * (tgt[:, 0] + tgt[:, 2])).view(-1, 1, 1)
1185
+ cy = (0.5 * (tgt[:, 1] + tgt[:, 3])).view(-1, 1, 1)
1186
+
1187
+ sigma = max(1e-4, float(sigma_in_patches) / float(g))
1188
+ d2 = (xx.view(1, g, g) - cx).pow(2) + (yy.view(1, g, g) - cy).pow(2)
1189
+ target_heat = torch.exp(-d2 / (2.0 * sigma * sigma)).clamp(0.0, 1.0)
1190
+
1191
+ Bv = target_heat.size(0)
1192
+ flat = target_heat.view(Bv, -1)
1193
+ peak_idx = flat.argmax(dim=1, keepdim=True)
1194
+ flat.scatter_(1, peak_idx, 1.0)
1195
+ target_heat = flat.view(Bv, g, g)
1196
+
1197
+ pred_logits_v = objectness_logits[valid].view(-1, g, g)
1198
+ return gaussian_focal_loss(pred_logits_v, target_heat)
1199
+
1200
+
1201
+ def direct_center_loss(
1202
+ pred_bboxes: torch.Tensor,
1203
+ target_bboxes: torch.Tensor,
1204
+ valid_mask: torch.Tensor,
1205
+ ) -> torch.Tensor:
1206
+ """Center-to-center SmoothL1 loss for localization."""
1207
+ if pred_bboxes is None or target_bboxes is None or valid_mask is None:
1208
+ dev = pred_bboxes.device if pred_bboxes is not None else "cpu"
1209
+ return torch.tensor(0.0, device=dev, requires_grad=True)
1210
+
1211
+ valid = valid_mask > 0.5
1212
+ if not valid.any():
1213
+ return torch.tensor(0.0, device=pred_bboxes.device, requires_grad=True)
1214
+
1215
+ pred_first = pred_bboxes[:, 0, :] if pred_bboxes.dim() == 3 else pred_bboxes
1216
+ pred = pred_first[valid]
1217
+ tgt = target_bboxes[valid]
1218
+
1219
+ pred_center = torch.stack([
1220
+ 0.5 * (pred[:, 0] + pred[:, 2]),
1221
+ 0.5 * (pred[:, 1] + pred[:, 3]),
1222
+ ], dim=1)
1223
+ tgt_center = torch.stack([
1224
+ 0.5 * (tgt[:, 0] + tgt[:, 2]),
1225
+ 0.5 * (tgt[:, 1] + tgt[:, 3]),
1226
+ ], dim=1)
1227
+
1228
+ return F.smooth_l1_loss(pred_center, tgt_center, reduction="mean")
1229
+
1230
+
1231
+ def frame_classification_loss(
1232
+ frame_logits: torch.Tensor,
1233
+ frame_labels: torch.Tensor,
1234
+ ) -> torch.Tensor:
1235
+ """Per-frame CE loss, ignoring frames with label -1."""
1236
+ B, T, C = frame_logits.shape
1237
+ device = frame_logits.device
1238
+ logits_flat = frame_logits.reshape(B * T, C)
1239
+ labels_flat = frame_labels.reshape(B * T)
1240
+ valid = labels_flat >= 0
1241
+ if not valid.any():
1242
+ return torch.tensor(0.0, device=device, requires_grad=True)
1243
+ return F.cross_entropy(logits_flat[valid], labels_flat[valid])
1244
+
1245
+
1246
+ def boundary_detection_loss(
1247
+ boundary_logits: torch.Tensor,
1248
+ boundary_labels: torch.Tensor,
1249
+ ) -> torch.Tensor:
1250
+ """BCE loss for boundary/change-point detection.
1251
+
1252
+ Args:
1253
+ boundary_logits: [B, T, 1] raw logits
1254
+ boundary_labels: [B, T] binary (1=transition, 0=no, -1=ignore)
1255
+ """
1256
+ B, T, _ = boundary_logits.shape
1257
+ device = boundary_logits.device
1258
+ logits_flat = boundary_logits.squeeze(-1) # [B, T]
1259
+ valid = boundary_labels >= 0
1260
+ if not valid.any():
1261
+ return torch.tensor(0.0, device=device, requires_grad=True)
1262
+
1263
+ n_pos = (boundary_labels[valid] > 0.5).float().sum().clamp(min=1.0)
1264
+ n_neg = (boundary_labels[valid] <= 0.5).float().sum().clamp(min=1.0)
1265
+ pos_weight = (n_neg / n_pos).clamp(max=20.0)
1266
+
1267
+ raw = F.binary_cross_entropy_with_logits(
1268
+ logits_flat[valid], boundary_labels[valid].float(),
1269
+ pos_weight=pos_weight, reduction='none',
1270
+ )
1271
+
1272
+ return raw.mean()
1273
+
1274
+
1275
+ def temporal_smoothness_loss(
1276
+ frame_logits: torch.Tensor,
1277
+ frame_labels: torch.Tensor,
1278
+ ) -> torch.Tensor:
1279
+ """L2 smoothness regularizer on consecutive frame logits."""
1280
+ B, T, C = frame_logits.shape
1281
+ if T < 2:
1282
+ return torch.tensor(0.0, device=frame_logits.device, requires_grad=True)
1283
+
1284
+ valid = frame_labels >= 0
1285
+ both_valid = valid[:, :-1] & valid[:, 1:]
1286
+ if not both_valid.any():
1287
+ return torch.tensor(0.0, device=frame_logits.device, requires_grad=True)
1288
+
1289
+ diff = (frame_logits[:, 1:, :] - frame_logits[:, :-1, :]) ** 2
1290
+ return diff.mean(dim=-1)[both_valid].mean()