opensportslib 0.0.1.dev8__tar.gz → 0.0.1.dev9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.0.1.dev8/opensportslib.egg-info → opensportslib-0.0.1.dev9}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/localization_trainer.py +34 -0
- opensportslib-0.0.1.dev9/opensportslib/core/utils/lightning.py +52 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/video_processing.py +125 -0
- opensportslib-0.0.1.dev9/opensportslib/models/base/contextaware.py +394 -0
- opensportslib-0.0.1.dev9/opensportslib/models/base/learnablepooling.py +360 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/builder.py +23 -1
- opensportslib-0.0.1.dev9/opensportslib/models/neck/builder.py +637 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9/opensportslib.egg-info}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/SOURCES.txt +3 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/requires.txt +1 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/pyproject.toml +2 -2
- opensportslib-0.0.1.dev8/opensportslib/models/neck/builder.py +0 -210
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/LICENSE +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/MANIFEST.in +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/README.md +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/localization.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/classification_tracking.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/sngar_frames.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/classification_trainer.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.0.1.
|
|
3
|
+
Version: 0.0.1.dev9
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -25,6 +25,7 @@ Provides-Extra: localization
|
|
|
25
25
|
Requires-Dist: nvidia-dali-cuda120; extra == "localization"
|
|
26
26
|
Requires-Dist: cupy-cuda12x; extra == "localization"
|
|
27
27
|
Requires-Dist: tabulate; extra == "localization"
|
|
28
|
+
Requires-Dist: pytorch-lightning; extra == "localization"
|
|
28
29
|
Provides-Extra: py-geometric
|
|
29
30
|
Requires-Dist: torch-geometric; extra == "py-geometric"
|
|
30
31
|
Requires-Dist: torch-scatter; extra == "py-geometric"
|
|
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
135
135
|
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
|
|
136
136
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
137
137
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
trainer = Trainer_pl(cfg, default_args["work_dir"])
|
|
138
141
|
|
|
139
142
|
|
|
140
143
|
return trainer
|
|
@@ -147,6 +150,37 @@ class Trainer(ABC):
|
|
|
147
150
|
def train(self):
|
|
148
151
|
pass
|
|
149
152
|
|
|
153
|
+
class Trainer_pl(Trainer):
|
|
154
|
+
"""Trainer class used for models that rely on lightning modules.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, cfg, work_dir):
|
|
161
|
+
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
|
|
162
|
+
import pytorch_lightning as pl
|
|
163
|
+
|
|
164
|
+
self.work_dir = work_dir
|
|
165
|
+
call = MyCallback()
|
|
166
|
+
self.trainer = pl.Trainer(
|
|
167
|
+
max_epochs=cfg.max_epochs,
|
|
168
|
+
devices=[cfg.GPU],
|
|
169
|
+
callbacks=[call, CustomProgressBar(refresh_rate=1)],
|
|
170
|
+
num_sanity_val_steps=0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def train(self, **kwargs):
|
|
174
|
+
self.trainer.fit(**kwargs)
|
|
175
|
+
|
|
176
|
+
best_model = kwargs["model"].best_state
|
|
177
|
+
|
|
178
|
+
logging.info("Done training")
|
|
179
|
+
logging.info("Best epoch: {}".format(best_model.get("epoch")))
|
|
180
|
+
torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
|
|
181
|
+
|
|
182
|
+
logging.info("Model saved")
|
|
183
|
+
logging.info(os.path.join(self.work_dir, "model.pth.tar"))
|
|
150
184
|
|
|
151
185
|
|
|
152
186
|
class Trainer_e2e(Trainer):
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CustomProgressBar(TQDMProgressBar):
|
|
7
|
+
"""Override the custom progress bar used by pytorch lightning to change some attributes."""
|
|
8
|
+
|
|
9
|
+
def get_metrics(self, trainer, pl_module):
|
|
10
|
+
"""Override the method to don't show the version number in the progress bar."""
|
|
11
|
+
items = super().get_metrics(trainer, pl_module)
|
|
12
|
+
items.pop("v_num", None)
|
|
13
|
+
return items
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MyCallback(pl.Callback):
|
|
17
|
+
"""Override the Callback class of pl to change the behaviour on validation epoch end."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
23
|
+
loss_validation = pl_module.losses.avg
|
|
24
|
+
state = {
|
|
25
|
+
"epoch": trainer.current_epoch + 1,
|
|
26
|
+
"state_dict": pl_module.model.state_dict(),
|
|
27
|
+
"best_loss": pl_module.best_loss,
|
|
28
|
+
"optimizer": pl_module.optimizer.state_dict(),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# remember best prec@1 and save checkpoint
|
|
32
|
+
is_better = loss_validation < pl_module.best_loss
|
|
33
|
+
pl_module.best_loss = min(loss_validation, pl_module.best_loss)
|
|
34
|
+
|
|
35
|
+
# Save the best model based on loss only if the evaluation frequency too long
|
|
36
|
+
if is_better:
|
|
37
|
+
pl_module.best_state = state
|
|
38
|
+
# torch.save(state, best_model_path)
|
|
39
|
+
|
|
40
|
+
# Reduce LR on Plateau after patience reached
|
|
41
|
+
prevLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
42
|
+
pl_module.scheduler.step(loss_validation)
|
|
43
|
+
currLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
44
|
+
|
|
45
|
+
if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
|
|
46
|
+
logging.info("\nPlateau Reached!")
|
|
47
|
+
if (
|
|
48
|
+
prevLR < 2 * pl_module.scheduler.eps
|
|
49
|
+
and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
|
|
50
|
+
):
|
|
51
|
+
logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
|
|
52
|
+
trainer.should_stop = True
|
{opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/video_processing.py
RENAMED
|
@@ -719,6 +719,131 @@ def oneHotToShifts(onehot, params):
|
|
|
719
719
|
Shifts[:, i] = shifts
|
|
720
720
|
|
|
721
721
|
return Shifts
|
|
722
|
+
|
|
723
|
+
def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
|
|
724
|
+
"""Method to transform the timestamps to vectors"""
|
|
725
|
+
start = 0
|
|
726
|
+
last = False
|
|
727
|
+
receptive_field = receptive_field // 2
|
|
728
|
+
|
|
729
|
+
timestamps_long = (
|
|
730
|
+
torch.zeros(
|
|
731
|
+
[video_size, output_spotting.size()[-1] - 2],
|
|
732
|
+
dtype=torch.float,
|
|
733
|
+
device=output_spotting.device,
|
|
734
|
+
)
|
|
735
|
+
- 1
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
for batch in np.arange(output_spotting.size()[0]):
|
|
739
|
+
|
|
740
|
+
tmp_timestamps = (
|
|
741
|
+
torch.zeros(
|
|
742
|
+
[chunk_size, output_spotting.size()[-1] - 2],
|
|
743
|
+
dtype=torch.float,
|
|
744
|
+
device=output_spotting.device,
|
|
745
|
+
)
|
|
746
|
+
- 1
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
for i in np.arange(output_spotting.size()[1]):
|
|
750
|
+
tmp_timestamps[
|
|
751
|
+
torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
|
|
752
|
+
torch.int
|
|
753
|
+
),
|
|
754
|
+
torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
|
|
755
|
+
] = output_spotting[batch, i, 0]
|
|
756
|
+
|
|
757
|
+
# ------------------------------------------
|
|
758
|
+
# Store the result of the chunk in the video
|
|
759
|
+
# ------------------------------------------
|
|
760
|
+
|
|
761
|
+
# For the first chunk
|
|
762
|
+
if start == 0:
|
|
763
|
+
timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
|
|
764
|
+
0 : chunk_size - receptive_field
|
|
765
|
+
]
|
|
766
|
+
|
|
767
|
+
# For the last chunk
|
|
768
|
+
elif last:
|
|
769
|
+
timestamps_long[start + receptive_field : start + chunk_size] = (
|
|
770
|
+
tmp_timestamps[receptive_field:]
|
|
771
|
+
)
|
|
772
|
+
break
|
|
773
|
+
|
|
774
|
+
# For every other chunk
|
|
775
|
+
else:
|
|
776
|
+
timestamps_long[
|
|
777
|
+
start + receptive_field : start + chunk_size - receptive_field
|
|
778
|
+
] = tmp_timestamps[receptive_field : chunk_size - receptive_field]
|
|
779
|
+
|
|
780
|
+
# ---------------
|
|
781
|
+
# Loop Management
|
|
782
|
+
# ---------------
|
|
783
|
+
|
|
784
|
+
# Update the index
|
|
785
|
+
start += chunk_size - 2 * receptive_field
|
|
786
|
+
# Check if we are at the last index of the game
|
|
787
|
+
if start + chunk_size >= video_size:
|
|
788
|
+
start = video_size - chunk_size
|
|
789
|
+
last = True
|
|
790
|
+
return timestamps_long
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
|
|
794
|
+
"""Method to transform the batches to vectors."""
|
|
795
|
+
start = 0
|
|
796
|
+
last = False
|
|
797
|
+
receptive_field = receptive_field // 2
|
|
798
|
+
|
|
799
|
+
segmentation_long = torch.zeros(
|
|
800
|
+
[video_size, output_segmentation.size()[-1]],
|
|
801
|
+
dtype=torch.float,
|
|
802
|
+
device=output_segmentation.device,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
for batch in np.arange(output_segmentation.size()[0]):
|
|
806
|
+
|
|
807
|
+
tmp_segmentation = torch.nn.functional.one_hot(
|
|
808
|
+
torch.argmax(output_segmentation[batch], dim=-1),
|
|
809
|
+
num_classes=output_segmentation.size()[-1],
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# ------------------------------------------
|
|
813
|
+
# Store the result of the chunk in the video
|
|
814
|
+
# ------------------------------------------
|
|
815
|
+
|
|
816
|
+
# For the first chunk
|
|
817
|
+
if start == 0:
|
|
818
|
+
segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
|
|
819
|
+
0 : chunk_size - receptive_field
|
|
820
|
+
]
|
|
821
|
+
|
|
822
|
+
# For the last chunk
|
|
823
|
+
elif last:
|
|
824
|
+
segmentation_long[start + receptive_field : start + chunk_size] = (
|
|
825
|
+
tmp_segmentation[receptive_field:]
|
|
826
|
+
)
|
|
827
|
+
break
|
|
828
|
+
|
|
829
|
+
# For every other chunk
|
|
830
|
+
else:
|
|
831
|
+
segmentation_long[
|
|
832
|
+
start + receptive_field : start + chunk_size - receptive_field
|
|
833
|
+
] = tmp_segmentation[receptive_field : chunk_size - receptive_field]
|
|
834
|
+
|
|
835
|
+
# ---------------
|
|
836
|
+
# Loop Management
|
|
837
|
+
# ---------------
|
|
838
|
+
|
|
839
|
+
# Update the index
|
|
840
|
+
start += chunk_size - 2 * receptive_field
|
|
841
|
+
# Check if we are at the last index of the game
|
|
842
|
+
if start + chunk_size >= video_size:
|
|
843
|
+
start = video_size - chunk_size
|
|
844
|
+
last = True
|
|
845
|
+
return segmentation_long
|
|
846
|
+
|
|
722
847
|
# import torch
|
|
723
848
|
# import numpy as np
|
|
724
849
|
# import decord
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from opensportslib.models.utils.litebase import LiteBaseModel
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
from opensportslib.core.utils.video_processing import timestamps2long, batch2long
|
|
13
|
+
|
|
14
|
+
from opensportslib.models.utils.utils import (
|
|
15
|
+
NMS,
|
|
16
|
+
check_if_should_predict,
|
|
17
|
+
get_json_data,
|
|
18
|
+
predictions2json,
|
|
19
|
+
predictions2json_runnerjson,
|
|
20
|
+
zipResults,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from opensportslib.models.heads.builder import build_head
|
|
24
|
+
from opensportslib.models.backbones.builder import build_backbone
|
|
25
|
+
from opensportslib.models.neck.builder import build_neck
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ContextAwareModel(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
CALF model composed of a backbone, neck and head.
|
|
31
|
+
Args:
|
|
32
|
+
weights (string): Path of the weights file.
|
|
33
|
+
backbone (string): Name of the backbone type.
|
|
34
|
+
neck (string): Name of the neck type.
|
|
35
|
+
head (string): Name of the head type.
|
|
36
|
+
The model takes as input a Tensor of the form (batch_size,1,chunk_size,input_size)
|
|
37
|
+
and returns :
|
|
38
|
+
1. The segmentation of the form (batch_size,chunk_size,num_classes).
|
|
39
|
+
2. The action spotting of the form (batch_size,num_detections,2+num_classes).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
weights=None,
|
|
45
|
+
backbone="PreExtracted",
|
|
46
|
+
neck="CNN++",
|
|
47
|
+
head="SpottingCALF",
|
|
48
|
+
post_proc="NMS",
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
super(ContextAwareModel, self).__init__()
|
|
52
|
+
|
|
53
|
+
# Build Backbone
|
|
54
|
+
self.backbone = build_backbone(backbone)
|
|
55
|
+
|
|
56
|
+
# Build Neck
|
|
57
|
+
self.neck = build_neck(neck)
|
|
58
|
+
|
|
59
|
+
# Build Head
|
|
60
|
+
self.head = build_head(head)
|
|
61
|
+
|
|
62
|
+
# load weight if needed
|
|
63
|
+
self.load_weights(weights=weights)
|
|
64
|
+
|
|
65
|
+
def load_weights(self, weights=None):
|
|
66
|
+
if weights is not None:
|
|
67
|
+
print("=> loading checkpoint '{}'".format(weights))
|
|
68
|
+
checkpoint = torch.load(weights)
|
|
69
|
+
self.load_state_dict(checkpoint["state_dict"])
|
|
70
|
+
print(
|
|
71
|
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
|
72
|
+
weights, checkpoint["epoch"]
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def forward(self, inputs):
|
|
77
|
+
"""
|
|
78
|
+
INPUT: a Tensor of the form (batch_size,1,chunk_size,input_size)
|
|
79
|
+
OUTPUTS: 1. The segmentation of the form (batch_size,chunk_size,num_classes)
|
|
80
|
+
2. The action spotting of the form (batch_size,num_detections,2+num_classes)
|
|
81
|
+
"""
|
|
82
|
+
features = self.backbone(inputs)
|
|
83
|
+
conv_seg, output_segmentation = self.neck(features)
|
|
84
|
+
output_spotting = self.head(conv_seg, output_segmentation)
|
|
85
|
+
return output_segmentation, output_spotting
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class LiteContextAwareModel(LiteBaseModel):
|
|
89
|
+
"""
|
|
90
|
+
Lightning module for the CALF model.
|
|
91
|
+
Args:
|
|
92
|
+
cfg (dict): DIct of config.
|
|
93
|
+
weights (string): Path of the weights file.
|
|
94
|
+
backbone (string): Name of the backbone type for the CALF model.
|
|
95
|
+
neck (string): Name of the neck type for the CALF model.
|
|
96
|
+
head (string): Name of the head type for the CALF model.
|
|
97
|
+
runner (string): Name of the runner. "runner_CALF" if using SoccerNet dataset modules or "runner_JSON" if using the json format. This will the change the behaviour of processing the predictions while infering.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
cfg=None,
|
|
103
|
+
weights=None,
|
|
104
|
+
backbone="PreExtracted",
|
|
105
|
+
neck="CNN++",
|
|
106
|
+
head="SpottingCALF",
|
|
107
|
+
post_proc="NMS",
|
|
108
|
+
runner="runner_CALF",
|
|
109
|
+
):
|
|
110
|
+
super().__init__(cfg.training)
|
|
111
|
+
|
|
112
|
+
# check compatibility dims Backbone - Neck - Head
|
|
113
|
+
assert backbone.output_dim == neck.input_size
|
|
114
|
+
assert neck.num_classes == head.num_classes
|
|
115
|
+
assert neck.dim_capsule == head.dim_capsule
|
|
116
|
+
assert neck.num_detections == head.num_detections
|
|
117
|
+
assert neck.chunk_size == head.chunk_size
|
|
118
|
+
|
|
119
|
+
self.chunk_size = neck.chunk_size
|
|
120
|
+
self.receptive_field = neck.receptive_field
|
|
121
|
+
self.framerate = neck.framerate
|
|
122
|
+
|
|
123
|
+
self.model = ContextAwareModel(weights, backbone, neck, head, post_proc)
|
|
124
|
+
|
|
125
|
+
self.overwrite = True
|
|
126
|
+
|
|
127
|
+
self.cfg = cfg
|
|
128
|
+
|
|
129
|
+
self.runner = runner
|
|
130
|
+
|
|
131
|
+
self.infer_split = getattr(cfg, "infer_split", True)
|
|
132
|
+
|
|
133
|
+
def process(self, labels, targets, feats):
|
|
134
|
+
labels = labels.float()
|
|
135
|
+
targets = targets.float()
|
|
136
|
+
feats = feats.unsqueeze(1)
|
|
137
|
+
return labels, targets, feats
|
|
138
|
+
|
|
139
|
+
def _common_step(self, batch, batch_idx):
|
|
140
|
+
"""Operations in common for training and validation steps.
|
|
141
|
+
Process the features, labels and targets. The features are processed by the model to compute the outputs.
|
|
142
|
+
These outputs are used to compute the loss.
|
|
143
|
+
"""
|
|
144
|
+
feats, labels, targets = batch
|
|
145
|
+
labels, targets, feats = self.process(labels, targets, feats)
|
|
146
|
+
output_segmentation, output_spotting = self.forward(feats)
|
|
147
|
+
return self.criterion(
|
|
148
|
+
[labels, targets], [output_segmentation, output_spotting]
|
|
149
|
+
), feats.size(0)
|
|
150
|
+
|
|
151
|
+
def training_step(self, batch, batch_idx):
|
|
152
|
+
"""Training step that defines the train loop."""
|
|
153
|
+
loss, size = self._common_step(batch, batch_idx)
|
|
154
|
+
self.log_dict({"loss": loss}, on_step=True, on_epoch=True, prog_bar=True)
|
|
155
|
+
self.losses.update(loss.item(), size)
|
|
156
|
+
return loss
|
|
157
|
+
|
|
158
|
+
def validation_step(self, batch, batch_idx):
|
|
159
|
+
"""Validation step that defines the val loop."""
|
|
160
|
+
val_loss, size = self._common_step(batch, batch_idx)
|
|
161
|
+
self.log_dict(
|
|
162
|
+
{"valid_loss": val_loss}, on_step=False, on_epoch=True, prog_bar=True
|
|
163
|
+
)
|
|
164
|
+
self.losses.update(val_loss.item(), size)
|
|
165
|
+
return val_loss
|
|
166
|
+
|
|
167
|
+
def on_predict_start(self):
|
|
168
|
+
"""Operations to make before starting to infer."""
|
|
169
|
+
self.stop_predict = False
|
|
170
|
+
|
|
171
|
+
if self.infer_split:
|
|
172
|
+
self.output_folder, self.output_results, self.stop_predict = (
|
|
173
|
+
check_if_should_predict(
|
|
174
|
+
self.cfg.dataset.test.results, self.cfg.work_dir, self.overwrite
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if self.runner == "runner_JSON":
|
|
178
|
+
self.target_dir = os.path.join(self.cfg.work_dir, self.output_folder)
|
|
179
|
+
else:
|
|
180
|
+
self.target_dir = self.output_results
|
|
181
|
+
|
|
182
|
+
if not self.stop_predict:
|
|
183
|
+
self.spotting_predictions = list()
|
|
184
|
+
self.spotting_grountruth = list()
|
|
185
|
+
self.spotting_grountruth_visibility = list()
|
|
186
|
+
self.segmentation_predictions = list()
|
|
187
|
+
|
|
188
|
+
def on_predict_end(self):
|
|
189
|
+
"""Operations to make after inference.
|
|
190
|
+
The process is different whether the data come from json or from the SoccerNet dataset in the way we will store the jsons containing the predictions.
|
|
191
|
+
"""
|
|
192
|
+
if not self.stop_predict:
|
|
193
|
+
# Transformation to numpy for evaluation
|
|
194
|
+
targets_numpy = list()
|
|
195
|
+
closests_numpy = list()
|
|
196
|
+
detections_numpy = list()
|
|
197
|
+
for target, detection in zip(
|
|
198
|
+
self.spotting_grountruth_visibility, self.spotting_predictions
|
|
199
|
+
):
|
|
200
|
+
target_numpy = target.cpu().numpy()
|
|
201
|
+
targets_numpy.append(target_numpy)
|
|
202
|
+
detections_numpy.append(NMS(detection.numpy(), 20 * self.framerate))
|
|
203
|
+
closest_numpy = np.zeros(target_numpy.shape) - 1
|
|
204
|
+
# Get the closest action index
|
|
205
|
+
for c in np.arange(target_numpy.shape[-1]):
|
|
206
|
+
indexes = np.where(target_numpy[:, c] != 0)[0].tolist()
|
|
207
|
+
if len(indexes) == 0:
|
|
208
|
+
continue
|
|
209
|
+
indexes.insert(0, -indexes[0])
|
|
210
|
+
indexes.append(2 * closest_numpy.shape[0])
|
|
211
|
+
for i in np.arange(len(indexes) - 2) + 1:
|
|
212
|
+
start = max(0, (indexes[i - 1] + indexes[i]) // 2)
|
|
213
|
+
stop = min(
|
|
214
|
+
closest_numpy.shape[0], (indexes[i] + indexes[i + 1]) // 2
|
|
215
|
+
)
|
|
216
|
+
closest_numpy[start:stop, c] = target_numpy[indexes[i], c]
|
|
217
|
+
closests_numpy.append(closest_numpy)
|
|
218
|
+
|
|
219
|
+
# Save the predictions to the json format
|
|
220
|
+
# if save_predictions:
|
|
221
|
+
if self.runner == "runner_CALF":
|
|
222
|
+
list_game = self.trainer.predict_dataloaders.dataset.listGames
|
|
223
|
+
for index in np.arange(len(list_game)):
|
|
224
|
+
json_data = get_json_data(list_game[index])
|
|
225
|
+
if self.infer_split:
|
|
226
|
+
os.makedirs(
|
|
227
|
+
os.path.join(
|
|
228
|
+
self.cfg.work_dir, self.output_folder, list_game[index]
|
|
229
|
+
),
|
|
230
|
+
exist_ok=True,
|
|
231
|
+
)
|
|
232
|
+
output_file = os.path.join(
|
|
233
|
+
self.cfg.work_dir,
|
|
234
|
+
self.output_folder,
|
|
235
|
+
list_game[index],
|
|
236
|
+
"results_spotting.json",
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
output_file = os.path.join(
|
|
240
|
+
self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
|
|
241
|
+
)
|
|
242
|
+
json_data = predictions2json(
|
|
243
|
+
detections_numpy[index * 2],
|
|
244
|
+
detections_numpy[(index * 2) + 1],
|
|
245
|
+
json_data,
|
|
246
|
+
output_file,
|
|
247
|
+
self.framerate,
|
|
248
|
+
)
|
|
249
|
+
self.json_data = json_data
|
|
250
|
+
elif self.runner == "runner_JSON":
|
|
251
|
+
list_videos = self.trainer.predict_dataloaders.dataset.data_json[0][
|
|
252
|
+
"videos"
|
|
253
|
+
]
|
|
254
|
+
for index in np.arange(len(list_videos)):
|
|
255
|
+
video = list_videos[index]["path"]
|
|
256
|
+
|
|
257
|
+
if self.infer_split:
|
|
258
|
+
video = os.path.splitext(video)[0]
|
|
259
|
+
os.makedirs(
|
|
260
|
+
os.path.join(self.cfg.work_dir, self.output_folder, video),
|
|
261
|
+
exist_ok=True,
|
|
262
|
+
)
|
|
263
|
+
output_file = os.path.join(
|
|
264
|
+
self.cfg.work_dir,
|
|
265
|
+
self.output_folder,
|
|
266
|
+
video,
|
|
267
|
+
"results_spotting.json",
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
output_file = os.path.join(
|
|
271
|
+
self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
json_data = get_json_data(video)
|
|
275
|
+
json_data = predictions2json_runnerjson(
|
|
276
|
+
detections_numpy[index],
|
|
277
|
+
json_data,
|
|
278
|
+
output_file,
|
|
279
|
+
self.framerate,
|
|
280
|
+
inverse_event_dictionary=self.trainer.predict_dataloaders.dataset.inverse_event_dictionary,
|
|
281
|
+
)
|
|
282
|
+
self.json_data = json_data
|
|
283
|
+
if self.infer_split:
|
|
284
|
+
zipResults(
|
|
285
|
+
zip_path=self.output_results,
|
|
286
|
+
target_dir=os.path.join(self.cfg.work_dir, self.output_folder),
|
|
287
|
+
filename="results_spotting.json",
|
|
288
|
+
)
|
|
289
|
+
logging.info("Predictions saved")
|
|
290
|
+
logging.info(
|
|
291
|
+
os.path.join(
|
|
292
|
+
self.cfg.work_dir,
|
|
293
|
+
self.output_folder,
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
logging.info("Predictions saved")
|
|
297
|
+
logging.info(self.output_results)
|
|
298
|
+
else:
|
|
299
|
+
logging.info("Predictions saved")
|
|
300
|
+
logging.info(
|
|
301
|
+
os.path.join(
|
|
302
|
+
self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def predict_step(self, batch):
|
|
307
|
+
"""Infer step.
|
|
308
|
+
The process is different whether the data come from json or from the SoccerNet dataset.
|
|
309
|
+
In particular, processing data from json means processing one video (features) while processing data from SOccerNet
|
|
310
|
+
means processing two halfs of a game.
|
|
311
|
+
"""
|
|
312
|
+
if not self.stop_predict:
|
|
313
|
+
if self.runner == "runner_CALF":
|
|
314
|
+
feat_half1, feat_half2, label_half1, label_half2 = batch
|
|
315
|
+
|
|
316
|
+
label_half1 = label_half1.float().squeeze(0)
|
|
317
|
+
label_half2 = label_half2.float().squeeze(0)
|
|
318
|
+
|
|
319
|
+
feat_half1 = feat_half1.squeeze(0)
|
|
320
|
+
feat_half2 = feat_half2.squeeze(0)
|
|
321
|
+
|
|
322
|
+
feat_half1 = feat_half1.unsqueeze(1)
|
|
323
|
+
feat_half2 = feat_half2.unsqueeze(1)
|
|
324
|
+
|
|
325
|
+
# Compute the output
|
|
326
|
+
output_segmentation_half_1, output_spotting_half_1 = self.forward(
|
|
327
|
+
feat_half1
|
|
328
|
+
)
|
|
329
|
+
output_segmentation_half_2, output_spotting_half_2 = self.forward(
|
|
330
|
+
feat_half2
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
timestamp_long_half_1 = timestamps2long(
|
|
334
|
+
output_spotting_half_1.cpu().detach(),
|
|
335
|
+
label_half1.size()[0],
|
|
336
|
+
self.chunk_size,
|
|
337
|
+
self.receptive_field,
|
|
338
|
+
)
|
|
339
|
+
timestamp_long_half_2 = timestamps2long(
|
|
340
|
+
output_spotting_half_2.cpu().detach(),
|
|
341
|
+
label_half2.size()[0],
|
|
342
|
+
self.chunk_size,
|
|
343
|
+
self.receptive_field,
|
|
344
|
+
)
|
|
345
|
+
segmentation_long_half_1 = batch2long(
|
|
346
|
+
output_segmentation_half_1.cpu().detach(),
|
|
347
|
+
label_half1.size()[0],
|
|
348
|
+
self.chunk_size,
|
|
349
|
+
self.receptive_field,
|
|
350
|
+
)
|
|
351
|
+
segmentation_long_half_2 = batch2long(
|
|
352
|
+
output_segmentation_half_2.cpu().detach(),
|
|
353
|
+
label_half2.size()[0],
|
|
354
|
+
self.chunk_size,
|
|
355
|
+
self.receptive_field,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
self.spotting_grountruth.append(torch.abs(label_half1))
|
|
359
|
+
self.spotting_grountruth.append(torch.abs(label_half2))
|
|
360
|
+
self.spotting_grountruth_visibility.append(label_half1)
|
|
361
|
+
self.spotting_grountruth_visibility.append(label_half2)
|
|
362
|
+
self.spotting_predictions.append(timestamp_long_half_1)
|
|
363
|
+
self.spotting_predictions.append(timestamp_long_half_2)
|
|
364
|
+
self.segmentation_predictions.append(segmentation_long_half_1)
|
|
365
|
+
self.segmentation_predictions.append(segmentation_long_half_2)
|
|
366
|
+
elif self.runner == "runner_JSON":
|
|
367
|
+
features, labels = batch
|
|
368
|
+
|
|
369
|
+
labels = labels.float().squeeze(0)
|
|
370
|
+
|
|
371
|
+
features = features.squeeze(0)
|
|
372
|
+
|
|
373
|
+
features = features.unsqueeze(1)
|
|
374
|
+
|
|
375
|
+
# Compute the output
|
|
376
|
+
output_segmentation, output_spotting = self.forward(features)
|
|
377
|
+
|
|
378
|
+
timestamp_long = timestamps2long(
|
|
379
|
+
output_spotting.cpu().detach(),
|
|
380
|
+
labels.size()[0],
|
|
381
|
+
self.chunk_size,
|
|
382
|
+
self.receptive_field,
|
|
383
|
+
)
|
|
384
|
+
segmentation_long = batch2long(
|
|
385
|
+
output_segmentation.cpu().detach(),
|
|
386
|
+
labels.size()[0],
|
|
387
|
+
self.chunk_size,
|
|
388
|
+
self.receptive_field,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
self.spotting_grountruth.append(torch.abs(labels))
|
|
392
|
+
self.spotting_grountruth_visibility.append(labels)
|
|
393
|
+
self.spotting_predictions.append(timestamp_long)
|
|
394
|
+
self.segmentation_predictions.append(segmentation_long)
|