singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,4667 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
|
7
|
+
from torch.optim import AdamW
|
|
8
|
+
from typing import Callable, Optional, Dict, Any
|
|
9
|
+
import numpy as np
|
|
10
|
+
from sklearn.model_selection import train_test_split
|
|
11
|
+
from sklearn.metrics import f1_score
|
|
12
|
+
from collections import Counter
|
|
13
|
+
import random
|
|
14
|
+
import math
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _slugify_class_name(name: str) -> str:
|
|
21
|
+
"""Create a filesystem/column-safe slug for class names."""
|
|
22
|
+
slug = re.sub(r'[^0-9a-zA-Z]+', '_', name).strip('_').lower()
|
|
23
|
+
return slug or "class"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FocalLoss(nn.Module):
|
|
27
|
+
"""
|
|
28
|
+
Focal Loss for dense object detection and classification.
|
|
29
|
+
focuses training on hard examples and down-weights easy ones.
|
|
30
|
+
Loss = -alpha * (1 - pt)^gamma * log(pt)
|
|
31
|
+
"""
|
|
32
|
+
def __init__(self, gamma: float = 2.0, alpha: Optional[torch.Tensor] = None, reduction: str = 'mean'):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.gamma = gamma
|
|
35
|
+
self.alpha = alpha
|
|
36
|
+
self.reduction = reduction
|
|
37
|
+
|
|
38
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
Args:
|
|
41
|
+
inputs: [B, C] logits (not probabilities)
|
|
42
|
+
targets: [B] labels (use -100 to ignore)
|
|
43
|
+
"""
|
|
44
|
+
valid = targets >= 0
|
|
45
|
+
ce_loss = F.cross_entropy(inputs, targets.clamp(min=0), reduction='none', weight=self.alpha)
|
|
46
|
+
pt = torch.exp(-ce_loss)
|
|
47
|
+
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
|
48
|
+
focal_loss = focal_loss * valid.float()
|
|
49
|
+
|
|
50
|
+
if self.reduction == 'mean':
|
|
51
|
+
n = valid.sum().clamp(min=1)
|
|
52
|
+
return focal_loss.sum() / n
|
|
53
|
+
elif self.reduction == 'sum':
|
|
54
|
+
return focal_loss.sum()
|
|
55
|
+
else:
|
|
56
|
+
return focal_loss
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BinaryFocalLoss(nn.Module):
|
|
60
|
+
"""Per-element binary focal loss for OvR heads."""
|
|
61
|
+
def __init__(self, gamma: float = 2.0, reduction: str = 'mean'):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.gamma = gamma
|
|
64
|
+
self.reduction = reduction
|
|
65
|
+
|
|
66
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor,
|
|
67
|
+
weight: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
68
|
+
"""
|
|
69
|
+
Args:
|
|
70
|
+
inputs: [B, C] logits
|
|
71
|
+
targets: [B, C] binary targets (0 or 1)
|
|
72
|
+
weight: optional [B, C] per-element weight
|
|
73
|
+
"""
|
|
74
|
+
bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
|
75
|
+
pt = torch.exp(-bce)
|
|
76
|
+
focal = ((1 - pt) ** self.gamma) * bce
|
|
77
|
+
if weight is not None:
|
|
78
|
+
focal = focal * weight
|
|
79
|
+
if self.reduction == 'mean':
|
|
80
|
+
return focal.mean()
|
|
81
|
+
elif self.reduction == 'sum':
|
|
82
|
+
return focal.sum()
|
|
83
|
+
return focal
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AsymmetricLoss(nn.Module):
|
|
87
|
+
"""Asymmetric Loss for multi-label / OvR classification.
|
|
88
|
+
|
|
89
|
+
Applies focal-style down-weighting only to *negatives* (gamma_neg),
|
|
90
|
+
while keeping full gradient signal from positives (gamma_pos, default 0).
|
|
91
|
+
Optionally applies probability shifting (clip) to hard-threshold easy
|
|
92
|
+
negatives, further reducing their contribution.
|
|
93
|
+
|
|
94
|
+
Reference: Ridnik et al., "Asymmetric Loss For Multi-Label Classification"
|
|
95
|
+
(https://arxiv.org/abs/2009.14119)
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
gamma_neg: float = 4.0,
|
|
101
|
+
gamma_pos: float = 0.0,
|
|
102
|
+
clip: float = 0.05,
|
|
103
|
+
reduction: str = "mean",
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.gamma_neg = gamma_neg
|
|
107
|
+
self.gamma_pos = gamma_pos
|
|
108
|
+
self.clip = clip
|
|
109
|
+
self.reduction = reduction
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self,
|
|
113
|
+
inputs: torch.Tensor,
|
|
114
|
+
targets: torch.Tensor,
|
|
115
|
+
weight: Optional[torch.Tensor] = None,
|
|
116
|
+
) -> torch.Tensor:
|
|
117
|
+
"""
|
|
118
|
+
Args:
|
|
119
|
+
inputs: [*, C] raw logits
|
|
120
|
+
targets: [*, C] binary targets in [0, 1]
|
|
121
|
+
weight: optional [*, C] per-element weight
|
|
122
|
+
"""
|
|
123
|
+
p = torch.sigmoid(inputs)
|
|
124
|
+
|
|
125
|
+
# Derive hard targets to prevent ASL from over-penalizing easy negatives under label smoothing.
|
|
126
|
+
hard_targets = (targets >= 0.5).float()
|
|
127
|
+
|
|
128
|
+
pos_part = hard_targets * torch.log(p.clamp(min=1e-8))
|
|
129
|
+
neg_p = 1.0 - p
|
|
130
|
+
|
|
131
|
+
# Probability shifting: suppress easy negatives to reduce their gradient contribution
|
|
132
|
+
if self.clip > 0:
|
|
133
|
+
neg_p = (neg_p + self.clip).clamp(max=1.0)
|
|
134
|
+
|
|
135
|
+
neg_part = (1.0 - hard_targets) * torch.log(neg_p.clamp(min=1e-8))
|
|
136
|
+
|
|
137
|
+
# Asymmetric focusing
|
|
138
|
+
if self.gamma_pos > 0:
|
|
139
|
+
pos_part = pos_part * ((1.0 - p) ** self.gamma_pos)
|
|
140
|
+
if self.gamma_neg > 0:
|
|
141
|
+
neg_part = neg_part * (p ** self.gamma_neg)
|
|
142
|
+
|
|
143
|
+
loss = -(pos_part + neg_part)
|
|
144
|
+
|
|
145
|
+
if weight is not None:
|
|
146
|
+
loss = loss * weight
|
|
147
|
+
if self.reduction == "mean":
|
|
148
|
+
return loss.mean()
|
|
149
|
+
elif self.reduction == "sum":
|
|
150
|
+
return loss.sum()
|
|
151
|
+
return loss
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class BehaviorDataset(Dataset):
|
|
156
|
+
"""PyTorch Dataset for behavior clips."""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
clips: list,
|
|
161
|
+
annotation_manager,
|
|
162
|
+
classes: list,
|
|
163
|
+
clip_base_dir: str,
|
|
164
|
+
transform: Optional[Callable] = None,
|
|
165
|
+
target_size: tuple[int, int] = (288, 288),
|
|
166
|
+
clip_length: int = 16,
|
|
167
|
+
virtual_size_multiplier: int = 1,
|
|
168
|
+
grid_size: int = 0,
|
|
169
|
+
stitch_prob: float = 0.0,
|
|
170
|
+
crop_jitter: bool = False,
|
|
171
|
+
crop_jitter_strength: float = 0.15,
|
|
172
|
+
ovr_background_classes: Optional[list[str]] = None,
|
|
173
|
+
):
|
|
174
|
+
self.clips = clips
|
|
175
|
+
self.annotation_manager = annotation_manager
|
|
176
|
+
self.raw_classes = classes # Original labels
|
|
177
|
+
self.clip_base_dir = clip_base_dir
|
|
178
|
+
self.transform = transform
|
|
179
|
+
self.target_size = target_size
|
|
180
|
+
self.clip_length = clip_length
|
|
181
|
+
self.virtual_size_multiplier = virtual_size_multiplier
|
|
182
|
+
self.grid_size = grid_size # Spatial grid size (e.g. 16 for 288px). 0 = disabled.
|
|
183
|
+
# Masks stored at 16×16 reference grid; rescaled to training grid_size if different.
|
|
184
|
+
STORED_GRID = 16
|
|
185
|
+
self.spatial_masks = []
|
|
186
|
+
self.spatial_bboxes = []
|
|
187
|
+
self.spatial_bbox_valid = []
|
|
188
|
+
num_patches = grid_size * grid_size if grid_size > 0 else 0
|
|
189
|
+
for clip in clips:
|
|
190
|
+
clip_id = clip["id"]
|
|
191
|
+
mask_indices = clip.get("spatial_mask")
|
|
192
|
+
if mask_indices and grid_size > 0:
|
|
193
|
+
# Build mask at stored 16×16 resolution
|
|
194
|
+
ref_mask = torch.zeros(STORED_GRID, STORED_GRID, dtype=torch.float32)
|
|
195
|
+
for idx in mask_indices:
|
|
196
|
+
if 0 <= idx < STORED_GRID * STORED_GRID:
|
|
197
|
+
row, col = divmod(idx, STORED_GRID)
|
|
198
|
+
ref_mask[row, col] = 1.0
|
|
199
|
+
|
|
200
|
+
if grid_size != STORED_GRID:
|
|
201
|
+
# Rescale to training grid size via bilinear interpolation
|
|
202
|
+
ref_4d = ref_mask.unsqueeze(0).unsqueeze(0) # [1,1,16,16]
|
|
203
|
+
scaled = torch.nn.functional.interpolate(
|
|
204
|
+
ref_4d, size=(grid_size, grid_size), mode='bilinear', align_corners=False
|
|
205
|
+
)
|
|
206
|
+
mask = (scaled.squeeze() > 0.25).float() # threshold to keep binary
|
|
207
|
+
else:
|
|
208
|
+
mask = ref_mask
|
|
209
|
+
|
|
210
|
+
self.spatial_masks.append(mask.reshape(-1)) # flatten to [G*G]
|
|
211
|
+
else:
|
|
212
|
+
self.spatial_masks.append(torch.zeros(max(1, num_patches), dtype=torch.float32))
|
|
213
|
+
|
|
214
|
+
# Per-frame bboxes: [T, 4] and [T] validity
|
|
215
|
+
T = self.clip_length
|
|
216
|
+
bbox_frames_data = clip.get("spatial_bbox_frames")
|
|
217
|
+
if bbox_frames_data and isinstance(bbox_frames_data, (list, tuple)):
|
|
218
|
+
frame_boxes = torch.zeros(T, 4, dtype=torch.float32)
|
|
219
|
+
frame_valid = torch.zeros(T, dtype=torch.float32)
|
|
220
|
+
for fi in range(min(T, len(bbox_frames_data))):
|
|
221
|
+
b = bbox_frames_data[fi]
|
|
222
|
+
if b and isinstance(b, (list, tuple)) and len(b) == 4:
|
|
223
|
+
x1, y1, x2, y2 = [max(0.0, min(1.0, float(v))) for v in b]
|
|
224
|
+
if x2 > x1 and y2 > y1:
|
|
225
|
+
frame_boxes[fi] = torch.tensor([x1, y1, x2, y2])
|
|
226
|
+
frame_valid[fi] = 1.0
|
|
227
|
+
self.spatial_bboxes.append(frame_boxes)
|
|
228
|
+
self.spatial_bbox_valid.append(frame_valid)
|
|
229
|
+
else:
|
|
230
|
+
# Legacy single bbox: replicate to all frames
|
|
231
|
+
bbox = clip.get("spatial_bbox")
|
|
232
|
+
if bbox and isinstance(bbox, (list, tuple)) and len(bbox) == 4:
|
|
233
|
+
x1, y1, x2, y2 = [max(0.0, min(1.0, float(v))) for v in bbox]
|
|
234
|
+
if x2 > x1 and y2 > y1:
|
|
235
|
+
single = torch.tensor([x1, y1, x2, y2], dtype=torch.float32)
|
|
236
|
+
self.spatial_bboxes.append(single.unsqueeze(0).expand(T, -1).clone())
|
|
237
|
+
# Mark only frame 0 valid: legacy single-frame annotations shouldn't penalize per-frame tracking.
|
|
238
|
+
legacy_valid = torch.zeros(T, dtype=torch.float32)
|
|
239
|
+
legacy_valid[0] = 1.0
|
|
240
|
+
self.spatial_bbox_valid.append(legacy_valid)
|
|
241
|
+
else:
|
|
242
|
+
self.spatial_bboxes.append(torch.zeros(T, 4, dtype=torch.float32))
|
|
243
|
+
self.spatial_bbox_valid.append(torch.zeros(T, dtype=torch.float32))
|
|
244
|
+
else:
|
|
245
|
+
self.spatial_bboxes.append(torch.zeros(T, 4, dtype=torch.float32))
|
|
246
|
+
self.spatial_bbox_valid.append(torch.zeros(T, dtype=torch.float32))
|
|
247
|
+
|
|
248
|
+
self.classes = classes
|
|
249
|
+
self.attributes = []
|
|
250
|
+
self.class_to_idx = {c: i for i, c in enumerate(classes)}
|
|
251
|
+
self.attr_to_idx = {}
|
|
252
|
+
self.ovr_background_classes = set(ovr_background_classes or [])
|
|
253
|
+
# Primary label index per clip (-1 if not in class list, e.g. near_negative_*).
|
|
254
|
+
self.labels = []
|
|
255
|
+
# Multi-label: list of all label indices per clip (for OvR multi-hot targets).
|
|
256
|
+
self.multi_labels = []
|
|
257
|
+
for clip in clips:
|
|
258
|
+
clip_labels = clip.get("labels")
|
|
259
|
+
if not isinstance(clip_labels, list) or not clip_labels:
|
|
260
|
+
clip_labels = [clip.get("label", "")]
|
|
261
|
+
indices = [self.class_to_idx[l] for l in clip_labels if l in self.class_to_idx]
|
|
262
|
+
if indices:
|
|
263
|
+
self.labels.append(indices[0])
|
|
264
|
+
self.multi_labels.append(indices)
|
|
265
|
+
else:
|
|
266
|
+
self.labels.append(-1)
|
|
267
|
+
self.multi_labels.append([])
|
|
268
|
+
self.class_labels = self.labels
|
|
269
|
+
self.attr_labels = []
|
|
270
|
+
|
|
271
|
+
# OvR: for each clip, store the matched suppression class index (or -1).
|
|
272
|
+
# near_negative_X → suppress class X; also check hard_negative_for_class metadata.
|
|
273
|
+
self.ovr_suppress_idx = []
|
|
274
|
+
for clip in clips:
|
|
275
|
+
suppress = -1
|
|
276
|
+
label = clip.get("label", "")
|
|
277
|
+
meta = clip.get("meta") or {}
|
|
278
|
+
hn_for = meta.get("hard_negative_for_class")
|
|
279
|
+
if hn_for and hn_for in self.class_to_idx:
|
|
280
|
+
suppress = self.class_to_idx[hn_for]
|
|
281
|
+
elif label.startswith("near_negative_"):
|
|
282
|
+
suffix = label[len("near_negative_"):]
|
|
283
|
+
for cls_name, cls_idx in self.class_to_idx.items():
|
|
284
|
+
if cls_name == suffix or cls_name.replace(" ", "_") == suffix:
|
|
285
|
+
suppress = cls_idx
|
|
286
|
+
break
|
|
287
|
+
self.ovr_suppress_idx.append(suppress)
|
|
288
|
+
|
|
289
|
+
# Per-frame labels: [T] tensor per clip.
|
|
290
|
+
# - If explicit frame labels exist, use them.
|
|
291
|
+
# - Else if a valid clip class exists, supervise all frames with that class.
|
|
292
|
+
# - Else use -1 (ignored by frame loss).
|
|
293
|
+
# has_real_frame_labels is used by training to route samples to frame
|
|
294
|
+
# supervision (and exclude them from clip-loss metrics/path).
|
|
295
|
+
self.frame_labels = []
|
|
296
|
+
self.ovr_background_frame_mask = []
|
|
297
|
+
self.ovr_background_clip = []
|
|
298
|
+
self.has_real_frame_labels = []
|
|
299
|
+
for i, clip in enumerate(clips):
|
|
300
|
+
clip_fl = clip.get("frame_labels")
|
|
301
|
+
T = self.clip_length
|
|
302
|
+
primary = self.labels[i]
|
|
303
|
+
raw_label = clip.get("label", "")
|
|
304
|
+
clip_is_background = raw_label in self.ovr_background_classes
|
|
305
|
+
if clip_fl and isinstance(clip_fl, (list, tuple)) and len(clip_fl) > 0:
|
|
306
|
+
fl = []
|
|
307
|
+
bg = []
|
|
308
|
+
for lbl_name in clip_fl:
|
|
309
|
+
if lbl_name is None:
|
|
310
|
+
fl.append(-1)
|
|
311
|
+
bg.append(False)
|
|
312
|
+
elif isinstance(lbl_name, str):
|
|
313
|
+
if lbl_name in self.ovr_background_classes:
|
|
314
|
+
fl.append(-1)
|
|
315
|
+
bg.append(True)
|
|
316
|
+
else:
|
|
317
|
+
fl.append(self.class_to_idx.get(lbl_name, -1))
|
|
318
|
+
bg.append(False)
|
|
319
|
+
else:
|
|
320
|
+
v = int(lbl_name)
|
|
321
|
+
if 0 <= v < len(self.class_to_idx):
|
|
322
|
+
fl.append(v)
|
|
323
|
+
bg.append(False)
|
|
324
|
+
else:
|
|
325
|
+
fl.append(-1)
|
|
326
|
+
bg.append(False)
|
|
327
|
+
fl_tensor = torch.tensor(fl, dtype=torch.long)
|
|
328
|
+
bg_tensor = torch.tensor(bg, dtype=torch.bool)
|
|
329
|
+
if len(fl_tensor) < T:
|
|
330
|
+
fl_tensor = torch.cat([fl_tensor, fl_tensor[-1:].expand(T - len(fl_tensor))])
|
|
331
|
+
bg_tensor = torch.cat([bg_tensor, bg_tensor[-1:].expand(T - len(bg_tensor))])
|
|
332
|
+
elif len(fl_tensor) > T:
|
|
333
|
+
start = (len(fl_tensor) - T) // 2
|
|
334
|
+
fl_tensor = fl_tensor[start:start + T]
|
|
335
|
+
bg_tensor = bg_tensor[start:start + T]
|
|
336
|
+
# Safety: convert any out-of-range numeric label to ignore (-1).
|
|
337
|
+
fl_tensor[(fl_tensor < 0) | (fl_tensor >= len(self.class_to_idx))] = -1
|
|
338
|
+
self.frame_labels.append(fl_tensor)
|
|
339
|
+
self.ovr_background_frame_mask.append(bg_tensor)
|
|
340
|
+
self.ovr_background_clip.append(bool(bg_tensor.any().item()) or clip_is_background)
|
|
341
|
+
self.has_real_frame_labels.append(True)
|
|
342
|
+
else:
|
|
343
|
+
if primary >= 0:
|
|
344
|
+
self.frame_labels.append(torch.full((T,), int(primary), dtype=torch.long))
|
|
345
|
+
self.ovr_background_frame_mask.append(torch.zeros(T, dtype=torch.bool))
|
|
346
|
+
self.ovr_background_clip.append(False)
|
|
347
|
+
self.has_real_frame_labels.append(True)
|
|
348
|
+
else:
|
|
349
|
+
bg_tensor = torch.full((T,), bool(clip_is_background), dtype=torch.bool)
|
|
350
|
+
self.frame_labels.append(torch.full((T,), -1, dtype=torch.long))
|
|
351
|
+
self.ovr_background_frame_mask.append(bg_tensor)
|
|
352
|
+
self.ovr_background_clip.append(bool(clip_is_background))
|
|
353
|
+
# Background clips have real negative supervision in hybrid OvR mode.
|
|
354
|
+
self.has_real_frame_labels.append(bool(clip_is_background))
|
|
355
|
+
|
|
356
|
+
# Clip-stitching augmentation params.
|
|
357
|
+
self.stitch_prob = float(stitch_prob)
|
|
358
|
+
self.stitch_exclude_classes: set[int] = set()
|
|
359
|
+
# Map class index → list of clip indices for fast same/different-class sampling.
|
|
360
|
+
self._label_to_clip_indices: dict[int, list[int]] = {}
|
|
361
|
+
for i, lbl in enumerate(self.labels):
|
|
362
|
+
if lbl >= 0:
|
|
363
|
+
self._label_to_clip_indices.setdefault(lbl, []).append(i)
|
|
364
|
+
|
|
365
|
+
# Crop jitter augmentation (only used with ROI cache).
|
|
366
|
+
self.crop_jitter = bool(crop_jitter)
|
|
367
|
+
self.crop_jitter_strength = float(crop_jitter_strength)
|
|
368
|
+
|
|
369
|
+
# Runtime toggles used by training curriculum.
|
|
370
|
+
self._roi_cache_mode = False
|
|
371
|
+
self._roi_cache_dir = None
|
|
372
|
+
# Embedding-space stitch cache (backbone tokens pre-computed).
|
|
373
|
+
self._emb_cache_mode = False
|
|
374
|
+
self._emb_cache_dir = None
|
|
375
|
+
self._emb_clip_length = self.clip_length
|
|
376
|
+
|
|
377
|
+
def __len__(self):
|
|
378
|
+
return len(self.clips) * self.virtual_size_multiplier
|
|
379
|
+
|
|
380
|
+
def _resolve_clip_path(self, clip_id: str) -> str:
|
|
381
|
+
clip_path = os.path.join(self.clip_base_dir, clip_id)
|
|
382
|
+
found = False
|
|
383
|
+
if os.path.exists(clip_path):
|
|
384
|
+
found = True
|
|
385
|
+
else:
|
|
386
|
+
base_name, ext = os.path.splitext(clip_id)
|
|
387
|
+
clip_basename = os.path.basename(clip_id)
|
|
388
|
+
clip_dir_part = os.path.dirname(clip_id) if os.path.dirname(clip_id) else None
|
|
389
|
+
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
|
|
390
|
+
|
|
391
|
+
if not ext:
|
|
392
|
+
for video_ext in video_extensions:
|
|
393
|
+
test_path = os.path.join(self.clip_base_dir, clip_id + video_ext)
|
|
394
|
+
if os.path.exists(test_path):
|
|
395
|
+
clip_path = test_path
|
|
396
|
+
found = True
|
|
397
|
+
break
|
|
398
|
+
else:
|
|
399
|
+
base_name_only = os.path.basename(base_name)
|
|
400
|
+
ext_lower = ext.lower()
|
|
401
|
+
|
|
402
|
+
for video_ext in video_extensions:
|
|
403
|
+
if video_ext.lower() == ext_lower:
|
|
404
|
+
continue
|
|
405
|
+
test_path = os.path.join(self.clip_base_dir, base_name + video_ext)
|
|
406
|
+
if os.path.exists(test_path):
|
|
407
|
+
clip_path = test_path
|
|
408
|
+
found = True
|
|
409
|
+
break
|
|
410
|
+
|
|
411
|
+
if not found:
|
|
412
|
+
for video_ext in video_extensions:
|
|
413
|
+
test_path = os.path.join(self.clip_base_dir, base_name_only + video_ext)
|
|
414
|
+
if os.path.exists(test_path):
|
|
415
|
+
clip_path = test_path
|
|
416
|
+
found = True
|
|
417
|
+
break
|
|
418
|
+
|
|
419
|
+
if not found:
|
|
420
|
+
for root, dirs, files in os.walk(self.clip_base_dir):
|
|
421
|
+
for file in files:
|
|
422
|
+
file_base, file_ext = os.path.splitext(file)
|
|
423
|
+
if file_base == base_name_only or file_base == base_name:
|
|
424
|
+
if file_ext.lower() in [e.lower() for e in video_extensions]:
|
|
425
|
+
clip_path = os.path.join(root, file)
|
|
426
|
+
found = True
|
|
427
|
+
break
|
|
428
|
+
if found:
|
|
429
|
+
break
|
|
430
|
+
|
|
431
|
+
if not found and clip_dir_part:
|
|
432
|
+
subdir_path = os.path.join(self.clip_base_dir, clip_dir_part)
|
|
433
|
+
if os.path.exists(subdir_path):
|
|
434
|
+
for video_ext in video_extensions:
|
|
435
|
+
test_path = os.path.join(subdir_path, clip_basename)
|
|
436
|
+
if os.path.exists(test_path):
|
|
437
|
+
clip_path = test_path
|
|
438
|
+
found = True
|
|
439
|
+
break
|
|
440
|
+
test_path = os.path.join(subdir_path, base_name_only + video_ext)
|
|
441
|
+
if os.path.exists(test_path):
|
|
442
|
+
clip_path = test_path
|
|
443
|
+
found = True
|
|
444
|
+
break
|
|
445
|
+
|
|
446
|
+
if not found:
|
|
447
|
+
error_msg = f"Clip not found: {clip_path}\n"
|
|
448
|
+
error_msg += f"Clip ID from annotation: {clip_id}\n"
|
|
449
|
+
error_msg += f"Base directory: {self.clip_base_dir}\n"
|
|
450
|
+
error_msg += "Please check if the file exists or update the annotation."
|
|
451
|
+
raise FileNotFoundError(error_msg)
|
|
452
|
+
return clip_path
|
|
453
|
+
|
|
454
|
+
_NATIVE_RES = object() # Sentinel: load at native resolution (no resize)
|
|
455
|
+
|
|
456
|
+
def _load_clip(self, clip_path: str, target_size=None, apply_transform: bool = True,
|
|
457
|
+
apply_aoi: bool = True) -> torch.Tensor:
|
|
458
|
+
"""Load and preprocess a video clip.
|
|
459
|
+
target_size: (w,h) to resize to, _NATIVE_RES for no resize, None for default self.target_size.
|
|
460
|
+
apply_aoi: unused (kept for API compatibility).
|
|
461
|
+
"""
|
|
462
|
+
import cv2
|
|
463
|
+
|
|
464
|
+
cap = cv2.VideoCapture(clip_path)
|
|
465
|
+
frames = []
|
|
466
|
+
if target_size is self._NATIVE_RES:
|
|
467
|
+
resize_to = None
|
|
468
|
+
elif target_size is not None:
|
|
469
|
+
resize_to = target_size
|
|
470
|
+
else:
|
|
471
|
+
resize_to = self.target_size
|
|
472
|
+
|
|
473
|
+
while True:
|
|
474
|
+
ret, frame = cap.read()
|
|
475
|
+
if not ret:
|
|
476
|
+
break
|
|
477
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
478
|
+
if resize_to is not None:
|
|
479
|
+
h_src, w_src = frame.shape[:2]
|
|
480
|
+
is_upscale = (w_src < resize_to[0]) or (h_src < resize_to[1])
|
|
481
|
+
interp = cv2.INTER_LANCZOS4 if is_upscale else cv2.INTER_AREA
|
|
482
|
+
frame = cv2.resize(frame, resize_to, interpolation=interp)
|
|
483
|
+
if is_upscale:
|
|
484
|
+
blurred = cv2.GaussianBlur(frame, (0, 0), sigmaX=1.0)
|
|
485
|
+
frame = cv2.addWeighted(frame, 1.5, blurred, -0.5, 0)
|
|
486
|
+
frames.append(frame)
|
|
487
|
+
|
|
488
|
+
cap.release()
|
|
489
|
+
|
|
490
|
+
while len(frames) < self.clip_length:
|
|
491
|
+
if frames:
|
|
492
|
+
frames.append(frames[-1])
|
|
493
|
+
else:
|
|
494
|
+
fallback_size = resize_to if resize_to is not None else self.target_size
|
|
495
|
+
frames.append(np.zeros((*fallback_size[::-1], 3), dtype=np.uint8))
|
|
496
|
+
|
|
497
|
+
if len(frames) > self.clip_length:
|
|
498
|
+
# Center-crop temporally: pick the middle clip_length frames
|
|
499
|
+
start = (len(frames) - self.clip_length) // 2
|
|
500
|
+
frames = frames[start:start + self.clip_length]
|
|
501
|
+
|
|
502
|
+
clip_array = np.stack(frames).astype(np.float32) / 255.0
|
|
503
|
+
clip_tensor = torch.from_numpy(clip_array)
|
|
504
|
+
clip_tensor = clip_tensor.permute(0, 3, 1, 2)
|
|
505
|
+
|
|
506
|
+
if apply_transform and self.transform:
|
|
507
|
+
clip_tensor = self.transform(clip_tensor)
|
|
508
|
+
|
|
509
|
+
return clip_tensor
|
|
510
|
+
|
|
511
|
+
def load_fullres_clip_by_index(self, actual_idx: int, apply_aoi: bool = False) -> torch.Tensor:
|
|
512
|
+
"""Load clip at native resolution without augmentation.
|
|
513
|
+
|
|
514
|
+
By default skip AOI crop (localization needs the original frame).
|
|
515
|
+
"""
|
|
516
|
+
clip_info = self.clips[actual_idx]
|
|
517
|
+
clip_id = clip_info["id"]
|
|
518
|
+
clip_path = self._resolve_clip_path(clip_id)
|
|
519
|
+
return self._load_clip(clip_path, target_size=self._NATIVE_RES, apply_transform=False, apply_aoi=apply_aoi)
|
|
520
|
+
|
|
521
|
+
def load_modelres_clip_by_index(self, actual_idx: int) -> torch.Tensor:
|
|
522
|
+
"""Load clip at model input resolution without augmentation."""
|
|
523
|
+
clip_info = self.clips[actual_idx]
|
|
524
|
+
clip_id = clip_info["id"]
|
|
525
|
+
clip_path = self._resolve_clip_path(clip_id)
|
|
526
|
+
return self._load_clip(clip_path, target_size=self.target_size, apply_transform=False)
|
|
527
|
+
|
|
528
|
+
def _apply_crop_jitter(self, x: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
"""Randomly shift the crop to vary background context.
|
|
530
|
+
|
|
531
|
+
x: [T, C, H, W] in [0,1]. Returns same shape with a random
|
|
532
|
+
translation applied (pixels that shift out are filled with edge values).
|
|
533
|
+
"""
|
|
534
|
+
_, _, H, W = x.shape
|
|
535
|
+
max_dx = int(self.crop_jitter_strength * W)
|
|
536
|
+
max_dy = int(self.crop_jitter_strength * H)
|
|
537
|
+
if max_dx < 1 and max_dy < 1:
|
|
538
|
+
return x
|
|
539
|
+
dx = random.randint(-max_dx, max_dx)
|
|
540
|
+
dy = random.randint(-max_dy, max_dy)
|
|
541
|
+
if dx == 0 and dy == 0:
|
|
542
|
+
return x
|
|
543
|
+
# Use affine_grid + grid_sample for smooth sub-pixel shifting
|
|
544
|
+
theta = torch.tensor([
|
|
545
|
+
[1.0, 0.0, -2.0 * dx / W],
|
|
546
|
+
[0.0, 1.0, -2.0 * dy / H],
|
|
547
|
+
], dtype=x.dtype).unsqueeze(0).expand(x.size(0), -1, -1)
|
|
548
|
+
grid = F.affine_grid(theta, x.shape, align_corners=False)
|
|
549
|
+
return F.grid_sample(x, grid, mode='bilinear', padding_mode='border', align_corners=False)
|
|
550
|
+
|
|
551
|
+
def _apply_spatial_label_augment(
|
|
552
|
+
self,
|
|
553
|
+
spatial_mask: torch.Tensor,
|
|
554
|
+
spatial_bbox: torch.Tensor,
|
|
555
|
+
spatial_bbox_valid: torch.Tensor,
|
|
556
|
+
aug_params: dict,
|
|
557
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
558
|
+
"""Apply spatial label transforms using the same params as clip augmentation."""
|
|
559
|
+
hflip = bool(aug_params.get("hflip", aug_params.get("flip", False))) if aug_params else False
|
|
560
|
+
vflip = bool(aug_params.get("vflip", False)) if aug_params else False
|
|
561
|
+
if not hflip and not vflip:
|
|
562
|
+
return spatial_mask, spatial_bbox
|
|
563
|
+
|
|
564
|
+
out_mask = spatial_mask
|
|
565
|
+
out_bbox = spatial_bbox.clone()
|
|
566
|
+
|
|
567
|
+
# Flip mask on width/height axes when mask grid is enabled.
|
|
568
|
+
if self.grid_size > 0 and out_mask.numel() == self.grid_size * self.grid_size:
|
|
569
|
+
out_mask_grid = out_mask.reshape(self.grid_size, self.grid_size)
|
|
570
|
+
if hflip:
|
|
571
|
+
out_mask_grid = out_mask_grid.flip(1)
|
|
572
|
+
if vflip:
|
|
573
|
+
out_mask_grid = out_mask_grid.flip(0)
|
|
574
|
+
out_mask = out_mask_grid.reshape(-1)
|
|
575
|
+
|
|
576
|
+
# Flip per-frame bboxes in normalized xyxy coordinates.
|
|
577
|
+
# spatial_bbox is [T, 4], spatial_bbox_valid is [T]
|
|
578
|
+
if out_bbox.dim() == 2:
|
|
579
|
+
valid = spatial_bbox_valid > 0.5 # [T]
|
|
580
|
+
if valid.any():
|
|
581
|
+
if hflip:
|
|
582
|
+
new_x1 = (1.0 - out_bbox[:, 2]).clamp(0.0, 1.0)
|
|
583
|
+
new_x2 = (1.0 - out_bbox[:, 0]).clamp(0.0, 1.0)
|
|
584
|
+
out_bbox[:, 0] = new_x1
|
|
585
|
+
out_bbox[:, 2] = new_x2
|
|
586
|
+
if vflip:
|
|
587
|
+
new_y1 = (1.0 - out_bbox[:, 3]).clamp(0.0, 1.0)
|
|
588
|
+
new_y2 = (1.0 - out_bbox[:, 1]).clamp(0.0, 1.0)
|
|
589
|
+
out_bbox[:, 1] = new_y1
|
|
590
|
+
out_bbox[:, 3] = new_y2
|
|
591
|
+
# Zero out invalid frames to avoid corrupted coords
|
|
592
|
+
out_bbox[~valid] = 0.0
|
|
593
|
+
elif out_bbox.dim() == 1 and out_bbox.numel() == 4:
|
|
594
|
+
# Legacy fallback (single bbox)
|
|
595
|
+
if float(spatial_bbox_valid.sum().item()) > 0.5:
|
|
596
|
+
x1, y1, x2, y2 = [float(v) for v in out_bbox]
|
|
597
|
+
if hflip:
|
|
598
|
+
x1, x2 = max(0.0, min(1.0, 1.0 - x2)), max(0.0, min(1.0, 1.0 - x1))
|
|
599
|
+
if vflip:
|
|
600
|
+
y1, y2 = max(0.0, min(1.0, 1.0 - y2)), max(0.0, min(1.0, 1.0 - y1))
|
|
601
|
+
out_bbox = torch.tensor([x1, y1, x2, y2], dtype=out_bbox.dtype)
|
|
602
|
+
|
|
603
|
+
return out_mask, out_bbox
|
|
604
|
+
|
|
605
|
+
def _do_stitch(self, actual_idx: int, x_a: torch.Tensor):
|
|
606
|
+
"""Splice clip A with a clip from a different class at a fixed 50/50 boundary.
|
|
607
|
+
|
|
608
|
+
Returns (x_stitched, fl_stitched, y_stitched, spatial_mask, bbox, bbox_valid, bg_mask).
|
|
609
|
+
The clip-level label is set to -1 because the mixed clip has no single ground
|
|
610
|
+
truth; the per-frame labels carry all supervision.
|
|
611
|
+
"""
|
|
612
|
+
T = self.clip_length
|
|
613
|
+
label_a = self.labels[actual_idx]
|
|
614
|
+
|
|
615
|
+
# Don't stitch if the source clip is an excluded class (e.g. "Other")
|
|
616
|
+
if label_a in self.stitch_exclude_classes:
|
|
617
|
+
return (x_a, self.frame_labels[actual_idx],
|
|
618
|
+
self.labels[actual_idx],
|
|
619
|
+
self.spatial_masks[actual_idx].clone(),
|
|
620
|
+
self.spatial_bboxes[actual_idx].clone(),
|
|
621
|
+
self.spatial_bbox_valid[actual_idx].clone(),
|
|
622
|
+
self.ovr_background_frame_mask[actual_idx].clone())
|
|
623
|
+
|
|
624
|
+
# Gather candidate indices from any class other than clip A's class,
|
|
625
|
+
# excluding F1-excluded classes (e.g. "Other") to avoid contaminating
|
|
626
|
+
# real behavior clips with catch-all content.
|
|
627
|
+
other_indices: list[int] = []
|
|
628
|
+
for cls_idx, idxs in self._label_to_clip_indices.items():
|
|
629
|
+
if cls_idx != label_a and cls_idx not in self.stitch_exclude_classes:
|
|
630
|
+
other_indices.extend(idxs)
|
|
631
|
+
if not other_indices:
|
|
632
|
+
# Fall back to a different clip from the same class.
|
|
633
|
+
same = [i for i in self._label_to_clip_indices.get(label_a, []) if i != actual_idx]
|
|
634
|
+
if not same:
|
|
635
|
+
# Only one clip in the whole dataset — skip stitching.
|
|
636
|
+
return (x_a, self.frame_labels[actual_idx],
|
|
637
|
+
self.labels[actual_idx],
|
|
638
|
+
self.spatial_masks[actual_idx].clone(),
|
|
639
|
+
self.spatial_bboxes[actual_idx].clone(),
|
|
640
|
+
self.spatial_bbox_valid[actual_idx].clone(),
|
|
641
|
+
self.ovr_background_frame_mask[actual_idx].clone())
|
|
642
|
+
other_indices = same
|
|
643
|
+
|
|
644
|
+
b_idx = random.choice(other_indices)
|
|
645
|
+
clip_b_info = self.clips[b_idx]
|
|
646
|
+
clip_b_path = self._resolve_clip_path(clip_b_info["id"])
|
|
647
|
+
x_b = self._load_clip(clip_b_path, apply_transform=False)
|
|
648
|
+
|
|
649
|
+
# 50/50 split: each half gets T//2 frames for maximum temporal context.
|
|
650
|
+
stitch_t = T // 2
|
|
651
|
+
|
|
652
|
+
x_stitched = torch.cat([x_a[:stitch_t], x_b[stitch_t:]], dim=0)
|
|
653
|
+
|
|
654
|
+
fl_a = self.frame_labels[actual_idx].clone()
|
|
655
|
+
fl_b = self.frame_labels[b_idx].clone()
|
|
656
|
+
fl_stitched = torch.cat([fl_a[:stitch_t], fl_b[stitch_t:]], dim=0)
|
|
657
|
+
bg_a = self.ovr_background_frame_mask[actual_idx].clone()
|
|
658
|
+
bg_b = self.ovr_background_frame_mask[b_idx].clone()
|
|
659
|
+
bg_stitched = torch.cat([bg_a[:stitch_t], bg_b[stitch_t:]], dim=0)
|
|
660
|
+
|
|
661
|
+
# Splice per-frame bboxes; use clip A's spatial grid mask (no better option
|
|
662
|
+
# for a synthetic composite clip).
|
|
663
|
+
bbox_a = self.spatial_bboxes[actual_idx].clone()
|
|
664
|
+
bbox_b = self.spatial_bboxes[b_idx].clone()
|
|
665
|
+
bbox_stitched = torch.cat([bbox_a[:stitch_t], bbox_b[stitch_t:]], dim=0)
|
|
666
|
+
|
|
667
|
+
bv_a = self.spatial_bbox_valid[actual_idx].clone()
|
|
668
|
+
bv_b = self.spatial_bbox_valid[b_idx].clone()
|
|
669
|
+
bv_stitched = torch.cat([bv_a[:stitch_t], bv_b[stitch_t:]], dim=0)
|
|
670
|
+
|
|
671
|
+
spatial_mask = self.spatial_masks[actual_idx].clone()
|
|
672
|
+
return x_stitched, fl_stitched, -1, spatial_mask, bbox_stitched, bv_stitched, bg_stitched
|
|
673
|
+
|
|
674
|
+
def __getitem__(self, idx):
|
|
675
|
+
# Map virtual index to actual clip index
|
|
676
|
+
actual_idx = idx % len(self.clips)
|
|
677
|
+
clip_info = self.clips[actual_idx]
|
|
678
|
+
|
|
679
|
+
# Embedding-space stitch: backbone tokens are pre-cached.
|
|
680
|
+
# Stitch happens on token tensors so VideoPrism always sees clean clips.
|
|
681
|
+
if getattr(self, '_emb_cache_mode', False):
|
|
682
|
+
cache_dir = getattr(self, '_emb_cache_dir', None)
|
|
683
|
+
num_versions = getattr(self, '_emb_num_versions', 1)
|
|
684
|
+
use_multi_scale = getattr(self, '_emb_multi_scale', False)
|
|
685
|
+
|
|
686
|
+
def _pick_version() -> int:
|
|
687
|
+
return random.randint(0, num_versions - 1) if num_versions > 1 else 0
|
|
688
|
+
|
|
689
|
+
def _load_emb(clip_idx: int, v: int, short: bool = False) -> torch.Tensor:
|
|
690
|
+
suffix = f"_{v}_s.pt" if short else f"_{v}.pt"
|
|
691
|
+
path = os.path.join(cache_dir, f"{clip_idx}{suffix}")
|
|
692
|
+
return torch.load(path, map_location='cpu', weights_only=True).float()
|
|
693
|
+
|
|
694
|
+
v_a = _pick_version()
|
|
695
|
+
emb_a = _load_emb(actual_idx, v_a) if cache_dir else torch.zeros(self.clip_length * 256, 768)
|
|
696
|
+
emb_a_s = _load_emb(actual_idx, v_a, short=True) if (cache_dir and use_multi_scale) else None
|
|
697
|
+
|
|
698
|
+
label_a = self.labels[actual_idx]
|
|
699
|
+
do_stitch = (
|
|
700
|
+
self.stitch_prob > 0.0
|
|
701
|
+
and label_a not in self.stitch_exclude_classes
|
|
702
|
+
and torch.rand(1).item() < self.stitch_prob
|
|
703
|
+
)
|
|
704
|
+
other_indices = [
|
|
705
|
+
i for cls, idxs in self._label_to_clip_indices.items()
|
|
706
|
+
if cls != label_a and cls not in self.stitch_exclude_classes
|
|
707
|
+
for i in idxs
|
|
708
|
+
]
|
|
709
|
+
|
|
710
|
+
if do_stitch and other_indices:
|
|
711
|
+
b_idx = random.choice(other_indices)
|
|
712
|
+
v_b = _pick_version()
|
|
713
|
+
emb_b = _load_emb(b_idx, v_b)
|
|
714
|
+
emb_b_s = _load_emb(b_idx, v_b, short=True) if use_multi_scale else None
|
|
715
|
+
|
|
716
|
+
T = self._emb_clip_length
|
|
717
|
+
S = emb_a.shape[0] // T
|
|
718
|
+
stitch_t = T // 2
|
|
719
|
+
|
|
720
|
+
x_out = torch.cat([emb_a[:stitch_t * S], emb_b[stitch_t * S:]], dim=0)
|
|
721
|
+
|
|
722
|
+
if use_multi_scale and emb_a_s is not None and emb_b_s is not None:
|
|
723
|
+
T_s = T // 2
|
|
724
|
+
S_s = emb_a_s.shape[0] // T_s
|
|
725
|
+
stitch_t_s = T_s // 2
|
|
726
|
+
x_short = torch.cat([emb_a_s[:stitch_t_s * S_s], emb_b_s[stitch_t_s * S_s:]], dim=0)
|
|
727
|
+
else:
|
|
728
|
+
x_short = torch.empty(0)
|
|
729
|
+
|
|
730
|
+
fl_a = self.frame_labels[actual_idx].clone()
|
|
731
|
+
fl_b = self.frame_labels[b_idx].clone()
|
|
732
|
+
fl = torch.cat([fl_a[:stitch_t], fl_b[stitch_t:]], dim=0)
|
|
733
|
+
bg_a = self.ovr_background_frame_mask[actual_idx].clone()
|
|
734
|
+
bg_b = self.ovr_background_frame_mask[b_idx].clone()
|
|
735
|
+
bg_mask = torch.cat([bg_a[:stitch_t], bg_b[stitch_t:]], dim=0)
|
|
736
|
+
bbox_a = self.spatial_bboxes[actual_idx].clone()
|
|
737
|
+
bbox_b = self.spatial_bboxes[b_idx].clone()
|
|
738
|
+
spatial_bbox = torch.cat([bbox_a[:stitch_t], bbox_b[stitch_t:]], dim=0)
|
|
739
|
+
bv_a = self.spatial_bbox_valid[actual_idx].clone()
|
|
740
|
+
bv_b = self.spatial_bbox_valid[b_idx].clone()
|
|
741
|
+
spatial_bbox_valid = torch.cat([bv_a[:stitch_t], bv_b[stitch_t:]], dim=0)
|
|
742
|
+
spatial_mask = self.spatial_masks[actual_idx].clone()
|
|
743
|
+
y = -1
|
|
744
|
+
else:
|
|
745
|
+
x_out = emb_a
|
|
746
|
+
x_short = emb_a_s if (use_multi_scale and emb_a_s is not None) else torch.empty(0)
|
|
747
|
+
y = self.labels[actual_idx]
|
|
748
|
+
fl = self.frame_labels[actual_idx].clone()
|
|
749
|
+
bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
|
|
750
|
+
spatial_mask = self.spatial_masks[actual_idx].clone()
|
|
751
|
+
spatial_bbox = self.spatial_bboxes[actual_idx].clone()
|
|
752
|
+
spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
|
|
753
|
+
|
|
754
|
+
return x_out, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, x_short, bg_mask
|
|
755
|
+
|
|
756
|
+
# When ROI cache is active, load the precomputed crop from disk instead
|
|
757
|
+
# of decoding the original video. This lets DataLoader workers operate
|
|
758
|
+
# normally (parallel prefetch) so classification training runs at the
|
|
759
|
+
# same speed as standard (no-localization) training.
|
|
760
|
+
if getattr(self, '_roi_cache_mode', False):
|
|
761
|
+
cache_dir = getattr(self, '_roi_cache_dir', None)
|
|
762
|
+
x_full = None
|
|
763
|
+
if cache_dir:
|
|
764
|
+
pt_path = os.path.join(cache_dir, f"{actual_idx}.pt")
|
|
765
|
+
cached = torch.load(pt_path, map_location='cpu', weights_only=True)
|
|
766
|
+
if isinstance(cached, dict) and torch.is_tensor(cached.get("roi")):
|
|
767
|
+
x = cached["roi"].float()
|
|
768
|
+
if torch.is_tensor(cached.get("full")):
|
|
769
|
+
x_full = cached["full"].float()
|
|
770
|
+
elif torch.is_tensor(cached):
|
|
771
|
+
x = cached.float()
|
|
772
|
+
else:
|
|
773
|
+
x = torch.zeros(self.clip_length, 3, 1, 1, dtype=torch.float32)
|
|
774
|
+
else:
|
|
775
|
+
x = torch.zeros(self.clip_length, 3, 1, 1, dtype=torch.float32)
|
|
776
|
+
if x_full is None:
|
|
777
|
+
x_full = x.clone()
|
|
778
|
+
|
|
779
|
+
if self.crop_jitter and self.crop_jitter_strength > 0:
|
|
780
|
+
x = self._apply_crop_jitter(x)
|
|
781
|
+
|
|
782
|
+
spatial_mask = self.spatial_masks[actual_idx].clone()
|
|
783
|
+
spatial_bbox = self.spatial_bboxes[actual_idx].clone()
|
|
784
|
+
spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
|
|
785
|
+
|
|
786
|
+
if self.transform:
|
|
787
|
+
if hasattr(self.transform, "augment_with_params"):
|
|
788
|
+
x, aug_params = self.transform.augment_with_params(x)
|
|
789
|
+
spatial_mask, spatial_bbox = self._apply_spatial_label_augment(
|
|
790
|
+
spatial_mask, spatial_bbox, spatial_bbox_valid, aug_params
|
|
791
|
+
)
|
|
792
|
+
else:
|
|
793
|
+
x = self.transform(x)
|
|
794
|
+
|
|
795
|
+
y = self.labels[actual_idx]
|
|
796
|
+
fl = self.frame_labels[actual_idx]
|
|
797
|
+
bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
|
|
798
|
+
return x, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, torch.empty(0), bg_mask
|
|
799
|
+
|
|
800
|
+
clip_id = clip_info["id"]
|
|
801
|
+
clip_path = self._resolve_clip_path(clip_id)
|
|
802
|
+
|
|
803
|
+
x = self._load_clip(clip_path, apply_transform=False)
|
|
804
|
+
|
|
805
|
+
# Clip-stitching augmentation: splice two clips from different classes so
|
|
806
|
+
# the model learns per-frame features independent of clip-level context.
|
|
807
|
+
if self.stitch_prob > 0.0 and torch.rand(1).item() < self.stitch_prob:
|
|
808
|
+
x, fl, y, spatial_mask, spatial_bbox, spatial_bbox_valid, bg_mask = self._do_stitch(actual_idx, x)
|
|
809
|
+
else:
|
|
810
|
+
spatial_mask = self.spatial_masks[actual_idx].clone()
|
|
811
|
+
spatial_bbox = self.spatial_bboxes[actual_idx].clone()
|
|
812
|
+
spatial_bbox_valid = self.spatial_bbox_valid[actual_idx]
|
|
813
|
+
y = self.labels[actual_idx]
|
|
814
|
+
fl = self.frame_labels[actual_idx]
|
|
815
|
+
bg_mask = self.ovr_background_frame_mask[actual_idx].clone()
|
|
816
|
+
|
|
817
|
+
if self.transform:
|
|
818
|
+
if hasattr(self.transform, "augment_with_params"):
|
|
819
|
+
x, aug_params = self.transform.augment_with_params(x)
|
|
820
|
+
spatial_mask, spatial_bbox = self._apply_spatial_label_augment(
|
|
821
|
+
spatial_mask, spatial_bbox, spatial_bbox_valid, aug_params
|
|
822
|
+
)
|
|
823
|
+
else:
|
|
824
|
+
x = self.transform(x)
|
|
825
|
+
|
|
826
|
+
return x, y, spatial_mask, spatial_bbox, spatial_bbox_valid, actual_idx, fl, torch.empty(0), bg_mask
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def compute_class_weights(labels: list, num_classes: int) -> torch.Tensor:
|
|
830
|
+
"""Compute class weights based on inverse frequency."""
|
|
831
|
+
if not labels:
|
|
832
|
+
return torch.ones(num_classes)
|
|
833
|
+
|
|
834
|
+
counter = Counter(labels)
|
|
835
|
+
total = len(labels)
|
|
836
|
+
|
|
837
|
+
weights = torch.ones(num_classes)
|
|
838
|
+
for class_idx, count in counter.items():
|
|
839
|
+
if count > 0:
|
|
840
|
+
weights[class_idx] = total / (num_classes * count)
|
|
841
|
+
|
|
842
|
+
weights = weights / weights.sum() * num_classes
|
|
843
|
+
return weights
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
class BalancedBatchSampler:
|
|
847
|
+
"""
|
|
848
|
+
Batch sampler that ensures at least `min_samples_per_class` samples per selected class inside each batch.
|
|
849
|
+
Strategy:
|
|
850
|
+
- Pick K = min(num_eligible_classes, batch_size // min_samples_per_class) classes per batch
|
|
851
|
+
- Draw `min_samples_per_class` samples for each selected class (with cycling/shuffle per class)
|
|
852
|
+
- Fill any remaining slots from the selected classes
|
|
853
|
+
Notes:
|
|
854
|
+
- Classes with fewer than `min_samples_per_class` examples are ignored for balancing
|
|
855
|
+
- When no eligible classes exist, the sampler yields empty and should be ignored by caller
|
|
856
|
+
"""
|
|
857
|
+
def __init__(
|
|
858
|
+
self,
|
|
859
|
+
labels: list[int],
|
|
860
|
+
batch_size: int,
|
|
861
|
+
min_samples_per_class: int = 2,
|
|
862
|
+
drop_last: bool = False,
|
|
863
|
+
seed: Optional[int] = None,
|
|
864
|
+
excluded_classes: Optional[list[int]] = None,
|
|
865
|
+
virtual_size_multiplier: int = 1,
|
|
866
|
+
background_indices: Optional[list[int]] = None,
|
|
867
|
+
background_per_batch: int = 0,
|
|
868
|
+
):
|
|
869
|
+
self.labels = list(labels)
|
|
870
|
+
self.batch_size = int(batch_size)
|
|
871
|
+
self.min_samples_per_class = int(min_samples_per_class)
|
|
872
|
+
self.drop_last = bool(drop_last)
|
|
873
|
+
self.excluded_classes = set(excluded_classes or [])
|
|
874
|
+
self.virtual_size_multiplier = max(1, int(virtual_size_multiplier))
|
|
875
|
+
self.background_indices = list(background_indices or [])
|
|
876
|
+
self.background_per_batch = max(0, int(background_per_batch))
|
|
877
|
+
self._effective_length = len(self.labels) * self.virtual_size_multiplier
|
|
878
|
+
|
|
879
|
+
# Use global random state if seed is None (respects random.seed())
|
|
880
|
+
# Otherwise use private generator
|
|
881
|
+
if seed is None:
|
|
882
|
+
self.rng = random
|
|
883
|
+
else:
|
|
884
|
+
self.rng = random.Random(seed)
|
|
885
|
+
|
|
886
|
+
# Build index lists per class
|
|
887
|
+
self.class_to_indices: Dict[int, list[int]] = {}
|
|
888
|
+
for idx, y in enumerate(self.labels):
|
|
889
|
+
self.class_to_indices.setdefault(int(y), []).append(idx)
|
|
890
|
+
|
|
891
|
+
# Eligible classes must have at least `min_samples_per_class` items and NOT be excluded
|
|
892
|
+
self.eligible_classes = [
|
|
893
|
+
c for c, idxs in self.class_to_indices.items()
|
|
894
|
+
if len(idxs) >= self.min_samples_per_class and c not in self.excluded_classes
|
|
895
|
+
]
|
|
896
|
+
self.enabled = len(self.eligible_classes) > 0 and self.batch_size > 0
|
|
897
|
+
|
|
898
|
+
# Identify dropped classes
|
|
899
|
+
self.dropped_classes = [c for c, idxs in self.class_to_indices.items() if len(idxs) < self.min_samples_per_class]
|
|
900
|
+
|
|
901
|
+
# Prepare per-class shuffled pools with cursors
|
|
902
|
+
self._pools: Dict[int, list[int]] = {}
|
|
903
|
+
self._cursors: Dict[int, int] = {}
|
|
904
|
+
for c in self.eligible_classes:
|
|
905
|
+
pool = self.class_to_indices[c][:]
|
|
906
|
+
self.rng.shuffle(pool)
|
|
907
|
+
self._pools[c] = pool
|
|
908
|
+
self._cursors[c] = 0
|
|
909
|
+
self._bg_pool = self.background_indices[:]
|
|
910
|
+
if self._bg_pool:
|
|
911
|
+
self.rng.shuffle(self._bg_pool)
|
|
912
|
+
self._bg_cursor = 0
|
|
913
|
+
|
|
914
|
+
def __len__(self) -> int:
|
|
915
|
+
if self._effective_length <= 0:
|
|
916
|
+
return 0
|
|
917
|
+
if self.drop_last:
|
|
918
|
+
return self._effective_length // max(1, self.batch_size)
|
|
919
|
+
return math.ceil(self._effective_length / max(1, self.batch_size))
|
|
920
|
+
|
|
921
|
+
def _draw_from_class(self, cls: int) -> int:
|
|
922
|
+
pool = self._pools[cls]
|
|
923
|
+
cur = self._cursors[cls]
|
|
924
|
+
if cur >= len(pool):
|
|
925
|
+
# Reshuffle and reset
|
|
926
|
+
pool = self.class_to_indices[cls][:]
|
|
927
|
+
self.rng.shuffle(pool)
|
|
928
|
+
self._pools[cls] = pool
|
|
929
|
+
cur = 0
|
|
930
|
+
idx = pool[cur]
|
|
931
|
+
self._cursors[cls] = cur + 1
|
|
932
|
+
return idx
|
|
933
|
+
|
|
934
|
+
def _draw_background(self) -> Optional[int]:
|
|
935
|
+
if not self._bg_pool:
|
|
936
|
+
return None
|
|
937
|
+
cur = self._bg_cursor
|
|
938
|
+
if cur >= len(self._bg_pool):
|
|
939
|
+
self._bg_pool = self.background_indices[:]
|
|
940
|
+
self.rng.shuffle(self._bg_pool)
|
|
941
|
+
cur = 0
|
|
942
|
+
idx = self._bg_pool[cur]
|
|
943
|
+
self._bg_cursor = cur + 1
|
|
944
|
+
return idx
|
|
945
|
+
|
|
946
|
+
def __iter__(self):
|
|
947
|
+
if not self.enabled:
|
|
948
|
+
# Yield nothing - caller should fallback
|
|
949
|
+
return
|
|
950
|
+
num_batches = len(self)
|
|
951
|
+
for _ in range(num_batches):
|
|
952
|
+
batch: list[int] = []
|
|
953
|
+
if self.batch_size <= 0:
|
|
954
|
+
yield batch
|
|
955
|
+
continue
|
|
956
|
+
|
|
957
|
+
# Number of classes to include this batch
|
|
958
|
+
k = max(1, min(len(self.eligible_classes), self.batch_size // self.min_samples_per_class))
|
|
959
|
+
if k <= len(self.eligible_classes):
|
|
960
|
+
selected = self.rng.sample(self.eligible_classes, k)
|
|
961
|
+
else:
|
|
962
|
+
# Not enough distinct classes; sample with replacement to fill
|
|
963
|
+
selected = self.eligible_classes[:] + [self.rng.choice(self.eligible_classes) for _ in range(k - len(self.eligible_classes))]
|
|
964
|
+
|
|
965
|
+
# Ensure min_samples_per_class per selected class
|
|
966
|
+
for cls in selected:
|
|
967
|
+
needed = min(self.min_samples_per_class, max(0, self.batch_size - len(batch)))
|
|
968
|
+
for _ in range(needed):
|
|
969
|
+
batch.append(self._draw_from_class(cls))
|
|
970
|
+
if len(batch) >= self.batch_size:
|
|
971
|
+
break
|
|
972
|
+
if len(batch) >= self.batch_size:
|
|
973
|
+
break
|
|
974
|
+
|
|
975
|
+
# Optional hybrid-OvR negatives: reserve a few slots for background clips
|
|
976
|
+
# that teach all target heads to stay low without becoming their own class.
|
|
977
|
+
bg_slots = min(
|
|
978
|
+
self.background_per_batch,
|
|
979
|
+
max(0, self.batch_size - len(batch)),
|
|
980
|
+
)
|
|
981
|
+
for _ in range(bg_slots):
|
|
982
|
+
bg_idx = self._draw_background()
|
|
983
|
+
if bg_idx is None:
|
|
984
|
+
break
|
|
985
|
+
batch.append(bg_idx)
|
|
986
|
+
if len(batch) >= self.batch_size:
|
|
987
|
+
break
|
|
988
|
+
|
|
989
|
+
# Fill remaining slots from the selected classes (round-robin)
|
|
990
|
+
si = 0
|
|
991
|
+
while len(batch) < self.batch_size:
|
|
992
|
+
cls = selected[si % len(selected)]
|
|
993
|
+
batch.append(self._draw_from_class(cls))
|
|
994
|
+
si += 1
|
|
995
|
+
|
|
996
|
+
yield batch
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
class ConfusionAwareSampler(BalancedBatchSampler):
|
|
1000
|
+
"""BalancedBatchSampler extended with per-sample confusion-based weights for OvR hard mining.
|
|
1001
|
+
|
|
1002
|
+
After each training epoch, call update_weights() with per-clip confusion scores.
|
|
1003
|
+
Clips with high confusion scores (model fires wrong heads) are sampled more often
|
|
1004
|
+
within their class pool. Weights are EMA-blended to avoid sudden jumps.
|
|
1005
|
+
|
|
1006
|
+
Confusion score for clip with true class c:
|
|
1007
|
+
blend of:
|
|
1008
|
+
- strongest rival-head activation
|
|
1009
|
+
- low activation of the true head
|
|
1010
|
+
- rival-over-true margin violation
|
|
1011
|
+
This keeps the score pair-aware while still producing a single scalar weight.
|
|
1012
|
+
"""
|
|
1013
|
+
|
|
1014
|
+
def __init__(self, *args, weight_temperature: float = 2.0, **kwargs):
|
|
1015
|
+
super().__init__(*args, **kwargs)
|
|
1016
|
+
self.weight_temperature = max(0.1, float(weight_temperature))
|
|
1017
|
+
n = len(self.labels)
|
|
1018
|
+
self._confusion_scores = np.zeros(n, dtype=np.float32)
|
|
1019
|
+
self._top_rival = np.full(n, -1, dtype=np.int32)
|
|
1020
|
+
# Rebuild per-class weighted pools (initially uniform)
|
|
1021
|
+
self._class_weights: Dict[int, tuple] = {}
|
|
1022
|
+
self._rebuild_class_weights()
|
|
1023
|
+
|
|
1024
|
+
def update_weights(
|
|
1025
|
+
self,
|
|
1026
|
+
confusion_scores: np.ndarray,
|
|
1027
|
+
top_rivals: Optional[np.ndarray] = None,
|
|
1028
|
+
ema_alpha: float = 0.4,
|
|
1029
|
+
) -> None:
|
|
1030
|
+
"""Blend new confusion scores into running weights and rebuild class pools."""
|
|
1031
|
+
if len(confusion_scores) != len(self._confusion_scores):
|
|
1032
|
+
return
|
|
1033
|
+
self._confusion_scores = (
|
|
1034
|
+
ema_alpha * confusion_scores.astype(np.float32)
|
|
1035
|
+
+ (1.0 - ema_alpha) * self._confusion_scores
|
|
1036
|
+
)
|
|
1037
|
+
if top_rivals is not None and len(top_rivals) == len(self._top_rival):
|
|
1038
|
+
valid = top_rivals.astype(np.int32) >= 0
|
|
1039
|
+
self._top_rival[valid] = top_rivals.astype(np.int32)[valid]
|
|
1040
|
+
self._rebuild_class_weights()
|
|
1041
|
+
|
|
1042
|
+
def _rebuild_class_weights(self) -> None:
|
|
1043
|
+
"""Compute normalised sampling probability per class from confusion scores."""
|
|
1044
|
+
self._class_weights = {}
|
|
1045
|
+
for c, indices in self.class_to_indices.items():
|
|
1046
|
+
scores = self._confusion_scores[indices]
|
|
1047
|
+
# Base weight 1.0 + confusion boost; apply temperature sharpening
|
|
1048
|
+
w = (1.0 + scores) ** self.weight_temperature
|
|
1049
|
+
total = w.sum()
|
|
1050
|
+
if total <= 0:
|
|
1051
|
+
w = np.ones_like(w, dtype=np.float32)
|
|
1052
|
+
total = w.sum()
|
|
1053
|
+
self._class_weights[c] = (np.asarray(indices), w / total)
|
|
1054
|
+
|
|
1055
|
+
def _draw_from_class(self, cls: int) -> int:
|
|
1056
|
+
"""Weighted sample from class pool; falls back to parent if weights unavailable."""
|
|
1057
|
+
if cls in self._class_weights:
|
|
1058
|
+
indices, probs = self._class_weights[cls]
|
|
1059
|
+
return int(np.random.choice(indices, p=probs))
|
|
1060
|
+
return super()._draw_from_class(cls)
|
|
1061
|
+
|
|
1062
|
+
def log_top_confused(self, class_names: list, dataset_clips: list, n: int = 5) -> list[str]:
|
|
1063
|
+
"""Return log lines describing the most-confused clip per class for transparency."""
|
|
1064
|
+
lines = []
|
|
1065
|
+
for ci, cname in enumerate(class_names):
|
|
1066
|
+
indices = self.class_to_indices.get(ci, [])
|
|
1067
|
+
if not indices:
|
|
1068
|
+
continue
|
|
1069
|
+
scores = self._confusion_scores[indices]
|
|
1070
|
+
top_local = int(np.argmax(scores))
|
|
1071
|
+
top_global = indices[top_local]
|
|
1072
|
+
top_score = float(scores[top_local])
|
|
1073
|
+
if top_score < 0.05:
|
|
1074
|
+
continue
|
|
1075
|
+
clip_id = dataset_clips[top_global % len(dataset_clips)].get("id", "?") if dataset_clips else str(top_global)
|
|
1076
|
+
rival_idx = int(self._top_rival[top_global]) if 0 <= top_global < len(self._top_rival) else -1
|
|
1077
|
+
rival_txt = ""
|
|
1078
|
+
if 0 <= rival_idx < len(class_names):
|
|
1079
|
+
rival_txt = f", rival={class_names[rival_idx]}"
|
|
1080
|
+
lines.append(
|
|
1081
|
+
f" [{cname}] hardest: {os.path.basename(clip_id)} "
|
|
1082
|
+
f"(confusion={top_score:.3f}{rival_txt})"
|
|
1083
|
+
)
|
|
1084
|
+
return lines[:n]
|
|
1085
|
+
|
|
1086
|
+
|
|
1087
|
+
def _run_augmentation_ablation_eval(
|
|
1088
|
+
model: nn.Module,
|
|
1089
|
+
dataset,
|
|
1090
|
+
config: Dict[str, Any],
|
|
1091
|
+
device: torch.device,
|
|
1092
|
+
log_fn: Optional[Callable] = None,
|
|
1093
|
+
):
|
|
1094
|
+
"""Post-training evaluation: measure per-augmentation impact on accuracy.
|
|
1095
|
+
|
|
1096
|
+
For each enabled augmentation, runs all clips through the model with ONLY that
|
|
1097
|
+
augmentation applied (3 random trials averaged) and compares to the clean baseline.
|
|
1098
|
+
Reports which augmentations help/hurt and the worst-affected clips per augmentation.
|
|
1099
|
+
"""
|
|
1100
|
+
from .augmentations import ClipAugment
|
|
1101
|
+
|
|
1102
|
+
aug_opts = config.get("augmentation_options") or {}
|
|
1103
|
+
use_ovr = config.get("use_ovr", False)
|
|
1104
|
+
|
|
1105
|
+
# Map of augmentation toggle names → ClipAugment kwargs that isolate that aug
|
|
1106
|
+
aug_toggles = {
|
|
1107
|
+
"horizontal_flip": "use_horizontal_flip",
|
|
1108
|
+
"vertical_flip": "use_vertical_flip",
|
|
1109
|
+
"color_jitter": "use_color_jitter",
|
|
1110
|
+
"gaussian_blur": "use_gaussian_blur",
|
|
1111
|
+
"random_noise": "use_random_noise",
|
|
1112
|
+
"small_rotation": "use_small_rotation",
|
|
1113
|
+
"speed_perturb": "use_speed_perturb",
|
|
1114
|
+
"random_shapes": "use_random_shapes",
|
|
1115
|
+
"grayscale": "use_grayscale",
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
# Only evaluate augmentations that were actually enabled during training
|
|
1119
|
+
enabled = []
|
|
1120
|
+
for name, kwarg in aug_toggles.items():
|
|
1121
|
+
if aug_opts.get(kwarg, False):
|
|
1122
|
+
enabled.append((name, kwarg))
|
|
1123
|
+
|
|
1124
|
+
if not enabled:
|
|
1125
|
+
if log_fn:
|
|
1126
|
+
log_fn("Augmentation ablation: no augmentations were enabled — skipping.")
|
|
1127
|
+
return
|
|
1128
|
+
|
|
1129
|
+
if log_fn:
|
|
1130
|
+
log_fn(f"\n{'='*60}")
|
|
1131
|
+
log_fn("AUGMENTATION ABLATION EVALUATION")
|
|
1132
|
+
log_fn(f"{'='*60}")
|
|
1133
|
+
log_fn(f"Evaluating {len(enabled)} enabled augmentation(s)...")
|
|
1134
|
+
|
|
1135
|
+
n_clips = len(dataset.clips)
|
|
1136
|
+
class_names = dataset.classes
|
|
1137
|
+
n_classes = len(class_names)
|
|
1138
|
+
n_trials = 3 # average over multiple random augmentation rolls
|
|
1139
|
+
|
|
1140
|
+
saved_transform = dataset.transform
|
|
1141
|
+
saved_stitch = getattr(dataset, "stitch_prob", 0.0)
|
|
1142
|
+
saved_mult = getattr(dataset, "virtual_size_multiplier", 1)
|
|
1143
|
+
dataset.transform = None
|
|
1144
|
+
dataset.stitch_prob = 0.0
|
|
1145
|
+
dataset.virtual_size_multiplier = 1
|
|
1146
|
+
|
|
1147
|
+
model.eval()
|
|
1148
|
+
|
|
1149
|
+
def _eval_clips(transform_fn) -> np.ndarray:
|
|
1150
|
+
"""Run all clips, return per-clip frame accuracy array."""
|
|
1151
|
+
dataset.transform = transform_fn
|
|
1152
|
+
clip_accs = np.zeros(n_clips, dtype=np.float32)
|
|
1153
|
+
loader = DataLoader(dataset, batch_size=config.get("batch_size", 8),
|
|
1154
|
+
shuffle=False, num_workers=0, pin_memory=False)
|
|
1155
|
+
clip_cursor = 0
|
|
1156
|
+
with torch.no_grad():
|
|
1157
|
+
for batch_data in loader:
|
|
1158
|
+
if not isinstance(batch_data, (list, tuple)) or len(batch_data) < 7:
|
|
1159
|
+
continue
|
|
1160
|
+
clips_t = batch_data[0].to(device)
|
|
1161
|
+
frame_labels_t = batch_data[6].to(device) if batch_data[6] is not None else None
|
|
1162
|
+
indices_t = batch_data[5]
|
|
1163
|
+
|
|
1164
|
+
_emb_mode = getattr(dataset, '_emb_cache_mode', False)
|
|
1165
|
+
_cl = config.get("clip_length", 8)
|
|
1166
|
+
if _emb_mode:
|
|
1167
|
+
clips_short_t = None
|
|
1168
|
+
if len(batch_data) >= 8:
|
|
1169
|
+
_cs_abl = batch_data[7]
|
|
1170
|
+
if isinstance(_cs_abl, torch.Tensor) and _cs_abl.numel() > 0:
|
|
1171
|
+
clips_short_t = _cs_abl.to(device)
|
|
1172
|
+
out = model(
|
|
1173
|
+
None, backbone_tokens=clips_t, num_frames=_cl,
|
|
1174
|
+
backbone_tokens_short=clips_short_t,
|
|
1175
|
+
num_frames_short=_cl // 2 if clips_short_t is not None else None,
|
|
1176
|
+
return_frame_logits=True,
|
|
1177
|
+
)
|
|
1178
|
+
else:
|
|
1179
|
+
out = model(clips_t, return_frame_logits=True)
|
|
1180
|
+
|
|
1181
|
+
fo = getattr(model, '_frame_output', None)
|
|
1182
|
+
if fo is None:
|
|
1183
|
+
continue
|
|
1184
|
+
f_logits = fo[0] # [B, T, C]
|
|
1185
|
+
|
|
1186
|
+
if use_ovr:
|
|
1187
|
+
preds = torch.argmax(torch.sigmoid(f_logits), dim=-1)
|
|
1188
|
+
else:
|
|
1189
|
+
preds = torch.argmax(f_logits, dim=-1)
|
|
1190
|
+
|
|
1191
|
+
B = preds.shape[0]
|
|
1192
|
+
for bi in range(B):
|
|
1193
|
+
idx = int(indices_t[bi].item()) % n_clips
|
|
1194
|
+
if frame_labels_t is not None:
|
|
1195
|
+
valid = frame_labels_t[bi] >= 0
|
|
1196
|
+
if valid.any():
|
|
1197
|
+
acc = float((preds[bi][valid] == frame_labels_t[bi][valid]).float().mean().item())
|
|
1198
|
+
else:
|
|
1199
|
+
acc = 1.0
|
|
1200
|
+
else:
|
|
1201
|
+
acc = 1.0
|
|
1202
|
+
clip_accs[idx] = acc
|
|
1203
|
+
return clip_accs
|
|
1204
|
+
|
|
1205
|
+
# 1. Clean baseline (no augmentation)
|
|
1206
|
+
if log_fn:
|
|
1207
|
+
log_fn("Running clean baseline (no augmentation)...")
|
|
1208
|
+
baseline_accs = _eval_clips(None)
|
|
1209
|
+
baseline_mean = float(baseline_accs.mean()) * 100
|
|
1210
|
+
|
|
1211
|
+
# 2. Per-augmentation evaluation
|
|
1212
|
+
results = {}
|
|
1213
|
+
for aug_name, aug_kwarg in enabled:
|
|
1214
|
+
if log_fn:
|
|
1215
|
+
log_fn(f"Evaluating: {aug_name}...")
|
|
1216
|
+
trial_accs = []
|
|
1217
|
+
for trial in range(n_trials):
|
|
1218
|
+
# Build a ClipAugment with ONLY this one augmentation on
|
|
1219
|
+
kwargs = {k: False for _, k in aug_toggles.items()}
|
|
1220
|
+
kwargs[aug_kwarg] = True
|
|
1221
|
+
# Pass through augmentation-specific params
|
|
1222
|
+
for pkey in ["color_jitter_brightness", "color_jitter_contrast",
|
|
1223
|
+
"color_jitter_saturation", "color_jitter_hue",
|
|
1224
|
+
"noise_std", "rotation_degrees"]:
|
|
1225
|
+
if pkey in aug_opts:
|
|
1226
|
+
kwargs[pkey] = aug_opts[pkey]
|
|
1227
|
+
aug_fn = ClipAugment(**kwargs)
|
|
1228
|
+
trial_accs.append(_eval_clips(aug_fn))
|
|
1229
|
+
avg_accs = np.mean(trial_accs, axis=0)
|
|
1230
|
+
delta = avg_accs - baseline_accs # per-clip change
|
|
1231
|
+
results[aug_name] = {
|
|
1232
|
+
"mean_acc": float(avg_accs.mean()) * 100,
|
|
1233
|
+
"delta_mean": float(delta.mean()) * 100,
|
|
1234
|
+
"n_hurt": int((delta < -0.05).sum()),
|
|
1235
|
+
"n_helped": int((delta > 0.05).sum()),
|
|
1236
|
+
"per_clip_delta": delta,
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
# Restore
|
|
1240
|
+
dataset.transform = saved_transform
|
|
1241
|
+
dataset.stitch_prob = saved_stitch
|
|
1242
|
+
dataset.virtual_size_multiplier = saved_mult
|
|
1243
|
+
|
|
1244
|
+
# 3. Log results
|
|
1245
|
+
if log_fn:
|
|
1246
|
+
log_fn(f"\nClean baseline accuracy: {baseline_mean:.1f}%\n")
|
|
1247
|
+
log_fn(f"{'Augmentation':<20} {'Acc':>7} {'Δ':>7} {'Hurt':>6} {'Helped':>8}")
|
|
1248
|
+
log_fn("-" * 52)
|
|
1249
|
+
for aug_name, r in sorted(results.items(), key=lambda x: x[1]["delta_mean"]):
|
|
1250
|
+
log_fn(
|
|
1251
|
+
f"{aug_name:<20} {r['mean_acc']:>6.1f}% {r['delta_mean']:>+6.1f}% "
|
|
1252
|
+
f"{r['n_hurt']:>5} {r['n_helped']:>7}"
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
# Worst clips per augmentation
|
|
1256
|
+
log_fn(f"\nWorst-affected clips per augmentation:")
|
|
1257
|
+
for aug_name, r in results.items():
|
|
1258
|
+
deltas = r["per_clip_delta"]
|
|
1259
|
+
worst_idx = int(np.argmin(deltas))
|
|
1260
|
+
worst_delta = float(deltas[worst_idx]) * 100
|
|
1261
|
+
if worst_delta > -1.0:
|
|
1262
|
+
continue
|
|
1263
|
+
clip_id = dataset.clips[worst_idx].get("id", "?")
|
|
1264
|
+
clip_label_idx = dataset.labels[worst_idx]
|
|
1265
|
+
clip_cls = class_names[clip_label_idx] if 0 <= clip_label_idx < n_classes else "?"
|
|
1266
|
+
log_fn(
|
|
1267
|
+
f" {aug_name}: {os.path.basename(clip_id)} [{clip_cls}] "
|
|
1268
|
+
f"(clean={baseline_accs[worst_idx]*100:.0f}% → aug={float(r['per_clip_delta'][worst_idx] + baseline_accs[worst_idx])*100:.0f}%, "
|
|
1269
|
+
f"Δ={worst_delta:+.1f}%)"
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
log_fn(f"\n{'='*60}")
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
def train_model(
|
|
1276
|
+
model: nn.Module,
|
|
1277
|
+
train_dataset: Dataset,
|
|
1278
|
+
val_dataset: Optional[Dataset],
|
|
1279
|
+
config: Dict[str, Any],
|
|
1280
|
+
log_fn: Optional[Callable[[str], None]] = None,
|
|
1281
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
1282
|
+
stop_callback: Optional[Callable[[], bool]] = None,
|
|
1283
|
+
metrics_callback: Optional[Callable[[Dict[str, Any]], None]] = None
|
|
1284
|
+
):
|
|
1285
|
+
"""Training loop for behavior classifier.
|
|
1286
|
+
|
|
1287
|
+
Args:
|
|
1288
|
+
model: BehaviorClassifier model
|
|
1289
|
+
train_dataset: Training dataset
|
|
1290
|
+
val_dataset: Optional validation dataset
|
|
1291
|
+
config: Training configuration dict
|
|
1292
|
+
log_fn: Optional callback for logging messages
|
|
1293
|
+
progress_callback: Optional callback(epoch, total_epochs) for progress
|
|
1294
|
+
stop_callback: Optional callback returning True if training should stop
|
|
1295
|
+
metrics_callback: Optional callback(metrics_dict) called after each epoch
|
|
1296
|
+
"""
|
|
1297
|
+
import traceback
|
|
1298
|
+
|
|
1299
|
+
try:
|
|
1300
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1301
|
+
if log_fn:
|
|
1302
|
+
log_fn(f"Using device: {device}")
|
|
1303
|
+
if device.type == "cuda":
|
|
1304
|
+
log_fn(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
|
1305
|
+
log_fn(f"CUDA available: {torch.cuda.is_available()}")
|
|
1306
|
+
log_fn(f"CUDA device count: {torch.cuda.device_count()}")
|
|
1307
|
+
log_fn(f"Current CUDA device: {torch.cuda.current_device()}")
|
|
1308
|
+
log_fn(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
|
|
1309
|
+
|
|
1310
|
+
# Move model to device (head will be on GPU, backbone stays on CPU for JAX)
|
|
1311
|
+
model.to(device)
|
|
1312
|
+
|
|
1313
|
+
# Load pretrained weights if specified (fine-tuning)
|
|
1314
|
+
pretrained_path = config.get("pretrained_path")
|
|
1315
|
+
if pretrained_path and os.path.exists(pretrained_path):
|
|
1316
|
+
if log_fn:
|
|
1317
|
+
log_fn(f"Loading pretrained weights from {pretrained_path}...")
|
|
1318
|
+
try:
|
|
1319
|
+
payload = torch.load(pretrained_path, map_location=device)
|
|
1320
|
+
if isinstance(payload, dict):
|
|
1321
|
+
frame_head_state_dict = payload.get("frame_head_state_dict", {})
|
|
1322
|
+
localization_state_dict = payload.get("localization_state_dict", {})
|
|
1323
|
+
else:
|
|
1324
|
+
frame_head_state_dict = {}
|
|
1325
|
+
localization_state_dict = {}
|
|
1326
|
+
|
|
1327
|
+
if (
|
|
1328
|
+
getattr(model, "frame_head", None) is not None
|
|
1329
|
+
and isinstance(frame_head_state_dict, dict)
|
|
1330
|
+
and frame_head_state_dict
|
|
1331
|
+
):
|
|
1332
|
+
frame_model_state = model.frame_head.state_dict()
|
|
1333
|
+
filtered_frame = {}
|
|
1334
|
+
mismatched_frame = []
|
|
1335
|
+
for k, v in frame_head_state_dict.items():
|
|
1336
|
+
if k in frame_model_state:
|
|
1337
|
+
if v.shape == frame_model_state[k].shape:
|
|
1338
|
+
filtered_frame[k] = v
|
|
1339
|
+
else:
|
|
1340
|
+
mismatched_frame.append(k)
|
|
1341
|
+
model.frame_head.load_state_dict(filtered_frame, strict=False)
|
|
1342
|
+
if log_fn:
|
|
1343
|
+
log_fn(
|
|
1344
|
+
f"Loaded frame head weights: {len(filtered_frame)} tensors"
|
|
1345
|
+
+ (f" (skipped mismatched: {mismatched_frame})" if mismatched_frame else "")
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
# Also restore localization head when available and compatible.
|
|
1349
|
+
if (
|
|
1350
|
+
getattr(model, "use_localization", False)
|
|
1351
|
+
and getattr(model, "localization_head", None) is not None
|
|
1352
|
+
and isinstance(localization_state_dict, dict)
|
|
1353
|
+
and localization_state_dict
|
|
1354
|
+
):
|
|
1355
|
+
loc_model_state = model.localization_head.state_dict()
|
|
1356
|
+
filtered_loc = {}
|
|
1357
|
+
mismatched_loc = []
|
|
1358
|
+
for k, v in localization_state_dict.items():
|
|
1359
|
+
if k in loc_model_state:
|
|
1360
|
+
if v.shape == loc_model_state[k].shape:
|
|
1361
|
+
filtered_loc[k] = v
|
|
1362
|
+
else:
|
|
1363
|
+
mismatched_loc.append(k)
|
|
1364
|
+
model.localization_head.load_state_dict(filtered_loc, strict=False)
|
|
1365
|
+
if log_fn:
|
|
1366
|
+
log_fn(
|
|
1367
|
+
f"Loaded localization head weights: {len(filtered_loc)} tensors"
|
|
1368
|
+
+ (f" (skipped mismatched: {mismatched_loc})" if mismatched_loc else "")
|
|
1369
|
+
)
|
|
1370
|
+
elif log_fn and getattr(model, "use_localization", False):
|
|
1371
|
+
log_fn("No localization weights found in pretrained checkpoint; localization head will train from initialization.")
|
|
1372
|
+
|
|
1373
|
+
if log_fn:
|
|
1374
|
+
log_fn("Pretrained weights loaded successfully (partial load if class count changed).")
|
|
1375
|
+
except Exception as e:
|
|
1376
|
+
if log_fn:
|
|
1377
|
+
log_fn(f"WARNING: Failed to load pretrained weights: {e}")
|
|
1378
|
+
|
|
1379
|
+
# Ensure frame head is on GPU
|
|
1380
|
+
if device.type == "cuda":
|
|
1381
|
+
model.frame_head.to(device)
|
|
1382
|
+
fh_device = next(model.frame_head.parameters()).device
|
|
1383
|
+
if log_fn:
|
|
1384
|
+
log_fn(f"Frame head device: {fh_device}")
|
|
1385
|
+
if fh_device.type != "cuda":
|
|
1386
|
+
log_fn("ERROR: Frame head is not on GPU! Training will be very slow.")
|
|
1387
|
+
else:
|
|
1388
|
+
log_fn(f"GPU memory after moving head: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
|
|
1389
|
+
|
|
1390
|
+
if log_fn:
|
|
1391
|
+
log_fn(f"Model moved to {device}")
|
|
1392
|
+
log_fn("Note: VideoPrism backbone runs on GPU (JAX) if available, classification heads run on GPU (PyTorch)")
|
|
1393
|
+
|
|
1394
|
+
if log_fn:
|
|
1395
|
+
log_fn(f"Creating data loaders (batch_size={config['batch_size']})...")
|
|
1396
|
+
|
|
1397
|
+
use_weighted_sampler = config.get("use_weighted_sampler", False)
|
|
1398
|
+
use_ovr = config.get("use_ovr", False)
|
|
1399
|
+
_confusion_warmup_pct = float(config.get("confusion_sampler_warmup_pct", 0.2))
|
|
1400
|
+
sampler = None
|
|
1401
|
+
batch_sampler = None
|
|
1402
|
+
shuffle = True
|
|
1403
|
+
|
|
1404
|
+
# OvR needs balanced batches so each binary head sees positives every batch.
|
|
1405
|
+
if use_ovr and hasattr(train_dataset, 'labels'):
|
|
1406
|
+
labels_for_balance = train_dataset.labels
|
|
1407
|
+
ovr_min_samples = 1
|
|
1408
|
+
_confusion_temperature = float(config.get("confusion_sampler_temperature", 2.0))
|
|
1409
|
+
_use_confusion_sampler = bool(config.get("use_confusion_sampler", True)) and use_ovr
|
|
1410
|
+
_hybrid_bg = bool(config.get("ovr_background_as_negative", False))
|
|
1411
|
+
_bg_indices = []
|
|
1412
|
+
_bg_per_batch = 0
|
|
1413
|
+
if _hybrid_bg and hasattr(train_dataset, "ovr_background_clip"):
|
|
1414
|
+
_bg_indices = [
|
|
1415
|
+
i for i, is_bg in enumerate(train_dataset.ovr_background_clip)
|
|
1416
|
+
if bool(is_bg)
|
|
1417
|
+
]
|
|
1418
|
+
if _bg_indices:
|
|
1419
|
+
_bg_per_batch = max(1, config["batch_size"] // 4)
|
|
1420
|
+
_SamplerClass = ConfusionAwareSampler if _use_confusion_sampler else BalancedBatchSampler
|
|
1421
|
+
_sampler_kwargs = dict(weight_temperature=_confusion_temperature) if _use_confusion_sampler else {}
|
|
1422
|
+
bbs = _SamplerClass(
|
|
1423
|
+
labels=labels_for_balance,
|
|
1424
|
+
batch_size=config["batch_size"],
|
|
1425
|
+
min_samples_per_class=ovr_min_samples,
|
|
1426
|
+
drop_last=False,
|
|
1427
|
+
excluded_classes=[-1],
|
|
1428
|
+
virtual_size_multiplier=getattr(train_dataset, "virtual_size_multiplier", 1),
|
|
1429
|
+
background_indices=_bg_indices,
|
|
1430
|
+
background_per_batch=_bg_per_batch,
|
|
1431
|
+
**_sampler_kwargs,
|
|
1432
|
+
)
|
|
1433
|
+
if bbs and getattr(bbs, "enabled", False):
|
|
1434
|
+
batch_sampler = bbs
|
|
1435
|
+
if log_fn:
|
|
1436
|
+
reason = "OvR" if use_ovr else "contrastive loss"
|
|
1437
|
+
sampler_type = "ConfusionAwareSampler" if _use_confusion_sampler else "BalancedBatchSampler"
|
|
1438
|
+
log_fn(f"Using {sampler_type} (>={ovr_min_samples} per class) for training ({reason} enabled)")
|
|
1439
|
+
if _bg_indices:
|
|
1440
|
+
log_fn(
|
|
1441
|
+
f"Hybrid OvR background negatives: {len(_bg_indices)} clips "
|
|
1442
|
+
f"(up to {_bg_per_batch} per batch)"
|
|
1443
|
+
)
|
|
1444
|
+
if _use_confusion_sampler:
|
|
1445
|
+
warmup_ep = int(config["epochs"] * _confusion_warmup_pct)
|
|
1446
|
+
log_fn(f"Confusion sampler warmup: {int(_confusion_warmup_pct * 100)}% ({warmup_ep} epochs) — uniform sampling until epoch {warmup_ep}")
|
|
1447
|
+
if hasattr(bbs, "dropped_classes") and bbs.dropped_classes:
|
|
1448
|
+
# Map indices back to names safely.
|
|
1449
|
+
dropped_names = []
|
|
1450
|
+
for idx in bbs.dropped_classes:
|
|
1451
|
+
name = None
|
|
1452
|
+
if hasattr(train_dataset, "classes") and 0 <= idx < len(train_dataset.classes):
|
|
1453
|
+
name = train_dataset.classes[idx]
|
|
1454
|
+
elif hasattr(train_dataset, "raw_classes") and 0 <= idx < len(train_dataset.raw_classes):
|
|
1455
|
+
name = train_dataset.raw_classes[idx]
|
|
1456
|
+
else:
|
|
1457
|
+
name = str(idx)
|
|
1458
|
+
dropped_names.append(name)
|
|
1459
|
+
log_fn(
|
|
1460
|
+
f"WARNING: The following classes have < {ovr_min_samples} samples and will be SKIPPED by BalancedBatchSampler: "
|
|
1461
|
+
f"{dropped_names}"
|
|
1462
|
+
)
|
|
1463
|
+
else:
|
|
1464
|
+
if log_fn:
|
|
1465
|
+
log_fn("BalancedBatchSampler disabled (insufficient class counts); falling back to standard batching")
|
|
1466
|
+
|
|
1467
|
+
# If not using balanced batches, optionally use weighted sampler
|
|
1468
|
+
if batch_sampler is None and use_weighted_sampler and hasattr(train_dataset, 'labels'):
|
|
1469
|
+
if log_fn:
|
|
1470
|
+
log_fn("Creating weighted random sampler...")
|
|
1471
|
+
|
|
1472
|
+
labels_for_sampling = list(train_dataset.labels)
|
|
1473
|
+
|
|
1474
|
+
virtual_mult = int(getattr(train_dataset, "virtual_size_multiplier", 1) or 1)
|
|
1475
|
+
if virtual_mult > 1:
|
|
1476
|
+
labels_for_sampling = labels_for_sampling * virtual_mult
|
|
1477
|
+
|
|
1478
|
+
class_counts = Counter(labels_for_sampling)
|
|
1479
|
+
if log_fn:
|
|
1480
|
+
log_fn(f"Class counts: {dict(class_counts)}")
|
|
1481
|
+
num_samples = len(labels_for_sampling)
|
|
1482
|
+
weights = [1.0 / class_counts[label] for label in labels_for_sampling]
|
|
1483
|
+
sampler = WeightedRandomSampler(weights=weights, num_samples=num_samples, replacement=True)
|
|
1484
|
+
shuffle = False
|
|
1485
|
+
if log_fn:
|
|
1486
|
+
log_fn("Using weighted random sampler for training")
|
|
1487
|
+
|
|
1488
|
+
num_workers = config.get("num_workers", 4)
|
|
1489
|
+
|
|
1490
|
+
if batch_sampler is not None:
|
|
1491
|
+
train_loader = DataLoader(
|
|
1492
|
+
train_dataset,
|
|
1493
|
+
batch_sampler=batch_sampler,
|
|
1494
|
+
num_workers=num_workers,
|
|
1495
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
1496
|
+
persistent_workers=True if num_workers > 0 else False
|
|
1497
|
+
)
|
|
1498
|
+
else:
|
|
1499
|
+
train_loader = DataLoader(
|
|
1500
|
+
train_dataset,
|
|
1501
|
+
batch_size=config["batch_size"],
|
|
1502
|
+
shuffle=shuffle,
|
|
1503
|
+
sampler=sampler,
|
|
1504
|
+
num_workers=num_workers,
|
|
1505
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
1506
|
+
persistent_workers=True if num_workers > 0 else False
|
|
1507
|
+
)
|
|
1508
|
+
|
|
1509
|
+
if log_fn:
|
|
1510
|
+
log_fn(f"Train loader created: {len(train_loader)} batches (workers={num_workers})")
|
|
1511
|
+
|
|
1512
|
+
val_loader = None
|
|
1513
|
+
if val_dataset:
|
|
1514
|
+
val_loader = DataLoader(
|
|
1515
|
+
val_dataset,
|
|
1516
|
+
batch_size=config["batch_size"],
|
|
1517
|
+
shuffle=False,
|
|
1518
|
+
num_workers=num_workers,
|
|
1519
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
1520
|
+
persistent_workers=True if num_workers > 0 else False
|
|
1521
|
+
)
|
|
1522
|
+
if log_fn:
|
|
1523
|
+
log_fn(f"Val loader created: {len(val_loader)} batches")
|
|
1524
|
+
|
|
1525
|
+
if log_fn:
|
|
1526
|
+
log_fn("Creating optimizer and loss function...")
|
|
1527
|
+
|
|
1528
|
+
base_lr = config["lr"]
|
|
1529
|
+
localization_lr = float(config.get("localization_lr", base_lr))
|
|
1530
|
+
classification_lr = float(config.get("classification_lr", base_lr))
|
|
1531
|
+
wd = config.get("weight_decay", 0.001)
|
|
1532
|
+
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
|
|
1533
|
+
raw_model = model.module if hasattr(model, "module") else model
|
|
1534
|
+
use_separate_lr = (
|
|
1535
|
+
getattr(raw_model, "use_localization", False)
|
|
1536
|
+
and getattr(raw_model, "localization_head", None) is not None
|
|
1537
|
+
)
|
|
1538
|
+
if use_separate_lr:
|
|
1539
|
+
loc_params = [p for n, p in param_dict.items() if "localization_head" in n]
|
|
1540
|
+
other_params = [p for n, p in param_dict.items() if "localization_head" not in n]
|
|
1541
|
+
loc_decay = [p for p in loc_params if p.dim() >= 2]
|
|
1542
|
+
loc_nodecay = [p for p in loc_params if p.dim() < 2]
|
|
1543
|
+
cls_decay = [p for p in other_params if p.dim() >= 2]
|
|
1544
|
+
cls_nodecay = [p for p in other_params if p.dim() < 2]
|
|
1545
|
+
optim_groups = []
|
|
1546
|
+
if loc_decay:
|
|
1547
|
+
optim_groups.append({"params": loc_decay, "weight_decay": wd, "lr": localization_lr})
|
|
1548
|
+
if loc_nodecay:
|
|
1549
|
+
optim_groups.append({"params": loc_nodecay, "weight_decay": 0.0, "lr": localization_lr})
|
|
1550
|
+
if cls_decay:
|
|
1551
|
+
optim_groups.append({"params": cls_decay, "weight_decay": wd, "lr": classification_lr})
|
|
1552
|
+
if cls_nodecay:
|
|
1553
|
+
optim_groups.append({"params": cls_nodecay, "weight_decay": 0.0, "lr": classification_lr})
|
|
1554
|
+
else:
|
|
1555
|
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
|
1556
|
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
|
1557
|
+
optim_groups = [
|
|
1558
|
+
{"params": decay_params, "weight_decay": wd, "lr": classification_lr},
|
|
1559
|
+
{"params": nodecay_params, "weight_decay": 0.0, "lr": classification_lr},
|
|
1560
|
+
]
|
|
1561
|
+
optimizer = AdamW(optim_groups, lr=classification_lr)
|
|
1562
|
+
|
|
1563
|
+
# --- Scheduler: Cosine Annealing with Warm Restarts + Linear Warmup ---
|
|
1564
|
+
use_scheduler = bool(config.get('use_scheduler', True))
|
|
1565
|
+
total_epochs = config['epochs']
|
|
1566
|
+
warmup_epochs = 0
|
|
1567
|
+
scheduler = None
|
|
1568
|
+
warmup_scheduler = None
|
|
1569
|
+
restart_period = None
|
|
1570
|
+
eta_min = None
|
|
1571
|
+
if use_scheduler:
|
|
1572
|
+
eta_min = 0.2 * classification_lr
|
|
1573
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
1574
|
+
optimizer, T_max=max(1, total_epochs - warmup_epochs), eta_min=eta_min
|
|
1575
|
+
)
|
|
1576
|
+
warmup_scheduler = None
|
|
1577
|
+
if warmup_epochs > 0:
|
|
1578
|
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
|
1579
|
+
optimizer, start_factor=0.01, end_factor=1.0,
|
|
1580
|
+
total_iters=warmup_epochs,
|
|
1581
|
+
)
|
|
1582
|
+
|
|
1583
|
+
# --- EMA (Exponential Moving Average) for classification head only ---
|
|
1584
|
+
use_ema = bool(config.get("use_ema", True))
|
|
1585
|
+
ema_decay = 0.99
|
|
1586
|
+
ema_state: dict[str, torch.Tensor] = {}
|
|
1587
|
+
ema_active = False # enabled later when ema_start_epoch is reached
|
|
1588
|
+
def _init_ema():
|
|
1589
|
+
raw = model.module if hasattr(model, "module") else model
|
|
1590
|
+
ema_state.clear()
|
|
1591
|
+
for name, param in raw.named_parameters():
|
|
1592
|
+
if param.requires_grad and not name.startswith("localization_head."):
|
|
1593
|
+
ema_state[name] = param.data.clone()
|
|
1594
|
+
def _update_ema():
|
|
1595
|
+
if not ema_active:
|
|
1596
|
+
return
|
|
1597
|
+
raw = model.module if hasattr(model, "module") else model
|
|
1598
|
+
for name, param in raw.named_parameters():
|
|
1599
|
+
if name in ema_state:
|
|
1600
|
+
ema_state[name].mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)
|
|
1601
|
+
def _apply_ema():
|
|
1602
|
+
"""Swap model weights with EMA weights. Call again to restore."""
|
|
1603
|
+
if not ema_active:
|
|
1604
|
+
return
|
|
1605
|
+
raw = model.module if hasattr(model, "module") else model
|
|
1606
|
+
for name, param in raw.named_parameters():
|
|
1607
|
+
if name in ema_state:
|
|
1608
|
+
param.data, ema_state[name] = ema_state[name].clone(), param.data.clone()
|
|
1609
|
+
|
|
1610
|
+
if log_fn:
|
|
1611
|
+
log_fn("Using AdamW with separate weight decay for biases/norm")
|
|
1612
|
+
if use_separate_lr:
|
|
1613
|
+
log_fn(f"Localization LR: {localization_lr:.2e}, Classification LR: {classification_lr:.2e}")
|
|
1614
|
+
else:
|
|
1615
|
+
log_fn(f"Learning rate: {classification_lr:.2e}")
|
|
1616
|
+
log_fn("Gradient clipping: max_norm=1.0 (NaN-guarded)")
|
|
1617
|
+
if use_scheduler:
|
|
1618
|
+
if warmup_epochs > 0:
|
|
1619
|
+
log_fn(f"LR schedule: {warmup_epochs}-epoch linear warmup → CosineAnnealingLR (single decay, {total_epochs - warmup_epochs} epochs)")
|
|
1620
|
+
else:
|
|
1621
|
+
log_fn(f"LR schedule: CosineAnnealingLR (single cosine decay, {total_epochs} epochs, no warmup)")
|
|
1622
|
+
else:
|
|
1623
|
+
log_fn("LR scheduler disabled (fixed learning rate)")
|
|
1624
|
+
if use_ema:
|
|
1625
|
+
if use_separate_lr or getattr(raw_model, "use_localization", False):
|
|
1626
|
+
log_fn(f"EMA: classification head only (decay={ema_decay}); localization head unchanged")
|
|
1627
|
+
else:
|
|
1628
|
+
log_fn(f"EMA model averaging: enabled (decay={ema_decay})")
|
|
1629
|
+
else:
|
|
1630
|
+
log_fn("EMA model averaging: disabled")
|
|
1631
|
+
|
|
1632
|
+
use_class_weights = config.get("use_class_weights", False)
|
|
1633
|
+
|
|
1634
|
+
# Focal Loss settings
|
|
1635
|
+
use_focal_loss = config.get("use_focal_loss", False)
|
|
1636
|
+
focal_gamma = config.get("focal_gamma", 2.0)
|
|
1637
|
+
use_supcon_loss = bool(config.get("use_supcon_loss", False))
|
|
1638
|
+
supcon_weight = float(config.get("supcon_weight", 0.2)) if use_supcon_loss else 0.0
|
|
1639
|
+
supcon_temperature = float(config.get("supcon_temperature", 0.1))
|
|
1640
|
+
|
|
1641
|
+
# Asymmetric Loss settings
|
|
1642
|
+
use_asl = config.get("use_asl", True) # Asymmetric Loss on by default for OvR
|
|
1643
|
+
asl_gamma_neg = float(config.get("asl_gamma_neg", 4.0))
|
|
1644
|
+
asl_gamma_pos = float(config.get("asl_gamma_pos", 0.0))
|
|
1645
|
+
asl_clip = float(config.get("asl_clip", 0.05))
|
|
1646
|
+
|
|
1647
|
+
# ASL's 'clip' parameter already applies negative smoothing, so default
|
|
1648
|
+
# to 0.0 when ASL is active to avoid contradictory loss signals.
|
|
1649
|
+
default_smoothing = 0.0 if (use_ovr and use_asl) else (0.05 if use_ovr else 0.0)
|
|
1650
|
+
ovr_label_smoothing = float(config.get("ovr_label_smoothing", default_smoothing))
|
|
1651
|
+
ovr_background_as_negative = bool(config.get("ovr_background_as_negative", False))
|
|
1652
|
+
allowed_cooccurrence = []
|
|
1653
|
+
cooccur_lookup: dict[int, set[int]] = {}
|
|
1654
|
+
hard_pair_mining = bool(config.get("use_hard_pair_mining", False) and use_ovr)
|
|
1655
|
+
hard_pair_margin = float(config.get("hard_pair_margin", 0.5))
|
|
1656
|
+
hard_pair_loss_weight = float(config.get("hard_pair_loss_weight", 0.2)) if hard_pair_mining else 0.0
|
|
1657
|
+
hard_pair_confusion_boost = max(1.0, float(config.get("hard_pair_confusion_boost", 1.5)))
|
|
1658
|
+
hard_pair_index_pairs: list[tuple[int, int]] = []
|
|
1659
|
+
hard_pair_name_pairs: list[list[str]] = []
|
|
1660
|
+
|
|
1661
|
+
if use_ovr and getattr(model, "frame_head", None) is not None:
|
|
1662
|
+
model.frame_head.use_ovr = True
|
|
1663
|
+
if ovr_background_as_negative and log_fn:
|
|
1664
|
+
bg_names = config.get("ovr_background_class_names", [])
|
|
1665
|
+
bg_txt = ", ".join(bg_names) if bg_names else "background"
|
|
1666
|
+
log_fn(f"OvR hybrid background negatives: {bg_txt} kept as all-zero targets")
|
|
1667
|
+
if hard_pair_mining:
|
|
1668
|
+
class_to_idx = {c: i for i, c in enumerate(getattr(train_dataset, "classes", []))}
|
|
1669
|
+
seen_pairs = set()
|
|
1670
|
+
skipped_pairs = []
|
|
1671
|
+
for pair in (config.get("hard_pairs", []) or []):
|
|
1672
|
+
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
1673
|
+
continue
|
|
1674
|
+
a_name = str(pair[0]).strip()
|
|
1675
|
+
b_name = str(pair[1]).strip()
|
|
1676
|
+
if not a_name or not b_name or a_name == b_name:
|
|
1677
|
+
continue
|
|
1678
|
+
a_idx = class_to_idx.get(a_name, -1)
|
|
1679
|
+
b_idx = class_to_idx.get(b_name, -1)
|
|
1680
|
+
if a_idx < 0 or b_idx < 0:
|
|
1681
|
+
skipped_pairs.append((a_name, b_name))
|
|
1682
|
+
continue
|
|
1683
|
+
key = (min(a_idx, b_idx), max(a_idx, b_idx))
|
|
1684
|
+
if key in seen_pairs:
|
|
1685
|
+
continue
|
|
1686
|
+
seen_pairs.add(key)
|
|
1687
|
+
hard_pair_index_pairs.append(key)
|
|
1688
|
+
hard_pair_name_pairs.append([
|
|
1689
|
+
train_dataset.classes[key[0]],
|
|
1690
|
+
train_dataset.classes[key[1]],
|
|
1691
|
+
])
|
|
1692
|
+
if skipped_pairs and log_fn:
|
|
1693
|
+
skipped_txt = ", ".join(f"{a}<->{b}" for a, b in skipped_pairs[:8])
|
|
1694
|
+
if len(skipped_pairs) > 8:
|
|
1695
|
+
skipped_txt += ", ..."
|
|
1696
|
+
log_fn(f"Hard-pair mining: skipped unknown pair(s): {skipped_txt}")
|
|
1697
|
+
if hard_pair_index_pairs:
|
|
1698
|
+
if log_fn:
|
|
1699
|
+
pair_txt = ", ".join(f"{a}<->{b}" for a, b in hard_pair_name_pairs)
|
|
1700
|
+
log_fn(
|
|
1701
|
+
"Hard-pair mining: "
|
|
1702
|
+
f"{pair_txt} | margin={hard_pair_margin:.2f}, "
|
|
1703
|
+
f"loss_weight={hard_pair_loss_weight:.3f}, "
|
|
1704
|
+
f"confusion_boost={hard_pair_confusion_boost:.2f}x"
|
|
1705
|
+
)
|
|
1706
|
+
else:
|
|
1707
|
+
hard_pair_mining = False
|
|
1708
|
+
hard_pair_loss_weight = 0.0
|
|
1709
|
+
if log_fn:
|
|
1710
|
+
log_fn("Hard-pair mining enabled, but no valid configured pairs matched the active classes.")
|
|
1711
|
+
|
|
1712
|
+
# Frame-level loss is always the sole objective
|
|
1713
|
+
use_frame_loss = True
|
|
1714
|
+
if log_fn:
|
|
1715
|
+
log_fn("Frame-level classification: enabled (sole classification loss)")
|
|
1716
|
+
|
|
1717
|
+
# Boundary and smoothness loss settings
|
|
1718
|
+
use_temporal_decoder = bool(config.get("use_temporal_decoder", True))
|
|
1719
|
+
boundary_loss_weight = float(config.get("boundary_loss_weight", 0.3)) if use_temporal_decoder else 0.0
|
|
1720
|
+
smoothness_loss_weight = float(config.get("smoothness_loss_weight", 0.05))
|
|
1721
|
+
boundary_tolerance = int(config.get("boundary_tolerance", 1))
|
|
1722
|
+
if log_fn:
|
|
1723
|
+
if use_temporal_decoder:
|
|
1724
|
+
log_fn(f"Boundary loss: weight={boundary_loss_weight}, tolerance={boundary_tolerance}")
|
|
1725
|
+
else:
|
|
1726
|
+
log_fn("Temporal decoder: disabled (direct per-frame classifier after spatial pooling)")
|
|
1727
|
+
log_fn("Boundary loss: disabled (no boundary branch)")
|
|
1728
|
+
log_fn(f"Smoothness loss: weight={smoothness_loss_weight}")
|
|
1729
|
+
if use_supcon_loss and supcon_weight > 0:
|
|
1730
|
+
log_fn(
|
|
1731
|
+
f"SupCon on MAP embeddings: enabled "
|
|
1732
|
+
f"(weight={supcon_weight:.3f}, temperature={supcon_temperature:.3f})"
|
|
1733
|
+
)
|
|
1734
|
+
else:
|
|
1735
|
+
log_fn("SupCon on MAP embeddings: disabled")
|
|
1736
|
+
|
|
1737
|
+
use_frame_bout_balance = bool(config.get("use_frame_bout_balance", True))
|
|
1738
|
+
frame_bout_balance_power = float(config.get("frame_bout_balance_power", 1.0))
|
|
1739
|
+
if use_frame_bout_balance and log_fn:
|
|
1740
|
+
log_fn(
|
|
1741
|
+
f"Frame bout balancing: enabled (power={frame_bout_balance_power:.2f}) "
|
|
1742
|
+
f"[reduces dominance of long behavior bouts]"
|
|
1743
|
+
)
|
|
1744
|
+
|
|
1745
|
+
def _generate_boundary_labels(frame_labels: torch.Tensor, tolerance: int = 1) -> torch.Tensor:
|
|
1746
|
+
"""Generate boundary labels from frame labels.
|
|
1747
|
+
|
|
1748
|
+
boundary[t] = 1 if there is a label transition within ±tolerance frames.
|
|
1749
|
+
Frames with label=-1 get boundary label=-1 (ignored).
|
|
1750
|
+
"""
|
|
1751
|
+
B, T = frame_labels.shape
|
|
1752
|
+
boundaries = torch.zeros(B, T, dtype=torch.float32, device=frame_labels.device)
|
|
1753
|
+
for bi in range(B):
|
|
1754
|
+
for t in range(1, T):
|
|
1755
|
+
cur = int(frame_labels[bi, t].item())
|
|
1756
|
+
prev = int(frame_labels[bi, t - 1].item())
|
|
1757
|
+
if cur < 0 or prev < 0:
|
|
1758
|
+
continue
|
|
1759
|
+
if cur != prev:
|
|
1760
|
+
lo = max(0, t - tolerance)
|
|
1761
|
+
hi = min(T, t + tolerance + 1)
|
|
1762
|
+
boundaries[bi, lo:hi] = 1.0
|
|
1763
|
+
# Mark invalid frames as -1
|
|
1764
|
+
boundaries[frame_labels < 0] = -1.0
|
|
1765
|
+
return boundaries
|
|
1766
|
+
|
|
1767
|
+
def _pool_frame_labels(frame_labels: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
|
|
1768
|
+
"""Downsample [B, T] labels to pooled timeline by majority vote (ignore=-1)."""
|
|
1769
|
+
if frame_labels is None or pool <= 1:
|
|
1770
|
+
return frame_labels
|
|
1771
|
+
if frame_labels.dim() != 2:
|
|
1772
|
+
return frame_labels
|
|
1773
|
+
B, T = frame_labels.shape
|
|
1774
|
+
pad = (pool - (T % pool)) % pool
|
|
1775
|
+
lbl = frame_labels
|
|
1776
|
+
if pad > 0:
|
|
1777
|
+
pad_vals = lbl[:, -1:].repeat(1, pad)
|
|
1778
|
+
lbl = torch.cat([lbl, pad_vals], dim=1)
|
|
1779
|
+
Tp = lbl.shape[1] // pool
|
|
1780
|
+
chunks = lbl.view(B, Tp, pool)
|
|
1781
|
+
pooled = torch.full((B, Tp), -1, dtype=lbl.dtype, device=lbl.device)
|
|
1782
|
+
for bi in range(B):
|
|
1783
|
+
for ti in range(Tp):
|
|
1784
|
+
vals = chunks[bi, ti]
|
|
1785
|
+
valid = vals[vals >= 0]
|
|
1786
|
+
if valid.numel() == 0:
|
|
1787
|
+
continue
|
|
1788
|
+
uniq, counts = torch.unique(valid, return_counts=True)
|
|
1789
|
+
pooled[bi, ti] = uniq[torch.argmax(counts)]
|
|
1790
|
+
return pooled
|
|
1791
|
+
|
|
1792
|
+
def _pool_binary_mask(mask: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
|
|
1793
|
+
"""Downsample a boolean [B, T] mask by any() over each pooled window."""
|
|
1794
|
+
if mask is None or pool <= 1:
|
|
1795
|
+
return mask
|
|
1796
|
+
if mask.dim() != 2:
|
|
1797
|
+
return mask
|
|
1798
|
+
B, T = mask.shape
|
|
1799
|
+
pad = (pool - (T % pool)) % pool
|
|
1800
|
+
m = mask
|
|
1801
|
+
if pad > 0:
|
|
1802
|
+
pad_vals = m[:, -1:].repeat(1, pad)
|
|
1803
|
+
m = torch.cat([m, pad_vals], dim=1)
|
|
1804
|
+
Tp = m.shape[1] // pool
|
|
1805
|
+
return m.view(B, Tp, pool).any(dim=2)
|
|
1806
|
+
|
|
1807
|
+
def _pool_frame_embeddings(frame_embeddings: Optional[torch.Tensor], pool: int) -> Optional[torch.Tensor]:
|
|
1808
|
+
"""Average [B, T, D] embeddings over pooled windows to match pooled labels."""
|
|
1809
|
+
if frame_embeddings is None or pool <= 1:
|
|
1810
|
+
return frame_embeddings
|
|
1811
|
+
if frame_embeddings.dim() != 3:
|
|
1812
|
+
return frame_embeddings
|
|
1813
|
+
B, T, D = frame_embeddings.shape
|
|
1814
|
+
pad = (pool - (T % pool)) % pool
|
|
1815
|
+
emb = frame_embeddings
|
|
1816
|
+
if pad > 0:
|
|
1817
|
+
pad_vals = emb[:, -1:, :].repeat(1, pad, 1)
|
|
1818
|
+
emb = torch.cat([emb, pad_vals], dim=1)
|
|
1819
|
+
Tp = emb.shape[1] // pool
|
|
1820
|
+
return emb.view(B, Tp, pool, D).mean(dim=2)
|
|
1821
|
+
|
|
1822
|
+
def _supervised_contrastive_loss(
|
|
1823
|
+
frame_embeddings: Optional[torch.Tensor],
|
|
1824
|
+
frame_labels: Optional[torch.Tensor],
|
|
1825
|
+
temperature: float = 0.1,
|
|
1826
|
+
max_samples: int = 512,
|
|
1827
|
+
) -> torch.Tensor:
|
|
1828
|
+
"""SupCon over pooled frame embeddings using valid frame labels only."""
|
|
1829
|
+
if frame_embeddings is None or frame_labels is None:
|
|
1830
|
+
if frame_embeddings is not None:
|
|
1831
|
+
return frame_embeddings.sum() * 0.0
|
|
1832
|
+
return torch.tensor(0.0, device=device)
|
|
1833
|
+
if frame_embeddings.dim() != 3 or frame_labels.dim() != 2:
|
|
1834
|
+
return frame_embeddings.sum() * 0.0
|
|
1835
|
+
feats = frame_embeddings.reshape(-1, frame_embeddings.shape[-1])
|
|
1836
|
+
labels_flat = frame_labels.reshape(-1)
|
|
1837
|
+
valid = labels_flat >= 0
|
|
1838
|
+
if valid.sum().item() < 2:
|
|
1839
|
+
return feats.sum() * 0.0
|
|
1840
|
+
feats = feats[valid]
|
|
1841
|
+
labels_flat = labels_flat[valid]
|
|
1842
|
+
if feats.shape[0] > max_samples:
|
|
1843
|
+
perm = torch.randperm(feats.shape[0], device=feats.device)[:max_samples]
|
|
1844
|
+
feats = feats[perm]
|
|
1845
|
+
labels_flat = labels_flat[perm]
|
|
1846
|
+
uniq_labels, counts = torch.unique(labels_flat, return_counts=True)
|
|
1847
|
+
keep_labels = uniq_labels[counts >= 2]
|
|
1848
|
+
if keep_labels.numel() == 0:
|
|
1849
|
+
return feats.sum() * 0.0
|
|
1850
|
+
keep_mask = torch.zeros_like(labels_flat, dtype=torch.bool)
|
|
1851
|
+
for cls_id in keep_labels:
|
|
1852
|
+
keep_mask |= labels_flat == cls_id
|
|
1853
|
+
feats = feats[keep_mask]
|
|
1854
|
+
labels_flat = labels_flat[keep_mask]
|
|
1855
|
+
if feats.shape[0] < 2:
|
|
1856
|
+
return feats.sum() * 0.0
|
|
1857
|
+
|
|
1858
|
+
feats = F.normalize(feats, p=2, dim=1)
|
|
1859
|
+
logits = torch.matmul(feats, feats.T) / max(float(temperature), 1e-6)
|
|
1860
|
+
logits = logits - logits.max(dim=1, keepdim=True).values.detach()
|
|
1861
|
+
|
|
1862
|
+
same_class = labels_flat.unsqueeze(0) == labels_flat.unsqueeze(1)
|
|
1863
|
+
self_mask = torch.eye(feats.shape[0], device=feats.device, dtype=torch.bool)
|
|
1864
|
+
pos_mask = same_class & (~self_mask)
|
|
1865
|
+
valid_anchor = pos_mask.any(dim=1)
|
|
1866
|
+
if not valid_anchor.any():
|
|
1867
|
+
return feats.sum() * 0.0
|
|
1868
|
+
|
|
1869
|
+
exp_logits = torch.exp(logits) * (~self_mask)
|
|
1870
|
+
log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True).clamp(min=1e-12))
|
|
1871
|
+
mean_log_prob_pos = (pos_mask.float() * log_prob).sum(dim=1) / pos_mask.sum(dim=1).clamp(min=1)
|
|
1872
|
+
return -mean_log_prob_pos[valid_anchor].mean()
|
|
1873
|
+
|
|
1874
|
+
def _lookup_ovr_suppress_for_batch(sample_indices, dataset, device) -> Optional[torch.Tensor]:
|
|
1875
|
+
"""Map batch sample indices back to per-clip OvR suppression targets."""
|
|
1876
|
+
suppress_list = getattr(dataset, "ovr_suppress_idx", None)
|
|
1877
|
+
if not suppress_list:
|
|
1878
|
+
return None
|
|
1879
|
+
if isinstance(sample_indices, torch.Tensor):
|
|
1880
|
+
indices = sample_indices.detach().cpu().tolist()
|
|
1881
|
+
else:
|
|
1882
|
+
indices = list(sample_indices)
|
|
1883
|
+
mapped = []
|
|
1884
|
+
for idx in indices:
|
|
1885
|
+
try:
|
|
1886
|
+
clip_idx = int(idx)
|
|
1887
|
+
except Exception as e:
|
|
1888
|
+
logger.debug("Could not convert index to int: %s", e)
|
|
1889
|
+
clip_idx = -1
|
|
1890
|
+
if 0 <= clip_idx < len(suppress_list):
|
|
1891
|
+
mapped.append(int(suppress_list[clip_idx]))
|
|
1892
|
+
else:
|
|
1893
|
+
mapped.append(-1)
|
|
1894
|
+
return torch.tensor(mapped, dtype=torch.long, device=device)
|
|
1895
|
+
|
|
1896
|
+
def _build_bout_weights(frame_labels: torch.Tensor, power: float = 1.0) -> torch.Tensor:
|
|
1897
|
+
"""Per-frame weights inversely proportional to contiguous bout length."""
|
|
1898
|
+
B, T = frame_labels.shape
|
|
1899
|
+
weights = torch.zeros((B, T), dtype=torch.float32, device=frame_labels.device)
|
|
1900
|
+
for bi in range(B):
|
|
1901
|
+
t = 0
|
|
1902
|
+
while t < T:
|
|
1903
|
+
lbl = int(frame_labels[bi, t].item())
|
|
1904
|
+
if lbl < 0:
|
|
1905
|
+
t += 1
|
|
1906
|
+
continue
|
|
1907
|
+
t_end = t + 1
|
|
1908
|
+
while t_end < T and int(frame_labels[bi, t_end].item()) == lbl:
|
|
1909
|
+
t_end += 1
|
|
1910
|
+
seg_len = max(1, t_end - t)
|
|
1911
|
+
w = float(seg_len) ** (-float(power))
|
|
1912
|
+
weights[bi, t:t_end] = w
|
|
1913
|
+
t = t_end
|
|
1914
|
+
valid = frame_labels >= 0
|
|
1915
|
+
if valid.any():
|
|
1916
|
+
mean_w = weights[valid].mean().clamp(min=1e-8)
|
|
1917
|
+
weights = weights / mean_w
|
|
1918
|
+
return weights
|
|
1919
|
+
|
|
1920
|
+
def _hard_pair_margin_loss(
|
|
1921
|
+
frame_logits: torch.Tensor,
|
|
1922
|
+
frame_labels: torch.Tensor,
|
|
1923
|
+
pair_indices: list[tuple[int, int]],
|
|
1924
|
+
margin: float,
|
|
1925
|
+
use_bout_balance: bool = False,
|
|
1926
|
+
bout_power: float = 1.0,
|
|
1927
|
+
) -> torch.Tensor:
|
|
1928
|
+
"""Extra pairwise margin pressure for configured confusing class pairs."""
|
|
1929
|
+
if not pair_indices or margin <= 0:
|
|
1930
|
+
return frame_logits.sum() * 0.0
|
|
1931
|
+
B, T, _ = frame_logits.shape
|
|
1932
|
+
frame_w = torch.ones((B, T), dtype=torch.float32, device=frame_logits.device)
|
|
1933
|
+
if use_bout_balance:
|
|
1934
|
+
frame_w = _build_bout_weights(frame_labels, power=bout_power)
|
|
1935
|
+
total_loss = frame_logits.new_tensor(0.0)
|
|
1936
|
+
total_weight = frame_logits.new_tensor(0.0)
|
|
1937
|
+
for a_idx, b_idx in pair_indices:
|
|
1938
|
+
mask_a = frame_labels == a_idx
|
|
1939
|
+
if mask_a.any():
|
|
1940
|
+
loss_a = F.relu(margin - (frame_logits[..., a_idx] - frame_logits[..., b_idx]))
|
|
1941
|
+
total_loss = total_loss + (loss_a[mask_a] * frame_w[mask_a]).sum()
|
|
1942
|
+
total_weight = total_weight + frame_w[mask_a].sum()
|
|
1943
|
+
mask_b = frame_labels == b_idx
|
|
1944
|
+
if mask_b.any():
|
|
1945
|
+
loss_b = F.relu(margin - (frame_logits[..., b_idx] - frame_logits[..., a_idx]))
|
|
1946
|
+
total_loss = total_loss + (loss_b[mask_b] * frame_w[mask_b]).sum()
|
|
1947
|
+
total_weight = total_weight + frame_w[mask_b].sum()
|
|
1948
|
+
if float(total_weight.item()) <= 0:
|
|
1949
|
+
return frame_logits.sum() * 0.0
|
|
1950
|
+
return total_loss / total_weight.clamp(min=1e-6)
|
|
1951
|
+
|
|
1952
|
+
def _frame_loss_balanced(
|
|
1953
|
+
frame_logits: torch.Tensor,
|
|
1954
|
+
frame_labels: torch.Tensor,
|
|
1955
|
+
use_ovr_local: bool = False,
|
|
1956
|
+
ovr_targets: Optional[torch.Tensor] = None,
|
|
1957
|
+
ovr_weight: Optional[torch.Tensor] = None,
|
|
1958
|
+
use_bout_balance: bool = False,
|
|
1959
|
+
bout_power: float = 1.0,
|
|
1960
|
+
valid_mask_override: Optional[torch.Tensor] = None,
|
|
1961
|
+
) -> torch.Tensor:
|
|
1962
|
+
"""Frame classification loss with optional inverse-bout-length weighting."""
|
|
1963
|
+
B, T, C = frame_logits.shape
|
|
1964
|
+
valid = valid_mask_override if valid_mask_override is not None else (frame_labels >= 0)
|
|
1965
|
+
if not valid.any():
|
|
1966
|
+
return frame_logits.sum() * 0.0
|
|
1967
|
+
|
|
1968
|
+
frame_w = torch.ones((B, T), dtype=torch.float32, device=frame_logits.device)
|
|
1969
|
+
if use_bout_balance:
|
|
1970
|
+
base_w = _build_bout_weights(frame_labels, power=bout_power)
|
|
1971
|
+
if valid_mask_override is not None:
|
|
1972
|
+
frame_w = torch.where(
|
|
1973
|
+
valid,
|
|
1974
|
+
torch.where(frame_labels >= 0, base_w, torch.ones_like(base_w)),
|
|
1975
|
+
torch.zeros_like(base_w),
|
|
1976
|
+
)
|
|
1977
|
+
else:
|
|
1978
|
+
frame_w = base_w
|
|
1979
|
+
|
|
1980
|
+
if use_ovr_local and ovr_targets is not None:
|
|
1981
|
+
if use_asl and isinstance(criterion, AsymmetricLoss):
|
|
1982
|
+
per_elem = criterion(frame_logits, ovr_targets)
|
|
1983
|
+
else:
|
|
1984
|
+
per_elem = F.binary_cross_entropy_with_logits(
|
|
1985
|
+
frame_logits, ovr_targets, reduction='none'
|
|
1986
|
+
)
|
|
1987
|
+
if use_focal_loss:
|
|
1988
|
+
pt = torch.exp(-per_elem)
|
|
1989
|
+
per_elem = ((1 - pt) ** focal_gamma) * per_elem
|
|
1990
|
+
|
|
1991
|
+
elem_w = valid.unsqueeze(-1).float() * frame_w.unsqueeze(-1)
|
|
1992
|
+
if ovr_weight is not None:
|
|
1993
|
+
elem_w = elem_w * ovr_weight
|
|
1994
|
+
denom = elem_w.sum().clamp(min=1.0)
|
|
1995
|
+
return (per_elem * elem_w).sum() / denom
|
|
1996
|
+
|
|
1997
|
+
logits_flat = frame_logits.reshape(B * T, C)
|
|
1998
|
+
labels_flat = frame_labels.reshape(B * T)
|
|
1999
|
+
valid_flat = labels_flat >= 0
|
|
2000
|
+
if not valid_flat.any():
|
|
2001
|
+
return frame_logits.sum() * 0.0
|
|
2002
|
+
logits_valid = logits_flat[valid_flat]
|
|
2003
|
+
labels_valid = labels_flat[valid_flat]
|
|
2004
|
+
raw_ce = F.cross_entropy(logits_valid, labels_valid, reduction='none')
|
|
2005
|
+
w_flat = frame_w.reshape(B * T)[valid_flat]
|
|
2006
|
+
denom = w_flat.sum().clamp(min=1.0)
|
|
2007
|
+
return (raw_ce * w_flat).sum() / denom
|
|
2008
|
+
|
|
2009
|
+
# Localization supervision settings (autonomous staged curriculum)
|
|
2010
|
+
use_localization = bool(config.get("use_localization", False) and getattr(model, "use_localization", False))
|
|
2011
|
+
has_any_localization = use_localization and any(
|
|
2012
|
+
float(v.sum().item()) > 0.5 for v in getattr(train_dataset, "spatial_bbox_valid", [])
|
|
2013
|
+
)
|
|
2014
|
+
# Default localization cap follows total training epochs from GUI unless explicitly overridden.
|
|
2015
|
+
loc_max_stage_epochs = int(config.get("localization_stage_max_epochs", config.get("epochs", 20)))
|
|
2016
|
+
use_manual_loc_switch = bool(config.get("use_manual_localization_switch", False))
|
|
2017
|
+
manual_loc_switch_epoch = int(config.get("manual_localization_switch_epoch", 20))
|
|
2018
|
+
loc_gate_patience = int(config.get("localization_gate_patience", 2))
|
|
2019
|
+
loc_gate_iou_threshold = float(config.get("localization_gate_iou", 0.55))
|
|
2020
|
+
loc_gate_center_error = float(config.get("localization_gate_center_error", 0.15))
|
|
2021
|
+
loc_gate_valid_rate = float(config.get("localization_gate_valid_rate", 0.9))
|
|
2022
|
+
crop_mix_start_gt = float(config.get("classification_crop_gt_prob_start", 1.0))
|
|
2023
|
+
crop_mix_end_gt = float(config.get("classification_crop_gt_prob_end", 0.0))
|
|
2024
|
+
crop_padding = float(config.get("classification_crop_padding", 0.35))
|
|
2025
|
+
crop_min_size = float(config.get("classification_crop_min_size_norm", 0.04))
|
|
2026
|
+
enable_roi_cache = True # Always use precomputed crops for classification stage
|
|
2027
|
+
center_heatmap_weight = float(config.get("center_heatmap_weight", 1.0))
|
|
2028
|
+
center_heatmap_sigma = float(config.get("center_heatmap_sigma", 2.5))
|
|
2029
|
+
direct_center_weight = float(config.get("direct_center_weight", 2.0))
|
|
2030
|
+
# Use the largest bbox across all classes so every crop has the same
|
|
2031
|
+
# extent. This prevents the classifier from using background size as a
|
|
2032
|
+
# class cue and keeps training/inference consistent (inference doesn't
|
|
2033
|
+
# know the class label).
|
|
2034
|
+
global_fixed_wh = (0.2, 0.2)
|
|
2035
|
+
if use_localization and has_any_localization:
|
|
2036
|
+
all_w = []
|
|
2037
|
+
all_h = []
|
|
2038
|
+
max_n = min(
|
|
2039
|
+
len(getattr(train_dataset, "spatial_bboxes", [])),
|
|
2040
|
+
len(getattr(train_dataset, "spatial_bbox_valid", [])),
|
|
2041
|
+
)
|
|
2042
|
+
for i in range(max_n):
|
|
2043
|
+
bboxes_i = train_dataset.spatial_bboxes[i] # [T, 4]
|
|
2044
|
+
valid_i = train_dataset.spatial_bbox_valid[i] # [T]
|
|
2045
|
+
for t in range(bboxes_i.size(0)):
|
|
2046
|
+
if float(valid_i[t].item()) <= 0.5:
|
|
2047
|
+
continue
|
|
2048
|
+
x1, y1, x2, y2 = [float(v) for v in bboxes_i[t].tolist()]
|
|
2049
|
+
w = max(1e-4, min(1.0, x2 - x1))
|
|
2050
|
+
h = max(1e-4, min(1.0, y2 - y1))
|
|
2051
|
+
all_w.append(w)
|
|
2052
|
+
all_h.append(h)
|
|
2053
|
+
|
|
2054
|
+
if all_w and all_h:
|
|
2055
|
+
gw = float(max(all_w))
|
|
2056
|
+
gh = float(max(all_h))
|
|
2057
|
+
global_fixed_wh = (max(1e-4, min(1.0, gw)), max(1e-4, min(1.0, gh)))
|
|
2058
|
+
|
|
2059
|
+
raw_model = model.module if hasattr(model, "module") else model
|
|
2060
|
+
if getattr(raw_model, "use_localization", False) and getattr(raw_model, "localization_head", None) is not None:
|
|
2061
|
+
raw_model.localization_head.set_fixed_box_size(global_fixed_wh[0], global_fixed_wh[1])
|
|
2062
|
+
|
|
2063
|
+
def _fixed_wh_for_labels(label_tensor: torch.Tensor, device: torch.device, dtype: torch.dtype) -> Optional[torch.Tensor]:
|
|
2064
|
+
"""Return the single global fixed box size for each sample in the batch."""
|
|
2065
|
+
if label_tensor is None:
|
|
2066
|
+
return None
|
|
2067
|
+
if not torch.is_tensor(label_tensor):
|
|
2068
|
+
return None
|
|
2069
|
+
if label_tensor.dim() == 0:
|
|
2070
|
+
label_tensor = label_tensor.view(1)
|
|
2071
|
+
B = label_tensor.size(0)
|
|
2072
|
+
wh = [list(global_fixed_wh)] * B
|
|
2073
|
+
return torch.tensor(wh, device=device, dtype=dtype)
|
|
2074
|
+
if use_localization and log_fn:
|
|
2075
|
+
if has_any_localization:
|
|
2076
|
+
n_with_bbox = sum(1 for v in train_dataset.spatial_bbox_valid if float(v.sum().item()) > 0.5)
|
|
2077
|
+
log_fn(
|
|
2078
|
+
f"Localization Supervision: enabled ({n_with_bbox} clips with bbox)"
|
|
2079
|
+
)
|
|
2080
|
+
if enable_roi_cache:
|
|
2081
|
+
log_fn("Classification ROI cache: enabled (precompute crops once, then augment in-memory).")
|
|
2082
|
+
if center_heatmap_weight > 0:
|
|
2083
|
+
log_fn(
|
|
2084
|
+
f"Center Heatmap Loss (Gaussian Focal): enabled (weight={center_heatmap_weight}, "
|
|
2085
|
+
f"sigma_patches={center_heatmap_sigma})"
|
|
2086
|
+
)
|
|
2087
|
+
if direct_center_weight > 0:
|
|
2088
|
+
log_fn(
|
|
2089
|
+
f"Direct Center Loss: enabled (weight={direct_center_weight})"
|
|
2090
|
+
)
|
|
2091
|
+
log_fn(
|
|
2092
|
+
f"Localization fixed box size: max across all classes, "
|
|
2093
|
+
f"w={global_fixed_wh[0]:.4f}, h={global_fixed_wh[1]:.4f}"
|
|
2094
|
+
)
|
|
2095
|
+
if use_manual_loc_switch:
|
|
2096
|
+
log_fn(
|
|
2097
|
+
f"Manual localization switch enabled: phase will switch at epoch {manual_loc_switch_epoch}"
|
|
2098
|
+
)
|
|
2099
|
+
else:
|
|
2100
|
+
log_fn("Localization Supervision: enabled but no clips have bbox annotations — will be skipped")
|
|
2101
|
+
|
|
2102
|
+
def _split_localization_output(output):
|
|
2103
|
+
if (
|
|
2104
|
+
use_localization
|
|
2105
|
+
and isinstance(output, tuple)
|
|
2106
|
+
and len(output) >= 2
|
|
2107
|
+
and torch.is_tensor(output[-1])
|
|
2108
|
+
and output[-1].dim() in (2, 3)
|
|
2109
|
+
and output[-1].size(-1) == 4
|
|
2110
|
+
):
|
|
2111
|
+
head_out = output[:-1]
|
|
2112
|
+
if len(head_out) == 1:
|
|
2113
|
+
head_out = head_out[0]
|
|
2114
|
+
return head_out, output[-1]
|
|
2115
|
+
return output, None
|
|
2116
|
+
|
|
2117
|
+
def _bbox_iou(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
2118
|
+
x1 = torch.maximum(pred[:, 0], target[:, 0])
|
|
2119
|
+
y1 = torch.maximum(pred[:, 1], target[:, 1])
|
|
2120
|
+
x2 = torch.minimum(pred[:, 2], target[:, 2])
|
|
2121
|
+
y2 = torch.minimum(pred[:, 3], target[:, 3])
|
|
2122
|
+
inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
|
|
2123
|
+
area_p = (pred[:, 2] - pred[:, 0]).clamp(min=0) * (pred[:, 3] - pred[:, 1]).clamp(min=0)
|
|
2124
|
+
area_t = (target[:, 2] - target[:, 0]).clamp(min=0) * (target[:, 3] - target[:, 1]).clamp(min=0)
|
|
2125
|
+
union = area_p + area_t - inter
|
|
2126
|
+
return inter / (union + 1e-6)
|
|
2127
|
+
|
|
2128
|
+
def _localization_metrics(pred_bboxes: torch.Tensor, target_bboxes: torch.Tensor, valid_mask: torch.Tensor):
|
|
2129
|
+
if pred_bboxes is None:
|
|
2130
|
+
return 0.0, 1.0, 0.0
|
|
2131
|
+
if pred_bboxes.dim() == 3:
|
|
2132
|
+
pred_bboxes = pred_bboxes[:, 0, :]
|
|
2133
|
+
valid = valid_mask > 0.5
|
|
2134
|
+
if int(valid.sum().item()) == 0:
|
|
2135
|
+
return 0.0, 1.0, 0.0
|
|
2136
|
+
pred = pred_bboxes[valid]
|
|
2137
|
+
tgt = target_bboxes[valid]
|
|
2138
|
+
iou = _bbox_iou(pred, tgt).mean().item()
|
|
2139
|
+
pred_cx = 0.5 * (pred[:, 0] + pred[:, 2])
|
|
2140
|
+
pred_cy = 0.5 * (pred[:, 1] + pred[:, 3])
|
|
2141
|
+
tgt_cx = 0.5 * (tgt[:, 0] + tgt[:, 2])
|
|
2142
|
+
tgt_cy = 0.5 * (tgt[:, 1] + tgt[:, 3])
|
|
2143
|
+
center_err = torch.sqrt((pred_cx - tgt_cx) ** 2 + (pred_cy - tgt_cy) ** 2).mean().item()
|
|
2144
|
+
valid_pred_rate = float(
|
|
2145
|
+
((pred[:, 2] > pred[:, 0]) & (pred[:, 3] > pred[:, 1])).float().mean().item()
|
|
2146
|
+
)
|
|
2147
|
+
return float(iou), float(center_err), valid_pred_rate
|
|
2148
|
+
|
|
2149
|
+
def _sanitize_bboxes(bboxes: torch.Tensor) -> torch.Tensor:
|
|
2150
|
+
"""Pad bboxes proportionally to their own size, not the full frame.
|
|
2151
|
+
crop_padding is now a fraction of the bbox dimension (e.g. 0.20 = 20% of bbox w/h).
|
|
2152
|
+
"""
|
|
2153
|
+
boxes = bboxes.clone()
|
|
2154
|
+
orig_shape = boxes.shape
|
|
2155
|
+
boxes = boxes.view(-1, 4)
|
|
2156
|
+
x1 = boxes[:, 0].clamp(0.0, 1.0)
|
|
2157
|
+
y1 = boxes[:, 1].clamp(0.0, 1.0)
|
|
2158
|
+
x2 = boxes[:, 2].clamp(0.0, 1.0)
|
|
2159
|
+
y2 = boxes[:, 3].clamp(0.0, 1.0)
|
|
2160
|
+
cx = 0.5 * (x1 + x2)
|
|
2161
|
+
cy = 0.5 * (y1 + y2)
|
|
2162
|
+
w = (x2 - x1).abs().clamp(min=crop_min_size)
|
|
2163
|
+
h = (y2 - y1).abs().clamp(min=crop_min_size)
|
|
2164
|
+
# Proportional padding: add crop_padding * bbox_size on each side
|
|
2165
|
+
w = torch.clamp(w * (1.0 + 2.0 * crop_padding), max=1.0)
|
|
2166
|
+
h = torch.clamp(h * (1.0 + 2.0 * crop_padding), max=1.0)
|
|
2167
|
+
x1 = (cx - 0.5 * w).clamp(0.0, 1.0)
|
|
2168
|
+
y1 = (cy - 0.5 * h).clamp(0.0, 1.0)
|
|
2169
|
+
x2 = (cx + 0.5 * w).clamp(0.0, 1.0)
|
|
2170
|
+
y2 = (cy + 0.5 * h).clamp(0.0, 1.0)
|
|
2171
|
+
boxes[:, 0] = x1
|
|
2172
|
+
boxes[:, 1] = y1
|
|
2173
|
+
boxes[:, 2] = torch.maximum(x2, x1 + crop_min_size).clamp(0.0, 1.0)
|
|
2174
|
+
boxes[:, 3] = torch.maximum(y2, y1 + crop_min_size).clamp(0.0, 1.0)
|
|
2175
|
+
return boxes.view(orig_shape)
|
|
2176
|
+
|
|
2177
|
+
def _clamp_bboxes_no_expand(bboxes: torch.Tensor) -> torch.Tensor:
|
|
2178
|
+
"""Clamp/reorder bbox coordinates without padding or min-size expansion."""
|
|
2179
|
+
boxes = bboxes.clone()
|
|
2180
|
+
orig_shape = boxes.shape
|
|
2181
|
+
boxes = boxes.view(-1, 4)
|
|
2182
|
+
x1 = boxes[:, 0].clamp(0.0, 1.0)
|
|
2183
|
+
y1 = boxes[:, 1].clamp(0.0, 1.0)
|
|
2184
|
+
x2 = boxes[:, 2].clamp(0.0, 1.0)
|
|
2185
|
+
y2 = boxes[:, 3].clamp(0.0, 1.0)
|
|
2186
|
+
lo_x = torch.minimum(x1, x2)
|
|
2187
|
+
hi_x = torch.maximum(x1, x2)
|
|
2188
|
+
lo_y = torch.minimum(y1, y2)
|
|
2189
|
+
hi_y = torch.maximum(y1, y2)
|
|
2190
|
+
boxes[:, 0] = lo_x
|
|
2191
|
+
boxes[:, 1] = lo_y
|
|
2192
|
+
boxes[:, 2] = torch.maximum(hi_x, lo_x + 1e-4).clamp(0.0, 1.0)
|
|
2193
|
+
boxes[:, 3] = torch.maximum(hi_y, lo_y + 1e-4).clamp(0.0, 1.0)
|
|
2194
|
+
return boxes.view(orig_shape)
|
|
2195
|
+
|
|
2196
|
+
def _crop_clips_with_bboxes(clips: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
|
|
2197
|
+
# clips: [B, T, C, H, W], bboxes: [B, 4] normalized.
|
|
2198
|
+
B, T, C, H, W = clips.shape
|
|
2199
|
+
out = torch.empty_like(clips)
|
|
2200
|
+
boxes = _sanitize_bboxes(bboxes)
|
|
2201
|
+
for i in range(B):
|
|
2202
|
+
x1 = int(round(float(boxes[i, 0].item()) * (W - 1)))
|
|
2203
|
+
y1 = int(round(float(boxes[i, 1].item()) * (H - 1)))
|
|
2204
|
+
x2 = int(round(float(boxes[i, 2].item()) * W))
|
|
2205
|
+
y2 = int(round(float(boxes[i, 3].item()) * H))
|
|
2206
|
+
x1 = max(0, min(x1, W - 1))
|
|
2207
|
+
y1 = max(0, min(y1, H - 1))
|
|
2208
|
+
x2 = max(x1 + 1, min(x2, W))
|
|
2209
|
+
y2 = max(y1 + 1, min(y2, H))
|
|
2210
|
+
sample = clips[i] # [T, C, H, W]
|
|
2211
|
+
cropped = sample[:, :, y1:y2, x1:x2]
|
|
2212
|
+
if cropped.size(-1) < 2 or cropped.size(-2) < 2:
|
|
2213
|
+
out[i] = sample
|
|
2214
|
+
continue
|
|
2215
|
+
out[i] = F.interpolate(cropped, size=(H, W), mode="bilinear", align_corners=False)
|
|
2216
|
+
return out
|
|
2217
|
+
|
|
2218
|
+
def _crop_single_clip_to_target(clip_tchw: torch.Tensor, bbox_xyxy: torch.Tensor, out_h: int, out_w: int) -> torch.Tensor:
|
|
2219
|
+
# clip_tchw: [T,C,H,W], bbox_xyxy: [4] or [T,4] normalized
|
|
2220
|
+
T, C, H, W = clip_tchw.shape
|
|
2221
|
+
box = _sanitize_bboxes(bbox_xyxy)
|
|
2222
|
+
if box.dim() == 1:
|
|
2223
|
+
box = box.view(1, 4).repeat(T, 1)
|
|
2224
|
+
elif box.dim() == 2 and box.size(0) != T:
|
|
2225
|
+
if box.size(0) < T:
|
|
2226
|
+
box = torch.cat([box, box[-1:].repeat(T - box.size(0), 1)], dim=0)
|
|
2227
|
+
else:
|
|
2228
|
+
box = box[:T]
|
|
2229
|
+
|
|
2230
|
+
out_frames = []
|
|
2231
|
+
for ti in range(T):
|
|
2232
|
+
bt = box[ti]
|
|
2233
|
+
x1 = int(round(float(bt[0].item()) * (W - 1)))
|
|
2234
|
+
y1 = int(round(float(bt[1].item()) * (H - 1)))
|
|
2235
|
+
x2 = int(round(float(bt[2].item()) * W))
|
|
2236
|
+
y2 = int(round(float(bt[3].item()) * H))
|
|
2237
|
+
x1 = max(0, min(x1, W - 1))
|
|
2238
|
+
y1 = max(0, min(y1, H - 1))
|
|
2239
|
+
x2 = max(x1 + 1, min(x2, W))
|
|
2240
|
+
y2 = max(y1 + 1, min(y2, H))
|
|
2241
|
+
frame = clip_tchw[ti : ti + 1] # [1,C,H,W]
|
|
2242
|
+
cropped = frame[:, :, y1:y2, x1:x2]
|
|
2243
|
+
if cropped.size(-2) < 2 or cropped.size(-1) < 2:
|
|
2244
|
+
resized = F.interpolate(frame, size=(out_h, out_w), mode="bilinear", align_corners=False)
|
|
2245
|
+
else:
|
|
2246
|
+
resized = F.interpolate(cropped, size=(out_h, out_w), mode="bilinear", align_corners=False)
|
|
2247
|
+
out_frames.append(resized[0])
|
|
2248
|
+
return torch.stack(out_frames, dim=0)
|
|
2249
|
+
|
|
2250
|
+
def _lock_temporal_box_size_from_first(bboxes: torch.Tensor) -> torch.Tensor:
|
|
2251
|
+
"""For [B,T,4], use the frame-0 box for all frames (fully fixed crop)."""
|
|
2252
|
+
if bboxes.dim() != 3 or bboxes.size(1) < 1:
|
|
2253
|
+
return bboxes
|
|
2254
|
+
first = bboxes[:, 0:1, :].clone() # [B,1,4]
|
|
2255
|
+
return first.expand(-1, bboxes.size(1), -1).contiguous()
|
|
2256
|
+
|
|
2257
|
+
def _precompute_roi_cache(dataset_obj, split_name: str = "train") -> str:
|
|
2258
|
+
"""Precompute localized crops and save to disk as .pt files.
|
|
2259
|
+
|
|
2260
|
+
Returns the directory path containing the saved crops.
|
|
2261
|
+
The dataset __getitem__ can then load these directly with workers.
|
|
2262
|
+
"""
|
|
2263
|
+
cache_dir = os.path.join(output_dir_base, f"{basename}_roi_crops_{split_name}")
|
|
2264
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
2265
|
+
was_training = model.training
|
|
2266
|
+
model.eval()
|
|
2267
|
+
try:
|
|
2268
|
+
total_items = len(dataset_obj.clips)
|
|
2269
|
+
for ds_idx in range(total_items):
|
|
2270
|
+
loc_clip_noaoi = dataset_obj.load_modelres_clip_by_index(ds_idx).float().unsqueeze(0).to(device)
|
|
2271
|
+
loc_wh = torch.tensor([[float(global_fixed_wh[0]), float(global_fixed_wh[1])]], device=device, dtype=loc_clip_noaoi.dtype)
|
|
2272
|
+
with torch.no_grad():
|
|
2273
|
+
loc_out = model(loc_clip_noaoi, return_localization=True, localization_box_wh=loc_wh)
|
|
2274
|
+
_, pred = _split_localization_output(loc_out)
|
|
2275
|
+
if pred is None:
|
|
2276
|
+
pred = torch.tensor([[0.0, 0.0, 1.0, 1.0]], device=device, dtype=loc_clip_noaoi.dtype)
|
|
2277
|
+
pred = _sanitize_bboxes(pred)
|
|
2278
|
+
if pred.dim() == 3:
|
|
2279
|
+
pred = _lock_temporal_box_size_from_first(pred)
|
|
2280
|
+
bbox_for_crop = pred[0].detach().cpu() if pred.dim() == 3 else pred[0].detach().cpu()
|
|
2281
|
+
|
|
2282
|
+
# Safety fallback: use GT when predicted geometry is invalid.
|
|
2283
|
+
if pred.dim() == 3:
|
|
2284
|
+
invalid = bool(((bbox_for_crop[:, 2] <= bbox_for_crop[:, 0]) | (bbox_for_crop[:, 3] <= bbox_for_crop[:, 1])).any().item())
|
|
2285
|
+
else:
|
|
2286
|
+
invalid = bool(((bbox_for_crop[2] <= bbox_for_crop[0]) or (bbox_for_crop[3] <= bbox_for_crop[1])))
|
|
2287
|
+
ds_valid = dataset_obj.spatial_bbox_valid[ds_idx]
|
|
2288
|
+
if invalid and float(ds_valid[0].item() if ds_valid.dim() > 0 else ds_valid.item()) > 0.5:
|
|
2289
|
+
ds_bbox = dataset_obj.spatial_bboxes[ds_idx]
|
|
2290
|
+
ds_bbox_f0 = ds_bbox[0] if ds_bbox.dim() == 2 else ds_bbox
|
|
2291
|
+
gt_box = _sanitize_bboxes(ds_bbox_f0.view(1, 4))[0].detach().cpu()
|
|
2292
|
+
if pred.dim() == 3:
|
|
2293
|
+
bbox_for_crop = gt_box.unsqueeze(0).expand(pred.size(1), -1).contiguous()
|
|
2294
|
+
else:
|
|
2295
|
+
bbox_for_crop = gt_box
|
|
2296
|
+
|
|
2297
|
+
# Localization crops come from the original frame (no AOI)
|
|
2298
|
+
raw_clip = dataset_obj.load_fullres_clip_by_index(ds_idx, apply_aoi=False).float()
|
|
2299
|
+
out_h = int(loc_clip_noaoi.shape[-2])
|
|
2300
|
+
out_w = int(loc_clip_noaoi.shape[-1])
|
|
2301
|
+
cropped = _crop_single_clip_to_target(raw_clip, bbox_for_crop, out_h, out_w).detach().cpu()
|
|
2302
|
+
torch.save(cropped, os.path.join(cache_dir, f"{ds_idx}.pt"))
|
|
2303
|
+
|
|
2304
|
+
if log_fn and ((ds_idx + 1) % 50 == 0 or (ds_idx + 1) == total_items):
|
|
2305
|
+
log_fn(f"Saving cropped clips [{split_name}]: {ds_idx+1}/{total_items}")
|
|
2306
|
+
finally:
|
|
2307
|
+
if was_training:
|
|
2308
|
+
model.train()
|
|
2309
|
+
if log_fn:
|
|
2310
|
+
log_fn(f"Saved {total_items} cropped clips to {cache_dir}")
|
|
2311
|
+
return cache_dir
|
|
2312
|
+
|
|
2313
|
+
def _precompute_embedding_cache(
|
|
2314
|
+
dataset_obj,
|
|
2315
|
+
split_name: str = "train",
|
|
2316
|
+
num_aug_versions: int = 1,
|
|
2317
|
+
use_augmentation: bool = False,
|
|
2318
|
+
multi_scale: bool = False,
|
|
2319
|
+
) -> str:
|
|
2320
|
+
"""Run each clip through the frozen backbone and save token tensors.
|
|
2321
|
+
|
|
2322
|
+
When num_aug_versions > 1 and use_augmentation=True, each clip is
|
|
2323
|
+
passed through the backbone num_aug_versions times with different
|
|
2324
|
+
random augmentations applied, giving the temporal head diverse
|
|
2325
|
+
inputs each epoch without re-running the backbone at training time.
|
|
2326
|
+
Version 0 is always unaugmented (clean reference).
|
|
2327
|
+
Embeddings are stored as float16 to halve disk usage.
|
|
2328
|
+
|
|
2329
|
+
When multi_scale=True, each clip is also passed through the backbone
|
|
2330
|
+
at half fps (every-other frame subsampled), and saved as {idx}_{v}_s.pt.
|
|
2331
|
+
This gives the temporal head both fine-grained and broader temporal context.
|
|
2332
|
+
|
|
2333
|
+
Returns the cache directory path.
|
|
2334
|
+
"""
|
|
2335
|
+
cache_dir = os.path.join(output_dir_base, f"{basename}_emb_cache_{split_name}")
|
|
2336
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
2337
|
+
was_training = model.training
|
|
2338
|
+
model.eval()
|
|
2339
|
+
raw_model = model.module if hasattr(model, "module") else model
|
|
2340
|
+
total_items = len(dataset_obj.clips)
|
|
2341
|
+
# Multi-scale doubles the number of backbone passes
|
|
2342
|
+
passes_per_item = (2 if multi_scale else 1) * num_aug_versions
|
|
2343
|
+
total_ops = total_items * passes_per_item
|
|
2344
|
+
ops_done = 0
|
|
2345
|
+
try:
|
|
2346
|
+
for ds_idx in range(total_items):
|
|
2347
|
+
clip_info = dataset_obj.clips[ds_idx]
|
|
2348
|
+
clip_id = clip_info["id"]
|
|
2349
|
+
clip_path = dataset_obj._resolve_clip_path(clip_id)
|
|
2350
|
+
|
|
2351
|
+
# If the dataset has no transform, augmentation is a no-op regardless
|
|
2352
|
+
can_augment = use_augmentation and dataset_obj.transform is not None
|
|
2353
|
+
for v in range(num_aug_versions):
|
|
2354
|
+
# Version 0: always clean (no augmentation)
|
|
2355
|
+
apply_aug = can_augment and v > 0
|
|
2356
|
+
clip_t = dataset_obj._load_clip(
|
|
2357
|
+
clip_path,
|
|
2358
|
+
target_size=dataset_obj.target_size,
|
|
2359
|
+
apply_transform=apply_aug,
|
|
2360
|
+
).float().unsqueeze(0).to(device) # [1, T, C, H, W]
|
|
2361
|
+
with torch.no_grad():
|
|
2362
|
+
tokens = raw_model.backbone(clip_t) # [1, T*S, D]
|
|
2363
|
+
# Save as float16 — halves disk usage with negligible precision loss
|
|
2364
|
+
torch.save(tokens[0].cpu().half(), os.path.join(cache_dir, f"{ds_idx}_{v}.pt"))
|
|
2365
|
+
ops_done += 1
|
|
2366
|
+
if log_fn and (ops_done % 50 == 0 or ops_done == total_ops):
|
|
2367
|
+
log_fn(
|
|
2368
|
+
f"Caching backbone embeddings [{split_name}]: "
|
|
2369
|
+
f"{ops_done}/{total_ops} "
|
|
2370
|
+
f"(clip {ds_idx+1}/{total_items}, aug v{v})"
|
|
2371
|
+
)
|
|
2372
|
+
|
|
2373
|
+
if multi_scale:
|
|
2374
|
+
# Short scale: subsample every-other frame (half fps, same duration)
|
|
2375
|
+
clip_short = clip_t[:, ::2, :, :, :] # [1, T//2, C, H, W]
|
|
2376
|
+
with torch.no_grad():
|
|
2377
|
+
tokens_s = raw_model.backbone(clip_short) # [1, T_s*S, D]
|
|
2378
|
+
torch.save(tokens_s[0].cpu().half(), os.path.join(cache_dir, f"{ds_idx}_{v}_s.pt"))
|
|
2379
|
+
ops_done += 1
|
|
2380
|
+
if log_fn and (ops_done % 50 == 0 or ops_done == total_ops):
|
|
2381
|
+
log_fn(
|
|
2382
|
+
f"Caching backbone embeddings [{split_name}]: "
|
|
2383
|
+
f"{ops_done}/{total_ops} "
|
|
2384
|
+
f"(clip {ds_idx+1}/{total_items}, aug v{v}, short-scale)"
|
|
2385
|
+
)
|
|
2386
|
+
finally:
|
|
2387
|
+
if was_training:
|
|
2388
|
+
model.train()
|
|
2389
|
+
scale_note = " + short-scale" if multi_scale else ""
|
|
2390
|
+
if log_fn:
|
|
2391
|
+
log_fn(
|
|
2392
|
+
f"Saved {total_ops} embedding tensors to {cache_dir} "
|
|
2393
|
+
f"({num_aug_versions} version(s) per clip{scale_note}, float16)"
|
|
2394
|
+
)
|
|
2395
|
+
return cache_dir
|
|
2396
|
+
|
|
2397
|
+
if log_fn:
|
|
2398
|
+
if use_focal_loss:
|
|
2399
|
+
log_fn(f"Using Focal Loss (Active Learning): gamma={focal_gamma} (replaces CrossEntropy)")
|
|
2400
|
+
|
|
2401
|
+
# Determine class-only loss function
|
|
2402
|
+
criterion = None
|
|
2403
|
+
if use_class_weights and hasattr(train_dataset, 'labels') and not use_ovr:
|
|
2404
|
+
if log_fn:
|
|
2405
|
+
log_fn("Computing class weights for loss...")
|
|
2406
|
+
class_weights = compute_class_weights(
|
|
2407
|
+
[l for l in train_dataset.labels if l >= 0],
|
|
2408
|
+
len(train_dataset.classes)
|
|
2409
|
+
).to(device)
|
|
2410
|
+
if log_fn:
|
|
2411
|
+
log_fn(f"Class weights: {class_weights.tolist()}")
|
|
2412
|
+
if use_ovr:
|
|
2413
|
+
if use_asl:
|
|
2414
|
+
criterion = AsymmetricLoss(
|
|
2415
|
+
gamma_neg=asl_gamma_neg,
|
|
2416
|
+
gamma_pos=asl_gamma_pos,
|
|
2417
|
+
clip=asl_clip,
|
|
2418
|
+
reduction="none",
|
|
2419
|
+
)
|
|
2420
|
+
elif use_focal_loss:
|
|
2421
|
+
criterion = BinaryFocalLoss(gamma=focal_gamma)
|
|
2422
|
+
else:
|
|
2423
|
+
criterion = None # handled inline with F.binary_cross_entropy_with_logits
|
|
2424
|
+
n_hn = sum(1 for s in train_dataset.ovr_suppress_idx if s >= 0)
|
|
2425
|
+
n_real = sum(1 for lbl in train_dataset.labels if lbl >= 0)
|
|
2426
|
+
|
|
2427
|
+
# Per-head pos_weight: upweight positives for minority classes so each
|
|
2428
|
+
# binary head sees balanced effective pos/neg counts.
|
|
2429
|
+
num_c = len(train_dataset.classes)
|
|
2430
|
+
pos_counts = torch.zeros(num_c, dtype=torch.float32)
|
|
2431
|
+
|
|
2432
|
+
if hasattr(train_dataset, 'frame_labels') and getattr(train_dataset, 'frame_labels', None) is not None:
|
|
2433
|
+
if log_fn:
|
|
2434
|
+
log_fn("Computing OvR pos_weight using frame labels...")
|
|
2435
|
+
total_frames = 0
|
|
2436
|
+
for fl in train_dataset.frame_labels:
|
|
2437
|
+
if fl is not None:
|
|
2438
|
+
for lbl in fl:
|
|
2439
|
+
if 0 <= lbl < num_c:
|
|
2440
|
+
pos_counts[lbl] += 1.0
|
|
2441
|
+
total_frames += 1
|
|
2442
|
+
neg_counts = float(total_frames) - pos_counts
|
|
2443
|
+
else:
|
|
2444
|
+
if log_fn:
|
|
2445
|
+
log_fn("Computing OvR pos_weight using clip labels...")
|
|
2446
|
+
total_real = 0
|
|
2447
|
+
for ml in train_dataset.multi_labels:
|
|
2448
|
+
if ml:
|
|
2449
|
+
total_real += 1
|
|
2450
|
+
for lbl in ml:
|
|
2451
|
+
if 0 <= lbl < num_c:
|
|
2452
|
+
pos_counts[lbl] += 1.0
|
|
2453
|
+
neg_counts = float(total_real) - pos_counts
|
|
2454
|
+
|
|
2455
|
+
ovr_pos_weight = torch.ones(num_c, device=device)
|
|
2456
|
+
for ci in range(num_c):
|
|
2457
|
+
if pos_counts[ci] > 0:
|
|
2458
|
+
raw_ratio = neg_counts[ci].item() / pos_counts[ci].item()
|
|
2459
|
+
# ASL already suppresses easy negatives via its focusing term,
|
|
2460
|
+
# so the full neg/pos ratio double-corrects for imbalance.
|
|
2461
|
+
# Use sqrt when ASL is active for a softer complementary weight.
|
|
2462
|
+
if use_asl:
|
|
2463
|
+
ovr_pos_weight[ci] = max(1.0, raw_ratio ** 0.5)
|
|
2464
|
+
else:
|
|
2465
|
+
ovr_pos_weight[ci] = max(1.0, raw_ratio)
|
|
2466
|
+
ovr_pos_weight = ovr_pos_weight.clamp(max=50.0)
|
|
2467
|
+
|
|
2468
|
+
# Helper/background classes (e.g. "Other"): set pos_weight to a fixed value
|
|
2469
|
+
# (default 1.0) so the head trains gently. Use ovr_pos_weight_f1_excluded=1.5
|
|
2470
|
+
# to slightly upweight them if desired.
|
|
2471
|
+
_f1_exclude_names_early = set(config.get("f1_exclude_classes", []))
|
|
2472
|
+
_ovr_pw_excluded = float(config.get("ovr_pos_weight_f1_excluded", 1.5))
|
|
2473
|
+
for ci in range(num_c):
|
|
2474
|
+
if train_dataset.classes[ci] in _f1_exclude_names_early:
|
|
2475
|
+
ovr_pos_weight[ci] = _ovr_pw_excluded
|
|
2476
|
+
|
|
2477
|
+
# Detect co-occurrence pairs from multi-label annotations
|
|
2478
|
+
cooccur_set = set()
|
|
2479
|
+
for ml in train_dataset.multi_labels:
|
|
2480
|
+
if len(ml) >= 2:
|
|
2481
|
+
for a in ml:
|
|
2482
|
+
for b in ml:
|
|
2483
|
+
if a < b:
|
|
2484
|
+
cooccur_set.add((a, b))
|
|
2485
|
+
allowed_cooccurrence = [[train_dataset.classes[a], train_dataset.classes[b]] for a, b in sorted(cooccur_set)]
|
|
2486
|
+
cooccur_lookup = {i: set() for i in range(num_c)}
|
|
2487
|
+
for a, b in cooccur_set:
|
|
2488
|
+
cooccur_lookup[a].add(b)
|
|
2489
|
+
cooccur_lookup[b].add(a)
|
|
2490
|
+
|
|
2491
|
+
if log_fn:
|
|
2492
|
+
if use_asl:
|
|
2493
|
+
loss_msg = f" + ASL(γ-={asl_gamma_neg}, γ+={asl_gamma_pos}, clip={asl_clip})"
|
|
2494
|
+
elif use_focal_loss:
|
|
2495
|
+
loss_msg = f" + BinaryFocal(gamma={focal_gamma})"
|
|
2496
|
+
else:
|
|
2497
|
+
loss_msg = ""
|
|
2498
|
+
log_fn(f"OvR mode: {num_c} heads, {n_real} real clips, "
|
|
2499
|
+
f"{n_hn} near-negative clips{loss_msg}")
|
|
2500
|
+
pw_str = ", ".join(f"{train_dataset.classes[i]}={ovr_pos_weight[i]:.1f}" for i in range(num_c))
|
|
2501
|
+
log_fn(f"OvR per-head pos_weight: {pw_str}")
|
|
2502
|
+
if ovr_label_smoothing > 0:
|
|
2503
|
+
log_fn(f"OvR label smoothing: {ovr_label_smoothing} (targets [{ovr_label_smoothing:.2f}, {1-ovr_label_smoothing:.2f}])")
|
|
2504
|
+
bs = config["batch_size"]
|
|
2505
|
+
n_distinct = num_c + (1 if n_hn > 0 else 0)
|
|
2506
|
+
if bs < num_c:
|
|
2507
|
+
log_fn(f"WARNING: batch_size ({bs}) < num_classes ({num_c}). "
|
|
2508
|
+
f"Some OvR heads will see no positives per batch. "
|
|
2509
|
+
f"Recommend batch_size >= {num_c * 2} for stable OvR training.")
|
|
2510
|
+
elif bs < num_c * 2:
|
|
2511
|
+
log_fn(f"Note: batch_size ({bs}) is small for {num_c}-class OvR. "
|
|
2512
|
+
f"Consider increasing to {num_c * 2}+ for better per-head gradient signal.")
|
|
2513
|
+
if allowed_cooccurrence:
|
|
2514
|
+
pairs_str = ", ".join(f"{a}+{b}" for a, b in allowed_cooccurrence)
|
|
2515
|
+
log_fn(f"OvR allowed co-occurrence pairs: {pairs_str}")
|
|
2516
|
+
elif use_class_weights:
|
|
2517
|
+
if use_focal_loss:
|
|
2518
|
+
criterion = FocalLoss(gamma=focal_gamma, alpha=class_weights)
|
|
2519
|
+
else:
|
|
2520
|
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
|
2521
|
+
else:
|
|
2522
|
+
if use_focal_loss:
|
|
2523
|
+
criterion = FocalLoss(gamma=focal_gamma)
|
|
2524
|
+
else:
|
|
2525
|
+
criterion = nn.CrossEntropyLoss()
|
|
2526
|
+
|
|
2527
|
+
if log_fn and not use_ovr:
|
|
2528
|
+
if use_focal_loss:
|
|
2529
|
+
w_msg = "weighted" if use_class_weights else "unweighted"
|
|
2530
|
+
log_fn(f"Using {w_msg} FocalLoss(gamma={focal_gamma})")
|
|
2531
|
+
else:
|
|
2532
|
+
if use_class_weights:
|
|
2533
|
+
log_fn(f"Using Class-Weighted CrossEntropyLoss")
|
|
2534
|
+
else:
|
|
2535
|
+
log_fn(f"Using standard CrossEntropyLoss (no class weights)")
|
|
2536
|
+
|
|
2537
|
+
# Use the ACTUAL classes from the dataset for metadata to ensure correct mapping order
|
|
2538
|
+
class_names = train_dataset.classes
|
|
2539
|
+
|
|
2540
|
+
# Classes excluded from F1 metrics (e.g., "Other"/"Background" helper classes).
|
|
2541
|
+
# They still train normally but don't affect best-model selection.
|
|
2542
|
+
_f1_exclude_names = set(config.get("f1_exclude_classes", []))
|
|
2543
|
+
_f1_include_indices = [i for i, c in enumerate(class_names) if c not in _f1_exclude_names]
|
|
2544
|
+
_f1_exclude_indices = {i for i, c in enumerate(class_names) if c in _f1_exclude_names}
|
|
2545
|
+
train_dataset.stitch_exclude_classes = _f1_exclude_indices
|
|
2546
|
+
slug_counts = {}
|
|
2547
|
+
class_key_map = {}
|
|
2548
|
+
class_label_map = {}
|
|
2549
|
+
for idx, cls_name in enumerate(class_names):
|
|
2550
|
+
base_slug = _slugify_class_name(cls_name)
|
|
2551
|
+
slug = base_slug
|
|
2552
|
+
counter = 1
|
|
2553
|
+
while slug in slug_counts:
|
|
2554
|
+
slug = f"{base_slug}_{counter}"
|
|
2555
|
+
counter += 1
|
|
2556
|
+
slug_counts[slug] = True
|
|
2557
|
+
key = f"val_f1_{slug}"
|
|
2558
|
+
class_key_map[idx] = key
|
|
2559
|
+
class_label_map[idx] = cls_name
|
|
2560
|
+
|
|
2561
|
+
class_counts = Counter(train_dataset.labels)
|
|
2562
|
+
|
|
2563
|
+
class_counts_named = {
|
|
2564
|
+
train_dataset.classes[idx]: count
|
|
2565
|
+
for idx, count in class_counts.items()
|
|
2566
|
+
if 0 <= idx < len(train_dataset.classes)
|
|
2567
|
+
}
|
|
2568
|
+
|
|
2569
|
+
clip_length_value = config.get("clip_length")
|
|
2570
|
+
if clip_length_value is None and hasattr(train_dataset, "clip_length"):
|
|
2571
|
+
clip_length_value = train_dataset.clip_length
|
|
2572
|
+
|
|
2573
|
+
resolution_value = config.get("resolution")
|
|
2574
|
+
if resolution_value is None and hasattr(train_dataset, "target_size"):
|
|
2575
|
+
target_size = train_dataset.target_size
|
|
2576
|
+
if isinstance(target_size, (tuple, list)) and target_size:
|
|
2577
|
+
resolution_value = int(target_size[0])
|
|
2578
|
+
elif isinstance(target_size, int):
|
|
2579
|
+
resolution_value = int(target_size)
|
|
2580
|
+
|
|
2581
|
+
def _json_safe(value):
|
|
2582
|
+
if value is None or isinstance(value, (bool, int, float, str)):
|
|
2583
|
+
return value
|
|
2584
|
+
if isinstance(value, dict):
|
|
2585
|
+
return {str(k): _json_safe(v) for k, v in value.items()}
|
|
2586
|
+
if isinstance(value, (list, tuple, set)):
|
|
2587
|
+
return [_json_safe(v) for v in value]
|
|
2588
|
+
return str(value)
|
|
2589
|
+
|
|
2590
|
+
def _calibrate_ignore_thresholds_from_validation(
|
|
2591
|
+
score_chunks_by_class: dict[int, list[np.ndarray]],
|
|
2592
|
+
target_chunks_by_class: dict[int, list[np.ndarray]],
|
|
2593
|
+
class_names_local: list[str],
|
|
2594
|
+
) -> Optional[dict]:
|
|
2595
|
+
"""Pick per-class ignore thresholds from validation scores by F1."""
|
|
2596
|
+
per_class_thresholds = {}
|
|
2597
|
+
per_class_stats = {}
|
|
2598
|
+
weighted_thresholds = []
|
|
2599
|
+
weighted_supports = []
|
|
2600
|
+
for cls_idx, cls_name in enumerate(class_names_local):
|
|
2601
|
+
score_chunks = score_chunks_by_class.get(cls_idx, [])
|
|
2602
|
+
target_chunks = target_chunks_by_class.get(cls_idx, [])
|
|
2603
|
+
if not score_chunks or not target_chunks:
|
|
2604
|
+
continue
|
|
2605
|
+
scores = np.concatenate(score_chunks).astype(np.float32, copy=False)
|
|
2606
|
+
targets = np.concatenate(target_chunks).astype(np.uint8, copy=False)
|
|
2607
|
+
if scores.size == 0 or targets.size != scores.size:
|
|
2608
|
+
continue
|
|
2609
|
+
pos_support = int(targets.sum())
|
|
2610
|
+
neg_support = int(targets.size - pos_support)
|
|
2611
|
+
if pos_support < 3 or neg_support < 3:
|
|
2612
|
+
continue
|
|
2613
|
+
|
|
2614
|
+
base_grid = np.linspace(0.35, 0.90, 56, dtype=np.float32)
|
|
2615
|
+
quantiles = np.quantile(scores, np.linspace(0.05, 0.95, 19)).astype(np.float32)
|
|
2616
|
+
candidates = np.unique(np.clip(np.concatenate([base_grid, quantiles]), 0.35, 0.90))
|
|
2617
|
+
|
|
2618
|
+
best_tau = 0.60
|
|
2619
|
+
best_f1 = -1.0
|
|
2620
|
+
best_precision = -1.0
|
|
2621
|
+
best_recall = -1.0
|
|
2622
|
+
for tau in candidates:
|
|
2623
|
+
pred_pos = scores >= float(tau)
|
|
2624
|
+
tp = float(np.sum(pred_pos & (targets == 1)))
|
|
2625
|
+
fp = float(np.sum(pred_pos & (targets == 0)))
|
|
2626
|
+
fn = float(np.sum((~pred_pos) & (targets == 1)))
|
|
2627
|
+
precision = tp / max(1.0, tp + fp)
|
|
2628
|
+
recall = tp / max(1.0, tp + fn)
|
|
2629
|
+
f1 = (2.0 * precision * recall) / max(1e-8, precision + recall)
|
|
2630
|
+
if (
|
|
2631
|
+
(f1 > best_f1 + 1e-8)
|
|
2632
|
+
or (abs(f1 - best_f1) <= 1e-8 and precision > best_precision + 1e-8)
|
|
2633
|
+
or (
|
|
2634
|
+
abs(f1 - best_f1) <= 1e-8
|
|
2635
|
+
and abs(precision - best_precision) <= 1e-8
|
|
2636
|
+
and float(tau) > best_tau
|
|
2637
|
+
)
|
|
2638
|
+
):
|
|
2639
|
+
best_tau = float(tau)
|
|
2640
|
+
best_f1 = float(f1)
|
|
2641
|
+
best_precision = float(precision)
|
|
2642
|
+
best_recall = float(recall)
|
|
2643
|
+
|
|
2644
|
+
per_class_thresholds[cls_name] = float(best_tau)
|
|
2645
|
+
per_class_stats[cls_name] = {
|
|
2646
|
+
"positive_support": pos_support,
|
|
2647
|
+
"negative_support": neg_support,
|
|
2648
|
+
"best_f1": float(best_f1),
|
|
2649
|
+
"best_precision": float(best_precision),
|
|
2650
|
+
"best_recall": float(best_recall),
|
|
2651
|
+
}
|
|
2652
|
+
weighted_thresholds.append(float(best_tau))
|
|
2653
|
+
weighted_supports.append(max(1, pos_support))
|
|
2654
|
+
|
|
2655
|
+
if not per_class_thresholds:
|
|
2656
|
+
return None
|
|
2657
|
+
|
|
2658
|
+
global_threshold = float(
|
|
2659
|
+
np.average(np.asarray(weighted_thresholds, dtype=np.float32), weights=np.asarray(weighted_supports, dtype=np.float32))
|
|
2660
|
+
)
|
|
2661
|
+
global_threshold = max(0.35, min(0.90, global_threshold))
|
|
2662
|
+
return {
|
|
2663
|
+
"source": "validation_f1_calibration",
|
|
2664
|
+
"global_threshold": global_threshold,
|
|
2665
|
+
"per_class_thresholds": per_class_thresholds,
|
|
2666
|
+
"per_class_stats": per_class_stats,
|
|
2667
|
+
}
|
|
2668
|
+
|
|
2669
|
+
ovr_pos_weight_named = None
|
|
2670
|
+
if use_ovr:
|
|
2671
|
+
ovr_pos_weight_named = {
|
|
2672
|
+
train_dataset.classes[i]: float(ovr_pos_weight[i].item())
|
|
2673
|
+
for i in range(len(train_dataset.classes))
|
|
2674
|
+
}
|
|
2675
|
+
stitch_excluded_names = []
|
|
2676
|
+
if hasattr(train_dataset, "stitch_exclude_classes"):
|
|
2677
|
+
stitch_excluded_names = [
|
|
2678
|
+
class_names[i] for i in sorted(train_dataset.stitch_exclude_classes)
|
|
2679
|
+
if 0 <= int(i) < len(class_names)
|
|
2680
|
+
]
|
|
2681
|
+
sampler_mode = "balanced_batch" if batch_sampler is not None else ("weighted_random" if sampler is not None else "shuffle")
|
|
2682
|
+
|
|
2683
|
+
head_metadata = {
|
|
2684
|
+
"classes": class_names,
|
|
2685
|
+
"num_classes": len(class_names),
|
|
2686
|
+
"clip_length": clip_length_value,
|
|
2687
|
+
"target_fps": int(config.get("target_fps", 16)),
|
|
2688
|
+
"resolution": resolution_value,
|
|
2689
|
+
"training_samples": class_counts_named,
|
|
2690
|
+
"backbone_model": config.get("backbone_model", "videoprism_public_v1_base"),
|
|
2691
|
+
"head": {
|
|
2692
|
+
"type": "DilatedTemporalHead" if use_temporal_decoder else "SpatialPoolLinearHead",
|
|
2693
|
+
"dropout": config.get("dropout", 0.1),
|
|
2694
|
+
"use_localization": use_localization,
|
|
2695
|
+
"localization_hidden_dim": int(config.get("localization_hidden_dim", 256)),
|
|
2696
|
+
"localization_dropout": float(config.get("localization_dropout", 0.0)),
|
|
2697
|
+
"use_temporal_decoder": use_temporal_decoder,
|
|
2698
|
+
"frame_head_temporal_layers": int(config.get("frame_head_temporal_layers", 1)),
|
|
2699
|
+
"temporal_pool_frames": int(config.get("temporal_pool_frames", 1)),
|
|
2700
|
+
"num_stages": int(config.get("num_stages", 3)),
|
|
2701
|
+
"proj_dim": int(config.get("proj_dim", 256)),
|
|
2702
|
+
},
|
|
2703
|
+
"training_config": {
|
|
2704
|
+
"batch_size": config["batch_size"],
|
|
2705
|
+
"epochs": config["epochs"],
|
|
2706
|
+
"lr": config["lr"],
|
|
2707
|
+
"weight_decay": config.get("weight_decay", 0.001),
|
|
2708
|
+
"use_scheduler": use_scheduler,
|
|
2709
|
+
"scheduler_name": "CosineAnnealingLR" if use_scheduler else None,
|
|
2710
|
+
"warmup_epochs": int(warmup_epochs),
|
|
2711
|
+
"t_max_epochs": int(total_epochs - warmup_epochs),
|
|
2712
|
+
"eta_min": float(eta_min) if eta_min is not None else None,
|
|
2713
|
+
"use_ovr": use_ovr,
|
|
2714
|
+
"ovr_background_as_negative": bool(config.get("ovr_background_as_negative", False)) if use_ovr else False,
|
|
2715
|
+
"ovr_background_class_names": config.get("ovr_background_class_names", []) if use_ovr else [],
|
|
2716
|
+
"allowed_cooccurrence": allowed_cooccurrence if use_ovr else [],
|
|
2717
|
+
"cooccurrence_loss_mode": "ignore_negative_pairs" if use_ovr else None,
|
|
2718
|
+
"ovr_pos_weight": ovr_pos_weight_named if use_ovr else None,
|
|
2719
|
+
"use_hard_pair_mining": hard_pair_mining if use_ovr else False,
|
|
2720
|
+
"hard_pair_mode": "pairwise_margin" if (use_ovr and hard_pair_mining) else None,
|
|
2721
|
+
"hard_pairs": hard_pair_name_pairs if (use_ovr and hard_pair_mining) else [],
|
|
2722
|
+
"hard_pair_margin": hard_pair_margin if (use_ovr and hard_pair_mining) else None,
|
|
2723
|
+
"hard_pair_loss_weight": hard_pair_loss_weight if (use_ovr and hard_pair_mining) else None,
|
|
2724
|
+
"hard_pair_confusion_boost": hard_pair_confusion_boost if (use_ovr and hard_pair_mining) else None,
|
|
2725
|
+
"use_class_weights": use_class_weights,
|
|
2726
|
+
"resolution": resolution_value,
|
|
2727
|
+
"use_weighted_sampler": config.get("use_weighted_sampler", False),
|
|
2728
|
+
"use_balanced_sampler": (batch_sampler is not None),
|
|
2729
|
+
"sampler_mode": sampler_mode,
|
|
2730
|
+
"use_augmentation": config.get("use_augmentation", False),
|
|
2731
|
+
"augmentation_options": config.get("augmentation_options") or {},
|
|
2732
|
+
"stitch_augmentation_prob": float(getattr(train_dataset, "stitch_prob", 0.0)),
|
|
2733
|
+
"emb_cache": bool(config.get("emb_cache", False)),
|
|
2734
|
+
"emb_aug_versions": int(config.get("emb_aug_versions", 1)),
|
|
2735
|
+
"stitch_excluded_classes": stitch_excluded_names,
|
|
2736
|
+
"use_focal_loss": use_focal_loss,
|
|
2737
|
+
"focal_gamma": focal_gamma if use_focal_loss else None,
|
|
2738
|
+
"use_supcon_loss": use_supcon_loss,
|
|
2739
|
+
"supcon_weight": supcon_weight if use_supcon_loss else None,
|
|
2740
|
+
"supcon_temperature": supcon_temperature if use_supcon_loss else None,
|
|
2741
|
+
"use_asl": use_asl if use_ovr else False,
|
|
2742
|
+
"asl_gamma_neg": asl_gamma_neg if (use_ovr and use_asl) else None,
|
|
2743
|
+
"asl_gamma_pos": asl_gamma_pos if (use_ovr and use_asl) else None,
|
|
2744
|
+
"asl_clip": asl_clip if (use_ovr and use_asl) else None,
|
|
2745
|
+
"use_ema": use_ema,
|
|
2746
|
+
"ema_decay": float(ema_decay) if use_ema else None,
|
|
2747
|
+
"ema_start_epoch": int(max(warmup_epochs, 3)) if use_ema else None,
|
|
2748
|
+
"frame_loss_weight": 1.0,
|
|
2749
|
+
"use_temporal_decoder": use_temporal_decoder,
|
|
2750
|
+
"use_frame_bout_balance": use_frame_bout_balance,
|
|
2751
|
+
"frame_bout_balance_power": frame_bout_balance_power if use_frame_bout_balance else None,
|
|
2752
|
+
"temporal_pool_frames": int(config.get("temporal_pool_frames", 1)),
|
|
2753
|
+
"frame_head_temporal_layers": int(config.get("frame_head_temporal_layers", 1)),
|
|
2754
|
+
"num_stages": int(config.get("num_stages", 3)),
|
|
2755
|
+
"boundary_loss_weight": boundary_loss_weight,
|
|
2756
|
+
"smoothness_loss_weight": smoothness_loss_weight,
|
|
2757
|
+
"boundary_tolerance": boundary_tolerance,
|
|
2758
|
+
"proj_dim": int(config.get("proj_dim", 256)),
|
|
2759
|
+
"use_localization": use_localization,
|
|
2760
|
+
"use_manual_localization_switch": use_manual_loc_switch if use_localization else None,
|
|
2761
|
+
"manual_localization_switch_epoch": manual_loc_switch_epoch if (use_localization and use_manual_loc_switch) else None,
|
|
2762
|
+
"localization_stage_max_epochs": loc_max_stage_epochs if use_localization else None,
|
|
2763
|
+
"localization_gate_patience": loc_gate_patience if use_localization else None,
|
|
2764
|
+
"localization_gate_iou": loc_gate_iou_threshold if use_localization else None,
|
|
2765
|
+
"localization_gate_center_error": loc_gate_center_error if use_localization else None,
|
|
2766
|
+
"localization_gate_valid_rate": loc_gate_valid_rate if use_localization else None,
|
|
2767
|
+
"classification_crop_gt_prob_start": crop_mix_start_gt if use_localization else None,
|
|
2768
|
+
"classification_crop_gt_prob_end": crop_mix_end_gt if use_localization else None,
|
|
2769
|
+
"classification_crop_padding": crop_padding if use_localization else None,
|
|
2770
|
+
"crop_jitter": config.get("crop_jitter", False),
|
|
2771
|
+
"crop_jitter_strength": config.get("crop_jitter_strength", 0.15),
|
|
2772
|
+
"classification_crop_min_size_norm": crop_min_size if use_localization else None,
|
|
2773
|
+
"center_heatmap_weight": center_heatmap_weight if use_localization else None,
|
|
2774
|
+
"center_heatmap_sigma": center_heatmap_sigma if use_localization else None,
|
|
2775
|
+
"direct_center_weight": direct_center_weight if use_localization else None,
|
|
2776
|
+
"localization_fixed_size_stat": "max" if use_localization else None,
|
|
2777
|
+
"localization_fixed_box_global_w": global_fixed_wh[0] if use_localization else None,
|
|
2778
|
+
"localization_fixed_box_global_h": global_fixed_wh[1] if use_localization else None,
|
|
2779
|
+
"val_split": config.get("val_split", 0.2),
|
|
2780
|
+
"use_all_for_training": config.get("use_all_for_training", False),
|
|
2781
|
+
"config_snapshot": _json_safe(config),
|
|
2782
|
+
}
|
|
2783
|
+
}
|
|
2784
|
+
|
|
2785
|
+
if log_fn:
|
|
2786
|
+
log_fn(f"Training with classes: {class_names}")
|
|
2787
|
+
log_fn(f"Training samples per class (primary label): {class_counts_named}")
|
|
2788
|
+
# Multi-class breakdown
|
|
2789
|
+
mc_count = sum(1 for ml in train_dataset.multi_labels if len(ml) > 1)
|
|
2790
|
+
if mc_count > 0:
|
|
2791
|
+
from collections import Counter as _Counter
|
|
2792
|
+
mc_combos = _Counter(
|
|
2793
|
+
tuple(sorted(train_dataset.classes[i] for i in ml))
|
|
2794
|
+
for ml in train_dataset.multi_labels if len(ml) > 1
|
|
2795
|
+
)
|
|
2796
|
+
log_fn(f"Multi-class clips in training set: {mc_count} of {len(train_dataset.labels)}")
|
|
2797
|
+
for combo, cnt in mc_combos.most_common():
|
|
2798
|
+
log_fn(f" {' + '.join(combo)}: {cnt}")
|
|
2799
|
+
|
|
2800
|
+
best_val_frame_acc = 0.0
|
|
2801
|
+
best_val_f1 = -1.0
|
|
2802
|
+
|
|
2803
|
+
if log_fn:
|
|
2804
|
+
log_fn(f"Starting training for {config['epochs']} epochs...")
|
|
2805
|
+
|
|
2806
|
+
history = {
|
|
2807
|
+
"epoch": [],
|
|
2808
|
+
"train_loss": [],
|
|
2809
|
+
"train_loss_class": [], # primary classification loss only
|
|
2810
|
+
"train_acc": [],
|
|
2811
|
+
"train_frame_acc": [],
|
|
2812
|
+
"val_loss": [],
|
|
2813
|
+
"val_acc": [],
|
|
2814
|
+
"val_frame_acc": [],
|
|
2815
|
+
"val_f1": [],
|
|
2816
|
+
"loc_val_iou": [],
|
|
2817
|
+
"loc_val_center_error": [],
|
|
2818
|
+
"loc_val_valid_rate": [],
|
|
2819
|
+
}
|
|
2820
|
+
for key in class_key_map.values():
|
|
2821
|
+
history[key] = []
|
|
2822
|
+
|
|
2823
|
+
|
|
2824
|
+
# Curriculum state: stage 1 localization -> stage 2 classification on crops.
|
|
2825
|
+
in_localization_stage = bool(use_localization and has_any_localization)
|
|
2826
|
+
localization_gate_streak = 0
|
|
2827
|
+
classification_stage_start_epoch = None
|
|
2828
|
+
output_dir_base = os.path.dirname(config["output_path"])
|
|
2829
|
+
basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
|
|
2830
|
+
crop_progress_num_samples = int(config.get("crop_progress_num_samples", 5))
|
|
2831
|
+
crop_progress_dir = None
|
|
2832
|
+
train_roi_cache = None
|
|
2833
|
+
val_roi_cache = None
|
|
2834
|
+
train_emb_cache = None
|
|
2835
|
+
val_emb_cache_dir = None
|
|
2836
|
+
|
|
2837
|
+
# Save sample clips exactly as VideoPrism sees them (with augmentation)
|
|
2838
|
+
try:
|
|
2839
|
+
import matplotlib
|
|
2840
|
+
matplotlib.use('Agg')
|
|
2841
|
+
import matplotlib.pyplot as plt
|
|
2842
|
+
import matplotlib.patches as mpatches
|
|
2843
|
+
import cv2
|
|
2844
|
+
|
|
2845
|
+
samples_dir = os.path.join(output_dir_base, f"{basename}_input_samples")
|
|
2846
|
+
os.makedirs(samples_dir, exist_ok=True)
|
|
2847
|
+
loc_samples_dir = os.path.join(output_dir_base, f"{basename}_localized_input_samples")
|
|
2848
|
+
crop_progress_dir = os.path.join(output_dir_base, f"{basename}_crop_progress")
|
|
2849
|
+
if use_localization:
|
|
2850
|
+
os.makedirs(loc_samples_dir, exist_ok=True)
|
|
2851
|
+
os.makedirs(crop_progress_dir, exist_ok=True)
|
|
2852
|
+
|
|
2853
|
+
resolution = config.get("resolution", 288)
|
|
2854
|
+
grid_g = resolution // 18
|
|
2855
|
+
sample_indices = np.random.choice(len(train_dataset), size=min(5, len(train_dataset)), replace=False)
|
|
2856
|
+
|
|
2857
|
+
for si, idx in enumerate(sample_indices):
|
|
2858
|
+
batch = train_dataset[idx]
|
|
2859
|
+
clip_tensor = batch[0] # [T, C, H, W]
|
|
2860
|
+
T = clip_tensor.shape[0]
|
|
2861
|
+
frames = clip_tensor.permute(0, 2, 3, 1).numpy() # [T, H, W, C]
|
|
2862
|
+
H, W = frames.shape[1], frames.shape[2]
|
|
2863
|
+
|
|
2864
|
+
label_idx = batch[1] if isinstance(batch[1], int) else batch[1].item()
|
|
2865
|
+
if 0 <= label_idx < len(class_names):
|
|
2866
|
+
label_name = class_names[label_idx]
|
|
2867
|
+
elif label_idx < 0:
|
|
2868
|
+
label_name = "mixed/stitched"
|
|
2869
|
+
else:
|
|
2870
|
+
label_name = f"Class {label_idx}"
|
|
2871
|
+
|
|
2872
|
+
actual_idx = idx % len(train_dataset.clips)
|
|
2873
|
+
clip_id = train_dataset.clips[actual_idx].get("id", "?")
|
|
2874
|
+
orig_label = train_dataset.clips[actual_idx].get("label", "?")
|
|
2875
|
+
# Count valid frame labels
|
|
2876
|
+
fl_tensor = train_dataset.frame_labels[actual_idx]
|
|
2877
|
+
num_labeled_frames = (fl_tensor >= 0).sum().item()
|
|
2878
|
+
total_frames = fl_tensor.numel()
|
|
2879
|
+
|
|
2880
|
+
if log_fn:
|
|
2881
|
+
log_fn(f" Sample {si+1}: idx={idx} actual={actual_idx} clip_id={clip_id} "
|
|
2882
|
+
f"dataset_label={label_name} orig_label={orig_label} "
|
|
2883
|
+
f"({num_labeled_frames}/{total_frames} frames labeled)")
|
|
2884
|
+
|
|
2885
|
+
num_show = min(T, 10)
|
|
2886
|
+
frame_indices_show = np.linspace(0, T - 1, num_show, dtype=int)
|
|
2887
|
+
|
|
2888
|
+
# Scale figure to true pixel size: each frame at 1:1 pixels
|
|
2889
|
+
dpi = 150
|
|
2890
|
+
fig_w = (W * num_show + 40) / dpi # 40px padding between frames
|
|
2891
|
+
fig_h = (H + 80) / dpi # 80px for two-line title
|
|
2892
|
+
fig, axes = plt.subplots(1, num_show, figsize=(fig_w, fig_h))
|
|
2893
|
+
if num_show == 1:
|
|
2894
|
+
axes = [axes]
|
|
2895
|
+
|
|
2896
|
+
for j, fi in enumerate(frame_indices_show):
|
|
2897
|
+
frame_rgb = (np.clip(frames[fi], 0, 1) * 255).astype(np.uint8)
|
|
2898
|
+
axes[j].imshow(frame_rgb)
|
|
2899
|
+
# Draw patch grid
|
|
2900
|
+
for g in range(1, grid_g):
|
|
2901
|
+
axes[j].axhline(y=g * H / grid_g, color='white', linewidth=0.3, alpha=0.5)
|
|
2902
|
+
axes[j].axvline(x=g * W / grid_g, color='white', linewidth=0.3, alpha=0.5)
|
|
2903
|
+
axes[j].axis('off')
|
|
2904
|
+
|
|
2905
|
+
# Show per-frame label if available
|
|
2906
|
+
f_lbl_idx = int(fl_tensor[fi].item())
|
|
2907
|
+
f_lbl_text = f'f{fi}'
|
|
2908
|
+
if f_lbl_idx >= 0 and f_lbl_idx < len(class_names):
|
|
2909
|
+
f_lbl_text += f'\n{class_names[f_lbl_idx]}'
|
|
2910
|
+
elif f_lbl_idx >= 0:
|
|
2911
|
+
f_lbl_text += f'\n{f_lbl_idx}'
|
|
2912
|
+
|
|
2913
|
+
axes[j].set_title(f_lbl_text, fontsize=7)
|
|
2914
|
+
|
|
2915
|
+
clip_id_short = os.path.basename(clip_id)
|
|
2916
|
+
fig.suptitle(
|
|
2917
|
+
f'"{label_name}" — clip {idx} (actual={actual_idx}) | '
|
|
2918
|
+
f'res {H}×{W}, grid {grid_g}×{grid_g}, {T} frames, patch 18px\n'
|
|
2919
|
+
f'file: {clip_id_short} | orig_label: {orig_label} | '
|
|
2920
|
+
f'Frames Labeled: {num_labeled_frames}/{total_frames}',
|
|
2921
|
+
fontsize=8, fontweight='bold'
|
|
2922
|
+
)
|
|
2923
|
+
plt.tight_layout(rect=[0, 0, 1, 0.90])
|
|
2924
|
+
safe_label = label_name.replace(" ", "_").replace("/", "-")
|
|
2925
|
+
plt.savefig(os.path.join(samples_dir, f'sample_{si+1}_{safe_label}.png'), dpi=150, bbox_inches='tight')
|
|
2926
|
+
plt.close()
|
|
2927
|
+
|
|
2928
|
+
# Optional localization preview: crop from native-res video
|
|
2929
|
+
# so the preview matches what the classification head actually receives.
|
|
2930
|
+
if use_localization:
|
|
2931
|
+
bbox_target = batch[3] if len(batch) >= 4 else None
|
|
2932
|
+
bbox_valid = batch[4] if len(batch) >= 5 else None
|
|
2933
|
+
if (
|
|
2934
|
+
torch.is_tensor(bbox_target)
|
|
2935
|
+
and torch.is_tensor(bbox_valid)
|
|
2936
|
+
and float(bbox_valid.sum().item()) > 0.5
|
|
2937
|
+
):
|
|
2938
|
+
# Use first valid frame bbox for the preview
|
|
2939
|
+
if bbox_target.dim() == 2:
|
|
2940
|
+
valid_mask = bbox_valid > 0.5
|
|
2941
|
+
first_valid_t = int(valid_mask.float().argmax().item()) if valid_mask.any() else 0
|
|
2942
|
+
bbox_target = bbox_target[first_valid_t]
|
|
2943
|
+
x1, y1, x2, y2 = [float(v) for v in bbox_target.tolist()]
|
|
2944
|
+
x1 = max(0.0, min(1.0, x1))
|
|
2945
|
+
y1 = max(0.0, min(1.0, y1))
|
|
2946
|
+
x2 = max(0.0, min(1.0, x2))
|
|
2947
|
+
y2 = max(0.0, min(1.0, y2))
|
|
2948
|
+
if x2 > x1 and y2 > y1:
|
|
2949
|
+
# Load full-resolution clip for high-quality crop
|
|
2950
|
+
actual_idx = idx % len(train_dataset.clips)
|
|
2951
|
+
try:
|
|
2952
|
+
raw_clip = train_dataset.load_fullres_clip_by_index(actual_idx)
|
|
2953
|
+
raw_frames = raw_clip.permute(0, 2, 3, 1).numpy() # [T, Hraw, Wraw, C]
|
|
2954
|
+
Hraw, Wraw = raw_frames.shape[1], raw_frames.shape[2]
|
|
2955
|
+
except Exception as e:
|
|
2956
|
+
logger.debug("Could not load full-res clip by index: %s", e)
|
|
2957
|
+
raw_frames = frames
|
|
2958
|
+
Hraw, Wraw = H, W
|
|
2959
|
+
|
|
2960
|
+
fig2, axes2 = plt.subplots(2, num_show, figsize=(fig_w, (H * 2 + 80) / dpi))
|
|
2961
|
+
if num_show == 1:
|
|
2962
|
+
axes2 = axes2.reshape(2, 1)
|
|
2963
|
+
|
|
2964
|
+
# Bbox in pixel coords on the low-res frames (for display row 1)
|
|
2965
|
+
ix1 = int(round(x1 * W))
|
|
2966
|
+
iy1 = int(round(y1 * H))
|
|
2967
|
+
ix2 = int(round(x2 * W))
|
|
2968
|
+
iy2 = int(round(y2 * H))
|
|
2969
|
+
ix1 = max(0, min(ix1, W - 1))
|
|
2970
|
+
iy1 = max(0, min(iy1, H - 1))
|
|
2971
|
+
ix2 = max(ix1 + 1, min(ix2, W))
|
|
2972
|
+
iy2 = max(iy1 + 1, min(iy2, H))
|
|
2973
|
+
|
|
2974
|
+
# Bbox in pixel coords on the full-res frames (for crop)
|
|
2975
|
+
rx1 = int(round(x1 * Wraw))
|
|
2976
|
+
ry1 = int(round(y1 * Hraw))
|
|
2977
|
+
rx2 = int(round(x2 * Wraw))
|
|
2978
|
+
ry2 = int(round(y2 * Hraw))
|
|
2979
|
+
rx1 = max(0, min(rx1, Wraw - 1))
|
|
2980
|
+
ry1 = max(0, min(ry1, Hraw - 1))
|
|
2981
|
+
rx2 = max(rx1 + 1, min(rx2, Wraw))
|
|
2982
|
+
ry2 = max(ry1 + 1, min(ry2, Hraw))
|
|
2983
|
+
|
|
2984
|
+
for j, fi in enumerate(frame_indices_show):
|
|
2985
|
+
frame_rgb = (np.clip(frames[fi], 0, 1) * 255).astype(np.uint8)
|
|
2986
|
+
|
|
2987
|
+
# Row 1: model-res frame with target bbox
|
|
2988
|
+
axes2[0, j].imshow(frame_rgb)
|
|
2989
|
+
bbox_rect = mpatches.Rectangle(
|
|
2990
|
+
(ix1, iy1), ix2 - ix1, iy2 - iy1,
|
|
2991
|
+
fill=False, edgecolor='orange', linewidth=1.5
|
|
2992
|
+
)
|
|
2993
|
+
axes2[0, j].add_patch(bbox_rect)
|
|
2994
|
+
axes2[0, j].axis('off')
|
|
2995
|
+
axes2[0, j].set_title(f'f{fi}', fontsize=7)
|
|
2996
|
+
|
|
2997
|
+
# Row 2: crop from full-res, resize to model resolution
|
|
2998
|
+
raw_fi = min(fi, len(raw_frames) - 1)
|
|
2999
|
+
raw_rgb = (np.clip(raw_frames[raw_fi], 0, 1) * 255).astype(np.uint8)
|
|
3000
|
+
crop = raw_rgb[ry1:ry2, rx1:rx2]
|
|
3001
|
+
if crop.size == 0:
|
|
3002
|
+
crop = raw_rgb
|
|
3003
|
+
crop = cv2.resize(crop, (W, H), interpolation=cv2.INTER_AREA)
|
|
3004
|
+
axes2[1, j].imshow(crop)
|
|
3005
|
+
axes2[1, j].axis('off')
|
|
3006
|
+
|
|
3007
|
+
fig2.suptitle(
|
|
3008
|
+
f'"{label_name}" — localization target crop preview (from native {Hraw}×{Wraw})',
|
|
3009
|
+
fontsize=9, fontweight='bold'
|
|
3010
|
+
)
|
|
3011
|
+
axes2[0, 0].set_ylabel('Original+bbox', fontsize=9)
|
|
3012
|
+
axes2[1, 0].set_ylabel('Crop preview', fontsize=9)
|
|
3013
|
+
plt.tight_layout(rect=[0, 0, 1, 0.92])
|
|
3014
|
+
plt.savefig(
|
|
3015
|
+
os.path.join(loc_samples_dir, f'loc_sample_{si+1}_{label_name.replace(" ", "_")}.png'),
|
|
3016
|
+
dpi=150,
|
|
3017
|
+
bbox_inches='tight'
|
|
3018
|
+
)
|
|
3019
|
+
plt.close()
|
|
3020
|
+
|
|
3021
|
+
if log_fn:
|
|
3022
|
+
log_fn(f"Saved {len(sample_indices)} input sample visualizations to {samples_dir}")
|
|
3023
|
+
if use_localization:
|
|
3024
|
+
log_fn(f"Saved localization crop previews to {loc_samples_dir}")
|
|
3025
|
+
log_fn(f"Crop progress visualizations will be saved every 2 epochs to {crop_progress_dir}")
|
|
3026
|
+
|
|
3027
|
+
def _save_epoch_crop_progress(epoch_idx: int, phase_name: str):
|
|
3028
|
+
if not (use_localization and crop_progress_dir and len(train_dataset) > 0):
|
|
3029
|
+
return
|
|
3030
|
+
if (epoch_idx + 1) % 2 != 0:
|
|
3031
|
+
return
|
|
3032
|
+
was_training = model.training
|
|
3033
|
+
model.eval()
|
|
3034
|
+
try:
|
|
3035
|
+
sample_count = min(max(1, crop_progress_num_samples), len(train_dataset))
|
|
3036
|
+
sample_indices_epoch = np.random.choice(len(train_dataset), size=sample_count, replace=False)
|
|
3037
|
+
|
|
3038
|
+
if classification_stage_start_epoch is None:
|
|
3039
|
+
gt_crop_prob = crop_mix_end_gt
|
|
3040
|
+
else:
|
|
3041
|
+
cls_stage_steps = max(1, config["epochs"] - classification_stage_start_epoch)
|
|
3042
|
+
cls_epoch_idx = max(0, epoch_idx - classification_stage_start_epoch)
|
|
3043
|
+
alpha = min(1.0, cls_epoch_idx / cls_stage_steps)
|
|
3044
|
+
gt_crop_prob = crop_mix_start_gt + (crop_mix_end_gt - crop_mix_start_gt) * alpha
|
|
3045
|
+
|
|
3046
|
+
for si, idx in enumerate(sample_indices_epoch):
|
|
3047
|
+
actual_idx = idx % len(train_dataset.clips)
|
|
3048
|
+
# Mirror classification crop construction: non-augmented model-res clip.
|
|
3049
|
+
try:
|
|
3050
|
+
clip_tensor = train_dataset.load_modelres_clip_by_index(actual_idx).unsqueeze(0).to(device)
|
|
3051
|
+
except Exception as e:
|
|
3052
|
+
logger.debug("Could not load model-res clip by index: %s", e)
|
|
3053
|
+
batch = train_dataset[idx]
|
|
3054
|
+
clip_tensor = batch[0].unsqueeze(0).to(device) # fallback
|
|
3055
|
+
|
|
3056
|
+
cp_valid = train_dataset.spatial_bbox_valid[actual_idx]
|
|
3057
|
+
valid_gt = bool(float(cp_valid[0].item() if cp_valid.dim() > 0 else cp_valid.item()) > 0.5)
|
|
3058
|
+
gt_bbox_raw = None
|
|
3059
|
+
if valid_gt:
|
|
3060
|
+
cp_bbox = train_dataset.spatial_bboxes[actual_idx]
|
|
3061
|
+
cp_bbox_f0 = cp_bbox[0] if cp_bbox.dim() == 2 else cp_bbox
|
|
3062
|
+
gt_bbox_raw = _clamp_bboxes_no_expand(
|
|
3063
|
+
cp_bbox_f0.view(1, 4).to(device)
|
|
3064
|
+
)[0]
|
|
3065
|
+
|
|
3066
|
+
with torch.no_grad():
|
|
3067
|
+
loc_wh = torch.tensor([[float(global_fixed_wh[0]), float(global_fixed_wh[1])]], device=device, dtype=clip_tensor.dtype)
|
|
3068
|
+
loc_out = model(clip_tensor, return_localization=True, localization_box_wh=loc_wh)
|
|
3069
|
+
_, pred_bbox = _split_localization_output(loc_out)
|
|
3070
|
+
if pred_bbox is None:
|
|
3071
|
+
pred_bbox = torch.zeros((1, 4), device=device)
|
|
3072
|
+
pred_bbox_raw_all = _clamp_bboxes_no_expand(pred_bbox)
|
|
3073
|
+
if pred_bbox_raw_all.dim() == 3:
|
|
3074
|
+
pred_bbox_raw = pred_bbox_raw_all[0, 0]
|
|
3075
|
+
else:
|
|
3076
|
+
pred_bbox_raw = pred_bbox_raw_all[0]
|
|
3077
|
+
|
|
3078
|
+
cls_bbox_raw = pred_bbox_raw.clone()
|
|
3079
|
+
if (phase_name == "classification") and (gt_bbox_raw is not None) and (gt_crop_prob >= 0.5):
|
|
3080
|
+
cls_bbox_raw = gt_bbox_raw.clone()
|
|
3081
|
+
|
|
3082
|
+
# Actual crop boxes used by the pipeline (with padding/min-size sanitization).
|
|
3083
|
+
pred_bbox_used = _sanitize_bboxes(pred_bbox_raw.view(1, 4))[0]
|
|
3084
|
+
cls_bbox_used = _sanitize_bboxes(cls_bbox_raw.view(1, 4))[0]
|
|
3085
|
+
|
|
3086
|
+
# Crop from native-resolution video
|
|
3087
|
+
full_clip = clip_tensor[0].detach().cpu()
|
|
3088
|
+
T = int(full_clip.shape[0])
|
|
3089
|
+
H = int(full_clip.shape[2]) # model resolution
|
|
3090
|
+
W = int(full_clip.shape[3])
|
|
3091
|
+
|
|
3092
|
+
try:
|
|
3093
|
+
raw_clip = train_dataset.load_fullres_clip_by_index(actual_idx).float()
|
|
3094
|
+
Hraw, Wraw = int(raw_clip.shape[2]), int(raw_clip.shape[3])
|
|
3095
|
+
except Exception as e:
|
|
3096
|
+
logger.debug("Could not load full-res clip for crop preview: %s", e)
|
|
3097
|
+
raw_clip = full_clip.clone()
|
|
3098
|
+
Hraw, Wraw = H, W
|
|
3099
|
+
|
|
3100
|
+
pred_crop = _crop_single_clip_to_target(raw_clip, pred_bbox_used, H, W)
|
|
3101
|
+
cls_crop = _crop_single_clip_to_target(raw_clip, cls_bbox_used, H, W)
|
|
3102
|
+
|
|
3103
|
+
num_show = min(T, 8)
|
|
3104
|
+
frame_indices_show = np.linspace(0, T - 1, num_show, dtype=int)
|
|
3105
|
+
fig_w = max(8.0, 2.4 * num_show)
|
|
3106
|
+
fig, axes = plt.subplots(3, num_show, figsize=(fig_w, 6.8))
|
|
3107
|
+
if num_show == 1:
|
|
3108
|
+
axes = axes.reshape(3, 1)
|
|
3109
|
+
|
|
3110
|
+
def _bbox_px(b):
|
|
3111
|
+
x1 = int(round(float(b[0].item()) * (W - 1)))
|
|
3112
|
+
y1 = int(round(float(b[1].item()) * (H - 1)))
|
|
3113
|
+
x2 = int(round(float(b[2].item()) * W))
|
|
3114
|
+
y2 = int(round(float(b[3].item()) * H))
|
|
3115
|
+
x1 = max(0, min(x1, W - 1))
|
|
3116
|
+
y1 = max(0, min(y1, H - 1))
|
|
3117
|
+
x2 = max(x1 + 1, min(x2, W))
|
|
3118
|
+
y2 = max(y1 + 1, min(y2, H))
|
|
3119
|
+
return x1, y1, x2, y2
|
|
3120
|
+
|
|
3121
|
+
px_pred_used = _bbox_px(pred_bbox_used)
|
|
3122
|
+
px_pred_raw = _bbox_px(pred_bbox_raw)
|
|
3123
|
+
px_cls_used = _bbox_px(cls_bbox_used)
|
|
3124
|
+
|
|
3125
|
+
for j, fi in enumerate(frame_indices_show):
|
|
3126
|
+
frame_full = (full_clip[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
|
|
3127
|
+
frame_pred = (pred_crop[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
|
|
3128
|
+
frame_cls = (cls_crop[fi].permute(1, 2, 0).numpy().clip(0, 1) * 255).astype(np.uint8)
|
|
3129
|
+
|
|
3130
|
+
axes[0, j].imshow(frame_full)
|
|
3131
|
+
pred_raw_rect = mpatches.Rectangle(
|
|
3132
|
+
(px_pred_raw[0], px_pred_raw[1]),
|
|
3133
|
+
px_pred_raw[2] - px_pred_raw[0],
|
|
3134
|
+
px_pred_raw[3] - px_pred_raw[1],
|
|
3135
|
+
fill=False,
|
|
3136
|
+
edgecolor='magenta',
|
|
3137
|
+
linewidth=1.2,
|
|
3138
|
+
linestyle='--'
|
|
3139
|
+
)
|
|
3140
|
+
axes[0, j].add_patch(pred_raw_rect)
|
|
3141
|
+
pred_rect = mpatches.Rectangle(
|
|
3142
|
+
(px_pred_used[0], px_pred_used[1]),
|
|
3143
|
+
px_pred_used[2] - px_pred_used[0],
|
|
3144
|
+
px_pred_used[3] - px_pred_used[1],
|
|
3145
|
+
fill=False,
|
|
3146
|
+
edgecolor='cyan',
|
|
3147
|
+
linewidth=1.4
|
|
3148
|
+
)
|
|
3149
|
+
axes[0, j].add_patch(pred_rect)
|
|
3150
|
+
cls_rect = mpatches.Rectangle(
|
|
3151
|
+
(px_cls_used[0], px_cls_used[1]),
|
|
3152
|
+
px_cls_used[2] - px_cls_used[0],
|
|
3153
|
+
px_cls_used[3] - px_cls_used[1],
|
|
3154
|
+
fill=False,
|
|
3155
|
+
edgecolor='lime',
|
|
3156
|
+
linewidth=1.4
|
|
3157
|
+
)
|
|
3158
|
+
axes[0, j].add_patch(cls_rect)
|
|
3159
|
+
axes[0, j].axis('off')
|
|
3160
|
+
axes[0, j].set_title(f"f{fi}", fontsize=7)
|
|
3161
|
+
|
|
3162
|
+
axes[1, j].imshow(frame_pred)
|
|
3163
|
+
axes[1, j].axis('off')
|
|
3164
|
+
axes[2, j].imshow(frame_cls)
|
|
3165
|
+
axes[2, j].axis('off')
|
|
3166
|
+
|
|
3167
|
+
axes[0, 0].set_ylabel("Full+boxes (clip-level)", fontsize=9)
|
|
3168
|
+
axes[1, 0].set_ylabel("Pred crop", fontsize=9)
|
|
3169
|
+
axes[2, 0].set_ylabel("Cls input", fontsize=9)
|
|
3170
|
+
|
|
3171
|
+
cp_clip_id = train_dataset.clips[actual_idx].get("id", "?")
|
|
3172
|
+
cp_label_name = class_names[train_dataset.labels[actual_idx]] if train_dataset.labels[actual_idx] >= 0 else "?"
|
|
3173
|
+
cp_clip_short = os.path.basename(str(cp_clip_id))
|
|
3174
|
+
if phase_name == "localization":
|
|
3175
|
+
title_note = (
|
|
3176
|
+
f"Epoch {epoch_idx+1} | LOCALIZATION STAGE | \"{cp_label_name}\" | {cp_clip_short}\n"
|
|
3177
|
+
f"row0: model-res {H}×{W}, crops from native {Hraw}×{Wraw} | crop_padding={crop_padding} | "
|
|
3178
|
+
"magenta=pred raw, cyan=pred+padding"
|
|
3179
|
+
)
|
|
3180
|
+
else:
|
|
3181
|
+
title_note = (
|
|
3182
|
+
f"Epoch {epoch_idx+1} | CLASSIFICATION STAGE | \"{cp_label_name}\" | {cp_clip_short} | gt_mix={gt_crop_prob:.2f}\n"
|
|
3183
|
+
f"row0: model-res {H}×{W}, crops from native {Hraw}×{Wraw} | crop_padding={crop_padding} | "
|
|
3184
|
+
"magenta=pred raw, cyan=pred+padding, lime=cls input"
|
|
3185
|
+
)
|
|
3186
|
+
fig.suptitle(title_note, fontsize=9, fontweight='bold')
|
|
3187
|
+
plt.tight_layout(rect=[0, 0, 1, 0.93])
|
|
3188
|
+
out_name = f"epoch_{epoch_idx+1:03d}_sample_{si+1}.png"
|
|
3189
|
+
plt.savefig(os.path.join(crop_progress_dir, out_name), dpi=140, bbox_inches='tight')
|
|
3190
|
+
plt.close(fig)
|
|
3191
|
+
|
|
3192
|
+
if log_fn:
|
|
3193
|
+
log_fn(
|
|
3194
|
+
f"Saved crop progress previews for epoch {epoch_idx+1} to {crop_progress_dir} "
|
|
3195
|
+
f"(visualization only; classification uses precomputed ROI cache)"
|
|
3196
|
+
)
|
|
3197
|
+
except Exception as e_crop:
|
|
3198
|
+
if log_fn:
|
|
3199
|
+
log_fn(f"Note: Could not save crop progress previews for epoch {epoch_idx+1}: {e_crop}")
|
|
3200
|
+
finally:
|
|
3201
|
+
if was_training:
|
|
3202
|
+
model.train()
|
|
3203
|
+
except Exception as e:
|
|
3204
|
+
if log_fn:
|
|
3205
|
+
log_fn(f"Note: Could not save input samples: {e}")
|
|
3206
|
+
def _save_epoch_crop_progress(epoch_idx: int, phase_name: str):
|
|
3207
|
+
return
|
|
3208
|
+
|
|
3209
|
+
for epoch in range(config["epochs"]):
|
|
3210
|
+
if stop_callback and stop_callback():
|
|
3211
|
+
if log_fn:
|
|
3212
|
+
log_fn("Training stopped by user.")
|
|
3213
|
+
break
|
|
3214
|
+
|
|
3215
|
+
if progress_callback:
|
|
3216
|
+
progress_callback(epoch + 1, config["epochs"])
|
|
3217
|
+
|
|
3218
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
3219
|
+
epoch_phase = "localization" if in_localization_stage else "classification"
|
|
3220
|
+
if log_fn:
|
|
3221
|
+
log_fn(f"\n=== Epoch {epoch+1}/{config['epochs']} | phase={epoch_phase} (LR: {current_lr:.2e}) ===")
|
|
3222
|
+
|
|
3223
|
+
# On first classification epoch: crop all clips using the trained
|
|
3224
|
+
# localization head, save to disk, and switch datasets to load from
|
|
3225
|
+
# those .pt files. DataLoader uses normal num_workers so classification
|
|
3226
|
+
# training runs at the same speed as standard (no-localization) training.
|
|
3227
|
+
if not in_localization_stage and use_localization and has_any_localization and enable_roi_cache:
|
|
3228
|
+
if train_roi_cache is None:
|
|
3229
|
+
if log_fn:
|
|
3230
|
+
log_fn("Cropping clips using localization model and saving to disk (train split)...")
|
|
3231
|
+
train_roi_cache = _precompute_roi_cache(train_dataset, split_name="train")
|
|
3232
|
+
train_dataset._roi_cache_mode = True
|
|
3233
|
+
train_dataset._roi_cache_dir = train_roi_cache
|
|
3234
|
+
if log_fn:
|
|
3235
|
+
log_fn("Dataset switched to disk-cached crops. Training now identical to standard classifier.")
|
|
3236
|
+
# Recreate loader with normal workers (loading .pt files is fast)
|
|
3237
|
+
if batch_sampler is not None:
|
|
3238
|
+
train_loader = DataLoader(
|
|
3239
|
+
train_dataset, batch_sampler=batch_sampler,
|
|
3240
|
+
num_workers=num_workers,
|
|
3241
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
3242
|
+
persistent_workers=True if num_workers > 0 else False,
|
|
3243
|
+
)
|
|
3244
|
+
else:
|
|
3245
|
+
train_loader = DataLoader(
|
|
3246
|
+
train_dataset, batch_size=config["batch_size"],
|
|
3247
|
+
shuffle=shuffle, sampler=sampler, num_workers=num_workers,
|
|
3248
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
3249
|
+
persistent_workers=True if num_workers > 0 else False,
|
|
3250
|
+
)
|
|
3251
|
+
if val_roi_cache is None and val_dataset is not None:
|
|
3252
|
+
if log_fn:
|
|
3253
|
+
log_fn("Cropping clips using localization model and saving to disk (val split)...")
|
|
3254
|
+
val_roi_cache = _precompute_roi_cache(val_dataset, split_name="val")
|
|
3255
|
+
val_dataset._roi_cache_mode = True
|
|
3256
|
+
val_dataset._roi_cache_dir = val_roi_cache
|
|
3257
|
+
val_loader = DataLoader(
|
|
3258
|
+
val_dataset, batch_size=config["batch_size"],
|
|
3259
|
+
shuffle=False, num_workers=num_workers,
|
|
3260
|
+
pin_memory=True if device.type == "cuda" else False,
|
|
3261
|
+
persistent_workers=True if num_workers > 0 else False,
|
|
3262
|
+
)
|
|
3263
|
+
|
|
3264
|
+
# Backbone embeddings are pre-computed once for classification training.
|
|
3265
|
+
# If clip-stitch is enabled, stitching then happens on cached embeddings.
|
|
3266
|
+
# Backbone is frozen during classification — caching is always equivalent
|
|
3267
|
+
# to running it live and avoids 10-50x redundant computation per epoch.
|
|
3268
|
+
emb_aug_versions = max(1, int(config.get("emb_aug_versions", 1)))
|
|
3269
|
+
use_multi_scale = bool(config.get("multi_scale", False))
|
|
3270
|
+
if not in_localization_stage and not use_localization:
|
|
3271
|
+
if not getattr(train_dataset, '_emb_cache_mode', False):
|
|
3272
|
+
aug_note = f" × {emb_aug_versions} augmented versions" if emb_aug_versions > 1 else ""
|
|
3273
|
+
ms_note = " + short-scale (multi-scale)" if use_multi_scale else ""
|
|
3274
|
+
if log_fn:
|
|
3275
|
+
log_fn(
|
|
3276
|
+
f"Pre-computing backbone embeddings "
|
|
3277
|
+
f"(train split{aug_note}{ms_note})..."
|
|
3278
|
+
)
|
|
3279
|
+
emb_cache_dir = _precompute_embedding_cache(
|
|
3280
|
+
train_dataset, split_name="train",
|
|
3281
|
+
num_aug_versions=emb_aug_versions,
|
|
3282
|
+
use_augmentation=emb_aug_versions > 1,
|
|
3283
|
+
multi_scale=use_multi_scale,
|
|
3284
|
+
)
|
|
3285
|
+
train_emb_cache = emb_cache_dir
|
|
3286
|
+
train_dataset._emb_cache_mode = True
|
|
3287
|
+
train_dataset._emb_cache_dir = emb_cache_dir
|
|
3288
|
+
train_dataset._emb_clip_length = config.get("clip_length", 8)
|
|
3289
|
+
train_dataset._emb_num_versions = emb_aug_versions
|
|
3290
|
+
train_dataset._emb_multi_scale = use_multi_scale
|
|
3291
|
+
# batch_sampler is mutually exclusive with batch_size/shuffle/sampler
|
|
3292
|
+
_pin = device.type == "cuda"
|
|
3293
|
+
_pw = num_workers > 0
|
|
3294
|
+
if batch_sampler is not None:
|
|
3295
|
+
train_loader = DataLoader(
|
|
3296
|
+
train_dataset, batch_sampler=batch_sampler,
|
|
3297
|
+
num_workers=num_workers, pin_memory=_pin, persistent_workers=_pw,
|
|
3298
|
+
)
|
|
3299
|
+
else:
|
|
3300
|
+
train_loader = DataLoader(
|
|
3301
|
+
train_dataset, batch_size=config["batch_size"],
|
|
3302
|
+
shuffle=shuffle, sampler=sampler,
|
|
3303
|
+
num_workers=num_workers, pin_memory=_pin, persistent_workers=_pw,
|
|
3304
|
+
)
|
|
3305
|
+
if val_dataset is not None and not getattr(val_dataset, '_emb_cache_mode', False):
|
|
3306
|
+
if log_fn:
|
|
3307
|
+
log_fn("Pre-computing backbone embeddings (val split, no aug)...")
|
|
3308
|
+
# Validation always uses a single clean (unaugmented) version
|
|
3309
|
+
val_emb_cache = _precompute_embedding_cache(
|
|
3310
|
+
val_dataset, split_name="val",
|
|
3311
|
+
num_aug_versions=1, use_augmentation=False,
|
|
3312
|
+
multi_scale=use_multi_scale,
|
|
3313
|
+
)
|
|
3314
|
+
val_emb_cache_dir = val_emb_cache
|
|
3315
|
+
val_dataset._emb_cache_mode = True
|
|
3316
|
+
val_dataset._emb_cache_dir = val_emb_cache
|
|
3317
|
+
val_dataset._emb_clip_length = config.get("clip_length", 8)
|
|
3318
|
+
val_dataset._emb_num_versions = 1
|
|
3319
|
+
val_dataset._emb_multi_scale = use_multi_scale
|
|
3320
|
+
# No stitching during validation — we want clean per-clip evaluation
|
|
3321
|
+
val_dataset.stitch_prob = 0.0
|
|
3322
|
+
val_loader = DataLoader(
|
|
3323
|
+
val_dataset, batch_size=config["batch_size"],
|
|
3324
|
+
shuffle=False, num_workers=num_workers,
|
|
3325
|
+
pin_memory=device.type == "cuda",
|
|
3326
|
+
persistent_workers=num_workers > 0,
|
|
3327
|
+
)
|
|
3328
|
+
|
|
3329
|
+
model.train()
|
|
3330
|
+
total_loss = 0.0
|
|
3331
|
+
total_loss_class = 0.0
|
|
3332
|
+
correct = 0
|
|
3333
|
+
total = 0
|
|
3334
|
+
frame_correct = 0
|
|
3335
|
+
frame_total = 0
|
|
3336
|
+
train_targets_all = []
|
|
3337
|
+
train_preds_all = []
|
|
3338
|
+
# Per-clip confusion scores accumulated over this epoch (OvR only).
|
|
3339
|
+
# Use real clip count, not the virtual-expanded dataset length.
|
|
3340
|
+
_n_clips = len(train_dataset.clips) if hasattr(train_dataset, 'clips') else len(train_dataset)
|
|
3341
|
+
_epoch_confusion = np.zeros(_n_clips, dtype=np.float32)
|
|
3342
|
+
_epoch_confusion_count = np.zeros(_n_clips, dtype=np.int32)
|
|
3343
|
+
_epoch_top_rival = np.full(_n_clips, -1, dtype=np.int32)
|
|
3344
|
+
|
|
3345
|
+
try:
|
|
3346
|
+
for batch_idx, batch_data in enumerate(train_loader):
|
|
3347
|
+
if stop_callback and stop_callback():
|
|
3348
|
+
break
|
|
3349
|
+
|
|
3350
|
+
try:
|
|
3351
|
+
frame_labels_batch = None
|
|
3352
|
+
clips_short_batch = None
|
|
3353
|
+
bg_mask_batch = None
|
|
3354
|
+
suppress_batch = None
|
|
3355
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 7:
|
|
3356
|
+
clips, labels, spatial_masks_batch, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch, frame_labels_batch = batch_data[:7]
|
|
3357
|
+
else:
|
|
3358
|
+
clips, labels, spatial_masks_batch, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch = batch_data[:6]
|
|
3359
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 8:
|
|
3360
|
+
_cs = batch_data[7]
|
|
3361
|
+
if isinstance(_cs, torch.Tensor) and _cs.numel() > 0:
|
|
3362
|
+
clips_short_batch = _cs
|
|
3363
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 9:
|
|
3364
|
+
_bg = batch_data[8]
|
|
3365
|
+
if isinstance(_bg, torch.Tensor) and _bg.numel() > 0:
|
|
3366
|
+
bg_mask_batch = _bg.to(device=device, dtype=torch.bool)
|
|
3367
|
+
clips = clips.to(device)
|
|
3368
|
+
labels = labels.to(device)
|
|
3369
|
+
if frame_labels_batch is not None:
|
|
3370
|
+
frame_labels_batch = frame_labels_batch.to(device)
|
|
3371
|
+
if use_ovr:
|
|
3372
|
+
suppress_batch = _lookup_ovr_suppress_for_batch(sample_indices_batch, train_dataset, device)
|
|
3373
|
+
spatial_bboxes_batch = spatial_bboxes_batch.to(device)
|
|
3374
|
+
spatial_bbox_valid_batch = spatial_bbox_valid_batch.to(device)
|
|
3375
|
+
|
|
3376
|
+
optimizer.zero_grad()
|
|
3377
|
+
attn_w = None
|
|
3378
|
+
|
|
3379
|
+
# First-frame bbox slices for classification stage crop decisions
|
|
3380
|
+
spatial_bboxes_f0 = spatial_bboxes_batch[:, 0, :] if spatial_bboxes_batch.dim() == 3 else spatial_bboxes_batch
|
|
3381
|
+
spatial_bbox_valid_f0 = spatial_bbox_valid_batch[:, 0] if spatial_bbox_valid_batch.dim() == 2 else spatial_bbox_valid_batch
|
|
3382
|
+
|
|
3383
|
+
# Stage 1: train localization on full frames (all frames supervised).
|
|
3384
|
+
# spatial_bboxes_batch: [B, T, 4], spatial_bbox_valid_batch: [B, T]
|
|
3385
|
+
if in_localization_stage:
|
|
3386
|
+
loc_fixed_wh = _fixed_wh_for_labels(labels, device=clips.device, dtype=clips.dtype)
|
|
3387
|
+
loc_out = model(
|
|
3388
|
+
clips,
|
|
3389
|
+
return_localization=True,
|
|
3390
|
+
cache_backbone_tokens=True,
|
|
3391
|
+
localization_box_wh=loc_fixed_wh,
|
|
3392
|
+
)
|
|
3393
|
+
_, loc_pred_bboxes = _split_localization_output(loc_out)
|
|
3394
|
+
if loc_pred_bboxes is None or not has_any_localization:
|
|
3395
|
+
raise RuntimeError("Localization stage enabled but localization predictions/targets are unavailable.")
|
|
3396
|
+
|
|
3397
|
+
raw_model = model.module if hasattr(model, "module") else model
|
|
3398
|
+
backbone_tokens = raw_model._backbone_tokens
|
|
3399
|
+
|
|
3400
|
+
# Flatten temporal dimension for per-frame supervision:
|
|
3401
|
+
# loc_pred_bboxes: [B, T, 4] or [B, 4]
|
|
3402
|
+
B_loc = spatial_bboxes_batch.size(0)
|
|
3403
|
+
T_loc = spatial_bboxes_batch.size(1) if spatial_bboxes_batch.dim() == 3 else 1
|
|
3404
|
+
tgt_flat = spatial_bboxes_batch.view(B_loc * T_loc, 4) # [B*T, 4]
|
|
3405
|
+
valid_flat = spatial_bbox_valid_batch.view(B_loc * T_loc) # [B*T]
|
|
3406
|
+
if loc_pred_bboxes.dim() == 2:
|
|
3407
|
+
# [B, 4] → repeat for T frames
|
|
3408
|
+
pred_flat = loc_pred_bboxes.unsqueeze(1).expand(-1, T_loc, -1).reshape(B_loc * T_loc, 4)
|
|
3409
|
+
else:
|
|
3410
|
+
pred_flat = loc_pred_bboxes.view(B_loc * T_loc, 4) # [B*T, 4]
|
|
3411
|
+
|
|
3412
|
+
# Primary: center heatmap with Gaussian focal loss (all frames)
|
|
3413
|
+
obj_logits = raw_model.localization_head.get_objectness_logits(
|
|
3414
|
+
backbone_tokens, num_frames=train_dataset.clip_length,
|
|
3415
|
+
all_frames=(T_loc > 1),
|
|
3416
|
+
) # [B*T, S] when all_frames=True, else [B, S]
|
|
3417
|
+
from .model import center_heatmap_loss, direct_center_loss
|
|
3418
|
+
chm_loss = center_heatmap_loss(
|
|
3419
|
+
obj_logits,
|
|
3420
|
+
tgt_flat,
|
|
3421
|
+
valid_flat,
|
|
3422
|
+
sigma_in_patches=center_heatmap_sigma,
|
|
3423
|
+
)
|
|
3424
|
+
loss = center_heatmap_weight * chm_loss
|
|
3425
|
+
chm_loss_val = chm_loss.item()
|
|
3426
|
+
|
|
3427
|
+
# Secondary: direct center-to-center regression (all frames)
|
|
3428
|
+
dc_loss = direct_center_loss(
|
|
3429
|
+
pred_flat,
|
|
3430
|
+
tgt_flat,
|
|
3431
|
+
valid_flat,
|
|
3432
|
+
)
|
|
3433
|
+
loss = loss + direct_center_weight * dc_loss
|
|
3434
|
+
dc_loss_val = dc_loss.item()
|
|
3435
|
+
|
|
3436
|
+
loss.backward()
|
|
3437
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
3438
|
+
if any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
|
|
3439
|
+
optimizer.zero_grad()
|
|
3440
|
+
if log_fn and batch_idx % 10 == 0:
|
|
3441
|
+
log_fn(f"[loc] Skipped batch {batch_idx} (NaN grad)")
|
|
3442
|
+
else:
|
|
3443
|
+
optimizer.step()
|
|
3444
|
+
_update_ema()
|
|
3445
|
+
|
|
3446
|
+
batch_size = clips.size(0)
|
|
3447
|
+
total_loss += loss.item() * batch_size
|
|
3448
|
+
total_loss_class += 0.0
|
|
3449
|
+
total += batch_size
|
|
3450
|
+
|
|
3451
|
+
if log_fn and batch_idx % 10 == 0:
|
|
3452
|
+
iou_val, center_err, valid_rate = _localization_metrics(
|
|
3453
|
+
pred_flat.detach(),
|
|
3454
|
+
tgt_flat,
|
|
3455
|
+
valid_flat,
|
|
3456
|
+
)
|
|
3457
|
+
log_fn(
|
|
3458
|
+
f"Epoch {epoch+1}/{config['epochs']}, Batch {batch_idx}/{len(train_loader)}, "
|
|
3459
|
+
f"Loc Loss: {loss.item():.4f}, CHM: {chm_loss_val:.4f}, DC: {dc_loss_val:.4f}, "
|
|
3460
|
+
f"CErr: {center_err:.3f}, IoU: {iou_val:.3f}, VRate: {valid_rate:.3f}"
|
|
3461
|
+
)
|
|
3462
|
+
continue
|
|
3463
|
+
|
|
3464
|
+
# Classification forward: if embedding cache is active, clips
|
|
3465
|
+
# are pre-computed tokens [B, T*S, D]; otherwise (localization
|
|
3466
|
+
# pipeline) clips are raw pixels [B, T, C, H, W].
|
|
3467
|
+
_clip_len = config.get("clip_length", 8)
|
|
3468
|
+
_emb_active = getattr(train_dataset, '_emb_cache_mode', False)
|
|
3469
|
+
if _emb_active:
|
|
3470
|
+
_cs = clips_short_batch.to(device) if clips_short_batch is not None else None
|
|
3471
|
+
logits = model(
|
|
3472
|
+
None,
|
|
3473
|
+
backbone_tokens=clips,
|
|
3474
|
+
num_frames=_clip_len,
|
|
3475
|
+
backbone_tokens_short=_cs,
|
|
3476
|
+
num_frames_short=_clip_len // 2 if _cs is not None else None,
|
|
3477
|
+
return_localization=False,
|
|
3478
|
+
return_frame_logits=True,
|
|
3479
|
+
)
|
|
3480
|
+
else:
|
|
3481
|
+
logits = model(
|
|
3482
|
+
clips,
|
|
3483
|
+
return_localization=False,
|
|
3484
|
+
return_frame_logits=True,
|
|
3485
|
+
)
|
|
3486
|
+
|
|
3487
|
+
# Frame-level multi-task loss
|
|
3488
|
+
_fo = getattr(model, '_frame_output', None)
|
|
3489
|
+
if _fo is not None and frame_labels_batch is not None and not in_localization_stage:
|
|
3490
|
+
f_logits = _fo[0]
|
|
3491
|
+
f_logits_pooled = _fo[3] if len(_fo) > 3 else f_logits
|
|
3492
|
+
pool_n = int(_fo[4]) if len(_fo) > 4 else 1
|
|
3493
|
+
boundary_logits_out = _fo[5] if len(_fo) > 5 else None
|
|
3494
|
+
frame_embeddings_out = _fo[6] if len(_fo) > 6 else None
|
|
3495
|
+
labels_for_loss = _pool_frame_labels(frame_labels_batch, pool_n)
|
|
3496
|
+
logits_for_loss = f_logits_pooled if pool_n > 1 else f_logits
|
|
3497
|
+
embeddings_for_loss = _pool_frame_embeddings(frame_embeddings_out, pool_n)
|
|
3498
|
+
bg_mask_for_loss = None
|
|
3499
|
+
if use_ovr and ovr_background_as_negative and bg_mask_batch is not None:
|
|
3500
|
+
bg_mask_for_loss = _pool_binary_mask(bg_mask_batch, pool_n)
|
|
3501
|
+
|
|
3502
|
+
# L_state: frame classification loss
|
|
3503
|
+
if use_ovr:
|
|
3504
|
+
B_f, T_f, C_f = logits_for_loss.shape
|
|
3505
|
+
fl_ovr_targets = torch.full((B_f, T_f, C_f), ovr_label_smoothing, device=device)
|
|
3506
|
+
fl_ovr_weight = torch.ones(B_f, T_f, C_f, device=device)
|
|
3507
|
+
for bi in range(B_f):
|
|
3508
|
+
for ti in range(T_f):
|
|
3509
|
+
lbl = int(labels_for_loss[bi, ti].item())
|
|
3510
|
+
if 0 <= lbl < C_f:
|
|
3511
|
+
fl_ovr_targets[bi, ti, lbl] = 1.0 - ovr_label_smoothing
|
|
3512
|
+
fl_ovr_weight[bi, ti, lbl] = ovr_pos_weight[lbl]
|
|
3513
|
+
for cj in cooccur_lookup.get(lbl, ()):
|
|
3514
|
+
if cj != lbl and 0 <= cj < C_f:
|
|
3515
|
+
fl_ovr_weight[bi, ti, cj] = 0.0
|
|
3516
|
+
suppress_mask_for_loss = None
|
|
3517
|
+
if suppress_batch is not None:
|
|
3518
|
+
suppress_mask_for_loss = (labels_for_loss < 0) & (suppress_batch.view(-1, 1) >= 0)
|
|
3519
|
+
if suppress_mask_for_loss.any():
|
|
3520
|
+
fl_ovr_weight = fl_ovr_weight.masked_fill(
|
|
3521
|
+
suppress_mask_for_loss.unsqueeze(-1), 0.0
|
|
3522
|
+
)
|
|
3523
|
+
hn_b, hn_t = torch.nonzero(suppress_mask_for_loss, as_tuple=True)
|
|
3524
|
+
hn_c = suppress_batch[hn_b]
|
|
3525
|
+
fl_ovr_targets[hn_b, hn_t, hn_c] = 0.0
|
|
3526
|
+
fl_ovr_weight[hn_b, hn_t, hn_c] = 1.0
|
|
3527
|
+
if bg_mask_for_loss is not None and bg_mask_for_loss.any():
|
|
3528
|
+
fl_ovr_targets = fl_ovr_targets.masked_fill(
|
|
3529
|
+
bg_mask_for_loss.unsqueeze(-1), 0.0
|
|
3530
|
+
)
|
|
3531
|
+
valid_ovr_mask = (
|
|
3532
|
+
(labels_for_loss >= 0) |
|
|
3533
|
+
(suppress_mask_for_loss if suppress_mask_for_loss is not None else torch.zeros_like(labels_for_loss, dtype=torch.bool)) |
|
|
3534
|
+
(bg_mask_for_loss if bg_mask_for_loss is not None else torch.zeros_like(labels_for_loss, dtype=torch.bool))
|
|
3535
|
+
)
|
|
3536
|
+
loss = _frame_loss_balanced(
|
|
3537
|
+
logits_for_loss, labels_for_loss,
|
|
3538
|
+
use_ovr_local=True, ovr_targets=fl_ovr_targets,
|
|
3539
|
+
ovr_weight=fl_ovr_weight,
|
|
3540
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3541
|
+
bout_power=frame_bout_balance_power,
|
|
3542
|
+
valid_mask_override=valid_ovr_mask,
|
|
3543
|
+
)
|
|
3544
|
+
else:
|
|
3545
|
+
loss = _frame_loss_balanced(
|
|
3546
|
+
logits_for_loss, labels_for_loss,
|
|
3547
|
+
use_ovr_local=False,
|
|
3548
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3549
|
+
bout_power=frame_bout_balance_power,
|
|
3550
|
+
)
|
|
3551
|
+
if hard_pair_mining and hard_pair_loss_weight > 0:
|
|
3552
|
+
pair_loss = _hard_pair_margin_loss(
|
|
3553
|
+
logits_for_loss,
|
|
3554
|
+
labels_for_loss,
|
|
3555
|
+
hard_pair_index_pairs,
|
|
3556
|
+
hard_pair_margin,
|
|
3557
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3558
|
+
bout_power=frame_bout_balance_power,
|
|
3559
|
+
)
|
|
3560
|
+
loss = loss + hard_pair_loss_weight * pair_loss
|
|
3561
|
+
|
|
3562
|
+
# L_boundary: boundary detection loss
|
|
3563
|
+
if boundary_logits_out is not None and boundary_loss_weight > 0:
|
|
3564
|
+
boundary_labels_batch = _generate_boundary_labels(
|
|
3565
|
+
frame_labels_batch, tolerance=boundary_tolerance,
|
|
3566
|
+
)
|
|
3567
|
+
from .model import boundary_detection_loss
|
|
3568
|
+
b_loss = boundary_detection_loss(
|
|
3569
|
+
boundary_logits_out, boundary_labels_batch,
|
|
3570
|
+
)
|
|
3571
|
+
loss = loss + boundary_loss_weight * b_loss
|
|
3572
|
+
|
|
3573
|
+
# L_smooth: temporal smoothness regularizer
|
|
3574
|
+
if smoothness_loss_weight > 0:
|
|
3575
|
+
from .model import temporal_smoothness_loss
|
|
3576
|
+
s_loss = temporal_smoothness_loss(logits_for_loss, labels_for_loss)
|
|
3577
|
+
loss = loss + smoothness_loss_weight * s_loss
|
|
3578
|
+
|
|
3579
|
+
if use_supcon_loss and supcon_weight > 0:
|
|
3580
|
+
sc_loss = _supervised_contrastive_loss(
|
|
3581
|
+
embeddings_for_loss,
|
|
3582
|
+
labels_for_loss,
|
|
3583
|
+
temperature=supcon_temperature,
|
|
3584
|
+
)
|
|
3585
|
+
loss = loss + supcon_weight * sc_loss
|
|
3586
|
+
|
|
3587
|
+
with torch.no_grad():
|
|
3588
|
+
valid_fl = labels_for_loss >= 0
|
|
3589
|
+
if valid_fl.any():
|
|
3590
|
+
if use_ovr:
|
|
3591
|
+
f_pred = torch.argmax(torch.sigmoid(logits_for_loss.detach()), dim=-1)
|
|
3592
|
+
else:
|
|
3593
|
+
f_pred = torch.argmax(logits_for_loss.detach(), dim=-1)
|
|
3594
|
+
frame_total += int(valid_fl.sum().item())
|
|
3595
|
+
frame_correct += int((f_pred[valid_fl] == labels_for_loss[valid_fl]).sum().item())
|
|
3596
|
+
train_targets_all.extend(labels_for_loss[valid_fl].detach().cpu().tolist())
|
|
3597
|
+
train_preds_all.extend(f_pred[valid_fl].detach().cpu().tolist())
|
|
3598
|
+
|
|
3599
|
+
# Accumulate per-clip confusion scores for ConfusionAwareSampler
|
|
3600
|
+
# Skip during warmup: model is still learning basics, scores are noise
|
|
3601
|
+
_confusion_warmup_epoch = int(config["epochs"] * _confusion_warmup_pct)
|
|
3602
|
+
if (use_ovr and isinstance(batch_sampler, ConfusionAwareSampler)
|
|
3603
|
+
and not in_localization_stage
|
|
3604
|
+
and (epoch + 1) > _confusion_warmup_epoch):
|
|
3605
|
+
probs = torch.sigmoid(logits_for_loss.detach()) # [B, T, C]
|
|
3606
|
+
B_cs = probs.shape[0]
|
|
3607
|
+
C_cs = probs.shape[-1]
|
|
3608
|
+
for bi in range(B_cs):
|
|
3609
|
+
raw_idx = int(sample_indices_batch[bi].item())
|
|
3610
|
+
clip_idx = raw_idx % _n_clips
|
|
3611
|
+
# Skip stitched clips (y=-1): mixed classes corrupt the score
|
|
3612
|
+
clip_label = int(labels[bi].item())
|
|
3613
|
+
if clip_label < 0 or clip_label >= C_cs:
|
|
3614
|
+
continue
|
|
3615
|
+
valid_t = labels_for_loss[bi] >= 0
|
|
3616
|
+
if not valid_t.any() or C_cs < 2:
|
|
3617
|
+
continue
|
|
3618
|
+
avg_p = probs[bi][valid_t].mean(0) # [C]
|
|
3619
|
+
# Pair-aware confusion:
|
|
3620
|
+
# favor clips where a specific rival head stays high,
|
|
3621
|
+
# the true head stays weak, or the rival overtakes the true head.
|
|
3622
|
+
wrong_mask = torch.ones(C_cs, dtype=torch.bool, device=avg_p.device)
|
|
3623
|
+
wrong_mask[clip_label] = False
|
|
3624
|
+
if not wrong_mask.any():
|
|
3625
|
+
continue
|
|
3626
|
+
wrong_idx = torch.where(wrong_mask)[0]
|
|
3627
|
+
wrong_probs = avg_p[wrong_mask]
|
|
3628
|
+
top_local = int(torch.argmax(wrong_probs).item())
|
|
3629
|
+
top_rival = int(wrong_idx[top_local].item())
|
|
3630
|
+
top_wrong = float(wrong_probs[top_local].item())
|
|
3631
|
+
true_p = float(avg_p[clip_label].item())
|
|
3632
|
+
rival_margin = max(0.0, top_wrong - true_p)
|
|
3633
|
+
true_deficit = max(0.0, 1.0 - true_p)
|
|
3634
|
+
confusion = 0.55 * top_wrong + 0.30 * rival_margin + 0.15 * true_deficit
|
|
3635
|
+
if hard_pair_mining and hard_pair_confusion_boost > 1.0:
|
|
3636
|
+
pair_key = (min(clip_label, top_rival), max(clip_label, top_rival))
|
|
3637
|
+
if pair_key in hard_pair_index_pairs:
|
|
3638
|
+
confusion *= hard_pair_confusion_boost
|
|
3639
|
+
confusion = min(2.0, confusion)
|
|
3640
|
+
_epoch_confusion[clip_idx] += confusion
|
|
3641
|
+
_epoch_confusion_count[clip_idx] += 1
|
|
3642
|
+
_epoch_top_rival[clip_idx] = top_rival
|
|
3643
|
+
else:
|
|
3644
|
+
loss = logits.sum() * 0.0
|
|
3645
|
+
|
|
3646
|
+
loss.backward()
|
|
3647
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
3648
|
+
if any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
|
|
3649
|
+
optimizer.zero_grad()
|
|
3650
|
+
if log_fn and batch_idx % 10 == 0:
|
|
3651
|
+
log_fn(f"[cls] Skipped batch {batch_idx} (NaN grad)")
|
|
3652
|
+
continue
|
|
3653
|
+
optimizer.step()
|
|
3654
|
+
_update_ema()
|
|
3655
|
+
|
|
3656
|
+
total_loss += loss.item() * clips.size(0)
|
|
3657
|
+
total_loss_class += loss.item() * clips.size(0)
|
|
3658
|
+
|
|
3659
|
+
with torch.no_grad():
|
|
3660
|
+
if use_ovr:
|
|
3661
|
+
predicted = torch.argmax(torch.sigmoid(logits.data), dim=1)
|
|
3662
|
+
else:
|
|
3663
|
+
_, predicted = torch.max(logits.data, 1)
|
|
3664
|
+
valid_mask = labels >= 0
|
|
3665
|
+
total += int(valid_mask.sum().item())
|
|
3666
|
+
correct += int((predicted[valid_mask] == labels[valid_mask]).sum().item())
|
|
3667
|
+
|
|
3668
|
+
if log_fn and batch_idx % 10 == 0:
|
|
3669
|
+
log_fn(f"Epoch {epoch+1}/{config['epochs']}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
|
|
3670
|
+
except Exception as e:
|
|
3671
|
+
error_msg = f"Error in training batch {batch_idx}: {str(e)}\n{traceback.format_exc()}"
|
|
3672
|
+
if log_fn:
|
|
3673
|
+
log_fn(f"ERROR: {error_msg}")
|
|
3674
|
+
raise
|
|
3675
|
+
|
|
3676
|
+
if len(train_dataset) > 0:
|
|
3677
|
+
avg_loss = total_loss / len(train_dataset)
|
|
3678
|
+
avg_loss_class = total_loss_class / len(train_dataset)
|
|
3679
|
+
else:
|
|
3680
|
+
avg_loss = 0.0
|
|
3681
|
+
avg_loss_class = 0.0
|
|
3682
|
+
|
|
3683
|
+
# Update ConfusionAwareSampler weights from this epoch's scores
|
|
3684
|
+
_confusion_warmup_epoch = int(config["epochs"] * _confusion_warmup_pct)
|
|
3685
|
+
if (use_ovr and isinstance(batch_sampler, ConfusionAwareSampler)
|
|
3686
|
+
and not in_localization_stage
|
|
3687
|
+
and (epoch + 1) > _confusion_warmup_epoch):
|
|
3688
|
+
nonzero = _epoch_confusion_count > 0
|
|
3689
|
+
avg_epoch_confusion = np.where(
|
|
3690
|
+
nonzero,
|
|
3691
|
+
_epoch_confusion / np.maximum(_epoch_confusion_count, 1),
|
|
3692
|
+
batch_sampler._confusion_scores, # keep last known score for unseen clips
|
|
3693
|
+
)
|
|
3694
|
+
avg_epoch_top_rival = np.where(
|
|
3695
|
+
nonzero,
|
|
3696
|
+
_epoch_top_rival,
|
|
3697
|
+
batch_sampler._top_rival,
|
|
3698
|
+
)
|
|
3699
|
+
batch_sampler.update_weights(avg_epoch_confusion, top_rivals=avg_epoch_top_rival)
|
|
3700
|
+
if log_fn and (epoch + 1) % 5 == 0:
|
|
3701
|
+
log_fn("Confusion sampler — hardest clips per class:")
|
|
3702
|
+
for line in batch_sampler.log_top_confused(
|
|
3703
|
+
class_names, getattr(train_dataset, "clips", [])
|
|
3704
|
+
):
|
|
3705
|
+
log_fn(line)
|
|
3706
|
+
|
|
3707
|
+
train_acc = 100.0 * correct / total if total > 0 else 0.0
|
|
3708
|
+
train_frame_acc = 100.0 * frame_correct / frame_total if frame_total > 0 else 0.0
|
|
3709
|
+
|
|
3710
|
+
if log_fn:
|
|
3711
|
+
if use_frame_loss and not in_localization_stage:
|
|
3712
|
+
log_fn(
|
|
3713
|
+
f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_loss:.4f}, "
|
|
3714
|
+
f"Train Acc (clip): {train_acc:.2f}%, Train Acc (frame): {train_frame_acc:.2f}%"
|
|
3715
|
+
)
|
|
3716
|
+
else:
|
|
3717
|
+
log_fn(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%")
|
|
3718
|
+
|
|
3719
|
+
# Record train metrics
|
|
3720
|
+
history["epoch"].append(epoch + 1)
|
|
3721
|
+
history["train_loss"].append(avg_loss)
|
|
3722
|
+
history["train_loss_class"].append(avg_loss_class)
|
|
3723
|
+
history["train_acc"].append(train_acc)
|
|
3724
|
+
history["train_frame_acc"].append(train_frame_acc)
|
|
3725
|
+
history["val_loss"].append(0.0)
|
|
3726
|
+
history["val_acc"].append(0.0)
|
|
3727
|
+
history["val_frame_acc"].append(0.0)
|
|
3728
|
+
history["val_f1"].append(0.0)
|
|
3729
|
+
history["loc_val_iou"].append(0.0)
|
|
3730
|
+
history["loc_val_center_error"].append(1.0 if in_localization_stage else 0.0)
|
|
3731
|
+
history["loc_val_valid_rate"].append(0.0)
|
|
3732
|
+
for key in class_key_map.values():
|
|
3733
|
+
history[key].append(0.0)
|
|
3734
|
+
|
|
3735
|
+
val_acc = 0.0
|
|
3736
|
+
avg_val_loss = 0.0
|
|
3737
|
+
per_attr_f1 = {}
|
|
3738
|
+
|
|
3739
|
+
if val_loader:
|
|
3740
|
+
_apply_ema() # swap to EMA weights for validation
|
|
3741
|
+
model.eval()
|
|
3742
|
+
val_correct = 0
|
|
3743
|
+
val_total = 0
|
|
3744
|
+
val_frame_correct = 0
|
|
3745
|
+
val_frame_total = 0
|
|
3746
|
+
val_loss = 0.0
|
|
3747
|
+
val_targets_all = []
|
|
3748
|
+
val_preds_all = []
|
|
3749
|
+
val_score_chunks_by_class = {i: [] for i in range(len(class_names))}
|
|
3750
|
+
val_target_chunks_by_class = {i: [] for i in range(len(class_names))}
|
|
3751
|
+
val_loc_iou_sum = 0.0
|
|
3752
|
+
val_loc_center_sum = 0.0
|
|
3753
|
+
val_loc_valid_sum = 0.0
|
|
3754
|
+
val_loc_batches = 0
|
|
3755
|
+
|
|
3756
|
+
try:
|
|
3757
|
+
with torch.no_grad():
|
|
3758
|
+
for batch_data in val_loader:
|
|
3759
|
+
frame_labels_batch = None
|
|
3760
|
+
clips_short_val = None
|
|
3761
|
+
bg_mask_batch = None
|
|
3762
|
+
suppress_batch = None
|
|
3763
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 7:
|
|
3764
|
+
clips, labels, _, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch, frame_labels_batch = batch_data[:7]
|
|
3765
|
+
else:
|
|
3766
|
+
clips, labels, _, spatial_bboxes_batch, spatial_bbox_valid_batch, sample_indices_batch = batch_data[:6]
|
|
3767
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 8:
|
|
3768
|
+
_cs_v = batch_data[7]
|
|
3769
|
+
if isinstance(_cs_v, torch.Tensor) and _cs_v.numel() > 0:
|
|
3770
|
+
clips_short_val = _cs_v.to(device)
|
|
3771
|
+
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 9:
|
|
3772
|
+
_bg_v = batch_data[8]
|
|
3773
|
+
if isinstance(_bg_v, torch.Tensor) and _bg_v.numel() > 0:
|
|
3774
|
+
bg_mask_batch = _bg_v.to(device=device, dtype=torch.bool)
|
|
3775
|
+
clips = clips.to(device)
|
|
3776
|
+
labels = labels.to(device)
|
|
3777
|
+
if frame_labels_batch is not None:
|
|
3778
|
+
frame_labels_batch = frame_labels_batch.to(device)
|
|
3779
|
+
if use_ovr:
|
|
3780
|
+
suppress_batch = _lookup_ovr_suppress_for_batch(sample_indices_batch, val_dataset, device)
|
|
3781
|
+
spatial_bboxes_batch = spatial_bboxes_batch.to(device)
|
|
3782
|
+
spatial_bbox_valid_batch = spatial_bbox_valid_batch.to(device)
|
|
3783
|
+
|
|
3784
|
+
spatial_bboxes_f0 = spatial_bboxes_batch[:, 0, :] if spatial_bboxes_batch.dim() == 3 else spatial_bboxes_batch
|
|
3785
|
+
spatial_bbox_valid_f0 = spatial_bbox_valid_batch[:, 0] if spatial_bbox_valid_batch.dim() == 2 else spatial_bbox_valid_batch
|
|
3786
|
+
|
|
3787
|
+
if in_localization_stage:
|
|
3788
|
+
val_wh = _fixed_wh_for_labels(labels, device=clips.device, dtype=clips.dtype)
|
|
3789
|
+
val_out = model(clips, return_localization=True, localization_box_wh=val_wh)
|
|
3790
|
+
_, loc_pred_bboxes = _split_localization_output(val_out)
|
|
3791
|
+
if loc_pred_bboxes is None:
|
|
3792
|
+
continue
|
|
3793
|
+
|
|
3794
|
+
# Flatten temporal dim for per-frame eval
|
|
3795
|
+
B_v = spatial_bboxes_batch.size(0)
|
|
3796
|
+
T_v = spatial_bboxes_batch.size(1) if spatial_bboxes_batch.dim() == 3 else 1
|
|
3797
|
+
tgt_flat_v = spatial_bboxes_batch.view(B_v * T_v, 4)
|
|
3798
|
+
valid_flat_v = spatial_bbox_valid_batch.view(B_v * T_v)
|
|
3799
|
+
if loc_pred_bboxes.dim() == 2:
|
|
3800
|
+
pred_flat_v = loc_pred_bboxes.unsqueeze(1).expand(-1, T_v, -1).reshape(B_v * T_v, 4)
|
|
3801
|
+
else:
|
|
3802
|
+
pred_flat_v = loc_pred_bboxes.view(B_v * T_v, 4)
|
|
3803
|
+
|
|
3804
|
+
from .model import direct_center_loss
|
|
3805
|
+
loss = direct_center_loss(
|
|
3806
|
+
pred_flat_v,
|
|
3807
|
+
tgt_flat_v,
|
|
3808
|
+
valid_flat_v,
|
|
3809
|
+
)
|
|
3810
|
+
val_loss += loss.item() * clips.size(0)
|
|
3811
|
+
iou_val, center_err, valid_rate = _localization_metrics(
|
|
3812
|
+
pred_flat_v,
|
|
3813
|
+
tgt_flat_v,
|
|
3814
|
+
valid_flat_v,
|
|
3815
|
+
)
|
|
3816
|
+
val_loc_iou_sum += iou_val
|
|
3817
|
+
val_loc_center_sum += center_err
|
|
3818
|
+
val_loc_valid_sum += valid_rate
|
|
3819
|
+
val_loc_batches += 1
|
|
3820
|
+
continue
|
|
3821
|
+
|
|
3822
|
+
_val_clip_len = config.get("clip_length", 8)
|
|
3823
|
+
_val_emb_active = getattr(val_dataset, '_emb_cache_mode', False) if val_dataset is not None else False
|
|
3824
|
+
if _val_emb_active:
|
|
3825
|
+
logits = model(
|
|
3826
|
+
None,
|
|
3827
|
+
backbone_tokens=clips,
|
|
3828
|
+
num_frames=_val_clip_len,
|
|
3829
|
+
backbone_tokens_short=clips_short_val,
|
|
3830
|
+
num_frames_short=_val_clip_len // 2 if clips_short_val is not None else None,
|
|
3831
|
+
return_localization=False,
|
|
3832
|
+
return_frame_logits=True,
|
|
3833
|
+
)
|
|
3834
|
+
else:
|
|
3835
|
+
logits = model(
|
|
3836
|
+
clips,
|
|
3837
|
+
return_localization=False,
|
|
3838
|
+
return_frame_logits=True,
|
|
3839
|
+
)
|
|
3840
|
+
|
|
3841
|
+
# Frame-level validation metrics
|
|
3842
|
+
if not in_localization_stage and frame_labels_batch is not None:
|
|
3843
|
+
_fo_val = getattr(model, "_frame_output", None)
|
|
3844
|
+
if _fo_val is not None:
|
|
3845
|
+
f_logits_val = _fo_val[0]
|
|
3846
|
+
f_logits_val_pooled = _fo_val[3] if len(_fo_val) > 3 else f_logits_val
|
|
3847
|
+
pool_n_val = int(_fo_val[4]) if len(_fo_val) > 4 else 1
|
|
3848
|
+
frame_embeddings_val = _fo_val[6] if len(_fo_val) > 6 else None
|
|
3849
|
+
labels_for_val = _pool_frame_labels(frame_labels_batch, pool_n_val)
|
|
3850
|
+
logits_for_val = f_logits_val_pooled if pool_n_val > 1 else f_logits_val
|
|
3851
|
+
embeddings_for_val = _pool_frame_embeddings(frame_embeddings_val, pool_n_val)
|
|
3852
|
+
bg_mask_for_val = None
|
|
3853
|
+
if use_ovr and ovr_background_as_negative and bg_mask_batch is not None:
|
|
3854
|
+
bg_mask_for_val = _pool_binary_mask(bg_mask_batch, pool_n_val)
|
|
3855
|
+
valid_fl_val = labels_for_val >= 0
|
|
3856
|
+
if valid_fl_val.any():
|
|
3857
|
+
if use_ovr:
|
|
3858
|
+
score_tensor_val = torch.sigmoid(logits_for_val)
|
|
3859
|
+
else:
|
|
3860
|
+
score_tensor_val = torch.softmax(logits_for_val, dim=-1)
|
|
3861
|
+
valid_scores_np = score_tensor_val[valid_fl_val].detach().cpu().numpy()
|
|
3862
|
+
valid_targets_np = labels_for_val[valid_fl_val].detach().cpu().numpy()
|
|
3863
|
+
for cls_idx in range(len(class_names)):
|
|
3864
|
+
val_score_chunks_by_class[cls_idx].append(valid_scores_np[:, cls_idx].astype(np.float32, copy=False))
|
|
3865
|
+
val_target_chunks_by_class[cls_idx].append((valid_targets_np == cls_idx).astype(np.uint8, copy=False))
|
|
3866
|
+
if use_ovr:
|
|
3867
|
+
f_pred_val = torch.argmax(torch.sigmoid(logits_for_val), dim=-1)
|
|
3868
|
+
else:
|
|
3869
|
+
f_pred_val = torch.argmax(logits_for_val, dim=-1)
|
|
3870
|
+
val_frame_total += int(valid_fl_val.sum().item())
|
|
3871
|
+
val_frame_correct += int((f_pred_val[valid_fl_val] == labels_for_val[valid_fl_val]).sum().item())
|
|
3872
|
+
|
|
3873
|
+
if use_ovr:
|
|
3874
|
+
B_fv, T_fv, C_fv = logits_for_val.shape
|
|
3875
|
+
fl_ovr_targets_v = torch.full((B_fv, T_fv, C_fv), ovr_label_smoothing, device=device)
|
|
3876
|
+
fl_ovr_weight_v = torch.ones(B_fv, T_fv, C_fv, device=device)
|
|
3877
|
+
for bi in range(B_fv):
|
|
3878
|
+
for ti in range(T_fv):
|
|
3879
|
+
lbl = int(labels_for_val[bi, ti].item())
|
|
3880
|
+
if 0 <= lbl < C_fv:
|
|
3881
|
+
fl_ovr_targets_v[bi, ti, lbl] = 1.0 - ovr_label_smoothing
|
|
3882
|
+
fl_ovr_weight_v[bi, ti, lbl] = ovr_pos_weight[lbl]
|
|
3883
|
+
for cj in cooccur_lookup.get(lbl, ()):
|
|
3884
|
+
if cj != lbl and 0 <= cj < C_fv:
|
|
3885
|
+
fl_ovr_weight_v[bi, ti, cj] = 0.0
|
|
3886
|
+
suppress_mask_for_val = None
|
|
3887
|
+
if suppress_batch is not None:
|
|
3888
|
+
suppress_mask_for_val = (labels_for_val < 0) & (suppress_batch.view(-1, 1) >= 0)
|
|
3889
|
+
if suppress_mask_for_val.any():
|
|
3890
|
+
fl_ovr_weight_v = fl_ovr_weight_v.masked_fill(
|
|
3891
|
+
suppress_mask_for_val.unsqueeze(-1), 0.0
|
|
3892
|
+
)
|
|
3893
|
+
hn_bv, hn_tv = torch.nonzero(suppress_mask_for_val, as_tuple=True)
|
|
3894
|
+
hn_cv = suppress_batch[hn_bv]
|
|
3895
|
+
fl_ovr_targets_v[hn_bv, hn_tv, hn_cv] = 0.0
|
|
3896
|
+
fl_ovr_weight_v[hn_bv, hn_tv, hn_cv] = 1.0
|
|
3897
|
+
if bg_mask_for_val is not None and bg_mask_for_val.any():
|
|
3898
|
+
fl_ovr_targets_v = fl_ovr_targets_v.masked_fill(
|
|
3899
|
+
bg_mask_for_val.unsqueeze(-1), 0.0
|
|
3900
|
+
)
|
|
3901
|
+
valid_ovr_mask_v = (
|
|
3902
|
+
(labels_for_val >= 0) |
|
|
3903
|
+
(suppress_mask_for_val if suppress_mask_for_val is not None else torch.zeros_like(labels_for_val, dtype=torch.bool)) |
|
|
3904
|
+
(bg_mask_for_val if bg_mask_for_val is not None else torch.zeros_like(labels_for_val, dtype=torch.bool))
|
|
3905
|
+
)
|
|
3906
|
+
loss = _frame_loss_balanced(
|
|
3907
|
+
logits_for_val, labels_for_val,
|
|
3908
|
+
use_ovr_local=True, ovr_targets=fl_ovr_targets_v,
|
|
3909
|
+
ovr_weight=fl_ovr_weight_v,
|
|
3910
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3911
|
+
bout_power=frame_bout_balance_power,
|
|
3912
|
+
valid_mask_override=valid_ovr_mask_v,
|
|
3913
|
+
)
|
|
3914
|
+
else:
|
|
3915
|
+
loss = _frame_loss_balanced(
|
|
3916
|
+
logits_for_val, labels_for_val,
|
|
3917
|
+
use_ovr_local=False,
|
|
3918
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3919
|
+
bout_power=frame_bout_balance_power,
|
|
3920
|
+
)
|
|
3921
|
+
if hard_pair_mining and hard_pair_loss_weight > 0:
|
|
3922
|
+
pair_loss_v = _hard_pair_margin_loss(
|
|
3923
|
+
logits_for_val,
|
|
3924
|
+
labels_for_val,
|
|
3925
|
+
hard_pair_index_pairs,
|
|
3926
|
+
hard_pair_margin,
|
|
3927
|
+
use_bout_balance=use_frame_bout_balance,
|
|
3928
|
+
bout_power=frame_bout_balance_power,
|
|
3929
|
+
)
|
|
3930
|
+
loss = loss + hard_pair_loss_weight * pair_loss_v
|
|
3931
|
+
if use_supcon_loss and supcon_weight > 0:
|
|
3932
|
+
sc_loss_v = _supervised_contrastive_loss(
|
|
3933
|
+
embeddings_for_val,
|
|
3934
|
+
labels_for_val,
|
|
3935
|
+
temperature=supcon_temperature,
|
|
3936
|
+
)
|
|
3937
|
+
loss = loss + supcon_weight * sc_loss_v
|
|
3938
|
+
val_loss += loss.item() * clips.size(0)
|
|
3939
|
+
|
|
3940
|
+
if valid_fl_val.any():
|
|
3941
|
+
val_targets_all.extend(labels_for_val[valid_fl_val].detach().cpu().tolist())
|
|
3942
|
+
val_preds_all.extend(f_pred_val[valid_fl_val].detach().cpu().tolist())
|
|
3943
|
+
|
|
3944
|
+
avg_val_loss = val_loss / len(val_dataset)
|
|
3945
|
+
val_acc = 100.0 * val_correct / val_total if val_total > 0 else 0.0
|
|
3946
|
+
val_frame_acc = 100.0 * val_frame_correct / val_frame_total if val_frame_total > 0 else 0.0
|
|
3947
|
+
val_macro_f1 = 0.0
|
|
3948
|
+
per_class_f1 = np.zeros(len(class_names), dtype=float)
|
|
3949
|
+
per_class_support = np.zeros(len(class_names), dtype=int)
|
|
3950
|
+
conf_matrix = np.zeros((len(class_names), len(class_names)), dtype=int)
|
|
3951
|
+
|
|
3952
|
+
per_attr_f1 = {}
|
|
3953
|
+
val_ignore_thresholds = _calibrate_ignore_thresholds_from_validation(
|
|
3954
|
+
val_score_chunks_by_class,
|
|
3955
|
+
val_target_chunks_by_class,
|
|
3956
|
+
class_names,
|
|
3957
|
+
)
|
|
3958
|
+
|
|
3959
|
+
# val_targets_all / val_preds_all are from frame-level validation when frame_labels_batch was used.
|
|
3960
|
+
if val_targets_all:
|
|
3961
|
+
per_class_f1 = f1_score(
|
|
3962
|
+
val_targets_all,
|
|
3963
|
+
val_preds_all,
|
|
3964
|
+
labels=list(range(len(class_names))),
|
|
3965
|
+
average=None,
|
|
3966
|
+
zero_division=0
|
|
3967
|
+
) * 100.0
|
|
3968
|
+
if _f1_include_indices and len(_f1_include_indices) < len(class_names):
|
|
3969
|
+
val_macro_f1 = float(per_class_f1[_f1_include_indices].mean())
|
|
3970
|
+
else:
|
|
3971
|
+
val_macro_f1 = f1_score(
|
|
3972
|
+
val_targets_all,
|
|
3973
|
+
val_preds_all,
|
|
3974
|
+
average='macro',
|
|
3975
|
+
zero_division=0
|
|
3976
|
+
) * 100.0
|
|
3977
|
+
per_class_support = np.bincount(
|
|
3978
|
+
np.asarray(val_targets_all, dtype=np.int64),
|
|
3979
|
+
minlength=len(class_names),
|
|
3980
|
+
).astype(int)
|
|
3981
|
+
for t, p in zip(val_targets_all, val_preds_all):
|
|
3982
|
+
if 0 <= int(t) < len(class_names) and 0 <= int(p) < len(class_names):
|
|
3983
|
+
conf_matrix[int(t), int(p)] += 1
|
|
3984
|
+
|
|
3985
|
+
loc_val_iou = (val_loc_iou_sum / val_loc_batches) if val_loc_batches > 0 else 0.0
|
|
3986
|
+
loc_val_center = (val_loc_center_sum / val_loc_batches) if val_loc_batches > 0 else 1.0
|
|
3987
|
+
loc_val_valid = (val_loc_valid_sum / val_loc_batches) if val_loc_batches > 0 else 0.0
|
|
3988
|
+
|
|
3989
|
+
# Update history: when validation is frame-level, val_acc and val_f1 are per-frame.
|
|
3990
|
+
history["val_loss"][-1] = avg_val_loss
|
|
3991
|
+
if val_frame_total > 0:
|
|
3992
|
+
history["val_acc"][-1] = val_frame_acc
|
|
3993
|
+
else:
|
|
3994
|
+
history["val_acc"][-1] = val_acc
|
|
3995
|
+
history["val_frame_acc"][-1] = val_frame_acc
|
|
3996
|
+
history["val_f1"][-1] = val_macro_f1
|
|
3997
|
+
history["loc_val_iou"][-1] = loc_val_iou
|
|
3998
|
+
history["loc_val_center_error"][-1] = loc_val_center
|
|
3999
|
+
history["loc_val_valid_rate"][-1] = loc_val_valid
|
|
4000
|
+
for idx, key in class_key_map.items():
|
|
4001
|
+
if idx < len(per_class_f1):
|
|
4002
|
+
history[key][-1] = per_class_f1[idx]
|
|
4003
|
+
|
|
4004
|
+
if log_fn:
|
|
4005
|
+
if in_localization_stage:
|
|
4006
|
+
log_fn(
|
|
4007
|
+
f"Epoch {epoch+1}/{config['epochs']} - Val Loc Loss: {avg_val_loss:.4f}, "
|
|
4008
|
+
f"IoU: {loc_val_iou:.3f}, CErr: {loc_val_center:.3f}, VRate: {loc_val_valid:.3f}"
|
|
4009
|
+
)
|
|
4010
|
+
else:
|
|
4011
|
+
log_fn(
|
|
4012
|
+
f"Epoch {epoch+1}/{config['epochs']} - Val Loss: {avg_val_loss:.4f}, "
|
|
4013
|
+
f"Val Acc (frame): {val_frame_acc:.2f}%, "
|
|
4014
|
+
f"Val Macro F1: {val_macro_f1:.2f}%"
|
|
4015
|
+
)
|
|
4016
|
+
if val_targets_all:
|
|
4017
|
+
metric_scope = "frame-labeled" if (use_frame_loss and not in_localization_stage) else "clip"
|
|
4018
|
+
log_fn(f"Val class diagnostics ({metric_scope}):")
|
|
4019
|
+
for ci, cname in enumerate(class_names):
|
|
4020
|
+
excl_tag = " [excluded from F1]" if cname in _f1_exclude_names else ""
|
|
4021
|
+
log_fn(
|
|
4022
|
+
f" - {cname}: support={int(per_class_support[ci])}, "
|
|
4023
|
+
f"F1={float(per_class_f1[ci]):.2f}%{excl_tag}"
|
|
4024
|
+
)
|
|
4025
|
+
if len(class_names) <= 12:
|
|
4026
|
+
log_fn("Val confusion matrix rows=true, cols=pred:")
|
|
4027
|
+
for ci, cname in enumerate(class_names):
|
|
4028
|
+
row_vals = " ".join(str(int(v)) for v in conf_matrix[ci].tolist())
|
|
4029
|
+
log_fn(f" {ci}:{cname} | {row_vals}")
|
|
4030
|
+
if val_ignore_thresholds:
|
|
4031
|
+
log_fn(
|
|
4032
|
+
"Validation-calibrated ignore thresholds: "
|
|
4033
|
+
+ ", ".join(
|
|
4034
|
+
f"{cls}={float(tau):.2f}"
|
|
4035
|
+
for cls, tau in sorted(
|
|
4036
|
+
val_ignore_thresholds.get("per_class_thresholds", {}).items(),
|
|
4037
|
+
key=lambda item: item[0],
|
|
4038
|
+
)
|
|
4039
|
+
)
|
|
4040
|
+
)
|
|
4041
|
+
|
|
4042
|
+
if head_metadata:
|
|
4043
|
+
head_metadata["training_config"]["validation_calibrated_ignore_thresholds"] = _json_safe(val_ignore_thresholds)
|
|
4044
|
+
|
|
4045
|
+
metric_improved = False if in_localization_stage else (
|
|
4046
|
+
val_macro_f1 > best_val_f1 + 1e-6
|
|
4047
|
+
)
|
|
4048
|
+
|
|
4049
|
+
if in_localization_stage:
|
|
4050
|
+
gate_pass = (
|
|
4051
|
+
(loc_val_iou >= loc_gate_iou_threshold)
|
|
4052
|
+
and (loc_val_center <= loc_gate_center_error)
|
|
4053
|
+
and (loc_val_valid >= loc_gate_valid_rate)
|
|
4054
|
+
)
|
|
4055
|
+
localization_gate_streak = localization_gate_streak + 1 if gate_pass else 0
|
|
4056
|
+
reached_epoch_cap = (epoch + 1) >= max(1, loc_max_stage_epochs)
|
|
4057
|
+
reached_manual_switch = use_manual_loc_switch and ((epoch + 1) >= max(1, manual_loc_switch_epoch))
|
|
4058
|
+
if gate_pass and log_fn:
|
|
4059
|
+
log_fn(
|
|
4060
|
+
f"Localization gate check passed ({localization_gate_streak}/{loc_gate_patience})"
|
|
4061
|
+
)
|
|
4062
|
+
if reached_manual_switch or (localization_gate_streak >= max(1, loc_gate_patience)) or reached_epoch_cap:
|
|
4063
|
+
in_localization_stage = False
|
|
4064
|
+
classification_stage_start_epoch = epoch + 1
|
|
4065
|
+
if log_fn:
|
|
4066
|
+
if reached_manual_switch:
|
|
4067
|
+
reason = f"manual switch epoch reached ({manual_loc_switch_epoch})"
|
|
4068
|
+
elif localization_gate_streak >= max(1, loc_gate_patience):
|
|
4069
|
+
reason = "metrics gate reached"
|
|
4070
|
+
else:
|
|
4071
|
+
reason = "max localization epochs reached"
|
|
4072
|
+
log_fn(
|
|
4073
|
+
f"Switching to classification stage at epoch {epoch+1} ({reason}). "
|
|
4074
|
+
"Classifier will train on localized crops."
|
|
4075
|
+
)
|
|
4076
|
+
# Reset LR schedule, EMA, and optimizer state for classification
|
|
4077
|
+
cls_remaining = total_epochs - (epoch + 1)
|
|
4078
|
+
cls_warmup = 0
|
|
4079
|
+
if use_scheduler:
|
|
4080
|
+
eta_min = 0.2 * classification_lr
|
|
4081
|
+
for pg in optimizer.param_groups:
|
|
4082
|
+
pg['lr'] = classification_lr
|
|
4083
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
4084
|
+
optimizer, T_max=max(1, cls_remaining - cls_warmup), eta_min=eta_min
|
|
4085
|
+
)
|
|
4086
|
+
warmup_scheduler = None
|
|
4087
|
+
warmup_epochs = cls_warmup
|
|
4088
|
+
if log_fn:
|
|
4089
|
+
log_fn(f"Reset LR schedule: CosineAnnealingLR (single decay, {cls_remaining} epochs)")
|
|
4090
|
+
optimizer.state.clear()
|
|
4091
|
+
if log_fn:
|
|
4092
|
+
log_fn("Reset optimizer momentum/adaptive state for classification")
|
|
4093
|
+
ema_active = False
|
|
4094
|
+
ema_state.clear()
|
|
4095
|
+
|
|
4096
|
+
if metric_improved:
|
|
4097
|
+
best_val_f1 = val_macro_f1
|
|
4098
|
+
best_val_frame_acc = val_frame_acc
|
|
4099
|
+
if head_metadata:
|
|
4100
|
+
head_metadata["training_config"]["best_val_f1"] = best_val_f1
|
|
4101
|
+
if config.get("save_best", True):
|
|
4102
|
+
# Create folder for this best epoch
|
|
4103
|
+
output_dir = os.path.dirname(config["output_path"])
|
|
4104
|
+
basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
|
|
4105
|
+
from datetime import datetime
|
|
4106
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4107
|
+
|
|
4108
|
+
best_folder = os.path.join(
|
|
4109
|
+
output_dir,
|
|
4110
|
+
f"{basename}_checkpoints",
|
|
4111
|
+
f"epoch_{epoch+1}_f1_{val_macro_f1:.1f}_frameacc_{val_frame_acc:.1f}_{timestamp}"
|
|
4112
|
+
)
|
|
4113
|
+
os.makedirs(best_folder, exist_ok=True)
|
|
4114
|
+
|
|
4115
|
+
# Save Model
|
|
4116
|
+
best_path = os.path.join(best_folder, "model.pt")
|
|
4117
|
+
if head_metadata:
|
|
4118
|
+
model.save_head(best_path, metadata=head_metadata)
|
|
4119
|
+
else:
|
|
4120
|
+
model.save_head(best_path)
|
|
4121
|
+
|
|
4122
|
+
# Also update main best file
|
|
4123
|
+
best_main_path = config["output_path"].replace(".pt", "_best.pt")
|
|
4124
|
+
if head_metadata:
|
|
4125
|
+
model.save_head(best_main_path, metadata=head_metadata)
|
|
4126
|
+
else:
|
|
4127
|
+
model.save_head(best_main_path)
|
|
4128
|
+
|
|
4129
|
+
# Save Logs & Plots
|
|
4130
|
+
import pandas as pd
|
|
4131
|
+
import matplotlib
|
|
4132
|
+
matplotlib.use('Agg')
|
|
4133
|
+
import matplotlib.pyplot as plt
|
|
4134
|
+
|
|
4135
|
+
# Construct temp history including current epoch
|
|
4136
|
+
curr_hist = {k: v.copy() for k, v in history.items()}
|
|
4137
|
+
# Train metrics and epoch are ALREADY in history (updated before validation)
|
|
4138
|
+
# And val metrics were just updated in-place at index -1
|
|
4139
|
+
|
|
4140
|
+
pd.DataFrame(curr_hist).to_csv(os.path.join(best_folder, "history.csv"), index=False)
|
|
4141
|
+
|
|
4142
|
+
plt.style.use('ggplot')
|
|
4143
|
+
fig, axes = plt.subplots(4, 1, figsize=(10, 18))
|
|
4144
|
+
ax1, ax2, ax3, ax4 = axes
|
|
4145
|
+
epochs_hist = curr_hist['epoch']
|
|
4146
|
+
|
|
4147
|
+
ax1.plot(epochs_hist, curr_hist['train_acc'], label='Train Acc', marker='o')
|
|
4148
|
+
ax1.plot(epochs_hist, curr_hist['val_acc'], label='Val Acc (frame)', marker='s')
|
|
4149
|
+
ax1.set_title(f'Accuracy - Epoch {epoch+1}')
|
|
4150
|
+
ax1.set_ylabel('Accuracy (%)')
|
|
4151
|
+
ax1.legend()
|
|
4152
|
+
ax1.grid(True)
|
|
4153
|
+
|
|
4154
|
+
ax2.plot(epochs_hist, curr_hist['train_loss'], label='Train Loss', marker='o')
|
|
4155
|
+
ax2.plot(epochs_hist, curr_hist['val_loss'], label='Val Loss', marker='s')
|
|
4156
|
+
ax2.set_ylabel('Loss')
|
|
4157
|
+
ax2.legend()
|
|
4158
|
+
ax2.grid(True)
|
|
4159
|
+
|
|
4160
|
+
ax3.plot(epochs_hist, curr_hist['val_f1'], label='Val Macro F1 (frame)', linewidth=2, color='tab:purple')
|
|
4161
|
+
for idx in range(len(class_names)):
|
|
4162
|
+
class_key = class_key_map.get(idx)
|
|
4163
|
+
if class_key in curr_hist:
|
|
4164
|
+
ax3.plot(
|
|
4165
|
+
epochs_hist,
|
|
4166
|
+
curr_hist[class_key],
|
|
4167
|
+
label=f"{class_names[idx]}",
|
|
4168
|
+
linestyle='--',
|
|
4169
|
+
alpha=0.6
|
|
4170
|
+
)
|
|
4171
|
+
ax3.set_ylabel('F1 (%)')
|
|
4172
|
+
ax3.legend(ncol=2, fontsize=8)
|
|
4173
|
+
ax3.grid(True)
|
|
4174
|
+
|
|
4175
|
+
per_class_keys_ordered = [class_key_map[idx] for idx in range(len(class_names))]
|
|
4176
|
+
if per_class_keys_ordered:
|
|
4177
|
+
per_class_matrix = np.array([
|
|
4178
|
+
curr_hist[key] for key in per_class_keys_ordered
|
|
4179
|
+
])
|
|
4180
|
+
else:
|
|
4181
|
+
per_class_matrix = np.zeros((0, len(epochs_hist)))
|
|
4182
|
+
|
|
4183
|
+
im = ax4.imshow(
|
|
4184
|
+
per_class_matrix,
|
|
4185
|
+
aspect='auto',
|
|
4186
|
+
cmap='magma',
|
|
4187
|
+
vmin=0,
|
|
4188
|
+
vmax=100
|
|
4189
|
+
)
|
|
4190
|
+
ax4.set_yticks(range(len(class_names)))
|
|
4191
|
+
ax4.set_yticklabels(class_names)
|
|
4192
|
+
ax4.set_xlabel('Epoch')
|
|
4193
|
+
ax4.set_ylabel('Class')
|
|
4194
|
+
ax4.set_title('Validation F1 Heatmap (%)')
|
|
4195
|
+
if epochs_hist:
|
|
4196
|
+
max_ticks = min(len(epochs_hist), 12)
|
|
4197
|
+
tick_positions = np.linspace(0, len(epochs_hist) - 1, max_ticks, dtype=int)
|
|
4198
|
+
ax4.set_xticks(tick_positions)
|
|
4199
|
+
ax4.set_xticklabels([str(epochs_hist[i]) for i in tick_positions])
|
|
4200
|
+
cbar = fig.colorbar(im, ax=ax4, orientation='vertical', pad=0.01)
|
|
4201
|
+
cbar.set_label('F1 (%)')
|
|
4202
|
+
|
|
4203
|
+
plt.tight_layout()
|
|
4204
|
+
plt.savefig(os.path.join(best_folder, "training_plot.pdf"))
|
|
4205
|
+
plt.close()
|
|
4206
|
+
|
|
4207
|
+
if log_fn:
|
|
4208
|
+
log_fn(f"Saved best model checkpoint to {best_folder}")
|
|
4209
|
+
|
|
4210
|
+
except Exception as e:
|
|
4211
|
+
error_msg = f"Error in validation: {str(e)}\n{traceback.format_exc()}"
|
|
4212
|
+
if log_fn:
|
|
4213
|
+
log_fn(f"ERROR: {error_msg}")
|
|
4214
|
+
raise
|
|
4215
|
+
|
|
4216
|
+
# Record val metrics (use train metrics if no validation)
|
|
4217
|
+
if not val_loader:
|
|
4218
|
+
history["val_loss"][-1] = avg_loss
|
|
4219
|
+
history["val_acc"][-1] = train_acc
|
|
4220
|
+
history["val_frame_acc"][-1] = train_frame_acc
|
|
4221
|
+
history["val_f1"][-1] = 0.0
|
|
4222
|
+
|
|
4223
|
+
# Print train confusion matrix every 5 epochs when there is no val set
|
|
4224
|
+
if (log_fn and train_targets_all and not in_localization_stage
|
|
4225
|
+
and ((epoch + 1) % 5 == 0 or epoch == 0)):
|
|
4226
|
+
train_macro_f1 = f1_score(
|
|
4227
|
+
train_targets_all, train_preds_all,
|
|
4228
|
+
average='macro', zero_division=0,
|
|
4229
|
+
) * 100.0
|
|
4230
|
+
per_class_f1_tr = f1_score(
|
|
4231
|
+
train_targets_all, train_preds_all,
|
|
4232
|
+
labels=list(range(len(class_names))),
|
|
4233
|
+
average=None, zero_division=0,
|
|
4234
|
+
) * 100.0
|
|
4235
|
+
per_class_support_tr = np.bincount(
|
|
4236
|
+
np.asarray(train_targets_all, dtype=np.int64),
|
|
4237
|
+
minlength=len(class_names),
|
|
4238
|
+
).astype(int)
|
|
4239
|
+
conf_matrix_tr = np.zeros((len(class_names), len(class_names)), dtype=np.int64)
|
|
4240
|
+
for t, p in zip(train_targets_all, train_preds_all):
|
|
4241
|
+
if 0 <= int(t) < len(class_names) and 0 <= int(p) < len(class_names):
|
|
4242
|
+
conf_matrix_tr[int(t), int(p)] += 1
|
|
4243
|
+
history["val_f1"][-1] = train_macro_f1
|
|
4244
|
+
log_fn(f"Train Macro F1: {train_macro_f1:.2f}%")
|
|
4245
|
+
log_fn("Train class diagnostics (frame-labeled):")
|
|
4246
|
+
for ci, cname in enumerate(class_names):
|
|
4247
|
+
log_fn(
|
|
4248
|
+
f" - {cname}: support={int(per_class_support_tr[ci])}, "
|
|
4249
|
+
f"F1={float(per_class_f1_tr[ci]):.2f}%"
|
|
4250
|
+
)
|
|
4251
|
+
if len(class_names) <= 12:
|
|
4252
|
+
log_fn("Train confusion matrix rows=true, cols=pred:")
|
|
4253
|
+
for ci, cname in enumerate(class_names):
|
|
4254
|
+
row_vals = " ".join(str(int(v)) for v in conf_matrix_tr[ci].tolist())
|
|
4255
|
+
log_fn(f" {ci}:{cname} | {row_vals}")
|
|
4256
|
+
reached_manual_switch = use_manual_loc_switch and ((epoch + 1) >= max(1, manual_loc_switch_epoch))
|
|
4257
|
+
reached_epoch_cap = (epoch + 1) >= max(1, loc_max_stage_epochs)
|
|
4258
|
+
if in_localization_stage and (reached_manual_switch or reached_epoch_cap):
|
|
4259
|
+
in_localization_stage = False
|
|
4260
|
+
classification_stage_start_epoch = epoch + 1
|
|
4261
|
+
if log_fn:
|
|
4262
|
+
if reached_manual_switch:
|
|
4263
|
+
reason = f"manual switch epoch reached ({manual_loc_switch_epoch})"
|
|
4264
|
+
else:
|
|
4265
|
+
reason = "max localization epochs reached"
|
|
4266
|
+
log_fn(
|
|
4267
|
+
f"Switching to classification stage at epoch {epoch+1} ({reason}, no validation set)."
|
|
4268
|
+
)
|
|
4269
|
+
# Reset LR schedule, EMA, and optimizer state for classification
|
|
4270
|
+
cls_remaining = total_epochs - (epoch + 1)
|
|
4271
|
+
cls_warmup = 0
|
|
4272
|
+
if use_scheduler:
|
|
4273
|
+
eta_min = 0.2 * classification_lr
|
|
4274
|
+
for pg in optimizer.param_groups:
|
|
4275
|
+
pg['lr'] = classification_lr
|
|
4276
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
4277
|
+
optimizer, T_max=max(1, cls_remaining - cls_warmup), eta_min=eta_min
|
|
4278
|
+
)
|
|
4279
|
+
warmup_scheduler = None
|
|
4280
|
+
warmup_epochs = cls_warmup
|
|
4281
|
+
if log_fn:
|
|
4282
|
+
log_fn(f"Reset LR schedule: CosineAnnealingLR (single decay, {cls_remaining} epochs)")
|
|
4283
|
+
optimizer.state.clear()
|
|
4284
|
+
if log_fn:
|
|
4285
|
+
log_fn("Reset optimizer momentum/adaptive state for classification")
|
|
4286
|
+
ema_active = False
|
|
4287
|
+
ema_state.clear()
|
|
4288
|
+
|
|
4289
|
+
# Save checkpoints only after classification stage starts.
|
|
4290
|
+
# During localization stage we skip periodic model saves to
|
|
4291
|
+
# avoid generating unusable intermediate checkpoints.
|
|
4292
|
+
_apply_ema() # swap to EMA weights for saving
|
|
4293
|
+
if config.get("save_best", True) and (not in_localization_stage):
|
|
4294
|
+
output_dir = os.path.dirname(config["output_path"])
|
|
4295
|
+
basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
|
|
4296
|
+
from datetime import datetime
|
|
4297
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4298
|
+
|
|
4299
|
+
epoch_folder = os.path.join(
|
|
4300
|
+
output_dir,
|
|
4301
|
+
f"{basename}_checkpoints",
|
|
4302
|
+
f"epoch_{epoch+1}_trainloss_{avg_loss:.4f}_acc_{train_acc:.1f}_{timestamp}"
|
|
4303
|
+
)
|
|
4304
|
+
os.makedirs(epoch_folder, exist_ok=True)
|
|
4305
|
+
|
|
4306
|
+
epoch_path = os.path.join(epoch_folder, "model.pt")
|
|
4307
|
+
if head_metadata:
|
|
4308
|
+
model.save_head(epoch_path, metadata=head_metadata)
|
|
4309
|
+
else:
|
|
4310
|
+
model.save_head(epoch_path)
|
|
4311
|
+
|
|
4312
|
+
if log_fn:
|
|
4313
|
+
log_fn(f"Saved epoch checkpoint to {epoch_folder}")
|
|
4314
|
+
elif config.get("save_best", True) and in_localization_stage and log_fn:
|
|
4315
|
+
log_fn("Checkpoint save skipped (still in localization phase; classification not started yet).")
|
|
4316
|
+
|
|
4317
|
+
# Restore training weights after EMA-based validation/saving
|
|
4318
|
+
_apply_ema()
|
|
4319
|
+
|
|
4320
|
+
# Save crop-progress visualization every 2nd epoch.
|
|
4321
|
+
_save_epoch_crop_progress(epoch, epoch_phase)
|
|
4322
|
+
|
|
4323
|
+
# Incremental history CSV every 2 epochs for offline plotting
|
|
4324
|
+
if (epoch + 1) % 2 == 0:
|
|
4325
|
+
try:
|
|
4326
|
+
import pandas as pd
|
|
4327
|
+
inc_csv_dir = os.path.join(output_dir_base, f"{basename}_training_history")
|
|
4328
|
+
os.makedirs(inc_csv_dir, exist_ok=True)
|
|
4329
|
+
inc_csv_path = os.path.join(inc_csv_dir, "history.csv")
|
|
4330
|
+
pd.DataFrame(history).to_csv(inc_csv_path, index=False)
|
|
4331
|
+
except Exception as e:
|
|
4332
|
+
logger.debug("Could not save incremental history CSV: %s", e)
|
|
4333
|
+
|
|
4334
|
+
# Step LR scheduler once per epoch.
|
|
4335
|
+
sched_epoch_base = classification_stage_start_epoch if classification_stage_start_epoch is not None else 0
|
|
4336
|
+
sched_epoch = epoch - sched_epoch_base
|
|
4337
|
+
if use_scheduler and sched_epoch >= 0:
|
|
4338
|
+
if warmup_scheduler is not None and warmup_epochs > 0 and sched_epoch < warmup_epochs:
|
|
4339
|
+
warmup_scheduler.step()
|
|
4340
|
+
if use_ema and sched_epoch == warmup_epochs - 1:
|
|
4341
|
+
ema_active = True
|
|
4342
|
+
_init_ema()
|
|
4343
|
+
if log_fn:
|
|
4344
|
+
log_fn(f"EMA activated after {warmup_epochs}-epoch warmup")
|
|
4345
|
+
elif scheduler is not None:
|
|
4346
|
+
scheduler.step()
|
|
4347
|
+
ema_start_epoch = max(warmup_epochs, 3)
|
|
4348
|
+
if use_ema and (not ema_active) and sched_epoch >= ema_start_epoch - 1:
|
|
4349
|
+
ema_active = True
|
|
4350
|
+
_init_ema()
|
|
4351
|
+
if log_fn:
|
|
4352
|
+
log_fn(f"EMA activated at epoch {epoch + 1}")
|
|
4353
|
+
current_lr = max(pg['lr'] for pg in optimizer.param_groups)
|
|
4354
|
+
if log_fn and (epoch + 1) % 5 == 0:
|
|
4355
|
+
log_fn(f"Current Learning Rate: {current_lr:.8f}")
|
|
4356
|
+
|
|
4357
|
+
if metrics_callback:
|
|
4358
|
+
current_metrics = {
|
|
4359
|
+
"epoch": epoch + 1,
|
|
4360
|
+
"train_loss": history["train_loss"][-1],
|
|
4361
|
+
"train_loss_class": history["train_loss_class"][-1],
|
|
4362
|
+
"train_acc": history["train_acc"][-1],
|
|
4363
|
+
"train_frame_acc": history["train_frame_acc"][-1],
|
|
4364
|
+
"val_loss": history["val_loss"][-1],
|
|
4365
|
+
"val_acc": history["val_acc"][-1],
|
|
4366
|
+
"val_frame_acc": history["val_frame_acc"][-1],
|
|
4367
|
+
"val_f1": history["val_f1"][-1],
|
|
4368
|
+
"training_phase": epoch_phase,
|
|
4369
|
+
"loc_val_iou": history["loc_val_iou"][-1],
|
|
4370
|
+
"loc_val_center_error": history["loc_val_center_error"][-1],
|
|
4371
|
+
"loc_val_valid_rate": history["loc_val_valid_rate"][-1],
|
|
4372
|
+
"per_class_f1": {
|
|
4373
|
+
class_names[idx]: history[class_key_map[idx]][-1]
|
|
4374
|
+
for idx in range(len(class_names))
|
|
4375
|
+
if class_key_map.get(idx) in history
|
|
4376
|
+
and class_names[idx] not in _f1_exclude_names
|
|
4377
|
+
},
|
|
4378
|
+
"per_attr_f1": per_attr_f1,
|
|
4379
|
+
"crop_progress_dir": crop_progress_dir,
|
|
4380
|
+
}
|
|
4381
|
+
metrics_callback(current_metrics)
|
|
4382
|
+
|
|
4383
|
+
except Exception as e:
|
|
4384
|
+
error_msg = f"Error in epoch {epoch+1}: {str(e)}\n{traceback.format_exc()}"
|
|
4385
|
+
if log_fn:
|
|
4386
|
+
log_fn(f"ERROR: {error_msg}")
|
|
4387
|
+
raise
|
|
4388
|
+
|
|
4389
|
+
if log_fn:
|
|
4390
|
+
log_fn("Training complete!")
|
|
4391
|
+
|
|
4392
|
+
# --- Augmentation ablation evaluation ---
|
|
4393
|
+
if (config.get("use_augmentation", False)
|
|
4394
|
+
and not (stop_callback and stop_callback())):
|
|
4395
|
+
try:
|
|
4396
|
+
# Load best model for evaluation
|
|
4397
|
+
best_main_path = config["output_path"].replace(".pt", "_best.pt")
|
|
4398
|
+
if os.path.exists(best_main_path):
|
|
4399
|
+
model.load_head(best_main_path)
|
|
4400
|
+
_run_augmentation_ablation_eval(
|
|
4401
|
+
model, train_dataset, config, device, log_fn=log_fn
|
|
4402
|
+
)
|
|
4403
|
+
except Exception as abl_err:
|
|
4404
|
+
if log_fn:
|
|
4405
|
+
log_fn(f"Augmentation ablation eval failed (non-fatal): {abl_err}")
|
|
4406
|
+
|
|
4407
|
+
# --- Per-head temperature calibration (OvR only) ---
|
|
4408
|
+
if use_ovr and val_loader is not None and not (stop_callback and stop_callback()):
|
|
4409
|
+
try:
|
|
4410
|
+
# Load best checkpoint for calibration
|
|
4411
|
+
best_main_path = config["output_path"].replace(".pt", "_best.pt")
|
|
4412
|
+
if os.path.exists(best_main_path):
|
|
4413
|
+
model.load_head(best_main_path)
|
|
4414
|
+
model.eval()
|
|
4415
|
+
all_logits = []
|
|
4416
|
+
all_labels = []
|
|
4417
|
+
_val_emb_mode = getattr(val_dataset, '_emb_cache_mode', False)
|
|
4418
|
+
_cal_clip_len = config.get("clip_length", 8)
|
|
4419
|
+
with torch.no_grad():
|
|
4420
|
+
for batch in val_loader:
|
|
4421
|
+
# Batch is a tuple: (clips, labels, spatial_mask, bboxes, bbox_valid, indices, frame_labels)
|
|
4422
|
+
if not isinstance(batch, (list, tuple)) or len(batch) < 7:
|
|
4423
|
+
continue
|
|
4424
|
+
clips = batch[0].to(device)
|
|
4425
|
+
labels = batch[1].to(device)
|
|
4426
|
+
frame_labels_cal = batch[6].to(device)
|
|
4427
|
+
_cs_cal = batch[7] if len(batch) > 7 and isinstance(batch[7], torch.Tensor) and batch[7].numel() > 0 else None
|
|
4428
|
+
if _cs_cal is not None:
|
|
4429
|
+
_cs_cal = _cs_cal.to(device)
|
|
4430
|
+
|
|
4431
|
+
if _val_emb_mode:
|
|
4432
|
+
out = model(None, backbone_tokens=clips,
|
|
4433
|
+
num_frames=_cal_clip_len,
|
|
4434
|
+
backbone_tokens_short=_cs_cal,
|
|
4435
|
+
num_frames_short=_cal_clip_len // 2 if _cs_cal is not None else None,
|
|
4436
|
+
return_frame_logits=True)
|
|
4437
|
+
else:
|
|
4438
|
+
out = model(clips, return_frame_logits=True)
|
|
4439
|
+
|
|
4440
|
+
fo = getattr(model, '_frame_output', None)
|
|
4441
|
+
if fo is not None:
|
|
4442
|
+
f_logits = fo[0] # [B, T, C]
|
|
4443
|
+
all_logits.append(f_logits.cpu())
|
|
4444
|
+
all_labels.append(frame_labels_cal.cpu())
|
|
4445
|
+
|
|
4446
|
+
if all_logits:
|
|
4447
|
+
cat_logits = torch.cat(all_logits, dim=0) # [N, T, C]
|
|
4448
|
+
cat_labels = torch.cat(all_labels, dim=0) # [N, T]
|
|
4449
|
+
B_cal, T_cal, C_cal = cat_logits.shape
|
|
4450
|
+
flat_logits = cat_logits.reshape(-1, C_cal) # [N*T, C]
|
|
4451
|
+
flat_labels = cat_labels.reshape(-1) # [N*T]
|
|
4452
|
+
valid_mask = flat_labels >= 0
|
|
4453
|
+
|
|
4454
|
+
if valid_mask.sum() > 50:
|
|
4455
|
+
temperatures = torch.ones(C_cal)
|
|
4456
|
+
for ci in range(C_cal):
|
|
4457
|
+
ci_logits = flat_logits[valid_mask, ci]
|
|
4458
|
+
ci_targets = (flat_labels[valid_mask] == ci).float()
|
|
4459
|
+
if ci_targets.sum() < 5 or (1 - ci_targets).sum() < 5:
|
|
4460
|
+
continue
|
|
4461
|
+
# Grid search for temperature that maximizes F1
|
|
4462
|
+
best_f1 = -1.0
|
|
4463
|
+
best_t = 1.0
|
|
4464
|
+
for t_cand in [0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 2.5, 3.0]:
|
|
4465
|
+
probs = torch.sigmoid(ci_logits / t_cand)
|
|
4466
|
+
preds = (probs >= 0.5).float()
|
|
4467
|
+
tp = (preds * ci_targets).sum()
|
|
4468
|
+
fp = (preds * (1 - ci_targets)).sum()
|
|
4469
|
+
fn = ((1 - preds) * ci_targets).sum()
|
|
4470
|
+
prec = tp / (tp + fp + 1e-8)
|
|
4471
|
+
rec = tp / (tp + fn + 1e-8)
|
|
4472
|
+
f1 = 2 * prec * rec / (prec + rec + 1e-8)
|
|
4473
|
+
if f1.item() > best_f1:
|
|
4474
|
+
best_f1 = f1.item()
|
|
4475
|
+
best_t = t_cand
|
|
4476
|
+
temperatures[ci] = best_t
|
|
4477
|
+
|
|
4478
|
+
temp_dict = {
|
|
4479
|
+
class_names[ci]: round(float(temperatures[ci]), 3)
|
|
4480
|
+
for ci in range(C_cal)
|
|
4481
|
+
}
|
|
4482
|
+
if head_metadata:
|
|
4483
|
+
head_metadata["ovr_temperatures"] = temp_dict
|
|
4484
|
+
if log_fn:
|
|
4485
|
+
t_str = ", ".join(f"{k}={v}" for k, v in temp_dict.items())
|
|
4486
|
+
log_fn(f"OvR per-head calibrated temperatures: {t_str}")
|
|
4487
|
+
elif log_fn:
|
|
4488
|
+
log_fn("Skipping OvR temperature calibration: too few valid validation frames.")
|
|
4489
|
+
elif log_fn:
|
|
4490
|
+
log_fn("Skipping OvR temperature calibration: no frame logits from validation.")
|
|
4491
|
+
except Exception as cal_err:
|
|
4492
|
+
if log_fn:
|
|
4493
|
+
log_fn(f"OvR temperature calibration failed (non-fatal): {cal_err}")
|
|
4494
|
+
|
|
4495
|
+
if head_metadata:
|
|
4496
|
+
head_metadata["training_config"]["best_val_f1"] = best_val_f1 if best_val_f1 >= 0 else None
|
|
4497
|
+
head_metadata["training_config"]["best_val_frame_acc"] = best_val_frame_acc
|
|
4498
|
+
# Legacy field retained for compatibility with existing readers.
|
|
4499
|
+
head_metadata["training_config"]["best_val_acc"] = best_val_frame_acc
|
|
4500
|
+
per_class_final = {}
|
|
4501
|
+
for idx in range(len(class_names)):
|
|
4502
|
+
class_key = class_key_map.get(idx)
|
|
4503
|
+
if class_key in history and history[class_key]:
|
|
4504
|
+
per_class_final[class_label_map[idx]] = history[class_key][-1]
|
|
4505
|
+
head_metadata["training_config"]["final_epoch_val_f1_per_class"] = per_class_final
|
|
4506
|
+
|
|
4507
|
+
# Apply EMA weights for final model save (if EMA was activated)
|
|
4508
|
+
if ema_active:
|
|
4509
|
+
_apply_ema()
|
|
4510
|
+
|
|
4511
|
+
try:
|
|
4512
|
+
if head_metadata:
|
|
4513
|
+
model.save_head(config["output_path"], metadata=head_metadata)
|
|
4514
|
+
else:
|
|
4515
|
+
model.save_head(config["output_path"])
|
|
4516
|
+
if log_fn:
|
|
4517
|
+
tag = " (EMA-averaged)" if ema_active else ""
|
|
4518
|
+
log_fn(f"Saved final model{tag} to {config['output_path']}")
|
|
4519
|
+
|
|
4520
|
+
# --- Save Training Logs and Plots ---
|
|
4521
|
+
import pandas as pd
|
|
4522
|
+
from datetime import datetime
|
|
4523
|
+
import matplotlib
|
|
4524
|
+
matplotlib.use('Agg')
|
|
4525
|
+
import matplotlib.pyplot as plt
|
|
4526
|
+
import json
|
|
4527
|
+
|
|
4528
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4529
|
+
output_dir = os.path.dirname(config["output_path"])
|
|
4530
|
+
output_basename = os.path.splitext(os.path.basename(config["output_path"]))[0]
|
|
4531
|
+
|
|
4532
|
+
# Save training config snapshot for inference resolution fallback
|
|
4533
|
+
training_config_path = os.path.join(output_dir, f"{output_basename}_training_config.json")
|
|
4534
|
+
training_snapshot = {
|
|
4535
|
+
"classes": class_names,
|
|
4536
|
+
"attributes": getattr(train_dataset, "attributes", []),
|
|
4537
|
+
"clip_length": clip_length_value,
|
|
4538
|
+
"resolution": resolution_value,
|
|
4539
|
+
"backbone_model": config.get("backbone_model", "videoprism_public_v1_base"),
|
|
4540
|
+
"training_config": head_metadata.get("training_config", {}) if head_metadata else {},
|
|
4541
|
+
}
|
|
4542
|
+
with open(training_config_path, "w", encoding="utf-8") as cfg_file:
|
|
4543
|
+
json.dump(training_snapshot, cfg_file, indent=2)
|
|
4544
|
+
if log_fn:
|
|
4545
|
+
log_fn(f"Saved training config to {training_config_path}")
|
|
4546
|
+
|
|
4547
|
+
# Save CSV log
|
|
4548
|
+
csv_path = os.path.join(output_dir, f"{output_basename}_training_log_{timestamp}.csv")
|
|
4549
|
+
df = pd.DataFrame(history)
|
|
4550
|
+
df.to_csv(csv_path, index=False)
|
|
4551
|
+
if log_fn:
|
|
4552
|
+
log_fn(f"Saved training log to {csv_path}")
|
|
4553
|
+
|
|
4554
|
+
# Generate Plots
|
|
4555
|
+
plt.style.use('ggplot')
|
|
4556
|
+
fig, axes = plt.subplots(4, 1, figsize=(10, 18))
|
|
4557
|
+
ax1, ax2, ax3, ax4 = axes
|
|
4558
|
+
epochs_hist = history['epoch']
|
|
4559
|
+
|
|
4560
|
+
ax1.plot(epochs_hist, history['train_acc'], label='Train Accuracy', marker='o')
|
|
4561
|
+
ax1.plot(epochs_hist, history['val_acc'], label='Val Acc (frame)', marker='s')
|
|
4562
|
+
ax1.set_title(f'Training Accuracy - {output_basename}')
|
|
4563
|
+
ax1.set_ylabel('Accuracy (%)')
|
|
4564
|
+
ax1.legend()
|
|
4565
|
+
ax1.grid(True)
|
|
4566
|
+
|
|
4567
|
+
ax2.plot(epochs_hist, history['train_loss'], label='Train Loss', marker='o')
|
|
4568
|
+
ax2.plot(epochs_hist, history['val_loss'], label='Val Loss', marker='s')
|
|
4569
|
+
ax2.set_title(f'Training Loss - {output_basename}')
|
|
4570
|
+
ax2.set_ylabel('Loss')
|
|
4571
|
+
ax2.legend()
|
|
4572
|
+
ax2.grid(True)
|
|
4573
|
+
|
|
4574
|
+
ax3.plot(epochs_hist, history['val_f1'], label='Val Macro F1 (frame)', linewidth=2, color='tab:purple')
|
|
4575
|
+
for idx in range(len(class_names)):
|
|
4576
|
+
class_key = class_key_map.get(idx)
|
|
4577
|
+
if class_key in history:
|
|
4578
|
+
ax3.plot(
|
|
4579
|
+
epochs_hist,
|
|
4580
|
+
history[class_key],
|
|
4581
|
+
label=f"{class_names[idx]}",
|
|
4582
|
+
linestyle='--',
|
|
4583
|
+
alpha=0.6
|
|
4584
|
+
)
|
|
4585
|
+
ax3.set_ylabel('F1 (%)')
|
|
4586
|
+
ax3.legend(ncol=2, fontsize=8)
|
|
4587
|
+
ax3.grid(True)
|
|
4588
|
+
|
|
4589
|
+
per_class_keys_ordered = [class_key_map[idx] for idx in range(len(class_names))]
|
|
4590
|
+
if per_class_keys_ordered:
|
|
4591
|
+
per_class_matrix = np.array([
|
|
4592
|
+
history[key] for key in per_class_keys_ordered
|
|
4593
|
+
])
|
|
4594
|
+
else:
|
|
4595
|
+
per_class_matrix = np.zeros((0, len(epochs_hist)))
|
|
4596
|
+
|
|
4597
|
+
im = ax4.imshow(
|
|
4598
|
+
per_class_matrix,
|
|
4599
|
+
aspect='auto',
|
|
4600
|
+
cmap='magma',
|
|
4601
|
+
vmin=0,
|
|
4602
|
+
vmax=100
|
|
4603
|
+
)
|
|
4604
|
+
ax4.set_yticks(range(len(class_names)))
|
|
4605
|
+
ax4.set_yticklabels(class_names)
|
|
4606
|
+
ax4.set_xlabel('Epoch')
|
|
4607
|
+
ax4.set_ylabel('Class')
|
|
4608
|
+
ax4.set_title('Validation F1 Heatmap (%)')
|
|
4609
|
+
if epochs_hist:
|
|
4610
|
+
max_ticks = min(len(epochs_hist), 12)
|
|
4611
|
+
tick_positions = np.linspace(0, len(epochs_hist) - 1, max_ticks, dtype=int)
|
|
4612
|
+
ax4.set_xticks(tick_positions)
|
|
4613
|
+
ax4.set_xticklabels([str(epochs_hist[i]) for i in tick_positions])
|
|
4614
|
+
cbar = fig.colorbar(im, ax=ax4, orientation='vertical', pad=0.01)
|
|
4615
|
+
cbar.set_label('F1 (%)')
|
|
4616
|
+
|
|
4617
|
+
# Save Plot as PDF
|
|
4618
|
+
plot_path = os.path.join(output_dir, f"{output_basename}_training_plot_{timestamp}.pdf")
|
|
4619
|
+
plt.tight_layout()
|
|
4620
|
+
plt.savefig(plot_path)
|
|
4621
|
+
plt.close()
|
|
4622
|
+
|
|
4623
|
+
if log_fn:
|
|
4624
|
+
log_fn(f"Saved training plot to {plot_path}")
|
|
4625
|
+
|
|
4626
|
+
except Exception as e:
|
|
4627
|
+
error_msg = f"Error saving model/logs: {str(e)}\n{traceback.format_exc()}"
|
|
4628
|
+
if log_fn:
|
|
4629
|
+
log_fn(f"ERROR: {error_msg}")
|
|
4630
|
+
|
|
4631
|
+
final_train_acc = history["train_acc"][-1] if history["train_acc"] else 0.0
|
|
4632
|
+
best_val_f1_out = best_val_f1 if best_val_f1 >= 0 else 0.0
|
|
4633
|
+
|
|
4634
|
+
final_per_class_f1 = {}
|
|
4635
|
+
for idx in range(len(class_names)):
|
|
4636
|
+
if class_names[idx] in _f1_exclude_names:
|
|
4637
|
+
continue
|
|
4638
|
+
class_key = class_key_map.get(idx)
|
|
4639
|
+
if class_key in history and history[class_key]:
|
|
4640
|
+
final_per_class_f1[class_label_map[idx]] = history[class_key][-1]
|
|
4641
|
+
|
|
4642
|
+
return {
|
|
4643
|
+
"best_val_acc": best_val_frame_acc,
|
|
4644
|
+
"best_val_frame_acc": best_val_frame_acc,
|
|
4645
|
+
"best_val_f1": best_val_f1_out,
|
|
4646
|
+
"final_train_acc": final_train_acc,
|
|
4647
|
+
"per_class_f1": final_per_class_f1
|
|
4648
|
+
}
|
|
4649
|
+
|
|
4650
|
+
except Exception as e:
|
|
4651
|
+
error_msg = f"Training failed: {str(e)}\n{traceback.format_exc()}"
|
|
4652
|
+
if log_fn:
|
|
4653
|
+
log_fn(f"FATAL ERROR: {error_msg}")
|
|
4654
|
+
raise RuntimeError(error_msg) from e
|
|
4655
|
+
|
|
4656
|
+
finally:
|
|
4657
|
+
# Clean up embedding caches — they are only valid for this training run
|
|
4658
|
+
# (specific backbone weights, augmentation settings, dataset split).
|
|
4659
|
+
import shutil
|
|
4660
|
+
for _cache_path in [train_emb_cache, val_emb_cache_dir]:
|
|
4661
|
+
if _cache_path and os.path.isdir(_cache_path):
|
|
4662
|
+
try:
|
|
4663
|
+
shutil.rmtree(_cache_path)
|
|
4664
|
+
if log_fn:
|
|
4665
|
+
log_fn(f"Cleaned up embedding cache: {_cache_path}")
|
|
4666
|
+
except Exception as e:
|
|
4667
|
+
logger.debug("Could not remove embedding cache %s: %s", _cache_path, e)
|