opensportslib 0.1.1__tar.gz → 0.1.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {opensportslib-0.1.1/opensportslib.egg-info → opensportslib-0.1.1.dev2}/PKG-INFO +5 -1
  2. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/README.md +1 -0
  3. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/classification.py +8 -5
  4. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/localization.py +5 -5
  5. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/classification_trainer.py +19 -16
  6. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/localization_trainer.py +14 -13
  7. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/config.py +17 -2
  8. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/data.py +11 -2
  9. opensportslib-0.1.1.dev2/opensportslib/core/utils/wandb.py +280 -0
  10. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/classification_dataset.py +21 -3
  11. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/backbones/builder.py +54 -16
  12. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/tracking.py +1 -2
  13. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/utils.py +8 -1
  14. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2/opensportslib.egg-info}/PKG-INFO +5 -1
  15. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/SOURCES.txt +6 -1
  16. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/requires.txt +4 -0
  17. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/top_level.txt +2 -0
  18. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/pyproject.toml +2 -1
  19. opensportslib-0.1.1.dev2/tests/conftest.py +59 -0
  20. opensportslib-0.1.1.dev2/tests/test_config_utils_smoke.py +46 -0
  21. opensportslib-0.1.1.dev2/tests/test_package_smoke.py +23 -0
  22. opensportslib-0.1.1.dev2/tests/test_public_apis_smoke.py +29 -0
  23. opensportslib-0.1.1.dev2/tests/test_subset_train_infer_integration.py +172 -0
  24. opensportslib-0.1.1/opensportslib/core/utils/wandb.py +0 -120
  25. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/LICENSE +0 -0
  26. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/LICENSE-COMMERCIAL +0 -0
  27. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/MANIFEST.in +0 -0
  28. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/examples/quickstart/basic_classification.py +0 -0
  29. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/examples/quickstart/basic_localization.py +0 -0
  30. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/__init__.py +0 -0
  31. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/__init__.py +0 -0
  32. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/classification.yaml +0 -0
  33. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  34. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  35. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  36. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization.yaml +0 -0
  37. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/sngar-frames.yaml +0 -0
  38. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/sngar-tracking.yaml +0 -0
  39. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/__init__.py +0 -0
  40. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/__init__.py +0 -0
  41. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/builder.py +0 -0
  42. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/calf.py +0 -0
  43. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/ce.py +0 -0
  44. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/combine.py +0 -0
  45. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/nll.py +0 -0
  46. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
  47. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/optimizer/builder.py +0 -0
  48. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  49. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
  50. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/scheduler/builder.py +0 -0
  51. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/__init__.py +0 -0
  52. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
  53. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/ddp.py +0 -0
  54. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/default_args.py +0 -0
  55. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/lightning.py +0 -0
  56. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
  57. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/seed.py +0 -0
  58. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/video_processing.py +0 -0
  59. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/__init__.py +0 -0
  60. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/builder.py +0 -0
  61. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/localization_dataset.py +0 -0
  62. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
  63. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
  64. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/metrics/classification_metric.py +0 -0
  65. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/metrics/localization_metric.py +0 -0
  66. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/__init__.py +0 -0
  67. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/contextaware.py +0 -0
  68. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/e2e.py +0 -0
  69. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
  70. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/vars.py +0 -0
  71. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/video.py +0 -0
  72. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/video_mae.py +0 -0
  73. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/builder.py +0 -0
  74. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/heads/builder.py +0 -0
  75. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/neck/builder.py +0 -0
  76. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/common.py +0 -0
  77. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
  78. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
  79. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
  80. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
  81. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
  82. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
  83. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/litebase.py +0 -0
  84. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/modules.py +0 -0
  85. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/shift.py +0 -0
  86. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
  87. {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.1
3
+ Version: 0.1.1.dev2
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -33,6 +33,9 @@ Requires-Dist: torch-scatter; extra == "py-geometric"
33
33
  Requires-Dist: torch-sparse; extra == "py-geometric"
34
34
  Requires-Dist: torch-cluster; extra == "py-geometric"
35
35
  Requires-Dist: torch-spline-conv; extra == "py-geometric"
36
+ Provides-Extra: test
37
+ Requires-Dist: pytest; extra == "test"
38
+ Requires-Dist: pytest-cov; extra == "test"
36
39
  Dynamic: license-file
37
40
 
38
41
  # OpenSportsLib
@@ -192,6 +195,7 @@ Generate text descriptions for sports events and temporal segments.
192
195
  Use the README for the fast start, then go deeper through:
193
196
 
194
197
  - Full documentation: https://opensportslab.github.io/opensportslib/
198
+ - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
195
199
  - Example configs: [examples/configs/](examples/configs/)
196
200
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
197
201
  - Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
@@ -155,6 +155,7 @@ Generate text descriptions for sports events and temporal segments.
155
155
  Use the README for the fast start, then go deeper through:
156
156
 
157
157
  - Full documentation: https://opensportslab.github.io/opensportslib/
158
+ - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
158
159
  - Example configs: [examples/configs/](examples/configs/)
159
160
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
160
161
  - Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
@@ -40,8 +40,8 @@ class ClassificationAPI:
40
40
  if config is None:
41
41
  raise ValueError("config path is required")
42
42
 
43
- config_path = expand(config)
44
- self.config = load_config_omega(config_path)
43
+ self.config_path = expand(config)
44
+ self.config = load_config_omega(self.config_path)
45
45
 
46
46
  # let the caller override the dataset root directory.
47
47
  self.config.DATA.data_dir = expand(
@@ -88,6 +88,7 @@ class ClassificationAPI:
88
88
  rank,
89
89
  world_size,
90
90
  mode,
91
+ config_path,
91
92
  config,
92
93
  return_queue=None,
93
94
  train_set=None,
@@ -134,7 +135,7 @@ class ClassificationAPI:
134
135
  logging.getLogger().setLevel(logging.ERROR)
135
136
 
136
137
  if rank == 0:
137
- init_wandb(config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
138
+ init_wandb(config_path, config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
138
139
 
139
140
  # reproducibility:
140
141
  # we default to reproducible training, but allow the user to
@@ -242,7 +243,7 @@ class ClassificationAPI:
242
243
  mp.spawn(
243
244
  ClassificationAPI._worker_ddp,
244
245
  args=(
245
- world_size, "train", self.config, queue,
246
+ world_size, "train", self.config_path, self.config, queue,
246
247
  train_set, valid_set, None, pretrained, use_wandb
247
248
  ),
248
249
  nprocs=world_size,
@@ -253,6 +254,7 @@ class ClassificationAPI:
253
254
  rank=0,
254
255
  world_size=1,
255
256
  mode="train",
257
+ config_path=self.config_path,
256
258
  config=self.config,
257
259
  return_queue=queue,
258
260
  train_set=train_set,
@@ -321,7 +323,7 @@ class ClassificationAPI:
321
323
  mp.spawn(
322
324
  ClassificationAPI._worker_ddp,
323
325
  args=(
324
- world_size, "infer", self.config, queue,
326
+ world_size, "infer",self.config_path, self.config, queue,
325
327
  None, None, test_set, pretrained, use_wandb
326
328
  ),
327
329
  nprocs=world_size,
@@ -331,6 +333,7 @@ class ClassificationAPI:
331
333
  rank=0,
332
334
  world_size=1,
333
335
  mode="infer",
336
+ config_path=self.config_path,
334
337
  config=self.config,
335
338
  return_queue=queue,
336
339
  test_set=test_set,
@@ -14,8 +14,8 @@ class LocalizationAPI:
14
14
 
15
15
  # Load config
16
16
  ### load data_dor first then do load config with omega to resolve $paths
17
- config_path = expand(config)
18
- self.config = load_config_omega(config_path)
17
+ self.config_path = expand(config)
18
+ self.config = load_config_omega(self.config_path)
19
19
  # User must control dataset folder
20
20
  self.config.DATA.data_dir = expand(data_dir or self.config.DATA.data_dir)
21
21
  print(self.config.DATA.classes)
@@ -46,7 +46,7 @@ class LocalizationAPI:
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
- print("CONFIG PATH :", config_path)
49
+ print("CONFIG PATH :", self.config_path)
50
50
  print("DATA DIR :", self.config.DATA.data_dir)
51
51
  print("SAVEDIR:", self.config.SYSTEM.save_dir)
52
52
  print("Classes :", self.config.DATA.classes)
@@ -76,7 +76,7 @@ class LocalizationAPI:
76
76
 
77
77
  self.config = resolve_config_omega(self.config)
78
78
  check_config(self.config, split="train")
79
- init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
79
+ init_wandb(self.config_path, self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
80
80
  logging.info("Configuration:")
81
81
  logging.info(self.config)
82
82
  #print(self.config)
@@ -164,7 +164,7 @@ class LocalizationAPI:
164
164
  self.config = resolve_config_omega(self.config)
165
165
  check_config(self.config, split="test")
166
166
  self.config.infer_split = whether_infer_split(self.config.DATA.test)
167
- init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
167
+ init_wandb(self.config_path, self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
168
168
  logging.info("Configuration:")
169
169
  logging.info(self.config)
170
170
  # Start Timing
@@ -286,14 +286,15 @@ class BaseTrainerClassification:
286
286
 
287
287
  if self.rank == 0:
288
288
  # ---------------- W&B LOG ----------------
289
- wandb.log({
290
- "epoch": epoch + 1,
291
- "lr": current_lr,
292
- "train/loss": train_loss,
293
- "valid/loss": val_loss,
294
- **{f"train/{k}": v for k, v in train_metrics.items()},
295
- **{f"valid/{k}": v for k, v in val_metrics.items()},
296
- })
289
+ if wandb.run is not None:
290
+ wandb.log({
291
+ "epoch": epoch + 1,
292
+ "lr": current_lr,
293
+ "train/loss": train_loss,
294
+ "valid/loss": val_loss,
295
+ **{f"train/{k}": v for k, v in train_metrics.items()},
296
+ **{f"valid/{k}": v for k, v in val_metrics.items()},
297
+ })
297
298
 
298
299
  logging.info(f"Train Loss: {train_loss:.4f} | Train Bal Acc: {train_metric:.4f}")
299
300
  logging.info(f"Val Loss: {val_loss:.4f} | Val Bal Acc: {val_metric:.4f}")
@@ -316,9 +317,10 @@ class BaseTrainerClassification:
316
317
  best_path = self._save_checkpoint("best", epoch + 1, tag="best")
317
318
  self.best_checkpoint_path = best_path
318
319
 
319
- artifact = wandb.Artifact("model-checkpoint", type="model")
320
- artifact.add_file(best_path)
321
- wandb.log_artifact(artifact)
320
+ if wandb.run is not None:
321
+ artifact = wandb.Artifact("model-checkpoint", type="model")
322
+ artifact.add_file(best_path)
323
+ wandb.log_artifact(artifact)
322
324
 
323
325
  if self.rank == 0:
324
326
  logging.info(f"Best checkpoint : {self.best_checkpoint_path}")
@@ -350,10 +352,11 @@ class BaseTrainerClassification:
350
352
  pbar.close()
351
353
 
352
354
  if self.rank==0:
353
- wandb.log({
354
- "test/loss": test_loss,
355
- **{f"test/{k}": v for k, v in test_metrics.items()},
356
- })
355
+ if wandb.run is not None:
356
+ wandb.log({
357
+ "test/loss": test_loss,
358
+ **{f"test/{k}": v for k, v in test_metrics.items()},
359
+ })
357
360
 
358
361
  if detailed_results:
359
362
  from opensportslib.metrics.classification_metric import (
@@ -1128,4 +1131,4 @@ class Trainer_Classification:
1128
1131
  self.scheduler = scheduler
1129
1132
  self.epoch = epoch
1130
1133
  logging.info(f"Model loaded from {path}, epoch: {epoch}")
1131
- return self.model, self.optimizer, self.scheduler, self.epoch
1134
+ return self.model, self.optimizer, self.scheduler, self.epoch
@@ -403,15 +403,16 @@ class Trainer_e2e(Trainer):
403
403
  )
404
404
 
405
405
  # ---------------- W&B LOG ----------------
406
- wandb.log({
407
- "epoch": epoch + 1,
408
- "train/loss": train_loss,
409
- "valid/loss": valid_loss,
410
- "valid/mAP": valid_mAP,
411
- "lr": self.optimizer.param_groups[0]["lr"],
412
- "best/mAP": self.best_criterion_valid if self.criterion_valid == "map" else None,
413
- "best/loss": self.best_criterion_valid if self.criterion_valid == "loss" else None,
414
- })
406
+ if wandb.run is not None:
407
+ wandb.log({
408
+ "epoch": epoch + 1,
409
+ "train/loss": train_loss,
410
+ "valid/loss": valid_loss,
411
+ "valid/mAP": valid_mAP,
412
+ "lr": self.optimizer.param_groups[0]["lr"],
413
+ "best/mAP": self.best_criterion_valid if self.criterion_valid == "map" else None,
414
+ "best/loss": self.best_criterion_valid if self.criterion_valid == "loss" else None,
415
+ })
415
416
 
416
417
  if self.save_dir is not None:
417
418
  os.makedirs(self.save_dir, exist_ok=True)
@@ -630,9 +631,10 @@ class Inferer:
630
631
  cfg.DATA.test.dataloader,
631
632
  return_pred=False,
632
633
  )
633
- wandb.log({
634
- "test/Avg_mAP": mAP,
635
- })
634
+ if wandb.run is not None:
635
+ wandb.log({
636
+ "test/Avg_mAP": mAP,
637
+ })
636
638
  pred_json_file = os.path.join(pred_file + ".json")
637
639
  pred_recall_file = os.path.join(pred_file + ".recall.json.gz")
638
640
  logging.info("Predictions saved")
@@ -1092,4 +1094,3 @@ class Evaluator:
1092
1094
 
1093
1095
  return results
1094
1096
 
1095
-
@@ -26,7 +26,22 @@ def dict_to_namespace(d, skip_keys=("classes",)):
26
26
  return d
27
27
 
28
28
  def namespace_to_dict(ns):
29
- return {k: vars(v) if hasattr(v, "__dict__") else v for k, v in vars(ns).items()}
29
+ """
30
+ Recursively convert namespace/dict/list containers into plain Python types.
31
+ """
32
+ if ns is None or isinstance(ns, (str, int, float, bool)):
33
+ return ns
34
+
35
+ if isinstance(ns, dict):
36
+ return {str(k): namespace_to_dict(v) for k, v in ns.items()}
37
+
38
+ if isinstance(ns, (list, tuple, set)):
39
+ return [namespace_to_dict(v) for v in ns]
40
+
41
+ if hasattr(ns, "__dict__"):
42
+ return {str(k): namespace_to_dict(v) for k, v in vars(ns).items()}
43
+
44
+ return ns
30
45
 
31
46
  def namespace_to_omegaconf(ns):
32
47
  """
@@ -196,4 +211,4 @@ def is_local_path(p):
196
211
  return p and (
197
212
  os.path.exists(p) or
198
213
  p.endswith((".pt", ".pth", ".tar"))
199
- )
214
+ )
@@ -52,7 +52,16 @@ def tracking_collate_fn(batch):
52
52
  Custom collate function for tracking data.
53
53
  Uses PyG Batch.from_data_list for efficient C++ batching.
54
54
  """
55
- from torch_geometric.data import Batch
55
+ try:
56
+ from torch_geometric.data import Batch
57
+ except ImportError as exc:
58
+ raise ImportError(
59
+ "torch-geometric is required for tracking_collate_fn. "
60
+ "Install with: pip install \"opensportslib[py-geometric]\" "
61
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
62
+ "or (editable): pip install -e \".[py-geometric]\" "
63
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
64
+ ) from exc
56
65
 
57
66
  batch_size = len(batch)
58
67
  seq_len = batch[0]['seq_len']
@@ -82,4 +91,4 @@ def mixup_data(x, y, alpha=0.2):
82
91
  index = torch.randperm(x.size(0)).to(x.device)
83
92
  mixed_x = lam * x + (1 - lam) * x[index]
84
93
  return mixed_x, y, y[index], lam
85
-
94
+
@@ -0,0 +1,280 @@
1
+ import wandb
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import logging
5
+ import os
6
+ from opensportslib.core.utils.config import namespace_to_dict
7
+
8
+ def build_wandb_config(cfg):
9
+ """
10
+ Extract minimal + useful config for W&B dashboard.
11
+ Returns a FLAT dict (ready for wandb.init(config=...))
12
+ """
13
+
14
+ from opensportslib.core.utils.config import namespace_to_dict
15
+
16
+ cfg_dict = namespace_to_dict(cfg)
17
+
18
+ def get(d, path, default=None):
19
+ """Safe nested get using dot notation"""
20
+ keys = path.split(".")
21
+ for k in keys:
22
+ if not isinstance(d, dict) or k not in d:
23
+ return default
24
+ d = d[k]
25
+ return d
26
+
27
+ def pick(paths):
28
+ """Pick only selected keys"""
29
+ out = {}
30
+ for p in paths:
31
+ v = get(cfg_dict, p)
32
+ if v is not None:
33
+ out[p] = v
34
+ return out
35
+
36
+ # -------------------------
37
+ # REQUIRED (core columns)
38
+ # -------------------------
39
+ REQUIRED_KEYS = [
40
+ "TASK",
41
+
42
+ "DATA.dataset_name",
43
+ "DATA.data_modality",
44
+
45
+ "MODEL.type",
46
+ "MODEL.backbone.type",
47
+ "MODEL.neck.type",
48
+ "MODEL.head.type",
49
+
50
+ "TRAIN.optimizer.type",
51
+ "TRAIN.optimizer.lr",
52
+ "TRAIN.scheduler.type",
53
+
54
+ "TRAIN.monitor",
55
+ "TRAIN.mode",
56
+
57
+ "SYSTEM.device",
58
+ ]
59
+
60
+ # -------------------------
61
+ # OPTIONAL (useful knobs)
62
+ # -------------------------
63
+ OPTIONAL_KEYS = [
64
+ # DATA
65
+ "DATA.num_frames",
66
+ "DATA.clip_len",
67
+ "DATA.input_fps",
68
+ "DATA.extract_fps",
69
+ "DATA.frame_size",
70
+ "DATA.view_type",
71
+ "DATA.num_classes",
72
+
73
+ # MODEL
74
+ "MODEL.backbone.encoder",
75
+ "MODEL.backbone.hidden_dim",
76
+ "MODEL.backbone.freeze",
77
+ "MODEL.unfreeze_last_n_layers",
78
+ "MODEL.neck.agr_type",
79
+ "MODEL.edge",
80
+
81
+ # TRAIN
82
+ "TRAIN.epochs",
83
+ "TRAIN.num_epochs",
84
+ "TRAIN.max_epochs",
85
+ "TRAIN.use_amp",
86
+ "TRAIN.use_weighted_loss",
87
+ "TRAIN.use_weighted_sampler",
88
+ "TRAIN.mixup",
89
+ "TRAIN.mixup_alpha",
90
+
91
+ # SYSTEM
92
+ "SYSTEM.GPU",
93
+ "SYSTEM.seed",
94
+ ]
95
+
96
+ config = {}
97
+
98
+ # pick required
99
+ config.update(pick(REQUIRED_KEYS))
100
+
101
+ # pick optional
102
+ config.update(pick(OPTIONAL_KEYS))
103
+
104
+ # -------------------------
105
+ # SPECIAL HANDLING
106
+ # -------------------------
107
+
108
+ # Normalize batch_size (from nested dataloader)
109
+ batch_size = get(cfg_dict, "DATA.train.dataloader.batch_size")
110
+ if batch_size is not None:
111
+ config["TRAIN.batch_size"] = batch_size
112
+
113
+ # Normalize epochs (different configs use different names)
114
+ epochs = (
115
+ get(cfg_dict, "TRAIN.epochs")
116
+ or get(cfg_dict, "TRAIN.num_epochs")
117
+ or get(cfg_dict, "TRAIN.max_epochs")
118
+ )
119
+ if epochs is not None:
120
+ config["TRAIN.total_epochs"] = epochs
121
+
122
+ return config
123
+
124
+ def _flatten_config(data, parent_key="", sep="."):
125
+ """Flatten nested dict/list config for W&B table-friendly columns."""
126
+ items = {}
127
+
128
+ if isinstance(data, dict):
129
+ for k, v in data.items():
130
+ key = f"{parent_key}{sep}{k}" if parent_key else str(k)
131
+ items.update(_flatten_config(v, key, sep=sep))
132
+ return items
133
+
134
+ if isinstance(data, list):
135
+ for i, v in enumerate(data):
136
+ key = f"{parent_key}{sep}{i}" if parent_key else str(i)
137
+ items.update(_flatten_config(v, key, sep=sep))
138
+ return items
139
+
140
+ if parent_key:
141
+ items[parent_key] = data
142
+
143
+ return items
144
+
145
+
146
+ def _wandb_ready():
147
+ return getattr(wandb, "run", None) is not None
148
+
149
+ def init_wandb(cfg_path, cfg, run_id, use_wandb=False):
150
+ """
151
+ Initialize Weights & Biases if enabled.
152
+
153
+ Args:
154
+ cfg_path: Path to the configuration file.
155
+ cfg: config object with attributes:
156
+ - use_wandb (bool)
157
+ - project_name (str)
158
+ - run_name (str)
159
+ """
160
+
161
+ if not use_wandb:
162
+ logging.info("W&B disabled.")
163
+ return None
164
+
165
+ try:
166
+ import wandb
167
+ except ImportError:
168
+ logging.warning("wandb not installed. Install with `pip install wandb`.")
169
+ return None
170
+
171
+ # Prevent multiple processes from initializing wandb
172
+ rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
173
+ if rank != 0:
174
+ return None
175
+
176
+ # Prevent re-initialization
177
+ if wandb.run is not None:
178
+ return wandb
179
+
180
+ if getattr(cfg.DATA, "data_modality", None):
181
+ run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
182
+ else:
183
+ run_name = f"{cfg.MODEL.backbone.type}"
184
+
185
+ config_flat = build_wandb_config(cfg)
186
+
187
+ wandb.init(
188
+ project=cfg.TASK,
189
+ name=run_name,
190
+ id=run_id,
191
+ resume="allow",
192
+ config=config_flat,
193
+ )
194
+
195
+ artifact = wandb.Artifact(
196
+ name=f"{cfg.TASK}-config",
197
+ type="config",
198
+ description="configuration (YAML)"
199
+ )
200
+
201
+ artifact.add_file(cfg_path)
202
+ wandb.log_artifact(artifact)
203
+
204
+ logging.info(f"Wandb initialised")
205
+ return wandb
206
+
207
+ def log_table_wandb(name, rows, headers):
208
+ """
209
+ Log a table to Weights & Biases.
210
+
211
+ Args:
212
+ name (str): Name of the table in wandb.
213
+ rows (list[list]): Table rows.
214
+ headers (list[str]): Column headers.
215
+ """
216
+ if not _wandb_ready():
217
+ return
218
+
219
+ table = wandb.Table(columns=headers)
220
+
221
+ for row in rows:
222
+ table.add_data(*row)
223
+
224
+ wandb.log({name: table})
225
+
226
+ def log_attention_wandb(attention, split_name):
227
+ if not _wandb_ready():
228
+ return
229
+
230
+ attn = attention.detach().cpu().numpy()
231
+
232
+ fig, ax = plt.subplots(figsize=(6, 3))
233
+ ax.imshow(attn, aspect="auto", cmap="viridis")
234
+ ax.set_title(f"{split_name} Attention Map")
235
+ ax.set_xlabel("Views / Time")
236
+ ax.set_ylabel("Batch")
237
+
238
+ wandb.log({
239
+ f"{split_name}/attention_map": wandb.Image(fig)
240
+ })
241
+
242
+ plt.close(fig)
243
+
244
+
245
+ def log_sample_videos_wandb(mvclips, preds, labels, split_name, max_samples=2, fps=5):
246
+ if not _wandb_ready():
247
+ return
248
+
249
+
250
+ # mvclips: (B, V, C, T, H, W)
251
+ mvclips = mvclips.detach().cpu().numpy()
252
+
253
+ for i in range(min(len(mvclips), max_samples)):
254
+ views = mvclips[i] # (V, C, T, H, W)
255
+
256
+ # Log each view separately
257
+ for v in range(views.shape[0]):
258
+ video = views[v].transpose(1, 2, 3, 0) # (T, H, W, C)
259
+ video = (video * 255).astype(np.uint8) if video.max() <= 1.0 else video
260
+
261
+ wandb.log({
262
+ f"{split_name}/sample_{i}_view_{v}": wandb.Video(
263
+ video,
264
+ fps=fps,
265
+ caption=f"Pred: {preds[i]}, GT: {labels[i]}"
266
+ )
267
+ })
268
+
269
+
270
+ def log_confusion_matrix_wandb(y_true, y_pred, class_names, split_name):
271
+ if not _wandb_ready():
272
+ return
273
+ wandb.log({
274
+ f"{split_name}/confusion_matrix": wandb.plot.confusion_matrix(
275
+ probs=None,
276
+ y_true=y_true,
277
+ preds=y_pred,
278
+ class_names=class_names
279
+ )
280
+ })
@@ -496,7 +496,16 @@ class TrackingDataset(ClassificationDataset):
496
496
  a copy of the feature array is made before augmentation and
497
497
  normalization so the cached data is never mutated.
498
498
  """
499
- from torch_geometric.data import Data
499
+ try:
500
+ from torch_geometric.data import Data
501
+ except ImportError as exc:
502
+ raise ImportError(
503
+ "torch-geometric is required for tracking_parquet datasets. "
504
+ "Install with: pip install \"opensportslib[py-geometric]\" "
505
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
506
+ "or (editable): pip install -e \".[py-geometric]\" "
507
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
508
+ ) from exc
500
509
 
501
510
  from opensportslib.datasets.utils.tracking import normalize_features
502
511
 
@@ -533,7 +542,16 @@ class TrackingDataset(ClassificationDataset):
533
542
 
534
543
  def _getitem_on_the_fly(self, idx):
535
544
  """load, parse, and process a single sample from disk."""
536
- from torch_geometric.data import Data
545
+ try:
546
+ from torch_geometric.data import Data
547
+ except ImportError as exc:
548
+ raise ImportError(
549
+ "torch-geometric is required for tracking_parquet datasets. "
550
+ "Install with: pip install \"opensportslib[py-geometric]\" "
551
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
552
+ "or (editable): pip install -e \".[py-geometric]\" "
553
+ "-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
554
+ ) from exc
537
555
 
538
556
  from opensportslib.datasets.utils.tracking import (
539
557
  compute_deltas,
@@ -600,4 +618,4 @@ class TrackingDataset(ClassificationDataset):
600
618
  if label is not None:
601
619
  out["label"] = label
602
620
  return out
603
-
621
+