opensportslib 0.0.1.dev8__tar.gz → 0.0.1.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {opensportslib-0.0.1.dev8/opensportslib.egg-info → opensportslib-0.0.1.dev10}/PKG-INFO +2 -1
  2. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/localization.py +3 -3
  3. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/localization_trainer.py +34 -0
  4. opensportslib-0.0.1.dev10/opensportslib/core/utils/lightning.py +52 -0
  5. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/video_processing.py +129 -6
  6. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/localization_dataset.py +9 -9
  7. opensportslib-0.0.1.dev10/opensportslib/models/base/contextaware.py +394 -0
  8. opensportslib-0.0.1.dev10/opensportslib/models/base/learnablepooling.py +360 -0
  9. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/builder.py +23 -1
  10. opensportslib-0.0.1.dev10/opensportslib/models/neck/builder.py +637 -0
  11. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/litebase.py +3 -3
  12. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10/opensportslib.egg-info}/PKG-INFO +2 -1
  13. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/SOURCES.txt +3 -0
  14. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/requires.txt +1 -0
  15. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/pyproject.toml +2 -2
  16. opensportslib-0.0.1.dev8/opensportslib/models/neck/builder.py +0 -210
  17. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/LICENSE +0 -0
  18. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/LICENSE-COMMERCIAL +0 -0
  19. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/MANIFEST.in +0 -0
  20. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/README.md +0 -0
  21. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/examples/quickstart/basic_classification.py +0 -0
  22. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/examples/quickstart/basic_localization.py +0 -0
  23. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/__init__.py +0 -0
  24. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/__init__.py +0 -0
  25. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/apis/classification.py +0 -0
  26. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/classification.yaml +0 -0
  27. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/classification_tracking.yaml +0 -0
  28. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
  29. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
  30. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
  31. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
  32. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
  33. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
  34. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/localization.yaml +0 -0
  35. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/config/sngar_frames.yaml +0 -0
  36. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/__init__.py +0 -0
  37. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/__init__.py +0 -0
  38. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/builder.py +0 -0
  39. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/calf.py +0 -0
  40. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/ce.py +0 -0
  41. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/combine.py +0 -0
  42. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/loss/nll.py +0 -0
  43. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/optimizer/__init__.py +0 -0
  44. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/optimizer/builder.py +0 -0
  45. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  46. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/scheduler/__init__.py +0 -0
  47. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/scheduler/builder.py +0 -0
  48. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/__init__.py +0 -0
  49. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/trainer/classification_trainer.py +0 -0
  50. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/checkpoint.py +0 -0
  51. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/config.py +0 -0
  52. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/data.py +0 -0
  53. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/ddp.py +0 -0
  54. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/default_args.py +0 -0
  55. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/load_annotations.py +0 -0
  56. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/seed.py +0 -0
  57. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/core/utils/wandb.py +0 -0
  58. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/__init__.py +0 -0
  59. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/builder.py +0 -0
  60. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/classification_dataset.py +0 -0
  61. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/utils/__init__.py +0 -0
  62. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/datasets/utils/tracking.py +0 -0
  63. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/metrics/classification_metric.py +0 -0
  64. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/metrics/localization_metric.py +0 -0
  65. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/__init__.py +0 -0
  66. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/backbones/builder.py +0 -0
  67. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/e2e.py +0 -0
  68. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/tracking.py +0 -0
  69. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/vars.py +0 -0
  70. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/video.py +0 -0
  71. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/base/video_mae.py +0 -0
  72. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/heads/builder.py +0 -0
  73. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/common.py +0 -0
  74. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/__init__.py +0 -0
  75. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/asformer.py +0 -0
  76. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/calf.py +0 -0
  77. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/gsm.py +0 -0
  78. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/gtad.py +0 -0
  79. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/impl/tsm.py +0 -0
  80. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/modules.py +0 -0
  81. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/shift.py +0 -0
  82. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib/models/utils/utils.py +0 -0
  83. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/dependency_links.txt +0 -0
  84. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/opensportslib.egg-info/top_level.txt +0 -0
  85. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev8
3
+ Version: 0.0.1.dev10
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -25,6 +25,7 @@ Provides-Extra: localization
25
25
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
26
  Requires-Dist: cupy-cuda12x; extra == "localization"
27
27
  Requires-Dist: tabulate; extra == "localization"
28
+ Requires-Dist: pytorch-lightning; extra == "localization"
28
29
  Provides-Extra: py-geometric
