opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1131 @@
1
+ # opensportslib/core/trainer/classification_trainer.py
2
+
3
+ """classification trainers for video and tracking modalities.
4
+
5
+ provides a base trainer with modality-agnostic training, validation,
6
+ and test loops, plus two modality-specific subclasses that implement
7
+ the forward pass. Trainer_Classification is the top-level dispatcher
8
+ consumed by the API layer.
9
+
10
+ """
11
+
12
+ import os
13
+ import gc
14
+ import json
15
+ import time
16
+ import logging
17
+
18
+ import torch
19
+ import tqdm
20
+ import wandb
21
+ import numpy as np
22
+
23
+ from torch.utils.data import (
24
+ DataLoader,
25
+ WeightedRandomSampler,
26
+ )
27
+
28
+ from transformers import Trainer as HFTrainer, TrainingArguments
29
+ from opensportslib.core.utils.ddp import DistributedWeightedSampler
30
+
31
+ from opensportslib.core.utils.wandb import log_confusion_matrix_wandb
32
+ from opensportslib.core.utils.checkpoint import *
33
+
34
+ from opensportslib.core.utils.config import select_device
35
+ from opensportslib.core.utils.data import mixup_data
36
+ import torch.distributed as dist
37
+ from datetime import datetime
38
+ from opensportslib.core.utils.seed import seed_worker
39
+ from opensportslib.metrics.classification_metric import (
40
+ compute_classification_metrics,
41
+ process_preds_labels
42
+ )
43
+
44
+ # -------------------------------------------------------------------
45
+ # base classification trainer
46
+ # -------------------------------------------------------------------
47
+
48
+ class BaseTrainerClassification:
49
+ """modality-agnostic training loop for classification.
50
+
51
+ handles epoch iteration, gradient updates, DDP gather,
52
+ metric computation, W&B logging, checkpoint saving, and JSON
53
+ prediction export. subclasses only need to override _forward_batch()
54
+ with modality-specific tensor preparation.
55
+
56
+ Args:
57
+ train_loader: DataLoader for training set.
58
+ val_loader: DataLoader for validation set.
59
+ test_loader: DataLoader for test set (may be None during training).
60
+ model: the classification model (already on device).
61
+ optimizer: PyTorch optimizer.
62
+ scheduler: learning-rate scheduler.
63
+ criterion: loss function callable.
64
+ class_weights: optional per-class weight tensor for the loss.
65
+ class_names: dict mapping class indices to names.
66
+ save_dir: root directory for checkpoint and prediction output.
67
+ model_name: name used for the checkpoint sub-directory.
68
+ max_epochs: maximum number of training epochs.
69
+ device: torch.device or device string.
70
+ top_k: k value for top-k accuracy computation.
71
+ wandb_project: W&B project name.
72
+ wandb_run_name: W&B run display name.
73
+ wandb_config: dict of hyperparameters logged to W&B.
74
+ patience: early-stopping patience (0=disabled).
75
+ monitor: metric name to monitor for checkpointing.
76
+ mode: "max" or "min" depending on the monitored metric.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ train_loader,
82
+ val_loader,
83
+ test_loader,
84
+ model,
85
+ optimizer,
86
+ scheduler,
87
+ criterion,
88
+ class_weights,
89
+ class_names,
90
+ save_dir,
91
+ model_name,
92
+ max_epochs=1000,
93
+ device="cuda",
94
+ top_k=2,
95
+ patience=10,
96
+ monitor="balanced_accuracy",
97
+ mode="max",
98
+ revert_on_lr_reduction=False,
99
+ config=None,
100
+ ):
101
+ self.train_loader = train_loader
102
+ self.val_loader = val_loader
103
+ self.test_loader = test_loader
104
+
105
+ self.model = model#.to(device)
106
+ #self.model = DDP(self.model, device_ids=[device])
107
+ self.optimizer = optimizer
108
+ self.scheduler = scheduler
109
+ self.criterion = criterion
110
+ self.class_weights = class_weights
111
+ self.class_names = class_names
112
+
113
+ self.model_name = model_name
114
+ self.max_epochs = max_epochs
115
+ self.device = device
116
+ self.top_k = top_k
117
+ self.patience = patience
118
+
119
+ self.monitor = monitor
120
+ self.mode = mode
121
+ self.config = config
122
+
123
+ self.best_checkpoint_path = None
124
+ self.best_metric = None
125
+ self.revert_on_lr_reduction = revert_on_lr_reduction
126
+ self._best_model_state = None
127
+
128
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
129
+
130
+ self.save_dir = save_dir
131
+ os.makedirs(self.save_dir, exist_ok=True)
132
+
133
+ try:
134
+ if self.rank == 0:
135
+ wandb.watch(self.model, log="gradients", log_freq=100)
136
+ except Exception:
137
+ pass
138
+
139
+ # -- abstract forward pass --------------------------------------
140
+
141
+ def _forward_batch(self, batch):
142
+ """run the modality-specific forward pass.
143
+
144
+ must be overridden by every subclass.
145
+
146
+ Args:
147
+ batch: a dict produced by the DataLoader.
148
+
149
+ Returns:
150
+ a tuple (logits, labels) where both are tensors on
151
+ self.device.
152
+ """
153
+ raise NotImplementedError
154
+
155
+ # -- process batch ----------------------------------------------
156
+
157
+ def _process_batch(self, batch, train):
158
+ """run forward pass, compute loss, and optionally update weights.
159
+
160
+ the default implementation calls _forward_batch() for the
161
+ modality-specific forward pass, then computes the loss and
162
+ runs the backward step. subclasses may override this entirely
163
+ to inject AMP, mixup, or other training-time modifications
164
+ without touching the base training loop.
165
+
166
+ Args:
167
+ batch: a dict produced by the DataLoader.
168
+ train: if True, compute gradients and update weights.
169
+
170
+ Returns:
171
+ a tuple (logits, labels, loss).
172
+ """
173
+ has_labels = "labels" in batch or "label" in batch
174
+ with torch.set_grad_enabled(train):
175
+ logits, labels = self._forward_batch(batch)
176
+ if labels is None:
177
+ has_labels = False
178
+ loss = None
179
+ if has_labels:
180
+ if self.class_weights is not None:
181
+ loss = self.criterion(
182
+ output=logits, labels=labels,
183
+ weight=self.class_weights.to(self.device)
184
+ )
185
+ else:
186
+ loss = self.criterion(output=logits, labels=labels)
187
+
188
+ if train and loss is not None:
189
+ self.optimizer.zero_grad()
190
+ loss.backward()
191
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
192
+ self.optimizer.step()
193
+
194
+ return logits, labels, loss, has_labels
195
+
196
+ # -- training loop ----------------------------------------------
197
+
198
+ def train(self, epoch_start=0, save_every=3):
199
+ """run the full training loop with validation after each epoch.
200
+
201
+ Args:
202
+ epoch_start: the epoch number to start from (0-based).
203
+ save_every: currently unused; reserved for periodic
204
+ checkpoint saving.
205
+ """
206
+ logging.info("Starting training")
207
+ monitor = self.monitor
208
+ mode = self.mode
209
+ best_metric = -float("inf") if mode == "max" else float("inf")
210
+ best_path = None
211
+
212
+ for epoch in range(epoch_start, self.max_epochs):
213
+ logging.info(f"\nEpoch {epoch+1}/{self.max_epochs}")
214
+
215
+ # --- train ---
216
+ if hasattr(self.train_loader.sampler, "set_epoch"):
217
+ self.train_loader.sampler.set_epoch(epoch)
218
+
219
+ disable = self.rank != 0
220
+
221
+ pbar = tqdm.tqdm(
222
+ total=len(self.train_loader), desc="Training",
223
+ position=0, leave=True, disable=disable
224
+ )
225
+ _, _, train_loss, train_metrics = self._run_epoch(
226
+ self.train_loader, epoch + 1,
227
+ train=True, set_name="train", pbar=pbar
228
+ )
229
+ pbar.close()
230
+
231
+ # --- validation ---
232
+ pbar = tqdm.tqdm(
233
+ total=len(self.val_loader), desc="Valid",
234
+ position=1, leave=True, disable=disable
235
+ )
236
+ _, _, val_loss, val_metrics = self._run_epoch(
237
+ self.val_loader, epoch + 1,
238
+ train=False, set_name="valid", pbar=pbar
239
+ )
240
+ pbar.close()
241
+
242
+ prev_lr = self.optimizer.param_groups[0]["lr"]
243
+
244
+ # capture LR before the scheduler step so we can detect
245
+ # plateau-triggered reductions.
246
+ val_metric = val_metrics.get(
247
+ "balanced_accuracy", val_metrics.get("accuracy", 0)
248
+ )
249
+ train_metric = train_metrics.get(
250
+ "balanced_accuracy", train_metrics.get("accuracy", 0)
251
+ )
252
+
253
+ # ReduceLROnPlateau needs the monitored metric
254
+ if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
255
+ self.scheduler.step(val_loss)
256
+ else:
257
+ self.scheduler.step()
258
+
259
+ current_lr = self.optimizer.param_groups[0]["lr"]
260
+
261
+ # early stopping: mirror pixels_vs_positions behavior
262
+ if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
263
+ min_lr = self.scheduler.min_lrs[0]
264
+ if current_lr <= 2 * min_lr:
265
+ if self.rank == 0:
266
+ logging.info(
267
+ f"Early stopping at epoch {epoch+1}: "
268
+ f"lr {current_lr:.2e} <= 2 * min_lr {min_lr:.2e}"
269
+ )
270
+ break
271
+
272
+ # When ReduceLROnPlateau drops the LR, revert weights to
273
+ # the best checkpoint so training continues from the
274
+ # strongest point rather than from a potentially overfit
275
+ # state. This mirrors the pixels_vs_positions recipe.
276
+ if (
277
+ self.revert_on_lr_reduction
278
+ and current_lr != prev_lr
279
+ and self._best_model_state is not None
280
+ ):
281
+ self.model.load_state_dict(self._best_model_state)
282
+ print(
283
+ f"LR reduced from {prev_lr:.2e} to {current_lr:.2e} "
284
+ f"-- reverted to best model"
285
+ )
286
+
287
+ if self.rank == 0:
288
+ # ---------------- W&B LOG ----------------
289
+ wandb.log({
290
+ "epoch": epoch + 1,
291
+ "lr": current_lr,
292
+ "train/loss": train_loss,
293
+ "valid/loss": val_loss,
294
+ **{f"train/{k}": v for k, v in train_metrics.items()},
295
+ **{f"valid/{k}": v for k, v in val_metrics.items()},
296
+ })
297
+
298
+ logging.info(f"Train Loss: {train_loss:.4f} | Train Bal Acc: {train_metric:.4f}")
299
+ logging.info(f"Val Loss: {val_loss:.4f} | Val Bal Acc: {val_metric:.4f}")
300
+
301
+ # ---------------- CHECKPOINT ----------------
302
+ current = val_loss if monitor == "loss" else val_metrics.get(monitor, 0)
303
+
304
+ is_better = current > best_metric if mode == "max" else current < best_metric
305
+
306
+ if is_better and self.rank == 0:
307
+ best_metric = current
308
+ self.best_metric = best_metric
309
+
310
+ if self.revert_on_lr_reduction:
311
+ self._best_model_state = {
312
+ k: v.cpu().clone()
313
+ for k, v in self.model.state_dict().items()
314
+ }
315
+
316
+ best_path = self._save_checkpoint("best", epoch + 1, tag="best")
317
+ self.best_checkpoint_path = best_path
318
+
319
+ artifact = wandb.Artifact("model-checkpoint", type="model")
320
+ artifact.add_file(best_path)
321
+ wandb.log_artifact(artifact)
322
+
323
+ if self.rank == 0:
324
+ logging.info(f"Best checkpoint : {self.best_checkpoint_path}")
325
+ logging.info("Training finished.")
326
+
327
+
328
+ # -- TEST evaluation ------------------------------------------
329
+
330
+ def test(self, epoch=None, detailed_results=False):
331
+ """run the test set evaluation.
332
+
333
+ Args:
334
+ epoch: the epoch number to evaluate (if None, uses "final").
335
+ detailed_results: whether to compute detailed classification metrics.
336
+
337
+ Returns:
338
+ a tuple (test_loss, test_metrics).
339
+ """
340
+ logging.info("\nRunning TEST evaluation")
341
+ pbar = tqdm.tqdm(
342
+ total=len(self.test_loader), desc="Test", position=0,
343
+ leave=True, disable = self.rank != 0
344
+ )
345
+ all_logits, all_labels, test_loss, test_metrics = self._run_epoch(
346
+ self.test_loader,
347
+ epoch if epoch is not None else "final",
348
+ train=False, set_name="test", pbar=pbar
349
+ )
350
+ pbar.close()
351
+
352
+ if self.rank==0:
353
+ wandb.log({
354
+ "test/loss": test_loss,
355
+ **{f"test/{k}": v for k, v in test_metrics.items()},
356
+ })
357
+
358
+ if detailed_results:
359
+ from opensportslib.metrics.classification_metric import (
360
+ compute_detailed_classification_metrics
361
+ )
362
+ compute_detailed_classification_metrics(
363
+ all_logits=all_logits, all_labels=all_labels,
364
+ class_names=self.class_names, save_dir=self.save_dir,
365
+ set_name="test"
366
+ )
367
+
368
+ logging.info(f"TEST METRICS : {test_metrics}")
369
+ return test_loss, test_metrics
370
+
371
+ # -- single epoch logic -------------------------------------
372
+
373
+ def _run_epoch(self, dataloader, epoch, train=False, set_name="train", pbar=None):
374
+ """execute one pass over a dataloader.
375
+
376
+ handles forward/backward, per-batch bookkeeping, DDP gather, metric
377
+ computation, confusion-matrix logging, and JSON prediction export.
378
+
379
+ Args:
380
+ dataloader: the DataLoader to iterate over.
381
+ epoch: the epoch number (for checkpointing and folder naming).
382
+ train: if True, compute gradients and update weights.
383
+ set_name: "train", "valid", or "test" (for logging and JSON).
384
+ pbar: optional tqdm progress bar.
385
+
386
+ Returns:
387
+ a tuple (all_logits, all_labels, avg_loss, metrics).
388
+ on non-rank-0 DDP workers the first two are None and metrics
389
+ is an empty dict.
390
+ """
391
+
392
+ import torch.distributed as dist
393
+
394
+ if train:
395
+ self.model.train()
396
+ else:
397
+ self.model.eval()
398
+
399
+ total_loss = 0.0
400
+ total_batches = 0
401
+
402
+ all_logits = []
403
+ all_labels = []
404
+ results = []
405
+
406
+ # -------- Create epoch folder --------
407
+ epoch_dir = os.path.join(self.save_dir, str(epoch))
408
+ os.makedirs(epoch_dir, exist_ok=True)
409
+ save_path = os.path.join(
410
+ epoch_dir, f"predictions_{set_name}_epoch_{epoch}.json"
411
+ )
412
+
413
+ # --- batch loop ---
414
+ for batch in dataloader:
415
+ if pbar:
416
+ pbar.update()
417
+
418
+ logits, labels, loss, has_labels = self._process_batch(batch, train)
419
+
420
+ if loss is not None:
421
+ total_loss += loss.item()
422
+ total_batches += 1
423
+
424
+ logits_cpu = logits.detach().cpu()
425
+ all_logits.append(logits_cpu)
426
+
427
+ if has_labels:
428
+ labels_cpu = labels.detach().cpu()
429
+ all_labels.append(labels_cpu)
430
+
431
+ # per-sample predictions for JSON export.
432
+ probs = torch.softmax(logits_cpu, dim=1)
433
+ preds = torch.argmax(probs, dim=1)
434
+ confs = probs.max(dim=1).values
435
+ ids = batch["id"]
436
+
437
+ for i in range(len(preds)):
438
+ results.append({
439
+ "id": ids[i],
440
+ "pred_label": self.class_names[preds[i].item()],
441
+ "confidence": float(confs[i].item()),
442
+ "pred_class_idx": preds[i].item(),
443
+ })
444
+
445
+ # --- concatenate local predictions ---
446
+ if len(all_logits) > 0:
447
+ all_logits = torch.cat(all_logits).numpy()
448
+ else:
449
+ all_logits = np.zeros((0, 1))
450
+
451
+ if len(all_labels) > 0:
452
+ all_labels = torch.cat(all_labels).numpy()
453
+ else:
454
+ all_labels = np.zeros((0,))
455
+
456
+ # --- DDP gather (handles uneven shard sizes) ---
457
+ if dist.is_initialized():
458
+ gathered = [None for _ in range(dist.get_world_size())]
459
+ dist.all_gather_object(gathered, (all_logits, all_labels, results))
460
+
461
+ if self.rank == 0:
462
+ all_logits = np.concatenate([g[0] for g in gathered])
463
+ all_labels = np.concatenate([g[1] for g in gathered])
464
+ results = [r for g in gathered for r in g[2]]
465
+ else:
466
+ return None, None, 0.0, {}
467
+
468
+ # --- metrics (rank-0 only in DDP) ---
469
+ if len(all_labels) > 0:
470
+ metrics = compute_classification_metrics(
471
+ (all_logits, all_labels), top_k=self.top_k,
472
+ )
473
+ else:
474
+ metrics = {}
475
+
476
+ # --- confusion matrix (validation and test only) ---
477
+ if self.rank == 0 and set_name in ["valid", "test"] and len(all_labels) > 0:
478
+ preds_all, labels_all, _ = process_preds_labels(
479
+ (all_logits, all_labels)
480
+ )
481
+ class_names = [
482
+ self.class_names[i] for i in sorted(self.class_names.keys())
483
+ ]
484
+
485
+ log_confusion_matrix_wandb(
486
+ y_true=labels_all.tolist(),
487
+ y_pred=preds_all.tolist(),
488
+ class_names=class_names,
489
+ split_name=set_name,
490
+ )
491
+
492
+ # --- save JSON (rank-0 only) ---
493
+ if self.rank == 0:
494
+ submission = {
495
+ "version": "2.0",
496
+ "task": "action_classification",
497
+ "date": datetime.now().strftime("%Y-%m-%d"),
498
+ "metadata": {"type": "predictions"},
499
+ "data": [],
500
+ }
501
+
502
+ for r in results:
503
+ submission["data"].append({
504
+ "id": r["id"],
505
+ "labels": {
506
+ "action": {
507
+ "label": r["pred_label"],
508
+ "confidence": r["confidence"],
509
+ }
510
+ },
511
+ })
512
+
513
+ logging.info(f"RESULTS Length: {len(results)}")
514
+ logging.info(f"Predicitions are stored at : {save_path}")
515
+ with open(save_path, "w") as f:
516
+ json.dump(submission, f, indent=2)
517
+
518
+ return all_logits, all_labels, total_loss / max(1, total_batches), metrics
519
+
520
+
521
+ # -- checkpoint saving ---------------------------------------
522
+
523
+ def _save_checkpoint(self, filename, epoch, tag=None):
524
+ epoch_dir = os.path.join(self.save_dir, str(filename))
525
+ os.makedirs(epoch_dir, exist_ok=True)
526
+
527
+ state = {
528
+ "epoch": epoch,
529
+ "state_dict": self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict(),
530
+ "optimizer": self.optimizer.state_dict(),
531
+ "scheduler": self.scheduler.state_dict(),
532
+ "monitor": self.monitor,
533
+ "mode": self.mode,
534
+ "best_metric": self.best_metric,
535
+ }
536
+
537
+ if hasattr(self, "scaler"):
538
+ state["scaler"] = self.scaler.state_dict()
539
+
540
+ name = f"epoch_{epoch}.pt"
541
+ if tag:
542
+ name = f"{tag}_epoch_{epoch}.pt"
543
+
544
+ path_aux = os.path.join(epoch_dir, name)
545
+ torch.save(state, path_aux)
546
+ logging.info(f"Saved checkpoint: {path_aux}")
547
+ return path_aux
548
+
549
+
550
+ # --------------------------------------------------------------
551
+ # modality-specific trainers
552
+ # --------------------------------------------------------------
553
+
554
+ class MVTrainerClassification(BaseTrainerClassification):
555
+ """forward pass for multi-view video classification.
556
+
557
+ expects batches with pixel_values of shape
558
+ (B, V, C, T, H, W) and integer labels of shape (B,).
559
+ """
560
+
561
+ def _forward_batch(self, batch):
562
+ """move video clips to device and run the model.
563
+
564
+ Args:
565
+ batch: dict with keys "pixel_values" and "labels".
566
+
567
+ Returns:
568
+ a tuple (logits, labels) on self.device.
569
+ """
570
+ mvclips = batch["pixel_values"].to(self.device).float()
571
+ labels = batch.get("labels", None)
572
+ if labels is not None:
573
+ labels = labels.to(self.device)
574
+
575
+ outputs = self.model(mvclips)
576
+
577
+ if isinstance(outputs, tuple):
578
+ logits = outputs[0]
579
+ else:
580
+ logits = outputs
581
+
582
+ if logits.dim() == 1:
583
+ logits = logits.unsqueeze(0)
584
+
585
+ return logits, labels
586
+
587
+
588
+ # ============================================================
589
+ # Tracking Trainer
590
+ # ============================================================
591
+
592
+ class TrackingTrainerClassification(BaseTrainerClassification):
593
+ """forward pass for tracking-based classification.
594
+
595
+ expects batches with x of shape (B, N, 2), edge_index of shape (2, E),
596
+ batch of shape (B,), batch_size, seq_len, and integer labels of shape (B,).
597
+ """
598
+
599
+ def _forward_batch(self, batch):
600
+ """move tracking data to device and run the model.
601
+
602
+ Args:
603
+ batch: dict with keys "x", "edge_index", "batch", "batch_size",
604
+ "seq_len", and "labels".
605
+
606
+ Returns:
607
+ a tuple (logits, labels) on self.device.
608
+ """
609
+ tracking_batch = {
610
+ "x": batch["x"].to(self.device),
611
+ "edge_index": batch["edge_index"].to(self.device),
612
+ "batch": batch["batch"].to(self.device),
613
+ "batch_size": batch["batch_size"],
614
+ "seq_len": batch["seq_len"],
615
+ }
616
+ labels = batch.get("labels", None)
617
+ if labels is not None:
618
+ labels = labels.to(self.device)
619
+
620
+ logits = self.model(tracking_batch)
621
+
622
+ return logits, labels
623
+
624
+ class FramesTrainerClassification(BaseTrainerClassification):
625
+ """forward pass for frames_npy video classification.
626
+
627
+ supports optional mixed-precision training (AMP) and mixup
628
+ augmentation, controlled via config.TRAIN.use_amp and
629
+ config.TRAIN.mixup_alpha respectively.
630
+
631
+ expects batches with pixel_values of shape (B, T, H, W, C)
632
+ and integer labels of shape (B,).
633
+ """
634
+
635
+ def __init__(self, *args, **kwargs):
636
+ super().__init__(*args, **kwargs)
637
+ cfg = self.config
638
+ self.use_amp = getattr(cfg.TRAIN, "use_amp", False) if cfg else False
639
+ self.mixup_alpha = getattr(cfg.TRAIN, "mixup_alpha", 0.0) if cfg else 0.0
640
+ self.scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
641
+
642
+ def _forward_batch(self, batch):
643
+ pixel_values = batch["pixel_values"].to(self.device).float()
644
+ labels = batch["labels"].to(self.device)
645
+ logits = self.model({"pixel_values": pixel_values})
646
+ return logits, labels
647
+
648
+ def _process_batch(self, batch, train):
649
+ pixel_values = batch["pixel_values"].to(self.device).float()
650
+ labels = batch["labels"].to(self.device)
651
+
652
+ with torch.set_grad_enabled(train):
653
+ use_mixup = (
654
+ train
655
+ and self.mixup_alpha > 0
656
+ and np.random.random() > 0.5
657
+ )
658
+
659
+ with torch.amp.autocast("cuda", enabled=self.use_amp):
660
+ if use_mixup:
661
+ pixel_values, labels_a, labels_b, lam = mixup_data(
662
+ pixel_values, labels, self.mixup_alpha
663
+ )
664
+ logits = self.model({"pixel_values": pixel_values})
665
+ loss = (
666
+ lam * self.criterion(output=logits, labels=labels_a)
667
+ + (1 - lam) * self.criterion(output=logits, labels=labels_b)
668
+ )
669
+ labels = labels_a
670
+ else:
671
+ logits = self.model({"pixel_values": pixel_values})
672
+ if self.class_weights is not None:
673
+ loss = self.criterion(
674
+ output=logits, labels=labels,
675
+ weight=self.class_weights.to(self.device),
676
+ )
677
+ else:
678
+ loss = self.criterion(output=logits, labels=labels)
679
+
680
+ if train:
681
+ self.optimizer.zero_grad(set_to_none=True)
682
+ self.scaler.scale(loss).backward()
683
+ self.scaler.unscale_(self.optimizer)
684
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
685
+ self.scaler.step(self.optimizer)
686
+ self.scaler.update()
687
+
688
+ return logits, labels, loss
689
+
690
+ # --------------------------------------------------------------
691
+ # unified trainer dispatcher
692
+ # --------------------------------------------------------------
693
+
694
+ class Trainer_Classification:
695
+ """high-level trainer that dispatches to the right modality trainer.
696
+
697
+ consumed by ClassificationAPI. Responsible for building data
698
+ loaders, optimizers, schedulers, and samplers, then delegating the
699
+ actual loop to MVTrainerClassification or TrackingTrainerClassification.
700
+
701
+ Args:
702
+ config: the configuration object.
703
+ """
704
+
705
+ def __init__(self, config):
706
+ self.config = config
707
+ self.device = select_device(self.config.SYSTEM)
708
+ self.model = None
709
+ self.optimizer = None
710
+ self.scheduler = None
711
+ self.epoch = 0
712
+ self.trainer = None
713
+
714
+ def compute_metrics(self, pred, mode="logits"):
715
+ """thin wrapper around the metric module.
716
+
717
+ Args:
718
+ pred: a tuple (logits, labels).
719
+ mode: "logits" or "labels" (default: "logits").
720
+
721
+ Returns:
722
+ a dictionary of classification metrics.
723
+ """
724
+ return compute_classification_metrics(
725
+ pred, top_k=2, mode=mode
726
+ )
727
+
728
+ # -- training -----------------------------------------------
729
+
730
+ def train(self, model, train_dataset, val_dataset=None, rank=0, world_size=1):
731
+ """build all training components and run the loop.
732
+
733
+ detects the model type (HuggingFace vs. custom) and the data
734
+ modality (video vs. tracking) to select the right trainer class,
735
+ sampler, and collate function.
736
+
737
+ Args:
738
+ model: the classification model.
739
+ train_dataset: training ClassificationDataset.
740
+ val_dataset: validation ClassificationDataset (optional).
741
+ rank: GPU rank (0-indexed).
742
+ world_size: total number of GPUs.
743
+ """
744
+ from opensportslib.core.loss.builder import build_criterion
745
+ from opensportslib.core.optimizer.builder import build_optimizer
746
+ from opensportslib.core.scheduler.builder import build_scheduler
747
+ from opensportslib.core.utils.data import tracking_collate_fn
748
+ from torch.nn.parallel import DistributedDataParallel as DDP
749
+ from torch.utils.data.distributed import DistributedSampler
750
+
751
+ is_ddp = world_size > 1
752
+ modality = getattr(self.config.DATA, 'data_modality', 'video')
753
+ seed = self.config.SYSTEM.seed
754
+
755
+ g = torch.Generator()
756
+ g.manual_seed(seed)
757
+
758
+ # HuggingFace models (e.g. VideoMAE) use the HF Trainer.
759
+ if self.config.MODEL.type == "huggingface":
760
+ self._train_huggingface(model, train_dataset, val_dataset)
761
+ return
762
+
763
+ if is_ddp:
764
+ torch.cuda.set_device(rank)
765
+ self.device = torch.device(f"cuda:{rank}")
766
+ else:
767
+ self.device = select_device(self.config.SYSTEM)
768
+
769
+ self.model = model.to(self.device)
770
+
771
+ if is_ddp:
772
+ self.model = DDP(self.model, device_ids=[rank])
773
+
774
+ # Build components
775
+ optimizer = build_optimizer(
776
+ self.model.parameters(), cfg=self.config.TRAIN.optimizer
777
+ )
778
+ scheduler = build_scheduler(
779
+ optimizer, cfg=self.config.TRAIN.scheduler
780
+ )
781
+ criterion = build_criterion(self.config.TRAIN.criterion)
782
+
783
+ # --- class weights for the loss ---
784
+ if self.config.TRAIN.use_weighted_loss:
785
+ class_weights = train_dataset.get_class_weights(
786
+ num_classes=train_dataset.num_classes(), sqrt=True
787
+ ).to(self.device)
788
+ else:
789
+ class_weights = None
790
+
791
+ # tracking modality needs a customm collate that merges PyG
792
+ # Data objects into a single batched graph per timestamp.
793
+ collate_fn = tracking_collate_fn if modality == "tracking_parquet" else None
794
+
795
+ # --- train sampler ---
796
+ if self.config.TRAIN.use_weighted_sampler:
797
+ sample_weights = train_dataset.get_sample_weights()
798
+
799
+ samples_per_class = getattr(
800
+ self.config.TRAIN, 'samples_per_class', None
801
+ )
802
+ if samples_per_class:
803
+ num_classes = train_dataset.num_classes()
804
+ num_samples = samples_per_class * num_classes
805
+ else:
806
+ num_samples = len(sample_weights)
807
+
808
+ if is_ddp:
809
+ train_sampler = DistributedWeightedSampler(
810
+ weights=sample_weights,
811
+ num_replicas=world_size,
812
+ rank=rank,
813
+ replacement=True,
814
+ num_samples=num_samples,
815
+ seed=self.config.SYSTEM.seed
816
+ )
817
+ else:
818
+ train_sampler = WeightedRandomSampler(
819
+ weights=sample_weights,
820
+ num_samples=num_samples,
821
+ replacement=True,
822
+ generator=g
823
+ )
824
+
825
+ shuffle = False
826
+
827
+ else:
828
+ if is_ddp:
829
+ train_sampler = DistributedSampler(
830
+ train_dataset,
831
+ num_replicas=world_size,
832
+ rank=rank,
833
+ shuffle=True,
834
+ drop_last=True
835
+ )
836
+ else:
837
+ train_sampler = None
838
+
839
+ shuffle = not is_ddp
840
+
841
+
842
+ # --- validation sampler ---
843
+ if is_ddp:
844
+ val_sampler = DistributedSampler(
845
+ val_dataset,
846
+ num_replicas=world_size,
847
+ rank=rank,
848
+ shuffle=False,
849
+ drop_last=False
850
+ )
851
+ else:
852
+ val_sampler = None
853
+
854
+ num_train_workers = self.config.DATA.train.dataloader.num_workers
855
+ num_val_workers = self.config.DATA.valid.dataloader.num_workers
856
+
857
+ train_loader = DataLoader(
858
+ train_dataset,
859
+ batch_size=self.config.DATA.train.dataloader.batch_size,
860
+ shuffle=(train_sampler is None and shuffle),
861
+ sampler=train_sampler,
862
+ num_workers=num_train_workers,
863
+ pin_memory=True,
864
+ collate_fn=collate_fn,
865
+ worker_init_fn=seed_worker,
866
+ generator=g,
867
+ drop_last=True,
868
+ persistent_workers=num_train_workers > 0,
869
+ prefetch_factor=4 if num_train_workers > 0 else None,
870
+ )
871
+
872
+ val_loader = DataLoader(
873
+ val_dataset,
874
+ batch_size=self.config.DATA.valid.dataloader.batch_size,
875
+ shuffle=False,
876
+ sampler=val_sampler,
877
+ num_workers=num_val_workers,
878
+ pin_memory=True,
879
+ collate_fn=collate_fn,
880
+ worker_init_fn=seed_worker,
881
+ generator=g,
882
+ persistent_workers=num_val_workers > 0,
883
+ prefetch_factor=4 if num_val_workers > 0 else None,
884
+ )
885
+
886
+ # select the modality-specific trainer.
887
+ if modality == "tracking_parquet":
888
+ TrainerClass = TrackingTrainerClassification
889
+ elif modality == "frames_npy":
890
+ TrainerClass = FramesTrainerClassification
891
+ else:
892
+ TrainerClass = MVTrainerClassification
893
+
894
+ self.trainer = TrainerClass(
895
+ train_loader=train_loader,
896
+ val_loader=val_loader,
897
+ test_loader=None,
898
+ model=self.model,
899
+ optimizer=optimizer,
900
+ scheduler=scheduler,
901
+ criterion=criterion,
902
+ class_weights=class_weights,
903
+ class_names=train_dataset.label_map,
904
+ save_dir=self.config.SYSTEM.save_dir,
905
+ model_name=self.config.MODEL.backbone.type,
906
+ max_epochs=self.config.TRAIN.epochs,
907
+ device=self.device,
908
+ top_k=2,
909
+ patience=getattr(self.config.TRAIN, "patience", 0),
910
+ monitor=getattr(self.config.TRAIN, "monitor", "balanced_accuracy"),
911
+ mode=getattr(self.config.TRAIN, "mode", "max"),
912
+ revert_on_lr_reduction=(modality in ("tracking_parquet", "frames_npy")),
913
+ config=self.config,
914
+ )
915
+
916
+ self.trainer.train(epoch_start=self.epoch, save_every=self.config.TRAIN.save_every)
917
+ return getattr(self.trainer, "best_checkpoint_path", None)
918
+
919
+ def _train_huggingface(self, model, train_dataset, val_dataset):
920
+ """Handle HuggingFace Trainer for VideoMAE."""
921
+ from opensportslib.core.sampler.weighted_sampler import WeightedTrainer, VideoMAETrainer
922
+
923
+ self.model = model
924
+
925
+ args = TrainingArguments(
926
+ label_names=["labels"],
927
+ output_dir=self.config.SYSTEM.save_dir,
928
+ per_device_train_batch_size=self.config.DATA.train.dataloader.batch_size,
929
+ per_device_eval_batch_size=self.config.DATA.valid.dataloader.batch_size,
930
+ num_train_epochs=self.config.TRAIN.epochs,
931
+ eval_strategy="epoch" if val_dataset else "no",
932
+ save_strategy="epoch",
933
+ logging_strategy="steps",
934
+ logging_steps=5,
935
+ save_total_limit=10,
936
+ load_best_model_at_end=True,
937
+ fp16=True,
938
+ warmup_ratio=0.1,
939
+ )
940
+
941
+ if self.config.TRAIN.use_weighted_sampler:
942
+ self.trainer = WeightedTrainer(
943
+ model=self.model,
944
+ args=args,
945
+ train_dataset=train_dataset,
946
+ eval_dataset=val_dataset,
947
+ compute_metrics=self.compute_metrics,
948
+ config=self.config
949
+ )
950
+ else:
951
+ self.trainer = VideoMAETrainer(
952
+ model=self.model,
953
+ args=args,
954
+ train_dataset=train_dataset,
955
+ eval_dataset=val_dataset,
956
+ compute_metrics=self.compute_metrics,
957
+ config=self.config
958
+ )
959
+
960
+ self.trainer.train()
961
+ #############
962
+ train_metrics = self.hf_trainer.evaluate(train_dataset, metric_key_prefix="train")
963
+ logging.info(f"TRAIN METRICS: {train_metrics}")
964
+ #############
965
+
966
+ def infer(self, test_dataset, rank=0, world_size=1):
967
+ if self.config.MODEL.type == "huggingface":
968
+
969
+ args = TrainingArguments(
970
+ output_dir=self.config.SYSTEM.save_dir, # any directory, not used here
971
+ per_device_eval_batch_size=1#self.config.DATA.valid.dataloader.batch_size, # or whatever batch size you want
972
+ )
973
+
974
+ self.hf_trainer = HFTrainer(
975
+ model=self.model,
976
+ args=args,
977
+ compute_metrics=self.compute_metrics # optional, can compute later manually
978
+ )
979
+
980
+ preds_output = self.hf_trainer.predict(test_dataset)
981
+ logits = preds_output.predictions
982
+ # if isinstance(logits, tuple):
983
+ # logits = logits[0]
984
+
985
+ # predictions = np.argmax(logits, axis=-1)
986
+ labels = preds_output.label_ids
987
+ metrics = self.compute_metrics((logits, labels))
988
+
989
+ else:
990
+ from opensportslib.core.loss.builder import build_criterion
991
+ from opensportslib.core.optimizer.builder import build_optimizer
992
+ from opensportslib.core.scheduler.builder import build_scheduler
993
+ from opensportslib.core.utils.data import tracking_collate_fn
994
+ from torch.nn.parallel import DistributedDataParallel as DDP
995
+ from torch.utils.data.distributed import DistributedSampler
996
+
997
+ is_ddp = world_size > 1
998
+
999
+ if is_ddp:
1000
+ torch.cuda.set_device(rank)
1001
+ self.device = torch.device(f"cuda:{rank}")
1002
+ else:
1003
+ self.device = select_device(self.config.SYSTEM)
1004
+
1005
+ # model
1006
+ self.model = self.model.to(self.device)
1007
+ if is_ddp:
1008
+ self.model = DDP(self.model, device_ids=[rank])
1009
+ test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size)
1010
+ else:
1011
+ test_sampler = None
1012
+
1013
+ modality = getattr(self.config.DATA, 'data_modality', 'video')
1014
+ collate_fn = tracking_collate_fn if modality == "tracking_parquet" else None
1015
+
1016
+ test_loader = DataLoader(
1017
+ test_dataset,
1018
+ batch_size=self.config.DATA.test.dataloader.batch_size,
1019
+ shuffle=False,
1020
+ sampler=test_sampler,
1021
+ num_workers=self.config.DATA.test.dataloader.num_workers,
1022
+ pin_memory=True,
1023
+ collate_fn=collate_fn
1024
+ )
1025
+
1026
+ optimizer = self.optimizer if self.optimizer is not None else build_optimizer(self.model.parameters(), cfg=self.config.TRAIN.optimizer)
1027
+ scheduler = self.scheduler if self.scheduler is not None else build_scheduler(optimizer, cfg=self.config.TRAIN.scheduler)
1028
+ criterion = build_criterion(self.config.TRAIN.criterion)
1029
+
1030
+ # Select trainer class based on modality
1031
+ if modality == "tracking_parquet":
1032
+ TrainerClass = TrackingTrainerClassification
1033
+ elif modality == "frames_npy":
1034
+ TrainerClass = FramesTrainerClassification
1035
+ else:
1036
+ TrainerClass = MVTrainerClassification
1037
+
1038
+ self.test_trainer = TrainerClass(
1039
+ train_loader=None,
1040
+ val_loader=None,
1041
+ test_loader=test_loader,
1042
+ model=self.model,
1043
+ optimizer=optimizer,
1044
+ scheduler=scheduler,
1045
+ criterion=criterion,
1046
+ class_weights=None,
1047
+ class_names=test_dataset.label_map,
1048
+ save_dir=self.config.SYSTEM.save_dir,
1049
+ model_name=self.config.MODEL.backbone.type,
1050
+ max_epochs=self.config.TRAIN.epochs,
1051
+ device=self.device,
1052
+ top_k=2,
1053
+ monitor=getattr(self.config.TRAIN, "monitor", "balanced_accuracy"),
1054
+ mode=getattr(self.config.TRAIN, "mode", "max"),
1055
+ revert_on_lr_reduction=(modality in ("tracking_parquet", "frames_npy")),
1056
+ config=self.config,
1057
+ )
1058
+ loss, metrics = self.test_trainer.test(
1059
+ detailed_results=getattr(self.config.TRAIN, 'detailed_results', False)
1060
+ )
1061
+
1062
+ return metrics
1063
+
1064
+ def evaluate(self, pred_path, gt_path, class_names, exclude_labels=[]):
1065
+
1066
+ label_to_idx = {v: k for k, v in class_names.items()}
1067
+
1068
+ with open(pred_path) as f:
1069
+ pred_data = json.load(f)
1070
+
1071
+ with open(gt_path) as f:
1072
+ gt_data = json.load(f)
1073
+
1074
+ gt_dict = {}
1075
+ for item in gt_data["data"]:
1076
+ sid = item["id"]
1077
+ gt_label = item["labels"]["action"]["label"]
1078
+ if gt_label not in exclude_labels:
1079
+ gt_dict[sid] = label_to_idx[gt_label]
1080
+
1081
+ preds = []
1082
+ labels = []
1083
+
1084
+ for item in pred_data["data"]:
1085
+ sid = item["id"]
1086
+ if sid not in gt_dict:
1087
+ continue
1088
+
1089
+ pred_label = item["labels"]["action"]["label"]
1090
+
1091
+ preds.append(label_to_idx[pred_label])
1092
+ labels.append(gt_dict[sid])
1093
+
1094
+ metrics = self.compute_metrics(
1095
+ (preds, labels),
1096
+ mode="labels"
1097
+ )
1098
+ return metrics
1099
+
1100
+
1101
+ def demo(self, model, video_paths):
1102
+ pass
1103
+
1104
+ def save(self, model, path, processor=None, tokenizer=None, optimizer=None, epoch=None):
1105
+ """
1106
+ Save model checkpoint
1107
+ """
1108
+ save_checkpoint(model, path, processor, tokenizer, optimizer, epoch)
1109
+ logging.info(f"Model saved at {path}")
1110
+
1111
+ def load(self, path, optimizer=None, scheduler=None):
1112
+ """
1113
+ Load model checkpoint. Returns loaded model, optimizer, epoch
1114
+ """
1115
+ if self.config.MODEL.type == "huggingface":
1116
+ epoch = None
1117
+ self.model, processor = load_huggingface_checkpoint(self.config, path=path, device=self.device)
1118
+ logging.info(f"Model loaded from {path}")
1119
+ return self.model, processor, scheduler, epoch
1120
+ else:
1121
+ from opensportslib.models.builder import build_model
1122
+ if self.model is None:
1123
+ self.model, _ = build_model(self.config, self.device)
1124
+ self.model, optimizer, scheduler, epoch = load_checkpoint(
1125
+ self.model, path, optimizer, scheduler, device=self.device
1126
+ )
1127
+ self.optimizer = optimizer
1128
+ self.scheduler = scheduler
1129
+ self.epoch = epoch
1130
+ logging.info(f"Model loaded from {path}, epoch: {epoch}")
1131
+ return self.model, self.optimizer, self.scheduler, self.epoch