opensportslib 0.0.1.dev7__tar.gz → 0.0.1.dev9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {opensportslib-0.0.1.dev7/opensportslib.egg-info → opensportslib-0.0.1.dev9}/PKG-INFO +2 -1
  2. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/localization_trainer.py +34 -0
  3. opensportslib-0.0.1.dev9/opensportslib/core/utils/lightning.py +52 -0
  4. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/load_annotations.py +14 -0
  5. opensportslib-0.0.1.dev9/opensportslib/core/utils/video_processing.py +905 -0
  6. opensportslib-0.0.1.dev9/opensportslib/datasets/localization_dataset.py +2702 -0
  7. opensportslib-0.0.1.dev9/opensportslib/models/base/contextaware.py +394 -0
  8. opensportslib-0.0.1.dev9/opensportslib/models/base/learnablepooling.py +360 -0
  9. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/builder.py +23 -1
  10. opensportslib-0.0.1.dev9/opensportslib/models/neck/builder.py +637 -0
  11. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9/opensportslib.egg-info}/PKG-INFO +2 -1
  12. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/SOURCES.txt +3 -0
  13. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/requires.txt +1 -0
  14. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/pyproject.toml +2 -2
  15. opensportslib-0.0.1.dev7/opensportslib/core/utils/video_processing.py +0 -389
  16. opensportslib-0.0.1.dev7/opensportslib/datasets/localization_dataset.py +0 -813
  17. opensportslib-0.0.1.dev7/opensportslib/models/neck/builder.py +0 -210
  18. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/LICENSE +0 -0
  19. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/LICENSE-COMMERCIAL +0 -0
  20. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/MANIFEST.in +0 -0
  21. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/README.md +0 -0
  22. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_classification.py +0 -0
  23. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_localization.py +0 -0
  24. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/__init__.py +0 -0
  25. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/__init__.py +0 -0
  26. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/classification.py +0 -0
  27. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/localization.py +0 -0
  28. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/classification.yaml +0 -0
  29. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/classification_tracking.yaml +0 -0
  30. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
  31. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
  32. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
  33. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
  34. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
  35. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
  36. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/localization.yaml +0 -0
  37. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/sngar_frames.yaml +0 -0
  38. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/__init__.py +0 -0
  39. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/__init__.py +0 -0
  40. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/builder.py +0 -0
  41. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/calf.py +0 -0
  42. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/ce.py +0 -0
  43. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/combine.py +0 -0
  44. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/nll.py +0 -0
  45. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/__init__.py +0 -0
  46. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/builder.py +0 -0
  47. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  48. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/__init__.py +0 -0
  49. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/builder.py +0 -0
  50. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/__init__.py +0 -0
  51. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/classification_trainer.py +0 -0
  52. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/checkpoint.py +0 -0
  53. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/config.py +0 -0
  54. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/data.py +0 -0
  55. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/ddp.py +0 -0
  56. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/default_args.py +0 -0
  57. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/seed.py +0 -0
  58. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/wandb.py +0 -0
  59. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/__init__.py +0 -0
  60. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/builder.py +0 -0
  61. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/classification_dataset.py +0 -0
  62. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/__init__.py +0 -0
  63. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/tracking.py +0 -0
  64. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/metrics/classification_metric.py +0 -0
  65. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/metrics/localization_metric.py +0 -0
  66. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/__init__.py +0 -0
  67. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/backbones/builder.py +0 -0
  68. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/e2e.py +0 -0
  69. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/tracking.py +0 -0
  70. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/vars.py +0 -0
  71. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video.py +0 -0
  72. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video_mae.py +0 -0
  73. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/heads/builder.py +0 -0
  74. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/common.py +0 -0
  75. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/__init__.py +0 -0
  76. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/asformer.py +0 -0
  77. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/calf.py +0 -0
  78. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gsm.py +0 -0
  79. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gtad.py +0 -0
  80. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/tsm.py +0 -0
  81. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/litebase.py +0 -0
  82. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/modules.py +0 -0
  83. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/shift.py +0 -0
  84. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/utils.py +0 -0
  85. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/dependency_links.txt +0 -0
  86. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/top_level.txt +0 -0
  87. {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev7
3
+ Version: 0.0.1.dev9
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -25,6 +25,7 @@ Provides-Extra: localization
25
25
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
26
  Requires-Dist: cupy-cuda12x; extra == "localization"
27
27
  Requires-Dist: tabulate; extra == "localization"
28
+ Requires-Dist: pytorch-lightning; extra == "localization"
28
29
  Provides-Extra: py-geometric
29
30
  Requires-Dist: torch-geometric; extra == "py-geometric"
30
31
  Requires-Dist: torch-scatter; extra == "py-geometric"
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
135
135
  trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
136
136
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
137
137
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
138
+
139
+ else:
140
+ trainer = Trainer_pl(cfg, default_args["work_dir"])
138
141
 
139
142
 
140
143
  return trainer
@@ -147,6 +150,37 @@ class Trainer(ABC):
147
150
  def train(self):
148
151
  pass
149
152
 
153
+ class Trainer_pl(Trainer):
154
+ """Trainer class used for models that rely on lightning modules.
155
+
156
+ Args:
157
+ cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
158
+ """
159
+
160
+ def __init__(self, cfg, work_dir):
161
+ from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
162
+ import pytorch_lightning as pl
163
+
164
+ self.work_dir = work_dir
165
+ call = MyCallback()
166
+ self.trainer = pl.Trainer(
167
+ max_epochs=cfg.max_epochs,
168
+ devices=[cfg.GPU],
169
+ callbacks=[call, CustomProgressBar(refresh_rate=1)],
170
+ num_sanity_val_steps=0,
171
+ )
172
+
173
+ def train(self, **kwargs):
174
+ self.trainer.fit(**kwargs)
175
+
176
+ best_model = kwargs["model"].best_state
177
+
178
+ logging.info("Done training")
179
+ logging.info("Best epoch: {}".format(best_model.get("epoch")))
180
+ torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
181
+
182
+ logging.info("Model saved")
183
+ logging.info(os.path.join(self.work_dir, "model.pth.tar"))
150
184
 
151
185
 
152
186
  class Trainer_e2e(Trainer):
@@ -0,0 +1,52 @@
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
3
+ import logging
4
+
5
+
6
+ class CustomProgressBar(TQDMProgressBar):
7
+ """Override the custom progress bar used by pytorch lightning to change some attributes."""
8
+
9
+ def get_metrics(self, trainer, pl_module):
10
+ """Override the method to don't show the version number in the progress bar."""
11
+ items = super().get_metrics(trainer, pl_module)
12
+ items.pop("v_num", None)
13
+ return items
14
+
15
+
16
+ class MyCallback(pl.Callback):
17
+ """Override the Callback class of pl to change the behaviour on validation epoch end."""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def on_validation_epoch_end(self, trainer, pl_module):
23
+ loss_validation = pl_module.losses.avg
24
+ state = {
25
+ "epoch": trainer.current_epoch + 1,
26
+ "state_dict": pl_module.model.state_dict(),
27
+ "best_loss": pl_module.best_loss,
28
+ "optimizer": pl_module.optimizer.state_dict(),
29
+ }
30
+
31
+ # remember best prec@1 and save checkpoint
32
+ is_better = loss_validation < pl_module.best_loss
33
+ pl_module.best_loss = min(loss_validation, pl_module.best_loss)
34
+
35
+ # Save the best model based on loss only if the evaluation frequency too long
36
+ if is_better:
37
+ pl_module.best_state = state
38
+ # torch.save(state, best_model_path)
39
+
40
+ # Reduce LR on Plateau after patience reached
41
+ prevLR = pl_module.optimizer.param_groups[0]["lr"]
42
+ pl_module.scheduler.step(loss_validation)
43
+ currLR = pl_module.optimizer.param_groups[0]["lr"]
44
+
45
+ if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
46
+ logging.info("\nPlateau Reached!")
47
+ if (
48
+ prevLR < 2 * pl_module.scheduler.eps
49
+ and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
50
+ ):
51
+ logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
52
+ trainer.should_stop = True
@@ -386,6 +386,20 @@ def construct_labels(path, extract_fps):
386
386
  num_frames, fps, wanted_sample_fps if wanted_sample_fps < fps else fps
387
387
  )
388
388
 
389
+ return [
390
+ {
391
+ "video": path,
392
+ "path": path,
393
+ "num_frames": num_frames_after,
394
+ "num_frames_base": num_frames,
395
+ "num_events": 0,
396
+ "events": [],
397
+ "fps": sample_fps,
398
+ "width": 398,
399
+ "height": 224,
400
+ }
401
+ ], get_stride(fps, wanted_sample_fps if wanted_sample_fps < fps else fps)
402
+
389
403
 
390
404
  # def get_repartition_gpu():
391
405
  # """Returns the distribution of gpus that will be used by pipelines for dali."""