29
30
  Requires-Dist: torch-geometric; extra == "py-geometric"
30
31
  Requires-Dist: torch-scatter; extra == "py-geometric"
@@ -112,7 +112,7 @@ class LocalizationAPI:
112
112
  gpu=self.config.SYSTEM.GPU,
113
113
  default_args=data_obj_train.default_args,
114
114
  )
115
- train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
115
+ train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
116
116
  print(len(train_loader))
117
117
  # Valid
118
118
  data_obj_valid = build_dataset(self.config,split="valid")
@@ -121,7 +121,7 @@ class LocalizationAPI:
121
121
  gpu= self.config.SYSTEM.GPU,
122
122
  default_args=data_obj_valid.default_args,
123
123
  )
124
- valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
124
+ valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
125
125
  print(len(valid_loader))
126
126
 
127
127
  # Trainer
@@ -200,7 +200,7 @@ class LocalizationAPI:
200
200
  gpu=self.config.SYSTEM.GPU,
201
201
  default_args=data_obj_test.default_args,
202
202
  )
203
- test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
203
+ test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=self.config.dali)
204
204
  print(len(test_loader))
205
205
 
206
206
  # # Inference
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
135
135
  trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
136
136
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
137
137
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
138
+
139
+ else:
140
+ trainer = Trainer_pl(cfg, default_args["work_dir"])
138
141
 
139
142
 
140
143
  return trainer
@@ -147,6 +150,37 @@ class Trainer(ABC):
147
150
  def train(self):
148
151
  pass
149
152
 
153
+ class Trainer_pl(Trainer):
154
+ """Trainer class used for models that rely on lightning modules.
155
+
156
+ Args:
157
+ cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
158
+ """
159
+
160
+ def __init__(self, cfg, work_dir):
161
+ from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
162
+ import pytorch_lightning as pl
163
+
164
+ self.work_dir = work_dir
165
+ call = MyCallback()
166
+ self.trainer = pl.Trainer(
167
+ max_epochs=cfg.max_epochs,
168
+ devices=[cfg.GPU],
169
+ callbacks=[call, CustomProgressBar(refresh_rate=1)],
170
+ num_sanity_val_steps=0,
171
+ )
172
+
173
+ def train(self, **kwargs):
174
+ self.trainer.fit(**kwargs)
175
+
176
+ best_model = kwargs["model"].best_state
177
+
178
+ logging.info("Done training")
179
+ logging.info("Best epoch: {}".format(best_model.get("epoch")))
180
+ torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
181
+
182
+ logging.info("Model saved")
183
+ logging.info(os.path.join(self.work_dir, "model.pth.tar"))
150
184
 
151
185
 
152
186
  class Trainer_e2e(Trainer):
@@ -0,0 +1,52 @@
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
3
+ import logging
4
+
5
+
6
+ class CustomProgressBar(TQDMProgressBar):
7
+ """Override the custom progress bar used by pytorch lightning to change some attributes."""
8
+
9
+ def get_metrics(self, trainer, pl_module):
10
+ """Override the method to don't show the version number in the progress bar."""
11
+ items = super().get_metrics(trainer, pl_module)
12
+ items.pop("v_num", None)
13
+ return items
14
+
15
+
16
+ class MyCallback(pl.Callback):
17
+ """Override the Callback class of pl to change the behaviour on validation epoch end."""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def on_validation_epoch_end(self, trainer, pl_module):
23
+ loss_validation = pl_module.losses.avg
24
+ state = {
25
+ "epoch": trainer.current_epoch + 1,
26
+ "state_dict": pl_module.model.state_dict(),
27
+ "best_loss": pl_module.best_loss,
28
+ "optimizer": pl_module.optimizer.state_dict(),
29
+ }
30
+
31
+ # remember best prec@1 and save checkpoint
32
+ is_better = loss_validation < pl_module.best_loss
33
+ pl_module.best_loss = min(loss_validation, pl_module.best_loss)
34
+
35
+ # Save the best model based on loss only if the evaluation frequency too long
36
+ if is_better:
37
+ pl_module.best_state = state
38
+ # torch.save(state, best_model_path)
39
+
40
+ # Reduce LR on Plateau after patience reached
41
+ prevLR = pl_module.optimizer.param_groups[0]["lr"]
42
+ pl_module.scheduler.step(loss_validation)
43
+ currLR = pl_module.optimizer.param_groups[0]["lr"]
44
+
45
+ if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
46
+ logging.info("\nPlateau Reached!")
47
+ if (
48
+ prevLR < 2 * pl_module.scheduler.eps
49
+ and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
50
+ ):
51
+ logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
52
+ trainer.should_stop = True
@@ -1,6 +1,10 @@
1
1
  import torch
