opensportslib 0.0.1.dev7__tar.gz → 0.0.1.dev9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.0.1.dev7/opensportslib.egg-info → opensportslib-0.0.1.dev9}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/localization_trainer.py +34 -0
- opensportslib-0.0.1.dev9/opensportslib/core/utils/lightning.py +52 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/load_annotations.py +14 -0
- opensportslib-0.0.1.dev9/opensportslib/core/utils/video_processing.py +905 -0
- opensportslib-0.0.1.dev9/opensportslib/datasets/localization_dataset.py +2702 -0
- opensportslib-0.0.1.dev9/opensportslib/models/base/contextaware.py +394 -0
- opensportslib-0.0.1.dev9/opensportslib/models/base/learnablepooling.py +360 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/builder.py +23 -1
- opensportslib-0.0.1.dev9/opensportslib/models/neck/builder.py +637 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9/opensportslib.egg-info}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/SOURCES.txt +3 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/requires.txt +1 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/pyproject.toml +2 -2
- opensportslib-0.0.1.dev7/opensportslib/core/utils/video_processing.py +0 -389
- opensportslib-0.0.1.dev7/opensportslib/datasets/localization_dataset.py +0 -813
- opensportslib-0.0.1.dev7/opensportslib/models/neck/builder.py +0 -210
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/LICENSE +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/MANIFEST.in +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/README.md +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/apis/localization.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/classification_tracking.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/config/sngar_frames.yaml +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/classification_trainer.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.0.1.
|
|
3
|
+
Version: 0.0.1.dev9
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -25,6 +25,7 @@ Provides-Extra: localization
|
|
|
25
25
|
Requires-Dist: nvidia-dali-cuda120; extra == "localization"
|
|
26
26
|
Requires-Dist: cupy-cuda12x; extra == "localization"
|
|
27
27
|
Requires-Dist: tabulate; extra == "localization"
|
|
28
|
+
Requires-Dist: pytorch-lightning; extra == "localization"
|
|
28
29
|
Provides-Extra: py-geometric
|
|
29
30
|
Requires-Dist: torch-geometric; extra == "py-geometric"
|
|
30
31
|
Requires-Dist: torch-scatter; extra == "py-geometric"
|
|
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
135
135
|
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
|
|
136
136
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
137
137
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
trainer = Trainer_pl(cfg, default_args["work_dir"])
|
|
138
141
|
|
|
139
142
|
|
|
140
143
|
return trainer
|
|
@@ -147,6 +150,37 @@ class Trainer(ABC):
|
|
|
147
150
|
def train(self):
|
|
148
151
|
pass
|
|
149
152
|
|
|
153
|
+
class Trainer_pl(Trainer):
|
|
154
|
+
"""Trainer class used for models that rely on lightning modules.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, cfg, work_dir):
|
|
161
|
+
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
|
|
162
|
+
import pytorch_lightning as pl
|
|
163
|
+
|
|
164
|
+
self.work_dir = work_dir
|
|
165
|
+
call = MyCallback()
|
|
166
|
+
self.trainer = pl.Trainer(
|
|
167
|
+
max_epochs=cfg.max_epochs,
|
|
168
|
+
devices=[cfg.GPU],
|
|
169
|
+
callbacks=[call, CustomProgressBar(refresh_rate=1)],
|
|
170
|
+
num_sanity_val_steps=0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def train(self, **kwargs):
|
|
174
|
+
self.trainer.fit(**kwargs)
|
|
175
|
+
|
|
176
|
+
best_model = kwargs["model"].best_state
|
|
177
|
+
|
|
178
|
+
logging.info("Done training")
|
|
179
|
+
logging.info("Best epoch: {}".format(best_model.get("epoch")))
|
|
180
|
+
torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
|
|
181
|
+
|
|
182
|
+
logging.info("Model saved")
|
|
183
|
+
logging.info(os.path.join(self.work_dir, "model.pth.tar"))
|
|
150
184
|
|
|
151
185
|
|
|
152
186
|
class Trainer_e2e(Trainer):
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CustomProgressBar(TQDMProgressBar):
|
|
7
|
+
"""Override the custom progress bar used by pytorch lightning to change some attributes."""
|
|
8
|
+
|
|
9
|
+
def get_metrics(self, trainer, pl_module):
|
|
10
|
+
"""Override the method to don't show the version number in the progress bar."""
|
|
11
|
+
items = super().get_metrics(trainer, pl_module)
|
|
12
|
+
items.pop("v_num", None)
|
|
13
|
+
return items
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MyCallback(pl.Callback):
|
|
17
|
+
"""Override the Callback class of pl to change the behaviour on validation epoch end."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
23
|
+
loss_validation = pl_module.losses.avg
|
|
24
|
+
state = {
|
|
25
|
+
"epoch": trainer.current_epoch + 1,
|
|
26
|
+
"state_dict": pl_module.model.state_dict(),
|
|
27
|
+
"best_loss": pl_module.best_loss,
|
|
28
|
+
"optimizer": pl_module.optimizer.state_dict(),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# remember best prec@1 and save checkpoint
|
|
32
|
+
is_better = loss_validation < pl_module.best_loss
|
|
33
|
+
pl_module.best_loss = min(loss_validation, pl_module.best_loss)
|
|
34
|
+
|
|
35
|
+
# Save the best model based on loss only if the evaluation frequency too long
|
|
36
|
+
if is_better:
|
|
37
|
+
pl_module.best_state = state
|
|
38
|
+
# torch.save(state, best_model_path)
|
|
39
|
+
|
|
40
|
+
# Reduce LR on Plateau after patience reached
|
|
41
|
+
prevLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
42
|
+
pl_module.scheduler.step(loss_validation)
|
|
43
|
+
currLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
44
|
+
|
|
45
|
+
if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
|
|
46
|
+
logging.info("\nPlateau Reached!")
|
|
47
|
+
if (
|
|
48
|
+
prevLR < 2 * pl_module.scheduler.eps
|
|
49
|
+
and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
|
|
50
|
+
):
|
|
51
|
+
logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
|
|
52
|
+
trainer.should_stop = True
|
{opensportslib-0.0.1.dev7 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/load_annotations.py
RENAMED
|
@@ -386,6 +386,20 @@ def construct_labels(path, extract_fps):
|
|
|
386
386
|
num_frames, fps, wanted_sample_fps if wanted_sample_fps < fps else fps
|
|
387
387
|
)
|
|
388
388
|
|
|
389
|
+
return [
|
|
390
|
+
{
|
|
391
|
+
"video": path,
|
|
392
|
+
"path": path,
|
|
393
|
+
"num_frames": num_frames_after,
|
|
394
|
+
"num_frames_base": num_frames,
|
|
395
|
+
"num_events": 0,
|
|
396
|
+
"events": [],
|
|
397
|
+
"fps": sample_fps,
|
|
398
|
+
"width": 398,
|
|
399
|
+
"height": 224,
|
|
400
|
+
}
|
|
401
|
+
], get_stride(fps, wanted_sample_fps if wanted_sample_fps < fps else fps)
|
|
402
|
+
|
|
389
403
|
|
|
390
404
|
# def get_repartition_gpu():
|
|
391
405
|
# """Returns the distribution of gpus that will be used by pipelines for dali."""
|