singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,1290 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import gc
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
# JAX memory config MUST be set before importing jax.
|
|
8
|
+
# Without this, JAX grabs 75-90% of GPU memory upfront, leaving PyTorch
|
|
9
|
+
# starved and eventually causing CUDA_ERROR_ILLEGAL_ADDRESS under sustained
|
|
10
|
+
# inference workloads where both frameworks share the same GPU.
|
|
11
|
+
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
|
12
|
+
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from typing import Optional, Callable
|
|
18
|
+
import numpy as np
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
from videoprism import models as vp
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def interpolate_pos_embed_2d(pos_embed: np.ndarray, orig_grid: int, new_grid: int) -> np.ndarray:
|
|
25
|
+
"""
|
|
26
|
+
Bicubic interpolation of 2D spatial position embeddings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
pos_embed: [N, D] where N = orig_grid * orig_grid
|
|
30
|
+
orig_grid: original spatial grid size (e.g., 16 for 288px)
|
|
31
|
+
new_grid: target spatial grid size (e.g., 19 for 342px)
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Interpolated pos_embed: [new_grid*new_grid, D]
|
|
35
|
+
"""
|
|
36
|
+
if orig_grid == new_grid:
|
|
37
|
+
return pos_embed
|
|
38
|
+
|
|
39
|
+
D = pos_embed.shape[-1]
|
|
40
|
+
pos_2d = pos_embed.reshape(orig_grid, orig_grid, D)
|
|
41
|
+
pos_torch = torch.from_numpy(pos_2d).permute(2, 0, 1).unsqueeze(0).float()
|
|
42
|
+
|
|
43
|
+
pos_interp = F.interpolate(
|
|
44
|
+
pos_torch,
|
|
45
|
+
size=(new_grid, new_grid),
|
|
46
|
+
mode='bicubic',
|
|
47
|
+
align_corners=False
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
pos_interp = pos_interp.squeeze(0).permute(1, 2, 0).numpy()
|
|
51
|
+
return pos_interp.reshape(new_grid * new_grid, D)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class VideoPrismBackbone(nn.Module):
|
|
55
|
+
"""VideoPrism backbone wrapper for PyTorch compatibility."""
|
|
56
|
+
|
|
57
|
+
DEFAULT_RESOLUTION = 288
|
|
58
|
+
PATCH_SIZE = 18
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
model_name: str = 'videoprism_public_v1_base',
|
|
63
|
+
resolution: int = 288,
|
|
64
|
+
log_fn: Optional[Callable[[str], None]] = None
|
|
65
|
+
):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.model_name = model_name
|
|
68
|
+
self.resolution = resolution
|
|
69
|
+
self.flax_model = None
|
|
70
|
+
self.params = None
|
|
71
|
+
self.original_params = None
|
|
72
|
+
self.jax_device = None
|
|
73
|
+
self._forward_fn = None
|
|
74
|
+
self.log_fn = log_fn or print
|
|
75
|
+
self._current_grid_size = None
|
|
76
|
+
self.enable_dlpack = os.environ.get("BEHAVIOR_APP_USE_DLPACK", "0").strip().lower() in ("1", "true", "yes", "on")
|
|
77
|
+
self._load_model()
|
|
78
|
+
|
|
79
|
+
def _is_jax_gpu(self) -> bool:
|
|
80
|
+
d = self.jax_device
|
|
81
|
+
return (
|
|
82
|
+
d.device_kind == 'gpu' or
|
|
83
|
+
'gpu' in d.device_kind.lower() or
|
|
84
|
+
'cuda' in str(d).lower() or
|
|
85
|
+
d.platform in ['gpu', 'cuda']
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _load_model(self):
|
|
89
|
+
"""Load VideoPrism model (JAX/Flax) and configure for GPU if available."""
|
|
90
|
+
try:
|
|
91
|
+
import copy
|
|
92
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
93
|
+
os.environ['XLA_FLAGS'] = '--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found'
|
|
94
|
+
|
|
95
|
+
gc.collect()
|
|
96
|
+
if torch.cuda.is_available():
|
|
97
|
+
torch.cuda.empty_cache()
|
|
98
|
+
torch.cuda.synchronize()
|
|
99
|
+
self.log_fn("Cleared PyTorch GPU cache before loading VideoPrism")
|
|
100
|
+
|
|
101
|
+
devices = jax.devices()
|
|
102
|
+
self.log_fn(f"JAX devices found: {len(devices)}")
|
|
103
|
+
for i, d in enumerate(devices):
|
|
104
|
+
self.log_fn(f" Device {i}: {d} (kind: {d.device_kind}, platform: {d.platform})")
|
|
105
|
+
|
|
106
|
+
gpu_devices = [
|
|
107
|
+
d for d in devices
|
|
108
|
+
if (d.device_kind == 'gpu' or
|
|
109
|
+
'gpu' in d.device_kind.lower() or
|
|
110
|
+
'cuda' in str(d).lower() or
|
|
111
|
+
d.platform in ['gpu', 'cuda'])
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
if gpu_devices:
|
|
115
|
+
self.jax_device = gpu_devices[0]
|
|
116
|
+
self.log_fn(f"JAX GPU device selected: {self.jax_device}")
|
|
117
|
+
else:
|
|
118
|
+
self.jax_device = jax.devices('cpu')[0]
|
|
119
|
+
self.log_fn(f"JAX using CPU device: {self.jax_device} (GPU not available)")
|
|
120
|
+
|
|
121
|
+
self.log_fn("Loading VideoPrism model...")
|
|
122
|
+
self.flax_model = vp.get_model(self.model_name)
|
|
123
|
+
self.params = vp.load_pretrained_weights(self.model_name)
|
|
124
|
+
self.original_params = copy.deepcopy(self.params)
|
|
125
|
+
|
|
126
|
+
orig_grid = self.DEFAULT_RESOLUTION // self.PATCH_SIZE
|
|
127
|
+
new_grid = self.resolution // self.PATCH_SIZE
|
|
128
|
+
self._current_grid_size = new_grid
|
|
129
|
+
|
|
130
|
+
if new_grid != orig_grid:
|
|
131
|
+
self.log_fn(f"Resolution {self.resolution}x{self.resolution} -> {new_grid}x{new_grid} spatial grid")
|
|
132
|
+
self.params = self._interpolate_spatial_pos_embed(self.params, orig_grid, new_grid)
|
|
133
|
+
else:
|
|
134
|
+
self.log_fn(f"Using default resolution {self.resolution}x{self.resolution} ({new_grid}x{new_grid} grid)")
|
|
135
|
+
|
|
136
|
+
if self._is_jax_gpu():
|
|
137
|
+
self.log_fn("Moving VideoPrism parameters to GPU...")
|
|
138
|
+
self.params = jax.device_put(self.params, self.jax_device)
|
|
139
|
+
|
|
140
|
+
@jax.jit
|
|
141
|
+
def forward_fn(params, videos_bthwc: jnp.ndarray) -> jnp.ndarray:
|
|
142
|
+
embeddings, _ = self.flax_model.apply(
|
|
143
|
+
params, videos_bthwc, train=False, return_intermediate=False
|
|
144
|
+
)
|
|
145
|
+
return embeddings
|
|
146
|
+
|
|
147
|
+
self._forward_fn = forward_fn
|
|
148
|
+
self.log_fn(f"Loaded VideoPrism model: {self.model_name} (device: {self.jax_device.device_kind})")
|
|
149
|
+
|
|
150
|
+
self.log_fn(f"Warming up JIT compilation at {self.resolution}x{self.resolution}...")
|
|
151
|
+
dummy_batch = jnp.zeros((1, 16, self.resolution, self.resolution, 3), dtype=jnp.float32)
|
|
152
|
+
dummy_batch = jax.device_put(dummy_batch, self.jax_device)
|
|
153
|
+
_ = self._forward_fn(self.params, dummy_batch)
|
|
154
|
+
self.log_fn("JIT compilation complete")
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
error_msg = f"Error loading VideoPrism model: {e}\n"
|
|
158
|
+
error_msg += "This may be due to CuDNN version mismatch or missing dependencies."
|
|
159
|
+
if "DNN library initialization failed" in str(e) or "FAILED_PRECONDITION" in str(e):
|
|
160
|
+
error_msg += (
|
|
161
|
+
"\n\nTry:\n"
|
|
162
|
+
"conda activate singlebehaviorlab\n"
|
|
163
|
+
"pip install --upgrade torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124\n"
|
|
164
|
+
"pip install --upgrade \"jax[cuda12]==0.6.2\" flax==0.10.7\n"
|
|
165
|
+
"pip install --upgrade \"nvidia-cudnn-cu12==9.20.0.48\""
|
|
166
|
+
)
|
|
167
|
+
self.log_fn(error_msg)
|
|
168
|
+
import traceback
|
|
169
|
+
self.log_fn(traceback.format_exc())
|
|
170
|
+
raise RuntimeError(error_msg) from e
|
|
171
|
+
|
|
172
|
+
def _interpolate_spatial_pos_embed(self, params: dict, orig_grid: int, new_grid: int) -> dict:
|
|
173
|
+
import copy
|
|
174
|
+
params = copy.deepcopy(params)
|
|
175
|
+
|
|
176
|
+
def find_and_interpolate(d, path=""):
|
|
177
|
+
if isinstance(d, dict):
|
|
178
|
+
for k, v in d.items():
|
|
179
|
+
new_path = f"{path}.{k}" if path else k
|
|
180
|
+
d[k] = find_and_interpolate(v, new_path)
|
|
181
|
+
elif isinstance(d, (np.ndarray, jnp.ndarray)):
|
|
182
|
+
arr = np.asarray(d)
|
|
183
|
+
expected_spatial = orig_grid * orig_grid
|
|
184
|
+
is_pos_embed = any(name in path.lower() for name in ['pos_embed', 'posembed', 'position'])
|
|
185
|
+
|
|
186
|
+
if is_pos_embed and arr.ndim >= 2:
|
|
187
|
+
if arr.shape[-2] == expected_spatial:
|
|
188
|
+
D = arr.shape[-1]
|
|
189
|
+
self.log_fn(f" Interpolating {path}: [{expected_spatial}, {D}] -> [{new_grid*new_grid}, {D}]")
|
|
190
|
+
arr_interp = interpolate_pos_embed_2d(arr, orig_grid, new_grid)
|
|
191
|
+
return jnp.asarray(arr_interp)
|
|
192
|
+
elif arr.ndim == 3 and arr.shape[1] == expected_spatial:
|
|
193
|
+
D = arr.shape[-1]
|
|
194
|
+
self.log_fn(f" Interpolating {path}: [1, {expected_spatial}, {D}] -> [1, {new_grid*new_grid}, {D}]")
|
|
195
|
+
arr_interp = interpolate_pos_embed_2d(arr[0], orig_grid, new_grid)
|
|
196
|
+
return jnp.asarray(arr_interp[np.newaxis, :, :])
|
|
197
|
+
elif arr.ndim == 3 and arr.shape[1] == expected_spatial + 1:
|
|
198
|
+
D = arr.shape[-1]
|
|
199
|
+
self.log_fn(f" Interpolating {path}: [1, 1+{expected_spatial}, {D}] -> [1, 1+{new_grid*new_grid}, {D}]")
|
|
200
|
+
cls_token = arr[:, :1, :]
|
|
201
|
+
spatial = arr[0, 1:, :]
|
|
202
|
+
spatial_interp = interpolate_pos_embed_2d(spatial, orig_grid, new_grid)
|
|
203
|
+
return jnp.concatenate([cls_token, spatial_interp[np.newaxis, :, :]], axis=1)
|
|
204
|
+
return d
|
|
205
|
+
return d
|
|
206
|
+
|
|
207
|
+
params = find_and_interpolate(params)
|
|
208
|
+
return params
|
|
209
|
+
|
|
210
|
+
def set_resolution(self, resolution: int):
|
|
211
|
+
if resolution == self.resolution:
|
|
212
|
+
return
|
|
213
|
+
self.resolution = resolution
|
|
214
|
+
orig_grid = self.DEFAULT_RESOLUTION // self.PATCH_SIZE
|
|
215
|
+
new_grid = resolution // self.PATCH_SIZE
|
|
216
|
+
self._current_grid_size = new_grid
|
|
217
|
+
|
|
218
|
+
if self.original_params is not None:
|
|
219
|
+
self.log_fn(f"Re-interpolating spatial pos embeddings for resolution {resolution}...")
|
|
220
|
+
import copy
|
|
221
|
+
self.params = copy.deepcopy(self.original_params)
|
|
222
|
+
if new_grid != orig_grid:
|
|
223
|
+
self.params = self._interpolate_spatial_pos_embed(self.params, orig_grid, new_grid)
|
|
224
|
+
if self._is_jax_gpu():
|
|
225
|
+
self.params = jax.device_put(self.params, self.jax_device)
|
|
226
|
+
self.log_fn(f"Resolution updated to {resolution}x{resolution}")
|
|
227
|
+
|
|
228
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
229
|
+
"""
|
|
230
|
+
Args:
|
|
231
|
+
x: Input tensor [B, T, C, H, W] in [0, 1] float32
|
|
232
|
+
Returns:
|
|
233
|
+
Token embeddings [B, N, D] where N = T * S spatial tokens
|
|
234
|
+
"""
|
|
235
|
+
if self.flax_model is None or self.params is None or self._forward_fn is None:
|
|
236
|
+
raise RuntimeError("Model not loaded")
|
|
237
|
+
|
|
238
|
+
device = x.device
|
|
239
|
+
x_bthwc = x.permute(0, 1, 3, 4, 2).contiguous()
|
|
240
|
+
|
|
241
|
+
use_dlpack = (
|
|
242
|
+
self.enable_dlpack and
|
|
243
|
+
self.resolution <= 288 and
|
|
244
|
+
x_bthwc.is_cuda and
|
|
245
|
+
self.jax_device is not None and
|
|
246
|
+
self.jax_device.platform in ["gpu", "cuda"]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def _to_jax_tensor(prefer_dlpack: bool):
|
|
250
|
+
if prefer_dlpack:
|
|
251
|
+
try:
|
|
252
|
+
x_jax_local = jax.dlpack.from_dlpack(x_bthwc)
|
|
253
|
+
except Exception:
|
|
254
|
+
import torch.utils.dlpack as torch_dlpack
|
|
255
|
+
x_dlpack = torch_dlpack.to_dlpack(x_bthwc)
|
|
256
|
+
x_jax_local = jax.dlpack.from_dlpack(x_dlpack)
|
|
257
|
+
if x_jax_local.device != self.jax_device:
|
|
258
|
+
x_jax_local = jax.device_put(x_jax_local, self.jax_device)
|
|
259
|
+
return x_jax_local, True
|
|
260
|
+
x_np = x_bthwc.detach().cpu().contiguous().numpy()
|
|
261
|
+
x_jax_local = jnp.asarray(x_np)
|
|
262
|
+
x_jax_local = jax.device_put(x_jax_local, self.jax_device)
|
|
263
|
+
return x_jax_local, False
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
x_jax, used_dlpack = _to_jax_tensor(use_dlpack)
|
|
267
|
+
embeddings_jax = self._forward_fn(self.params, x_jax)
|
|
268
|
+
embeddings_jax.block_until_ready()
|
|
269
|
+
del x_jax
|
|
270
|
+
except Exception as e:
|
|
271
|
+
if use_dlpack:
|
|
272
|
+
self.log_fn(f"Warning: DLPack path failed ({e}). Retrying with safe host-copy transfer.")
|
|
273
|
+
x_jax, used_dlpack = _to_jax_tensor(False)
|
|
274
|
+
embeddings_jax = self._forward_fn(self.params, x_jax)
|
|
275
|
+
embeddings_jax.block_until_ready()
|
|
276
|
+
del x_jax
|
|
277
|
+
else:
|
|
278
|
+
raise
|
|
279
|
+
|
|
280
|
+
embeddings_np = np.asarray(embeddings_jax)
|
|
281
|
+
del embeddings_jax
|
|
282
|
+
embeddings_torch = torch.from_numpy(embeddings_np.copy()).to(device)
|
|
283
|
+
del embeddings_np
|
|
284
|
+
return embeddings_torch
|
|
285
|
+
|
|
286
|
+
def get_embed_dim(self) -> int:
|
|
287
|
+
dummy_input = torch.zeros(1, 16, 3, self.resolution, self.resolution)
|
|
288
|
+
with torch.no_grad():
|
|
289
|
+
tokens = self.forward(dummy_input)
|
|
290
|
+
return tokens.shape[-1]
|
|
291
|
+
|
|
292
|
+
def get_num_tokens(self) -> int:
|
|
293
|
+
grid_size = self.resolution // self.PATCH_SIZE
|
|
294
|
+
return grid_size * grid_size
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Stage B: frame decoder and localization heads.
|
|
298
|
+
|
|
299
|
+
class SpatialAttentionPool(nn.Module):
|
|
300
|
+
"""Trainable per-frame spatial attention pooling.
|
|
301
|
+
|
|
302
|
+
For each frame, a learned query attends over S spatial tokens
|
|
303
|
+
and an MLP bottleneck projects the result to proj_dim.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(self, embed_dim: int, proj_dim: int = 256,
|
|
307
|
+
num_heads: int = 4, dropout: float = 0.1):
|
|
308
|
+
super().__init__()
|
|
309
|
+
self.embed_dim = embed_dim
|
|
310
|
+
self.proj_dim = proj_dim
|
|
311
|
+
self.ln = nn.LayerNorm(embed_dim)
|
|
312
|
+
self.query = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
|
|
313
|
+
self.attn = nn.MultiheadAttention(
|
|
314
|
+
embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True,
|
|
315
|
+
)
|
|
316
|
+
self.mlp = nn.Sequential(
|
|
317
|
+
nn.Linear(embed_dim, proj_dim),
|
|
318
|
+
nn.GELU(),
|
|
319
|
+
nn.Dropout(dropout),
|
|
320
|
+
nn.Linear(proj_dim, proj_dim),
|
|
321
|
+
nn.Dropout(dropout),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def forward(self, tokens: torch.Tensor, num_frames: int,
|
|
325
|
+
return_attn_weights: bool = False):
|
|
326
|
+
"""
|
|
327
|
+
Args:
|
|
328
|
+
tokens: [B, T*S, D]
|
|
329
|
+
num_frames: T
|
|
330
|
+
return_attn_weights: if True, also return spatial attention maps
|
|
331
|
+
Returns:
|
|
332
|
+
[B, T, proj_dim] or ([B, T, proj_dim], [B, T, num_heads, S])
|
|
333
|
+
"""
|
|
334
|
+
B, N, D = tokens.shape
|
|
335
|
+
S = N // num_frames
|
|
336
|
+
T = num_frames
|
|
337
|
+
|
|
338
|
+
flat = tokens.view(B, T, S, D).reshape(B * T, S, D)
|
|
339
|
+
flat = self.ln(flat)
|
|
340
|
+
|
|
341
|
+
q = self.query.expand(B * T, -1, -1)
|
|
342
|
+
pooled, attn_w = self.attn(q, flat, flat, need_weights=return_attn_weights,
|
|
343
|
+
average_attn_weights=False)
|
|
344
|
+
pooled = pooled.squeeze(1) # [B*T, D]
|
|
345
|
+
|
|
346
|
+
out = self.mlp(pooled) # [B*T, proj_dim]
|
|
347
|
+
|
|
348
|
+
if return_attn_weights and attn_w is not None:
|
|
349
|
+
# attn_w: [B*T, num_heads, 1, S] -> [B, T, num_heads, S]
|
|
350
|
+
attn_w = attn_w.squeeze(2).view(B, T, -1, S)
|
|
351
|
+
return out.view(B, T, self.proj_dim), attn_w
|
|
352
|
+
|
|
353
|
+
return out.view(B, T, self.proj_dim)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class _DilatedResidualBlock(nn.Module):
|
|
357
|
+
def __init__(self, channels: int, dilation: int, k: int, p: float):
|
|
358
|
+
super().__init__()
|
|
359
|
+
self.block = nn.Sequential(
|
|
360
|
+
nn.Conv1d(channels, channels, kernel_size=k,
|
|
361
|
+
padding=dilation, dilation=dilation, bias=False),
|
|
362
|
+
nn.GELU(),
|
|
363
|
+
nn.Dropout(p),
|
|
364
|
+
nn.Conv1d(channels, channels, kernel_size=1, bias=False),
|
|
365
|
+
nn.Dropout(p),
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
369
|
+
return x + self.block(x)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class _MSRefineStage(nn.Module):
|
|
373
|
+
def __init__(self, in_channels: int, hidden_channels: int,
|
|
374
|
+
out_channels: int, layers: int, k: int, p: float):
|
|
375
|
+
super().__init__()
|
|
376
|
+
self.in_proj = nn.Conv1d(in_channels, hidden_channels,
|
|
377
|
+
kernel_size=1, bias=False)
|
|
378
|
+
self.blocks = nn.ModuleList([
|
|
379
|
+
_DilatedResidualBlock(hidden_channels, dilation=2 ** i, k=k, p=p)
|
|
380
|
+
for i in range(max(1, int(layers)))
|
|
381
|
+
])
|
|
382
|
+
self.out_proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=1)
|
|
383
|
+
|
|
384
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
385
|
+
y = self.in_proj(x)
|
|
386
|
+
for blk in self.blocks:
|
|
387
|
+
y = blk(y)
|
|
388
|
+
return self.out_proj(y)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class DilatedTemporalHead(nn.Module):
|
|
392
|
+
"""Frame decoder head (Stage B).
|
|
393
|
+
|
|
394
|
+
Trainable spatial attention pooling -> local TCN -> state + boundary heads.
|
|
395
|
+
Produces per-frame state logits, boundary logits, and frame embeddings.
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
def __init__(
|
|
399
|
+
self,
|
|
400
|
+
embed_dim: int,
|
|
401
|
+
num_classes: int,
|
|
402
|
+
num_layers: int = 4,
|
|
403
|
+
kernel_size: int = 3,
|
|
404
|
+
dropout: float = 0.2,
|
|
405
|
+
num_stages: int = 3,
|
|
406
|
+
hidden_dim: Optional[int] = None,
|
|
407
|
+
temporal_pool: int = 1,
|
|
408
|
+
proj_dim: int = 256,
|
|
409
|
+
spatial_pool_heads: int = 4,
|
|
410
|
+
multi_scale: bool = False,
|
|
411
|
+
use_temporal_decoder: bool = True,
|
|
412
|
+
):
|
|
413
|
+
super().__init__()
|
|
414
|
+
self.embed_dim = int(embed_dim)
|
|
415
|
+
self.num_classes = int(num_classes)
|
|
416
|
+
self.num_layers = max(1, int(num_layers))
|
|
417
|
+
self.num_stages = max(1, int(num_stages))
|
|
418
|
+
self.dropout = float(dropout)
|
|
419
|
+
self.temporal_pool = max(1, int(temporal_pool))
|
|
420
|
+
self.use_ovr = False
|
|
421
|
+
self.proj_dim = int(proj_dim)
|
|
422
|
+
self.multi_scale = bool(multi_scale)
|
|
423
|
+
self.use_temporal_decoder = bool(use_temporal_decoder)
|
|
424
|
+
auto_hidden = max(128, self.proj_dim)
|
|
425
|
+
self.hidden_dim = int(hidden_dim) if hidden_dim is not None else auto_hidden
|
|
426
|
+
|
|
427
|
+
self.spatial_pool = SpatialAttentionPool(
|
|
428
|
+
embed_dim=self.embed_dim,
|
|
429
|
+
proj_dim=self.proj_dim,
|
|
430
|
+
num_heads=int(spatial_pool_heads),
|
|
431
|
+
dropout=self.dropout,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Raw per-frame projection: simple mean-pool of spatial tokens → proj_dim.
|
|
435
|
+
# Not filtered by learned attention, so it retains broader spatial context
|
|
436
|
+
# that SpatialAttentionPool may suppress in favour of discriminative patches.
|
|
437
|
+
self.raw_proj = nn.Linear(self.embed_dim, self.proj_dim)
|
|
438
|
+
|
|
439
|
+
# When multi_scale=True, long-scale and short-scale features are concatenated,
|
|
440
|
+
# so the TCN input is 2*proj_dim instead of proj_dim.
|
|
441
|
+
tcn_in = self.proj_dim * 2 if self.multi_scale else self.proj_dim
|
|
442
|
+
if self.use_temporal_decoder:
|
|
443
|
+
self.stage1 = _MSRefineStage(
|
|
444
|
+
in_channels=tcn_in,
|
|
445
|
+
hidden_channels=self.hidden_dim,
|
|
446
|
+
out_channels=self.num_classes,
|
|
447
|
+
layers=self.num_layers,
|
|
448
|
+
k=kernel_size,
|
|
449
|
+
p=self.dropout,
|
|
450
|
+
)
|
|
451
|
+
self.refine_stages = nn.ModuleList()
|
|
452
|
+
for _ in range(self.num_stages - 1):
|
|
453
|
+
self.refine_stages.append(
|
|
454
|
+
_MSRefineStage(
|
|
455
|
+
in_channels=self.num_classes,
|
|
456
|
+
hidden_channels=self.hidden_dim,
|
|
457
|
+
out_channels=self.num_classes,
|
|
458
|
+
layers=self.num_layers,
|
|
459
|
+
k=kernel_size,
|
|
460
|
+
p=self.dropout,
|
|
461
|
+
)
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
self.boundary_tcn = _MSRefineStage(
|
|
465
|
+
in_channels=tcn_in,
|
|
466
|
+
hidden_channels=self.hidden_dim,
|
|
467
|
+
out_channels=1,
|
|
468
|
+
layers=max(1, self.num_layers // 2),
|
|
469
|
+
k=kernel_size,
|
|
470
|
+
p=self.dropout,
|
|
471
|
+
)
|
|
472
|
+
self.frame_classifier = None
|
|
473
|
+
else:
|
|
474
|
+
self.stage1 = None
|
|
475
|
+
self.refine_stages = nn.ModuleList()
|
|
476
|
+
self.boundary_tcn = None
|
|
477
|
+
self.frame_classifier = nn.Sequential(
|
|
478
|
+
nn.LayerNorm(tcn_in),
|
|
479
|
+
nn.Linear(tcn_in, self.num_classes),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def forward(self, tokens: torch.Tensor, num_frames: int,
|
|
483
|
+
tokens_short: Optional[torch.Tensor] = None,
|
|
484
|
+
num_frames_short: Optional[int] = None,
|
|
485
|
+
return_attn_weights: bool = False):
|
|
486
|
+
"""
|
|
487
|
+
Args:
|
|
488
|
+
tokens: [B, T*S, D] long-scale backbone tokens
|
|
489
|
+
num_frames: T (long scale)
|
|
490
|
+
tokens_short: [B, T_s*S, D] short-scale tokens (half fps, T_s = T//2).
|
|
491
|
+
Required when multi_scale=True.
|
|
492
|
+
num_frames_short: T_s
|
|
493
|
+
return_attn_weights: if True, include spatial attention maps in output
|
|
494
|
+
Returns:
|
|
495
|
+
frame_logits: [B, T, C]
|
|
496
|
+
clip_logits: [B, C]
|
|
497
|
+
temporal_weights: [B, T]
|
|
498
|
+
frame_logits_pooled: [B, T_pooled, C]
|
|
499
|
+
boundary_logits: [B, T, 1]
|
|
500
|
+
frame_embeddings: [B, T, proj_dim] (long-scale attention pool)
|
|
501
|
+
frame_embeddings_combined: [B, T, 2*proj_dim] (attn || raw_mean, long scale)
|
|
502
|
+
attn_weights: [B, T, num_heads, S] or None
|
|
503
|
+
"""
|
|
504
|
+
B, N, D = tokens.shape
|
|
505
|
+
if num_frames <= 0 or (N % num_frames) != 0:
|
|
506
|
+
raise ValueError(
|
|
507
|
+
f"DilatedTemporalHead expects T*S tokens. "
|
|
508
|
+
f"Got N={N}, num_frames={num_frames}."
|
|
509
|
+
)
|
|
510
|
+
T = num_frames
|
|
511
|
+
|
|
512
|
+
# 1. Spatial attention pooling (long scale): [B, T*S, D] -> [B, T, proj_dim]
|
|
513
|
+
attn_weights = None
|
|
514
|
+
pool_out = self.spatial_pool(tokens, T, return_attn_weights=return_attn_weights)
|
|
515
|
+
if return_attn_weights and isinstance(pool_out, tuple):
|
|
516
|
+
x_long, attn_weights = pool_out
|
|
517
|
+
else:
|
|
518
|
+
x_long = pool_out
|
|
519
|
+
# frame_embeddings is the long-scale per-frame feature
|
|
520
|
+
frame_embeddings = x_long # [B, T, proj_dim]
|
|
521
|
+
|
|
522
|
+
# Multi-scale: pool short-scale tokens and upsample to align with T
|
|
523
|
+
if self.multi_scale:
|
|
524
|
+
if tokens_short is None or num_frames_short is None:
|
|
525
|
+
raise ValueError(
|
|
526
|
+
"DilatedTemporalHead has multi_scale=True but tokens_short / "
|
|
527
|
+
"num_frames_short were not provided. Ensure the dataset has "
|
|
528
|
+
"_emb_multi_scale=True and short-scale embeddings are cached."
|
|
529
|
+
)
|
|
530
|
+
T_short = num_frames_short
|
|
531
|
+
x_short = self.spatial_pool(tokens_short, T_short) # [B, T_short, proj_dim]
|
|
532
|
+
scale_factor = T // T_short
|
|
533
|
+
# Nearest-neighbour upsample: each short frame covers scale_factor long frames
|
|
534
|
+
x_short_up = x_short.repeat_interleave(scale_factor, dim=1)[:, :T, :]
|
|
535
|
+
x = torch.cat([x_long, x_short_up], dim=-1) # [B, T, 2*proj_dim]
|
|
536
|
+
else:
|
|
537
|
+
x = x_long # [B, T, proj_dim]
|
|
538
|
+
|
|
539
|
+
if self.use_temporal_decoder:
|
|
540
|
+
# 2. Boundary prediction at full T (before temporal pooling reduces resolution)
|
|
541
|
+
emb_conv = x.transpose(1, 2) # [B, tcn_in, T]
|
|
542
|
+
boundary_logits = self.boundary_tcn(emb_conv).transpose(1, 2) # [B, T, 1]
|
|
543
|
+
|
|
544
|
+
T_pooled = T
|
|
545
|
+
if self.temporal_pool > 1:
|
|
546
|
+
p = self.temporal_pool
|
|
547
|
+
pad = (p - T % p) % p
|
|
548
|
+
if pad > 0:
|
|
549
|
+
x = torch.cat([x, x[:, -1:, :].expand(-1, pad, -1)], dim=1)
|
|
550
|
+
T_padded = T + pad
|
|
551
|
+
feat_dim = x.shape[-1]
|
|
552
|
+
x = x.view(B, T_padded // p, p, feat_dim).mean(dim=2)
|
|
553
|
+
T_pooled = T_padded // p
|
|
554
|
+
|
|
555
|
+
x_conv = x.transpose(1, 2) # [B, tcn_in, T_pooled]
|
|
556
|
+
stage_logits = self.stage1(x_conv)
|
|
557
|
+
for refine in self.refine_stages:
|
|
558
|
+
if self.use_ovr:
|
|
559
|
+
refine_in = torch.sigmoid(stage_logits)
|
|
560
|
+
else:
|
|
561
|
+
refine_in = torch.softmax(stage_logits, dim=1)
|
|
562
|
+
stage_logits = stage_logits + refine(refine_in)
|
|
563
|
+
|
|
564
|
+
stage_logits_pooled = stage_logits
|
|
565
|
+
clip_logits = stage_logits_pooled.mean(dim=2)
|
|
566
|
+
|
|
567
|
+
if self.temporal_pool > 1:
|
|
568
|
+
stage_logits = stage_logits.repeat_interleave(
|
|
569
|
+
self.temporal_pool, dim=2
|
|
570
|
+
)[:, :, :T]
|
|
571
|
+
|
|
572
|
+
frame_logits = stage_logits.transpose(1, 2)
|
|
573
|
+
frame_logits_pooled = stage_logits_pooled.transpose(1, 2)
|
|
574
|
+
else:
|
|
575
|
+
boundary_logits = None
|
|
576
|
+
frame_logits = self.frame_classifier(x)
|
|
577
|
+
frame_logits_pooled = frame_logits
|
|
578
|
+
clip_logits = frame_logits.mean(dim=1)
|
|
579
|
+
temporal_weights = torch.ones(B, T, device=tokens.device) / T
|
|
580
|
+
|
|
581
|
+
# Raw mean-pool per frame from long-scale tokens: bypasses spatial attention
|
|
582
|
+
# so downstream analysis can access a less task-biased view as well.
|
|
583
|
+
S = N // T
|
|
584
|
+
raw_mean = tokens.view(B, T, S, self.embed_dim).mean(dim=2)
|
|
585
|
+
raw_emb = self.raw_proj(raw_mean) # [B, T, proj_dim]
|
|
586
|
+
# Combined embedding: [attention_emb || raw_emb], shape [B, T, 2*proj_dim]
|
|
587
|
+
frame_embeddings_combined = torch.cat([frame_embeddings, raw_emb], dim=-1)
|
|
588
|
+
|
|
589
|
+
return (frame_logits, clip_logits, temporal_weights,
|
|
590
|
+
frame_logits_pooled, boundary_logits, frame_embeddings,
|
|
591
|
+
frame_embeddings_combined, attn_weights)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
# Spatial localization head.
|
|
595
|
+
|
|
596
|
+
class SpatialLocalizationHead(nn.Module):
|
|
597
|
+
"""Frame-level bbox regressor over spatial tokens.
|
|
598
|
+
|
|
599
|
+
Design:
|
|
600
|
+
1) Dense objectness map over spatial patches per frame.
|
|
601
|
+
2) Upsampled objectness for sub-patch precision.
|
|
602
|
+
3) Temperature-scaled soft-argmax for coarse center.
|
|
603
|
+
4) Local token gathering for context-aware refinement.
|
|
604
|
+
5) Temporal residual refiner over center trajectories.
|
|
605
|
+
6) Fixed box size from dataset stats for stable crops.
|
|
606
|
+
"""
|
|
607
|
+
|
|
608
|
+
UPSAMPLE_FACTOR = 4
|
|
609
|
+
|
|
610
|
+
def __init__(self, embed_dim: int, hidden_dim: int = 256, dropout: float = 0.1):
|
|
611
|
+
super().__init__()
|
|
612
|
+
self.norm = nn.LayerNorm(embed_dim)
|
|
613
|
+
self.query = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
|
|
614
|
+
self.attn = nn.MultiheadAttention(
|
|
615
|
+
embed_dim, num_heads=4, dropout=dropout, batch_first=True,
|
|
616
|
+
)
|
|
617
|
+
self.attn_norm = nn.LayerNorm(embed_dim)
|
|
618
|
+
self.objectness = nn.Linear(embed_dim, 1)
|
|
619
|
+
self.temperature = nn.Parameter(torch.tensor(8.0))
|
|
620
|
+
|
|
621
|
+
self.refine = nn.Sequential(
|
|
622
|
+
nn.Linear(embed_dim * 2, hidden_dim),
|
|
623
|
+
nn.GELU(),
|
|
624
|
+
nn.Dropout(dropout),
|
|
625
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
626
|
+
nn.GELU(),
|
|
627
|
+
nn.Dropout(dropout),
|
|
628
|
+
nn.Linear(hidden_dim // 2, 2),
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
temporal_hidden = max(16, hidden_dim // 8)
|
|
632
|
+
self.temporal_refine = nn.Sequential(
|
|
633
|
+
nn.Conv1d(2, temporal_hidden, kernel_size=5, padding=2),
|
|
634
|
+
nn.GELU(),
|
|
635
|
+
nn.Dropout(dropout),
|
|
636
|
+
nn.Conv1d(temporal_hidden, 2, kernel_size=5, padding=2),
|
|
637
|
+
)
|
|
638
|
+
final_temporal = self.temporal_refine[-1]
|
|
639
|
+
if isinstance(final_temporal, nn.Conv1d):
|
|
640
|
+
nn.init.zeros_(final_temporal.weight)
|
|
641
|
+
if final_temporal.bias is not None:
|
|
642
|
+
nn.init.zeros_(final_temporal.bias)
|
|
643
|
+
|
|
644
|
+
proj_dim = hidden_dim // 2
|
|
645
|
+
self.contrastive_proj = nn.Sequential(
|
|
646
|
+
nn.Linear(embed_dim, hidden_dim),
|
|
647
|
+
nn.GELU(),
|
|
648
|
+
nn.Linear(hidden_dim, proj_dim),
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
self.register_buffer(
|
|
652
|
+
"fixed_box_wh", torch.tensor([0.2, 0.2], dtype=torch.float32)
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
def set_fixed_box_size(self, width: float, height: float):
|
|
656
|
+
w = float(max(1e-4, min(1.0, width)))
|
|
657
|
+
h = float(max(1e-4, min(1.0, height)))
|
|
658
|
+
self.fixed_box_wh.data = torch.tensor(
|
|
659
|
+
[w, h], dtype=self.fixed_box_wh.dtype, device=self.fixed_box_wh.device,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
def get_contrastive_tokens(
|
|
663
|
+
self, tokens: torch.Tensor, num_frames: Optional[int] = None
|
|
664
|
+
) -> torch.Tensor:
|
|
665
|
+
B, N, D = tokens.shape
|
|
666
|
+
normed = self.norm(tokens)
|
|
667
|
+
S = N
|
|
668
|
+
if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
|
|
669
|
+
S = N // int(num_frames)
|
|
670
|
+
first_frame = normed[:, :S, :]
|
|
671
|
+
return self.contrastive_proj(first_frame)
|
|
672
|
+
|
|
673
|
+
def get_objectness_logits(
|
|
674
|
+
self, tokens: torch.Tensor, num_frames: Optional[int] = None,
|
|
675
|
+
all_frames: bool = False,
|
|
676
|
+
) -> torch.Tensor:
|
|
677
|
+
B, N, D = tokens.shape
|
|
678
|
+
normed = self.norm(tokens)
|
|
679
|
+
S = N
|
|
680
|
+
T = 1
|
|
681
|
+
if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
|
|
682
|
+
T = int(num_frames)
|
|
683
|
+
S = N // T
|
|
684
|
+
if all_frames and T > 1:
|
|
685
|
+
frames = normed.view(B, T, S, D).reshape(B * T, S, D)
|
|
686
|
+
return self.objectness(frames).squeeze(-1)
|
|
687
|
+
else:
|
|
688
|
+
first_frame = normed[:, :S, :]
|
|
689
|
+
return self.objectness(first_frame).squeeze(-1)
|
|
690
|
+
|
|
691
|
+
def forward(
|
|
692
|
+
self,
|
|
693
|
+
tokens: torch.Tensor,
|
|
694
|
+
num_frames: Optional[int] = None,
|
|
695
|
+
fixed_box_wh: Optional[torch.Tensor] = None,
|
|
696
|
+
) -> torch.Tensor:
|
|
697
|
+
B, N, D = tokens.shape
|
|
698
|
+
normed = self.norm(tokens)
|
|
699
|
+
|
|
700
|
+
T = 1
|
|
701
|
+
S = N
|
|
702
|
+
if num_frames is not None and int(num_frames) > 1 and N % int(num_frames) == 0:
|
|
703
|
+
T = int(num_frames)
|
|
704
|
+
S = N // T
|
|
705
|
+
frames_tokens = normed.view(B, T, S, D)
|
|
706
|
+
else:
|
|
707
|
+
frames_tokens = normed.view(B, 1, N, D)
|
|
708
|
+
|
|
709
|
+
flat_tokens = frames_tokens.reshape(B * T, S, D)
|
|
710
|
+
q = self.query.expand(B * T, -1, -1)
|
|
711
|
+
pooled, _ = self.attn(q, flat_tokens, flat_tokens)
|
|
712
|
+
pooled = self.attn_norm(pooled.squeeze(1))
|
|
713
|
+
|
|
714
|
+
obj_logits = self.objectness(flat_tokens).squeeze(-1)
|
|
715
|
+
|
|
716
|
+
g = int(round(S ** 0.5))
|
|
717
|
+
is_square = g * g == S
|
|
718
|
+
min_size = 1.0 / float(g) if is_square else 1.0 / float(max(1, S))
|
|
719
|
+
|
|
720
|
+
if is_square:
|
|
721
|
+
g_up = g * self.UPSAMPLE_FACTOR
|
|
722
|
+
obj_2d = obj_logits.view(B * T, 1, g, g)
|
|
723
|
+
obj_up = F.interpolate(obj_2d, size=(g_up, g_up), mode="bilinear", align_corners=False)
|
|
724
|
+
obj_up = obj_up.view(B * T, g_up * g_up)
|
|
725
|
+
|
|
726
|
+
temp = self.temperature.clamp(min=1.0)
|
|
727
|
+
obj_probs = torch.softmax(obj_up * temp, dim=1)
|
|
728
|
+
|
|
729
|
+
ys = (torch.arange(g_up, device=normed.device, dtype=normed.dtype) + 0.5) / float(g_up)
|
|
730
|
+
xs = (torch.arange(g_up, device=normed.device, dtype=normed.dtype) + 0.5) / float(g_up)
|
|
731
|
+
yy, xx = torch.meshgrid(ys, xs, indexing="ij")
|
|
732
|
+
x_coords = xx.reshape(-1).unsqueeze(0)
|
|
733
|
+
y_coords = yy.reshape(-1).unsqueeze(0)
|
|
734
|
+
else:
|
|
735
|
+
temp = self.temperature.clamp(min=1.0)
|
|
736
|
+
obj_probs = torch.softmax(obj_logits * temp, dim=1)
|
|
737
|
+
x_coords = ((torch.arange(S, device=normed.device, dtype=normed.dtype) + 0.5) / float(S)).unsqueeze(0)
|
|
738
|
+
y_coords = torch.full_like(x_coords, 0.5)
|
|
739
|
+
|
|
740
|
+
cx0 = (obj_probs * x_coords).sum(dim=1)
|
|
741
|
+
cy0 = (obj_probs * y_coords).sum(dim=1)
|
|
742
|
+
|
|
743
|
+
if is_square:
|
|
744
|
+
cx0_grid = (cx0 * g).clamp(0, g - 1).long()
|
|
745
|
+
cy0_grid = (cy0 * g).clamp(0, g - 1).long()
|
|
746
|
+
tokens_2d = flat_tokens.view(B * T, g, g, D)
|
|
747
|
+
local_feats = []
|
|
748
|
+
for dy in range(-1, 2):
|
|
749
|
+
for dx in range(-1, 2):
|
|
750
|
+
gy = (cy0_grid + dy).clamp(0, g - 1)
|
|
751
|
+
gx = (cx0_grid + dx).clamp(0, g - 1)
|
|
752
|
+
idx_bt = torch.arange(B * T, device=normed.device)
|
|
753
|
+
local_feats.append(tokens_2d[idx_bt, gy, gx, :])
|
|
754
|
+
local_pool = torch.stack(local_feats, dim=1).mean(dim=1)
|
|
755
|
+
else:
|
|
756
|
+
local_pool = pooled
|
|
757
|
+
|
|
758
|
+
refine_input = torch.cat([pooled, local_pool], dim=1)
|
|
759
|
+
delta = self.refine(refine_input)
|
|
760
|
+
dx = 0.30 * torch.tanh(delta[:, 0])
|
|
761
|
+
dy = 0.30 * torch.tanh(delta[:, 1])
|
|
762
|
+
|
|
763
|
+
cx = (cx0 + dx).clamp(0.0, 1.0)
|
|
764
|
+
cy = (cy0 + dy).clamp(0.0, 1.0)
|
|
765
|
+
|
|
766
|
+
if T > 1:
|
|
767
|
+
center_seq = torch.stack([cx, cy], dim=1).view(B, T, 2).transpose(1, 2)
|
|
768
|
+
temporal_delta = self.temporal_refine(center_seq).transpose(1, 2).reshape(B * T, 2)
|
|
769
|
+
cx = (cx + 0.12 * torch.tanh(temporal_delta[:, 0])).clamp(0.0, 1.0)
|
|
770
|
+
cy = (cy + 0.12 * torch.tanh(temporal_delta[:, 1])).clamp(0.0, 1.0)
|
|
771
|
+
|
|
772
|
+
if fixed_box_wh is not None:
|
|
773
|
+
wh = fixed_box_wh.to(device=normed.device, dtype=normed.dtype)
|
|
774
|
+
if wh.dim() == 1:
|
|
775
|
+
wh = wh.view(1, 2).expand(B, -1)
|
|
776
|
+
if wh.size(0) != B:
|
|
777
|
+
wh = wh[:B]
|
|
778
|
+
else:
|
|
779
|
+
wh = self.fixed_box_wh.to(device=normed.device, dtype=normed.dtype).view(1, 2).expand(B, -1)
|
|
780
|
+
wh = wh.clamp(min=min_size, max=1.0)
|
|
781
|
+
wh_bt = wh.unsqueeze(1).expand(B, T, 2).reshape(B * T, 2) if T > 1 else wh
|
|
782
|
+
w = wh_bt[:, 0]
|
|
783
|
+
h = wh_bt[:, 1]
|
|
784
|
+
|
|
785
|
+
x1 = (cx - 0.5 * w).clamp(0.0, 1.0)
|
|
786
|
+
y1 = (cy - 0.5 * h).clamp(0.0, 1.0)
|
|
787
|
+
x2 = (cx + 0.5 * w).clamp(0.0, 1.0)
|
|
788
|
+
y2 = (cy + 0.5 * h).clamp(0.0, 1.0)
|
|
789
|
+
x2 = torch.maximum(x2, x1 + min_size).clamp(0.0, 1.0)
|
|
790
|
+
y2 = torch.maximum(y2, y1 + min_size).clamp(0.0, 1.0)
|
|
791
|
+
|
|
792
|
+
boxes = torch.stack([x1, y1, x2, y2], dim=1).view(B, T, 4)
|
|
793
|
+
if T == 1:
|
|
794
|
+
return boxes[:, 0, :]
|
|
795
|
+
return boxes
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
# BehaviorClassifier.
|
|
799
|
+
|
|
800
|
+
class BehaviorClassifier(nn.Module):
|
|
801
|
+
"""Complete model: VideoPrism backbone + frame decoder head (Stage B)."""
|
|
802
|
+
|
|
803
|
+
def __init__(
|
|
804
|
+
self,
|
|
805
|
+
backbone: VideoPrismBackbone,
|
|
806
|
+
num_classes: int,
|
|
807
|
+
class_names: list = None,
|
|
808
|
+
dropout: float = 0.1,
|
|
809
|
+
freeze_backbone: bool = True,
|
|
810
|
+
head_kwargs: Optional[dict] = None,
|
|
811
|
+
use_localization: bool = False,
|
|
812
|
+
localization_hidden_dim: int = 256,
|
|
813
|
+
localization_dropout: float = 0.0,
|
|
814
|
+
frame_head_temporal_layers: int = 1,
|
|
815
|
+
temporal_pool_frames: int = 1,
|
|
816
|
+
proj_dim: int = 256,
|
|
817
|
+
num_stages: int = 3,
|
|
818
|
+
multi_scale: bool = False,
|
|
819
|
+
use_temporal_decoder: bool = True,
|
|
820
|
+
use_frame_head: bool = True,
|
|
821
|
+
**kwargs,
|
|
822
|
+
):
|
|
823
|
+
super().__init__()
|
|
824
|
+
self.backbone = backbone
|
|
825
|
+
self.multi_scale = bool(multi_scale)
|
|
826
|
+
self.use_temporal_decoder = bool(use_temporal_decoder)
|
|
827
|
+
|
|
828
|
+
if freeze_backbone:
|
|
829
|
+
for p in self.backbone.parameters():
|
|
830
|
+
p.requires_grad = False
|
|
831
|
+
|
|
832
|
+
embed_dim = backbone.get_embed_dim()
|
|
833
|
+
self.temporal_pool_frames = max(1, int(temporal_pool_frames))
|
|
834
|
+
self.use_frame_head = True # always on
|
|
835
|
+
|
|
836
|
+
head_kwargs = head_kwargs or {}
|
|
837
|
+
spatial_pool_heads = head_kwargs.get("num_heads", 4)
|
|
838
|
+
|
|
839
|
+
self.frame_head = DilatedTemporalHead(
|
|
840
|
+
embed_dim=embed_dim,
|
|
841
|
+
num_classes=num_classes,
|
|
842
|
+
dropout=dropout,
|
|
843
|
+
num_layers=max(1, int(frame_head_temporal_layers)),
|
|
844
|
+
num_stages=max(1, int(num_stages)),
|
|
845
|
+
temporal_pool=self.temporal_pool_frames,
|
|
846
|
+
proj_dim=int(proj_dim),
|
|
847
|
+
spatial_pool_heads=int(spatial_pool_heads),
|
|
848
|
+
multi_scale=self.multi_scale,
|
|
849
|
+
use_temporal_decoder=self.use_temporal_decoder,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
self.use_localization = bool(use_localization)
|
|
853
|
+
self.localization_head = None
|
|
854
|
+
if self.use_localization:
|
|
855
|
+
self.localization_head = SpatialLocalizationHead(
|
|
856
|
+
embed_dim=embed_dim,
|
|
857
|
+
hidden_dim=int(localization_hidden_dim),
|
|
858
|
+
dropout=float(localization_dropout),
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
def forward(
|
|
862
|
+
self,
|
|
863
|
+
video: Optional[torch.Tensor],
|
|
864
|
+
return_localization: bool = False,
|
|
865
|
+
return_frame_logits: bool = False,
|
|
866
|
+
cache_backbone_tokens: bool = False,
|
|
867
|
+
localization_box_wh: Optional[torch.Tensor] = None,
|
|
868
|
+
return_attn_weights: bool = False,
|
|
869
|
+
backbone_tokens: Optional[torch.Tensor] = None,
|
|
870
|
+
num_frames: Optional[int] = None,
|
|
871
|
+
backbone_tokens_short: Optional[torch.Tensor] = None,
|
|
872
|
+
num_frames_short: Optional[int] = None,
|
|
873
|
+
return_features: bool = False,
|
|
874
|
+
):
|
|
875
|
+
"""
|
|
876
|
+
Returns:
|
|
877
|
+
clip_logits [B, C] by default.
|
|
878
|
+
Stores full frame outputs in self._frame_output when
|
|
879
|
+
return_frame_logits=True.
|
|
880
|
+
|
|
881
|
+
backbone_tokens: pre-computed tokens [B, T*S, D]. If provided,
|
|
882
|
+
the backbone call is skipped. num_frames must also be provided.
|
|
883
|
+
backbone_tokens_short: pre-computed short-scale tokens [B, T_s*S, D]
|
|
884
|
+
(T_s = T//2, half-fps clip). Used when multi_scale=True.
|
|
885
|
+
"""
|
|
886
|
+
if backbone_tokens is not None:
|
|
887
|
+
# Embedding-space stitch path: bypass backbone entirely
|
|
888
|
+
tokens_full = backbone_tokens
|
|
889
|
+
if num_frames is None:
|
|
890
|
+
raise ValueError("num_frames must be provided when backbone_tokens is given")
|
|
891
|
+
T = num_frames
|
|
892
|
+
tokens_short = backbone_tokens_short
|
|
893
|
+
T_short = num_frames_short
|
|
894
|
+
else:
|
|
895
|
+
tokens_full = self.backbone(video)
|
|
896
|
+
T = int(video.shape[1])
|
|
897
|
+
# Multi-scale at inference: subsample video by 2 and run backbone again
|
|
898
|
+
if self.multi_scale:
|
|
899
|
+
video_short = video[:, ::2, :, :, :] # [B, T//2, H, W, C]
|
|
900
|
+
T_short = int(video_short.shape[1])
|
|
901
|
+
tokens_short = self.backbone(video_short) # [B, T_s*S, D]
|
|
902
|
+
else:
|
|
903
|
+
tokens_short = None
|
|
904
|
+
T_short = None
|
|
905
|
+
|
|
906
|
+
if cache_backbone_tokens:
|
|
907
|
+
self._backbone_tokens = tokens_full
|
|
908
|
+
|
|
909
|
+
self._frame_output = None
|
|
910
|
+
frame_device = next(self.frame_head.parameters()).device
|
|
911
|
+
tokens_in = tokens_full.to(frame_device) if tokens_full.device != frame_device else tokens_full
|
|
912
|
+
tokens_short_in = (
|
|
913
|
+
tokens_short.to(frame_device)
|
|
914
|
+
if tokens_short is not None and tokens_short.device != frame_device
|
|
915
|
+
else tokens_short
|
|
916
|
+
)
|
|
917
|
+
frame_out = self.frame_head(
|
|
918
|
+
tokens_in, num_frames=T,
|
|
919
|
+
tokens_short=tokens_short_in,
|
|
920
|
+
num_frames_short=T_short,
|
|
921
|
+
return_attn_weights=return_attn_weights,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
frame_logits = frame_out[0]
|
|
925
|
+
frame_clip_logits = frame_out[1]
|
|
926
|
+
temporal_weights = frame_out[2]
|
|
927
|
+
frame_logits_pooled = frame_out[3]
|
|
928
|
+
boundary_logits = frame_out[4]
|
|
929
|
+
frame_embeddings = frame_out[5]
|
|
930
|
+
frame_embeddings_combined = frame_out[6] if len(frame_out) > 6 else None
|
|
931
|
+
attn_weights = frame_out[7] if len(frame_out) > 7 else None
|
|
932
|
+
|
|
933
|
+
if return_frame_logits:
|
|
934
|
+
self._frame_output = (
|
|
935
|
+
frame_logits, # 0 [B, T, C]
|
|
936
|
+
frame_clip_logits, # 1 [B, C]
|
|
937
|
+
temporal_weights, # 2 [B, T]
|
|
938
|
+
frame_logits_pooled, # 3 [B, T_pooled, C]
|
|
939
|
+
int(self.temporal_pool_frames), # 4
|
|
940
|
+
boundary_logits, # 5 [B, T, 1]
|
|
941
|
+
frame_embeddings, # 6 [B, T, proj_dim] (attention-pooled)
|
|
942
|
+
frame_embeddings_combined, # 7 [B, T, 2*proj_dim] (attn || raw_mean)
|
|
943
|
+
attn_weights, # 8 [B, T, num_heads, S] or None
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
head_output = frame_clip_logits
|
|
947
|
+
|
|
948
|
+
if return_localization and self.use_localization and self.localization_head is not None:
|
|
949
|
+
loc_device = next(self.localization_head.parameters()).device
|
|
950
|
+
loc_out = self.localization_head(
|
|
951
|
+
tokens_full.to(loc_device) if tokens_full.device != loc_device else tokens_full,
|
|
952
|
+
num_frames=T,
|
|
953
|
+
fixed_box_wh=localization_box_wh,
|
|
954
|
+
)
|
|
955
|
+
return head_output, loc_out
|
|
956
|
+
|
|
957
|
+
return head_output
|
|
958
|
+
|
|
959
|
+
def save_head(self, path: str, metadata: Optional[dict] = None):
|
|
960
|
+
"""Save frame head (and localization head) parameters."""
|
|
961
|
+
payload = {
|
|
962
|
+
"frame_head_state_dict": self.frame_head.state_dict(),
|
|
963
|
+
"use_localization": self.use_localization,
|
|
964
|
+
"use_frame_head": True,
|
|
965
|
+
"multi_scale": self.multi_scale,
|
|
966
|
+
"use_temporal_decoder": self.use_temporal_decoder,
|
|
967
|
+
}
|
|
968
|
+
if self.use_localization and self.localization_head is not None:
|
|
969
|
+
payload["localization_state_dict"] = self.localization_head.state_dict()
|
|
970
|
+
torch.save(payload, path)
|
|
971
|
+
|
|
972
|
+
if metadata:
|
|
973
|
+
import json
|
|
974
|
+
meta_path = path + ".meta.json"
|
|
975
|
+
try:
|
|
976
|
+
with open(meta_path, "w", encoding="utf-8") as f:
|
|
977
|
+
json.dump(metadata, f, indent=2)
|
|
978
|
+
except Exception as exc:
|
|
979
|
+
logger.warning("Failed to write metadata to %s: %s", meta_path, exc)
|
|
980
|
+
|
|
981
|
+
def load_head(self, path: str):
|
|
982
|
+
"""Load head parameters."""
|
|
983
|
+
state = torch.load(path, map_location='cpu')
|
|
984
|
+
if not isinstance(state, dict):
|
|
985
|
+
return
|
|
986
|
+
|
|
987
|
+
if "frame_head_state_dict" in state:
|
|
988
|
+
self.frame_head.load_state_dict(
|
|
989
|
+
state["frame_head_state_dict"], strict=False,
|
|
990
|
+
)
|
|
991
|
+
if (
|
|
992
|
+
self.use_localization
|
|
993
|
+
and self.localization_head is not None
|
|
994
|
+
and isinstance(state.get("localization_state_dict"), dict)
|
|
995
|
+
):
|
|
996
|
+
self.localization_head.load_state_dict(
|
|
997
|
+
state["localization_state_dict"], strict=False,
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
# Loss functions.
|
|
1002
|
+
|
|
1003
|
+
def _giou(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
1004
|
+
"""Generalized IoU between xyxy boxes. Returns [N] in [-1, 1]."""
|
|
1005
|
+
eps = 1e-7
|
|
1006
|
+
ix1 = torch.maximum(pred[:, 0], target[:, 0])
|
|
1007
|
+
iy1 = torch.maximum(pred[:, 1], target[:, 1])
|
|
1008
|
+
ix2 = torch.minimum(pred[:, 2], target[:, 2])
|
|
1009
|
+
iy2 = torch.minimum(pred[:, 3], target[:, 3])
|
|
1010
|
+
inter = (ix2 - ix1).clamp(min=0) * (iy2 - iy1).clamp(min=0)
|
|
1011
|
+
|
|
1012
|
+
area_p = (pred[:, 2] - pred[:, 0]).clamp(min=0) * (pred[:, 3] - pred[:, 1]).clamp(min=0)
|
|
1013
|
+
area_t = (target[:, 2] - target[:, 0]).clamp(min=0) * (target[:, 3] - target[:, 1]).clamp(min=0)
|
|
1014
|
+
union = area_p + area_t - inter
|
|
1015
|
+
iou = inter / (union + eps)
|
|
1016
|
+
|
|
1017
|
+
ex1 = torch.minimum(pred[:, 0], target[:, 0])
|
|
1018
|
+
ey1 = torch.minimum(pred[:, 1], target[:, 1])
|
|
1019
|
+
ex2 = torch.maximum(pred[:, 2], target[:, 2])
|
|
1020
|
+
ey2 = torch.maximum(pred[:, 3], target[:, 3])
|
|
1021
|
+
area_enclosing = (ex2 - ex1).clamp(min=0) * (ey2 - ey1).clamp(min=0)
|
|
1022
|
+
|
|
1023
|
+
return iou - (area_enclosing - union) / (area_enclosing + eps)
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def localization_bbox_loss(
|
|
1027
|
+
pred_bboxes: torch.Tensor,
|
|
1028
|
+
target_bboxes: torch.Tensor,
|
|
1029
|
+
valid_mask: torch.Tensor,
|
|
1030
|
+
smooth_l1_weight: float = 0.5,
|
|
1031
|
+
giou_weight: float = 0.5,
|
|
1032
|
+
temporal_smoothness_weight: float = 0.0,
|
|
1033
|
+
) -> torch.Tensor:
|
|
1034
|
+
"""Combined Smooth-L1 + GIoU bbox loss over valid samples."""
|
|
1035
|
+
if pred_bboxes is None or target_bboxes is None or valid_mask is None:
|
|
1036
|
+
return torch.tensor(0.0, device=pred_bboxes.device if pred_bboxes is not None else "cpu", requires_grad=True)
|
|
1037
|
+
|
|
1038
|
+
valid = valid_mask > 0.5
|
|
1039
|
+
if not valid.any():
|
|
1040
|
+
return torch.tensor(0.0, device=pred_bboxes.device, requires_grad=True)
|
|
1041
|
+
|
|
1042
|
+
pred_first = pred_bboxes[:, 0, :] if pred_bboxes.dim() == 3 else pred_bboxes
|
|
1043
|
+
pred = pred_first[valid]
|
|
1044
|
+
tgt = target_bboxes[valid]
|
|
1045
|
+
|
|
1046
|
+
loss_l1 = F.smooth_l1_loss(pred, tgt, reduction="mean")
|
|
1047
|
+
loss_giou = (1.0 - _giou(pred, tgt)).mean()
|
|
1048
|
+
loss = smooth_l1_weight * loss_l1 + giou_weight * loss_giou
|
|
1049
|
+
|
|
1050
|
+
if pred_bboxes.dim() == 3 and pred_bboxes.size(1) > 1:
|
|
1051
|
+
pred_valid_seq = pred_bboxes[valid]
|
|
1052
|
+
if pred_valid_seq.numel() > 0:
|
|
1053
|
+
smooth = F.smooth_l1_loss(
|
|
1054
|
+
pred_valid_seq[:, 1:, :], pred_valid_seq[:, :-1, :],
|
|
1055
|
+
reduction="mean",
|
|
1056
|
+
)
|
|
1057
|
+
loss = loss + temporal_smoothness_weight * smooth
|
|
1058
|
+
|
|
1059
|
+
return loss
|
|
1060
|
+
|
|
1061
|
+
|
|
1062
|
+
def objectness_spatial_contrastive_loss(
|
|
1063
|
+
projected_tokens: torch.Tensor,
|
|
1064
|
+
spatial_masks: torch.Tensor,
|
|
1065
|
+
temperature: float = 0.1,
|
|
1066
|
+
) -> torch.Tensor:
|
|
1067
|
+
"""Per-sample spatial contrastive loss for localization training."""
|
|
1068
|
+
device = projected_tokens.device
|
|
1069
|
+
B, S, D = projected_tokens.shape
|
|
1070
|
+
|
|
1071
|
+
if spatial_masks.size(1) != S:
|
|
1072
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1073
|
+
|
|
1074
|
+
n_inside = spatial_masks.sum(dim=1)
|
|
1075
|
+
n_outside = S - n_inside
|
|
1076
|
+
valid = (n_inside >= 2) & (n_outside >= 1)
|
|
1077
|
+
if not valid.any():
|
|
1078
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1079
|
+
|
|
1080
|
+
tokens_sub = F.normalize(projected_tokens[valid], p=2, dim=-1)
|
|
1081
|
+
masks_sub = spatial_masks[valid]
|
|
1082
|
+
|
|
1083
|
+
total_loss = torch.tensor(0.0, device=device)
|
|
1084
|
+
count = 0
|
|
1085
|
+
|
|
1086
|
+
for i in range(tokens_sub.size(0)):
|
|
1087
|
+
tok = tokens_sub[i]
|
|
1088
|
+
mask = masks_sub[i]
|
|
1089
|
+
inside_idx = (mask > 0.5).nonzero(as_tuple=True)[0]
|
|
1090
|
+
outside_idx = (mask <= 0.5).nonzero(as_tuple=True)[0]
|
|
1091
|
+
K = inside_idx.size(0)
|
|
1092
|
+
if K < 2 or outside_idx.size(0) < 1:
|
|
1093
|
+
continue
|
|
1094
|
+
|
|
1095
|
+
inside_tok = tok[inside_idx]
|
|
1096
|
+
outside_tok = tok[outside_idx]
|
|
1097
|
+
sim_pos = torch.matmul(inside_tok, inside_tok.T) / temperature
|
|
1098
|
+
sim_neg = torch.matmul(inside_tok, outside_tok.T) / temperature
|
|
1099
|
+
sim_pos = sim_pos.masked_fill(
|
|
1100
|
+
torch.eye(K, device=device, dtype=torch.bool), float("-inf"),
|
|
1101
|
+
)
|
|
1102
|
+
log_numer = torch.logsumexp(sim_pos, dim=1)
|
|
1103
|
+
all_sim = torch.cat([sim_pos, sim_neg], dim=1)
|
|
1104
|
+
log_denom = torch.logsumexp(all_sim, dim=1)
|
|
1105
|
+
total_loss = total_loss + (-(log_numer - log_denom)).mean()
|
|
1106
|
+
count += 1
|
|
1107
|
+
|
|
1108
|
+
if count == 0:
|
|
1109
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1110
|
+
return total_loss / count
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
def objectness_mask_loss(
|
|
1114
|
+
objectness_logits: torch.Tensor,
|
|
1115
|
+
spatial_masks: torch.Tensor,
|
|
1116
|
+
min_pos_tokens: int = 2,
|
|
1117
|
+
) -> torch.Tensor:
|
|
1118
|
+
"""BCE logits loss for localization objectness supervision."""
|
|
1119
|
+
device = objectness_logits.device
|
|
1120
|
+
if spatial_masks.shape != objectness_logits.shape:
|
|
1121
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1122
|
+
|
|
1123
|
+
pos_count = spatial_masks.sum(dim=1)
|
|
1124
|
+
valid = (pos_count >= float(min_pos_tokens)) & (pos_count < float(spatial_masks.size(1)))
|
|
1125
|
+
if not valid.any():
|
|
1126
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1127
|
+
|
|
1128
|
+
logits = objectness_logits[valid]
|
|
1129
|
+
targets = spatial_masks[valid].float()
|
|
1130
|
+
pos = targets.sum()
|
|
1131
|
+
neg = torch.tensor(float(targets.numel()), device=device) - pos
|
|
1132
|
+
pos_weight = (neg / (pos + 1e-6)).clamp(min=1.0, max=20.0)
|
|
1133
|
+
return F.binary_cross_entropy_with_logits(logits, targets, pos_weight=pos_weight)
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
def gaussian_focal_loss(
|
|
1137
|
+
pred_logits: torch.Tensor,
|
|
1138
|
+
target_heatmap: torch.Tensor,
|
|
1139
|
+
alpha: float = 2.0,
|
|
1140
|
+
beta: float = 4.0,
|
|
1141
|
+
) -> torch.Tensor:
|
|
1142
|
+
"""CenterNet-style Gaussian focal loss for heatmap supervision."""
|
|
1143
|
+
pred = torch.sigmoid(pred_logits)
|
|
1144
|
+
pos_mask = target_heatmap.eq(1.0)
|
|
1145
|
+
neg_mask = ~pos_mask
|
|
1146
|
+
|
|
1147
|
+
pos_loss = -((1 - pred).pow(alpha) * torch.log(pred.clamp(min=1e-6))) * pos_mask
|
|
1148
|
+
neg_loss = -(
|
|
1149
|
+
(1 - target_heatmap).pow(beta) * pred.pow(alpha)
|
|
1150
|
+
* torch.log((1 - pred).clamp(min=1e-6))
|
|
1151
|
+
) * neg_mask
|
|
1152
|
+
|
|
1153
|
+
num_pos = pos_mask.float().sum().clamp(min=1.0)
|
|
1154
|
+
return (pos_loss.sum() + neg_loss.sum()) / num_pos
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
def center_heatmap_loss(
|
|
1158
|
+
objectness_logits: torch.Tensor,
|
|
1159
|
+
target_bboxes: torch.Tensor,
|
|
1160
|
+
valid_mask: torch.Tensor,
|
|
1161
|
+
sigma_in_patches: float = 1.5,
|
|
1162
|
+
) -> torch.Tensor:
|
|
1163
|
+
"""Gaussian center heatmap supervision for localization objectness."""
|
|
1164
|
+
device = objectness_logits.device
|
|
1165
|
+
B, S = objectness_logits.shape
|
|
1166
|
+
if target_bboxes is None or valid_mask is None:
|
|
1167
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1168
|
+
if target_bboxes.dim() != 2 or target_bboxes.size(0) != B or target_bboxes.size(1) != 4:
|
|
1169
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1170
|
+
|
|
1171
|
+
valid = valid_mask > 0.5
|
|
1172
|
+
if not valid.any():
|
|
1173
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1174
|
+
|
|
1175
|
+
g = int(round(S ** 0.5))
|
|
1176
|
+
if g * g != S:
|
|
1177
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1178
|
+
|
|
1179
|
+
ys = (torch.arange(g, device=device, dtype=objectness_logits.dtype) + 0.5) / float(g)
|
|
1180
|
+
xs = (torch.arange(g, device=device, dtype=objectness_logits.dtype) + 0.5) / float(g)
|
|
1181
|
+
yy, xx = torch.meshgrid(ys, xs, indexing="ij")
|
|
1182
|
+
|
|
1183
|
+
tgt = target_bboxes[valid]
|
|
1184
|
+
cx = (0.5 * (tgt[:, 0] + tgt[:, 2])).view(-1, 1, 1)
|
|
1185
|
+
cy = (0.5 * (tgt[:, 1] + tgt[:, 3])).view(-1, 1, 1)
|
|
1186
|
+
|
|
1187
|
+
sigma = max(1e-4, float(sigma_in_patches) / float(g))
|
|
1188
|
+
d2 = (xx.view(1, g, g) - cx).pow(2) + (yy.view(1, g, g) - cy).pow(2)
|
|
1189
|
+
target_heat = torch.exp(-d2 / (2.0 * sigma * sigma)).clamp(0.0, 1.0)
|
|
1190
|
+
|
|
1191
|
+
Bv = target_heat.size(0)
|
|
1192
|
+
flat = target_heat.view(Bv, -1)
|
|
1193
|
+
peak_idx = flat.argmax(dim=1, keepdim=True)
|
|
1194
|
+
flat.scatter_(1, peak_idx, 1.0)
|
|
1195
|
+
target_heat = flat.view(Bv, g, g)
|
|
1196
|
+
|
|
1197
|
+
pred_logits_v = objectness_logits[valid].view(-1, g, g)
|
|
1198
|
+
return gaussian_focal_loss(pred_logits_v, target_heat)
|
|
1199
|
+
|
|
1200
|
+
|
|
1201
|
+
def direct_center_loss(
|
|
1202
|
+
pred_bboxes: torch.Tensor,
|
|
1203
|
+
target_bboxes: torch.Tensor,
|
|
1204
|
+
valid_mask: torch.Tensor,
|
|
1205
|
+
) -> torch.Tensor:
|
|
1206
|
+
"""Center-to-center SmoothL1 loss for localization."""
|
|
1207
|
+
if pred_bboxes is None or target_bboxes is None or valid_mask is None:
|
|
1208
|
+
dev = pred_bboxes.device if pred_bboxes is not None else "cpu"
|
|
1209
|
+
return torch.tensor(0.0, device=dev, requires_grad=True)
|
|
1210
|
+
|
|
1211
|
+
valid = valid_mask > 0.5
|
|
1212
|
+
if not valid.any():
|
|
1213
|
+
return torch.tensor(0.0, device=pred_bboxes.device, requires_grad=True)
|
|
1214
|
+
|
|
1215
|
+
pred_first = pred_bboxes[:, 0, :] if pred_bboxes.dim() == 3 else pred_bboxes
|
|
1216
|
+
pred = pred_first[valid]
|
|
1217
|
+
tgt = target_bboxes[valid]
|
|
1218
|
+
|
|
1219
|
+
pred_center = torch.stack([
|
|
1220
|
+
0.5 * (pred[:, 0] + pred[:, 2]),
|
|
1221
|
+
0.5 * (pred[:, 1] + pred[:, 3]),
|
|
1222
|
+
], dim=1)
|
|
1223
|
+
tgt_center = torch.stack([
|
|
1224
|
+
0.5 * (tgt[:, 0] + tgt[:, 2]),
|
|
1225
|
+
0.5 * (tgt[:, 1] + tgt[:, 3]),
|
|
1226
|
+
], dim=1)
|
|
1227
|
+
|
|
1228
|
+
return F.smooth_l1_loss(pred_center, tgt_center, reduction="mean")
|
|
1229
|
+
|
|
1230
|
+
|
|
1231
|
+
def frame_classification_loss(
|
|
1232
|
+
frame_logits: torch.Tensor,
|
|
1233
|
+
frame_labels: torch.Tensor,
|
|
1234
|
+
) -> torch.Tensor:
|
|
1235
|
+
"""Per-frame CE loss, ignoring frames with label -1."""
|
|
1236
|
+
B, T, C = frame_logits.shape
|
|
1237
|
+
device = frame_logits.device
|
|
1238
|
+
logits_flat = frame_logits.reshape(B * T, C)
|
|
1239
|
+
labels_flat = frame_labels.reshape(B * T)
|
|
1240
|
+
valid = labels_flat >= 0
|
|
1241
|
+
if not valid.any():
|
|
1242
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1243
|
+
return F.cross_entropy(logits_flat[valid], labels_flat[valid])
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
def boundary_detection_loss(
|
|
1247
|
+
boundary_logits: torch.Tensor,
|
|
1248
|
+
boundary_labels: torch.Tensor,
|
|
1249
|
+
) -> torch.Tensor:
|
|
1250
|
+
"""BCE loss for boundary/change-point detection.
|
|
1251
|
+
|
|
1252
|
+
Args:
|
|
1253
|
+
boundary_logits: [B, T, 1] raw logits
|
|
1254
|
+
boundary_labels: [B, T] binary (1=transition, 0=no, -1=ignore)
|
|
1255
|
+
"""
|
|
1256
|
+
B, T, _ = boundary_logits.shape
|
|
1257
|
+
device = boundary_logits.device
|
|
1258
|
+
logits_flat = boundary_logits.squeeze(-1) # [B, T]
|
|
1259
|
+
valid = boundary_labels >= 0
|
|
1260
|
+
if not valid.any():
|
|
1261
|
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
1262
|
+
|
|
1263
|
+
n_pos = (boundary_labels[valid] > 0.5).float().sum().clamp(min=1.0)
|
|
1264
|
+
n_neg = (boundary_labels[valid] <= 0.5).float().sum().clamp(min=1.0)
|
|
1265
|
+
pos_weight = (n_neg / n_pos).clamp(max=20.0)
|
|
1266
|
+
|
|
1267
|
+
raw = F.binary_cross_entropy_with_logits(
|
|
1268
|
+
logits_flat[valid], boundary_labels[valid].float(),
|
|
1269
|
+
pos_weight=pos_weight, reduction='none',
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
return raw.mean()
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
def temporal_smoothness_loss(
|
|
1276
|
+
frame_logits: torch.Tensor,
|
|
1277
|
+
frame_labels: torch.Tensor,
|
|
1278
|
+
) -> torch.Tensor:
|
|
1279
|
+
"""L2 smoothness regularizer on consecutive frame logits."""
|
|
1280
|
+
B, T, C = frame_logits.shape
|
|
1281
|
+
if T < 2:
|
|
1282
|
+
return torch.tensor(0.0, device=frame_logits.device, requires_grad=True)
|
|
1283
|
+
|
|
1284
|
+
valid = frame_labels >= 0
|
|
1285
|
+
both_valid = valid[:, :-1] & valid[:, 1:]
|
|
1286
|
+
if not both_valid.any():
|
|
1287
|
+
return torch.tensor(0.0, device=frame_logits.device, requires_grad=True)
|
|
1288
|
+
|
|
1289
|
+
diff = (frame_logits[:, 1:, :] - frame_logits[:, :-1, :]) ** 2
|
|
1290
|
+
return diff.mean(dim=-1)[both_valid].mean()
|