2
2
  import numpy as np
3
3
  import math
4
+ import random
5
+ import torch.nn as nn
6
+ import torchvision.transforms as T
7
+ import torchvision.transforms.functional as F
4
8
 
5
9
  try:
6
10
  import decord
@@ -301,12 +305,6 @@ def get_remaining(data_len, batch_size):
301
305
  return (math.ceil(data_len / batch_size) * batch_size) - data_len
302
306
 
303
307
 
304
- import random
305
- import numpy as np
306
- import torch
307
- import torchvision.transforms as T
308
- import torchvision.transforms.functional as F
309
-
310
308
  class RandomHorizontalFlipFLow(nn.Module):
311
309
 
312
310
  def __init__(self, p=0.5):
@@ -719,6 +717,131 @@ def oneHotToShifts(onehot, params):
719
717
  Shifts[:, i] = shifts
720
718
 
721
719
  return Shifts
720
+
721
+ def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
722
+ """Method to transform the timestamps to vectors"""
723
+ start = 0
724
+ last = False
725
+ receptive_field = receptive_field // 2
726
+
727
+ timestamps_long = (
728
+ torch.zeros(
729
+ [video_size, output_spotting.size()[-1] - 2],
730
+ dtype=torch.float,
731
+ device=output_spotting.device,
732
+ )
733
+ - 1
734
+ )
735
+
736
+ for batch in np.arange(output_spotting.size()[0]):
737
+
738
+ tmp_timestamps = (
739
+ torch.zeros(
740
+ [chunk_size, output_spotting.size()[-1] - 2],
741
+ dtype=torch.float,
742
+ device=output_spotting.device,
743
+ )
744
+ - 1
745
+ )
746
+
747
+ for i in np.arange(output_spotting.size()[1]):
748
+ tmp_timestamps[
749
+ torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
750
+ torch.int
751
+ ),
752
+ torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
753
+ ] = output_spotting[batch, i, 0]
754
+
755
+ # ------------------------------------------
756
+ # Store the result of the chunk in the video
757
+ # ------------------------------------------
758
+
759
+ # For the first chunk
760
+ if start == 0:
761
+ timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
762
+ 0 : chunk_size - receptive_field
763
+ ]
764
+
765
+ # For the last chunk
766
+ elif last:
767
+ timestamps_long[start + receptive_field : start + chunk_size] = (
768
+ tmp_timestamps[receptive_field:]
769
+ )
770
+ break
771
+
772
+ # For every other chunk
773
+ else:
774
+ timestamps_long[
775
+ start + receptive_field : start + chunk_size - receptive_field
776
+ ] = tmp_timestamps[receptive_field : chunk_size - receptive_field]
777
+
778
+ # ---------------
779
+ # Loop Management
780
+ # ---------------
781
+
782
+ # Update the index
783
+ start += chunk_size - 2 * receptive_field
784
+ # Check if we are at the last index of the game
785
+ if start + chunk_size >= video_size:
786
+ start = video_size - chunk_size
787
+ last = True
788
+ return timestamps_long
789
+
790
+
791
+ def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
792
+ """Method to transform the batches to vectors."""
793
+ start = 0
794
+ last = False
795
+ receptive_field = receptive_field // 2
796
+
797
+ segmentation_long = torch.zeros(
798
+ [video_size, output_segmentation.size()[-1]],
799
+ dtype=torch.float,
800
+ device=output_segmentation.device,
801
+ )
802
+
803
+ for batch in np.arange(output_segmentation.size()[0]):
804
+
805
+ tmp_segmentation = torch.nn.functional.one_hot(
806
+ torch.argmax(output_segmentation[batch], dim=-1),
807
+ num_classes=output_segmentation.size()[-1],
808
+ )
809
+
810
+ # ------------------------------------------
811
+ # Store the result of the chunk in the video
812
+ # ------------------------------------------
813
+
814
+ # For the first chunk
815
+ if start == 0:
816
+ segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
817
+ 0 : chunk_size - receptive_field
818
+ ]
819
+
820
+ # For the last chunk
821
+ elif last:
822
+ segmentation_long[start + receptive_field : start + chunk_size] = (
823
+ tmp_segmentation[receptive_field:]
824
+ )
825
+ break
826
+
827
+ # For every other chunk
828
+ else:
829
+ segmentation_long[
830
+ start + receptive_field : start + chunk_size - receptive_field
831
+ ] = tmp_segmentation[receptive_field : chunk_size - receptive_field]
832
+
833
+ # ---------------
834
+ # Loop Management
835
+ # ---------------
836
+
837
+ # Update the index
838
+ start += chunk_size - 2 * receptive_field
839
+ # Check if we are at the last index of the game
840
+ if start + chunk_size >= video_size:
841
+ start = video_size - chunk_size
842
+ last = True
843
+ return segmentation_long
844
+
722
845
  # import torch
