opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1131 @@
|
|
|
1
|
+
# opensportslib/core/trainer/classification_trainer.py
|
|
2
|
+
|
|
3
|
+
"""classification trainers for video and tracking modalities.
|
|
4
|
+
|
|
5
|
+
provides a base trainer with modality-agnostic training, validation,
|
|
6
|
+
and test loops, plus two modality-specific subclasses that implement
|
|
7
|
+
the forward pass. Trainer_Classification is the top-level dispatcher
|
|
8
|
+
consumed by the API layer.
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import gc
|
|
14
|
+
import json
|
|
15
|
+
import time
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import tqdm
|
|
20
|
+
import wandb
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from torch.utils.data import (
|
|
24
|
+
DataLoader,
|
|
25
|
+
WeightedRandomSampler,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from transformers import Trainer as HFTrainer, TrainingArguments
|
|
29
|
+
from opensportslib.core.utils.ddp import DistributedWeightedSampler
|
|
30
|
+
|
|
31
|
+
from opensportslib.core.utils.wandb import log_confusion_matrix_wandb
|
|
32
|
+
from opensportslib.core.utils.checkpoint import *
|
|
33
|
+
|
|
34
|
+
from opensportslib.core.utils.config import select_device
|
|
35
|
+
from opensportslib.core.utils.data import mixup_data
|
|
36
|
+
import torch.distributed as dist
|
|
37
|
+
from datetime import datetime
|
|
38
|
+
from opensportslib.core.utils.seed import seed_worker
|
|
39
|
+
from opensportslib.metrics.classification_metric import (
|
|
40
|
+
compute_classification_metrics,
|
|
41
|
+
process_preds_labels
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# -------------------------------------------------------------------
|
|
45
|
+
# base classification trainer
|
|
46
|
+
# -------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
class BaseTrainerClassification:
|
|
49
|
+
"""modality-agnostic training loop for classification.
|
|
50
|
+
|
|
51
|
+
handles epoch iteration, gradient updates, DDP gather,
|
|
52
|
+
metric computation, W&B logging, checkpoint saving, and JSON
|
|
53
|
+
prediction export. subclasses only need to override _forward_batch()
|
|
54
|
+
with modality-specific tensor preparation.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
train_loader: DataLoader for training set.
|
|
58
|
+
val_loader: DataLoader for validation set.
|
|
59
|
+
test_loader: DataLoader for test set (may be None during training).
|
|
60
|
+
model: the classification model (already on device).
|
|
61
|
+
optimizer: PyTorch optimizer.
|
|
62
|
+
scheduler: learning-rate scheduler.
|
|
63
|
+
criterion: loss function callable.
|
|
64
|
+
class_weights: optional per-class weight tensor for the loss.
|
|
65
|
+
class_names: dict mapping class indices to names.
|
|
66
|
+
save_dir: root directory for checkpoint and prediction output.
|
|
67
|
+
model_name: name used for the checkpoint sub-directory.
|
|
68
|
+
max_epochs: maximum number of training epochs.
|
|
69
|
+
device: torch.device or device string.
|
|
70
|
+
top_k: k value for top-k accuracy computation.
|
|
71
|
+
wandb_project: W&B project name.
|
|
72
|
+
wandb_run_name: W&B run display name.
|
|
73
|
+
wandb_config: dict of hyperparameters logged to W&B.
|
|
74
|
+
patience: early-stopping patience (0=disabled).
|
|
75
|
+
monitor: metric name to monitor for checkpointing.
|
|
76
|
+
mode: "max" or "min" depending on the monitored metric.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
train_loader,
|
|
82
|
+
val_loader,
|
|
83
|
+
test_loader,
|
|
84
|
+
model,
|
|
85
|
+
optimizer,
|
|
86
|
+
scheduler,
|
|
87
|
+
criterion,
|
|
88
|
+
class_weights,
|
|
89
|
+
class_names,
|
|
90
|
+
save_dir,
|
|
91
|
+
model_name,
|
|
92
|
+
max_epochs=1000,
|
|
93
|
+
device="cuda",
|
|
94
|
+
top_k=2,
|
|
95
|
+
patience=10,
|
|
96
|
+
monitor="balanced_accuracy",
|
|
97
|
+
mode="max",
|
|
98
|
+
revert_on_lr_reduction=False,
|
|
99
|
+
config=None,
|
|
100
|
+
):
|
|
101
|
+
self.train_loader = train_loader
|
|
102
|
+
self.val_loader = val_loader
|
|
103
|
+
self.test_loader = test_loader
|
|
104
|
+
|
|
105
|
+
self.model = model#.to(device)
|
|
106
|
+
#self.model = DDP(self.model, device_ids=[device])
|
|
107
|
+
self.optimizer = optimizer
|
|
108
|
+
self.scheduler = scheduler
|
|
109
|
+
self.criterion = criterion
|
|
110
|
+
self.class_weights = class_weights
|
|
111
|
+
self.class_names = class_names
|
|
112
|
+
|
|
113
|
+
self.model_name = model_name
|
|
114
|
+
self.max_epochs = max_epochs
|
|
115
|
+
self.device = device
|
|
116
|
+
self.top_k = top_k
|
|
117
|
+
self.patience = patience
|
|
118
|
+
|
|
119
|
+
self.monitor = monitor
|
|
120
|
+
self.mode = mode
|
|
121
|
+
self.config = config
|
|
122
|
+
|
|
123
|
+
self.best_checkpoint_path = None
|
|
124
|
+
self.best_metric = None
|
|
125
|
+
self.revert_on_lr_reduction = revert_on_lr_reduction
|
|
126
|
+
self._best_model_state = None
|
|
127
|
+
|
|
128
|
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
|
129
|
+
|
|
130
|
+
self.save_dir = save_dir
|
|
131
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
if self.rank == 0:
|
|
135
|
+
wandb.watch(self.model, log="gradients", log_freq=100)
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
# -- abstract forward pass --------------------------------------
|
|
140
|
+
|
|
141
|
+
def _forward_batch(self, batch):
|
|
142
|
+
"""run the modality-specific forward pass.
|
|
143
|
+
|
|
144
|
+
must be overridden by every subclass.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
batch: a dict produced by the DataLoader.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
a tuple (logits, labels) where both are tensors on
|
|
151
|
+
self.device.
|
|
152
|
+
"""
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
# -- process batch ----------------------------------------------
|
|
156
|
+
|
|
157
|
+
def _process_batch(self, batch, train):
|
|
158
|
+
"""run forward pass, compute loss, and optionally update weights.
|
|
159
|
+
|
|
160
|
+
the default implementation calls _forward_batch() for the
|
|
161
|
+
modality-specific forward pass, then computes the loss and
|
|
162
|
+
runs the backward step. subclasses may override this entirely
|
|
163
|
+
to inject AMP, mixup, or other training-time modifications
|
|
164
|
+
without touching the base training loop.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
batch: a dict produced by the DataLoader.
|
|
168
|
+
train: if True, compute gradients and update weights.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
a tuple (logits, labels, loss).
|
|
172
|
+
"""
|
|
173
|
+
has_labels = "labels" in batch or "label" in batch
|
|
174
|
+
with torch.set_grad_enabled(train):
|
|
175
|
+
logits, labels = self._forward_batch(batch)
|
|
176
|
+
if labels is None:
|
|
177
|
+
has_labels = False
|
|
178
|
+
loss = None
|
|
179
|
+
if has_labels:
|
|
180
|
+
if self.class_weights is not None:
|
|
181
|
+
loss = self.criterion(
|
|
182
|
+
output=logits, labels=labels,
|
|
183
|
+
weight=self.class_weights.to(self.device)
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
loss = self.criterion(output=logits, labels=labels)
|
|
187
|
+
|
|
188
|
+
if train and loss is not None:
|
|
189
|
+
self.optimizer.zero_grad()
|
|
190
|
+
loss.backward()
|
|
191
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
192
|
+
self.optimizer.step()
|
|
193
|
+
|
|
194
|
+
return logits, labels, loss, has_labels
|
|
195
|
+
|
|
196
|
+
# -- training loop ----------------------------------------------
|
|
197
|
+
|
|
198
|
+
def train(self, epoch_start=0, save_every=3):
|
|
199
|
+
"""run the full training loop with validation after each epoch.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
epoch_start: the epoch number to start from (0-based).
|
|
203
|
+
save_every: currently unused; reserved for periodic
|
|
204
|
+
checkpoint saving.
|
|
205
|
+
"""
|
|
206
|
+
logging.info("Starting training")
|
|
207
|
+
monitor = self.monitor
|
|
208
|
+
mode = self.mode
|
|
209
|
+
best_metric = -float("inf") if mode == "max" else float("inf")
|
|
210
|
+
best_path = None
|
|
211
|
+
|
|
212
|
+
for epoch in range(epoch_start, self.max_epochs):
|
|
213
|
+
logging.info(f"\nEpoch {epoch+1}/{self.max_epochs}")
|
|
214
|
+
|
|
215
|
+
# --- train ---
|
|
216
|
+
if hasattr(self.train_loader.sampler, "set_epoch"):
|
|
217
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
218
|
+
|
|
219
|
+
disable = self.rank != 0
|
|
220
|
+
|
|
221
|
+
pbar = tqdm.tqdm(
|
|
222
|
+
total=len(self.train_loader), desc="Training",
|
|
223
|
+
position=0, leave=True, disable=disable
|
|
224
|
+
)
|
|
225
|
+
_, _, train_loss, train_metrics = self._run_epoch(
|
|
226
|
+
self.train_loader, epoch + 1,
|
|
227
|
+
train=True, set_name="train", pbar=pbar
|
|
228
|
+
)
|
|
229
|
+
pbar.close()
|
|
230
|
+
|
|
231
|
+
# --- validation ---
|
|
232
|
+
pbar = tqdm.tqdm(
|
|
233
|
+
total=len(self.val_loader), desc="Valid",
|
|
234
|
+
position=1, leave=True, disable=disable
|
|
235
|
+
)
|
|
236
|
+
_, _, val_loss, val_metrics = self._run_epoch(
|
|
237
|
+
self.val_loader, epoch + 1,
|
|
238
|
+
train=False, set_name="valid", pbar=pbar
|
|
239
|
+
)
|
|
240
|
+
pbar.close()
|
|
241
|
+
|
|
242
|
+
prev_lr = self.optimizer.param_groups[0]["lr"]
|
|
243
|
+
|
|
244
|
+
# capture LR before the scheduler step so we can detect
|
|
245
|
+
# plateau-triggered reductions.
|
|
246
|
+
val_metric = val_metrics.get(
|
|
247
|
+
"balanced_accuracy", val_metrics.get("accuracy", 0)
|
|
248
|
+
)
|
|
249
|
+
train_metric = train_metrics.get(
|
|
250
|
+
"balanced_accuracy", train_metrics.get("accuracy", 0)
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# ReduceLROnPlateau needs the monitored metric
|
|
254
|
+
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
255
|
+
self.scheduler.step(val_loss)
|
|
256
|
+
else:
|
|
257
|
+
self.scheduler.step()
|
|
258
|
+
|
|
259
|
+
current_lr = self.optimizer.param_groups[0]["lr"]
|
|
260
|
+
|
|
261
|
+
# early stopping: mirror pixels_vs_positions behavior
|
|
262
|
+
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
263
|
+
min_lr = self.scheduler.min_lrs[0]
|
|
264
|
+
if current_lr <= 2 * min_lr:
|
|
265
|
+
if self.rank == 0:
|
|
266
|
+
logging.info(
|
|
267
|
+
f"Early stopping at epoch {epoch+1}: "
|
|
268
|
+
f"lr {current_lr:.2e} <= 2 * min_lr {min_lr:.2e}"
|
|
269
|
+
)
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
# When ReduceLROnPlateau drops the LR, revert weights to
|
|
273
|
+
# the best checkpoint so training continues from the
|
|
274
|
+
# strongest point rather than from a potentially overfit
|
|
275
|
+
# state. This mirrors the pixels_vs_positions recipe.
|
|
276
|
+
if (
|
|
277
|
+
self.revert_on_lr_reduction
|
|
278
|
+
and current_lr != prev_lr
|
|
279
|
+
and self._best_model_state is not None
|
|
280
|
+
):
|
|
281
|
+
self.model.load_state_dict(self._best_model_state)
|
|
282
|
+
print(
|
|
283
|
+
f"LR reduced from {prev_lr:.2e} to {current_lr:.2e} "
|
|
284
|
+
f"-- reverted to best model"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if self.rank == 0:
|
|
288
|
+
# ---------------- W&B LOG ----------------
|
|
289
|
+
wandb.log({
|
|
290
|
+
"epoch": epoch + 1,
|
|
291
|
+
"lr": current_lr,
|
|
292
|
+
"train/loss": train_loss,
|
|
293
|
+
"valid/loss": val_loss,
|
|
294
|
+
**{f"train/{k}": v for k, v in train_metrics.items()},
|
|
295
|
+
**{f"valid/{k}": v for k, v in val_metrics.items()},
|
|
296
|
+
})
|
|
297
|
+
|
|
298
|
+
logging.info(f"Train Loss: {train_loss:.4f} | Train Bal Acc: {train_metric:.4f}")
|
|
299
|
+
logging.info(f"Val Loss: {val_loss:.4f} | Val Bal Acc: {val_metric:.4f}")
|
|
300
|
+
|
|
301
|
+
# ---------------- CHECKPOINT ----------------
|
|
302
|
+
current = val_loss if monitor == "loss" else val_metrics.get(monitor, 0)
|
|
303
|
+
|
|
304
|
+
is_better = current > best_metric if mode == "max" else current < best_metric
|
|
305
|
+
|
|
306
|
+
if is_better and self.rank == 0:
|
|
307
|
+
best_metric = current
|
|
308
|
+
self.best_metric = best_metric
|
|
309
|
+
|
|
310
|
+
if self.revert_on_lr_reduction:
|
|
311
|
+
self._best_model_state = {
|
|
312
|
+
k: v.cpu().clone()
|
|
313
|
+
for k, v in self.model.state_dict().items()
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
best_path = self._save_checkpoint("best", epoch + 1, tag="best")
|
|
317
|
+
self.best_checkpoint_path = best_path
|
|
318
|
+
|
|
319
|
+
artifact = wandb.Artifact("model-checkpoint", type="model")
|
|
320
|
+
artifact.add_file(best_path)
|
|
321
|
+
wandb.log_artifact(artifact)
|
|
322
|
+
|
|
323
|
+
if self.rank == 0:
|
|
324
|
+
logging.info(f"Best checkpoint : {self.best_checkpoint_path}")
|
|
325
|
+
logging.info("Training finished.")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
# -- TEST evaluation ------------------------------------------
|
|
329
|
+
|
|
330
|
+
def test(self, epoch=None, detailed_results=False):
|
|
331
|
+
"""run the test set evaluation.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
epoch: the epoch number to evaluate (if None, uses "final").
|
|
335
|
+
detailed_results: whether to compute detailed classification metrics.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
a tuple (test_loss, test_metrics).
|
|
339
|
+
"""
|
|
340
|
+
logging.info("\nRunning TEST evaluation")
|
|
341
|
+
pbar = tqdm.tqdm(
|
|
342
|
+
total=len(self.test_loader), desc="Test", position=0,
|
|
343
|
+
leave=True, disable = self.rank != 0
|
|
344
|
+
)
|
|
345
|
+
all_logits, all_labels, test_loss, test_metrics = self._run_epoch(
|
|
346
|
+
self.test_loader,
|
|
347
|
+
epoch if epoch is not None else "final",
|
|
348
|
+
train=False, set_name="test", pbar=pbar
|
|
349
|
+
)
|
|
350
|
+
pbar.close()
|
|
351
|
+
|
|
352
|
+
if self.rank==0:
|
|
353
|
+
wandb.log({
|
|
354
|
+
"test/loss": test_loss,
|
|
355
|
+
**{f"test/{k}": v for k, v in test_metrics.items()},
|
|
356
|
+
})
|
|
357
|
+
|
|
358
|
+
if detailed_results:
|
|
359
|
+
from opensportslib.metrics.classification_metric import (
|
|
360
|
+
compute_detailed_classification_metrics
|
|
361
|
+
)
|
|
362
|
+
compute_detailed_classification_metrics(
|
|
363
|
+
all_logits=all_logits, all_labels=all_labels,
|
|
364
|
+
class_names=self.class_names, save_dir=self.save_dir,
|
|
365
|
+
set_name="test"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
logging.info(f"TEST METRICS : {test_metrics}")
|
|
369
|
+
return test_loss, test_metrics
|
|
370
|
+
|
|
371
|
+
# -- single epoch logic -------------------------------------
|
|
372
|
+
|
|
373
|
+
def _run_epoch(self, dataloader, epoch, train=False, set_name="train", pbar=None):
|
|
374
|
+
"""execute one pass over a dataloader.
|
|
375
|
+
|
|
376
|
+
handles forward/backward, per-batch bookkeeping, DDP gather, metric
|
|
377
|
+
computation, confusion-matrix logging, and JSON prediction export.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
dataloader: the DataLoader to iterate over.
|
|
381
|
+
epoch: the epoch number (for checkpointing and folder naming).
|
|
382
|
+
train: if True, compute gradients and update weights.
|
|
383
|
+
set_name: "train", "valid", or "test" (for logging and JSON).
|
|
384
|
+
pbar: optional tqdm progress bar.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
a tuple (all_logits, all_labels, avg_loss, metrics).
|
|
388
|
+
on non-rank-0 DDP workers the first two are None and metrics
|
|
389
|
+
is an empty dict.
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
import torch.distributed as dist
|
|
393
|
+
|
|
394
|
+
if train:
|
|
395
|
+
self.model.train()
|
|
396
|
+
else:
|
|
397
|
+
self.model.eval()
|
|
398
|
+
|
|
399
|
+
total_loss = 0.0
|
|
400
|
+
total_batches = 0
|
|
401
|
+
|
|
402
|
+
all_logits = []
|
|
403
|
+
all_labels = []
|
|
404
|
+
results = []
|
|
405
|
+
|
|
406
|
+
# -------- Create epoch folder --------
|
|
407
|
+
epoch_dir = os.path.join(self.save_dir, str(epoch))
|
|
408
|
+
os.makedirs(epoch_dir, exist_ok=True)
|
|
409
|
+
save_path = os.path.join(
|
|
410
|
+
epoch_dir, f"predictions_{set_name}_epoch_{epoch}.json"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# --- batch loop ---
|
|
414
|
+
for batch in dataloader:
|
|
415
|
+
if pbar:
|
|
416
|
+
pbar.update()
|
|
417
|
+
|
|
418
|
+
logits, labels, loss, has_labels = self._process_batch(batch, train)
|
|
419
|
+
|
|
420
|
+
if loss is not None:
|
|
421
|
+
total_loss += loss.item()
|
|
422
|
+
total_batches += 1
|
|
423
|
+
|
|
424
|
+
logits_cpu = logits.detach().cpu()
|
|
425
|
+
all_logits.append(logits_cpu)
|
|
426
|
+
|
|
427
|
+
if has_labels:
|
|
428
|
+
labels_cpu = labels.detach().cpu()
|
|
429
|
+
all_labels.append(labels_cpu)
|
|
430
|
+
|
|
431
|
+
# per-sample predictions for JSON export.
|
|
432
|
+
probs = torch.softmax(logits_cpu, dim=1)
|
|
433
|
+
preds = torch.argmax(probs, dim=1)
|
|
434
|
+
confs = probs.max(dim=1).values
|
|
435
|
+
ids = batch["id"]
|
|
436
|
+
|
|
437
|
+
for i in range(len(preds)):
|
|
438
|
+
results.append({
|
|
439
|
+
"id": ids[i],
|
|
440
|
+
"pred_label": self.class_names[preds[i].item()],
|
|
441
|
+
"confidence": float(confs[i].item()),
|
|
442
|
+
"pred_class_idx": preds[i].item(),
|
|
443
|
+
})
|
|
444
|
+
|
|
445
|
+
# --- concatenate local predictions ---
|
|
446
|
+
if len(all_logits) > 0:
|
|
447
|
+
all_logits = torch.cat(all_logits).numpy()
|
|
448
|
+
else:
|
|
449
|
+
all_logits = np.zeros((0, 1))
|
|
450
|
+
|
|
451
|
+
if len(all_labels) > 0:
|
|
452
|
+
all_labels = torch.cat(all_labels).numpy()
|
|
453
|
+
else:
|
|
454
|
+
all_labels = np.zeros((0,))
|
|
455
|
+
|
|
456
|
+
# --- DDP gather (handles uneven shard sizes) ---
|
|
457
|
+
if dist.is_initialized():
|
|
458
|
+
gathered = [None for _ in range(dist.get_world_size())]
|
|
459
|
+
dist.all_gather_object(gathered, (all_logits, all_labels, results))
|
|
460
|
+
|
|
461
|
+
if self.rank == 0:
|
|
462
|
+
all_logits = np.concatenate([g[0] for g in gathered])
|
|
463
|
+
all_labels = np.concatenate([g[1] for g in gathered])
|
|
464
|
+
results = [r for g in gathered for r in g[2]]
|
|
465
|
+
else:
|
|
466
|
+
return None, None, 0.0, {}
|
|
467
|
+
|
|
468
|
+
# --- metrics (rank-0 only in DDP) ---
|
|
469
|
+
if len(all_labels) > 0:
|
|
470
|
+
metrics = compute_classification_metrics(
|
|
471
|
+
(all_logits, all_labels), top_k=self.top_k,
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
metrics = {}
|
|
475
|
+
|
|
476
|
+
# --- confusion matrix (validation and test only) ---
|
|
477
|
+
if self.rank == 0 and set_name in ["valid", "test"] and len(all_labels) > 0:
|
|
478
|
+
preds_all, labels_all, _ = process_preds_labels(
|
|
479
|
+
(all_logits, all_labels)
|
|
480
|
+
)
|
|
481
|
+
class_names = [
|
|
482
|
+
self.class_names[i] for i in sorted(self.class_names.keys())
|
|
483
|
+
]
|
|
484
|
+
|
|
485
|
+
log_confusion_matrix_wandb(
|
|
486
|
+
y_true=labels_all.tolist(),
|
|
487
|
+
y_pred=preds_all.tolist(),
|
|
488
|
+
class_names=class_names,
|
|
489
|
+
split_name=set_name,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# --- save JSON (rank-0 only) ---
|
|
493
|
+
if self.rank == 0:
|
|
494
|
+
submission = {
|
|
495
|
+
"version": "2.0",
|
|
496
|
+
"task": "action_classification",
|
|
497
|
+
"date": datetime.now().strftime("%Y-%m-%d"),
|
|
498
|
+
"metadata": {"type": "predictions"},
|
|
499
|
+
"data": [],
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
for r in results:
|
|
503
|
+
submission["data"].append({
|
|
504
|
+
"id": r["id"],
|
|
505
|
+
"labels": {
|
|
506
|
+
"action": {
|
|
507
|
+
"label": r["pred_label"],
|
|
508
|
+
"confidence": r["confidence"],
|
|
509
|
+
}
|
|
510
|
+
},
|
|
511
|
+
})
|
|
512
|
+
|
|
513
|
+
logging.info(f"RESULTS Length: {len(results)}")
|
|
514
|
+
logging.info(f"Predicitions are stored at : {save_path}")
|
|
515
|
+
with open(save_path, "w") as f:
|
|
516
|
+
json.dump(submission, f, indent=2)
|
|
517
|
+
|
|
518
|
+
return all_logits, all_labels, total_loss / max(1, total_batches), metrics
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# -- checkpoint saving ---------------------------------------
|
|
522
|
+
|
|
523
|
+
def _save_checkpoint(self, filename, epoch, tag=None):
|
|
524
|
+
epoch_dir = os.path.join(self.save_dir, str(filename))
|
|
525
|
+
os.makedirs(epoch_dir, exist_ok=True)
|
|
526
|
+
|
|
527
|
+
state = {
|
|
528
|
+
"epoch": epoch,
|
|
529
|
+
"state_dict": self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict(),
|
|
530
|
+
"optimizer": self.optimizer.state_dict(),
|
|
531
|
+
"scheduler": self.scheduler.state_dict(),
|
|
532
|
+
"monitor": self.monitor,
|
|
533
|
+
"mode": self.mode,
|
|
534
|
+
"best_metric": self.best_metric,
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
if hasattr(self, "scaler"):
|
|
538
|
+
state["scaler"] = self.scaler.state_dict()
|
|
539
|
+
|
|
540
|
+
name = f"epoch_{epoch}.pt"
|
|
541
|
+
if tag:
|
|
542
|
+
name = f"{tag}_epoch_{epoch}.pt"
|
|
543
|
+
|
|
544
|
+
path_aux = os.path.join(epoch_dir, name)
|
|
545
|
+
torch.save(state, path_aux)
|
|
546
|
+
logging.info(f"Saved checkpoint: {path_aux}")
|
|
547
|
+
return path_aux
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
# --------------------------------------------------------------
|
|
551
|
+
# modality-specific trainers
|
|
552
|
+
# --------------------------------------------------------------
|
|
553
|
+
|
|
554
|
+
class MVTrainerClassification(BaseTrainerClassification):
|
|
555
|
+
"""forward pass for multi-view video classification.
|
|
556
|
+
|
|
557
|
+
expects batches with pixel_values of shape
|
|
558
|
+
(B, V, C, T, H, W) and integer labels of shape (B,).
|
|
559
|
+
"""
|
|
560
|
+
|
|
561
|
+
def _forward_batch(self, batch):
|
|
562
|
+
"""move video clips to device and run the model.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
batch: dict with keys "pixel_values" and "labels".
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
a tuple (logits, labels) on self.device.
|
|
569
|
+
"""
|
|
570
|
+
mvclips = batch["pixel_values"].to(self.device).float()
|
|
571
|
+
labels = batch.get("labels", None)
|
|
572
|
+
if labels is not None:
|
|
573
|
+
labels = labels.to(self.device)
|
|
574
|
+
|
|
575
|
+
outputs = self.model(mvclips)
|
|
576
|
+
|
|
577
|
+
if isinstance(outputs, tuple):
|
|
578
|
+
logits = outputs[0]
|
|
579
|
+
else:
|
|
580
|
+
logits = outputs
|
|
581
|
+
|
|
582
|
+
if logits.dim() == 1:
|
|
583
|
+
logits = logits.unsqueeze(0)
|
|
584
|
+
|
|
585
|
+
return logits, labels
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
# ============================================================
|
|
589
|
+
# Tracking Trainer
|
|
590
|
+
# ============================================================
|
|
591
|
+
|
|
592
|
+
class TrackingTrainerClassification(BaseTrainerClassification):
|
|
593
|
+
"""forward pass for tracking-based classification.
|
|
594
|
+
|
|
595
|
+
expects batches with x of shape (B, N, 2), edge_index of shape (2, E),
|
|
596
|
+
batch of shape (B,), batch_size, seq_len, and integer labels of shape (B,).
|
|
597
|
+
"""
|
|
598
|
+
|
|
599
|
+
def _forward_batch(self, batch):
|
|
600
|
+
"""move tracking data to device and run the model.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
batch: dict with keys "x", "edge_index", "batch", "batch_size",
|
|
604
|
+
"seq_len", and "labels".
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
a tuple (logits, labels) on self.device.
|
|
608
|
+
"""
|
|
609
|
+
tracking_batch = {
|
|
610
|
+
"x": batch["x"].to(self.device),
|
|
611
|
+
"edge_index": batch["edge_index"].to(self.device),
|
|
612
|
+
"batch": batch["batch"].to(self.device),
|
|
613
|
+
"batch_size": batch["batch_size"],
|
|
614
|
+
"seq_len": batch["seq_len"],
|
|
615
|
+
}
|
|
616
|
+
labels = batch.get("labels", None)
|
|
617
|
+
if labels is not None:
|
|
618
|
+
labels = labels.to(self.device)
|
|
619
|
+
|
|
620
|
+
logits = self.model(tracking_batch)
|
|
621
|
+
|
|
622
|
+
return logits, labels
|
|
623
|
+
|
|
624
|
+
class FramesTrainerClassification(BaseTrainerClassification):
|
|
625
|
+
"""forward pass for frames_npy video classification.
|
|
626
|
+
|
|
627
|
+
supports optional mixed-precision training (AMP) and mixup
|
|
628
|
+
augmentation, controlled via config.TRAIN.use_amp and
|
|
629
|
+
config.TRAIN.mixup_alpha respectively.
|
|
630
|
+
|
|
631
|
+
expects batches with pixel_values of shape (B, T, H, W, C)
|
|
632
|
+
and integer labels of shape (B,).
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
def __init__(self, *args, **kwargs):
|
|
636
|
+
super().__init__(*args, **kwargs)
|
|
637
|
+
cfg = self.config
|
|
638
|
+
self.use_amp = getattr(cfg.TRAIN, "use_amp", False) if cfg else False
|
|
639
|
+
self.mixup_alpha = getattr(cfg.TRAIN, "mixup_alpha", 0.0) if cfg else 0.0
|
|
640
|
+
self.scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
|
|
641
|
+
|
|
642
|
+
def _forward_batch(self, batch):
|
|
643
|
+
pixel_values = batch["pixel_values"].to(self.device).float()
|
|
644
|
+
labels = batch["labels"].to(self.device)
|
|
645
|
+
logits = self.model({"pixel_values": pixel_values})
|
|
646
|
+
return logits, labels
|
|
647
|
+
|
|
648
|
+
def _process_batch(self, batch, train):
|
|
649
|
+
pixel_values = batch["pixel_values"].to(self.device).float()
|
|
650
|
+
labels = batch["labels"].to(self.device)
|
|
651
|
+
|
|
652
|
+
with torch.set_grad_enabled(train):
|
|
653
|
+
use_mixup = (
|
|
654
|
+
train
|
|
655
|
+
and self.mixup_alpha > 0
|
|
656
|
+
and np.random.random() > 0.5
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
|
660
|
+
if use_mixup:
|
|
661
|
+
pixel_values, labels_a, labels_b, lam = mixup_data(
|
|
662
|
+
pixel_values, labels, self.mixup_alpha
|
|
663
|
+
)
|
|
664
|
+
logits = self.model({"pixel_values": pixel_values})
|
|
665
|
+
loss = (
|
|
666
|
+
lam * self.criterion(output=logits, labels=labels_a)
|
|
667
|
+
+ (1 - lam) * self.criterion(output=logits, labels=labels_b)
|
|
668
|
+
)
|
|
669
|
+
labels = labels_a
|
|
670
|
+
else:
|
|
671
|
+
logits = self.model({"pixel_values": pixel_values})
|
|
672
|
+
if self.class_weights is not None:
|
|
673
|
+
loss = self.criterion(
|
|
674
|
+
output=logits, labels=labels,
|
|
675
|
+
weight=self.class_weights.to(self.device),
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
loss = self.criterion(output=logits, labels=labels)
|
|
679
|
+
|
|
680
|
+
if train:
|
|
681
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
682
|
+
self.scaler.scale(loss).backward()
|
|
683
|
+
self.scaler.unscale_(self.optimizer)
|
|
684
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
685
|
+
self.scaler.step(self.optimizer)
|
|
686
|
+
self.scaler.update()
|
|
687
|
+
|
|
688
|
+
return logits, labels, loss
|
|
689
|
+
|
|
690
|
+
# --------------------------------------------------------------
|
|
691
|
+
# unified trainer dispatcher
|
|
692
|
+
# --------------------------------------------------------------
|
|
693
|
+
|
|
694
|
+
class Trainer_Classification:
|
|
695
|
+
"""high-level trainer that dispatches to the right modality trainer.
|
|
696
|
+
|
|
697
|
+
consumed by ClassificationAPI. Responsible for building data
|
|
698
|
+
loaders, optimizers, schedulers, and samplers, then delegating the
|
|
699
|
+
actual loop to MVTrainerClassification or TrackingTrainerClassification.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
config: the configuration object.
|
|
703
|
+
"""
|
|
704
|
+
|
|
705
|
+
def __init__(self, config):
|
|
706
|
+
self.config = config
|
|
707
|
+
self.device = select_device(self.config.SYSTEM)
|
|
708
|
+
self.model = None
|
|
709
|
+
self.optimizer = None
|
|
710
|
+
self.scheduler = None
|
|
711
|
+
self.epoch = 0
|
|
712
|
+
self.trainer = None
|
|
713
|
+
|
|
714
|
+
def compute_metrics(self, pred, mode="logits"):
|
|
715
|
+
"""thin wrapper around the metric module.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
pred: a tuple (logits, labels).
|
|
719
|
+
mode: "logits" or "labels" (default: "logits").
|
|
720
|
+
|
|
721
|
+
Returns:
|
|
722
|
+
a dictionary of classification metrics.
|
|
723
|
+
"""
|
|
724
|
+
return compute_classification_metrics(
|
|
725
|
+
pred, top_k=2, mode=mode
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# -- training -----------------------------------------------
|
|
729
|
+
|
|
730
|
+
def train(self, model, train_dataset, val_dataset=None, rank=0, world_size=1):
|
|
731
|
+
"""build all training components and run the loop.
|
|
732
|
+
|
|
733
|
+
detects the model type (HuggingFace vs. custom) and the data
|
|
734
|
+
modality (video vs. tracking) to select the right trainer class,
|
|
735
|
+
sampler, and collate function.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
model: the classification model.
|
|
739
|
+
train_dataset: training ClassificationDataset.
|
|
740
|
+
val_dataset: validation ClassificationDataset (optional).
|
|
741
|
+
rank: GPU rank (0-indexed).
|
|
742
|
+
world_size: total number of GPUs.
|
|
743
|
+
"""
|
|
744
|
+
from opensportslib.core.loss.builder import build_criterion
|
|
745
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
746
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
747
|
+
from opensportslib.core.utils.data import tracking_collate_fn
|
|
748
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
749
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
750
|
+
|
|
751
|
+
is_ddp = world_size > 1
|
|
752
|
+
modality = getattr(self.config.DATA, 'data_modality', 'video')
|
|
753
|
+
seed = self.config.SYSTEM.seed
|
|
754
|
+
|
|
755
|
+
g = torch.Generator()
|
|
756
|
+
g.manual_seed(seed)
|
|
757
|
+
|
|
758
|
+
# HuggingFace models (e.g. VideoMAE) use the HF Trainer.
|
|
759
|
+
if self.config.MODEL.type == "huggingface":
|
|
760
|
+
self._train_huggingface(model, train_dataset, val_dataset)
|
|
761
|
+
return
|
|
762
|
+
|
|
763
|
+
if is_ddp:
|
|
764
|
+
torch.cuda.set_device(rank)
|
|
765
|
+
self.device = torch.device(f"cuda:{rank}")
|
|
766
|
+
else:
|
|
767
|
+
self.device = select_device(self.config.SYSTEM)
|
|
768
|
+
|
|
769
|
+
self.model = model.to(self.device)
|
|
770
|
+
|
|
771
|
+
if is_ddp:
|
|
772
|
+
self.model = DDP(self.model, device_ids=[rank])
|
|
773
|
+
|
|
774
|
+
# Build components
|
|
775
|
+
optimizer = build_optimizer(
|
|
776
|
+
self.model.parameters(), cfg=self.config.TRAIN.optimizer
|
|
777
|
+
)
|
|
778
|
+
scheduler = build_scheduler(
|
|
779
|
+
optimizer, cfg=self.config.TRAIN.scheduler
|
|
780
|
+
)
|
|
781
|
+
criterion = build_criterion(self.config.TRAIN.criterion)
|
|
782
|
+
|
|
783
|
+
# --- class weights for the loss ---
|
|
784
|
+
if self.config.TRAIN.use_weighted_loss:
|
|
785
|
+
class_weights = train_dataset.get_class_weights(
|
|
786
|
+
num_classes=train_dataset.num_classes(), sqrt=True
|
|
787
|
+
).to(self.device)
|
|
788
|
+
else:
|
|
789
|
+
class_weights = None
|
|
790
|
+
|
|
791
|
+
# tracking modality needs a customm collate that merges PyG
|
|
792
|
+
# Data objects into a single batched graph per timestamp.
|
|
793
|
+
collate_fn = tracking_collate_fn if modality == "tracking_parquet" else None
|
|
794
|
+
|
|
795
|
+
# --- train sampler ---
|
|
796
|
+
if self.config.TRAIN.use_weighted_sampler:
|
|
797
|
+
sample_weights = train_dataset.get_sample_weights()
|
|
798
|
+
|
|
799
|
+
samples_per_class = getattr(
|
|
800
|
+
self.config.TRAIN, 'samples_per_class', None
|
|
801
|
+
)
|
|
802
|
+
if samples_per_class:
|
|
803
|
+
num_classes = train_dataset.num_classes()
|
|
804
|
+
num_samples = samples_per_class * num_classes
|
|
805
|
+
else:
|
|
806
|
+
num_samples = len(sample_weights)
|
|
807
|
+
|
|
808
|
+
if is_ddp:
|
|
809
|
+
train_sampler = DistributedWeightedSampler(
|
|
810
|
+
weights=sample_weights,
|
|
811
|
+
num_replicas=world_size,
|
|
812
|
+
rank=rank,
|
|
813
|
+
replacement=True,
|
|
814
|
+
num_samples=num_samples,
|
|
815
|
+
seed=self.config.SYSTEM.seed
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
train_sampler = WeightedRandomSampler(
|
|
819
|
+
weights=sample_weights,
|
|
820
|
+
num_samples=num_samples,
|
|
821
|
+
replacement=True,
|
|
822
|
+
generator=g
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
shuffle = False
|
|
826
|
+
|
|
827
|
+
else:
|
|
828
|
+
if is_ddp:
|
|
829
|
+
train_sampler = DistributedSampler(
|
|
830
|
+
train_dataset,
|
|
831
|
+
num_replicas=world_size,
|
|
832
|
+
rank=rank,
|
|
833
|
+
shuffle=True,
|
|
834
|
+
drop_last=True
|
|
835
|
+
)
|
|
836
|
+
else:
|
|
837
|
+
train_sampler = None
|
|
838
|
+
|
|
839
|
+
shuffle = not is_ddp
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
# --- validation sampler ---
|
|
843
|
+
if is_ddp:
|
|
844
|
+
val_sampler = DistributedSampler(
|
|
845
|
+
val_dataset,
|
|
846
|
+
num_replicas=world_size,
|
|
847
|
+
rank=rank,
|
|
848
|
+
shuffle=False,
|
|
849
|
+
drop_last=False
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
val_sampler = None
|
|
853
|
+
|
|
854
|
+
num_train_workers = self.config.DATA.train.dataloader.num_workers
|
|
855
|
+
num_val_workers = self.config.DATA.valid.dataloader.num_workers
|
|
856
|
+
|
|
857
|
+
train_loader = DataLoader(
|
|
858
|
+
train_dataset,
|
|
859
|
+
batch_size=self.config.DATA.train.dataloader.batch_size,
|
|
860
|
+
shuffle=(train_sampler is None and shuffle),
|
|
861
|
+
sampler=train_sampler,
|
|
862
|
+
num_workers=num_train_workers,
|
|
863
|
+
pin_memory=True,
|
|
864
|
+
collate_fn=collate_fn,
|
|
865
|
+
worker_init_fn=seed_worker,
|
|
866
|
+
generator=g,
|
|
867
|
+
drop_last=True,
|
|
868
|
+
persistent_workers=num_train_workers > 0,
|
|
869
|
+
prefetch_factor=4 if num_train_workers > 0 else None,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
val_loader = DataLoader(
|
|
873
|
+
val_dataset,
|
|
874
|
+
batch_size=self.config.DATA.valid.dataloader.batch_size,
|
|
875
|
+
shuffle=False,
|
|
876
|
+
sampler=val_sampler,
|
|
877
|
+
num_workers=num_val_workers,
|
|
878
|
+
pin_memory=True,
|
|
879
|
+
collate_fn=collate_fn,
|
|
880
|
+
worker_init_fn=seed_worker,
|
|
881
|
+
generator=g,
|
|
882
|
+
persistent_workers=num_val_workers > 0,
|
|
883
|
+
prefetch_factor=4 if num_val_workers > 0 else None,
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
# select the modality-specific trainer.
|
|
887
|
+
if modality == "tracking_parquet":
|
|
888
|
+
TrainerClass = TrackingTrainerClassification
|
|
889
|
+
elif modality == "frames_npy":
|
|
890
|
+
TrainerClass = FramesTrainerClassification
|
|
891
|
+
else:
|
|
892
|
+
TrainerClass = MVTrainerClassification
|
|
893
|
+
|
|
894
|
+
self.trainer = TrainerClass(
|
|
895
|
+
train_loader=train_loader,
|
|
896
|
+
val_loader=val_loader,
|
|
897
|
+
test_loader=None,
|
|
898
|
+
model=self.model,
|
|
899
|
+
optimizer=optimizer,
|
|
900
|
+
scheduler=scheduler,
|
|
901
|
+
criterion=criterion,
|
|
902
|
+
class_weights=class_weights,
|
|
903
|
+
class_names=train_dataset.label_map,
|
|
904
|
+
save_dir=self.config.SYSTEM.save_dir,
|
|
905
|
+
model_name=self.config.MODEL.backbone.type,
|
|
906
|
+
max_epochs=self.config.TRAIN.epochs,
|
|
907
|
+
device=self.device,
|
|
908
|
+
top_k=2,
|
|
909
|
+
patience=getattr(self.config.TRAIN, "patience", 0),
|
|
910
|
+
monitor=getattr(self.config.TRAIN, "monitor", "balanced_accuracy"),
|
|
911
|
+
mode=getattr(self.config.TRAIN, "mode", "max"),
|
|
912
|
+
revert_on_lr_reduction=(modality in ("tracking_parquet", "frames_npy")),
|
|
913
|
+
config=self.config,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
self.trainer.train(epoch_start=self.epoch, save_every=self.config.TRAIN.save_every)
|
|
917
|
+
return getattr(self.trainer, "best_checkpoint_path", None)
|
|
918
|
+
|
|
919
|
+
def _train_huggingface(self, model, train_dataset, val_dataset):
|
|
920
|
+
"""Handle HuggingFace Trainer for VideoMAE."""
|
|
921
|
+
from opensportslib.core.sampler.weighted_sampler import WeightedTrainer, VideoMAETrainer
|
|
922
|
+
|
|
923
|
+
self.model = model
|
|
924
|
+
|
|
925
|
+
args = TrainingArguments(
|
|
926
|
+
label_names=["labels"],
|
|
927
|
+
output_dir=self.config.SYSTEM.save_dir,
|
|
928
|
+
per_device_train_batch_size=self.config.DATA.train.dataloader.batch_size,
|
|
929
|
+
per_device_eval_batch_size=self.config.DATA.valid.dataloader.batch_size,
|
|
930
|
+
num_train_epochs=self.config.TRAIN.epochs,
|
|
931
|
+
eval_strategy="epoch" if val_dataset else "no",
|
|
932
|
+
save_strategy="epoch",
|
|
933
|
+
logging_strategy="steps",
|
|
934
|
+
logging_steps=5,
|
|
935
|
+
save_total_limit=10,
|
|
936
|
+
load_best_model_at_end=True,
|
|
937
|
+
fp16=True,
|
|
938
|
+
warmup_ratio=0.1,
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
if self.config.TRAIN.use_weighted_sampler:
|
|
942
|
+
self.trainer = WeightedTrainer(
|
|
943
|
+
model=self.model,
|
|
944
|
+
args=args,
|
|
945
|
+
train_dataset=train_dataset,
|
|
946
|
+
eval_dataset=val_dataset,
|
|
947
|
+
compute_metrics=self.compute_metrics,
|
|
948
|
+
config=self.config
|
|
949
|
+
)
|
|
950
|
+
else:
|
|
951
|
+
self.trainer = VideoMAETrainer(
|
|
952
|
+
model=self.model,
|
|
953
|
+
args=args,
|
|
954
|
+
train_dataset=train_dataset,
|
|
955
|
+
eval_dataset=val_dataset,
|
|
956
|
+
compute_metrics=self.compute_metrics,
|
|
957
|
+
config=self.config
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
self.trainer.train()
|
|
961
|
+
#############
|
|
962
|
+
train_metrics = self.hf_trainer.evaluate(train_dataset, metric_key_prefix="train")
|
|
963
|
+
logging.info(f"TRAIN METRICS: {train_metrics}")
|
|
964
|
+
#############
|
|
965
|
+
|
|
966
|
+
def infer(self, test_dataset, rank=0, world_size=1):
|
|
967
|
+
if self.config.MODEL.type == "huggingface":
|
|
968
|
+
|
|
969
|
+
args = TrainingArguments(
|
|
970
|
+
output_dir=self.config.SYSTEM.save_dir, # any directory, not used here
|
|
971
|
+
per_device_eval_batch_size=1#self.config.DATA.valid.dataloader.batch_size, # or whatever batch size you want
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
self.hf_trainer = HFTrainer(
|
|
975
|
+
model=self.model,
|
|
976
|
+
args=args,
|
|
977
|
+
compute_metrics=self.compute_metrics # optional, can compute later manually
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
preds_output = self.hf_trainer.predict(test_dataset)
|
|
981
|
+
logits = preds_output.predictions
|
|
982
|
+
# if isinstance(logits, tuple):
|
|
983
|
+
# logits = logits[0]
|
|
984
|
+
|
|
985
|
+
# predictions = np.argmax(logits, axis=-1)
|
|
986
|
+
labels = preds_output.label_ids
|
|
987
|
+
metrics = self.compute_metrics((logits, labels))
|
|
988
|
+
|
|
989
|
+
else:
|
|
990
|
+
from opensportslib.core.loss.builder import build_criterion
|
|
991
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
992
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
993
|
+
from opensportslib.core.utils.data import tracking_collate_fn
|
|
994
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
995
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
996
|
+
|
|
997
|
+
is_ddp = world_size > 1
|
|
998
|
+
|
|
999
|
+
if is_ddp:
|
|
1000
|
+
torch.cuda.set_device(rank)
|
|
1001
|
+
self.device = torch.device(f"cuda:{rank}")
|
|
1002
|
+
else:
|
|
1003
|
+
self.device = select_device(self.config.SYSTEM)
|
|
1004
|
+
|
|
1005
|
+
# model
|
|
1006
|
+
self.model = self.model.to(self.device)
|
|
1007
|
+
if is_ddp:
|
|
1008
|
+
self.model = DDP(self.model, device_ids=[rank])
|
|
1009
|
+
test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size)
|
|
1010
|
+
else:
|
|
1011
|
+
test_sampler = None
|
|
1012
|
+
|
|
1013
|
+
modality = getattr(self.config.DATA, 'data_modality', 'video')
|
|
1014
|
+
collate_fn = tracking_collate_fn if modality == "tracking_parquet" else None
|
|
1015
|
+
|
|
1016
|
+
test_loader = DataLoader(
|
|
1017
|
+
test_dataset,
|
|
1018
|
+
batch_size=self.config.DATA.test.dataloader.batch_size,
|
|
1019
|
+
shuffle=False,
|
|
1020
|
+
sampler=test_sampler,
|
|
1021
|
+
num_workers=self.config.DATA.test.dataloader.num_workers,
|
|
1022
|
+
pin_memory=True,
|
|
1023
|
+
collate_fn=collate_fn
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
optimizer = self.optimizer if self.optimizer is not None else build_optimizer(self.model.parameters(), cfg=self.config.TRAIN.optimizer)
|
|
1027
|
+
scheduler = self.scheduler if self.scheduler is not None else build_scheduler(optimizer, cfg=self.config.TRAIN.scheduler)
|
|
1028
|
+
criterion = build_criterion(self.config.TRAIN.criterion)
|
|
1029
|
+
|
|
1030
|
+
# Select trainer class based on modality
|
|
1031
|
+
if modality == "tracking_parquet":
|
|
1032
|
+
TrainerClass = TrackingTrainerClassification
|
|
1033
|
+
elif modality == "frames_npy":
|
|
1034
|
+
TrainerClass = FramesTrainerClassification
|
|
1035
|
+
else:
|
|
1036
|
+
TrainerClass = MVTrainerClassification
|
|
1037
|
+
|
|
1038
|
+
self.test_trainer = TrainerClass(
|
|
1039
|
+
train_loader=None,
|
|
1040
|
+
val_loader=None,
|
|
1041
|
+
test_loader=test_loader,
|
|
1042
|
+
model=self.model,
|
|
1043
|
+
optimizer=optimizer,
|
|
1044
|
+
scheduler=scheduler,
|
|
1045
|
+
criterion=criterion,
|
|
1046
|
+
class_weights=None,
|
|
1047
|
+
class_names=test_dataset.label_map,
|
|
1048
|
+
save_dir=self.config.SYSTEM.save_dir,
|
|
1049
|
+
model_name=self.config.MODEL.backbone.type,
|
|
1050
|
+
max_epochs=self.config.TRAIN.epochs,
|
|
1051
|
+
device=self.device,
|
|
1052
|
+
top_k=2,
|
|
1053
|
+
monitor=getattr(self.config.TRAIN, "monitor", "balanced_accuracy"),
|
|
1054
|
+
mode=getattr(self.config.TRAIN, "mode", "max"),
|
|
1055
|
+
revert_on_lr_reduction=(modality in ("tracking_parquet", "frames_npy")),
|
|
1056
|
+
config=self.config,
|
|
1057
|
+
)
|
|
1058
|
+
loss, metrics = self.test_trainer.test(
|
|
1059
|
+
detailed_results=getattr(self.config.TRAIN, 'detailed_results', False)
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
return metrics
|
|
1063
|
+
|
|
1064
|
+
def evaluate(self, pred_path, gt_path, class_names, exclude_labels=[]):
|
|
1065
|
+
|
|
1066
|
+
label_to_idx = {v: k for k, v in class_names.items()}
|
|
1067
|
+
|
|
1068
|
+
with open(pred_path) as f:
|
|
1069
|
+
pred_data = json.load(f)
|
|
1070
|
+
|
|
1071
|
+
with open(gt_path) as f:
|
|
1072
|
+
gt_data = json.load(f)
|
|
1073
|
+
|
|
1074
|
+
gt_dict = {}
|
|
1075
|
+
for item in gt_data["data"]:
|
|
1076
|
+
sid = item["id"]
|
|
1077
|
+
gt_label = item["labels"]["action"]["label"]
|
|
1078
|
+
if gt_label not in exclude_labels:
|
|
1079
|
+
gt_dict[sid] = label_to_idx[gt_label]
|
|
1080
|
+
|
|
1081
|
+
preds = []
|
|
1082
|
+
labels = []
|
|
1083
|
+
|
|
1084
|
+
for item in pred_data["data"]:
|
|
1085
|
+
sid = item["id"]
|
|
1086
|
+
if sid not in gt_dict:
|
|
1087
|
+
continue
|
|
1088
|
+
|
|
1089
|
+
pred_label = item["labels"]["action"]["label"]
|
|
1090
|
+
|
|
1091
|
+
preds.append(label_to_idx[pred_label])
|
|
1092
|
+
labels.append(gt_dict[sid])
|
|
1093
|
+
|
|
1094
|
+
metrics = self.compute_metrics(
|
|
1095
|
+
(preds, labels),
|
|
1096
|
+
mode="labels"
|
|
1097
|
+
)
|
|
1098
|
+
return metrics
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def demo(self, model, video_paths):
|
|
1102
|
+
pass
|
|
1103
|
+
|
|
1104
|
+
def save(self, model, path, processor=None, tokenizer=None, optimizer=None, epoch=None):
|
|
1105
|
+
"""
|
|
1106
|
+
Save model checkpoint
|
|
1107
|
+
"""
|
|
1108
|
+
save_checkpoint(model, path, processor, tokenizer, optimizer, epoch)
|
|
1109
|
+
logging.info(f"Model saved at {path}")
|
|
1110
|
+
|
|
1111
|
+
def load(self, path, optimizer=None, scheduler=None):
|
|
1112
|
+
"""
|
|
1113
|
+
Load model checkpoint. Returns loaded model, optimizer, epoch
|
|
1114
|
+
"""
|
|
1115
|
+
if self.config.MODEL.type == "huggingface":
|
|
1116
|
+
epoch = None
|
|
1117
|
+
self.model, processor = load_huggingface_checkpoint(self.config, path=path, device=self.device)
|
|
1118
|
+
logging.info(f"Model loaded from {path}")
|
|
1119
|
+
return self.model, processor, scheduler, epoch
|
|
1120
|
+
else:
|
|
1121
|
+
from opensportslib.models.builder import build_model
|
|
1122
|
+
if self.model is None:
|
|
1123
|
+
self.model, _ = build_model(self.config, self.device)
|
|
1124
|
+
self.model, optimizer, scheduler, epoch = load_checkpoint(
|
|
1125
|
+
self.model, path, optimizer, scheduler, device=self.device
|
|
1126
|
+
)
|
|
1127
|
+
self.optimizer = optimizer
|
|
1128
|
+
self.scheduler = scheduler
|
|
1129
|
+
self.epoch = epoch
|
|
1130
|
+
logging.info(f"Model loaded from {path}, epoch: {epoch}")
|
|
1131
|
+
return self.model, self.optimizer, self.scheduler, self.epoch
|