singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,4667 @@
1
+ import logging
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
7
+ from torch.optim import AdamW
8
+ from typing import Callable, Optional, Dict, Any
9
+ import numpy as np
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import f1_score
12
+ from collections import Counter
13
+ import random
14
+ import math
15
+ import re
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _slugify_class_name(name: str) -> str:
21
+ """Create a filesystem/column-safe slug for class names."""
22
+ slug = re.sub(r'[^0-9a-zA-Z]+', '_', name).strip('_').lower()
23
+ return slug or "class"
24
+
25
+
26
+ class FocalLoss(nn.Module):
27
+ """
28
+ Focal Loss for dense object detection and classification.
29
+ focuses training on hard examples and down-weights easy ones.
30
+ Loss = -alpha * (1 - pt)^gamma * log(pt)
31
+ """
32
+ def __init__(self, gamma: float = 2.0, alpha: Optional[torch.Tensor] = None, reduction: str = 'mean'):
33
+ super().__init__()
34
+ self.gamma = gamma
35
+ self.alpha = alpha
36
+ self.reduction = reduction
37
+
38
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Args:
41
+ inputs: [B, C] logits (not probabilities)
42
+ targets: [B] labels (use -100 to ignore)
43
+ """
44
+ valid = targets >= 0
45
+ ce_loss = F.cross_entropy(inputs, targets.clamp(min=0), reduction='none', weight=self.alpha)
46
+ pt = torch.exp(-ce_loss)
47
+ focal_loss = ((1 - pt) ** self.gamma) * ce_loss
48
+ focal_loss = focal_loss * valid.float()
49
+
50
+ if self.reduction == 'mean':
51
+ n = valid.sum().clamp(min=1)
52
+ return focal_loss.sum() / n
53
+ elif self.reduction == 'sum':
54
+ return focal_loss.sum()
55
+ else:
56
+ return focal_loss
57
+
58
+
59
+ class BinaryFocalLoss(nn.Module):
60
+ """Per-element binary focal loss for OvR heads."""
61
+ def __init__(self, gamma: float = 2.0, reduction: str = 'mean'):
62
+ super().__init__()
63
+ self.gamma = gamma
64
+ self.reduction = reduction
65
+
66
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor,
67
+ weight: Optional[torch.Tensor] = None) -> torch.Tensor:
68
+ """
69
+ Args:
70
+ inputs: [B, C] logits
71
+ targets: [B, C] binary targets (0 or 1)
72
+ weight: optional [B, C] per-element weight
73
+ """
74
+ bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
75
+ pt = torch.exp(-bce)
76
+ focal = ((1 - pt) ** self.gamma) * bce
77
+ if weight is not None:
78
+ focal = focal * weight
79
+ if self.reduction == 'mean':
80
+ return focal.mean()
81
+ elif self.reduction == 'sum':
82
+ return focal.sum()
83
+ return focal
84
+
85
+
86
+ class AsymmetricLoss(nn.Module):
87
+ """Asymmetric Loss for multi-label / OvR classification.
88
+
89
+ Applies focal-style down-weighting only to *negatives* (gamma_neg),
90
+ while keeping full gradient signal from positives (gamma_pos, default 0).
91
+ Optionally applies probability shifting (clip) to hard-threshold easy
92
+ negatives, further reducing their contribution.
93
+
94
+ Reference: Ridnik et al., "Asymmetric Loss For Multi-Label Classification"
95
+ (https://arxiv.org/abs/2009.14119)
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ gamma_neg: float = 4.0,
101
+ gamma_pos: float = 0.0,
102
+ clip: float = 0.05,
103
+ reduction: str = "mean",
104
+ ):
105
+ super().__init__()
106
+ self.gamma_neg = gamma_neg
107
+ self.gamma_pos = gamma_pos
108
+ self.clip = clip
109
+ self.reduction = reduction
110
+
111
+ def forward(
112
+ self,
113
+ inputs: torch.Tensor,
114
+ targets: torch.Tensor,
115
+ weight: Optional[torch.Tensor] = None,
116
+ ) -> torch.Tensor:
117
+ """
118
+ Args:
119
+ inputs: [*, C] raw logits
120
+ targets: [*, C] binary targets in [0, 1]
121
+ weight: optional [*, C] per-element weight
122
+ """
123
+ p = torch.sigmoid(inputs)
124
+
125
+ # Derive hard targets to prevent ASL from over-penalizing easy negatives under label smoothing.
126
+ hard_targets = (targets >= 0.5).float()
127
+
128
+ pos_part = hard_targets * torch.log(p.clamp(min=1e-8))
129
+ neg_p = 1.0 - p
130
+
131
+ # Probability shifting: suppress easy negatives to reduce their gradient contribution
132
+ if self.clip > 0:
133
+ neg_p = (neg_p + self.clip).clamp(max=1.0)
134
+
135
+ neg_part = (1.0 - hard_targets) * torch.log(neg_p.clamp(min=1e-8))
136
+
137
+ # Asymmetric focusing
138
+ if self.gamma_pos > 0:
139
+ pos_part = pos_part * ((1.0 - p) ** self.gamma_pos)
140
+ if self.gamma_neg > 0:
141
+ neg_part = neg_part * (p ** self.gamma_neg)
142
+
143
+ loss = -(pos_part + neg_part)
144
+
145
+ if weight is not None:
146
+ loss = loss * weight
147
+ if self.reduction == "mean":
148
+ return loss.mean()
149
+ elif self.reduction == "sum":
150
+ return loss.sum()
151
+ return loss
152
+
153
+
154
+
155
+ class BehaviorDataset(Dataset):
156
+ """PyTorch Dataset for behavior clips."""
157
+
158
+ def __init__(
159
+ self,
160
+ clips: list,
161
+ annotation_manager,
162
+ classes: list,
163
+ clip_base_dir: str,
164
+ transform: Optional[Callable] = None,
165
+ target_size: tuple[int, int] = (288, 288),
166
+ clip_length: int = 16,
167
+ virtual_size_multiplier: int = 1,
168
+ grid_size: int = 0,
169
+ stitch_prob: float = 0.0,
170
+ crop_jitter: bool = False,
171
+ crop_jitter_strength: float = 0.15,
172
+ ovr_background_classes: Optional[list[str]] = None,
173
+ ):
174
+ self.clips = clips
175
+ self.annotation_manager = annotation_manager
176
+ self.raw_classes = classes # Original labels
177
+ self.clip_base_dir = clip_base_dir
178
+ self.transform = transform
179
+ self.target_size = target_size
180
+ self.clip_length = clip_length
181
+ self.virtual_size_multiplier = virtual_size_multiplier
182
+ self.grid_size = grid_size # Spatial grid size (e.g. 16 for 288px). 0 = disabled.
183
+ # Masks stored at 16×16 reference grid; rescaled to training grid_size if different.
184
+ STORED_GRID = 16
185
+ self.spatial_masks = []
186
+ self.spatial_bboxes = []
187
+ self.spatial_bbox_valid = []
188
+ num_patches = grid_size * grid_size if grid_size > 0 else 0
189
+ for clip in clips:
190
+ clip_id = clip["id"]
191
+ mask_indices = clip.get("spatial_mask")
192
+ if mask_indices and grid_size > 0:
193
+ # Build mask at stored 16×16 resolution
194
+ ref_mask = torch.zeros(STORED_GRID, STORED_GRID, dtype=torch.float32)
195
+ for idx in mask_indices:
196
+ if 0 <= idx < STORED_GRID * STORED_GRID:
197
+ row, col = divmod(idx, STORED_GRID)
198
+ ref_mask[row, col] = 1.0
199
+
200
+ if grid_size != STORED_GRID:
201
+ # Rescale to training grid size via bilinear interpolation
202
+ ref_4d = ref_mask.unsqueeze(0).unsqueeze(0) # [1,1,16,16]
203
+ scaled = torch.nn.functional.interpolate(
204
+ ref_4d, size=(grid_size, grid_size), mode='bilinear', align_corners=False
205
+ )
206
+ mask = (scaled.squeeze() > 0.25).float() # threshold to keep binary
207
+ else:
208
+ mask = ref_mask
209
+
210
+ self.spatial_masks.append(mask.reshape(-1)) # flatten to [G*G]
211
+ else:
212
+ self.spatial_masks.append(torch.zeros(max(1, num_patches), dtype=torch.float32))
213
+
214
+ # Per-frame bboxes: [T, 4] and [T] validity
215
+ T = self.clip_length
216
+ bbox_frames_data = clip.get("spatial_bbox_frames")
217
+ if bbox_frames_data and isinstance(bbox_frames_data, (list, tuple)):
218
+ frame_boxes = torch.zeros(T, 4, dtype=torch.float32)
219
+ frame_valid = torch.zeros(T, dtype=torch.float32)
220
+ for fi in range(min(T, len(bbox_frames_data))):
221
+ b = bbox_frames_data[fi]
222
+ if b and isinstance(b, (list, tuple)) and len(b) == 4:
223
+ x1, y1, x2, y2 = [max(0.0, min(1.0, float(v))) for v in b]
224
+ if x2 > x1 and y2 > y1:
225
+ frame_boxes[fi] = torch.tensor([x1, y1, x2, y2])
226
+ frame_valid[fi] = 1.0
227
+ self.spatial_bboxes.append(frame_boxes)
228
+ self.spatial_bbox_valid.append(frame_valid)
229
+ else:
230
+ # Legacy single bbox: replicate to all frames
231
+ bbox = clip.get("spatial_bbox")
232
+ if bbox and isinstance(bbox, (list, tuple)) and len(bbox) == 4:
233
+ x1, y1, x2, y2 = [max(0.0, min(1.0, float(v))) for v in bbox]
234
+ if x2 > x1 and y2 > y1:
235
+ single = torch.tensor([x1, y1, x2, y2], dtype=torch.float32)
236
+ self.spatial_bboxes.append(single.unsqueeze(0).expand(T, -1).clone())
237
+ # Mark only frame 0 valid: legacy single-frame annotations shouldn't penalize per-frame tracking.
238
+ legacy_valid = torch.zeros(T, dtype=torch.float32)
239
+ legacy_valid[0] = 1.0
240
+ self.spatial_bbox_valid.append(legacy_valid)
241
+ else:
242
+ self.spatial_bboxes.append(torch.zeros(T, 4, dtype=torch.float32))
243
+ self.spatial_bbox_valid.append(torch.zeros(T, dtype=torch.float32))
244
+ else:
245
+ self.spatial_bboxes.append(torch.zeros(T, 4, dtype=torch.float32))
246
+ self.spatial_bbox_valid.append(torch.zeros(T, dtype=torch.float32))
247
+
248
+ self.classes = classes
249
+ self.attributes = []
250
+ self.class_to_idx = {c: i for i, c in enumerate(classes)}
251
+ self.attr_to_idx = {}
252
+ self.ovr_background_classes = set(ovr_background_classes or [])
253
+ # Primary label index per clip (-1 if not in class list, e.g. near_negative_*).
254
+ self.labels = []
255
+ # Multi-label: list of all label indices per clip (for OvR multi-hot targets).
256
+ self.multi_labels = []
257
+ for clip in clips:
258
+ clip_labels = clip.get("labels")
259
+ if not isinstance(clip_labels, list) or not clip_labels:
260
+ clip_labels = [clip.get("label", "")]
261
+ indices = [self.class_to_idx[l] for l in clip_labels if l in self.class_to_idx]
262
+ if indices:
263
+ self.labels.append(indices[0])
264
+ self.multi_labels.append(indices)
265
+ else:
266
+ self.labels.append(-1)
267
+ self.multi_labels.append([])
268
+ self.class_labels = self.labels
269
+ self.attr_labels = []
270
+
271
+ # OvR: for each clip, store the matched suppression class index (or -1).
272
+ # near_negative_X → suppress class X; also check hard_negative_for_class metadata.
273
+ self.ovr_suppress_idx = []
274
+ for clip in clips:
275
+ suppress = -1
276
+ label = clip.get("label", "")
277
+ meta = clip.get("meta") or {}
278
+ hn_for = meta.get("hard_negative_for_class")
279
+ if hn_for and hn_for in self.class_to_idx:
280
+ suppress = self.class_to_idx[hn_for]
281
+ elif label.startswith("near_negative_"):
282
+ suffix = label[len("near_negative_"):]
283
+ for cls_name, cls_idx in self.class_to_idx.items():
284
+ if cls_name == suffix or cls_name.replace(" ", "_") == suffix:
285
+ suppress = cls_idx
286
+ break
287
+ self.ovr_suppress_idx.append(suppress)
288
+
289
+ # Per-frame labels: [T] tensor per clip.
290
+ # - If explicit frame labels exist, use them.
291
+ # - Else if a valid clip class exists, supervise all frames with that class.
292
+ # - Else use -1 (ignored by frame loss).
293
+ # has_real_frame_labels is used by training to route samples to frame
294
+ # supervision (and exclude them from clip-loss metrics/path).
295
+ self.frame_labels = []
296
+ self.ovr_background_frame_mask = []
297
+ self.ovr_background_clip = []
298
+ self.has_real_frame_labels = []
299
+ for i, clip in enumerate(clips):
300
+ clip_fl = clip.get("frame_labels")
301
+ T = self.clip_length
302
+ primary = self.labels[i]
303
+ raw_label = clip.get("label", "")
304
+ clip_is_background = raw_label in self.ovr_background_classes
305
+ if clip_fl and isinstance(clip_fl, (list, tuple)) and len(clip_fl) > 0:
306
+ fl = []
307
+ bg = []
308
+ for lbl_name in clip_fl:
309
+ if lbl_name is None:
310
+ fl.append(-1)
311
+ bg.append(False)
312
+ elif isinstance(lbl_name, str):
313
+ if lbl_name in self.ovr_background_classes:
314
+ fl.append(-1)
315
+ bg.append(True)
316
+ else:
317
+ fl.append(self.class_to_idx.get(lbl_name, -1))
318
+ bg.append(False)
319
+ else:
320
+ v = int(lbl_name)
321
+ if 0 <= v < len(self.class_to_idx):
322
+ fl.append(v)
323
+ bg.append(False)
324
+ else:
325
+ fl.append(-1)
326
+ bg.append(False)
327
+ fl_tensor = torch.tensor(fl, dtype=torch.long)
328
+ bg_tensor = torch.tensor(bg, dtype=torch.bool)
329
+ if len(fl_tensor) < T:
330
+ fl_tensor = torch.cat([fl_tensor, fl_tensor[-1:].expand(T - len(fl_tensor))])
331
+ bg_tensor = torch.cat([bg_tensor, bg_tensor[-1:].expand(T - len(bg_tensor))])
332
+ elif len(fl_tensor) > T:
333
+ start = (len(fl_tensor) - T) // 2
334
+ fl_tensor = fl_tensor[start:start + T]
335
+ bg_tensor = bg_tensor[start:start + T]
336
+ # Safety: convert any out-of-range numeric label to ignore (-1).
337
+ fl_tensor[(fl_tensor < 0) | (fl_tensor >= len(self.class_to_idx))] = -1
338
+ self.frame_labels.append(fl_tensor)
339
+ self.ovr_background_frame_mask.append(bg_tensor)
340
+ self.ovr_background_clip.append(bool(bg_tensor.any().item()) or clip_is_background)
341
+ self.has_real_frame_labels.append(True)
342
+ else:
343
+ if primary >= 0:
344
+ self.frame_labels.append(torch.full((T,), int(primary), dtype=torch.long))
345
+ self.ovr_background_frame_mask.append(torch.zeros(T, dtype=torch.bool))
346
+ self.ovr_background_clip.append(False)
347
+ self.has_real_frame_labels.append(True)
348
+ else:
349
+ bg_tensor = torch.full((T,), bool(clip_is_background), dtype=torch.bool)
350
+ self.frame_labels.append(torch.full((T,), -1, dtype=torch.long))
351
+ self.ovr_background_frame_mask.append(bg_tensor)
352
+ self.ovr_background_clip.append(bool(clip_is_background))
353
+ # Background clips have real negative supervision in hybrid OvR mode.
354
+ self.has_real_frame_labels.append(bool(clip_is_background))
355
+
356
+ # Clip-stitching augmentation params.
357
+ self.stitch_prob = float(stitch_prob)
358
+ self.stitch_exclude_classes: set[int] = set()
359
+ # Map class index → list of clip indices for fast same/different-class sampling.
360
+ self._label_to_clip_indices: dict[int, list[int]] = {}
361
+ for i, lbl in enumerate(self.labels):
362
+ if lbl >= 0:
363
+ self._label_to_clip_indices.setdefault(lbl, []).append(i)
364
+
365
+ # Crop jitter augmentation (only used with ROI cache).
366
+ self.crop_jitter = bool(crop_jitter)
367
+ self.crop_jitter_strength = float(crop_jitter_strength)
368
+
369
+ # Runtime toggles used by training curriculum.
370
+ self._roi_cache_mode = False
371
+ self._roi_cache_dir = None
372
+ # Embedding-space stitch cache (backbone tokens pre-computed).
373
+ self._emb_cache_mode = False
374
+ self._emb_cache_dir = None
375
+ self._emb_clip_length = self.clip_length
376
+
377
+ def __len__(self):
378
+ return len(self.clips) * self.virtual_size_multiplier
379
+
380
+ def _resolve_clip_path(self, clip_id: str) -> str:
381
+ clip_path = os.path.join(self.clip_base_dir, clip_id)
382
+ found = False
383
+ if os.path.exists(clip_path):
384
+ found = True
385
+ else:
386
+ base_name, ext = os.path.splitext(clip_id)
387
+ clip_basename = os.path.basename(clip_id)
388
+ clip_dir_part = os.path.dirname(clip_id) if os.path.dirname(clip_id) else None
389
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
390
+
391
+ if not ext:
392
+ for video_ext in video_extensions:
393
+ test_path = os.path.join(self.clip_base_dir, clip_id + video_ext)
394
+ if os.path.exists(test_path):
395
+ clip_path = test_path
396
+ found = True
397
+ break
398
+ else:
399
+ base_name_only = os.path.basename(base_name)
400
+ ext_lower = ext.lower()
401
+
402
+ for video_ext in video_extensions:
403
+ if video_ext.lower() == ext_lower:
404
+ continue
405
+ test_path = os.path.join(self.clip_base_dir, base_name + video_ext)
406
+ if os.path.exists(test_path):
407
+ clip_path = test_path
408
+ found = True
409
+ break
410
+
411
+ if not found:
412
+ for video_ext in video_extensions:
413
+ test_path = os.path.join(self.clip_base_dir, base_name_only + video_ext)
414
+ if os.path.exists(test_path):
415
+ clip_path = test_path
416
+ found = True
417
+ break
418
+
419
+ if not found:
420
+ for root, dirs, files in os.walk(self.clip_base_dir):
421
+ for file in files:
422
+ file_base, file_ext = os.path.splitext(file)
423
+ if file_base == base_name_only or file_base == base_name:
424
+ if file_ext.lower() in [e.lower() for e in video_extensions]:
425
+ clip_path = os.path.join(root, file)
426
+ found = True
427
+ break
428
+ if found:
429
+ break
430
+
431
+ if not found and clip_dir_part:
432
+ subdir_path = os.path.join(self.clip_base_dir, clip_dir_part)
433
+ if os.path.exists(subdir_path):
434
+ for video_ext in video_extensions:
435
+ test_path = os.path.join(subdir_path, clip_basename)
436
+ if os.path.exists(test_path):
437
+ clip_path = test_path
438
+ found = True
439
+ break
440
+ test_path = os.path.join(subdir_path, base_name_only + video_ext)
441
+ if os.path.exists(test_path):
442
+ clip_path = test_path
443
+ found = True
444
+ break
445
+
446
+ if not found:
447
+ error_msg = f"Clip not found: {clip_path}\n"
448
+ error_msg += f"Clip ID from annotation: {clip_id}\n"
449
+ error_msg += f"Base directory: {self.clip_base_dir}\n"
450
+ error_msg += "Please check if the file exists or update the annotation."
451
+ raise FileNotFoundError(error_msg)
452
+ return clip_path
453
+
454
+ _NATIVE_RES = object() # Sentinel: load at native resolution (no resize)
455
+
456
+ def _load_clip(self, clip_path: str, target_size=None, apply_transform: bool = True,
457
+ apply_aoi: bool = True) -> torch.Tensor:
458
+ """Load and preprocess a video clip.
459
+ target_size: (w,h) to resize to, _NATIVE_RES for no resize, None for default self.target_size.
460
+ apply_aoi: unused (kept for API compatibility).
461
+ """
462
+ import cv2
463
+
464
+ cap = cv2.VideoCapture(clip_path)
465
+ frames = []
466
+ if target_size is self._NATIVE_RES:
467
+ resize_to = None
468
+ elif target_size is not None:
469
+ resize_to = target_size
470
+ else:
471
+ resize_to = self.target_size
472
+
473
+ while True:
474
+ ret, frame = cap.read()
475
+ if not ret:
476
+ break
477
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
478
+ if resize_to is not None:
479
+ h_src, w_src = frame.shape[:2]
480
+ is_upscale = (w_src < resize_to[0]) or (h_src < resize_to[1])
481
+ interp = cv2.INTER_LANCZOS4 if is_upscale else cv2.INTER_AREA
482
+ frame = cv2.resize(frame, resize_to, interpolation=interp)
483
+ if is_upscale:
484
+ blurred = cv2.GaussianBlur(frame, (0, 0), sigmaX=1.0)
485
+ frame = cv2.addWeighted(frame, 1.5, blurred, -0.5, 0)
486
+ frames.append(frame)
487
+
488
+ cap.release()
489
+
490
+ while len(frames) < self.clip_length:
491
+ if frames:
492
+ frames.append(frames[-1])
493
+ else:
494
+ fallback_size = resize_to if resize_to is not None else self.target_size
495
+ frames.append(np.zeros((*fallback_size[::-1], 3), dtype=np.uint8))
496
+
497
+ if len(frames) > self.clip_length:
498
+ # Center-crop temporally: pick the middle clip_length frames
499
+ start = (len(frames) - self.clip_length) // 2
500
+ frames = frames[start:start + self.clip_length]
501
+
502
+ clip_array = np.stack(frames).astype(np.float32) / 255.0
503
+ clip_tensor = torch.from_numpy(clip_array)
504
+ clip_tensor = clip_tensor.permute(0, 3, 1, 2)
505
+
506
+ if apply_transform and self.transform:
507
+ clip_tensor = self.transform(clip_tensor)
508
+
509
+ return clip_tensor
510
+
511
+ def load_fullres_clip_by_index(self, actual_idx: int, apply_aoi: bool = False) -> torch.Tensor:
512
+ """Load clip at native resolution without augmentation.
513
+
514
+ By default skip AOI crop (localization needs the original frame).
515
+ """
516
+ clip_info = self.clips[actual_idx]
517
+ clip_id = clip_info["id"]
518
+ clip_path = self._resolve_clip_path(clip_id)
519
+ return self._load_clip(clip_path, target_size=self._NATIVE_RES, apply_transform=False, apply_aoi=apply_aoi)
520
+
521
+ def load_modelres_clip_by_index(self, actual_idx: int) -> torch.Tensor:
522
+ """Load clip at model input resolution without augmentation."""
523
+ clip_info = self.clips[actual_idx]
524
+ clip_id = clip_info["id"]
525
+ clip_path = self._resolve_clip_path(clip_id)
526
+ return self._load_clip(clip_path, target_size=self.target_size, apply_transform=False)
527
+
528
+ def _apply_crop_jitter(self, x: torch.Tensor) -> torch.Tensor:
529
+ """Randomly shift the crop to vary background context.
530
+
531
+ x: [T, C, H, W] in [0,1]. Returns same shape with a random
532
+ translation applied (pixels that shift out are filled with edge values).
533
+ """
534
+ _, _, H, W = x.shape
535
+ max_dx = int(self.crop_jitter_strength * W)
536
+ max_dy = int(self.crop_jitter_strength * H)
537
+ if max_dx < 1 and max_dy < 1:
538
+ return x
539
+ dx = random.randint(-max_dx, max_dx)
540
+ dy = random.randint(-max_dy, max_dy)
541
+ if dx == 0 and dy == 0:
542
+ return x
543
+ # Use affine_grid + grid_sample for smooth sub-pixel shifting
544
+ theta = torch.tensor([
545
+ [1.0, 0.0, -2.0 * dx / W],
546
+ [0.0, 1.0, -2.0 * dy / H],
547
+ ], dtype=x.dtype).unsqueeze(0).expand(x.size(0), -1, -1)
548
+ grid = F.affine_grid(theta, x.shape, align_corners=False)
549
+ return F.grid_sample(x, grid, mode='bilinear', padding_mode='border', align_corners=False)
550
+
551
+ def _apply_spatial_label_augment(
552
+ self,
553
+ spatial_mask: torch.Tensor,
554
+ spatial_bbox: torch.Tensor,
555
+ spatial_bbox_valid: torch.Tensor,
556
+ aug_params: dict,
557
+ ) -> tuple[torch.Tensor, torch.Tensor]:
558
+ """Apply spatial label transforms using the same params as clip augmentation."""
559
+ hflip = bool(aug_params.get("hflip", aug_params.get("flip", False))) if aug_params else False
560
+ vflip = bool(aug_params.get("vflip", False)) if aug_params else False
561
+ if not hflip and not vflip:
562
+ return spatial_mask, spatial_bbox
563
+
564
+ out_mask = spatial_mask
565
+ out_bbox = spatial_bbox.clone()
566
+
567
+ # Flip mask on width/height axes when mask grid is enabled.
568
+ if self.grid_size > 0 and out_mask.numel() == self.grid_size * self.grid_size:
569
+ out_mask_grid = out_mask.reshape(self.grid_size, self.grid_size)
570
+ if hflip:
571
+ out_mask_grid = out_mask_grid.flip(1)
572
+ if vflip:
573
+ out_mask_grid = out_mask_grid.flip(0)
574
+ out_mask = out_mask_grid.reshape(-1)
575
+
576
+ # Flip per-frame bboxes in normalized xyxy coordinates.
577
+ # spatial_bbox is [T, 4], spatial_bbox_valid is [T]
578
+ if out_bbox.dim() == 2:
579
+ valid = spatial_bbox_valid > 0.5 # [T]
580
+ if valid.any():
581
+ if hflip:
582
+ new_x1 = (1.0 - out_bbox[:, 2]).clamp(0.0, 1.0)
583
+ new_x2 = (1.0 - out_bbox[:, 0]).clamp(0.0, 1.0)
584
+ out_bbox[:, 0] = new_x1
585
+ out_bbox[:, 2] = new_x2
586
+ if vflip:
587
+ new_y1 = (1.0 - out_bbox[:, 3]).clamp(0.0, 1.0)
588
+ new_y2 = (1.0 - out_bbox[:, 1]).clamp(0.0, 1.0)
589
+ out_bbox[:, 1] = new_y1
590
+ out_bbox[:, 3] = new_y2
591
+ # Zero out invalid frames to avoid corrupted coords
592
+ out_bbox[~valid] = 0.0
593
+ elif out_bbox.dim() == 1 and out_bbox.numel() == 4:
594
+ # Legacy fallback (single bbox)
595
+ if float(spatial_bbox_valid.sum().item()) > 0.5:
596
+ x1, y1, x2, y2 = [float(v) for v in out_bbox]
597
+ if hflip:
598
+ x1, x2 = max(0.0, min(1.0, 1.0 - x2)), max(0.0, min(1.0, 1.0 - x1))
599
+ if vflip:
600
+ y1, y2 = max(0.0, min(1.0, 1.0 - y2)), max(0.0, min(1.0, 1.0 - y1))
601
+ out_bbox = torch.tensor([x1, y1, x2, y2], dtype=out_bbox.dtype)
602
+
603
+ return out_mask, out_bbox
604
+
605
+ def _do_stitch(self, actual_idx: int, x_a: torch.Tensor):
606
+ """Splice clip A with a clip from a different class at a fixed 50/50 boundary.
607
+
608
+ Returns (x_stitched, fl_stitched, y_stitched, spatial_mask, bbox, bbox_valid, bg_mask).
609
+ The clip-level label is set to -1 because the mixed clip has no single ground
610
+ truth; the per-frame labels carry all supervision.
611
+ """
612
+ T = self.clip_length
613
+ label_a = self.labels[actual_idx]
614
+
615
+ # Don't stitch if the source clip is an excluded class (e.g. "Other")
616
+ if label_a in self.stitch_exclude_classes:
617
+ return (x_a, self.frame_labels[actual_idx],
618
+ self.labels[actual_idx],
619
+ self.spatial_masks[actual_idx].clone(),
620
+ self.spatial_bboxes[actual_idx].clone(),
621
+ self.spatial_bbox_valid[actual_idx].clone(),
622
+ self.ovr_background_frame_mask[actual_idx].clone())
623
+
624
+ # Gather candidate indices from any class other than clip A's class,
625
+ # excluding F1-excluded classes (e.g. "Other") to avoid contaminating
626
+ # real behavior clips with catch-all content.
627
+ other_indices: list[int] = []
628
+ for cls_idx, idxs in self._label_to_clip_indices.items():
629
+ if cls_idx != label_a and cls_idx not in self.stitch_exclude_classes:
630
+ other_indices.extend(idxs)
631
+ if not other_indices:
632
+ # Fall back to a different clip from the same class.
633
+ same = [i for i in self._label_to_clip_indices.get(label_a, []) if i != actual_idx]
634
+ if not same:
635
+ # Only one clip in the whole dataset — skip stitching.
636
+ return (x_a, self.frame_labels[actual_idx],
637
+ self.labels[actual_idx],
638
+ self.spatial_masks[actual_idx].clone(),
639
+ self.spatial_bboxes[actual_idx].clone(),
640
+ self.spatial_bbox_valid[actual_idx].clone(),
641
+ self.ovr_background_frame_mask[actual_idx].clone())
642
+ other_indices = same
643
+
644
+ b_idx = random.choice(other_indices)
645
+ clip_b_info = self.clips[b_idx]
646
+ clip_b_path = self._resolve_clip_path(clip_b_info["id"])
647
+ x_b = self._load_clip(clip_b_path, apply_transform=False)
648
+
649
+ # 50/50 split: each half gets T//2 frames for maximum temporal context.
650
+ stitch_t = T // 2
651
+
652
+ x_stitched = torch.cat([x_a[:stitch_t], x_b[stitch_t:]], dim=0)
653
+
654
+ fl_a = self.frame_labels[actual_idx].clone()
655
+ fl_b = self.frame_labels[b_idx].clone()
656
+ fl_stitched = torch.cat([fl_a[:stitch_t], fl_b[stitch_t:]], dim=0)
657
+ bg_a = self.ovr_background_frame_mask[actual_idx].clone()
658
+ bg_b = self.ovr_background_frame_mask[b_idx].clone()
659
+ bg_stitched = torch.cat([bg_a[:stitch_t], bg_b[stitch_t:]], dim=0)
660
+
661
+ # Splice per-frame bboxes; use clip A's spatial grid mask (no better option
662
+ # for a synthetic composite clip).
663
+ bbox_a = self.spatial_bboxes[actual_idx].clone()
664
+ bbox_b = self.spatial_bboxes[b_idx].clone()
665
+ bbox_stitched = torch.cat([bbox_a[:stitch_t], bbox_b[stitch_t:]], dim=0)
666
+
667
+ bv_a = self.spatial_bbox_valid[actual_idx].clone()
668
+ bv_b = self.spatial_bbox_valid[b_idx].clone()
669
+ bv_stitched = torch.cat([bv_a[:stitch_t], bv_b[stitch_t:]], dim=0)
670
+
671
+ spatial_mask = self.spatial_masks[actual_idx].clone()
672
+ return x_stitched, fl_stitched, -1, spatial_mask, bbox_stitched, bv_stitched, bg_stitched
673
+
674
+ def __getitem__(self, idx):
675
+ # Map virtual index to actual clip index
676
+ actual_idx = idx % len(self.clips)
677
+ clip_info = self.clips[actual_idx]
678
+
679
+ # Embedding-space stitch: backbone tokens are pre-cached.
680
+ # Stitch happens on token tensors so VideoPrism always sees clean clips.
681
+ if getattr(self, '_emb_cache_mode', False):
682
+ cache_dir = getattr(self, '_emb_cache_dir', None)
683
+ num_versions = getattr(self, '_emb_num_versions', 1)
684
+ use_multi_scale = getattr(self, '_emb_multi_scale', False)
685
+
686
+ def _pick_version() -> int:
687
+ return random.randint(0, num_versions - 1) if num_versions > 1 else 0
688
+
689
+ def _load_emb(clip_idx: int, v: int, short: bool = False) -> torch.Tensor:
690
+ suffix = f"_{v}_s.pt" if short else f"_{v}.pt"
691
+ path = os.path.join(cache_dir, f"{clip_idx}{suffix}")
692
+ return torch.load(path, map_location='cpu', weights_only=True).float()
693
+
694
+ v_a = _pick_version()
695
+ emb_a = _load_emb(actual_idx, v_a) if cache_dir else torch.zeros(self.clip_length * 256, 768)
696
+ emb_a_s = _load_emb(actual_idx, v_a, short=True) if (cache_dir and use_multi_scale) else None
697
+
698
+ label_a = self.labels[actual_idx]
699
+ do_stitch = (
700
+ self.stitch_prob > 0.0
701
+ and label_a not in self.stitch_exclude_classes
702
+ and torch.rand(1).item() < self.stitch_prob
703
+ )
704
+ other_indices = [
705
+ i for cls, idxs in self._label_to_clip_indices.items()
706
+ if cls != label_a and cls not in self.stitch_exclude_classes
707
+ for i in idxs
708
+ ]
709
+
710
+ if do_stitch and other_indices:
711
+ b_idx = random.choice(other_indices)
712
+ v_b = _pick_version()
713
+ emb_b = _load_emb(b_idx, v_b)
714
+ emb_b_s = _load_emb(b_idx, v_b, short=True) if use_multi_scale else None
715
+
716
+ T = self._emb_clip_length
717
+ S = emb_a.shape[0] // T
718
+ stitch_t = T // 2
719
+
720
+ x_out = torch.cat([emb_a[:stitch_t * S], emb_b[stitch_t * S:]], dim=0)
721
+
722
+ if use_multi_scale and emb_a_s is not None and emb_b_s is not None:
723
+ T_s = T // 2
724
+ S_s = emb_a_s.shape[0] // T_s
725
+ stitch_t_s = T_s // 2
726
+ x_short = torch.cat([emb_a_s[:stitch_t_s * S_s], emb_b_s[stitch_t_s * S_s:]], dim=0)
727
+ else:
728
+ x_short = torch.empty(0)
729
+
730
+ fl_a = self.frame_labels[actual_idx].clone()
731
+ fl_b = self.frame_labels[b_idx].clone()
732
+ fl = torch.cat([fl_a[:stitch_t], fl_b[stitch_t:]], dim=0)
733
+ bg_a = self.ovr_background_frame_mask[actual_idx].clone()
734
+ bg_b = self.ovr_background_frame_mask[b_idx].clone()
735
+ bg_mask = torch.cat([bg_a[:stitch_t], bg_b[stitch_t:]], dim=0)
736
+ bbox_a = self.spatial_bboxes[actual_idx].clone()
737
+ bbox_b = self.spatial_bboxes[b_idx].clone()
738
+ spatial_bbox = torch.cat([bbox_a[:stitch_t], bbox_b[stitch_t:]], dim=0)
739
+ bv_a = self.spatial_bbox_valid[actual_idx].clone()
740
+ bv_b = self.spatial_bbox_valid[b_idx].clone()
741
+ spatial_bbox_valid = torch.cat([bv_a[:stitch_t], bv_b[stitch_t:]], dim=0)
742
+ spatial_mask = self.spatial_masks[actual_idx].clone()
743
+ y = -1
744
+ else:
745
+ x_out = emb_a
746
+ x_short = emb_a_s if (use_multi_scale and emb_a_s is not None) else torch.empty(0)
747
+ y = self.labels[actual_idx]
748
+ fl = self.frame_labels[actual_idx].clone()
749
+ bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
750
+ spatial_mask = self.spatial_masks[actual_idx].clone()
751
+ spatial_bbox = self.spatial_bboxes[actual_idx].clone()
752
+ spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
753
+
754
+ return x_out, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, x_short, bg_mask
755
+
756
+ # When ROI cache is active, load the precomputed crop from disk instead
757
+ # of decoding the original video. This lets DataLoader workers operate
758
+ # normally (parallel prefetch) so classification training runs at the
759
+ # same speed as standard (no-localization) training.
760
+ if getattr(self, '_roi_cache_mode', False):
761
+ cache_dir = getattr(self, '_roi_cache_dir', None)
762
+ x_full = None
763
+ if cache_dir:
764
+ pt_path = os.path.join(cache_dir, f"{actual_idx}.pt")
765
+ cached = torch.load(pt_path, map_location='cpu', weights_only=True)
766
+ if isinstance(cached, dict) and torch.is_tensor(cached.get("roi")):
767
+ x = cached["roi"].float()
768
+ if torch.is_tensor(cached.get("full")):
769
+ x_full = cached["full"].float()
770
+ elif torch.is_tensor(cached):
771
+ x = cached.float()
772
+ else:
773
+ x = torch.zeros(self.clip_length, 3, 1, 1, dtype=torch.float32)
774
+ else:
775
+ x = torch.zeros(self.clip_length, 3, 1, 1, dtype=torch.float32)
776
+ if x_full is None:
777
+ x_full = x.clone()
778
+
779
+ if self.crop_jitter and self.crop_jitter_strength > 0:
780
+ x = self._apply_crop_jitter(x)
781
+
782
+ spatial_mask = self.spatial_masks[actual_idx].clone()
783
+ spatial_bbox = self.spatial_bboxes[actual_idx].clone()
784
+ spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
785
+
786
+ if self.transform:
787
+ if hasattr(self.transform, "augment_with_params"):
788
+ x, aug_params = self.transform.augment_with_params(x)
789
+ spatial_mask, spatial_bbox = self._apply_spatial_label_augment(
790
+ spatial_mask, spatial_bbox, spatial_bbox_valid, aug_params
791
+ )
792
+ else:
793
+ x = self.transform(x)
794
+
795
+ y = self.labels[actual_idx]
796
+ fl = self.frame_labels[actual_idx]
797
+ bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
798
+ return x, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, torch.empty(0), bg_mask
799
+
800
+ clip_id = clip_info["id"]
801
+ clip_path = self._resolve_clip_path(clip_id)
802
+
803
+ x = self._load_clip(clip_path, apply_transform=False)
804
+
805
+ # Clip-stitching augmentation: splice two clips from different classes so
806
+ # the model learns per-frame features independent of clip-level context.
807
+ if self.stitch_prob > 0.0 and torch.rand(1).item() < self.stitch_prob:
808
+ x, fl, y, spatial_mask, spatial_bbox, spatial_bbox_valid, bg_mask = self._do_stitch(actual_idx, x)
809
+ else:
810
+ spatial_mask = self.spatial_masks[actual_idx].clone()
811
+ spatial_bbox = self.spatial_bboxes[actual_idx].clone()
812
+ spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
813
+ y = self.labels[actual_idx]
814
+ fl = self.frame_labels[actual_idx]
815
+ bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
816
+
817
+ if self.transform:
818
+ if hasattr(self.transform, "augment_with_params"):
819
+ x, aug_params = self.transform.augment_with_params(x)
820
+ spatial_mask, spatial_bbox = self._apply_spatial_label_augment(
821
+ spatial_mask, spatial_bbox, spatial_bbox_valid, aug_params
822
+ )
823
+ else:
824
+ x = self.transform(x)
825
+
826
+ return x, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, torch.empty(0), bg_mask
827
+
828
+
829
+ def compute_class_weights(labels: list, num_classes: int) -> torch.Tensor:
830
+ """Compute class weights based on inverse frequency."""
831
+ if not labels:
832
+ return torch.ones(num_classes)
833
+
834
+ counter = Counter(labels)
835
+ total = len(labels)
836
+
837
+ weights = torch.ones(num_classes)
838
+ for class_idx, count in counter.items():
839
+ if count > 0:
840
+ weights[class_idx] = total / (num_classes * count)
841
+
842
+ weights = weights / weights.sum() * num_classes
843
+ return weights
844
+
845
+
846
+ class BalancedBatchSampler:
847
+ """
848
+ Batch sampler that ensures at least `min_samples_per_class` samples per selected class inside each batch.
849
+ Strategy:
850
+ - Pick K = min(num_eligible_classes, batch_size // min_samples_per_class) classes per batch
851
+ - Draw `min_samples_per_class` samples for each selected class (with cycling/shuffle per class)
852
+ - Fill any remaining slots from the selected classes
853
+ Notes:
854
+ - Classes with fewer than `min_samples_per_class` examples are ignored for balancing
855
+ - When no eligible classes exist, the sampler yields empty and should be ignored by caller
856
+ """
857
+ def __init__(
858
+ self,
859
+ labels: list[int],
860
+ batch_size: int,
861
+ min_samples_per_class: int = 2,
862
+ drop_last: bool = False,
863
+ seed: Optional[int] = None,
864
+ excluded_classes: Optional[list[int]] = None,
865
+ virtual_size_multiplier: int = 1,
866
+ background_indices: Optional[list[int]] = None,
867
+ background_per_batch: int = 0,
868
+ ):
869
+ self.labels = list(labels)
870
+ self.batch_size = int(batch_size)
871
+ self.min_samples_per_class = int(min_samples_per_class)
872
+ self.drop_last = bool(drop_last)
873
+ self.excluded_classes = set(excluded_classes or [])
874
+ self.virtual_size_multiplier = max(1, int(virtual_size_multiplier))
875
+ self.background_indices = list(background_indices or [])
876
+ self.background_per_batch = max(0, int(background_per_batch))
877
+ self._effective_length = len(self.labels) * self.virtual_size_multiplier
878
+
879
+ # Use global random state if seed is None (respects random.seed())
880
+ # Otherwise use private generator
881
+ if seed is None:
882
+ self.rng = random
883
+ else:
884
+ self.rng = random.Random(seed)
885
+
886
+ # Build index lists per class
887
+ self.class_to_indices: Dict[int, list[int]] = {}
888
+ for idx, y in enumerate(self.labels):
889
+ self.class_to_indices.setdefault(int(y), []).append(idx)
890
+
891
+ # Eligible classes must have at least `min_samples_per_class` items and NOT be excluded
892
+ self.eligible_classes = [
893
+ c for c, idxs in self.class_to_indices.items()
894
+ if len(idxs) >= self.min_samples_per_class and c not in self.excluded_classes
895
+ ]
896
+ self.enabled = len(self.eligible_classes) > 0 and self.batch_size > 0
897
+
898
+ # Identify dropped classes
899
+ self.dropped_classes = [c for c, idxs in self.class_to_indices.items() if len(idxs) < self.min_samples_per_class]
900
+
901
+ # Prepare per-class shuffled pools with cursors
902
+ self._pools: Dict[int, list[int]] = {}
903
+ self._cursors: Dict[int, int] = {}
904
+ for c in self.eligible_classes:
905
+ pool = self.class_to_indices[c][:]
906
+ self.rng.shuffle(pool)
907
+ self._pools[c] = pool
908
+ self._cursors[c] = 0
909
+ self._bg_pool = self.background_indices[:]
910
+ if self._bg_pool:
911
+ self.rng.shuffle(self._bg_pool)
912
+ self._bg_cursor = 0
913
+
914
+ def __len__(self) -> int:
915
+ if self._effective_length <= 0:
916
+ return 0
917
+ if self.drop_last:
918
+ return self._effective_length // max(1, self.batch_size)
919
+ return math.ceil(self._effective_length / max(1, self.batch_size))
920
+
921
+ def _draw_from_class(self, cls: int) -> int:
922
+ pool = self._pools[cls]
923
+ cur = self._cursors[cls]
924
+ if cur >= len(pool):
925
+ # Reshuffle and reset
926
+ pool = self.class_to_indices[cls][:]
927
+ self.rng.shuffle(pool)
928
+ self._pools[cls] = pool
929
+ cur = 0
930
+ idx = pool[cur]
931
+ self._cursors[cls] = cur + 1
932
+ return idx
933
+
934
+ def _draw_background(self) -> Optional[int]:
935
+ if not self._bg_pool:
936
+ return None
937
+ cur = self._bg_cursor
938
+ if cur >= len(self._bg_pool):
939
+ self._bg_pool = self.background_indices[:]
940
+ self.rng.shuffle(self._bg_pool)
941
+ cur = 0
942
+ idx = self._bg_pool[cur]
943
+ self._bg_cursor = cur + 1
944
+ return idx
945
+
946
+ def __iter__(self):
947
+ if not self.enabled:
948
+ # Yield nothing - caller should fallback
949
+ return
950
+ num_batches = len(self)
951
+ for _ in range(num_batches):
952
+ batch: list[int] = []
953
+ if self.batch_size <= 0:
954
+ yield batch
955
+ continue
956
+
957
+ # Number of classes to include this batch
958
+ k = max(1, min(len(self.eligible_classes), self.batch_size // self.min_samples_per_class))
959
+ if k <= len(self.eligible_classes):
960
+ selected = self.rng.sample(self.eligible_classes, k)
961
+ else:
962
+ # Not enough distinct classes; sample with replacement to fill
963
+ selected = self.eligible_classes[:] + [self.rng.choice(self.eligible_classes) for _ in range(k - len(self.eligible_classes))]
964
+
965
+ # Ensure min_samples_per_class per selected class
966
+ for cls in selected:
967
+ needed = min(self.min_samples_per_class, max(0, self.batch_size - len(batch)))
968
+ for _ in range(needed):
969
+ batch.append(self._draw_from_class(cls))
970
+ if len(batch) >= self.batch_size:
971
+ break
972
+ if len(batch) >= self.batch_size:
973
+ break
974
+
975
+ # Optional hybrid-OvR negatives: reserve a few slots for background clips
976
+ # that teach all target heads to stay low without becoming their own class.
977
+ bg_slots = min(
978
+ self.background_per_batch,
979
+ max(0, self.batch_size - len(batch)),
980
+ )
981
+ for _ in range(bg_slots):
982
+ bg_idx = self._draw_background()
983
+ if bg_idx is None:
984
+ break
985
+ batch.append(bg_idx)
986
+ if len(batch) >= self.batch_size:
987
+ break
988
+
989
+ # Fill remaining slots from the selected classes (round-robin)
990
+ si = 0
991
+ while len(batch) < self.batch_size:
992
+ cls = selected[si % len(selected)]
993
+ batch.append(self._draw_from_class(cls))
994
+ si += 1
995
+
996
+ yield batch
997
+
998
+
999
+ class ConfusionAwareSampler(BalancedBatchSampler):
1000
+ """BalancedBatchSampler extended with per-sample confusion-based weights for OvR hard mining.
1001
+
1002
+ After each training epoch, call update_weights() with per-clip confusion scores.
1003
+ Clips with high confusion scores (model fires wrong heads) are sampled more often
1004
+ within their class pool. Weights are EMA-blended to avoid sudden jumps.
1005
+
1006
+ Confusion score for clip with true class c:
1007
+ blend of:
1008
+ - strongest rival-head activation
1009
+ - low activation of the true head
1010
+ - rival-over-true margin violation
1011
+ This keeps the score pair-aware while still producing a single scalar weight.
1012
+ """
1013
+
1014
+ def __init__(self, *args, weight_temperature: float = 2.0, **kwargs):
1015
+ super().__init__(*args, **kwargs)
1016
+ self.weight_temperature = max(0.1, float(weight_temperature))
1017
+ n = len(self.labels)
1018
+ self._confusion_scores = np.zeros(n, dtype=np.float32)
1019
+ self._top_rival = np.full(n, -1, dtype=np.int32)
1020
+ # Rebuild per-class weighted pools (initially uniform)
1021
+ self._class_weights: Dict[int, tuple] = {}
1022
+ self._rebuild_class_weights()
1023
+
1024
+ def update_weights(
1025
+ self,
1026
+ confusion_scores: np.ndarray,
1027
+ top_rivals: Optional[np.ndarray] = None,
1028
+ ema_alpha: float = 0.4,
1029
+ ) -> None:
1030
+ """Blend new confusion scores into running weights and rebuild class pools."""
1031
+ if len(confusion_scores) != len(self._confusion_scores):
1032
+ return
1033
+ self._confusion_scores = (
1034
+ ema_alpha * confusion_scores.astype(np.float32)
1035
+ + (1.0 - ema_alpha) * self._confusion_scores
1036
+ )
1037
+ if top_rivals is not None and len(top_rivals) == len(self._top_rival):
1038
+ valid = top_rivals.astype(np.int32) >= 0
1039
+ self._top_rival[valid] = top_rivals.astype(np.int32)[valid]
1040
+ self._rebuild_class_weights()
1041
+
1042
+ def _rebuild_class_weights(self) -> None:
1043
+ """Compute normalised sampling probability per class from confusion scores."""
1044
+ self._class_weights = {}
1045
+ for c, indices in self.class_to_indices.items():
1046
+ scores = self._confusion_scores[indices]
1047
+ # Base weight 1.0 + confusion boost; apply temperature sharpening
1048
+ w = (1.0 + scores) ** self.weight_temperature
1049
+ total = w.sum()
1050
+ if total <= 0:
1051
+ w = np.ones_like(w, dtype=np.float32)
1052
+ total = w.sum()
1053
+ self._class_weights[c] = (np.asarray(indices), w / total)
1054
+
1055
+ def _draw_from_class(self, cls: int) -> int:
1056
+ """Weighted sample from class pool; falls back to parent if weights unavailable."""
1057
+ if cls in self._class_weights:
1058
+ indices, probs = self._class_weights[cls]
1059
+ return int(np.random.choice(indices, p=probs))
1060
+ return super()._draw_from_class(cls)
1061
+
1062
+ def log_top_confused(self, class_names: list, dataset_clips: list, n: int = 5) -> list[str]:
1063
+ """Return log lines describing the most-confused clip per class for transparency."""
1064
+ lines = []
1065
+ for ci, cname in enumerate(class_names):
1066
+ indices = self.class_to_indices.get(ci, [])
1067
+ if not indices:
1068
+ continue
1069
+ scores = self._confusion_scores[indices]
1070
+ top_local = int(np.argmax(scores))
1071
+ top_global = indices[top_local]
1072
+ top_score = float(scores[top_local])
1073
+ if top_score < 0.05:
1074
+ continue
1075
+ clip_id = dataset_clips[top_global % len(dataset_clips)].get("id", "?") if dataset_clips else str(top_global)
1076
+ rival_idx = int(self._top_rival[top_global]) if 0 <= top_global < len(self._top_rival) else -1
1077
+ rival_txt = ""
1078
+ if 0 <= rival_idx < len(class_names):
1079
+ rival_txt = f", rival={class_names[rival_idx]}"
1080
+ lines.append(
1081
+ f" [{cname}] hardest: {os.path.basename(clip_id)} "
1082
+ f"(confusion={top_score:.3f}{rival_txt})"
1083
+ )
1084
+ return lines[:n]
1085
+
1086
+
1087
+ def _run_augmentation_ablation_eval(
1088
+ model: nn.Module,
1089
+ dataset,
1090
+ config: Dict[str, Any],
1091
+ device: torch.device,
1092
+ log_fn: Optional[Callable] = None,
1093
+ ):
1094
+ """Post-training evaluation: measure per-augmentation impact on accuracy.
1095
+
1096
+ For each enabled augmentation, runs all clips through the model with ONLY that
1097
+ augmentation applied (3 random trials averaged) and compares to the clean baseline.
1098
+ Reports which augmentations help/hurt and the worst-affected clips per augmentation.
1099
+ """
1100
+ from .augmentations import ClipAugment
1101
+
1102
+ aug_opts = config.get("augmentation_options") or {}
1103
+ use_ovr = config.get("use_ovr", False)
1104
+
1105
+ # Map of augmentation toggle names → ClipAugment kwargs that isolate that aug
1106
+ aug_toggles = {
1107
+ "horizontal_flip": "use_horizontal_flip",
1108
+ "vertical_flip": "use_vertical_flip",
1109
+ "color_jitter": "use_color_jitter",
1110
+ "gaussian_blur": "use_gaussian_blur",
1111
+ "random_noise": "use_random_noise",
1112
+ "small_rotation": "use_small_rotation",
1113
+ "speed_perturb": "use_speed_perturb",
1114
+ "random_shapes": "use_random_shapes",
1115
+ "grayscale": "use_grayscale",
1116
+ }
1117
+
1118
+ # Only evaluate augmentations that were actually enabled during training
1119
+ enabled = []
1120
+ for name, kwarg in aug_toggles.items():
1121
+ if aug_opts.get(kwarg, False):
1122
+ enabled.append((name, kwarg))
1123
+
1124
+ if not enabled:
1125
+ if log_fn:
1126
+ log_fn("Augmentation ablation: no augmentations were enabled — skipping.")
1127
+ return
1128
+
1129
+ if log_fn:
1130
+ log_fn(f"\n{'='*60}")
1131
+ log_fn("AUGMENTATION ABLATION EVALUATION")
1132
+ log_fn(f"{'='*60}")
1133
+ log_fn(f"Evaluating {len(enabled)} enabled augmentation(s)...")
1134
+
1135
+ n_clips = len(dataset.clips)
1136
+ class_names = dataset.classes
1137
+ n_classes = len(class_names)
1138
+ n_trials = 3 # average over multiple random augmentation rolls
1139
+
1140
+ saved_transform = dataset.transform
1141
+ saved_stitch = getattr(dataset, "stitch_prob", 0.0)
1142
+ saved_mult = getattr(dataset, "virtual_size_multiplier", 1)
1143
+ dataset.transform = None
1144
+ dataset.stitch_prob = 0.0
1145
+ dataset.virtual_size_multiplier = 1
1146
+
1147
+ model.eval()
1148
+
1149
+ def _eval_clips(transform_fn) -> np.ndarray:
1150
+ """Run all clips, return per-clip frame accuracy array."""
1151
+ dataset.transform = transform_fn
1152
+ clip_accs = np.zeros(n_clips, dtype=np.float32)
1153
+ loader = DataLoader(dataset, batch_size=config.get("batch_size", 8),
1154
+ shuffle=False, num_workers=0, pin_memory=False)
1155
+ clip_cursor = 0
1156
+ with torch.no_grad():
1157
+ for batch_data in loader:
1158
+ if not isinstance(batch_data, (list, tuple)) or len(batch_data) < 7:
1159
+ continue
1160
+ clips_t = batch_data[0].to(device)
1161
+ frame_labels_t = batch_data[6].to(device) if batch_data[6] is not None else None
1162
+ indices_t = batch_data[5]
1163
+
1164
+ _emb_mode = getattr(dataset, '_emb_cache_mode', False)
1165
+ _cl = config.get("clip_length", 8)
1166
+ if _emb_mode:
1167
+ clips_short_t = None
1168
+ if len(batch_data) >= 8:
1169
+ _cs_abl = batch_data[7]
1170
+ if isinstance(_cs_abl, torch.Tensor) and _cs_abl.numel() > 0:
1171
+ clips_short_t = _cs_abl.to(device)
1172
+ out = model(
1173
+ None, backbone_tokens=clips_t, num_frames=_cl,
1174
+ backbone_tokens_short=clips_short_t,
1175
+ num_frames_short=_cl // 2 if clips_short_t is not None else None,
1176
+ return_frame_logits=True,
1177
+ )
1178
+ else:
1179
+ out = model(clips_t, return_frame_logits=True)
1180
+
1181
+ fo = getattr(model, '_frame_output', None)
1182
+ if fo is None:
1183
+ continue
1184
+ f_logits = fo[0] # [B, T, C]
1185
+
1186
+ if use_ovr:
1187
+ preds = torch.argmax(torch.sigmoid(f_logits), dim=-1)
1188
+ else:
1189
+ preds = torch.argmax(f_logits, dim=-1)
1190
+
1191
+ B = preds.shape[0]
1192
+ for bi in range(B):
1193
+ idx = int(indices_t[bi].item()) % n_clips
1194
+ if frame_labels_t is not None:
1195
+ valid = frame_labels_t[bi] >= 0
1196
+ if valid.any():
1197
+ acc = float((preds[bi][valid] == frame_labels_t[bi][valid]).float().mean().item())
1198
+ else:
1199
+ acc = 1.0
1200
+ else:
1201
+ acc = 1.0
1202
+ clip_accs[idx] = acc
1203
+ return clip_accs
1204
+
1205
+ # 1. Clean baseline (no augmentation)
1206
+ if log_fn:
1207
+ log_fn("Running clean baseline (no augmentation)...")
1208
+ baseline_accs = _eval_clips(None)
1209
+ baseline_mean = float(baseline_accs.mean()) * 100
1210
+
1211
+ # 2. Per-augmentation evaluation
1212
+ results = {}
1213
+ for aug_name, aug_kwarg in enabled:
1214
+ if log_fn:
1215
+ log_fn(f"Evaluating: {aug_name}...")
1216
+ trial_accs = []
1217
+ for trial in range(n_trials):
1218
+ # Build a ClipAugment with ONLY this one augmentation on
1219
+ kwargs = {k: False for _, k in aug_toggles.items()}
1220
+ kwargs[aug_kwarg] = True
1221
+ # Pass through augmentation-specific params
1222
+ for pkey in ["color_jitter_brightness", "color_jitter_contrast",
1223
+ "color_jitter_saturation", "color_jitter_hue",
1224
+ "noise_std", "rotation_degrees"]:
1225
+ if pkey in aug_opts:
1226
+ kwargs[pkey] = aug_opts[pkey]
1227
+ aug_fn = ClipAugment(**kwargs)
1228
+ trial_accs.append(_eval_clips(aug_fn))
1229
+ avg_accs = np.mean(trial_accs, axis=0)
1230
+ delta = avg_accs - baseline_accs # per-clip change
1231
+ results[aug_name] = {
1232
+ "mean_acc": float(avg_accs.mean()) * 100,
1233
+ "delta_mean": float(delta.mean()) * 100,
1234
+ "n_hurt": int((delta < -0.05).sum()),
1235
+ "n_helped": int((delta > 0.05).sum()),
1236
+ "per_clip_delta": delta,
1237
+ }
1238
+
1239
+ # Restore
1240
+ dataset.transform = saved_transform
1241
+ dataset.stitch_prob = saved_stitch
1242
+ dataset.virtual_size_multiplier = saved_mult
1243
+
1244
+ # 3. Log results
1245
+ if log_fn:
1246
+ log_fn(f"\nClean baseline accuracy: {baseline_mean:.1f}%\n")
1247
+ log_fn(f"{'Augmentation':<20} {'Acc':>7} {'Δ':>7} {'Hurt':>6} {'Helped':>8}")
1248
+ log_fn("-" * 52)
1249
+ for aug_name, r in sorted(results.items(), key=lambda x: x[1]["delta_mean"]):
1250
+ log_fn(
1251
+ f"{aug_name:<20} {r['mean_acc']:>6.1f}% {r['delta_mean']:>+6.1f}% "
1252
+ f"{r['n_hurt']:>5} {r['n_helped']:>7}"
1253
+ )
1254
+
1255
+ # Worst clips per augmentation
1256
+ log_fn(f"\nWorst-affected clips per augmentation:")
1257
+ for aug_name, r in results.items():
1258
+ deltas = r["per_clip_delta"]
1259
+ worst_idx = int(np.argmin(deltas))
1260
+ worst_delta = float(deltas[worst_idx]) * 100
1261
+ if worst_delta > -1.0:
1262
+ continue
1263
+ clip_id = dataset.clips[worst_idx].get("id", "?")
1264
+ clip_label_idx = dataset.labels[worst_idx]
1265
+ clip_cls = class_names[clip_label_idx] if 0 <= clip_label_idx < n_classes else "?"
1266
+ log_fn(
1267
+ f" {aug_name}: {os.path.basename(clip_id)} [{clip_cls}] "
1268
+ f"(clean={baseline_accs[worst_idx]*100:.0f}% → aug={float(r['per_clip_delta'][worst_idx] + baseline_accs[worst_idx])*100:.0f}%, "
1269
+ f"Δ={worst_delta:+.1f}%)"
1270
+ )
1271
+
1272
+ log_fn(f"\n{'='*60}")
1273
+
1274
+
1275
+ def train_model(
1276
+ model: nn.Module,
1277
+ train_dataset: Dataset,
1278
+ val_dataset: Optional[Dataset],
1279
+ config: Dict[str, Any],
1280
+ log_fn: Optional[Callable[[str], None]] = None,
1281
+ progress_callback: Optional[Callable[[int, int], None]] = None,
1282
+ stop_callback: Optional[Callable[[], bool]] = None,
1283
+ metrics_callback: Optional[Callable[[Dict[str, Any]], None]] = None
1284
+ ):
1285
+ """Training loop for behavior classifier.
1286
+
1287
+ Args:
1288
+ model: BehaviorClassifier model
1289
+ train_dataset: Training dataset
1290
+ val_dataset: Optional validation dataset
1291
+ config: Training configuration dict
1292
+ log_fn: Optional callback for logging messages
1293
+ progress_callback: Optional callback(epoch, total_epochs) for progress
1294
+ stop_callback: Optional callback returning True if training should stop
1295
+ metrics_callback: Optional callback(metrics_dict) called after each epoch
1296
+ """
1297
+ import traceback
1298
+
1299
+ try:
1300
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1301
+ if log_fn:
1302
+ log_fn(f"Using device: {device}")
1303
+ if device.type == "cuda":
1304
+ log_fn(f"CUDA device: {torch.cuda.get_device_name(0)}")
1305
+ log_fn(f"CUDA available: {torch.cuda.is_available()}")
1306
+ log_fn(f"CUDA device count: {torch.cuda.device_count()}")
1307
+ log_fn(f"Current CUDA device: {torch.cuda.current_device()}")
1308
+ log_fn(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
1309
+
1310
+ # Move model to device (head will be on GPU, backbone stays on CPU for JAX)
1311
+ model.to(device)
1312
+
1313
+ # Load pretrained weights if specified (fine-tuning)
1314
+ pretrained_path = config.get("pretrained_path")
1315
+ if pretrained_path and os.path.exists(pretrained_path):
1316
+ if log_fn:
1317
+ log_fn(f"Loading pretrained weights from {pretrained_path}...")
1318
+ try:
1319
+ payload = torch.load(pretrained_path, map_location=device)
1320
+ if isinstance(payload, dict):
1321
+ frame_head_state_dict = payload.get("frame_head_state_dict", {})
1322
+ localization_state_dict = payload.get("localization_state_dict", {})
1323
+ else:
1324
+ frame_head_state_dict = {}
1325
+ localization_state_dict = {}
1326
+
1327
+ if (
1328
+ getattr(model, "frame_head", None) is not None
1329
+ and isinstance(frame_head_state_dict, dict)
1330
+ and frame_head_state_dict
1331
+ ):
1332
+ frame_model_state = model.frame_head.state_dict()
1333
+ filtered_frame = {}
1334
+ mismatched_frame = []
1335
+ for k, v in frame_head_state_dict.items():
1336
+ if k in frame_model_state:
1337
+ if v.shape == frame_model_state[k].shape:
1338
+ filtered_frame[k] = v
1339
+ else:
1340
+ mismatched_frame.append(k)
1341
+ model.frame_head.load_state_dict(filtered_frame, strict=False)
1342
+ if log_fn:
1343
+ log_fn(
1344
+ f"Loaded frame head weights: {len(filtered_frame)} tensors"
1345
+ + (f" (skipped mismatched: {mismatched_frame})" if mismatched_frame else "")
1346
+ )
1347
+
1348
+ # Also restore localization head when available and compatible.
1349
+ if (
1350
+ getattr(model, "use_localization", False)
1351
+ and getattr(model, "localization_head", None) is not None
1352
+ and isinstance(localization_state_dict, dict)
1353
+ and localization_state_dict
1354
+ ):
1355
+ loc_model_state = model.localization_head.state_dict()
1356
+ filtered_loc = {}
1357
+ mismatched_loc = []
1358
+ for k, v in localization_state_dict.items():
1359
+ if k in loc_model_state:
1360
+ if v.shape == loc_model_state[k].shape:
1361
+ filtered_loc[k] = v
1362
+ else:
1363
+ mismatched_loc.append(k)
1364
+ model.localization_head.load_state_dict(filtered_loc, strict=False)
1365
+ if log_fn:
1366
+ log_fn(
1367
+ f"Loaded localization head weights: {len(filtered_loc)} tensors"
1368
+ + (f" (skipped mismatched: {mismatched_loc})" if mismatched_loc else "")
1369
+ )
1370
+ elif log_fn and getattr(model, "use_localization", False):
1371
+ log_fn("No localization weights found in pretrained checkpoint; localization head will train from initialization.")
1372
+
1373
+ if log_fn:
1374
+ log_fn("Pretrained weights loaded successfully (partial load if class count changed).")
1375
+ except Exception as e:
1376
+ if log_fn:
1377
+ log_fn(f"WARNING: Failed to load pretrained weights: {e}")
1378
+
1379
+ # Ensure frame head is on GPU
1380
+ if device.type == "cuda":
1381
+ model.frame_head.to(device)
1382
+ fh_device = next(model.frame_head.parameters()).device
1383
+ if log_fn:
1384
+ log_fn(f"Frame head device: {fh_device}")
1385
+ if fh_device.type != "cuda":
1386
+ log_fn("ERROR: Frame head is not on GPU! Training will be very slow.")
1387
+ else:
1388
+ log_fn(f"GPU memory after moving head: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
1389
+
1390
+ if log_fn:
1391
+ log_fn(f"Model moved to {device}")
1392
+ log_fn("Note: VideoPrism backbone runs on GPU (JAX) if available, classification heads run on GPU (PyTorch)")
1393
+
1394
+ if log_fn:
1395
+ log_fn(f"Creating data loaders (batch_size={config['batch_size']})...")
1396
+
1397
+ use_weighted_sampler = config.get("use_weighted_sampler", False)
1398
+ use_ovr = config.get("use_ovr", False)
1399
+ _confusion_warmup_pct = float(config.get("confusion_sampler_warmup_pct", 0.2))
1400
+ sampler = None
1401
+ batch_sampler = None
1402
+ shuffle = True
1403
+
1404
+ # OvR needs balanced batches so each binary head sees positives every batch.
1405
+ if use_ovr and hasattr(train_dataset, 'labels'):
1406
+ labels_for_balance = train_dataset.labels
1407
+ ovr_min_samples = 1
1408
+ _confusion_temperature = float(config.get("confusion_sampler_temperature", 2.0))
1409
+ _use_confusion_sampler = bool(config.get("use_confusion_sampler", True)) and use_ovr
1410
+ _hybrid_bg = bool(config.get("ovr_background_as_negative", False))
1411
+ _bg_indices = []
1412
+ _bg_per_batch = 0
1413
+ if _hybrid_bg and hasattr(train_dataset, "ovr_background_clip"):
1414
+ _bg_indices = [
1415
+ i for i, is_bg in enumerate(train_dataset.ovr_background_clip)
1416
+ if bool(is_bg)
1417
+ ]
1418
+ if _bg_indices:
1419
+ _bg_per_batch = max(1, config["batch_size"] // 4)
1420
+ _SamplerClass = ConfusionAwareSampler if _use_confusion_sampler else BalancedBatchSampler
1421
+ _sampler_kwargs = dict(weight_temperature=_confusion_temperature) if _use_confusion_sampler else {}
1422
+ bbs = _SamplerClass(
1423
+ labels=labels_for_balance,
1424
+ batch_size=config["batch_size"],
1425
+ min_samples_per_class=ovr_min_samples,
1426
+ drop_last=False,
1427
+ excluded_classes=[-1],
1428
+ virtual_size_multiplier=getattr(train_dataset, "virtual_size_multiplier", 1),
1429
+ background_indices=_bg_indices,
1430
+ background_per_batch=_bg_per_batch,
1431
+ **_sampler_kwargs,
1432
+ )
1433
+ if bbs and getattr(bbs, "enabled", False):
1434
+ batch_sampler = bbs
1435
+ if log_fn:
1436
+ reason = "OvR" if use_ovr else "contrastive loss"
1437
+ sampler_type = "ConfusionAwareSampler" if _use_confusion_sampler else "BalancedBatchSampler"
1438
+ log_fn(f"Using {sampler_type} (>={ovr_min_samples} per class) for training ({reason} enabled)")
1439
+ if _bg_indices:
1440
+ log_fn(
1441
+ f"Hybrid OvR background negatives: {len(_bg_indices)} clips "
1442
+ f"(up to {_bg_per_batch} per batch)"
1443
+ )
1444
+ if _use_confusion_sampler:
1445
+ warmup_ep = int(config["epochs"] * _confusion_warmup_pct)
1446
+ log_fn(f"Confusion sampler warmup: {int(_confusion_warmup_pct * 100)}% ({warmup_ep} epochs) — uniform sampling until epoch {warmup_ep}")
1447
+ if hasattr(bbs, "dropped_classes") and bbs.dropped_classes:
1448
+ # Map indices back to names safely.
1449
+ dropped_names = []
1450
+ for idx in bbs.dropped_classes:
1451
+ name = None
1452
+ if hasattr(train_dataset, "classes") and 0 <= idx < len(train_dataset.classes):
1453
+ name = train_dataset.classes[idx]
1454
+ elif hasattr(train_dataset, "raw_classes") and 0 <= idx < len(train_dataset.raw_classes):
1455
+ name = train_dataset.raw_classes[idx]
1456
+ else:
1457
+ name = str(idx)
1458
+ dropped_names.append(name)
1459
+ log_fn(
1460
+ f"WARNING: The following classes have < {ovr_min_samples} samples and will be SKIPPED by BalancedBatchSampler: "
1461
+ f"{dropped_names}"
1462
+ )
1463
+ else:
1464
+ if log_fn:
1465
+ log_fn("BalancedBatchSampler disabled (insufficient class counts); falling back to standard batching")
1466
+
1467
+ # If not using balanced batches, optionally use weighted sampler
1468
+ if batch_sampler is None and use_weighted_sampler and hasattr(train_dataset, 'labels'):
1469
+ if log_fn:
1470
+ log_fn("Creating weighted random sampler...")
1471
+
1472
+ labels_for_sampling = list(train_dataset.labels)
1473
+
1474
+ virtual_mult = int(getattr(train_dataset, "virtual_size_multiplier", 1) or 1)
1475
+ if virtual_mult > 1:
1476
+ labels_for_sampling = labels_for_sampling * virtual_mult
1477
+
1478
+ class_counts = Counter(labels_for_sampling)
1479
+ if log_fn:
1480
+ log_fn(f"Class counts: {dict(class_counts)}")
1481
+ num_samples = len(labels_for_sampling)
1482
+ weights = [1.0 / class_counts[label] for label in labels_for_sampling]
1483
+ sampler = WeightedRandomSampler(weights=weights, num_samples=num_samples, replacement=True)
1484
+ shuffle = False
1485
+ if log_fn:
1486
+ log_fn("Using weighted random sampler for training")
1487
+
1488
+ num_workers = config.get("num_workers", 4)
1489
+
1490
+ if batch_sampler is not None:
1491
+ train_loader = DataLoader(
1492
+ train_dataset,
1493
+ batch_sampler=batch_sampler,
1494
+ num_workers=num_workers,
1495
+ pin_memory=True if device.type == "cuda" else False,
1496
+ persistent_workers=True if num_workers > 0 else False
1497
+ )
1498
+ else:
1499
+ train_loader = DataLoader(
1500
+ train_dataset,
1501
+ batch_size=config["batch_size"],
1502
+ shuffle=shuffle,
1503
+ sampler=sampler,
1504
+ num_workers=num_workers,
1505
+ pin_memory=True if device.type == "cuda" else False,
1506
+ persistent_workers=True if num_workers > 0 else False
1507
+ )
1508
+
1509
+ if log_fn:
1510
+ log_fn(f"Train loader created: {len(train_loader)} batches (workers={num_workers})")
1511
+
1512
+ val_loader = None
1513
+ if val_dataset:
1514
+ val_loader = DataLoader(
1515
+ val_dataset,
1516
+ batch_size=config["batch_size"],
1517
+ shuffle=False,
1518
+ num_workers=num_workers,
1519
+ pin_memory=True if device.type == "cuda" else False,
1520
+ persistent_workers=True if num_workers > 0 else False
1521
+ )
1522
+ if log_fn:
1523
+ log_fn(f"Val loader created: {len(val_loader)} batches")
1524
+
1525
+ if log_fn:
1526
+ log_fn("Creating optimizer and loss function...")
1527
+
1528
+ base_lr = config["lr"]
1529
+ localization_lr = float(config.get("localization_lr", base_lr))
1530
+ classification_lr = float(config.get("classification_lr", base_lr))
1531
+ wd = config.get("weight_decay", 0.001)
1532
+ param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
1533
+ raw_model = model.module if hasattr(model, "module") else model
1534
+ use_separate_lr = (
1535
+ getattr(raw_model, "use_localization", False)
1536
+ and getattr(raw_model, "localization_head", None) is not None
1537
+ )
1538
+ if use_separate_lr:
1539
+ loc_params = [p for n, p in param_dict.items() if "localization_head" in n]
1540
+ other_params = [p for n, p in param_dict.items() if "localization_head" not in n]
1541
+ loc_decay = [p for p in loc_params if p.dim() >= 2]
1542
+ loc_nodecay = [p for p in loc_params if p.dim() < 2]
1543
+ cls_decay = [p for p in other_params if p.dim() >= 2]
1544
+ cls_nodecay = [p for p in other_params if p.dim() < 2]
1545
+ optim_groups = []
1546
+ if loc_decay:
1547
+ optim_groups.append({"params": loc_decay, "weight_decay": wd, "lr": localization_lr})
1548
+ if loc_nodecay:
1549
+ optim_groups.append({"params": loc_nodecay, "weight_decay": 0.0, "lr": localization_lr})
1550
+ if cls_decay:
1551
+ optim_groups.append({"params": cls_decay, "weight_decay": wd, "lr": classification_lr})
1552
+ if cls_nodecay:
1553
+ optim_groups.append({"params": cls_nodecay, "weight_decay": 0.0, "lr": classification_lr})
1554
+ else:
1555
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
1556
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
1557
+ optim_groups = [
1558
+ {"params": decay_params, "weight_decay": wd, "lr": classification_lr},
1559
+ {"params": nodecay_params, "weight_decay": 0.0, "lr": classification_lr},
1560
+ ]
1561
+ optimizer = AdamW(optim_groups, lr=classification_lr)
1562
+
1563
+ # --- Scheduler: Cosine Annealing with Warm Restarts + Linear Warmup ---
1564
+ use_scheduler = bool(config.get('use_scheduler', True))
1565
+ total_epochs = config['epochs']
1566
+ warmup_epochs = 0
1567
+ scheduler = None
1568
+ warmup_scheduler = None
1569
+ restart_period = None
1570
+ eta_min = None
1571
+ if use_scheduler:
1572
+ eta_min = 0.2 * classification_lr
1573
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1574
+ optimizer, T_max=max(1, total_epochs - warmup_epochs), eta_min=eta_min
1575
+ )
1576
+ warmup_scheduler = None
1577
+ if warmup_epochs > 0:
1578
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
1579
+ optimizer, start_factor=0.01, end_factor=1.0,
1580
+ total_iters=warmup_epochs,
1581
+ )
1582
+
1583
+ # --- EMA (Exponential Moving Average) for classification head only ---
1584
+ use_ema = bool(config.get("use_ema", True))
1585
+ ema_decay = 0.99
1586
+ ema_state: dict[str, torch.Tensor] = {}
1587
+ ema_active = False # enabled later when ema_start_epoch is reached
1588
+ def _init_ema():
1589
+ raw = model.module if hasattr(model, "module") else model
1590
+ ema_state.clear()
1591
+ for name, param in raw.named_parameters():
1592
+ if param.requires_grad and not name.startswith("localization_head."):
1593
+ ema_state[name] = param.data.clone()
1594
+ def _update_ema():
1595
+ if not ema_active:
1596
+ return
1597
+ raw = model.module if hasattr(model, "module") else model
1598
+ for name, param in raw.named_parameters():
1599
+ if name in ema_state:
1600
+ ema_state[name].mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)
1601
+ def _apply_ema():
1602
+ """Swap model weights with EMA weights. Call again to restore."""
1603
+ if not ema_active:
1604
+ return
1605
+ raw = model.module if hasattr(model, "module") else model
1606
+ for name, param in raw.named_parameters():
1607
+ if name in ema_state:
1608
+ param.data, ema_state[name] = ema_state[name].clone(), param.data.clone()
1609
+
1610
+ if log_fn:
1611
+ log_fn("Using AdamW with separate weight decay for biases/norm")
1612
+ if use_separate_lr:
1613
+ log_fn(f"Localization LR: {localization_lr:.2e}, Classification LR: {classification_lr:.2e}")
1614
+ else:
1615
+ log_fn(f"Learning rate: {classification_lr:.2e}")
1616
+ log_fn("Gradient clipping: max_norm=1.0 (NaN-guarded)")
1617
+ if use_scheduler:
1618
+ if warmup_epochs > 0:
1619
+ log_fn(f"LR schedule: {warmup_epochs}-epoch linear warmup → CosineAnnealingLR (single decay, {total_epochs - warmup_epochs} epochs)")
1620
+ else:
1621
+ log_fn(f"LR schedule: CosineAnnealingLR (single cosine decay, {total_epochs} epochs, no warmup)")
1622
+ else:
1623
+ log_fn("LR scheduler disabled (fixed learning rate)")
1624
+ if use_ema:
1625
+ if use_separate_lr or getattr(raw_model, "use_localization", False):
1626
+ log_fn(f"EMA: classification head only (decay={ema_decay}); localization head unchanged")
1627
+ else:
1628
+ log_fn(f"EMA model averaging: enabled (decay={ema_decay})")
1629
+ else:
1630
+ log_fn("EMA model averaging: disabled")
1631
+
1632
+ use_class_weights = config.get("use_class_weights", False)
1633
+
1634
+ # Focal Loss settings
1635
+ use_focal_loss = config.get("use_focal_loss", False)
1636
+ focal_gamma = config.get("focal_gamma", 2.0)
1637
+ use_supcon_loss = bool(config.get("use_supcon_loss", False))
1638
+ supcon_weight = float(config.get("supcon_weight", 0.2)) if use_supcon_loss else 0.0
1639
+ supcon_temperature = float(config.get("supcon_temperature", 0.1))
1640
+
1641
+ # Asymmetric Loss settings
1642
+ use_asl = config.get("use_asl", True) # Asymmetric Loss on by default for OvR
1643
+ asl_gamma_neg = float(config.get("asl_gamma_neg", 4.0))
1644
+ asl_gamma_pos = float(config.get("asl_gamma_pos", 0.0))
1645
+ asl_clip = float(config.get("asl_clip", 0.05))
1646
+
1647
+ # ASL's 'clip' parameter already applies negative smoothing, so default
1648
+ # to 0.0 when ASL is active to avoid contradictory loss signals.
1649
+ default_smoothing = 0.0 if (use_ovr and use_asl) else (0.05 if use_ovr else 0.0)
1650
+ ovr_label_smoothing = float(config.get("ovr_label_smoothing", default_smoothing))
1651
+ ovr_background_as_negative = bool(config.get("ovr_background_as_negative", False))
1652
+ allowed_cooccurrence = []
1653
+ cooccur_lookup: dict[int, set[int]] = {}
1654
+ hard_pair_mining = bool(config.get("use_hard_pair_mining", False) and use_ovr)
1655
+ hard_pair_margin = float(config.get("hard_pair_margin", 0.5))
1656
+ hard_pair_loss_weight = float(config.get("hard_pair_loss_weight", 0.2)) if hard_pair_mining else 0.0
1657
+ hard_pair_confusion_boost = max(1.0, float(config.get("hard_pair_confusion_boost", 1.5)))
1658
+ hard_pair_index_pairs: list[tuple[int, int]] = []
1659
+ hard_pair_name_pairs: list[list[str]] = []
1660
+
1661
+ if use_ovr and getattr(model, "frame_head", None) is not None:
1662
+ model.frame_head.use_ovr = True
1663
+ if ovr_background_as_negative and log_fn:
1664
+ bg_names = config.get("ovr_background_class_names", [])
1665
+ bg_txt = ", ".join(bg_names) if bg_names else "background"
1666
+ log_fn(f"OvR hybrid background negatives: {bg_txt} kept as all-zero targets")
1667
+ if hard_pair_mining:
1668
+ class_to_idx = {c: i for i, c in enumerate(getattr(train_dataset, "classes", []))}
1669
+ seen_pairs = set()
1670
+ skipped_pairs = []
1671
+ for pair in (config.get("hard_pairs", []) or []):
1672
+ if not isinstance(pair, (list, tuple)) or len(pair) != 2:
1673
+ continue
1674
+ a_name = str(pair[0]).strip()
1675
+ b_name = str(pair[1]).strip()
1676
+ if not a_name or not b_name or a_name == b_name:
1677
+ continue
1678
+ a_idx = class_to_idx.get(a_name, -1)
1679
+ b_idx = class_to_idx.get(b_name, -1)
1680
+ if a_idx < 0 or b_idx < 0:
1681
+ skipped_pairs.append((a_name, b_name))
1682
+ continue
1683
+ key = (min(a_idx, b_idx), max(a_idx, b_idx))
1684
+ if key in seen_pairs:
1685
+ continue
1686
+ seen_pairs.add(key)
1687
+ hard_pair_index_pairs.append(key)
1688
+ hard_pair_name_pairs.append([
1689
+ train_dataset.classes[key[0]],
1690
+ train_dataset.classes[key[1]],
1691
+ ])
1692
+ if skipped_pairs and log_fn:
1693
+ skipped_txt = ", ".join(f"{a}<->{b}" for a, b in skipped_pairs[:8])
1694
+ if len(skipped_pairs) > 8:
1695
+ skipped_txt += ", ..."
1696
+ log_fn(f"Hard-pair mining: skipped unknown pair(s): {skipped_txt}")
1697
+ if hard_pair_index_pairs:
1698
+ if log_fn:
1699
+ pair_txt = ", ".join(f"{a}<->{b}" for a, b in hard_pair_name_pairs)
1700
+ log_fn(
1701
+ "Hard-pair mining: "
1702
+ f"{pair_txt} | margin={hard_pair_margin:.2f}, "
1703
+ f"loss_weight={hard_pair_loss_weight:.3f}, "
1704
+ f"confusion_boost={hard_pair_confusion_boost:.2f}x"
1705
+ )
1706
+ else:
1707
+ hard_pair_mining = False
1708
+ hard_pair_loss_weight = 0.0
1709
+ if log_fn:
1710
+ log_fn("Hard-pair mining enabled, but no valid configured pairs matched the active classes.")
1711
+
1712
+ # Frame-level loss is always the sole objective
1713
+ use_frame_loss = True
1714
+ if log_fn:
1715
+ log_fn("Frame-level classification: enabled (sole classification loss)")
1716
+
1717
+ # Boundary and smoothness loss settings
1718
+ use_temporal_decoder = bool(config.get("use_temporal_decoder", True))
1719
+ boundary_loss_weight = float(config.get("boundary_loss_weight", 0.3)) if use_temporal_decoder else 0.0
1720
+ smoothness_loss_weight = float(config.get("smoothness_loss_weight", 0.05))
1721
+ boundary_tolerance = int(config.get("boundary_tolerance", 1))
1722
+ if log_fn:
1723
+ if use_temporal_decoder:
1724
+ log_fn(f"Boundary loss: weight={boundary_loss_weight}, tolerance={boundary_tolerance}")
1725
+ else:
1726
+ log_fn("Temporal decoder: disabled (direct per-frame classifier after spatial pooling)")
1727
+ log_fn("Boundary loss: disabled (no boundary branch)")
1728
+ log_fn(f"Smoothness loss: weight={smoothness_loss_weight}")
1729
+ if use_supcon_loss and supcon_weight > 0:
1730
+ log_fn(
1731
+ f"SupCon on MAP embeddings: enabled "
1732
+ f"(weight={supcon_weight:.3f}, temperature={supcon_temperature:.3f})"
1733
+ )
1734
+ else:
1735
+ log_fn("SupCon on MAP embeddings: disabled")
1736
+
1737
+ use_frame_bout_balance = bool(config.get("use_frame_bout_balance", True))
1738
+ frame_bout_balance_power = float(config.get("frame_bout_balance_power", 1.0))
1739
+ if use_frame_bout_balance and log_fn:
1740
+ log_fn(
1741
+ f"Frame bout balancing: enabled (power={frame_bout_balance_power:.2f}) "
1742
+ f"[reduces dominance of long behavior bouts]"
1743
+ )
1744
+
1745
+ def _generate_boundary_labels(frame_labels: torch.Tensor, tolerance: int = 1) -> torch.Tensor:
1746
+ """Generate boundary labels from frame labels.
1747
+
1748
+ boundary[t] = 1 if there is a label transition within ±tolerance frames.
1749
+ Frames with label=-1 get boundary label=-1 (ignored).
1750
+ """
1751
+ B, T = frame_labels.shape
1752
+ boundaries = torch.zeros(B, T, dtype=torch.float32, device=frame_labels.device)
1753
+ for bi in range(B):
1754
+ for t in range(1, T):
1755
+ cur = int(frame_labels[bi, t].item())
1756
+ prev = int(frame_labels[bi, t - 1].item())
1757
+ if cur < 0 or prev < 0:
1758
+ continue
1759
+ if cur != prev:
1760
+ lo = max(0, t - tolerance)
1761
+ hi = min(T, t + tolerance + 1)
1762
+ boundaries[bi, lo:hi] = 1.0
1763
+ # Mark invalid frames as -1
1764
+ boundaries[frame_labels < 0] = -1.0
1765
+ return boundaries
1766
+
1767
+ def _pool_frame_labels(frame_labels: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
1768
+ """Downsample [B, T] labels to pooled timeline by majority vote (ignore=-1)."""
1769
+ if frame_labels is None or pool <= 1:
1770
+ return frame_labels
1771
+ if frame_labels.dim() != 2:
1772
+ return frame_labels
1773
+ B, T = frame_labels.shape
1774
+ pad = (pool - (T % pool)) % pool
1775
+ lbl = frame_labels
1776
+ if pad > 0:
1777
+ pad_vals = lbl[:, -1:].repeat(1, pad)
1778
+ lbl = torch.cat([lbl, pad_vals], dim=1)
1779
+ Tp = lbl.shape[1] // pool
1780
+ chunks = lbl.view(B, Tp, pool)
1781
+ pooled = torch.full((B, Tp), -1, dtype=lbl.dtype, device=lbl.device)
1782
+ for bi in range(B):
1783
+ for ti in range(Tp):
1784
+ vals = chunks[bi, ti]
1785
+ valid = vals[vals >= 0]
1786
+ if valid.numel() == 0:
1787
+ continue
1788
+ uniq, counts = torch.unique(valid, return_counts=True)
1789
+ pooled[bi, ti] = uniq[torch.argmax(counts)]
1790
+ return pooled
1791
+
1792
+ def _pool_binary_mask(mask: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
1793
+ """Downsample a boolean [B, T] mask by any() over each pooled window."""
1794
+ if mask is None or pool <= 1:
1795
+ return mask
1796
+ if mask.dim() != 2:
1797
+ return mask
1798
+ B, T = mask.shape
1799
+ pad = (pool - (T % pool)) % pool
1800
+ m = mask
1801
+ if pad > 0:
1802
+ pad_vals = m[:, -1:].repeat(1, pad)
1803
+ m = torch.cat([m, pad_vals], dim=1)
1804
+ Tp = m.shape[1] // pool
1805
+ return m.view(B, Tp, pool).any(dim=2)
1806
+
1807
+ def _pool_frame_embeddings(frame_embeddings: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
1808
+ """Average [B, T, D] embeddings over pooled windows to match pooled labels."""
1809
+ if frame_embeddings is None or pool <= 1:
1810
+ return frame_embeddings
1811
+ if frame_embeddings.dim() != 3:
1812
+ return frame_embeddings
1813
+ B, T, D = frame_embeddings.shape
1814
+ pad = (pool - (T % pool)) % pool
1815
+ emb = frame_embeddings
1816
+ if pad > 0:
1817
+ pad_vals = emb[:, -1:, :].repeat(1, pad, 1)
1818
+ emb = torch.cat([emb, pad_vals], dim=1)
1819
+ Tp = emb.shape[1] // pool
1820
+ return emb.view(B, Tp, pool, D).mean(dim=2)
1821
+
1822
+ def _supervised_contrastive_loss(
1823
+ frame_embeddings: Optional[torch.Tensor],
1824
+ frame_labels: Optional[torch.Tensor],
1825
+ temperature: float = 0.1,
1826
+ max_samples: int = 512,
1827
+ ) -> torch.Tensor:
1828
+ """SupCon over pooled frame embeddings using valid frame labels only."""
1829
+ if frame_embeddings is None or frame_labels is None:
1830
+ if frame_embeddings is not None:
1831
+ return frame_embeddings.sum() * 0.0
1832
+ return torch.tensor(0.0, device=device)
1833
+ if frame_embeddings.dim() != 3 or frame_labels.dim() != 2:
1834
+ return frame_embeddings.sum() * 0.0
1835
+ feats = frame_embeddings.reshape(-1, frame_embeddings.shape[-1])
1836
+ labels_flat = frame_labels.reshape(-1)
1837
+ valid = labels_flat >= 0
1838
+ if valid.sum().item() < 2:
1839
+ return feats.sum() * 0.0
1840
+ feats = feats[valid]
1841
+ labels_flat = labels_flat[valid]
1842
+ if feats.shape[0] > max_samples:
1843
+ perm = torch.randperm(feats.shape[0], device=feats.device)[:max_samples]
1844
+ feats = feats[perm]
1845
+ labels_flat = labels_flat[perm]
1846
+ uniq_labels, counts = torch.unique(labels_flat, return_counts=True)
1847
+ keep_labels = uniq_labels[counts >= 2]
1848
+ if keep_labels.numel() == 0:
1849
+ return feats.sum() * 0.0
1850
+ keep_mask = torch.zeros_like(labels_flat, dtype=torch.bool)
1851
+ for cls_id in keep_labels:
1852
+ keep_mask |= labels_flat == cls_id
1853
+ feats = feats[keep_mask]
1854
+ labels_flat = labels_flat[keep_mask]
1855
+ if feats.shape[0] < 2:
1856
+ return feats.sum() * 0.0
1857
+
1858
+ feats = F.normalize(feats, p=2, dim=1)
1859
+ logits = torch.matmul(feats, feats.T) / max(float(temperature), 1e-6)
1860
+ logits = logits - logits.max(dim=1, keepdim=True).values.detach()
1861
+
1862
+ same_class = labels_flat.unsqueeze(0) == labels_flat.unsqueeze(1)
1863
+ self_mask = torch.eye(feats.shape[0], device=feats.device, dtype=torch.bool)
1864
+ pos_mask = same_class & (~self_mask)
1865
+ valid_anchor = pos_mask.any(dim=1)
1866
+ if not valid_anchor.any():
1867
+ return feats.sum() * 0.0
1868
+
1869
+ exp_logits = torch.exp(logits) * (~self_mask)
1870
+ log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True).clamp(min=1e-12))
1871
+ mean_log_prob_pos = (pos_mask.float() * log_prob).sum(dim=1) / pos_mask.sum(dim=1).clamp(min=1)
1872
+ return -mean_log_prob_pos[valid_anchor].mean()
1873
+
1874
+ def _lookup_ovr_suppress_for_batch(sample_indices, dataset, device) -> Optional[torch.Tensor]:
1875
+ """Map batch sample indices back to per-clip OvR suppression targets."""
1876
+ suppress_list = getattr(dataset, "ovr_suppress_idx", None)
1877
+ if not suppress_list:
1878
+ return None
1879
+ if isinstance(sample_indices, torch.Tensor):
1880
+ indices = sample_indices.detach().cpu().tolist()
1881
+ else:
1882
+ indices = list(sample_indices)
1883
+ mapped = []
1884
+ for idx in indices:
1885
+ try:
1886
+ clip_idx = int(idx)
1887
+ except Exception as e:
1888
+ logger.debug("Could not convert index to int: %s", e)
1889
+ clip_idx = -1
1890
+ if 0 <= clip_idx < len(suppress_list):
1891
+ mapped.append(int(suppress_list[clip_idx]))
1892
+ else:
1893
+ mapped.append(-1)
1894
+ return torch.tensor(mapped, dtype=torch.long, device=device)
1895
+
1896
+ def _build_bout_weights(frame_labels: torch.Tensor, power: float = 1.0) -> torch.Tensor:
1897
+ """Per-frame weights inversely proportional to contiguous bout length."""
1898
+ B, T = frame_labels.shape
1899
+ weights = torch.zeros((B, T), dtype=torch.float32, device=frame_labels.device)
1900
+ for bi in range(B):
1901
+ t = 0
1902
+ while t < T:
1903
+ lbl = int(frame_labels[bi, t].item())
1904
+ if lbl < 0:
1905
+ t += 1
1906
+ continue
1907
+ t_end = t + 1
1908
+ while t_end < T and int(frame_labels[bi, t_end].item()) == lbl:
1909
+ t_end += 1
1910
+ seg_len = max(1, t_end - t)
1911
+ w = float(seg_len) ** (-float(power))
1912
+ weights[bi, t:t_end] = w
1913
+ t = t_end
1914
+ valid = frame_labels >= 0
1915
+ if valid.any():
1916
+ mean_w = weights[valid].mean().clamp(min=1e-8)
1917
+ weights = weights / mean_w
1918
+ return weights
1919
+
1920
+ def _hard_pair_margin_loss(
1921
+ frame_logits: torch.Tensor,
1922
+ frame_labels: torch.Tensor,
1923
+ pair_indices: list[tuple[int, int]],
1924
+ margin: float,
1925
+ use_bout_balance: bool = False,
1926
+ bout_power: float = 1.0,
1927
+ ) -> torch.Tensor:
1928
+ """Extra pairwise margin pressure for configured confusing class pairs."""
1929
+ if not pair_indices or margin <= 0:
1930
+ return frame_logits.sum() * 0.0
1931
+ B, T, _ = frame_logits.shape
1932
+ frame_w = torch.ones((B, T), dtype=torch.float32, device=frame_logits.device)
1933
+ if use_bout_balance:
1934
+ frame_w = _build_bout_weights(frame_labels, power=bout_power)
1935
+ total_loss = frame_logits.new_tensor(0.0)
1936
+ total_weight = frame_logits.new_tensor(0.0)
1937
+ for a_idx, b_idx in pair_indices:
1938
+ mask_a = frame_labels == a_idx
1939
+ if mask_a.any():
1940
+ loss_a = F.relu(margin - (frame_logits[..., a_idx] - frame_logits[..., b_idx]))
1941
+ total_loss = total_loss + (loss_a[mask_a] * frame_w[mask_a]).sum()
1942
+ total_weight = total_weight + frame_w[mask_a].sum()
1943
+ mask_b = frame_labels == b_idx
1944
+ if mask_b.any():
1945
+ loss_b = F.relu(margin - (frame_logits[..., b_idx] - frame_logits[..., a_idx]))
1946
+ total_loss = total_loss + (loss_b[mask_b] * frame_w[mask_b]).sum()
1947
+ total_weight = total_weight + frame_w[mask_b].sum()
1948
+ if float(total_weight.item()) <= 0:
1949
+ return frame_logits.sum() * 0.0
1950
+ return total_loss / total_weight.clamp(min=1e-6)
1951
+
1952
+ def _frame_loss_balanced(
1953
+ frame_logits: torch.Tensor,
1954
+ frame_labels: torch.Tensor,
1955
+ use_ovr_local: bool = False,
1956
+ ovr_targets: Optional[torch.Tensor] = None,
1957
+ ovr_weight: Optional[torch.Tensor] = None,
1958
+ use_bout_balance: bool = False,
1959
+ bout_power: float = 1.0,
1960
+ valid_mask_override: Optional[torch.Tensor] = None,
1961
+ ) -> torch.Tensor:
1962
+ """Frame classification loss with optional inverse-bout-length weighting."""
1963
+ B, T, C = frame_logits.shape
1964
+ valid = valid_mask_override if valid_mask_override is not None else (frame_labels >= 0)
1965
+ if not valid.any():
1966
+ return frame_logits.sum() * 0.0
1967
+
1968
+ frame_w = torch.ones((B, T), dtype=torch.float32, device=frame_logits.device)
1969
+ if use_bout_balance:
1970
+ base_w = _build_bout_weights(frame_labels, power=bout_power)
1971
+ if valid_mask_override is not None:
1972
+ frame_w = torch.where(
1973
+ valid,
1974
+ torch.where(frame_labels >= 0, base_w, torch.ones_like(base_w)),
1975
+ torch.zeros_like(base_w),
1976
+ )
1977
+ else:
1978
+ frame_w = base_w
1979
+
1980
+ if use_ovr_local and ovr_targets is not None:
1981
+ if use_asl and isinstance(criterion, AsymmetricLoss):
1982
+ per_elem = criterion(frame_logits, ovr_targets)
1983
+ else:
1984
+ per_elem = F.binary_cross_entropy_with_logits(
1985
+ frame_logits, ovr_targets, reduction='none'
1986
+ )
1987
+ if use_focal_loss:
1988
+ pt = torch.exp(-per_elem)
1989
+ per_elem = ((1 - pt) ** focal_gamma) * per_elem
1990
+
1991
+ elem_w = valid.unsqueeze(-1).float() * frame_w.unsqueeze(-1)
1992
+ if ovr_weight is not None:
1993
+ elem_w = elem_w * ovr_weight
1994
+ denom = elem_w.sum().clamp(min=1.0)
1995
+ return (per_elem * elem_w).sum() / denom
1996
+
1997
+ logits_flat = frame_logits.reshape(B * T, C)
1998
+ labels_flat = frame_labels.reshape(B * T)
1999
+ valid_flat = labels_flat >= 0
2000
+ if not valid_flat.any():
2001
+ return frame_logits.sum() * 0.0
2002
+ logits_valid = logits_flat[valid_flat]
2003
+ labels_valid = labels_flat[valid_flat]
2004
+ raw_ce = F.cross_entropy(logits_valid, labels_valid, reduction='none')
2005
+ w_flat = frame_w.reshape(B * T)[valid_flat]
2006
+ denom = w_flat.sum().clamp(min=1.0)
2007
+ return (raw_ce * w_flat).sum() / denom
2008
+
2009
+ # Localization supervision settings (autonomous staged curriculum)
2010
+ use_localization = bool(config.get("use_localization", False) and getattr(model, "use_localization", False))
2011
+ has_any_localization = use_localization and any(
2012
+ float(v.sum().item()) > 0.5 for v in getattr(train_dataset, "spatial_bbox_valid", [])
2013
+ )
2014
+ # Default localization cap follows total training epochs from GUI unless explicitly overridden.
2015
+ loc_max_stage_epochs = int(config.get("localization_stage_max_epochs", config.get("epochs", 20)))
2016
+ use_manual_loc_switch = bool(config.get("use_manual_localization_switch", False))
2017
+ manual_loc_switch_epoch = int(config.get("manual_localization_switch_epoch", 20))
2018
+ loc_gate_patience = int(config.get("localization_gate_patience", 2))
2019
+ loc_gate_iou_threshold = float(config.get("localization_gate_iou", 0.55))
2020
+ loc_gate_center_error = float(config.get("localization_gate_center_error", 0.15))
2021
+ loc_gate_valid_rate = float(config.get("localization_gate_valid_rate", 0.9))
2022
+ crop_mix_start_gt = float(config.get("classification_crop_gt_prob_start", 1.0))
2023
+ crop_mix_end_gt = float(config.get("classification_crop_gt_prob_end", 0.0))
2024
+ crop_padding = float(config.get("classification_crop_padding", 0.35))
2025
+ crop_min_size = float(config.get("classification_crop_min_size_norm", 0.04))
2026
+ enable_roi_cache = True # Always use precomputed crops for classification stage
2027
+ center_heatmap_weight = float(config.get("center_heatmap_weight", 1.0))
2028
+ center_heatmap_sigma = float(config.get("center_heatmap_sigma", 2.5))
2029
+ direct_center_weight = float(config.get("direct_center_weight", 2.0))
2030
+ # Use the largest bbox across all classes so every crop has the same
2031
+ # extent. This prevents the classifier from using background size as a
2032
+ # class cue and keeps training/inference consistent (inference doesn't
2033
+ # know the class label).
2034
+ global_fixed_wh = (0.2, 0.2)
2035
+ if use_localization and has_any_localization:
2036
+ all_w = []
2037
+ all_h = []
2038
+ max_n = min(
2039
+ len(getattr(train_dataset, "spatial_bboxes", [])),
2040
+ len(getattr(train_dataset, "spatial_bbox_valid", [])),
2041
+ )
2042
+ for i in range(max_n):
2043
+ bboxes_i = train_dataset.spatial_bboxes[i] # [T, 4]
2044
+ valid_i = train_dataset.spatial_bbox_valid[i] # [T]
2045
+ for t in range(bboxes_i.size(0)):
2046
+ if float(valid_i[t].item()) <= 0.5:
2047
+ continue
2048
+ x1, y1, x2, y2 = [float(v) for v in bboxes_i[t].tolist()]
2049
+ w = max(1e-4, min(1.0, x2 - x1))
2050
+ h = max(1e-4, min(1.0, y2 - y1))
2051
+ all_w.append(w)
2052
+ all_h.append(h)
2053
+
2054
+ if all_w and all_h:
2055
+ gw = float(max(all_w))
2056
+ gh = float(max(all_h))
2057
+ global_fixed_wh = (max(1e-4, min(1.0, gw)), max(1e-4, min(1.0, gh)))
2058
+
2059
+ raw_model = model.module if hasattr(model, "module") else model
2060
+ if getattr(raw_model, "use_localization", False) and getattr(raw_model, "localization_head", None) is not None:
2061
+ raw_model.localization_head.set_fixed_box_size(global_fixed_wh[0], global_fixed_wh[1])
2062
+
2063
+ def _fixed_wh_for_labels(label_tensor: torch.Tensor, device: torch.device, dtype: torch.dtype) -> Optional[torch.Tensor]:
2064
+ """Return the single global fixed box size for each sample in the batch."""
2065
+ if label_tensor is None:
2066
+ return None
2067
+ if not torch.is_tensor(label_tensor):
2068
+ return None
2069
+ if label_tensor.dim() == 0:
2070
+ label_tensor = label_tensor.view(1)
2071
+ B = label_tensor.size(0)
2072
+ wh = [list(global_fixed_wh)] * B
2073
+ return torch.tensor(wh, device=device, dtype=dtype)
2074
+ if use_localization and log_fn:
2075
+ if has_any_localization:
2076
+ n_with_bbox = sum(1 for v in train_dataset.spatial_bbox_valid if float(v.sum().item()) > 0.5)
2077
+ log_fn(
2078
+ f"Localization Supervision: enabled ({n_with_bbox} clips with bbox)"
2079
+ )
2080
+ if enable_roi_cache:
2081
+ log_fn("Classification ROI cache: enabled (precompute crops once, then augment in-memory).")
2082
+ if center_heatmap_weight > 0:
2083
+ log_fn(
2084
+ f"Center Heatmap Loss (Gaussian Focal): enabled (weight={center_heatmap_weight}, "
2085
+ f"sigma_patches={center_heatmap_sigma})"
2086
+ )
2087
+ if direct_center_weight > 0:
2088
+ log_fn(
2089
+ f"Direct Center Loss: enabled (weight={direct_center_weight})"
2090
+ )
2091
+ log_fn(
2092
+ f"Localization fixed box size: max across all classes, "
2093
+ f"w={global_fixed_wh[0]:.4f}, h={global_fixed_wh[1]:.4f}"
2094
+ )
2095
+ if use_manual_loc_switch:
2096
+ log_fn(
2097
+ f"Manual localization switch enabled: phase will switch at epoch {manual_loc_switch_epoch}"
2098
+ )
2099
+ else:
2100
+ log_fn("Localization Supervision: enabled but no clips have bbox annotations — will be skipped")
2101
+
2102
+ def _split_localization_output(output):
2103
+ if (
2104
+ use_localization
2105
+ and isinstance(output, tuple)
2106
+ and len(output) >= 2
2107
+ and torch.is_tensor(output[-1])
2108
+ and output[-1].dim() in (2, 3)
2109
+ and output[-1].size(-1) == 4
2110
+ ):
2111
+ head_out = output[:-1]
2112
+ if len(head_out) == 1:
2113
+ head_out = head_out[0]
2114
+ return head_out, output[-1]
2115
+ return output, None
2116
+
2117
+ def _bbox_iou(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
2118
+ x1 = torch.maximum(pred[:, 0], target[:, 0])
2119
+ y1 = torch.maximum(pred[:, 1], target[:, 1])
2120
+ x2 = torch.minimum(pred[:, 2], target[:, 2])
2121
+ y2 = torch.minimum(pred[:, 3], target[:, 3])
2122
+ inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
2123
+ area_p = (pred[:, 2] - pred[:, 0]).clamp(min=0) * (pred[:, 3] - pred[:, 1]).clamp(min=0)
2124
+ area_t = (target[:, 2] - target[:, 0]).clamp(min=0) * (target[:, 3] - target[:, 1]).clamp(min=0)
2125
+ union = area_p + area_t - inter
2126
+ return inter / (union + 1e-6)
2127
+
2128
+ def _localization_metrics(pred_bboxes: torch.Tensor, target_bboxes: torch.Tensor, valid_mask: torch.Tensor):
2129
+ if pred_bboxes is None:
2130
+ return 0.0, 1.0, 0.0
2131
+ if pred_bboxes.dim() == 3:
2132
+ pred_bboxes = pred_bboxes[:, 0, :]
2133
+ valid = valid_mask > 0.5
2134
+ if int(valid.sum().item()) == 0:
2135
+ return 0.0, 1.0, 0.0
2136
+ pred = pred_bboxes[valid]
2137
+ tgt = target_bboxes[valid]
2138
+ iou = _bbox_iou(pred, tgt).mean().item()
2139
+ pred_cx = 0.5 * (pred[:, 0] + pred[:, 2])
2140
+ pred_cy = 0.5 * (pred[:, 1] + pred[:, 3])
2141
+ tgt_cx = 0.5 * (tgt[:, 0] + tgt[:, 2])
2142
+ tgt_cy = 0.5 * (tgt[:, 1] + tgt[:, 3])
2143
+ center_err = torch.sqrt((pred_cx - tgt_cx) ** 2 + (pred_cy - tgt_cy) ** 2).mean().item()
2144
+ valid_pred_rate = float(
2145
+ ((pred[:, 2] > pred[:, 0]) & (pred[:, 3] > pred[:, 1])).float().mean().item()
2146
+ )
2147
+ return float(iou), float(center_err), valid_pred_rate
2148
+
2149
+ def _sanitize_bboxes(bboxes: torch.Tensor) -> torch.Tensor:
2150
+ """Pad bboxes proportionally to their own size, not the full frame.
2151
+ crop_padding is now a fraction of the bbox dimension (e.g. 0.20 = 20% of bbox w/h).
2152
+ """
2153
+ boxes = bboxes.clone()
2154
+ orig_shape = boxes.shape
2155
+ boxes = boxes.view(-1, 4)
2156
+ x1 = boxes[:, 0].clamp(0.0, 1.0)
2157
+ y1 = boxes[:, 1].clamp(0.0, 1.0)
2158
+ x2 = boxes[:, 2].clamp(0.0, 1.0)
2159
+ y2 = boxes[:, 3].clamp(0.0, 1.0)
2160
+ cx = 0.5 * (x1 + x2)
2161
+ cy = 0.5 * (y1 + y2)
2162
+ w = (x2 - x1).abs().clamp(min=crop_min_size)
2163
+ h = (y2 - y1).abs().clamp(min=crop_min_size)
2164
+ # Proportional padding: add crop_padding * bbox_size on each side
2165
+ w = torch.clamp(w * (1.0 + 2.0 * crop_padding), max=1.0)
2166
+ h = torch.clamp(h * (1.0 + 2.0 * crop_padding), max=1.0)
2167
+ x1 = (cx - 0.5 * w).clamp(0.0, 1.0)
2168
+ y1 = (cy - 0.5 * h).clamp(0.0, 1.0)
2169
+ x2 = (cx + 0.5 * w).clamp(0.0, 1.0)
2170
+ y2 = (cy + 0.5 * h).clamp(0.0, 1.0)
2171
+ boxes[:, 0] = x1
2172
+ boxes[:, 1] = y1
2173
+ boxes[:, 2] = torch.maximum(x2, x1 + crop_min_size).clamp(0.0, 1.0)
2174
+ boxes[:, 3] = torch.maximum(y2, y1 + crop_min_size).clamp(0.0, 1.0)
2175
+ return boxes.view(orig_shape)
2176
+
2177
+ def _clamp_bboxes_no_expand(bboxes: torch.Tensor) -> torch.Tensor:
2178
+ """Clamp/reorder bbox coordinates without padding or min-size expansion."""
2179
+ boxes = bboxes.clone()
2180
+ orig_shape = boxes.shape
2181
+ boxes = boxes.view(-1, 4)
2182
+ x1 = boxes[:, 0].clamp(0.0, 1.0)
2183
+ y1 = boxes[:, 1].clamp(0.0, 1.0)
2184
+ x2 = boxes[:, 2].clamp(0.0, 1.0)
2185
+ y2 = boxes[:, 3].clamp(0.0, 1.0)
2186
+ lo_x = torch.minimum(x1, x2)
2187
+ hi_x = torch.maximum(x1, x2)
2188
+ lo_y = torch.minimum(y1, y2)
2189
+ hi_y = torch.maximum(y1, y2)
2190
+ boxes[:, 0] = lo_x
2191
+ boxes[:, 1] = lo_y
2192
+ boxes[:, 2] = torch.maximum(hi_x, lo_x + 1e-4).clamp(0.0, 1.0)
2193
+ boxes[:, 3] = torch.maximum(hi_y, lo_y + 1e-4).clamp(0.0, 1.0)
2194
+ return boxes.view(orig_shape)
2195
+
2196
+ def _crop_clips_with_bboxes(clips: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
2197
+ # clips: [B, T, C, H, W], bboxes: [B, 4] normalized.
2198
+ B, T, C, H, W = clips.shape
2199
+ out = torch.empty_like(clips)
2200
+ boxes = _sanitize_bboxes(bboxes)
2201
+ for i in range(B):
2202
+ x1 = int(round(float(boxes[i, 0].item()) * (W - 1)))
2203
+ y1 = int(round(float(boxes[i, 1].item()) * (H - 1)))
2204
+ x2 = int(round(float(boxes[i, 2].item()) * W))
2205
+ y2 = int(round(float(boxes[i, 3].item()) * H))
2206
+ x1 = max(0, min(x1, W - 1))
2207
+ y1 = max(0, min(y1, H - 1))
2208
+ x2 = max(x1 + 1, min(x2, W))
2209
+ y2 = max(y1 + 1, min(y2, H))
2210
+ sample = clips[i] # [T, C, H, W]
2211
+ cropped = sample[:, :, y1:y2, x1:x2]
2212
+ if cropped.size(-1) < 2 or cropped.size(-2) < 2:
2213
+ out[i] = sample
2214
+ continue
2215
+ out[i] = F.interpolate(cropped, size=(H, W), mode="bilinear", align_corners=False)
2216
+ return out
2217
+
2218
+ def _crop_single_clip_to_target(clip_tchw: torch.Tensor, bbox_xyxy: torch.Tensor, out_h: int, out_w: int) -> torch.Tensor:
2219
+ # clip_tchw: [T,C,H,W], bbox_xyxy: [4] or [T,4] normalized
2220
+ T, C, H, W = clip_tchw.shape
2221
+ box = _sanitize_bboxes(bbox_xyxy)
2222
+ if box.dim() == 1:
2223
+ box = box.view(1, 4).repeat(T, 1)
2224
+ elif box.dim() == 2 and box.size(0) != T:
2225
+ if box.size(0) < T:
2226
+ box = torch.cat([box, box[-1:].repeat(T - box.size(0), 1)], dim=0)
2227
+ else:
2228
+ box = box[:T]
2229
+
2230
+ out_frames = []
2231
+ for ti in range(T):
2232
+ bt = box[ti]
2233
+ x1 = int(round(float(bt[0].item()) * (W - 1)))
2234
+ y1 = int(round(float(bt[1].item()) * (H - 1)))
2235
+ x2 = int(round(float(bt[2].item()) * W))
2236
+ y2 = int(round(float(bt[3].item()) * H))
2237
+ x1 = max(0, min(x1, W - 1))
2238
+ y1 = max(0, min(y1, H - 1))
2239
+ x2 = max(x1 + 1, min(x2, W))
2240
+ y2 = max(y1 + 1, min(y2, H))
2241
+ frame = clip_tchw[ti : ti + 1] # [1,C,H,W]
2242
+ cropped = frame[:, :, y1:y2, x1:x2]
2243
+ if cropped.size(-2) < 2 or cropped.size(-1) < 2:
2244
+ resized = F.interpolate(frame, size=(out_h, out_w), mode="bilinear", align_corners=False)
2245
+ else:
2246
+ resized = F.interpolate(cropped, size=(out_h, out_w), mode="bilinear", align_corners=False)
2247
+ out_frames.append(resized[0])
2248
+ return torch.stack(out_frames, dim=0)
2249
+
2250
+ def _lock_temporal_box_size_from_first(bboxes: torch.Tensor) -> torch.Tensor:
2251
+ """For [B,T,4], use the frame-0 box for all frames (fully fixed crop)."""
2252
+ if bboxes.dim() != 3 or bboxes.size(1) < 1:
2253
+ return bboxes
2254
+ first = bboxes[:, 0:1, :].clone() # [B,1,4]
2255
+ return first.expand(-1, bboxes.size(1), -1).contiguous()
2256
+
2257
+ def _precompute_roi_cache(dataset_obj, split_name: str = "train") -> str:
2258
+ """Precompute localized crops and save to disk as .pt files.
2259
+
2260
+ Returns the directory path containing the saved crops.
2261
+ The dataset __getitem__ can then load these directly with workers.
2262
+ """
2263
+ cache_dir = os.path.join(output_dir_base, f"{basename}_roi_crops_{split_name}")
2264
+ os.makedirs(cache_dir, exist_ok=True)
2265
+ was_training = model.training
2266
+ model.eval()
2267
+ try:
2268
+ total_items = len(dataset_obj.clips)
2269
+ for ds_idx in range(total_items):
2270
+ loc_clip_noaoi = dataset_obj.load_modelres_clip_by_index(ds_idx).float().unsqueeze(0).to(device)
2271
+ loc_wh = torch.tensor([[float(global_fixed_wh[0]), float(global_fixed_wh[1])]], device=device, dtype=loc_clip_noaoi.dtype)
2272
+ with torch.no_grad():
2273
+ loc_out = model(loc_clip_noaoi, return_localization=True, localization_box_wh=loc_wh)
2274
+ _, pred = _split_localization_output(loc_out)
2275
+ if pred is None:
2276
+ pred = torch.tensor([[0.0, 0.0, 1.0, 1.0]], device=device, dtype=loc_clip_noaoi.dtype)
2277
+ pred = _sanitize_bboxes(pred)
2278
+ if pred.dim() == 3:
2279
+ pred = _lock_temporal_box_size_from_first(pred)
2280
+ bbox_for_crop = pred[0].detach().cpu() if pred.dim() == 3 else pred[0].detach().cpu()
2281
+
2282
+ # Safety fallback: use GT when predicted geometry is invalid.
2283
+ if pred.dim() == 3:
2284
+ invalid = bool(((bbox_for_crop[:, 2] <= bbox_for_crop[:, 0]) | (bbox_for_crop[:, 3] <= bbox_for_crop[:, 1])).any().item())
2285
+ else:
2286
+ invalid = bool(((bbox_for_crop[2] <= bbox_for_crop[0]) or (bbox_for_crop[3] <= bbox_for_crop[1])))
2287
+ ds_valid = dataset_obj.spatial_bbox_valid[ds_idx]
2288
+ if invalid and float(ds_valid[0].item() if ds_valid.dim() > 0 else ds_valid.item()) > 0.5:
2289
+ ds_bbox = dataset_obj.spatial_bboxes[ds_idx]
2290
+ ds_bbox_f0 = ds_bbox[0] if ds_bbox.dim() == 2 else ds_bbox
2291
+ gt_box = _sanitize_bboxes(ds_bbox_f0.view(1, 4))[0].detach().cpu()
2292
+ if pred.dim() == 3:
2293
+ bbox_for_crop = gt_box.unsqueeze(0).expand(pred.size(1), -1).contiguous()
2294
+ else:
2295
+ bbox_for_crop = gt_box
2296
+
2297
+ # Localization crops come from the original frame (no AOI)
2298
+ raw_clip = dataset_obj.load_fullres_clip_by_index(ds_idx, apply_aoi=False).float()
2299
+ out_h = int(loc_clip_noaoi.shape[-2])
2300
+ out_w = int(loc_clip_noaoi.shape[-1])
2301
+ cropped = _crop_single_clip_to_target(raw_clip, bbox_for_crop, out_h, out_w).detach().cpu()
2302
+ torch.save(cropped, os.path.join(cache_dir, f"{ds_idx}.pt"))
2303
+
2304
+ if log_fn and ((ds_idx + 1) % 50 == 0 or (ds_idx + 1) == total_items):
2305
+ log_fn(f"Saving cropped clips [{split_name}]: {ds_idx+1}/{total_items}")
2306
+ finally:
2307
+ if was_training:
2308
+ model.train()
2309
+ if log_fn:
2310
+ log_fn(f"Saved {total_items} cropped clips to {cache_dir}")
2311
+ return cache_dir
2312
+
2313
+ def _precompute_embedding_cache(
2314
+ dataset_obj,
2315
+ split_name: str = "train",
2316
+ num_aug_versions: int = 1,
2317
+ use_augmentation: bool = False,
2318
+ multi_scale: bool = False,
2319
+ ) -> str:
2320
+ """Run each clip through the frozen backbone and save token tensors.
2321
+
2322
+ When num_aug_versions > 1 and use_augmentation=True, each clip is
2323
+ passed through the backbone num_aug_versions times with different
2324
+ random augmentations applied, giving the temporal head diverse
2325
+ inputs each epoch without re-running the backbone at training time.
2326
+ Version 0 is always unaugmented (clean reference).
2327
+ Embeddings are stored as float16 to halve disk usage.
2328
+
2329
+ When multi_scale=True, each clip is also passed through the backbone
2330
+ at half fps (every-other frame subsampled), and saved as {idx}_{v}_s.pt.
2331
+ This gives the temporal head both fine-grained and broader temporal context.
2332
+
2333
+ Returns the cache directory path.
2334
+ """
2335
+ cache_dir = os.path.join(output_dir_base, f"{basename}_emb_cache_{split_name}")
2336
+ os.makedirs(cache_dir, exist_ok=True)
2337
+ was_training = model.training
2338
+ model.eval()
2339
+ raw_model = model.module if hasattr(model, "module") else model
2340
+ total_items = len(dataset_obj.clips)
2341
+ # Multi-scale doubles the number of backbone passes
2342
+ passes_per_item = (2 if multi_scale else 1) * num_aug_versions
2343
+ total_ops = total_items * passes_per_item
2344
+ ops_done = 0
2345
+ try:
2346
+ for ds_idx in range(total_items):
2347
+ clip_info = dataset_obj.clips[ds_idx]
2348
+ clip_id = clip_info["id"]
2349
+ clip_path = dataset_obj._resolve_clip_path(clip_id)
2350
+
2351
+ # If the dataset has no transform, augmentation is a no-op regardless
2352
+ can_augment = use_augmentation and dataset_obj.transform is not None
2353
+ for v in range(num_aug_versions):
2354
+ # Version 0: always clean (no augmentation)
2355
+ apply_aug = can_augment and v > 0
2356
+ clip_t = dataset_obj._load_clip(
2357
+ clip_path,
2358
+ target_size=dataset_obj.target_size,
2359
+ apply_transform=apply_aug,
2360
+ ).float().unsqueeze(0).to(device) # [1, T, C, H, W]
2361
+ with torch.no_grad():
2362
+ tokens = raw_model.backbone(clip_t) # [1, T*S, D]
2363
+ # Save as float16 — halves disk usage with negligible precision loss
2364
+ torch.save(tokens[0].cpu().half(), os.path.join(cache_dir, f"{ds_idx}_{v}.pt"))
2365
+ ops_done += 1
2366
+ if log_fn and (ops_done % 50 == 0 or ops_done == total_ops):
2367
+ log_fn(
2368
+ f"Caching backbone embeddings [{split_name}]: "
2369
+ f"{ops_done}/{total_ops} "
2370
+ f"(clip {ds_idx+1}/{total_items}, aug v{v})"
2371
+ )
2372
+
2373
+ if multi_scale:
2374
+ # Short scale: subsample every-other frame (half fps, same duration)
2375
+ clip_short = clip_t[:, ::2, :, :, :] # [1, T//2, C, H, W]
2376
+ with torch.no_grad():
2377
+ tokens_s = raw_model.backbone(clip_short) # [1, T_s*S, D]
2378
+ torch.save(tokens_s[0].cpu().half(), os.path.join(cache_dir, f"{ds_idx}_{v}_s.pt"))
2379
+ ops_done += 1
2380
+ if log_fn and (ops_done % 50 == 0 or ops_done == total_ops):
2381
+ log_fn(
2382
+ f"Caching backbone embeddings [{split_name}]: "
2383
+ f"{ops_done}/{total_ops} "
2384
+ f"(clip {ds_idx+1}/{total_items}, aug v{v}, short-scale)"
2385
+ )
2386
+ finally:
2387
+ if was_training:
2388
+ model.train()
2389
+ scale_note = " + short-scale" if multi_scale else ""
2390
+ if log_fn:
2391
+ log_fn(
2392
+ f"Saved {total_ops} embedding tensors to {cache_dir} "
2393
+ f"({num_aug_versions} version(s) per clip{scale_note}, float16)"
2394
+ )
2395
+ return cache_dir
2396
+
2397
+ if log_fn:
2398
+ if use_focal_loss:
2399
+ log_fn(f"Using Focal Loss (Active Learning): gamma={focal_gamma} (replaces CrossEntropy)")
2400
+
2401
+ # Determine class-only loss function
2402
+ criterion = None
2403
+ if use_class_weights and hasattr(train_dataset, 'labels') and not use_ovr:
2404
+ if log_fn:
2405
+ log_fn("Computing class weights for loss...")
2406
+ class_weights = compute_class_weights(
2407
+ [l for l in train_dataset.labels if l >= 0],
2408
+ len(train_dataset.classes)
2409
+ ).to(device)
2410
+ if log_fn:
2411
+ log_fn(f"Class weights: {class_weights.tolist()}")
2412
+ if use_ovr:
2413
+ if use_asl:
2414
+ criterion = AsymmetricLoss(
2415
+ gamma_neg=asl_gamma_neg,
2416
+ gamma_pos=asl_gamma_pos,
2417
+ clip=asl_clip,
2418
+ reduction="none",
2419
+ )
2420
+ elif use_focal_loss:
2421
+ criterion = BinaryFocalLoss(gamma=focal_gamma)
2422
+ else:
2423
+ criterion = None # handled inline with F.binary_cross_entropy_with_logits
2424
+ n_hn = sum(1 for s in train_dataset.ovr_suppress_idx if s >= 0)
2425
+ n_real = sum(1 for lbl in train_dataset.labels if lbl >= 0)
2426
+
2427
+ # Per-head pos_weight: upweight positives for minority classes so each
2428
+ # binary head sees balanced effective pos/neg counts.
2429
+ num_c = len(train_dataset.classes)
2430
+ pos_counts = torch.zeros(num_c, dtype=torch.float32)
2431
+
2432
+ if hasattr(train_dataset, 'frame_labels') and getattr(train_dataset, 'frame_labels', None) is not None:
2433
+ if log_fn:
2434
+ log_fn("Computing OvR pos_weight using frame labels...")
2435
+ total_frames = 0
2436
+ for fl in train_dataset.frame_labels:
2437
+ if fl is not None:
2438
+ for lbl in fl:
2439
+ if 0 <= lbl < num_c:
2440
+ pos_counts[lbl] += 1.0
2441
+ total_frames += 1
2442
+ neg_counts = float(total_frames) - pos_counts
2443
+ else:
2444
+ if log_fn:
2445
+ log_fn("Computing OvR pos_weight using clip labels...")
2446
+ total_real = 0
2447
+ for ml in train_dataset.multi_labels:
2448
+ if ml:
2449
+ total_real += 1
2450
+ for lbl in ml:
2451
+ if 0 <= lbl < num_c:
2452
+ pos_counts[lbl] += 1.0
2453
+ neg_counts = float(total_real) - pos_counts
2454
+
2455
+ ovr_pos_weight = torch.ones(num_c, device=device)
2456
+ for ci in range(num_c):
2457
+ if pos_counts[ci] > 0:
2458
+ raw_ratio = neg_counts[ci].item() / pos_counts[ci].item()
2459
+ # ASL already suppresses easy negatives via its focusing term,
2460
+ # so the full neg/pos ratio double-corrects for imbalance.
2461
+ # Use sqrt when ASL is active for a softer complementary weight.
2462
+ if use_asl:
2463
+ ovr_pos_weight[ci] = max(1.0, raw_ratio ** 0.5)
2464
+ else:
2465
+ ovr_pos_weight[ci] = max(1.0, raw_ratio)
2466
+ ovr_pos_weight = ovr_pos_weight.clamp(max=50.0)
2467
+
2468
+ # Helper/background classes (e.g. "Other"): set pos_weight to a fixed value
2469
+ # (default 1.0) so the head trains gently. Use ovr_pos_weight_f1_excluded=1.5
2470
+ # to slightly upweight them if desired.
2471
+ _f1_exclude_names_early = set(config.get("f1_exclude_classes", []))
2472
+ _ovr_pw_excluded = float(config.get("ovr_pos_weight_f1_excluded", 1.5))
2473
+ for ci in range(num_c):
2474
+ if train_dataset.classes[ci] in _f1_exclude_names_early:
2475
+ ovr_pos_weight[ci] = _ovr_pw_excluded
2476
+
2477
+ # Detect co-occurrence pairs from multi-label annotations
2478
+ cooccur_set = set()
2479
+ for ml in train_dataset.multi_labels:
2480
+ if len(ml) >= 2:
2481
+ for a in ml:
2482
+ for b in ml:
2483
+ if a < b:
2484
+ cooccur_set.add((a, b))
2485
+ allowed_cooccurrence = [[train_dataset.classes[a], train_dataset.classes[b]] for a, b in sorted(cooccur_set)]
2486
+ cooccur_lookup = {i: set() for i in range(num_c)}
2487
+ for a, b in cooccur_set:
2488
+ cooccur_lookup[a].add(b)
2489
+ cooccur_lookup[b].add(a)
2490
+
2491
+ if log_fn:
2492
+ if use_asl:
2493
+ loss_msg = f" + ASL(γ-={asl_gamma_neg}, γ+={asl_gamma_pos}, clip={asl_clip})"
2494
+ elif use_focal_loss:
2495
+ loss_msg = f" + BinaryFocal(gamma={focal_gamma})"
2496
+ else:
2497
+ loss_msg = ""
2498
+ log_fn(f"OvR mode: {num_c} heads, {n_real} real clips, "
2499
+ f"{n_hn} near-negative clips{loss_msg}")
2500
+ pw_str = ", ".join(f"{train_dataset.classes[i]}={ovr_pos_weight[i]:.1f}" for i in range(num_c))
2501
+ log_fn(f"OvR per-head pos_weight: {pw_str}")
2502
+ if ovr_label_smoothing > 0:
2503
+ log_fn(f"OvR label smoothing: {ovr_label_smoothing} (targets [{ovr_label_smoothing:.2f}, {1-ovr_label_smoothing:.2f}])")
2504
+ bs = config["batch_size"]
2505
+ n_distinct = num_c + (1 if n_hn > 0 else 0)
2506
+ if bs < num_c:
2507
+ log_fn(f"WARNING: batch_size ({bs}) < num_classes ({num_c}). "
2508
+ f"Some OvR heads will see no positives per batch. "
2509
+ f"Recommend batch_size >= {num_c * 2} for stable OvR training.")
2510
+ elif bs < num_c * 2:
2511
+ log_fn(f"Note: batch_size ({bs}) is small for {num_c}-class OvR. "
2512
+ f"Consider increasing to {num_c * 2}+ for better per-head gradient signal.")
2513
+ if allowed_cooccurrence:
2514
+ pairs_str = ", ".join(f"{a}+{b}" for a, b in allowed_cooccurrence)
2515
+ log_fn(f"OvR allowed co-occurrence pairs: {pairs_str}")
2516
+ elif use_class_weights:
2517
+ if use_focal_loss:
2518
+ criterion = FocalLoss(gamma=focal_gamma, alpha=class_weights)
2519
+ else:
2520
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
2521
+ else:
2522
+ if use_focal_loss:
2523
+ criterion = FocalLoss(gamma=focal_gamma)
2524
+ else:
2525
+ criterion = nn.CrossEntropyLoss()
2526
+
2527
+ if log_fn and not use_ovr:
2528
+ if use_focal_loss:
2529
+ w_msg = "weighted" if use_class_weights else "unweighted"
2530
+ log_fn(f"Using {w_msg} FocalLoss(gamma={focal_gamma})")
2531
+ else:
2532
+ if use_class_weights:
2533
+ log_fn(f"Using Class-Weighted CrossEntropyLoss")
2534
+ else:
2535
+ log_fn(f"Using standard CrossEntropyLoss (no class weights)")
2536
+
2537
+ # Use the ACTUAL classes from the dataset for metadata to ensure correct mapping order
2538
+ class_names = train_dataset.classes
2539
+
2540
+ # Classes excluded from F1 metrics (e.g., "Other"/"Background" helper classes).
2541
+ # They still train normally but don't affect best-model selection.
2542
+ _f1_exclude_names = set(config.get("f1_exclude_classes", []))
2543
+ _f1_include_indices = [i for i, c in enumerate(class_names) if c not in _f1_exclude_names]
2544
+ _f1_exclude_indices = {i for i, c in enumerate(class_names) if c in _f1_exclude_names}
2545
+ train_dataset.stitch_exclude_classes = _f1_exclude_indices
2546
+ slug_counts = {}
2547
+ class_key_map = {}
2548
+ class_label_map = {}
2549
+ for idx, cls_name in enumerate(class_names):
2550
+ base_slug = _slugify_class_name(cls_name)
2551
+ slug = base_slug
2552
+ counter = 1
2553
+ while slug in slug_counts:
2554
+ slug = f"{base_slug}_{counter}"
2555
+ counter += 1
2556
+ slug_counts[slug] = True
2557
+ key = f"val_f1_{slug}"
2558
+ class_key_map[idx] = key
2559
+ class_label_map[idx] = cls_name
2560
+
2561
+ class_counts = Counter(train_dataset.labels)
2562
+
2563
+ class_counts_named = {
2564
+ train_dataset.classes[idx]: count
2565
+ for idx, count in class_counts.items()
2566
+ if 0 <= idx < len(train_dataset.classes)
2567
+ }
2568
+
2569
+ clip_length_value = config.get("clip_length")
2570
+ if clip_length_value is None and hasattr(train_dataset, "clip_length"):
2571
+ clip_length_value = train_dataset.clip_length
2572
+
2573
+ resolution_value = config.get("resolution")
2574
+ if resolution_value is None and hasattr(train_dataset, "target_size"):
2575
+ target_size = train_dataset.target_size
2576
+ if isinstance(target_size, (tuple, list)) and target_size:
2577
+ resolution_value = int(target_size[0])
2578
+ elif isinstance(target_size, int):
2579
+ resolution_value = int(target_size)
2580
+
2581
+ def _json_safe(value):
2582
+ if value is None or isinstance(value, (bool, int, float, str)):
2583
+ return value
2584
+ if isinstance(value, dict):
2585
+ return {str(k): _json_safe(v) for k, v in value.items()}
2586
+ if isinstance(value, (list, tuple, set)):
2587
+ return [_json_safe(v) for v in value]
2588
+ return str(value)
2589
+
2590
+ def _calibrate_ignore_thresholds_from_validation(
2591
+ score_chunks_by_class: dict[int, list[np.ndarray]],
2592
+ target_chunks_by_class: dict[int, list[np.ndarray]],
2593
+ class_names_local: list[str],
2594
+ ) -> Optional[dict]:
2595
+ """Pick per-class ignore thresholds from validation scores by F1."""
2596
+ per_class_thresholds = {}
2597
+ per_class_stats = {}
2598
+ weighted_thresholds = []
2599
+ weighted_supports = []
2600
+ for cls_idx, cls_name in enumerate(class_names_local):
2601
+ score_chunks = score_chunks_by_class.get(cls_idx, [])
2602
+ target_chunks = target_chunks_by_class.get(cls_idx, [])
2603
+ if not score_chunks or not target_chunks:
2604
+ continue
2605
+ scores = np.concatenate(score_chunks).astype(np.float32, copy=False)
2606
+ targets = np.concatenate(target_chunks).astype(np.uint8, copy=False)
2607
+ if scores.size == 0 or targets.size != scores.size:
2608
+ continue
2609
+ pos_support = int(targets.sum())
2610
+ neg_support = int(targets.size - pos_support)
2611
+ if pos_support < 3 or neg_support < 3:
2612
+ continue
2613
+
2614
+ base_grid = np.linspace(0.35, 0.90, 56, dtype=np.float32)
2615
+ quantiles = np.quantile(scores, np.linspace(0.05, 0.95, 19)).astype(np.float32)
2616
+ candidates = np.unique(np.clip(np.concatenate([base_grid, quantiles]), 0.35, 0.90))
2617
+
2618
+ best_tau = 0.60
2619
+ best_f1 = -1.0
2620
+ best_precision = -1.0
2621
+ best_recall = -1.0
2622
+ for tau in candidates:
2623
+ pred_pos = scores >= float(tau)
2624
+ tp = float(np.sum(pred_pos & (targets == 1)))
2625
+ fp = float(np.sum(pred_pos & (targets == 0)))
2626
+ fn = float(np.sum((~pred_pos) & (targets == 1)))
2627
+ precision = tp / max(1.0, tp + fp)
2628
+ recall = tp / max(1.0, tp + fn)
2629
+ f1 = (2.0 * precision * recall) / max(1e-8, precision + recall)
2630
+ if (
2631
+ (f1 > best_f1 + 1e-8)
2632
+ or (abs(f1 - best_f1) <= 1e-8 and precision > best_precision + 1e-8)
2633
+ or (
2634
+ abs(f1 - best_f1) <= 1e-8
2635
+ and abs(precision - best_precision) <= 1e-8
2636
+ and float(tau) > best_tau
2637
+ )
2638
+ ):
2639
+ best_tau = float(tau)
2640
+ best_f1 = float(f1)
2641
+ best_precision = float(precision)
2642
+ best_recall = float(recall)
2643
+
2644
+ per_class_thresholds[cls_name] = float(best_tau)
2645
+ per_class_stats[cls_name] = {
2646
+ "positive_support": pos_support,
2647
+ "negative_support": neg_support,
2648
+ "best_f1": float(best_f1),
2649
+ "best_precision": float(best_precision),
2650
+ "best_recall": float(best_recall),
2651
+ }
2652
+ weighted_thresholds.append(float(best_tau))
2653
+ weighted_supports.append(max(1, pos_support))
2654
+
2655
+ if not per_class_thresholds:
2656
+ return None
2657
+
2658
+ global_threshold = float(
2659
+ np.average(np.asarray(weighted_thresholds, dtype=np.float32), weights=np.asarray(weighted_supports, dtype=np.float32))
2660
+ )
2661
+ global_threshold = max(0.35, min(0.90, global_threshold))
2662
+ return {
2663
+ "source": "validation_f1_calibration",
2664
+ "global_threshold": global_threshold,
2665
+ "per_class_thresholds": per_class_thresholds,
2666
+ "per_class_stats": per_class_stats,
2667
+ }
2668
+
2669
+ ovr_pos_weight_named = None
2670
+ if use_ovr:
2671
+ ovr_pos_weight_named = {
2672
+ train_dataset.classes[i]: float(ovr_pos_weight[i].item())
2673
+ for i in range(len(train_dataset.classes))
2674
+ }
2675
+ stitch_excluded_names = []
2676
+ if hasattr(train_dataset, "stitch_exclude_classes"):
2677
+ stitch_excluded_names = [
2678
+ class_names[i] for i in sorted(train_dataset.stitch_exclude_classes)
2679
+ if 0 <= int(i) < len(class_names)
2680
+ ]
2681
+ sampler_mode = "balanced_batch" if batch_sampler is not None else ("weighted_random" if sampler is not None else "shuffle")
2682
+
2683
+ head_metadata = {
2684
+ "classes": class_names,
2685
+ "num_classes": len(class_names),
2686
+ "clip_length": clip_length_value,
2687
+ "target_fps": int(config.get("target_fps", 16)),
2688
+ "resolution": resolution_value,
2689
+ "training_samples": class_counts_named,
2690
+ "backbone_model": config.get("backbone_model", "videoprism_public_v1_base"),
2691
+ "head": {
2692
+ "type": "DilatedTemporalHead" if use_temporal_decoder else "SpatialPoolLinearHead",
2693
+ "dropout": config.get("dropout", 0.1),
2694
+ "use_localization": use_localization,
2695
+ "localization_hidden_dim": int(config.get("localization_hidden_dim", 256)),
2696
+ "localization_dropout": float(config.get("localization_dropout", 0.0)),
2697
+ "use_temporal_decoder": use_temporal_decoder,
2698
+ "frame_head_temporal_layers": int(config.get("frame_head_temporal_layers", 1)),
2699
+ "temporal_pool_frames": int(config.get("temporal_pool_frames", 1)),
2700
+ "num_stages": int(config.get("num_stages", 3)),
2701
+ "proj_dim": int(config.get("proj_dim", 256)),
2702
+ },
2703
+ "training_config": {
2704
+ "batch_size": config["batch_size"],
2705
+ "epochs": config["epochs"],
2706
+ "lr": config["lr"],
2707
+ "weight_decay": config.get("weight_decay", 0.001),
2708
+ "use_scheduler": use_scheduler,
2709
+ "scheduler_name": "CosineAnnealingLR" if use_scheduler else None,
2710
+ "warmup_epochs": int(warmup_epochs),
2711
+ "t_max_epochs": int(total_epochs - warmup_epochs),
2712
+ "eta_min": float(eta_min) if eta_min is not None else None,
2713
+ "use_ovr": use_ovr,
2714
+ "ovr_background_as_negative": bool(config.get("ovr_background_as_negative", False)) if use_ovr else False,
2715
+ "ovr_background_class_names": config.get("ovr_background_class_names", []) if use_ovr else [],
2716
+ "allowed_cooccurrence": allowed_cooccurrence if use_ovr else [],
2717
+ "cooccurrence_loss_mode": "ignore_negative_pairs" if use_ovr else None,
2718
+ "ovr_pos_weight": ovr_pos_weight_named if use_ovr else None,
2719
+ "use_hard_pair_mining": hard_pair_mining if use_ovr else False,
2720
+ "hard_pair_mode": "pairwise_margin" if (use_ovr and hard_pair_mining) else None,
2721
+ "hard_pairs": hard_pair_name_pairs if (use_ovr and hard_pair_mining) else [],
2722
+ "hard_pair_margin": hard_pair_margin if (use_ovr and hard_pair_mining) else None,
2723
+ "hard_pair_loss_weight": hard_pair_loss_weight if (use_ovr and hard_pair_mining) else None,
2724
+ "hard_pair_confusion_boost": hard_pair_confusion_boost if (use_ovr and hard_pair_mining) else None,
2725
+ "use_class_weights": use_class_weights,
2726
+ "resolution": resolution_value,
2727
+ "use_weighted_sampler": config.get("use_weighted_sampler", False),
2728
+ "use_balanced_sampler": (batch_sampler is not None),
2729
+ "sampler_mode": sampler_mode,
2730
+ "use_augmentation": config.get("use_augmentation", False),
2731
+ "augmentation_options": config.get("augmentation_options") or {},
2732
+ "stitch_augmentation_prob": float(getattr(train_dataset, "stitch_prob", 0.0)),
2733
+ "emb_cache": bool(config.get("emb_cache", False)),
2734
+ "emb_aug_versions": int(config.get("emb_aug_versions", 1)),
2735
+ "stitch_excluded_classes": stitch_excluded_names,
2736
+ "use_focal_loss": use_focal_loss,
2737
+ "focal_gamma": focal_gamma if use_focal_loss else None,
2738
+ "use_supcon_loss": use_supcon_loss,
2739
+ "supcon_weight": supcon_weight if use_supcon_loss else None,
2740
+ "supcon_temperature": supcon_temperature if use_supcon_loss else None,
2741
+ "use_asl": use_asl if use_ovr else False,
2742
+ "asl_gamma_neg": asl_gamma_neg if (use_ovr and use_asl) else None,
2743
+ "asl_gamma_pos": asl_gamma_pos if (use_ovr and use_asl) else None,
2744
+ "asl_clip": asl_clip if (use_ovr and use_asl) else None,
2745
+ "use_ema": use_ema,
2746
+ "ema_decay": float(ema_decay) if use_ema else None,
2747
+ "ema_start_epoch": int(max(warmup_epochs, 3)) if use_ema else None,
2748
+ "frame_loss_weight": 1.0,
2749
+ "use_temporal_decoder": use_temporal_decoder,
2750
+ "use_frame_bout_balance": use_frame_bout_balance,
2751
+ "frame_bout_balance_power": frame_bout_balance_power if use_frame_bout_balance else None,
2752
+ "temporal_pool_frames": int(config.get("temporal_pool_frames", 1)),
2753
+ "frame_head_temporal_layers": int(config.get("frame_head_temporal_layers", 1)),
2754
+ "num_stages": int(config.get("num_stages", 3)),
2755
+ "boundary_loss_weight": boundary_loss_weight,
2756
+ "smoothness_loss_weight": smoothness_loss_weight,
2757
+ "boundary_tolerance": boundary_tolerance,
2758
+ "proj_dim": int(config.get("proj_dim", 256)),
2759
+ "use_localization": use_localization,
2760
+ "use_manual_localization_switch": use_manual_loc_switch if use_localization else None,
2761
+ "manual_localization_switch_epoch": manual_loc_switch_epoch if (use_localization and use_manual_loc_switch) else None,
2762
+ "localization_stage_max_epochs": loc_max_stage_epochs if use_localization else None,
2763
+ "localization_gate_patience": loc_gate_patience if use_localization else None,
2764
+ "localization_gate_iou": loc_gate_iou_threshold if use_localization else None,
2765
+ "localization_gate_center_error": loc_gate_center_error if use_localization else None,
2766
+ "localization_gate_valid_rate": loc_gate_valid_rate if use_localization else None,
2767
+ "classification_crop_gt_prob_start": crop_mix_start_gt if use_localization else None,
2768
+ "classification_crop_gt_prob_end": crop_mix_end_gt if use_localization else None,
2769
+ "classification_crop_padding": crop_padding if use_localization else None,
2770
+ "crop_jitter": config.get("crop_jitter", False),
2771
+ "crop_jitter_strength": config.get("crop_jitter_strength", 0.15),
2772
+ "classification_crop_min_size_norm": crop_min_size if use_localization else None,
2773
+ "center_heatmap_weight": center_heatmap_weight if use_localization else None,
2774
+ "center_heatmap_sigma": center_heatmap_sigma if use_localization else None,
2775
+ "direct_center_weight": direct_center_weight if use_localization else None,
2776
+ "localization_fixed_size_stat": "max" if use_localization else None,
2777
+ "localization_fixed_box_global_w": global_fixed_wh[0] if use_localization else None,
2778
+ "localization_fixed_box_global_h": global_fixed_wh[1] if use_localization else None,
2779
+ "val_split": config.get("val_split", 0.2),
2780
+ "use_all_for_training": config.get("use_all_for_training", False),
2781
+ "config_snapshot": _json_safe(config),
2782
+ }
2783
+ }
2784
+
2785
+ if log_fn:
2786
+ log_fn(f"Training with classes: {class_names}")
2787
+ log_fn(f"Training samples per class (primary label): {class_counts_named}")
2788
+ # Multi-class breakdown
2789
+ mc_count = sum(1 for ml in train_dataset.multi_labels if len(ml) > 1)
2790
+ if mc_count > 0:
2791
+ from collections import Counter as _Counter
2792
+ mc_combos = _Counter(
2793
+ tuple(sorted(train_dataset.classes[i] for i in ml))
2794
+ for ml in train_dataset.multi_labels if len(ml) > 1
2795
+ )
2796
+ log_fn(f"Multi-class clips in training set: {mc_count} of {len(train_dataset.labels)}")
2797
+ for combo, cnt in mc_combos.most_common():
2798
+ log_fn(f" {' + '.join(combo)}: {cnt}")
2799
+
2800
+ best_val_frame_acc = 0.0
2801
+ best_val_f1 = -1.0
2802
+
2803
+ if log_fn:
2804
+ log_fn(f"Starting training for {config['epochs']} epochs...")
2805
+
2806
+ history = {
2807
+ "epoch": [],
2808
+ "train_loss": [],
2809
+ "train_loss_class": [], # primary classification loss only
2810
+ "train_acc": [],
2811
+ "train_frame_acc": [],
2812
+ "val_loss": [],
2813
+ "val_acc": [],
2814
+ "val_frame_acc": [],
2815
+ "val_f1": [],
2816
+ "loc_val_iou": [],
2817
+ "loc_val_center_error": [],
2818
+ "loc_val_valid_rate": [],
2819
+ }
2820
+ for key in class_key_map.values():
2821
+ history[key] = []
2822
+
2823
+
2824
+ # Curriculum state: stage 1 localization -> stage 2 classification on crops.
2825
+ in_localization_stage = bool(use_localization and has_any_localization)
2826
+ localization_gate_streak = 0
2827
+ classification_stage_start_epoch = None
2828
+ output_dir_base = os.path.dirname(config["output_path"])
2829
+ basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
2830
+ crop_progress_num_samples = int(config.get("crop_progress_num_samples", 5))
2831
+ crop_progress_dir = None
2832
+ train_roi_cache = None
2833
+ val_roi_cache = None
2834
+ train_emb_cache = None
2835
+ val_emb_cache_dir = None
2836
+
2837
+ # Save sample clips exactly as VideoPrism sees them (with augmentation)
2838
+ try:
2839
+ import matplotlib
2840
+ matplotlib.use('Agg')
2841
+ import matplotlib.pyplot as plt
2842
+ import matplotlib.patches as mpatches
2843
+ import cv2
2844
+
2845
+ samples_dir = os.path.join(output_dir_base, f"{basename}_input_samples")
2846
+ os.makedirs(samples_dir, exist_ok=True)
2847
+ loc_samples_dir = os.path.join(output_dir_base, f"{basename}_localized_input_samples")
2848
+ crop_progress_dir = os.path.join(output_dir_base, f"{basename}_crop_progress")
2849
+ if use_localization:
2850
+ os.makedirs(loc_samples_dir, exist_ok=True)
2851
+ os.makedirs(crop_progress_dir, exist_ok=True)
2852
+
2853
+ resolution = config.get("resolution", 288)
2854
+ grid_g = resolution // 18
2855
+ sample_indices = np.random.choice(len(train_dataset), size=min(5, len(train_dataset)), replace=False)
2856
+
2857
+ for si, idx in enumerate(sample_indices):
2858
+ batch = train_dataset[idx]
2859
+ clip_tensor = batch[0] # [T, C, H, W]
2860
+ T = clip_tensor.shape[0]
2861
+ frames = clip_tensor.permute(0, 2, 3, 1).numpy() # [T, H, W, C]
2862
+ H, W = frames.shape[1], frames.shape[2]
2863
+
2864
+ label_idx = batch[1] if isinstance(batch[1], int) else batch[1].item()
2865
+ if 0 <= label_idx < len(class_names):
2866
+ label_name = class_names[label_idx]
2867
+ elif label_idx < 0:
2868
+ label_name = "mixed/stitched"
2869
+ else:
2870
+ label_name = f"Class {label_idx}"
2871
+
2872
+ actual_idx = idx % len(train_dataset.clips)
2873
+ clip_id = train_dataset.clips[actual_idx].get("id", "?")
2874
+ orig_label = train_dataset.clips[actual_idx].get("label", "?")
2875
+ # Count valid frame labels
2876
+ fl_tensor = train_dataset.frame_labels[actual_idx]
2877
+ num_labeled_frames = (fl_tensor >= 0).sum().item()
2878
+ total_frames = fl_tensor.numel()
2879
+
2880
+ if log_fn:
2881
+ log_fn(f" Sample {si+1}: idx={idx} actual={actual_idx} clip_id={clip_id} "
2882
+ f"dataset_label={label_name} orig_label={orig_label} "
2883
+ f"({num_labeled_frames}/{total_frames} frames labeled)")
2884
+
2885
+ num_show = min(T, 10)
2886
+ frame_indices_show = np.linspace(0, T - 1, num_show, dtype=int)
2887
+
2888
+ # Scale figure to true pixel size: each frame at 1:1 pixels
2889
+ dpi = 150
2890
+ fig_w = (W * num_show + 40) / dpi # 40px padding between frames
2891
+ fig_h = (H + 80) / dpi # 80px for two-line title
2892
+ fig, axes = plt.subplots(1, num_show, figsize=(fig_w, fig_h))
2893
+ if num_show == 1:
2894
+ axes = [axes]
2895
+
2896
+ for j, fi in enumerate(frame_indices_show):
2897
+ frame_rgb = (np.clip(frames[fi], 0, 1) * 255).astype(np.uint8)
2898
+ axes[j].imshow(frame_rgb)
2899
+ # Draw patch grid
2900
+ for g in range(1, grid_g):
2901
+ axes[j].axhline(y=g * H / grid_g, color='white', linewidth=0.3, alpha=0.5)
2902
+ axes[j].axvline(x=g * W / grid_g, color='white', linewidth=0.3, alpha=0.5)
2903
+ axes[j].axis('off')
2904
+
2905
+ # Show per-frame label if available
2906
+ f_lbl_idx = int(fl_tensor[fi].item())
2907
+ f_lbl_text = f'f{fi}'
2908
+ if f_lbl_idx >= 0 and f_lbl_idx < len(class_names):
2909
+ f_lbl_text += f'\n{class_names[f_lbl_idx]}'
2910
+ elif f_lbl_idx >= 0:
2911
+ f_lbl_text += f'\n{f_lbl_idx}'
2912
+
2913
+ axes[j].set_title(f_lbl_text, fontsize=7)
2914
+
2915
+ clip_id_short = os.path.basename(clip_id)
2916
+ fig.suptitle(
2917
+ f'"{label_name}" — clip {idx} (actual={actual_idx}) | '
2918
+ f'res {H}×{W}, grid {grid_g}×{grid_g}, {T} frames, patch 18px\n'
2919
+ f'file: {clip_id_short} | orig_label: {orig_label} | '
2920
+ f'Frames Labeled: {num_labeled_frames}/{total_frames}',
2921
+ fontsize=8, fontweight='bold'
2922
+ )
2923
+ plt.tight_layout(rect=[0, 0, 1, 0.90])
2924
+ safe_label = label_name.replace(" ", "_").replace("/", "-")
2925
+ plt.savefig(os.path.join(samples_dir, f'sample_{si+1}_{safe_label}.png'), dpi=150, bbox_inches='tight')
2926
+ plt.close()
2927
+
2928
+ # Optional localization preview: crop from native-res video
2929
+ # so the preview matches what the classification head actually receives.
2930
+ if use_localization:
2931
+ bbox_target = batch[3] if len(batch) >= 4 else None
2932
+ bbox_valid = batch[4] if len(batch) >= 5 else None
2933
+ if (
2934
+ torch.is_tensor(bbox_target)
2935
+ and torch.is_tensor(bbox_valid)
2936
+ and float(bbox_valid.sum().item()) > 0.5
2937
+ ):
2938
+ # Use first valid frame bbox for the preview
2939
+ if bbox_target.dim() == 2:
2940
+ valid_mask = bbox_valid > 0.5
2941
+ first_valid_t = int(valid_mask.float().argmax().item()) if valid_mask.any() else 0
2942
+ bbox_target = bbox_target[first_valid_t]
2943
+ x1, y1, x2, y2 = [float(v) for v in bbox_target.tolist()]
2944
+ x1 = max(0.0, min(1.0, x1))
2945
+ y1 = max(0.0, min(1.0, y1))
2946
+ x2 = max(0.0, min(1.0, x2))
2947
+ y2 = max(0.0, min(1.0, y2))
2948
+ if x2 > x1 and y2 > y1:
2949
+ # Load full-resolution clip for high-quality crop
2950
+ actual_idx = idx % len(train_dataset.clips)
2951
+ try:
2952
+ raw_clip = train_dataset.load_fullres_clip_by_index(actual_idx)
2953
+ raw_frames = raw_clip.permute(0, 2, 3, 1).numpy() # [T, Hraw, Wraw, C]
2954
+ Hraw, Wraw = raw_frames.shape[1], raw_frames.shape[2]
2955
+ except Exception as e:
2956
+ logger.debug("Could not load full-res clip by index: %s", e)
2957
+ raw_frames = frames
2958
+ Hraw, Wraw = H, W
2959
+
2960
+ fig2, axes2 = plt.subplots(2, num_show, figsize=(fig_w, (H * 2 + 80) / dpi))
2961
+ if num_show == 1:
2962
+ axes2 = axes2.reshape(2, 1)
2963
+
2964
+ # Bbox in pixel coords on the low-res frames (for display row 1)
2965
+ ix1 = int(round(x1 * W))
2966
+ iy1 = int(round(y1 * H))
2967
+ ix2 = int(round(x2 * W))
2968
+ iy2 = int(round(y2 * H))
2969
+ ix1 = max(0, min(ix1, W - 1))
2970
+ iy1 = max(0, min(iy1, H - 1))
2971
+ ix2 = max(ix1 + 1, min(ix2, W))
2972
+ iy2 = max(iy1 + 1, min(iy2, H))
2973
+
2974
+ # Bbox in pixel coords on the full-res frames (for crop)
2975
+ rx1 = int(round(x1 * Wraw))
2976
+ ry1 = int(round(y1 * Hraw))
2977
+ rx2 = int(round(x2 * Wraw))
2978
+ ry2 = int(round(y2 * Hraw))
2979
+ rx1 = max(0, min(rx1, Wraw - 1))
2980
+ ry1 = max(0, min(ry1, Hraw - 1))
2981
+ rx2 = max(rx1 + 1, min(rx2, Wraw))
2982
+ ry2 = max(ry1 + 1, min(ry2, Hraw))
2983
+
2984
+ for j, fi in enumerate(frame_indices_show):
2985
+ frame_rgb = (np.clip(frames[fi], 0, 1) * 255).astype(np.uint8)
2986
+
2987
+ # Row 1: model-res frame with target bbox
2988
+ axes2[0, j].imshow(frame_rgb)
2989
+ bbox_rect = mpatches.Rectangle(
2990
+ (ix1, iy1), ix2 - ix1, iy2 - iy1,
2991
+ fill=False, edgecolor='orange', linewidth=1.5
2992
+ )
2993
+ axes2[0, j].add_patch(bbox_rect)
2994
+ axes2[0, j].axis('off')
2995
+ axes2[0, j].set_title(f'f{fi}', fontsize=7)
2996
+
2997
+ # Row 2: crop from full-res, resize to model resolution
2998
+ raw_fi = min(fi, len(raw_frames) - 1)
2999
+ raw_rgb = (np.clip(raw_frames[raw_fi], 0, 1) * 255).astype(np.uint8)
3000
+ crop = raw_rgb[ry1:ry2, rx1:rx2]
3001
+ if crop.size == 0:
3002
+ crop = raw_rgb
3003
+ crop = cv2.resize(crop, (W, H), interpolation=cv2.INTER_AREA)
3004
+ axes2[1, j].imshow(crop)
3005
+ axes2[1, j].axis('off')
3006
+
3007
+ fig2.suptitle(
3008
+ f'"{label_name}" — localization target crop preview (from native {Hraw}×{Wraw})',
3009
+ fontsize=9, fontweight='bold'
3010
+ )
3011
+ axes2[0, 0].set_ylabel('Original+bbox', fontsize=9)
3012
+ axes2[1, 0].set_ylabel('Crop preview', fontsize=9)
3013
+ plt.tight_layout(rect=[0, 0, 1, 0.92])
3014
+ plt.savefig(
3015
+ os.path.join(loc_samples_dir, f'loc_sample_{si+1}_{label_name.replace(" ", "_")}.png'),
3016
+ dpi=150,
3017
+ bbox_inches='tight'
3018
+ )
3019
+ plt.close()
3020
+
3021
+ if log_fn:
3022
+ log_fn(f"Saved {len(sample_indices)} input sample visualizations to {samples_dir}")
3023
+ if use_localization:
3024
+ log_fn(f"Saved localization crop previews to {loc_samples_dir}")
3025
+ log_fn(f"Crop progress visualizations will be saved every 2 epochs to {crop_progress_dir}")
3026
+
3027
+ def _save_epoch_crop_progress(epoch_idx: int, phase_name: str):
3028
+ if not (use_localization and crop_progress_dir and len(train_dataset) > 0):
3029
+ return
3030
+ if (epoch_idx + 1) % 2 != 0:
3031
+ return
3032
+ was_training = model.training
3033
+ model.eval()
3034
+ try:
3035
+ sample_count = min(max(1, crop_progress_num_samples), len(train_dataset))
3036
+ sample_indices_epoch = np.random.choice(len(train_dataset), size=sample_count, replace=False)
3037
+
3038
+ if classification_stage_start_epoch is None:
3039
+ gt_crop_prob = crop_mix_end_gt
3040
+ else:
3041
+ cls_stage_steps = max(1, config["epochs"] - classification_stage_start_epoch)
3042
+ cls_epoch_idx = max(0, epoch_idx - classification_stage_start_epoch)
3043
+ alpha = min(1.0, cls_epoch_idx / cls_stage_steps)
3044
+ gt_crop_prob = crop_mix_start_gt + (crop_mix_end_gt - crop_mix_start_gt) * alpha
3045
+
3046
+ for si, idx in enumerate(sample_indices_epoch):
3047
+ actual_idx = idx % len(train_dataset.clips)
3048
+ # Mirror classification crop construction: non-augmented model-res clip.
3049
+ try:
3050
+ clip_tensor = train_dataset.load_modelres_clip_by_index(actual_idx).unsqueeze(0).to(device)
3051
+ except Exception as e:
3052
+ logger.debug("Could not load model-res clip by index: %s", e)
3053
+ batch = train_dataset[idx]
3054
+ clip_tensor = batch[0].unsqueeze(0).to(device) # fallback
3055
+
3056
+ cp_valid = train_dataset.spatial_bbox_valid[actual_idx]
3057
+ valid_gt = bool(float(cp_valid[0].item() if cp_valid.dim() > 0 else cp_valid.item()) > 0.5)
3058
+ gt_bbox_raw = None
3059
+ if valid_gt:
3060
+ cp_bbox = train_dataset.spatial_bboxes[actual_idx]
3061
+ cp_bbox_f0 = cp_bbox[0] if cp_bbox.dim() == 2 else cp_bbox
3062
+ gt_bbox_raw = _clamp_bboxes_no_expand(
3063
+ cp_bbox_f0.view(1, 4).to(device)
3064
+ )[0]
3065
+
3066
+ with torch.no_grad():
3067
+ loc_wh = torch.tensor([[float(global_fixed_wh[0]), float(global_fixed_wh[1])]], device=device, dtype=clip_tensor.dtype)
3068
+ loc_out = model(clip_tensor, return_localization=True, localization_box_wh=loc_wh)
3069
+ _, pred_bbox = _split_localization_output(loc_out)
3070
+ if pred_bbox is None:
3071
+ pred_bbox = torch.zeros((1, 4), device=device)
3072
+ pred_bbox_raw_all = _clamp_bboxes_no_expand(pred_bbox)
3073
+ if pred_bbox_raw_all.dim() == 3:
3074
+ pred_bbox_raw = pred_bbox_raw_all[0, 0]
3075
+ else:
3076
+ pred_bbox_raw = pred_bbox_raw_all[0]
3077
+
3078
+ cls_bbox_raw = pred_bbox_raw.clone()
3079
+ if (phase_name == "classification") and (gt_bbox_raw is not None) and (gt_crop_prob >= 0.5):
3080
+ cls_bbox_raw = gt_bbox_raw.clone()
3081
+
3082
+ # Actual crop boxes used by the pipeline (with padding/min-size sanitization).
3083
+ pred_bbox_used = _sanitize_bboxes(pred_bbox_raw.view(1, 4))[0]
3084
+ cls_bbox_used = _sanitize_bboxes(cls_bbox_raw.view(1, 4))[0]
3085
+
3086
+ # Crop from native-resolution video
3087
+ full_clip = clip_tensor[0].detach().cpu()
3088
+ T = int(full_clip.shape[0])
3089
+ H = int(full_clip.shape[2]) # model resolution
3090
+ W = int(full_clip.shape[3])
3091
+
3092
+ try:
3093
+ raw_clip = train_dataset.load_fullres_clip_by_index(actual_idx).float()
3094
+ Hraw, Wraw = int(raw_clip.shape[2]), int(raw_clip.shape[3])
3095
+ except Exception as e:
3096
+ logger.debug("Could not load full-res clip for crop preview: %s", e)
3097
+ raw_clip = full_clip.clone()
3098
+ Hraw, Wraw = H, W
3099
+
3100
+ pred_crop = _crop_single_clip_to_target(raw_clip, pred_bbox_used, H, W)
3101
+ cls_crop = _crop_single_clip_to_target(raw_clip, cls_bbox_used, H, W)
3102
+
3103
+ num_show = min(T, 8)
3104
+ frame_indices_show = np.linspace(0, T - 1, num_show, dtype=int)
3105
+ fig_w = max(8.0, 2.4 * num_show)
3106
+ fig, axes = plt.subplots(3, num_show, figsize=(fig_w, 6.8))
3107
+ if num_show == 1:
3108
+ axes = axes.reshape(3, 1)
3109
+
3110
+ def _bbox_px(b):
3111
+ x1 = int(round(float(b[0].item()) * (W - 1)))
3112
+ y1 = int(round(float(b[1].item()) * (H - 1)))
3113
+ x2 = int(round(float(b[2].item()) * W))
3114
+ y2 = int(round(float(b[3].item()) * H))
3115
+ x1 = max(0, min(x1, W - 1))
3116
+ y1 = max(0, min(y1, H - 1))
3117
+ x2 = max(x1 + 1, min(x2, W))
3118
+ y2 = max(y1 + 1, min(y2, H))
3119
+ return x1, y1, x2, y2
3120
+
3121
+ px_pred_used = _bbox_px(pred_bbox_used)
3122
+ px_pred_raw = _bbox_px(pred_bbox_raw)
3123
+ px_cls_used = _bbox_px(cls_bbox_used)
3124
+
3125
+ for j, fi in enumerate(frame_indices_show):
3126
+ frame_full = (full_clip[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
3127
+ frame_pred = (pred_crop[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
3128
+ frame_cls = (cls_crop[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
3129
+
3130
+ axes[0, j].imshow(frame_full)
3131
+ pred_raw_rect = mpatches.Rectangle(
3132
+ (px_pred_raw[0], px_pred_raw[1]),
3133
+ px_pred_raw[2] - px_pred_raw[0],
3134
+ px_pred_raw[3] - px_pred_raw[1],
3135
+ fill=False,
3136
+ edgecolor='magenta',
3137
+ linewidth=1.2,
3138
+ linestyle='--'
3139
+ )
3140
+ axes[0, j].add_patch(pred_raw_rect)
3141
+ pred_rect = mpatches.Rectangle(
3142
+ (px_pred_used[0], px_pred_used[1]),
3143
+ px_pred_used[2] - px_pred_used[0],
3144
+ px_pred_used[3] - px_pred_used[1],
3145
+ fill=False,
3146
+ edgecolor='cyan',
3147
+ linewidth=1.4
3148
+ )
3149
+ axes[0, j].add_patch(pred_rect)
3150
+ cls_rect = mpatches.Rectangle(
3151
+ (px_cls_used[0], px_cls_used[1]),
3152
+ px_cls_used[2] - px_cls_used[0],
3153
+ px_cls_used[3] - px_cls_used[1],
3154
+ fill=False,
3155
+ edgecolor='lime',
3156
+ linewidth=1.4
3157
+ )
3158
+ axes[0, j].add_patch(cls_rect)
3159
+ axes[0, j].axis('off')
3160
+ axes[0, j].set_title(f"f{fi}", fontsize=7)
3161
+
3162
+ axes[1, j].imshow(frame_pred)
3163
+ axes[1, j].axis('off')
3164
+ axes[2, j].imshow(frame_cls)
3165
+ axes[2, j].axis('off')
3166
+
3167
+ axes[0, 0].set_ylabel("Full+boxes (clip-level)", fontsize=9)
3168
+ axes[1, 0].set_ylabel("Pred crop", fontsize=9)
3169
+ axes[2, 0].set_ylabel("Cls input", fontsize=9)
3170
+
3171
+ cp_clip_id = train_dataset.clips[actual_idx].get("id", "?")
3172
+ cp_label_name = class_names[train_dataset.labels[actual_idx]] if train_dataset.labels[actual_idx] >= 0 else "?"
3173
+ cp_clip_short = os.path.basename(str(cp_clip_id))
3174
+ if phase_name == "localization":
3175
+ title_note = (
3176
+ f"Epoch {epoch_idx+1} | LOCALIZATION STAGE | \"{cp_label_name}\" | {cp_clip_short}\n"
3177
+ f"row0: model-res {H}×{W}, crops from native {Hraw}×{Wraw} | crop_padding={crop_padding} | "
3178
+ "magenta=pred raw, cyan=pred+padding"
3179
+ )
3180
+ else:
3181
+ title_note = (
3182
+ f"Epoch {epoch_idx+1} | CLASSIFICATION STAGE | \"{cp_label_name}\" | {cp_clip_short} | gt_mix={gt_crop_prob:.2f}\n"
3183
+ f"row0: model-res {H}×{W}, crops from native {Hraw}×{Wraw} | crop_padding={crop_padding} | "
3184
+ "magenta=pred raw, cyan=pred+padding, lime=cls input"
3185
+ )
3186
+ fig.suptitle(title_note, fontsize=9, fontweight='bold')
3187
+ plt.tight_layout(rect=[0, 0, 1, 0.93])
3188
+ out_name = f"epoch_{epoch_idx+1:03d}_sample_{si+1}.png"
3189
+ plt.savefig(os.path.join(crop_progress_dir, out_name), dpi=140, bbox_inches='tight')
3190
+ plt.close(fig)
3191
+
3192
+ if log_fn:
3193
+ log_fn(
3194
+ f"Saved crop progress previews for epoch {epoch_idx+1} to {crop_progress_dir} "
3195
+ f"(visualization only; classification uses precomputed ROI cache)"
3196
+ )
3197
+ except Exception as e_crop:
3198
+ if log_fn:
3199
+ log_fn(f"Note: Could not save crop progress previews for epoch {epoch_idx+1}: {e_crop}")
3200
+ finally:
3201
+ if was_training:
3202
+ model.train()
3203
+ except Exception as e:
3204
+ if log_fn:
3205
+ log_fn(f"Note: Could not save input samples: {e}")
3206
+ def _save_epoch_crop_progress(epoch_idx: int, phase_name: str):
3207
+ return
3208
+
3209
+ for epoch in range(config["epochs"]):
3210
+ if stop_callback and stop_callback():
3211
+ if log_fn:
3212
+ log_fn("Training stopped by user.")
3213
+ break
3214
+
3215
+ if progress_callback:
3216
+ progress_callback(epoch + 1, config["epochs"])
3217
+
3218
+ current_lr = optimizer.param_groups[0]['lr']
3219
+ epoch_phase = "localization" if in_localization_stage else "classification"
3220
+ if log_fn:
3221
+ log_fn(f"\n=== Epoch {epoch+1}/{config['epochs']} | phase={epoch_phase} (LR: {current_lr:.2e}) ===")
3222
+
3223
+ # On first classification epoch: crop all clips using the trained
3224
+ # localization head, save to disk, and switch datasets to load from
3225
+ # those .pt files. DataLoader uses normal num_workers so classification
3226
+ # training runs at the same speed as standard (no-localization) training.
3227
+ if not in_localization_stage and use_localization and has_any_localization and enable_roi_cache:
3228
+ if train_roi_cache is None:
3229
+ if log_fn:
3230
+ log_fn("Cropping clips using localization model and saving to disk (train split)...")
3231
+ train_roi_cache = _precompute_roi_cache(train_dataset, split_name="train")
3232
+ train_dataset._roi_cache_mode = True
3233
+ train_dataset._roi_cache_dir = train_roi_cache
3234
+ if log_fn:
3235
+ log_fn("Dataset switched to disk-cached crops. Training now identical to standard classifier.")
3236
+ # Recreate loader with normal workers (loading .pt files is fast)
3237
+ if batch_sampler is not None:
3238
+ train_loader = DataLoader(
3239
+ train_dataset, batch_sampler=batch_sampler,
3240
+ num_workers=num_workers,
3241
+ pin_memory=True if device.type == "cuda" else False,
3242
+ persistent_workers=True if num_workers > 0 else False,
3243
+ )
3244
+ else:
3245
+ train_loader = DataLoader(
3246
+ train_dataset, batch_size=config["batch_size"],
3247
+ shuffle=shuffle, sampler=sampler, num_workers=num_workers,
3248
+ pin_memory=True if device.type == "cuda" else False,
3249
+ persistent_workers=True if num_workers > 0 else False,
3250
+ )
3251
+ if val_roi_cache is None and val_dataset is not None:
3252
+ if log_fn:
3253
+ log_fn("Cropping clips using localization model and saving to disk (val split)...")
3254
+ val_roi_cache = _precompute_roi_cache(val_dataset, split_name="val")
3255
+ val_dataset._roi_cache_mode = True
3256
+ val_dataset._roi_cache_dir = val_roi_cache
3257
+ val_loader = DataLoader(
3258
+ val_dataset, batch_size=config["batch_size"],
3259
+ shuffle=False, num_workers=num_workers,
3260
+ pin_memory=True if device.type == "cuda" else False,
3261
+ persistent_workers=True if num_workers > 0 else False,
3262
+ )
3263
+
3264
+ # Backbone embeddings are pre-computed once for classification training.
3265
+ # If clip-stitch is enabled, stitching then happens on cached embeddings.
3266
+ # Backbone is frozen during classification — caching is always equivalent
3267
+ # to running it live and avoids 10-50x redundant computation per epoch.
3268
+ emb_aug_versions = max(1, int(config.get("emb_aug_versions", 1)))
3269
+ use_multi_scale = bool(config.get("multi_scale", False))
3270
+ if not in_localization_stage and not use_localization:
3271
+ if not getattr(train_dataset, '_emb_cache_mode', False):
3272
+ aug_note = f" × {emb_aug_versions} augmented versions" if emb_aug_versions > 1 else ""
3273
+ ms_note = " + short-scale (multi-scale)" if use_multi_scale else ""
3274
+ if log_fn:
3275
+ log_fn(
3276
+ f"Pre-computing backbone embeddings "
3277
+ f"(train split{aug_note}{ms_note})..."
3278
+ )
3279
+ emb_cache_dir = _precompute_embedding_cache(
3280
+ train_dataset, split_name="train",
3281
+ num_aug_versions=emb_aug_versions,
3282
+ use_augmentation=emb_aug_versions > 1,
3283
+ multi_scale=use_multi_scale,
3284
+ )
3285
+ train_emb_cache = emb_cache_dir
3286
+ train_dataset._emb_cache_mode = True
3287
+ train_dataset._emb_cache_dir = emb_cache_dir
3288
+ train_dataset._emb_clip_length = config.get("clip_length", 8)
3289
+ train_dataset._emb_num_versions = emb_aug_versions
3290
+ train_dataset._emb_multi_scale = use_multi_scale
3291
+ # batch_sampler is mutually exclusive with batch_size/shuffle/sampler
3292
+ _pin = device.type == "cuda"
3293
+ _pw = num_workers > 0
3294
+ if batch_sampler is not None:
3295
+ train_loader = DataLoader(
3296
+ train_dataset, batch_sampler=batch_sampler,
3297
+ num_workers=num_workers, pin_memory=_pin, persistent_workers=_pw,
3298
+ )
3299
+ else:
3300
+ train_loader = DataLoader(
3301
+ train_dataset, batch_size=config["batch_size"],
3302
+ shuffle=shuffle, sampler=sampler,
3303
+ num_workers=num_workers, pin_memory=_pin, persistent_workers=_pw,
3304
+ )
3305
+ if val_dataset is not None and not getattr(val_dataset, '_emb_cache_mode', False):
3306
+ if log_fn:
3307
+ log_fn("Pre-computing backbone embeddings (val split, no aug)...")
3308
+ # Validation always uses a single clean (unaugmented) version
3309
+ val_emb_cache = _precompute_embedding_cache(
3310
+ val_dataset, split_name="val",
3311
+ num_aug_versions=1, use_augmentation=False,
3312
+ multi_scale=use_multi_scale,
3313
+ )
3314
+ val_emb_cache_dir = val_emb_cache
3315
+ val_dataset._emb_cache_mode = True
3316
+ val_dataset._emb_cache_dir = val_emb_cache
3317
+ val_dataset._emb_clip_length = config.get("clip_length", 8)
3318
+ val_dataset._emb_num_versions = 1
3319
+ val_dataset._emb_multi_scale = use_multi_scale
3320
+ # No stitching during validation — we want clean per-clip evaluation
3321
+ val_dataset.stitch_prob = 0.0
3322
+ val_loader = DataLoader(
3323
+ val_dataset, batch_size=config["batch_size"],
3324
+ shuffle=False, num_workers=num_workers,
3325
+ pin_memory=device.type == "cuda",
3326
+ persistent_workers=num_workers > 0,
3327
+ )
3328
+
3329
+ model.train()
3330
+ total_loss = 0.0
3331
+ total_loss_class = 0.0
3332
+ correct = 0
3333
+ total = 0
3334
+ frame_correct = 0
3335
+ frame_total = 0
3336
+ train_targets_all = []
3337
+ train_preds_all = []
3338
+ # Per-clip confusion scores accumulated over this epoch (OvR only).
3339
+ # Use real clip count, not the virtual-expanded dataset length.
3340
+ _n_clips = len(train_dataset.clips) if hasattr(train_dataset, 'clips') else len(train_dataset)
3341
+ _epoch_confusion = np.zeros(_n_clips, dtype=np.float32)
3342
+ _epoch_confusion_count = np.zeros(_n_clips, dtype=np.int32)
3343
+ _epoch_top_rival = np.full(_n_clips, -1, dtype=np.int32)
3344
+
3345
+ try:
3346
+ for batch_idx, batch_data in enumerate(train_loader):
3347
+ if stop_callback and stop_callback():
3348
+ break
3349
+
3350
+ try:
3351
+ frame_labels_batch = None
3352
+ clips_short_batch = None
3353
+ bg_mask_batch = None
3354
+ suppress_batch = None
3355
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 7:
3356
+ clips, labels, spatial_masks_batch, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch, frame_labels_batch = batch_data[:7]
3357
+ else:
3358
+ clips, labels, spatial_masks_batch, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch = batch_data[:6]
3359
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 8:
3360
+ _cs = batch_data[7]
3361
+ if isinstance(_cs, torch.Tensor) and _cs.numel() > 0:
3362
+ clips_short_batch = _cs
3363
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 9:
3364
+ _bg = batch_data[8]
3365
+ if isinstance(_bg, torch.Tensor) and _bg.numel() > 0:
3366
+ bg_mask_batch = _bg.to(device=device, dtype=torch.bool)
3367
+ clips = clips.to(device)
3368
+ labels = labels.to(device)
3369
+ if frame_labels_batch is not None:
3370
+ frame_labels_batch = frame_labels_batch.to(device)
3371
+ if use_ovr:
3372
+ suppress_batch = _lookup_ovr_suppress_for_batch(sample_indices_batch, train_dataset, device)
3373
+ spatial_bboxes_batch = spatial_bboxes_batch.to(device)
3374
+ spatial_bbox_valid_batch = spatial_bbox_valid_batch.to(device)
3375
+
3376
+ optimizer.zero_grad()
3377
+ attn_w = None
3378
+
3379
+ # First-frame bbox slices for classification stage crop decisions
3380
+ spatial_bboxes_f0 = spatial_bboxes_batch[:, 0, :] if spatial_bboxes_batch.dim() == 3 else spatial_bboxes_batch
3381
+ spatial_bbox_valid_f0 = spatial_bbox_valid_batch[:, 0] if spatial_bbox_valid_batch.dim() == 2 else spatial_bbox_valid_batch
3382
+
3383
+ # Stage 1: train localization on full frames (all frames supervised).
3384
+ # spatial_bboxes_batch: [B, T, 4], spatial_bbox_valid_batch: [B, T]
3385
+ if in_localization_stage:
3386
+ loc_fixed_wh = _fixed_wh_for_labels(labels, device=clips.device, dtype=clips.dtype)
3387
+ loc_out = model(
3388
+ clips,
3389
+ return_localization=True,
3390
+ cache_backbone_tokens=True,
3391
+ localization_box_wh=loc_fixed_wh,
3392
+ )
3393
+ _, loc_pred_bboxes = _split_localization_output(loc_out)
3394
+ if loc_pred_bboxes is None or not has_any_localization:
3395
+ raise RuntimeError("Localization stage enabled but localization predictions/targets are unavailable.")
3396
+
3397
+ raw_model = model.module if hasattr(model, "module") else model
3398
+ backbone_tokens = raw_model._backbone_tokens
3399
+
3400
+ # Flatten temporal dimension for per-frame supervision:
3401
+ # loc_pred_bboxes: [B, T, 4] or [B, 4]
3402
+ B_loc = spatial_bboxes_batch.size(0)
3403
+ T_loc = spatial_bboxes_batch.size(1) if spatial_bboxes_batch.dim() == 3 else 1
3404
+ tgt_flat = spatial_bboxes_batch.view(B_loc * T_loc, 4) # [B*T, 4]
3405
+ valid_flat = spatial_bbox_valid_batch.view(B_loc * T_loc) # [B*T]
3406
+ if loc_pred_bboxes.dim() == 2:
3407
+ # [B, 4] → repeat for T frames
3408
+ pred_flat = loc_pred_bboxes.unsqueeze(1).expand(-1, T_loc, -1).reshape(B_loc * T_loc, 4)
3409
+ else:
3410
+ pred_flat = loc_pred_bboxes.view(B_loc * T_loc, 4) # [B*T, 4]
3411
+
3412
+ # Primary: center heatmap with Gaussian focal loss (all frames)
3413
+ obj_logits = raw_model.localization_head.get_objectness_logits(
3414
+ backbone_tokens, num_frames=train_dataset.clip_length,
3415
+ all_frames=(T_loc > 1),
3416
+ ) # [B*T, S] when all_frames=True, else [B, S]
3417
+ from .model import center_heatmap_loss, direct_center_loss
3418
+ chm_loss = center_heatmap_loss(
3419
+ obj_logits,
3420
+ tgt_flat,
3421
+ valid_flat,
3422
+ sigma_in_patches=center_heatmap_sigma,
3423
+ )
3424
+ loss = center_heatmap_weight * chm_loss
3425
+ chm_loss_val = chm_loss.item()
3426
+
3427
+ # Secondary: direct center-to-center regression (all frames)
3428
+ dc_loss = direct_center_loss(
3429
+ pred_flat,
3430
+ tgt_flat,
3431
+ valid_flat,
3432
+ )
3433
+ loss = loss + direct_center_weight * dc_loss
3434
+ dc_loss_val = dc_loss.item()
3435
+
3436
+ loss.backward()
3437
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
3438
+ if any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
3439
+ optimizer.zero_grad()
3440
+ if log_fn and batch_idx % 10 == 0:
3441
+ log_fn(f"[loc] Skipped batch {batch_idx} (NaN grad)")
3442
+ else:
3443
+ optimizer.step()
3444
+ _update_ema()
3445
+
3446
+ batch_size = clips.size(0)
3447
+ total_loss += loss.item() * batch_size
3448
+ total_loss_class += 0.0
3449
+ total += batch_size
3450
+
3451
+ if log_fn and batch_idx % 10 == 0:
3452
+ iou_val, center_err, valid_rate = _localization_metrics(
3453
+ pred_flat.detach(),
3454
+ tgt_flat,
3455
+ valid_flat,
3456
+ )
3457
+ log_fn(
3458
+ f"Epoch {epoch+1}/{config['epochs']}, Batch {batch_idx}/{len(train_loader)}, "
3459
+ f"Loc Loss: {loss.item():.4f}, CHM: {chm_loss_val:.4f}, DC: {dc_loss_val:.4f}, "
3460
+ f"CErr: {center_err:.3f}, IoU: {iou_val:.3f}, VRate: {valid_rate:.3f}"
3461
+ )
3462
+ continue
3463
+
3464
+ # Classification forward: if embedding cache is active, clips
3465
+ # are pre-computed tokens [B, T*S, D]; otherwise (localization
3466
+ # pipeline) clips are raw pixels [B, T, C, H, W].
3467
+ _clip_len = config.get("clip_length", 8)
3468
+ _emb_active = getattr(train_dataset, '_emb_cache_mode', False)
3469
+ if _emb_active:
3470
+ _cs = clips_short_batch.to(device) if clips_short_batch is not None else None
3471
+ logits = model(
3472
+ None,
3473
+ backbone_tokens=clips,
3474
+ num_frames=_clip_len,
3475
+ backbone_tokens_short=_cs,
3476
+ num_frames_short=_clip_len // 2 if _cs is not None else None,
3477
+ return_localization=False,
3478
+ return_frame_logits=True,
3479
+ )
3480
+ else:
3481
+ logits = model(
3482
+ clips,
3483
+ return_localization=False,
3484
+ return_frame_logits=True,
3485
+ )
3486
+
3487
+ # Frame-level multi-task loss
3488
+ _fo = getattr(model, '_frame_output', None)
3489
+ if _fo is not None and frame_labels_batch is not None and not in_localization_stage:
3490
+ f_logits = _fo[0]
3491
+ f_logits_pooled = _fo[3] if len(_fo) > 3 else f_logits
3492
+ pool_n = int(_fo[4]) if len(_fo) > 4 else 1
3493
+ boundary_logits_out = _fo[5] if len(_fo) > 5 else None
3494
+ frame_embeddings_out = _fo[6] if len(_fo) > 6 else None
3495
+ labels_for_loss = _pool_frame_labels(frame_labels_batch, pool_n)
3496
+ logits_for_loss = f_logits_pooled if pool_n > 1 else f_logits
3497
+ embeddings_for_loss = _pool_frame_embeddings(frame_embeddings_out, pool_n)
3498
+ bg_mask_for_loss = None
3499
+ if use_ovr and ovr_background_as_negative and bg_mask_batch is not None:
3500
+ bg_mask_for_loss = _pool_binary_mask(bg_mask_batch, pool_n)
3501
+
3502
+ # L_state: frame classification loss
3503
+ if use_ovr:
3504
+ B_f, T_f, C_f = logits_for_loss.shape
3505
+ fl_ovr_targets = torch.full((B_f, T_f, C_f), ovr_label_smoothing, device=device)
3506
+ fl_ovr_weight = torch.ones(B_f, T_f, C_f, device=device)
3507
+ for bi in range(B_f):
3508
+ for ti in range(T_f):
3509
+ lbl = int(labels_for_loss[bi, ti].item())
3510
+ if 0 <= lbl < C_f:
3511
+ fl_ovr_targets[bi, ti, lbl] = 1.0 - ovr_label_smoothing
3512
+ fl_ovr_weight[bi, ti, lbl] = ovr_pos_weight[lbl]
3513
+ for cj in cooccur_lookup.get(lbl, ()):
3514
+ if cj != lbl and 0 <= cj < C_f:
3515
+ fl_ovr_weight[bi, ti, cj] = 0.0
3516
+ suppress_mask_for_loss = None
3517
+ if suppress_batch is not None:
3518
+ suppress_mask_for_loss = (labels_for_loss < 0) & (suppress_batch.view(-1, 1) >= 0)
3519
+ if suppress_mask_for_loss.any():
3520
+ fl_ovr_weight = fl_ovr_weight.masked_fill(
3521
+ suppress_mask_for_loss.unsqueeze(-1), 0.0
3522
+ )
3523
+ hn_b, hn_t = torch.nonzero(suppress_mask_for_loss, as_tuple=True)
3524
+ hn_c = suppress_batch[hn_b]
3525
+ fl_ovr_targets[hn_b, hn_t, hn_c] = 0.0
3526
+ fl_ovr_weight[hn_b, hn_t, hn_c] = 1.0
3527
+ if bg_mask_for_loss is not None and bg_mask_for_loss.any():
3528
+ fl_ovr_targets = fl_ovr_targets.masked_fill(
3529
+ bg_mask_for_loss.unsqueeze(-1), 0.0
3530
+ )
3531
+ valid_ovr_mask = (
3532
+ (labels_for_loss >= 0) |
3533
+ (suppress_mask_for_loss if suppress_mask_for_loss is not None else torch.zeros_like(labels_for_loss, dtype=torch.bool)) |
3534
+ (bg_mask_for_loss if bg_mask_for_loss is not None else torch.zeros_like(labels_for_loss, dtype=torch.bool))
3535
+ )
3536
+ loss = _frame_loss_balanced(
3537
+ logits_for_loss, labels_for_loss,
3538
+ use_ovr_local=True, ovr_targets=fl_ovr_targets,
3539
+ ovr_weight=fl_ovr_weight,
3540
+ use_bout_balance=use_frame_bout_balance,
3541
+ bout_power=frame_bout_balance_power,
3542
+ valid_mask_override=valid_ovr_mask,
3543
+ )
3544
+ else:
3545
+ loss = _frame_loss_balanced(
3546
+ logits_for_loss, labels_for_loss,
3547
+ use_ovr_local=False,
3548
+ use_bout_balance=use_frame_bout_balance,
3549
+ bout_power=frame_bout_balance_power,
3550
+ )
3551
+ if hard_pair_mining and hard_pair_loss_weight > 0:
3552
+ pair_loss = _hard_pair_margin_loss(
3553
+ logits_for_loss,
3554
+ labels_for_loss,
3555
+ hard_pair_index_pairs,
3556
+ hard_pair_margin,
3557
+ use_bout_balance=use_frame_bout_balance,
3558
+ bout_power=frame_bout_balance_power,
3559
+ )
3560
+ loss = loss + hard_pair_loss_weight * pair_loss
3561
+
3562
+ # L_boundary: boundary detection loss
3563
+ if boundary_logits_out is not None and boundary_loss_weight > 0:
3564
+ boundary_labels_batch = _generate_boundary_labels(
3565
+ frame_labels_batch, tolerance=boundary_tolerance,
3566
+ )
3567
+ from .model import boundary_detection_loss
3568
+ b_loss = boundary_detection_loss(
3569
+ boundary_logits_out, boundary_labels_batch,
3570
+ )
3571
+ loss = loss + boundary_loss_weight * b_loss
3572
+
3573
+ # L_smooth: temporal smoothness regularizer
3574
+ if smoothness_loss_weight > 0:
3575
+ from .model import temporal_smoothness_loss
3576
+ s_loss = temporal_smoothness_loss(logits_for_loss, labels_for_loss)
3577
+ loss = loss + smoothness_loss_weight * s_loss
3578
+
3579
+ if use_supcon_loss and supcon_weight > 0:
3580
+ sc_loss = _supervised_contrastive_loss(
3581
+ embeddings_for_loss,
3582
+ labels_for_loss,
3583
+ temperature=supcon_temperature,
3584
+ )
3585
+ loss = loss + supcon_weight * sc_loss
3586
+
3587
+ with torch.no_grad():
3588
+ valid_fl = labels_for_loss >= 0
3589
+ if valid_fl.any():
3590
+ if use_ovr:
3591
+ f_pred = torch.argmax(torch.sigmoid(logits_for_loss.detach()), dim=-1)
3592
+ else:
3593
+ f_pred = torch.argmax(logits_for_loss.detach(), dim=-1)
3594
+ frame_total += int(valid_fl.sum().item())
3595
+ frame_correct += int((f_pred[valid_fl] == labels_for_loss[valid_fl]).sum().item())
3596
+ train_targets_all.extend(labels_for_loss[valid_fl].detach().cpu().tolist())
3597
+ train_preds_all.extend(f_pred[valid_fl].detach().cpu().tolist())
3598
+
3599
+ # Accumulate per-clip confusion scores for ConfusionAwareSampler
3600
+ # Skip during warmup: model is still learning basics, scores are noise
3601
+ _confusion_warmup_epoch = int(config["epochs"] * _confusion_warmup_pct)
3602
+ if (use_ovr and isinstance(batch_sampler, ConfusionAwareSampler)
3603
+ and not in_localization_stage
3604
+ and (epoch + 1) > _confusion_warmup_epoch):
3605
+ probs = torch.sigmoid(logits_for_loss.detach()) # [B, T, C]
3606
+ B_cs = probs.shape[0]
3607
+ C_cs = probs.shape[-1]
3608
+ for bi in range(B_cs):
3609
+ raw_idx = int(sample_indices_batch[bi].item())
3610
+ clip_idx = raw_idx % _n_clips
3611
+ # Skip stitched clips (y=-1): mixed classes corrupt the score
3612
+ clip_label = int(labels[bi].item())
3613
+ if clip_label < 0 or clip_label >= C_cs:
3614
+ continue
3615
+ valid_t = labels_for_loss[bi] >= 0
3616
+ if not valid_t.any() or C_cs < 2:
3617
+ continue
3618
+ avg_p = probs[bi][valid_t].mean(0) # [C]
3619
+ # Pair-aware confusion:
3620
+ # favor clips where a specific rival head stays high,
3621
+ # the true head stays weak, or the rival overtakes the true head.
3622
+ wrong_mask = torch.ones(C_cs, dtype=torch.bool, device=avg_p.device)
3623
+ wrong_mask[clip_label] = False
3624
+ if not wrong_mask.any():
3625
+ continue
3626
+ wrong_idx = torch.where(wrong_mask)[0]
3627
+ wrong_probs = avg_p[wrong_mask]
3628
+ top_local = int(torch.argmax(wrong_probs).item())
3629
+ top_rival = int(wrong_idx[top_local].item())
3630
+ top_wrong = float(wrong_probs[top_local].item())
3631
+ true_p = float(avg_p[clip_label].item())
3632
+ rival_margin = max(0.0, top_wrong - true_p)
3633
+ true_deficit = max(0.0, 1.0 - true_p)
3634
+ confusion = 0.55 * top_wrong + 0.30 * rival_margin + 0.15 * true_deficit
3635
+ if hard_pair_mining and hard_pair_confusion_boost > 1.0:
3636
+ pair_key = (min(clip_label, top_rival), max(clip_label, top_rival))
3637
+ if pair_key in hard_pair_index_pairs:
3638
+ confusion *= hard_pair_confusion_boost
3639
+ confusion = min(2.0, confusion)
3640
+ _epoch_confusion[clip_idx] += confusion
3641
+ _epoch_confusion_count[clip_idx] += 1
3642
+ _epoch_top_rival[clip_idx] = top_rival
3643
+ else:
3644
+ loss = logits.sum() * 0.0
3645
+
3646
+ loss.backward()
3647
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
3648
+ if any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
3649
+ optimizer.zero_grad()
3650
+ if log_fn and batch_idx % 10 == 0:
3651
+ log_fn(f"[cls] Skipped batch {batch_idx} (NaN grad)")
3652
+ continue
3653
+ optimizer.step()
3654
+ _update_ema()
3655
+
3656
+ total_loss += loss.item() * clips.size(0)
3657
+ total_loss_class += loss.item() * clips.size(0)
3658
+
3659
+ with torch.no_grad():
3660
+ if use_ovr:
3661
+ predicted = torch.argmax(torch.sigmoid(logits.data), dim=1)
3662
+ else:
3663
+ _, predicted = torch.max(logits.data, 1)
3664
+ valid_mask = labels >= 0
3665
+ total += int(valid_mask.sum().item())
3666
+ correct += int((predicted[valid_mask] == labels[valid_mask]).sum().item())
3667
+
3668
+ if log_fn and batch_idx % 10 == 0:
3669
+ log_fn(f"Epoch {epoch+1}/{config['epochs']}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
3670
+ except Exception as e:
3671
+ error_msg = f"Error in training batch {batch_idx}: {str(e)}\n{traceback.format_exc()}"
3672
+ if log_fn:
3673
+ log_fn(f"ERROR: {error_msg}")
3674
+ raise
3675
+
3676
+ if len(train_dataset) > 0:
3677
+ avg_loss = total_loss / len(train_dataset)
3678
+ avg_loss_class = total_loss_class / len(train_dataset)
3679
+ else:
3680
+ avg_loss = 0.0
3681
+ avg_loss_class = 0.0
3682
+
3683
+ # Update ConfusionAwareSampler weights from this epoch's scores
3684
+ _confusion_warmup_epoch = int(config["epochs"] * _confusion_warmup_pct)
3685
+ if (use_ovr and isinstance(batch_sampler, ConfusionAwareSampler)
3686
+ and not in_localization_stage
3687
+ and (epoch + 1) > _confusion_warmup_epoch):
3688
+ nonzero = _epoch_confusion_count > 0
3689
+ avg_epoch_confusion = np.where(
3690
+ nonzero,
3691
+ _epoch_confusion / np.maximum(_epoch_confusion_count, 1),
3692
+ batch_sampler._confusion_scores, # keep last known score for unseen clips
3693
+ )
3694
+ avg_epoch_top_rival = np.where(
3695
+ nonzero,
3696
+ _epoch_top_rival,
3697
+ batch_sampler._top_rival,
3698
+ )
3699
+ batch_sampler.update_weights(avg_epoch_confusion, top_rivals=avg_epoch_top_rival)
3700
+ if log_fn and (epoch + 1) % 5 == 0:
3701
+ log_fn("Confusion sampler — hardest clips per class:")
3702
+ for line in batch_sampler.log_top_confused(
3703
+ class_names, getattr(train_dataset, "clips", [])
3704
+ ):
3705
+ log_fn(line)
3706
+
3707
+ train_acc = 100.0 * correct / total if total > 0 else 0.0
3708
+ train_frame_acc = 100.0 * frame_correct / frame_total if frame_total > 0 else 0.0
3709
+
3710
+ if log_fn:
3711
+ if use_frame_loss and not in_localization_stage:
3712
+ log_fn(
3713
+ f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_loss:.4f}, "
3714
+ f"Train Acc (clip): {train_acc:.2f}%, Train Acc (frame): {train_frame_acc:.2f}%"
3715
+ )
3716
+ else:
3717
+ log_fn(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%")
3718
+
3719
+ # Record train metrics
3720
+ history["epoch"].append(epoch + 1)
3721
+ history["train_loss"].append(avg_loss)
3722
+ history["train_loss_class"].append(avg_loss_class)
3723
+ history["train_acc"].append(train_acc)
3724
+ history["train_frame_acc"].append(train_frame_acc)
3725
+ history["val_loss"].append(0.0)
3726
+ history["val_acc"].append(0.0)
3727
+ history["val_frame_acc"].append(0.0)
3728
+ history["val_f1"].append(0.0)
3729
+ history["loc_val_iou"].append(0.0)
3730
+ history["loc_val_center_error"].append(1.0 if in_localization_stage else 0.0)
3731
+ history["loc_val_valid_rate"].append(0.0)
3732
+ for key in class_key_map.values():
3733
+ history[key].append(0.0)
3734
+
3735
+ val_acc = 0.0
3736
+ avg_val_loss = 0.0
3737
+ per_attr_f1 = {}
3738
+
3739
+ if val_loader:
3740
+ _apply_ema() # swap to EMA weights for validation
3741
+ model.eval()
3742
+ val_correct = 0
3743
+ val_total = 0
3744
+ val_frame_correct = 0
3745
+ val_frame_total = 0
3746
+ val_loss = 0.0
3747
+ val_targets_all = []
3748
+ val_preds_all = []
3749
+ val_score_chunks_by_class = {i: [] for i in range(len(class_names))}
3750
+ val_target_chunks_by_class = {i: [] for i in range(len(class_names))}
3751
+ val_loc_iou_sum = 0.0
3752
+ val_loc_center_sum = 0.0
3753
+ val_loc_valid_sum = 0.0
3754
+ val_loc_batches = 0
3755
+
3756
+ try:
3757
+ with torch.no_grad():
3758
+ for batch_data in val_loader:
3759
+ frame_labels_batch = None
3760
+ clips_short_val = None
3761
+ bg_mask_batch = None
3762
+ suppress_batch = None
3763
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 7:
3764
+ clips, labels, _, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch, frame_labels_batch = batch_data[:7]
3765
+ else:
3766
+ clips, labels, _, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch = batch_data[:6]
3767
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 8:
3768
+ _cs_v = batch_data[7]
3769
+ if isinstance(_cs_v, torch.Tensor) and _cs_v.numel() > 0:
3770
+ clips_short_val = _cs_v.to(device)
3771
+ if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 9:
3772
+ _bg_v = batch_data[8]
3773
+ if isinstance(_bg_v, torch.Tensor) and _bg_v.numel() > 0:
3774
+ bg_mask_batch = _bg_v.to(device=device, dtype=torch.bool)
3775
+ clips = clips.to(device)
3776
+ labels = labels.to(device)
3777
+ if frame_labels_batch is not None:
3778
+ frame_labels_batch = frame_labels_batch.to(device)
3779
+ if use_ovr:
3780
+ suppress_batch = _lookup_ovr_suppress_for_batch(sample_indices_batch, val_dataset, device)
3781
+ spatial_bboxes_batch = spatial_bboxes_batch.to(device)
3782
+ spatial_bbox_valid_batch = spatial_bbox_valid_batch.to(device)
3783
+
3784
+ spatial_bboxes_f0 = spatial_bboxes_batch[:, 0, :] if spatial_bboxes_batch.dim() == 3 else spatial_bboxes_batch
3785
+ spatial_bbox_valid_f0 = spatial_bbox_valid_batch[:, 0] if spatial_bbox_valid_batch.dim() == 2 else spatial_bbox_valid_batch
3786
+
3787
+ if in_localization_stage:
3788
+ val_wh = _fixed_wh_for_labels(labels, device=clips.device, dtype=clips.dtype)
3789
+ val_out = model(clips, return_localization=True, localization_box_wh=val_wh)
3790
+ _, loc_pred_bboxes = _split_localization_output(val_out)
3791
+ if loc_pred_bboxes is None:
3792
+ continue
3793
+
3794
+ # Flatten temporal dim for per-frame eval
3795
+ B_v = spatial_bboxes_batch.size(0)
3796
+ T_v = spatial_bboxes_batch.size(1) if spatial_bboxes_batch.dim() == 3 else 1
3797
+ tgt_flat_v = spatial_bboxes_batch.view(B_v * T_v, 4)
3798
+ valid_flat_v = spatial_bbox_valid_batch.view(B_v * T_v)
3799
+ if loc_pred_bboxes.dim() == 2:
3800
+ pred_flat_v = loc_pred_bboxes.unsqueeze(1).expand(-1, T_v, -1).reshape(B_v * T_v, 4)
3801
+ else:
3802
+ pred_flat_v = loc_pred_bboxes.view(B_v * T_v, 4)
3803
+
3804
+ from .model import direct_center_loss
3805
+ loss = direct_center_loss(
3806
+ pred_flat_v,
3807
+ tgt_flat_v,
3808
+ valid_flat_v,
3809
+ )
3810
+ val_loss += loss.item() * clips.size(0)
3811
+ iou_val, center_err, valid_rate = _localization_metrics(
3812
+ pred_flat_v,
3813
+ tgt_flat_v,
3814
+ valid_flat_v,
3815
+ )
3816
+ val_loc_iou_sum += iou_val
3817
+ val_loc_center_sum += center_err
3818
+ val_loc_valid_sum += valid_rate
3819
+ val_loc_batches += 1
3820
+ continue
3821
+
3822
+ _val_clip_len = config.get("clip_length", 8)
3823
+ _val_emb_active = getattr(val_dataset, '_emb_cache_mode', False) if val_dataset is not None else False
3824
+ if _val_emb_active:
3825
+ logits = model(
3826
+ None,
3827
+ backbone_tokens=clips,
3828
+ num_frames=_val_clip_len,
3829
+ backbone_tokens_short=clips_short_val,
3830
+ num_frames_short=_val_clip_len // 2 if clips_short_val is not None else None,
3831
+ return_localization=False,
3832
+ return_frame_logits=True,
3833
+ )
3834
+ else:
3835
+ logits = model(
3836
+ clips,
3837
+ return_localization=False,
3838
+ return_frame_logits=True,
3839
+ )
3840
+
3841
+ # Frame-level validation metrics
3842
+ if not in_localization_stage and frame_labels_batch is not None:
3843
+ _fo_val = getattr(model, "_frame_output", None)
3844
+ if _fo_val is not None:
3845
+ f_logits_val = _fo_val[0]
3846
+ f_logits_val_pooled = _fo_val[3] if len(_fo_val) > 3 else f_logits_val
3847
+ pool_n_val = int(_fo_val[4]) if len(_fo_val) > 4 else 1
3848
+ frame_embeddings_val = _fo_val[6] if len(_fo_val) > 6 else None
3849
+ labels_for_val = _pool_frame_labels(frame_labels_batch, pool_n_val)
3850
+ logits_for_val = f_logits_val_pooled if pool_n_val > 1 else f_logits_val
3851
+ embeddings_for_val = _pool_frame_embeddings(frame_embeddings_val, pool_n_val)
3852
+ bg_mask_for_val = None
3853
+ if use_ovr and ovr_background_as_negative and bg_mask_batch is not None:
3854
+ bg_mask_for_val = _pool_binary_mask(bg_mask_batch, pool_n_val)
3855
+ valid_fl_val = labels_for_val >= 0
3856
+ if valid_fl_val.any():
3857
+ if use_ovr:
3858
+ score_tensor_val = torch.sigmoid(logits_for_val)
3859
+ else:
3860
+ score_tensor_val = torch.softmax(logits_for_val, dim=-1)
3861
+ valid_scores_np = score_tensor_val[valid_fl_val].detach().cpu().numpy()
3862
+ valid_targets_np = labels_for_val[valid_fl_val].detach().cpu().numpy()
3863
+ for cls_idx in range(len(class_names)):
3864
+ val_score_chunks_by_class[cls_idx].append(valid_scores_np[:, cls_idx].astype(np.float32, copy=False))
3865
+ val_target_chunks_by_class[cls_idx].append((valid_targets_np == cls_idx).astype(np.uint8, copy=False))
3866
+ if use_ovr:
3867
+ f_pred_val = torch.argmax(torch.sigmoid(logits_for_val), dim=-1)
3868
+ else:
3869
+ f_pred_val = torch.argmax(logits_for_val, dim=-1)
3870
+ val_frame_total += int(valid_fl_val.sum().item())
3871
+ val_frame_correct += int((f_pred_val[valid_fl_val] == labels_for_val[valid_fl_val]).sum().item())
3872
+
3873
+ if use_ovr:
3874
+ B_fv, T_fv, C_fv = logits_for_val.shape
3875
+ fl_ovr_targets_v = torch.full((B_fv, T_fv, C_fv), ovr_label_smoothing, device=device)
3876
+ fl_ovr_weight_v = torch.ones(B_fv, T_fv, C_fv, device=device)
3877
+ for bi in range(B_fv):
3878
+ for ti in range(T_fv):
3879
+ lbl = int(labels_for_val[bi, ti].item())
3880
+ if 0 <= lbl < C_fv:
3881
+ fl_ovr_targets_v[bi, ti, lbl] = 1.0 - ovr_label_smoothing
3882
+ fl_ovr_weight_v[bi, ti, lbl] = ovr_pos_weight[lbl]
3883
+ for cj in cooccur_lookup.get(lbl, ()):
3884
+ if cj != lbl and 0 <= cj < C_fv:
3885
+ fl_ovr_weight_v[bi, ti, cj] = 0.0
3886
+ suppress_mask_for_val = None
3887
+ if suppress_batch is not None:
3888
+ suppress_mask_for_val = (labels_for_val < 0) & (suppress_batch.view(-1, 1) >= 0)
3889
+ if suppress_mask_for_val.any():
3890
+ fl_ovr_weight_v = fl_ovr_weight_v.masked_fill(
3891
+ suppress_mask_for_val.unsqueeze(-1), 0.0
3892
+ )
3893
+ hn_bv, hn_tv = torch.nonzero(suppress_mask_for_val, as_tuple=True)
3894
+ hn_cv = suppress_batch[hn_bv]
3895
+ fl_ovr_targets_v[hn_bv, hn_tv, hn_cv] = 0.0
3896
+ fl_ovr_weight_v[hn_bv, hn_tv, hn_cv] = 1.0
3897
+ if bg_mask_for_val is not None and bg_mask_for_val.any():
3898
+ fl_ovr_targets_v = fl_ovr_targets_v.masked_fill(
3899
+ bg_mask_for_val.unsqueeze(-1), 0.0
3900
+ )
3901
+ valid_ovr_mask_v = (
3902
+ (labels_for_val >= 0) |
3903
+ (suppress_mask_for_val if suppress_mask_for_val is not None else torch.zeros_like(labels_for_val, dtype=torch.bool)) |
3904
+ (bg_mask_for_val if bg_mask_for_val is not None else torch.zeros_like(labels_for_val, dtype=torch.bool))
3905
+ )
3906
+ loss = _frame_loss_balanced(
3907
+ logits_for_val, labels_for_val,
3908
+ use_ovr_local=True, ovr_targets=fl_ovr_targets_v,
3909
+ ovr_weight=fl_ovr_weight_v,
3910
+ use_bout_balance=use_frame_bout_balance,
3911
+ bout_power=frame_bout_balance_power,
3912
+ valid_mask_override=valid_ovr_mask_v,
3913
+ )
3914
+ else:
3915
+ loss = _frame_loss_balanced(
3916
+ logits_for_val, labels_for_val,
3917
+ use_ovr_local=False,
3918
+ use_bout_balance=use_frame_bout_balance,
3919
+ bout_power=frame_bout_balance_power,
3920
+ )
3921
+ if hard_pair_mining and hard_pair_loss_weight > 0:
3922
+ pair_loss_v = _hard_pair_margin_loss(
3923
+ logits_for_val,
3924
+ labels_for_val,
3925
+ hard_pair_index_pairs,
3926
+ hard_pair_margin,
3927
+ use_bout_balance=use_frame_bout_balance,
3928
+ bout_power=frame_bout_balance_power,
3929
+ )
3930
+ loss = loss + hard_pair_loss_weight * pair_loss_v
3931
+ if use_supcon_loss and supcon_weight > 0:
3932
+ sc_loss_v = _supervised_contrastive_loss(
3933
+ embeddings_for_val,
3934
+ labels_for_val,
3935
+ temperature=supcon_temperature,
3936
+ )
3937
+ loss = loss + supcon_weight * sc_loss_v
3938
+ val_loss += loss.item() * clips.size(0)
3939
+
3940
+ if valid_fl_val.any():
3941
+ val_targets_all.extend(labels_for_val[valid_fl_val].detach().cpu().tolist())
3942
+ val_preds_all.extend(f_pred_val[valid_fl_val].detach().cpu().tolist())
3943
+
3944
+ avg_val_loss = val_loss / len(val_dataset)
3945
+ val_acc = 100.0 * val_correct / val_total if val_total > 0 else 0.0
3946
+ val_frame_acc = 100.0 * val_frame_correct / val_frame_total if val_frame_total > 0 else 0.0
3947
+ val_macro_f1 = 0.0
3948
+ per_class_f1 = np.zeros(len(class_names), dtype=float)
3949
+ per_class_support = np.zeros(len(class_names), dtype=int)
3950
+ conf_matrix = np.zeros((len(class_names), len(class_names)), dtype=int)
3951
+
3952
+ per_attr_f1 = {}
3953
+ val_ignore_thresholds = _calibrate_ignore_thresholds_from_validation(
3954
+ val_score_chunks_by_class,
3955
+ val_target_chunks_by_class,
3956
+ class_names,
3957
+ )
3958
+
3959
+ # val_targets_all / val_preds_all are from frame-level validation when frame_labels_batch was used.
3960
+ if val_targets_all:
3961
+ per_class_f1 = f1_score(
3962
+ val_targets_all,
3963
+ val_preds_all,
3964
+ labels=list(range(len(class_names))),
3965
+ average=None,
3966
+ zero_division=0
3967
+ ) * 100.0
3968
+ if _f1_include_indices and len(_f1_include_indices) < len(class_names):
3969
+ val_macro_f1 = float(per_class_f1[_f1_include_indices].mean())
3970
+ else:
3971
+ val_macro_f1 = f1_score(
3972
+ val_targets_all,
3973
+ val_preds_all,
3974
+ average='macro',
3975
+ zero_division=0
3976
+ ) * 100.0
3977
+ per_class_support = np.bincount(
3978
+ np.asarray(val_targets_all, dtype=np.int64),
3979
+ minlength=len(class_names),
3980
+ ).astype(int)
3981
+ for t, p in zip(val_targets_all, val_preds_all):
3982
+ if 0 <= int(t) < len(class_names) and 0 <= int(p) < len(class_names):
3983
+ conf_matrix[int(t), int(p)] += 1
3984
+
3985
+ loc_val_iou = (val_loc_iou_sum / val_loc_batches) if val_loc_batches > 0 else 0.0
3986
+ loc_val_center = (val_loc_center_sum / val_loc_batches) if val_loc_batches > 0 else 1.0
3987
+ loc_val_valid = (val_loc_valid_sum / val_loc_batches) if val_loc_batches > 0 else 0.0
3988
+
3989
+ # Update history: when validation is frame-level, val_acc and val_f1 are per-frame.
3990
+ history["val_loss"][-1] = avg_val_loss
3991
+ if val_frame_total > 0:
3992
+ history["val_acc"][-1] = val_frame_acc
3993
+ else:
3994
+ history["val_acc"][-1] = val_acc
3995
+ history["val_frame_acc"][-1] = val_frame_acc
3996
+ history["val_f1"][-1] = val_macro_f1
3997
+ history["loc_val_iou"][-1] = loc_val_iou
3998
+ history["loc_val_center_error"][-1] = loc_val_center
3999
+ history["loc_val_valid_rate"][-1] = loc_val_valid
4000
+ for idx, key in class_key_map.items():
4001
+ if idx < len(per_class_f1):
4002
+ history[key][-1] = per_class_f1[idx]
4003
+
4004
+ if log_fn:
4005
+ if in_localization_stage:
4006
+ log_fn(
4007
+ f"Epoch {epoch+1}/{config['epochs']} - Val Loc Loss: {avg_val_loss:.4f}, "
4008
+ f"IoU: {loc_val_iou:.3f}, CErr: {loc_val_center:.3f}, VRate: {loc_val_valid:.3f}"
4009
+ )
4010
+ else:
4011
+ log_fn(
4012
+ f"Epoch {epoch+1}/{config['epochs']} - Val Loss: {avg_val_loss:.4f}, "
4013
+ f"Val Acc (frame): {val_frame_acc:.2f}%, "
4014
+ f"Val Macro F1: {val_macro_f1:.2f}%"
4015
+ )
4016
+ if val_targets_all:
4017
+ metric_scope = "frame-labeled" if (use_frame_loss and not in_localization_stage) else "clip"
4018
+ log_fn(f"Val class diagnostics ({metric_scope}):")
4019
+ for ci, cname in enumerate(class_names):
4020
+ excl_tag = " [excluded from F1]" if cname in _f1_exclude_names else ""
4021
+ log_fn(
4022
+ f" - {cname}: support={int(per_class_support[ci])}, "
4023
+ f"F1={float(per_class_f1[ci]):.2f}%{excl_tag}"
4024
+ )
4025
+ if len(class_names) <= 12:
4026
+ log_fn("Val confusion matrix rows=true, cols=pred:")
4027
+ for ci, cname in enumerate(class_names):
4028
+ row_vals = " ".join(str(int(v)) for v in conf_matrix[ci].tolist())
4029
+ log_fn(f" {ci}:{cname} | {row_vals}")
4030
+ if val_ignore_thresholds:
4031
+ log_fn(
4032
+ "Validation-calibrated ignore thresholds: "
4033
+ + ", ".join(
4034
+ f"{cls}={float(tau):.2f}"
4035
+ for cls, tau in sorted(
4036
+ val_ignore_thresholds.get("per_class_thresholds", {}).items(),
4037
+ key=lambda item: item[0],
4038
+ )
4039
+ )
4040
+ )
4041
+
4042
+ if head_metadata:
4043
+ head_metadata["training_config"]["validation_calibrated_ignore_thresholds"] = _json_safe(val_ignore_thresholds)
4044
+
4045
+ metric_improved = False if in_localization_stage else (
4046
+ val_macro_f1 > best_val_f1 + 1e-6
4047
+ )
4048
+
4049
+ if in_localization_stage:
4050
+ gate_pass = (
4051
+ (loc_val_iou >= loc_gate_iou_threshold)
4052
+ and (loc_val_center <= loc_gate_center_error)
4053
+ and (loc_val_valid >= loc_gate_valid_rate)
4054
+ )
4055
+ localization_gate_streak = localization_gate_streak + 1 if gate_pass else 0
4056
+ reached_epoch_cap = (epoch + 1) >= max(1, loc_max_stage_epochs)
4057
+ reached_manual_switch = use_manual_loc_switch and ((epoch + 1) >= max(1, manual_loc_switch_epoch))
4058
+ if gate_pass and log_fn:
4059
+ log_fn(
4060
+ f"Localization gate check passed ({localization_gate_streak}/{loc_gate_patience})"
4061
+ )
4062
+ if reached_manual_switch or (localization_gate_streak >= max(1, loc_gate_patience)) or reached_epoch_cap:
4063
+ in_localization_stage = False
4064
+ classification_stage_start_epoch = epoch + 1
4065
+ if log_fn:
4066
+ if reached_manual_switch:
4067
+ reason = f"manual switch epoch reached ({manual_loc_switch_epoch})"
4068
+ elif localization_gate_streak >= max(1, loc_gate_patience):
4069
+ reason = "metrics gate reached"
4070
+ else:
4071
+ reason = "max localization epochs reached"
4072
+ log_fn(
4073
+ f"Switching to classification stage at epoch {epoch+1} ({reason}). "
4074
+ "Classifier will train on localized crops."
4075
+ )
4076
+ # Reset LR schedule, EMA, and optimizer state for classification
4077
+ cls_remaining = total_epochs - (epoch + 1)
4078
+ cls_warmup = 0
4079
+ if use_scheduler:
4080
+ eta_min = 0.2 * classification_lr
4081
+ for pg in optimizer.param_groups:
4082
+ pg['lr'] = classification_lr
4083
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
4084
+ optimizer, T_max=max(1, cls_remaining - cls_warmup), eta_min=eta_min
4085
+ )
4086
+ warmup_scheduler = None
4087
+ warmup_epochs = cls_warmup
4088
+ if log_fn:
4089
+ log_fn(f"Reset LR schedule: CosineAnnealingLR (single decay, {cls_remaining} epochs)")
4090
+ optimizer.state.clear()
4091
+ if log_fn:
4092
+ log_fn("Reset optimizer momentum/adaptive state for classification")
4093
+ ema_active = False
4094
+ ema_state.clear()
4095
+
4096
+ if metric_improved:
4097
+ best_val_f1 = val_macro_f1
4098
+ best_val_frame_acc = val_frame_acc
4099
+ if head_metadata:
4100
+ head_metadata["training_config"]["best_val_f1"] = best_val_f1
4101
+ if config.get("save_best", True):
4102
+ # Create folder for this best epoch
4103
+ output_dir = os.path.dirname(config["output_path"])
4104
+ basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
4105
+ from datetime import datetime
4106
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4107
+
4108
+ best_folder = os.path.join(
4109
+ output_dir,
4110
+ f"{basename}_checkpoints",
4111
+ f"epoch_{epoch+1}_f1_{val_macro_f1:.1f}_frameacc_{val_frame_acc:.1f}_{timestamp}"
4112
+ )
4113
+ os.makedirs(best_folder, exist_ok=True)
4114
+
4115
+ # Save Model
4116
+ best_path = os.path.join(best_folder, "model.pt")
4117
+ if head_metadata:
4118
+ model.save_head(best_path, metadata=head_metadata)
4119
+ else:
4120
+ model.save_head(best_path)
4121
+
4122
+ # Also update main best file
4123
+ best_main_path = config["output_path"].replace(".pt", "_best.pt")
4124
+ if head_metadata:
4125
+ model.save_head(best_main_path, metadata=head_metadata)
4126
+ else:
4127
+ model.save_head(best_main_path)
4128
+
4129
+ # Save Logs & Plots
4130
+ import pandas as pd
4131
+ import matplotlib
4132
+ matplotlib.use('Agg')
4133
+ import matplotlib.pyplot as plt
4134
+
4135
+ # Construct temp history including current epoch
4136
+ curr_hist = {k: v.copy() for k, v in history.items()}
4137
+ # Train metrics and epoch are ALREADY in history (updated before validation)
4138
+ # And val metrics were just updated in-place at index -1
4139
+
4140
+ pd.DataFrame(curr_hist).to_csv(os.path.join(best_folder, "history.csv"), index=False)
4141
+
4142
+ plt.style.use('ggplot')
4143
+ fig, axes = plt.subplots(4, 1, figsize=(10, 18))
4144
+ ax1, ax2, ax3, ax4 = axes
4145
+ epochs_hist = curr_hist['epoch']
4146
+
4147
+ ax1.plot(epochs_hist, curr_hist['train_acc'], label='Train Acc', marker='o')
4148
+ ax1.plot(epochs_hist, curr_hist['val_acc'], label='Val Acc (frame)', marker='s')
4149
+ ax1.set_title(f'Accuracy - Epoch {epoch+1}')
4150
+ ax1.set_ylabel('Accuracy (%)')
4151
+ ax1.legend()
4152
+ ax1.grid(True)
4153
+
4154
+ ax2.plot(epochs_hist, curr_hist['train_loss'], label='Train Loss', marker='o')
4155
+ ax2.plot(epochs_hist, curr_hist['val_loss'], label='Val Loss', marker='s')
4156
+ ax2.set_ylabel('Loss')
4157
+ ax2.legend()
4158
+ ax2.grid(True)
4159
+
4160
+ ax3.plot(epochs_hist, curr_hist['val_f1'], label='Val Macro F1 (frame)', linewidth=2, color='tab:purple')
4161
+ for idx in range(len(class_names)):
4162
+ class_key = class_key_map.get(idx)
4163
+ if class_key in curr_hist:
4164
+ ax3.plot(
4165
+ epochs_hist,
4166
+ curr_hist[class_key],
4167
+ label=f"{class_names[idx]}",
4168
+ linestyle='--',
4169
+ alpha=0.6
4170
+ )
4171
+ ax3.set_ylabel('F1 (%)')
4172
+ ax3.legend(ncol=2, fontsize=8)
4173
+ ax3.grid(True)
4174
+
4175
+ per_class_keys_ordered = [class_key_map[idx] for idx in range(len(class_names))]
4176
+ if per_class_keys_ordered:
4177
+ per_class_matrix = np.array([
4178
+ curr_hist[key] for key in per_class_keys_ordered
4179
+ ])
4180
+ else:
4181
+ per_class_matrix = np.zeros((0, len(epochs_hist)))
4182
+
4183
+ im = ax4.imshow(
4184
+ per_class_matrix,
4185
+ aspect='auto',
4186
+ cmap='magma',
4187
+ vmin=0,
4188
+ vmax=100
4189
+ )
4190
+ ax4.set_yticks(range(len(class_names)))
4191
+ ax4.set_yticklabels(class_names)
4192
+ ax4.set_xlabel('Epoch')
4193
+ ax4.set_ylabel('Class')
4194
+ ax4.set_title('Validation F1 Heatmap (%)')
4195
+ if epochs_hist:
4196
+ max_ticks = min(len(epochs_hist), 12)
4197
+ tick_positions = np.linspace(0, len(epochs_hist) - 1, max_ticks, dtype=int)
4198
+ ax4.set_xticks(tick_positions)
4199
+ ax4.set_xticklabels([str(epochs_hist[i]) for i in tick_positions])
4200
+ cbar = fig.colorbar(im, ax=ax4, orientation='vertical', pad=0.01)
4201
+ cbar.set_label('F1 (%)')
4202
+
4203
+ plt.tight_layout()
4204
+ plt.savefig(os.path.join(best_folder, "training_plot.pdf"))
4205
+ plt.close()
4206
+
4207
+ if log_fn:
4208
+ log_fn(f"Saved best model checkpoint to {best_folder}")
4209
+
4210
+ except Exception as e:
4211
+ error_msg = f"Error in validation: {str(e)}\n{traceback.format_exc()}"
4212
+ if log_fn:
4213
+ log_fn(f"ERROR: {error_msg}")
4214
+ raise
4215
+
4216
+ # Record val metrics (use train metrics if no validation)
4217
+ if not val_loader:
4218
+ history["val_loss"][-1] = avg_loss
4219
+ history["val_acc"][-1] = train_acc
4220
+ history["val_frame_acc"][-1] = train_frame_acc
4221
+ history["val_f1"][-1] = 0.0
4222
+
4223
+ # Print train confusion matrix every 5 epochs when there is no val set
4224
+ if (log_fn and train_targets_all and not in_localization_stage
4225
+ and ((epoch + 1) % 5 == 0 or epoch == 0)):
4226
+ train_macro_f1 = f1_score(
4227
+ train_targets_all, train_preds_all,
4228
+ average='macro', zero_division=0,
4229
+ ) * 100.0
4230
+ per_class_f1_tr = f1_score(
4231
+ train_targets_all, train_preds_all,
4232
+ labels=list(range(len(class_names))),
4233
+ average=None, zero_division=0,
4234
+ ) * 100.0
4235
+ per_class_support_tr = np.bincount(
4236
+ np.asarray(train_targets_all, dtype=np.int64),
4237
+ minlength=len(class_names),
4238
+ ).astype(int)
4239
+ conf_matrix_tr = np.zeros((len(class_names), len(class_names)), dtype=np.int64)
4240
+ for t, p in zip(train_targets_all, train_preds_all):
4241
+ if 0 <= int(t) < len(class_names) and 0 <= int(p) < len(class_names):
4242
+ conf_matrix_tr[int(t), int(p)] += 1
4243
+ history["val_f1"][-1] = train_macro_f1
4244
+ log_fn(f"Train Macro F1: {train_macro_f1:.2f}%")
4245
+ log_fn("Train class diagnostics (frame-labeled):")
4246
+ for ci, cname in enumerate(class_names):
4247
+ log_fn(
4248
+ f" - {cname}: support={int(per_class_support_tr[ci])}, "
4249
+ f"F1={float(per_class_f1_tr[ci]):.2f}%"
4250
+ )
4251
+ if len(class_names) <= 12:
4252
+ log_fn("Train confusion matrix rows=true, cols=pred:")
4253
+ for ci, cname in enumerate(class_names):
4254
+ row_vals = " ".join(str(int(v)) for v in conf_matrix_tr[ci].tolist())
4255
+ log_fn(f" {ci}:{cname} | {row_vals}")
4256
+ reached_manual_switch = use_manual_loc_switch and ((epoch + 1) >= max(1, manual_loc_switch_epoch))
4257
+ reached_epoch_cap = (epoch + 1) >= max(1, loc_max_stage_epochs)
4258
+ if in_localization_stage and (reached_manual_switch or reached_epoch_cap):
4259
+ in_localization_stage = False
4260
+ classification_stage_start_epoch = epoch + 1
4261
+ if log_fn:
4262
+ if reached_manual_switch:
4263
+ reason = f"manual switch epoch reached ({manual_loc_switch_epoch})"
4264
+ else:
4265
+ reason = "max localization epochs reached"
4266
+ log_fn(
4267
+ f"Switching to classification stage at epoch {epoch+1} ({reason}, no validation set)."
4268
+ )
4269
+ # Reset LR schedule, EMA, and optimizer state for classification
4270
+ cls_remaining = total_epochs - (epoch + 1)
4271
+ cls_warmup = 0
4272
+ if use_scheduler:
4273
+ eta_min = 0.2 * classification_lr
4274
+ for pg in optimizer.param_groups:
4275
+ pg['lr'] = classification_lr
4276
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
4277
+ optimizer, T_max=max(1, cls_remaining - cls_warmup), eta_min=eta_min
4278
+ )
4279
+ warmup_scheduler = None
4280
+ warmup_epochs = cls_warmup
4281
+ if log_fn:
4282
+ log_fn(f"Reset LR schedule: CosineAnnealingLR (single decay, {cls_remaining} epochs)")
4283
+ optimizer.state.clear()
4284
+ if log_fn:
4285
+ log_fn("Reset optimizer momentum/adaptive state for classification")
4286
+ ema_active = False
4287
+ ema_state.clear()
4288
+
4289
+ # Save checkpoints only after classification stage starts.
4290
+ # During localization stage we skip periodic model saves to
4291
+ # avoid generating unusable intermediate checkpoints.
4292
+ _apply_ema() # swap to EMA weights for saving
4293
+ if config.get("save_best", True) and (not in_localization_stage):
4294
+ output_dir = os.path.dirname(config["output_path"])
4295
+ basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
4296
+ from datetime import datetime
4297
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4298
+
4299
+ epoch_folder = os.path.join(
4300
+ output_dir,
4301
+ f"{basename}_checkpoints",
4302
+ f"epoch_{epoch+1}_trainloss_{avg_loss:.4f}_acc_{train_acc:.1f}_{timestamp}"
4303
+ )
4304
+ os.makedirs(epoch_folder, exist_ok=True)
4305
+
4306
+ epoch_path = os.path.join(epoch_folder, "model.pt")
4307
+ if head_metadata:
4308
+ model.save_head(epoch_path, metadata=head_metadata)
4309
+ else:
4310
+ model.save_head(epoch_path)
4311
+
4312
+ if log_fn:
4313
+ log_fn(f"Saved epoch checkpoint to {epoch_folder}")
4314
+ elif config.get("save_best", True) and in_localization_stage and log_fn:
4315
+ log_fn("Checkpoint save skipped (still in localization phase; classification not started yet).")
4316
+
4317
+ # Restore training weights after EMA-based validation/saving
4318
+ _apply_ema()
4319
+
4320
+ # Save crop-progress visualization every 2nd epoch.
4321
+ _save_epoch_crop_progress(epoch, epoch_phase)
4322
+
4323
+ # Incremental history CSV every 2 epochs for offline plotting
4324
+ if (epoch + 1) % 2 == 0:
4325
+ try:
4326
+ import pandas as pd
4327
+ inc_csv_dir = os.path.join(output_dir_base, f"{basename}_training_history")
4328
+ os.makedirs(inc_csv_dir, exist_ok=True)
4329
+ inc_csv_path = os.path.join(inc_csv_dir, "history.csv")
4330
+ pd.DataFrame(history).to_csv(inc_csv_path, index=False)
4331
+ except Exception as e:
4332
+ logger.debug("Could not save incremental history CSV: %s", e)
4333
+
4334
+ # Step LR scheduler once per epoch.
4335
+ sched_epoch_base = classification_stage_start_epoch if classification_stage_start_epoch is not None else 0
4336
+ sched_epoch = epoch - sched_epoch_base
4337
+ if use_scheduler and sched_epoch >= 0:
4338
+ if warmup_scheduler is not None and warmup_epochs > 0 and sched_epoch < warmup_epochs:
4339
+ warmup_scheduler.step()
4340
+ if use_ema and sched_epoch == warmup_epochs - 1:
4341
+ ema_active = True
4342
+ _init_ema()
4343
+ if log_fn:
4344
+ log_fn(f"EMA activated after {warmup_epochs}-epoch warmup")
4345
+ elif scheduler is not None:
4346
+ scheduler.step()
4347
+ ema_start_epoch = max(warmup_epochs, 3)
4348
+ if use_ema and (not ema_active) and sched_epoch >= ema_start_epoch - 1:
4349
+ ema_active = True
4350
+ _init_ema()
4351
+ if log_fn:
4352
+ log_fn(f"EMA activated at epoch {epoch + 1}")
4353
+ current_lr = max(pg['lr'] for pg in optimizer.param_groups)
4354
+ if log_fn and (epoch + 1) % 5 == 0:
4355
+ log_fn(f"Current Learning Rate: {current_lr:.8f}")
4356
+
4357
+ if metrics_callback:
4358
+ current_metrics = {
4359
+ "epoch": epoch + 1,
4360
+ "train_loss": history["train_loss"][-1],
4361
+ "train_loss_class": history["train_loss_class"][-1],
4362
+ "train_acc": history["train_acc"][-1],
4363
+ "train_frame_acc": history["train_frame_acc"][-1],
4364
+ "val_loss": history["val_loss"][-1],
4365
+ "val_acc": history["val_acc"][-1],
4366
+ "val_frame_acc": history["val_frame_acc"][-1],
4367
+ "val_f1": history["val_f1"][-1],
4368
+ "training_phase": epoch_phase,
4369
+ "loc_val_iou": history["loc_val_iou"][-1],
4370
+ "loc_val_center_error": history["loc_val_center_error"][-1],
4371
+ "loc_val_valid_rate": history["loc_val_valid_rate"][-1],
4372
+ "per_class_f1": {
4373
+ class_names[idx]: history[class_key_map[idx]][-1]
4374
+ for idx in range(len(class_names))
4375
+ if class_key_map.get(idx) in history
4376
+ and class_names[idx] not in _f1_exclude_names
4377
+ },
4378
+ "per_attr_f1": per_attr_f1,
4379
+ "crop_progress_dir": crop_progress_dir,
4380
+ }
4381
+ metrics_callback(current_metrics)
4382
+
4383
+ except Exception as e:
4384
+ error_msg = f"Error in epoch {epoch+1}: {str(e)}\n{traceback.format_exc()}"
4385
+ if log_fn:
4386
+ log_fn(f"ERROR: {error_msg}")
4387
+ raise
4388
+
4389
+ if log_fn:
4390
+ log_fn("Training complete!")
4391
+
4392
+ # --- Augmentation ablation evaluation ---
4393
+ if (config.get("use_augmentation", False)
4394
+ and not (stop_callback and stop_callback())):
4395
+ try:
4396
+ # Load best model for evaluation
4397
+ best_main_path = config["output_path"].replace(".pt", "_best.pt")
4398
+ if os.path.exists(best_main_path):
4399
+ model.load_head(best_main_path)
4400
+ _run_augmentation_ablation_eval(
4401
+ model, train_dataset, config, device, log_fn=log_fn
4402
+ )
4403
+ except Exception as abl_err:
4404
+ if log_fn:
4405
+ log_fn(f"Augmentation ablation eval failed (non-fatal): {abl_err}")
4406
+
4407
+ # --- Per-head temperature calibration (OvR only) ---
4408
+ if use_ovr and val_loader is not None and not (stop_callback and stop_callback()):
4409
+ try:
4410
+ # Load best checkpoint for calibration
4411
+ best_main_path = config["output_path"].replace(".pt", "_best.pt")
4412
+ if os.path.exists(best_main_path):
4413
+ model.load_head(best_main_path)
4414
+ model.eval()
4415
+ all_logits = []
4416
+ all_labels = []
4417
+ _val_emb_mode = getattr(val_dataset, '_emb_cache_mode', False)
4418
+ _cal_clip_len = config.get("clip_length", 8)
4419
+ with torch.no_grad():
4420
+ for batch in val_loader:
4421
+ # Batch is a tuple: (clips, labels, spatial_mask, bboxes, bbox_valid, indices, frame_labels)
4422
+ if not isinstance(batch, (list, tuple)) or len(batch) < 7:
4423
+ continue
4424
+ clips = batch[0].to(device)
4425
+ labels = batch[1].to(device)
4426
+ frame_labels_cal = batch[6].to(device)
4427
+ _cs_cal = batch[7] if len(batch) > 7 and isinstance(batch[7], torch.Tensor) and batch[7].numel() > 0 else None
4428
+ if _cs_cal is not None:
4429
+ _cs_cal = _cs_cal.to(device)
4430
+
4431
+ if _val_emb_mode:
4432
+ out = model(None, backbone_tokens=clips,
4433
+ num_frames=_cal_clip_len,
4434
+ backbone_tokens_short=_cs_cal,
4435
+ num_frames_short=_cal_clip_len // 2 if _cs_cal is not None else None,
4436
+ return_frame_logits=True)
4437
+ else:
4438
+ out = model(clips, return_frame_logits=True)
4439
+
4440
+ fo = getattr(model, '_frame_output', None)
4441
+ if fo is not None:
4442
+ f_logits = fo[0] # [B, T, C]
4443
+ all_logits.append(f_logits.cpu())
4444
+ all_labels.append(frame_labels_cal.cpu())
4445
+
4446
+ if all_logits:
4447
+ cat_logits = torch.cat(all_logits, dim=0) # [N, T, C]
4448
+ cat_labels = torch.cat(all_labels, dim=0) # [N, T]
4449
+ B_cal, T_cal, C_cal = cat_logits.shape
4450
+ flat_logits = cat_logits.reshape(-1, C_cal) # [N*T, C]
4451
+ flat_labels = cat_labels.reshape(-1) # [N*T]
4452
+ valid_mask = flat_labels >= 0
4453
+
4454
+ if valid_mask.sum() > 50:
4455
+ temperatures = torch.ones(C_cal)
4456
+ for ci in range(C_cal):
4457
+ ci_logits = flat_logits[valid_mask, ci]
4458
+ ci_targets = (flat_labels[valid_mask] == ci).float()
4459
+ if ci_targets.sum() < 5 or (1 - ci_targets).sum() < 5:
4460
+ continue
4461
+ # Grid search for temperature that maximizes F1
4462
+ best_f1 = -1.0
4463
+ best_t = 1.0
4464
+ for t_cand in [0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 2.5, 3.0]:
4465
+ probs = torch.sigmoid(ci_logits / t_cand)
4466
+ preds = (probs >= 0.5).float()
4467
+ tp = (preds * ci_targets).sum()
4468
+ fp = (preds * (1 - ci_targets)).sum()
4469
+ fn = ((1 - preds) * ci_targets).sum()
4470
+ prec = tp / (tp + fp + 1e-8)
4471
+ rec = tp / (tp + fn + 1e-8)
4472
+ f1 = 2 * prec * rec / (prec + rec + 1e-8)
4473
+ if f1.item() > best_f1:
4474
+ best_f1 = f1.item()
4475
+ best_t = t_cand
4476
+ temperatures[ci] = best_t
4477
+
4478
+ temp_dict = {
4479
+ class_names[ci]: round(float(temperatures[ci]), 3)
4480
+ for ci in range(C_cal)
4481
+ }
4482
+ if head_metadata:
4483
+ head_metadata["ovr_temperatures"] = temp_dict
4484
+ if log_fn:
4485
+ t_str = ", ".join(f"{k}={v}" for k, v in temp_dict.items())
4486
+ log_fn(f"OvR per-head calibrated temperatures: {t_str}")
4487
+ elif log_fn:
4488
+ log_fn("Skipping OvR temperature calibration: too few valid validation frames.")
4489
+ elif log_fn:
4490
+ log_fn("Skipping OvR temperature calibration: no frame logits from validation.")
4491
+ except Exception as cal_err:
4492
+ if log_fn:
4493
+ log_fn(f"OvR temperature calibration failed (non-fatal): {cal_err}")
4494
+
4495
+ if head_metadata:
4496
+ head_metadata["training_config"]["best_val_f1"] = best_val_f1 if best_val_f1 >= 0 else None
4497
+ head_metadata["training_config"]["best_val_frame_acc"] = best_val_frame_acc
4498
+ # Legacy field retained for compatibility with existing readers.
4499
+ head_metadata["training_config"]["best_val_acc"] = best_val_frame_acc
4500
+ per_class_final = {}
4501
+ for idx in range(len(class_names)):
4502
+ class_key = class_key_map.get(idx)
4503
+ if class_key in history and history[class_key]:
4504
+ per_class_final[class_label_map[idx]] = history[class_key][-1]
4505
+ head_metadata["training_config"]["final_epoch_val_f1_per_class"] = per_class_final
4506
+
4507
+ # Apply EMA weights for final model save (if EMA was activated)
4508
+ if ema_active:
4509
+ _apply_ema()
4510
+
4511
+ try:
4512
+ if head_metadata:
4513
+ model.save_head(config["output_path"], metadata=head_metadata)
4514
+ else:
4515
+ model.save_head(config["output_path"])
4516
+ if log_fn:
4517
+ tag = " (EMA-averaged)" if ema_active else ""
4518
+ log_fn(f"Saved final model{tag} to {config['output_path']}")
4519
+
4520
+ # --- Save Training Logs and Plots ---
4521
+ import pandas as pd
4522
+ from datetime import datetime
4523
+ import matplotlib
4524
+ matplotlib.use('Agg')
4525
+ import matplotlib.pyplot as plt
4526
+ import json
4527
+
4528
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4529
+ output_dir = os.path.dirname(config["output_path"])
4530
+ output_basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
4531
+
4532
+ # Save training config snapshot for inference resolution fallback
4533
+ training_config_path = os.path.join(output_dir, f"{output_basename}_training_config.json")
4534
+ training_snapshot = {
4535
+ "classes": class_names,
4536
+ "attributes": getattr(train_dataset, "attributes", []),
4537
+ "clip_length": clip_length_value,
4538
+ "resolution": resolution_value,
4539
+ "backbone_model": config.get("backbone_model", "videoprism_public_v1_base"),
4540
+ "training_config": head_metadata.get("training_config", {}) if head_metadata else {},
4541
+ }
4542
+ with open(training_config_path, "w", encoding="utf-8") as cfg_file:
4543
+ json.dump(training_snapshot, cfg_file, indent=2)
4544
+ if log_fn:
4545
+ log_fn(f"Saved training config to {training_config_path}")
4546
+
4547
+ # Save CSV log
4548
+ csv_path = os.path.join(output_dir, f"{output_basename}_training_log_{timestamp}.csv")
4549
+ df = pd.DataFrame(history)
4550
+ df.to_csv(csv_path, index=False)
4551
+ if log_fn:
4552
+ log_fn(f"Saved training log to {csv_path}")
4553
+
4554
+ # Generate Plots
4555
+ plt.style.use('ggplot')
4556
+ fig, axes = plt.subplots(4, 1, figsize=(10, 18))
4557
+ ax1, ax2, ax3, ax4 = axes
4558
+ epochs_hist = history['epoch']
4559
+
4560
+ ax1.plot(epochs_hist, history['train_acc'], label='Train Accuracy', marker='o')
4561
+ ax1.plot(epochs_hist, history['val_acc'], label='Val Acc (frame)', marker='s')
4562
+ ax1.set_title(f'Training Accuracy - {output_basename}')
4563
+ ax1.set_ylabel('Accuracy (%)')
4564
+ ax1.legend()
4565
+ ax1.grid(True)
4566
+
4567
+ ax2.plot(epochs_hist, history['train_loss'], label='Train Loss', marker='o')
4568
+ ax2.plot(epochs_hist, history['val_loss'], label='Val Loss', marker='s')
4569
+ ax2.set_title(f'Training Loss - {output_basename}')
4570
+ ax2.set_ylabel('Loss')
4571
+ ax2.legend()
4572
+ ax2.grid(True)
4573
+
4574
+ ax3.plot(epochs_hist, history['val_f1'], label='Val Macro F1 (frame)', linewidth=2, color='tab:purple')
4575
+ for idx in range(len(class_names)):
4576
+ class_key = class_key_map.get(idx)
4577
+ if class_key in history:
4578
+ ax3.plot(
4579
+ epochs_hist,
4580
+ history[class_key],
4581
+ label=f"{class_names[idx]}",
4582
+ linestyle='--',
4583
+ alpha=0.6
4584
+ )
4585
+ ax3.set_ylabel('F1 (%)')
4586
+ ax3.legend(ncol=2, fontsize=8)
4587
+ ax3.grid(True)
4588
+
4589
+ per_class_keys_ordered = [class_key_map[idx] for idx in range(len(class_names))]
4590
+ if per_class_keys_ordered:
4591
+ per_class_matrix = np.array([
4592
+ history[key] for key in per_class_keys_ordered
4593
+ ])
4594
+ else:
4595
+ per_class_matrix = np.zeros((0, len(epochs_hist)))
4596
+
4597
+ im = ax4.imshow(
4598
+ per_class_matrix,
4599
+ aspect='auto',
4600
+ cmap='magma',
4601
+ vmin=0,
4602
+ vmax=100
4603
+ )
4604
+ ax4.set_yticks(range(len(class_names)))
4605
+ ax4.set_yticklabels(class_names)
4606
+ ax4.set_xlabel('Epoch')
4607
+ ax4.set_ylabel('Class')
4608
+ ax4.set_title('Validation F1 Heatmap (%)')
4609
+ if epochs_hist:
4610
+ max_ticks = min(len(epochs_hist), 12)
4611
+ tick_positions = np.linspace(0, len(epochs_hist) - 1, max_ticks, dtype=int)
4612
+ ax4.set_xticks(tick_positions)
4613
+ ax4.set_xticklabels([str(epochs_hist[i]) for i in tick_positions])
4614
+ cbar = fig.colorbar(im, ax=ax4, orientation='vertical', pad=0.01)
4615
+ cbar.set_label('F1 (%)')
4616
+
4617
+ # Save Plot as PDF
4618
+ plot_path = os.path.join(output_dir, f"{output_basename}_training_plot_{timestamp}.pdf")
4619
+ plt.tight_layout()
4620
+ plt.savefig(plot_path)
4621
+ plt.close()
4622
+
4623
+ if log_fn:
4624
+ log_fn(f"Saved training plot to {plot_path}")
4625
+
4626
+ except Exception as e:
4627
+ error_msg = f"Error saving model/logs: {str(e)}\n{traceback.format_exc()}"
4628
+ if log_fn:
4629
+ log_fn(f"ERROR: {error_msg}")
4630
+
4631
+ final_train_acc = history["train_acc"][-1] if history["train_acc"] else 0.0
4632
+ best_val_f1_out = best_val_f1 if best_val_f1 >= 0 else 0.0
4633
+
4634
+ final_per_class_f1 = {}
4635
+ for idx in range(len(class_names)):
4636
+ if class_names[idx] in _f1_exclude_names:
4637
+ continue
4638
+ class_key = class_key_map.get(idx)
4639
+ if class_key in history and history[class_key]:
4640
+ final_per_class_f1[class_label_map[idx]] = history[class_key][-1]
4641
+
4642
+ return {
4643
+ "best_val_acc": best_val_frame_acc,
4644
+ "best_val_frame_acc": best_val_frame_acc,
4645
+ "best_val_f1": best_val_f1_out,
4646
+ "final_train_acc": final_train_acc,
4647
+ "per_class_f1": final_per_class_f1
4648
+ }
4649
+
4650
+ except Exception as e:
4651
+ error_msg = f"Training failed: {str(e)}\n{traceback.format_exc()}"
4652
+ if log_fn:
4653
+ log_fn(f"FATAL ERROR: {error_msg}")
4654
+ raise RuntimeError(error_msg) from e
4655
+
4656
+ finally:
4657
+ # Clean up embedding caches — they are only valid for this training run
4658
+ # (specific backbone weights, augmentation settings, dataset split).
4659
+ import shutil
4660
+ for _cache_path in [train_emb_cache, val_emb_cache_dir]:
4661
+ if _cache_path and os.path.isdir(_cache_path):
4662
+ try:
4663
+ shutil.rmtree(_cache_path)
4664
+ if log_fn:
4665
+ log_fn(f"Cleaned up embedding cache: {_cache_path}")
4666
+ except Exception as e:
4667
+ logger.debug("Could not remove embedding cache %s: %s", _cache_path, e)