723
846
  # import numpy as np
724
847
  # import decord
@@ -16,7 +16,7 @@ import logging
16
16
  import tqdm
17
17
  from opensportslib.core.utils.default_args import get_default_args_dataset
18
18
  from opensportslib.core.utils.load_annotations import get_repartition_gpu
19
- from opensportslib.opensportslib.core.utils.video_processing import feats2clip, getChunks_anchors, getTimestampTargets, oneHotToShifts
19
+ from opensportslib.core.utils.video_processing import feats2clip, getChunks_anchors, getTimestampTargets, oneHotToShifts
20
20
  from SoccerNet.Downloader import getListGames
21
21
  from SoccerNet.Downloader import SoccerNetDownloader
22
22
  from SoccerNet.Evaluation.utils import (
@@ -246,7 +246,7 @@ class LocalizationDataset(Dataset):
246
246
  num_workers=cfg.num_workers if gpu >= 0 else 0,
247
247
  pin_memory=cfg.pin_memory if gpu >= 0 else False,
248
248
  prefetch_factor=(
249
- cfg.prefetch_factor if "prefetch_factor" in cfg.keys() else None
249
+ getattr(cfg, "prefetch_factor", None)
250
250
  ),
251
251
  worker_init_fn=worker_init_fn
252
252
  )
@@ -457,7 +457,7 @@ class ActionSpotDataset(Dataset):
457
457
  from opensportslib.core.utils.video_processing import _get_deferred_rgb_transform, _get_img_transforms
458
458
 
459
459
  self._src_file = label_file
460
- self._labels = annotationstoe2eformat(
460
+ self._labels, self.task_name = annotationstoe2eformat(
461
461
  label_file, video_dir, input_fps, extract_fps, False
462
462
  )
463
463
  # self._labels = load_json(label_file)
@@ -491,17 +491,17 @@ class ActionSpotDataset(Dataset):
491
491
 
492
492
  self._mixup = mixup
493
493
 
494
+ self.IMAGENET_MEAN = IMAGENET_MEAN
495
+ self.IMAGENET_STD = IMAGENET_STD
496
+ self.TARGET_HEIGHT = TARGET_HEIGHT
497
+ self.TARGET_WIDTH = TARGET_WIDTH
494
498
  # Try to do defer the latter half of the transforms to the GPU
495
499
  self._gpu_transform = None
496
500
  if not is_eval and same_transform:
497
501
  if modality == "rgb":
498
502
  print("=> Deferring some RGB transforms to the GPU!")
499
- self._gpu_transform = _get_deferred_rgb_transform()
503
+ self._gpu_transform = _get_deferred_rgb_transform(self.IMAGENET_MEAN, self.IMAGENET_STD)
500
504
 
501
- self.IMAGENET_MEAN = IMAGENET_MEAN
502
- self.IMAGENET_STD = IMAGENET_STD
503
- self.TARGET_HEIGHT = TARGET_HEIGHT
504
- self.TARGET_WIDTH = TARGET_WIDTH
505
505
 
506
506
  crop_transform, img_transform = _get_img_transforms(
507
507
  self.IMAGENET_MEAN,
@@ -771,7 +771,7 @@ class ActionSpotVideoDataset(Dataset, DatasetVideoSharedMethods):
771
771
 
772
772
  self._src_file = label_file
773
773
  if label_file.endswith(".json"):
774
- self._labels = annotationstoe2eformat(
774
+ self._labels, self.task_name = annotationstoe2eformat(
775
775
  label_file, video_dir, input_fps, extract_fps, False
776
776
  )
777
777
  # self._labels = load_json(label_file)