opensportslib 0.0.1.dev8__tar.gz → 0.0.1.dev10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.0.1.dev8/opensportslib.egg-info → opensportslib-0.0.1.dev10}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/localization.py +3 -3
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/localization_trainer.py +34 -0
- opensportslib-0.0.1.dev10/opensportslib/core/utils/lightning.py +52 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/video_processing.py +129 -6
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/localization_dataset.py +9 -9
- opensportslib-0.0.1.dev10/opensportslib/models/base/contextaware.py +394 -0
- opensportslib-0.0.1.dev10/opensportslib/models/base/learnablepooling.py +360 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/builder.py +23 -1
- opensportslib-0.0.1.dev10/opensportslib/models/neck/builder.py +637 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/litebase.py +3 -3
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10/opensportslib.egg-info}/PKG-INFO +2 -1
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/SOURCES.txt +3 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/requires.txt +1 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/pyproject.toml +2 -2
- opensportslib-0.0.1.dev8/opensportslib/models/neck/builder.py +0 -210
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/LICENSE +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/MANIFEST.in +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/README.md +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/classification_tracking.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/sngar_frames.yaml +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/classification_trainer.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.0.1.
|
|
3
|
+
Version: 0.0.1.dev10
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -25,6 +25,7 @@ Provides-Extra: localization
|
|
|
25
25
|
Requires-Dist: nvidia-dali-cuda120; extra == "localization"
|
|
26
26
|
Requires-Dist: cupy-cuda12x; extra == "localization"
|
|
27
27
|
Requires-Dist: tabulate; extra == "localization"
|
|
28
|
+
Requires-Dist: pytorch-lightning; extra == "localization"
|
|
28
29
|
Provides-Extra: py-geometric
|
|
29
30
|
Requires-Dist: torch-geometric; extra == "py-geometric"
|
|
30
31
|
Requires-Dist: torch-scatter; extra == "py-geometric"
|
|
@@ -112,7 +112,7 @@ class LocalizationAPI:
|
|
|
112
112
|
gpu=self.config.SYSTEM.GPU,
|
|
113
113
|
default_args=data_obj_train.default_args,
|
|
114
114
|
)
|
|
115
|
-
train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=
|
|
115
|
+
train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
|
|
116
116
|
print(len(train_loader))
|
|
117
117
|
# Valid
|
|
118
118
|
data_obj_valid = build_dataset(self.config,split="valid")
|
|
@@ -121,7 +121,7 @@ class LocalizationAPI:
|
|
|
121
121
|
gpu= self.config.SYSTEM.GPU,
|
|
122
122
|
default_args=data_obj_valid.default_args,
|
|
123
123
|
)
|
|
124
|
-
valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=
|
|
124
|
+
valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
|
|
125
125
|
print(len(valid_loader))
|
|
126
126
|
|
|
127
127
|
# Trainer
|
|
@@ -200,7 +200,7 @@ class LocalizationAPI:
|
|
|
200
200
|
gpu=self.config.SYSTEM.GPU,
|
|
201
201
|
default_args=data_obj_test.default_args,
|
|
202
202
|
)
|
|
203
|
-
test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=
|
|
203
|
+
test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
|
|
204
204
|
print(len(test_loader))
|
|
205
205
|
|
|
206
206
|
# # Inference
|
|
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
135
135
|
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
|
|
136
136
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
137
137
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
trainer = Trainer_pl(cfg, default_args["work_dir"])
|
|
138
141
|
|
|
139
142
|
|
|
140
143
|
return trainer
|
|
@@ -147,6 +150,37 @@ class Trainer(ABC):
|
|
|
147
150
|
def train(self):
|
|
148
151
|
pass
|
|
149
152
|
|
|
153
|
+
class Trainer_pl(Trainer):
|
|
154
|
+
"""Trainer class used for models that rely on lightning modules.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, cfg, work_dir):
|
|
161
|
+
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
|
|
162
|
+
import pytorch_lightning as pl
|
|
163
|
+
|
|
164
|
+
self.work_dir = work_dir
|
|
165
|
+
call = MyCallback()
|
|
166
|
+
self.trainer = pl.Trainer(
|
|
167
|
+
max_epochs=cfg.max_epochs,
|
|
168
|
+
devices=[cfg.GPU],
|
|
169
|
+
callbacks=[call, CustomProgressBar(refresh_rate=1)],
|
|
170
|
+
num_sanity_val_steps=0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def train(self, **kwargs):
|
|
174
|
+
self.trainer.fit(**kwargs)
|
|
175
|
+
|
|
176
|
+
best_model = kwargs["model"].best_state
|
|
177
|
+
|
|
178
|
+
logging.info("Done training")
|
|
179
|
+
logging.info("Best epoch: {}".format(best_model.get("epoch")))
|
|
180
|
+
torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
|
|
181
|
+
|
|
182
|
+
logging.info("Model saved")
|
|
183
|
+
logging.info(os.path.join(self.work_dir, "model.pth.tar"))
|
|
150
184
|
|
|
151
185
|
|
|
152
186
|
class Trainer_e2e(Trainer):
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CustomProgressBar(TQDMProgressBar):
|
|
7
|
+
"""Override the custom progress bar used by pytorch lightning to change some attributes."""
|
|
8
|
+
|
|
9
|
+
def get_metrics(self, trainer, pl_module):
|
|
10
|
+
"""Override the method to don't show the version number in the progress bar."""
|
|
11
|
+
items = super().get_metrics(trainer, pl_module)
|
|
12
|
+
items.pop("v_num", None)
|
|
13
|
+
return items
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MyCallback(pl.Callback):
|
|
17
|
+
"""Override the Callback class of pl to change the behaviour on validation epoch end."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
23
|
+
loss_validation = pl_module.losses.avg
|
|
24
|
+
state = {
|
|
25
|
+
"epoch": trainer.current_epoch + 1,
|
|
26
|
+
"state_dict": pl_module.model.state_dict(),
|
|
27
|
+
"best_loss": pl_module.best_loss,
|
|
28
|
+
"optimizer": pl_module.optimizer.state_dict(),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# remember best prec@1 and save checkpoint
|
|
32
|
+
is_better = loss_validation < pl_module.best_loss
|
|
33
|
+
pl_module.best_loss = min(loss_validation, pl_module.best_loss)
|
|
34
|
+
|
|
35
|
+
# Save the best model based on loss only if the evaluation frequency too long
|
|
36
|
+
if is_better:
|
|
37
|
+
pl_module.best_state = state
|
|
38
|
+
# torch.save(state, best_model_path)
|
|
39
|
+
|
|
40
|
+
# Reduce LR on Plateau after patience reached
|
|
41
|
+
prevLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
42
|
+
pl_module.scheduler.step(loss_validation)
|
|
43
|
+
currLR = pl_module.optimizer.param_groups[0]["lr"]
|
|
44
|
+
|
|
45
|
+
if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
|
|
46
|
+
logging.info("\nPlateau Reached!")
|
|
47
|
+
if (
|
|
48
|
+
prevLR < 2 * pl_module.scheduler.eps
|
|
49
|
+
and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
|
|
50
|
+
):
|
|
51
|
+
logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
|
|
52
|
+
trainer.should_stop = True
|
{opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/video_processing.py
RENAMED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import numpy as np
|
|
3
3
|
import math
|
|
4
|
+
import random
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torchvision.transforms as T
|
|
7
|
+
import torchvision.transforms.functional as F
|
|
4
8
|
|
|
5
9
|
try:
|
|
6
10
|
import decord
|
|
@@ -301,12 +305,6 @@ def get_remaining(data_len, batch_size):
|
|
|
301
305
|
return (math.ceil(data_len / batch_size) * batch_size) - data_len
|
|
302
306
|
|
|
303
307
|
|
|
304
|
-
import random
|
|
305
|
-
import numpy as np
|
|
306
|
-
import torch
|
|
307
|
-
import torchvision.transforms as T
|
|
308
|
-
import torchvision.transforms.functional as F
|
|
309
|
-
|
|
310
308
|
class RandomHorizontalFlipFLow(nn.Module):
|
|
311
309
|
|
|
312
310
|
def __init__(self, p=0.5):
|
|
@@ -719,6 +717,131 @@ def oneHotToShifts(onehot, params):
|
|
|
719
717
|
Shifts[:, i] = shifts
|
|
720
718
|
|
|
721
719
|
return Shifts
|
|
720
|
+
|
|
721
|
+
def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
|
|
722
|
+
"""Method to transform the timestamps to vectors"""
|
|
723
|
+
start = 0
|
|
724
|
+
last = False
|
|
725
|
+
receptive_field = receptive_field // 2
|
|
726
|
+
|
|
727
|
+
timestamps_long = (
|
|
728
|
+
torch.zeros(
|
|
729
|
+
[video_size, output_spotting.size()[-1] - 2],
|
|
730
|
+
dtype=torch.float,
|
|
731
|
+
device=output_spotting.device,
|
|
732
|
+
)
|
|
733
|
+
- 1
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
for batch in np.arange(output_spotting.size()[0]):
|
|
737
|
+
|
|
738
|
+
tmp_timestamps = (
|
|
739
|
+
torch.zeros(
|
|
740
|
+
[chunk_size, output_spotting.size()[-1] - 2],
|
|
741
|
+
dtype=torch.float,
|
|
742
|
+
device=output_spotting.device,
|
|
743
|
+
)
|
|
744
|
+
- 1
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
for i in np.arange(output_spotting.size()[1]):
|
|
748
|
+
tmp_timestamps[
|
|
749
|
+
torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
|
|
750
|
+
torch.int
|
|
751
|
+
),
|
|
752
|
+
torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
|
|
753
|
+
] = output_spotting[batch, i, 0]
|
|
754
|
+
|
|
755
|
+
# ------------------------------------------
|
|
756
|
+
# Store the result of the chunk in the video
|
|
757
|
+
# ------------------------------------------
|
|
758
|
+
|
|
759
|
+
# For the first chunk
|
|
760
|
+
if start == 0:
|
|
761
|
+
timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
|
|
762
|
+
0 : chunk_size - receptive_field
|
|
763
|
+
]
|
|
764
|
+
|
|
765
|
+
# For the last chunk
|
|
766
|
+
elif last:
|
|
767
|
+
timestamps_long[start + receptive_field : start + chunk_size] = (
|
|
768
|
+
tmp_timestamps[receptive_field:]
|
|
769
|
+
)
|
|
770
|
+
break
|
|
771
|
+
|
|
772
|
+
# For every other chunk
|
|
773
|
+
else:
|
|
774
|
+
timestamps_long[
|
|
775
|
+
start + receptive_field : start + chunk_size - receptive_field
|
|
776
|
+
] = tmp_timestamps[receptive_field : chunk_size - receptive_field]
|
|
777
|
+
|
|
778
|
+
# ---------------
|
|
779
|
+
# Loop Management
|
|
780
|
+
# ---------------
|
|
781
|
+
|
|
782
|
+
# Update the index
|
|
783
|
+
start += chunk_size - 2 * receptive_field
|
|
784
|
+
# Check if we are at the last index of the game
|
|
785
|
+
if start + chunk_size >= video_size:
|
|
786
|
+
start = video_size - chunk_size
|
|
787
|
+
last = True
|
|
788
|
+
return timestamps_long
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
|
|
792
|
+
"""Method to transform the batches to vectors."""
|
|
793
|
+
start = 0
|
|
794
|
+
last = False
|
|
795
|
+
receptive_field = receptive_field // 2
|
|
796
|
+
|
|
797
|
+
segmentation_long = torch.zeros(
|
|
798
|
+
[video_size, output_segmentation.size()[-1]],
|
|
799
|
+
dtype=torch.float,
|
|
800
|
+
device=output_segmentation.device,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
for batch in np.arange(output_segmentation.size()[0]):
|
|
804
|
+
|
|
805
|
+
tmp_segmentation = torch.nn.functional.one_hot(
|
|
806
|
+
torch.argmax(output_segmentation[batch], dim=-1),
|
|
807
|
+
num_classes=output_segmentation.size()[-1],
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
# ------------------------------------------
|
|
811
|
+
# Store the result of the chunk in the video
|
|
812
|
+
# ------------------------------------------
|
|
813
|
+
|
|
814
|
+
# For the first chunk
|
|
815
|
+
if start == 0:
|
|
816
|
+
segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
|
|
817
|
+
0 : chunk_size - receptive_field
|
|
818
|
+
]
|
|
819
|
+
|
|
820
|
+
# For the last chunk
|
|
821
|
+
elif last:
|
|
822
|
+
segmentation_long[start + receptive_field : start + chunk_size] = (
|
|
823
|
+
tmp_segmentation[receptive_field:]
|
|
824
|
+
)
|
|
825
|
+
break
|
|
826
|
+
|
|
827
|
+
# For every other chunk
|
|
828
|
+
else:
|
|
829
|
+
segmentation_long[
|
|
830
|
+
start + receptive_field : start + chunk_size - receptive_field
|
|
831
|
+
] = tmp_segmentation[receptive_field : chunk_size - receptive_field]
|
|
832
|
+
|
|
833
|
+
# ---------------
|
|
834
|
+
# Loop Management
|
|
835
|
+
# ---------------
|
|
836
|
+
|
|
837
|
+
# Update the index
|
|
838
|
+
start += chunk_size - 2 * receptive_field
|
|
839
|
+
# Check if we are at the last index of the game
|
|
840
|
+
if start + chunk_size >= video_size:
|
|
841
|
+
start = video_size - chunk_size
|
|
842
|
+
last = True
|
|
843
|
+
return segmentation_long
|
|
844
|
+
|
|
722
845
|
# import torch
|
|
723
846
|
# import numpy as np
|
|
724
847
|
# import decord
|
|
@@ -16,7 +16,7 @@ import logging
|
|
|
16
16
|
import tqdm
|
|
17
17
|
from opensportslib.core.utils.default_args import get_default_args_dataset
|
|
18
18
|
from opensportslib.core.utils.load_annotations import get_repartition_gpu
|
|
19
|
-
from opensportslib.
|
|
19
|
+
from opensportslib.core.utils.video_processing import feats2clip, getChunks_anchors, getTimestampTargets, oneHotToShifts
|
|
20
20
|
from SoccerNet.Downloader import getListGames
|
|
21
21
|
from SoccerNet.Downloader import SoccerNetDownloader
|
|
22
22
|
from SoccerNet.Evaluation.utils import (
|
|
@@ -246,7 +246,7 @@ class LocalizationDataset(Dataset):
|
|
|
246
246
|
num_workers=cfg.num_workers if gpu >= 0 else 0,
|
|
247
247
|
pin_memory=cfg.pin_memory if gpu >= 0 else False,
|
|
248
248
|
prefetch_factor=(
|
|
249
|
-
cfg
|
|
249
|
+
getattr(cfg, "prefetch_factor", None)
|
|
250
250
|
),
|
|
251
251
|
worker_init_fn=worker_init_fn
|
|
252
252
|
)
|
|
@@ -457,7 +457,7 @@ class ActionSpotDataset(Dataset):
|
|
|
457
457
|
from opensportslib.core.utils.video_processing import _get_deferred_rgb_transform, _get_img_transforms
|
|
458
458
|
|
|
459
459
|
self._src_file = label_file
|
|
460
|
-
self._labels = annotationstoe2eformat(
|
|
460
|
+
self._labels, self.task_name = annotationstoe2eformat(
|
|
461
461
|
label_file, video_dir, input_fps, extract_fps, False
|
|
462
462
|
)
|
|
463
463
|
# self._labels = load_json(label_file)
|
|
@@ -491,17 +491,17 @@ class ActionSpotDataset(Dataset):
|
|
|
491
491
|
|
|
492
492
|
self._mixup = mixup
|
|
493
493
|
|
|
494
|
+
self.IMAGENET_MEAN = IMAGENET_MEAN
|
|
495
|
+
self.IMAGENET_STD = IMAGENET_STD
|
|
496
|
+
self.TARGET_HEIGHT = TARGET_HEIGHT
|
|
497
|
+
self.TARGET_WIDTH = TARGET_WIDTH
|
|
494
498
|
# Try to do defer the latter half of the transforms to the GPU
|
|
495
499
|
self._gpu_transform = None
|
|
496
500
|
if not is_eval and same_transform:
|
|
497
501
|
if modality == "rgb":
|
|
498
502
|
print("=> Deferring some RGB transforms to the GPU!")
|
|
499
|
-
self._gpu_transform = _get_deferred_rgb_transform()
|
|
503
|
+
self._gpu_transform = _get_deferred_rgb_transform(self.IMAGENET_MEAN, self.IMAGENET_STD)
|
|
500
504
|
|
|
501
|
-
self.IMAGENET_MEAN = IMAGENET_MEAN
|
|
502
|
-
self.IMAGENET_STD = IMAGENET_STD
|
|
503
|
-
self.TARGET_HEIGHT = TARGET_HEIGHT
|
|
504
|
-
self.TARGET_WIDTH = TARGET_WIDTH
|
|
505
505
|
|
|
506
506
|
crop_transform, img_transform = _get_img_transforms(
|
|
507
507
|
self.IMAGENET_MEAN,
|
|
@@ -771,7 +771,7 @@ class ActionSpotVideoDataset(Dataset, DatasetVideoSharedMethods):
|
|
|
771
771
|
|
|
772
772
|
self._src_file = label_file
|
|
773
773
|
if label_file.endswith(".json"):
|
|
774
|
-
self._labels = annotationstoe2eformat(
|
|
774
|
+
self._labels, self.task_name = annotationstoe2eformat(
|
|
775
775
|
label_file, video_dir, input_fps, extract_fps, False
|
|
776
776
|
)
|
|
777
777
|
# self._labels = load_json(label_file)
|