opensportslib 0.0.1.dev8__tar.gz → 0.0.1.dev9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {opensportslib-0.0.1.dev8/opensportslib.egg-info → opensportslib-0.0.1.dev9}/PKG-INFO +2 -1
  2. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/localization_trainer.py +34 -0
  3. opensportslib-0.0.1.dev9/opensportslib/core/utils/lightning.py +52 -0
  4. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/video_processing.py +125 -0
  5. opensportslib-0.0.1.dev9/opensportslib/models/base/contextaware.py +394 -0
  6. opensportslib-0.0.1.dev9/opensportslib/models/base/learnablepooling.py +360 -0
  7. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/builder.py +23 -1
  8. opensportslib-0.0.1.dev9/opensportslib/models/neck/builder.py +637 -0
  9. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9/opensportslib.egg-info}/PKG-INFO +2 -1
  10. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/SOURCES.txt +3 -0
  11. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/requires.txt +1 -0
  12. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/pyproject.toml +2 -2
  13. opensportslib-0.0.1.dev8/opensportslib/models/neck/builder.py +0 -210
  14. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/LICENSE +0 -0
  15. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/LICENSE-COMMERCIAL +0 -0
  16. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/MANIFEST.in +0 -0
  17. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/README.md +0 -0
  18. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_classification.py +0 -0
  19. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/examples/quickstart/basic_localization.py +0 -0
  20. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/__init__.py +0 -0
  21. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/__init__.py +0 -0
  22. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/classification.py +0 -0
  23. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/apis/localization.py +0 -0
  24. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/classification.yaml +0 -0
  25. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/classification_tracking.yaml +0 -0
  26. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -0
  27. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/gin.yaml +0 -0
  28. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -0
  29. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -0
  30. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -0
  31. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -0
  32. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/localization.yaml +0 -0
  33. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/config/sngar_frames.yaml +0 -0
  34. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/__init__.py +0 -0
  35. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/__init__.py +0 -0
  36. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/builder.py +0 -0
  37. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/calf.py +0 -0
  38. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/ce.py +0 -0
  39. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/combine.py +0 -0
  40. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/loss/nll.py +0 -0
  41. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/__init__.py +0 -0
  42. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/optimizer/builder.py +0 -0
  43. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  44. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/__init__.py +0 -0
  45. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/scheduler/builder.py +0 -0
  46. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/__init__.py +0 -0
  47. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/trainer/classification_trainer.py +0 -0
  48. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/checkpoint.py +0 -0
  49. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/config.py +0 -0
  50. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/data.py +0 -0
  51. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/ddp.py +0 -0
  52. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/default_args.py +0 -0
  53. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/load_annotations.py +0 -0
  54. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/seed.py +0 -0
  55. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/core/utils/wandb.py +0 -0
  56. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/__init__.py +0 -0
  57. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/builder.py +0 -0
  58. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/classification_dataset.py +0 -0
  59. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/localization_dataset.py +0 -0
  60. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/__init__.py +0 -0
  61. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/datasets/utils/tracking.py +0 -0
  62. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/metrics/classification_metric.py +0 -0
  63. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/metrics/localization_metric.py +0 -0
  64. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/__init__.py +0 -0
  65. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/backbones/builder.py +0 -0
  66. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/e2e.py +0 -0
  67. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/tracking.py +0 -0
  68. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/vars.py +0 -0
  69. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video.py +0 -0
  70. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/base/video_mae.py +0 -0
  71. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/heads/builder.py +0 -0
  72. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/common.py +0 -0
  73. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/__init__.py +0 -0
  74. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/asformer.py +0 -0
  75. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/calf.py +0 -0
  76. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gsm.py +0 -0
  77. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/gtad.py +0 -0
  78. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/impl/tsm.py +0 -0
  79. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/litebase.py +0 -0
  80. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/modules.py +0 -0
  81. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/shift.py +0 -0
  82. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib/models/utils/utils.py +0 -0
  83. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/dependency_links.txt +0 -0
  84. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/opensportslib.egg-info/top_level.txt +0 -0
  85. {opensportslib-0.0.1.dev8 → opensportslib-0.0.1.dev9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev8
3
+ Version: 0.0.1.dev9
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -25,6 +25,7 @@ Provides-Extra: localization
25
25
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
26
  Requires-Dist: cupy-cuda12x; extra == "localization"
27
27
  Requires-Dist: tabulate; extra == "localization"
28
+ Requires-Dist: pytorch-lightning; extra == "localization"
28
29
  Provides-Extra: py-geometric
29
30
  Requires-Dist: torch-geometric; extra == "py-geometric"
30
31
  Requires-Dist: torch-scatter; extra == "py-geometric"
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
135
135
  trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
136
136
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
137
137
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
138
+
139
+ else:
140
+ trainer = Trainer_pl(cfg, default_args["work_dir"])
138
141
 
139
142
 
140
143
  return trainer
@@ -147,6 +150,37 @@ class Trainer(ABC):
147
150
  def train(self):
148
151
  pass
149
152
 
153
+ class Trainer_pl(Trainer):
154
+ """Trainer class used for models that rely on lightning modules.
155
+
156
+ Args:
157
+ cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
158
+ """
159
+
160
+ def __init__(self, cfg, work_dir):
161
+ from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
162
+ import pytorch_lightning as pl
163
+
164
+ self.work_dir = work_dir
165
+ call = MyCallback()
166
+ self.trainer = pl.Trainer(
167
+ max_epochs=cfg.max_epochs,
168
+ devices=[cfg.GPU],
169
+ callbacks=[call, CustomProgressBar(refresh_rate=1)],
170
+ num_sanity_val_steps=0,
171
+ )
172
+
173
+ def train(self, **kwargs):
174
+ self.trainer.fit(**kwargs)
175
+
176
+ best_model = kwargs["model"].best_state
177
+
178
+ logging.info("Done training")
179
+ logging.info("Best epoch: {}".format(best_model.get("epoch")))
180
+ torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
181
+
182
+ logging.info("Model saved")
183
+ logging.info(os.path.join(self.work_dir, "model.pth.tar"))
150
184
 
151
185
 
152
186
  class Trainer_e2e(Trainer):
@@ -0,0 +1,52 @@
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
3
+ import logging
4
+
5
+
6
+ class CustomProgressBar(TQDMProgressBar):
7
+ """Override the custom progress bar used by pytorch lightning to change some attributes."""
8
+
9
+ def get_metrics(self, trainer, pl_module):
10
+ """Override the method to don't show the version number in the progress bar."""
11
+ items = super().get_metrics(trainer, pl_module)
12
+ items.pop("v_num", None)
13
+ return items
14
+
15
+
16
+ class MyCallback(pl.Callback):
17
+ """Override the Callback class of pl to change the behaviour on validation epoch end."""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def on_validation_epoch_end(self, trainer, pl_module):
23
+ loss_validation = pl_module.losses.avg
24
+ state = {
25
+ "epoch": trainer.current_epoch + 1,
26
+ "state_dict": pl_module.model.state_dict(),
27
+ "best_loss": pl_module.best_loss,
28
+ "optimizer": pl_module.optimizer.state_dict(),
29
+ }
30
+
31
+ # remember best prec@1 and save checkpoint
32
+ is_better = loss_validation < pl_module.best_loss
33
+ pl_module.best_loss = min(loss_validation, pl_module.best_loss)
34
+
35
+ # Save the best model based on loss only if the evaluation frequency too long
36
+ if is_better:
37
+ pl_module.best_state = state
38
+ # torch.save(state, best_model_path)
39
+
40
+ # Reduce LR on Plateau after patience reached
41
+ prevLR = pl_module.optimizer.param_groups[0]["lr"]
42
+ pl_module.scheduler.step(loss_validation)
43
+ currLR = pl_module.optimizer.param_groups[0]["lr"]
44
+
45
+ if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
46
+ logging.info("\nPlateau Reached!")
47
+ if (
48
+ prevLR < 2 * pl_module.scheduler.eps
49
+ and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
50
+ ):
51
+ logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
52
+ trainer.should_stop = True
@@ -719,6 +719,131 @@ def oneHotToShifts(onehot, params):
719
719
  Shifts[:, i] = shifts
720
720
 
721
721
  return Shifts
722
+
723
+ def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
724
+ """Method to transform the timestamps to vectors"""
725
+ start = 0
726
+ last = False
727
+ receptive_field = receptive_field // 2
728
+
729
+ timestamps_long = (
730
+ torch.zeros(
731
+ [video_size, output_spotting.size()[-1] - 2],
732
+ dtype=torch.float,
733
+ device=output_spotting.device,
734
+ )
735
+ - 1
736
+ )
737
+
738
+ for batch in np.arange(output_spotting.size()[0]):
739
+
740
+ tmp_timestamps = (
741
+ torch.zeros(
742
+ [chunk_size, output_spotting.size()[-1] - 2],
743
+ dtype=torch.float,
744
+ device=output_spotting.device,
745
+ )
746
+ - 1
747
+ )
748
+
749
+ for i in np.arange(output_spotting.size()[1]):
750
+ tmp_timestamps[
751
+ torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
752
+ torch.int
753
+ ),
754
+ torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
755
+ ] = output_spotting[batch, i, 0]
756
+
757
+ # ------------------------------------------
758
+ # Store the result of the chunk in the video
759
+ # ------------------------------------------
760
+
761
+ # For the first chunk
762
+ if start == 0:
763
+ timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
764
+ 0 : chunk_size - receptive_field
765
+ ]
766
+
767
+ # For the last chunk
768
+ elif last:
769
+ timestamps_long[start + receptive_field : start + chunk_size] = (
770
+ tmp_timestamps[receptive_field:]
771
+ )
772
+ break
773
+
774
+ # For every other chunk
775
+ else:
776
+ timestamps_long[
777
+ start + receptive_field : start + chunk_size - receptive_field
778
+ ] = tmp_timestamps[receptive_field : chunk_size - receptive_field]
779
+
780
+ # ---------------
781
+ # Loop Management
782
+ # ---------------
783
+
784
+ # Update the index
785
+ start += chunk_size - 2 * receptive_field
786
+ # Check if we are at the last index of the game
787
+ if start + chunk_size >= video_size:
788
+ start = video_size - chunk_size
789
+ last = True
790
+ return timestamps_long
791
+
792
+
793
+ def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
794
+ """Method to transform the batches to vectors."""
795
+ start = 0
796
+ last = False
797
+ receptive_field = receptive_field // 2
798
+
799
+ segmentation_long = torch.zeros(
800
+ [video_size, output_segmentation.size()[-1]],
801
+ dtype=torch.float,
802
+ device=output_segmentation.device,
803
+ )
804
+
805
+ for batch in np.arange(output_segmentation.size()[0]):
806
+
807
+ tmp_segmentation = torch.nn.functional.one_hot(
808
+ torch.argmax(output_segmentation[batch], dim=-1),
809
+ num_classes=output_segmentation.size()[-1],
810
+ )
811
+
812
+ # ------------------------------------------
813
+ # Store the result of the chunk in the video
814
+ # ------------------------------------------
815
+
816
+ # For the first chunk
817
+ if start == 0:
818
+ segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
819
+ 0 : chunk_size - receptive_field
820
+ ]
821
+
822
+ # For the last chunk
823
+ elif last:
824
+ segmentation_long[start + receptive_field : start + chunk_size] = (
825
+ tmp_segmentation[receptive_field:]
826
+ )
827
+ break
828
+
829
+ # For every other chunk
830
+ else:
831
+ segmentation_long[
832
+ start + receptive_field : start + chunk_size - receptive_field
833
+ ] = tmp_segmentation[receptive_field : chunk_size - receptive_field]
834
+
835
+ # ---------------
836
+ # Loop Management
837
+ # ---------------
838
+
839
+ # Update the index
840
+ start += chunk_size - 2 * receptive_field
841
+ # Check if we are at the last index of the game
842
+ if start + chunk_size >= video_size:
843
+ start = video_size - chunk_size
844
+ last = True
845
+ return segmentation_long
846
+
722
847
  # import torch
723
848
  # import numpy as np
724
849
  # import decord
@@ -0,0 +1,394 @@
1
+ import logging
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from opensportslib.models.utils.litebase import LiteBaseModel
9
+
10
+ import os
11
+
12
+ from opensportslib.core.utils.video_processing import timestamps2long, batch2long
13
+
14
+ from opensportslib.models.utils.utils import (
15
+ NMS,
16
+ check_if_should_predict,
17
+ get_json_data,
18
+ predictions2json,
19
+ predictions2json_runnerjson,
20
+ zipResults,
21
+ )
22
+
23
+ from opensportslib.models.heads.builder import build_head
24
+ from opensportslib.models.backbones.builder import build_backbone
25
+ from opensportslib.models.neck.builder import build_neck
26
+
27
+
28
+ class ContextAwareModel(nn.Module):
29
+ """
30
+ CALF model composed of a backbone, neck and head.
31
+ Args:
32
+ weights (string): Path of the weights file.
33
+ backbone (string): Name of the backbone type.
34
+ neck (string): Name of the neck type.
35
+ head (string): Name of the head type.
36
+ The model takes as input a Tensor of the form (batch_size,1,chunk_size,input_size)
37
+ and returns :
38
+ 1. The segmentation of the form (batch_size,chunk_size,num_classes).
39
+ 2. The action spotting of the form (batch_size,num_detections,2+num_classes).
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ weights=None,
45
+ backbone="PreExtracted",
46
+ neck="CNN++",
47
+ head="SpottingCALF",
48
+ post_proc="NMS",
49
+ ):
50
+
51
+ super(ContextAwareModel, self).__init__()
52
+
53
+ # Build Backbone
54
+ self.backbone = build_backbone(backbone)
55
+
56
+ # Build Neck
57
+ self.neck = build_neck(neck)
58
+
59
+ # Build Head
60
+ self.head = build_head(head)
61
+
62
+ # load weight if needed
63
+ self.load_weights(weights=weights)
64
+
65
+ def load_weights(self, weights=None):
66
+ if weights is not None:
67
+ print("=> loading checkpoint '{}'".format(weights))
68
+ checkpoint = torch.load(weights)
69
+ self.load_state_dict(checkpoint["state_dict"])
70
+ print(
71
+ "=> loaded checkpoint '{}' (epoch {})".format(
72
+ weights, checkpoint["epoch"]
73
+ )
74
+ )
75
+
76
+ def forward(self, inputs):
77
+ """
78
+ INPUT: a Tensor of the form (batch_size,1,chunk_size,input_size)
79
+ OUTPUTS: 1. The segmentation of the form (batch_size,chunk_size,num_classes)
80
+ 2. The action spotting of the form (batch_size,num_detections,2+num_classes)
81
+ """
82
+ features = self.backbone(inputs)
83
+ conv_seg, output_segmentation = self.neck(features)
84
+ output_spotting = self.head(conv_seg, output_segmentation)
85
+ return output_segmentation, output_spotting
86
+
87
+
88
+ class LiteContextAwareModel(LiteBaseModel):
89
+ """
90
+ Lightning module for the CALF model.
91
+ Args:
92
+ cfg (dict): DIct of config.
93
+ weights (string): Path of the weights file.
94
+ backbone (string): Name of the backbone type for the CALF model.
95
+ neck (string): Name of the neck type for the CALF model.
96
+ head (string): Name of the head type for the CALF model.
97
+ runner (string): Name of the runner. "runner_CALF" if using SoccerNet dataset modules or "runner_JSON" if using the json format. This will the change the behaviour of processing the predictions while infering.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ cfg=None,
103
+ weights=None,
104
+ backbone="PreExtracted",
105
+ neck="CNN++",
106
+ head="SpottingCALF",
107
+ post_proc="NMS",
108
+ runner="runner_CALF",
109
+ ):
110
+ super().__init__(cfg.training)
111
+
112
+ # check compatibility dims Backbone - Neck - Head
113
+ assert backbone.output_dim == neck.input_size
114
+ assert neck.num_classes == head.num_classes
115
+ assert neck.dim_capsule == head.dim_capsule
116
+ assert neck.num_detections == head.num_detections
117
+ assert neck.chunk_size == head.chunk_size
118
+
119
+ self.chunk_size = neck.chunk_size
120
+ self.receptive_field = neck.receptive_field
121
+ self.framerate = neck.framerate
122
+
123
+ self.model = ContextAwareModel(weights, backbone, neck, head, post_proc)
124
+
125
+ self.overwrite = True
126
+
127
+ self.cfg = cfg
128
+
129
+ self.runner = runner
130
+
131
+ self.infer_split = getattr(cfg, "infer_split", True)
132
+
133
+ def process(self, labels, targets, feats):
134
+ labels = labels.float()
135
+ targets = targets.float()
136
+ feats = feats.unsqueeze(1)
137
+ return labels, targets, feats
138
+
139
+ def _common_step(self, batch, batch_idx):
140
+ """Operations in common for training and validation steps.
141
+ Process the features, labels and targets. The features are processed by the model to compute the outputs.
142
+ These outputs are used to compute the loss.
143
+ """
144
+ feats, labels, targets = batch
145
+ labels, targets, feats = self.process(labels, targets, feats)
146
+ output_segmentation, output_spotting = self.forward(feats)
147
+ return self.criterion(
148
+ [labels, targets], [output_segmentation, output_spotting]
149
+ ), feats.size(0)
150
+
151
+ def training_step(self, batch, batch_idx):
152
+ """Training step that defines the train loop."""
153
+ loss, size = self._common_step(batch, batch_idx)
154
+ self.log_dict({"loss": loss}, on_step=True, on_epoch=True, prog_bar=True)
155
+ self.losses.update(loss.item(), size)
156
+ return loss
157
+
158
+ def validation_step(self, batch, batch_idx):
159
+ """Validation step that defines the val loop."""
160
+ val_loss, size = self._common_step(batch, batch_idx)
161
+ self.log_dict(
162
+ {"valid_loss": val_loss}, on_step=False, on_epoch=True, prog_bar=True
163
+ )
164
+ self.losses.update(val_loss.item(), size)
165
+ return val_loss
166
+
167
+ def on_predict_start(self):
168
+ """Operations to make before starting to infer."""
169
+ self.stop_predict = False
170
+
171
+ if self.infer_split:
172
+ self.output_folder, self.output_results, self.stop_predict = (
173
+ check_if_should_predict(
174
+ self.cfg.dataset.test.results, self.cfg.work_dir, self.overwrite
175
+ )
176
+ )
177
+ if self.runner == "runner_JSON":
178
+ self.target_dir = os.path.join(self.cfg.work_dir, self.output_folder)
179
+ else:
180
+ self.target_dir = self.output_results
181
+
182
+ if not self.stop_predict:
183
+ self.spotting_predictions = list()
184
+ self.spotting_grountruth = list()
185
+ self.spotting_grountruth_visibility = list()
186
+ self.segmentation_predictions = list()
187
+
188
+ def on_predict_end(self):
189
+ """Operations to make after inference.
190
+ The process is different whether the data come from json or from the SoccerNet dataset in the way we will store the jsons containing the predictions.
191
+ """
192
+ if not self.stop_predict:
193
+ # Transformation to numpy for evaluation
194
+ targets_numpy = list()
195
+ closests_numpy = list()
196
+ detections_numpy = list()
197
+ for target, detection in zip(
198
+ self.spotting_grountruth_visibility, self.spotting_predictions
199
+ ):
200
+ target_numpy = target.cpu().numpy()
201
+ targets_numpy.append(target_numpy)
202
+ detections_numpy.append(NMS(detection.numpy(), 20 * self.framerate))
203
+ closest_numpy = np.zeros(target_numpy.shape) - 1
204
+ # Get the closest action index
205
+ for c in np.arange(target_numpy.shape[-1]):
206
+ indexes = np.where(target_numpy[:, c] != 0)[0].tolist()
207
+ if len(indexes) == 0:
208
+ continue
209
+ indexes.insert(0, -indexes[0])
210
+ indexes.append(2 * closest_numpy.shape[0])
211
+ for i in np.arange(len(indexes) - 2) + 1:
212
+ start = max(0, (indexes[i - 1] + indexes[i]) // 2)
213
+ stop = min(
214
+ closest_numpy.shape[0], (indexes[i] + indexes[i + 1]) // 2
215
+ )
216
+ closest_numpy[start:stop, c] = target_numpy[indexes[i], c]
217
+ closests_numpy.append(closest_numpy)
218
+
219
+ # Save the predictions to the json format
220
+ # if save_predictions:
221
+ if self.runner == "runner_CALF":
222
+ list_game = self.trainer.predict_dataloaders.dataset.listGames
223
+ for index in np.arange(len(list_game)):
224
+ json_data = get_json_data(list_game[index])
225
+ if self.infer_split:
226
+ os.makedirs(
227
+ os.path.join(
228
+ self.cfg.work_dir, self.output_folder, list_game[index]
229
+ ),
230
+ exist_ok=True,
231
+ )
232
+ output_file = os.path.join(
233
+ self.cfg.work_dir,
234
+ self.output_folder,
235
+ list_game[index],
236
+ "results_spotting.json",
237
+ )
238
+ else:
239
+ output_file = os.path.join(
240
+ self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
241
+ )
242
+ json_data = predictions2json(
243
+ detections_numpy[index * 2],
244
+ detections_numpy[(index * 2) + 1],
245
+ json_data,
246
+ output_file,
247
+ self.framerate,
248
+ )
249
+ self.json_data = json_data
250
+ elif self.runner == "runner_JSON":
251
+ list_videos = self.trainer.predict_dataloaders.dataset.data_json[0][
252
+ "videos"
253
+ ]
254
+ for index in np.arange(len(list_videos)):
255
+ video = list_videos[index]["path"]
256
+
257
+ if self.infer_split:
258
+ video = os.path.splitext(video)[0]
259
+ os.makedirs(
260
+ os.path.join(self.cfg.work_dir, self.output_folder, video),
261
+ exist_ok=True,
262
+ )
263
+ output_file = os.path.join(
264
+ self.cfg.work_dir,
265
+ self.output_folder,
266
+ video,
267
+ "results_spotting.json",
268
+ )
269
+ else:
270
+ output_file = os.path.join(
271
+ self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
272
+ )
273
+
274
+ json_data = get_json_data(video)
275
+ json_data = predictions2json_runnerjson(
276
+ detections_numpy[index],
277
+ json_data,
278
+ output_file,
279
+ self.framerate,
280
+ inverse_event_dictionary=self.trainer.predict_dataloaders.dataset.inverse_event_dictionary,
281
+ )
282
+ self.json_data = json_data
283
+ if self.infer_split:
284
+ zipResults(
285
+ zip_path=self.output_results,
286
+ target_dir=os.path.join(self.cfg.work_dir, self.output_folder),
287
+ filename="results_spotting.json",
288
+ )
289
+ logging.info("Predictions saved")
290
+ logging.info(
291
+ os.path.join(
292
+ self.cfg.work_dir,
293
+ self.output_folder,
294
+ )
295
+ )
296
+ logging.info("Predictions saved")
297
+ logging.info(self.output_results)
298
+ else:
299
+ logging.info("Predictions saved")
300
+ logging.info(
301
+ os.path.join(
302
+ self.cfg.work_dir, f"{self.cfg.dataset.test.results}.json"
303
+ )
304
+ )
305
+
306
+ def predict_step(self, batch):
307
+ """Infer step.
308
+ The process is different whether the data come from json or from the SoccerNet dataset.
309
+ In particular, processing data from json means processing one video (features) while processing data from SOccerNet
310
+ means processing two halfs of a game.
311
+ """
312
+ if not self.stop_predict:
313
+ if self.runner == "runner_CALF":
314
+ feat_half1, feat_half2, label_half1, label_half2 = batch
315
+
316
+ label_half1 = label_half1.float().squeeze(0)
317
+ label_half2 = label_half2.float().squeeze(0)
318
+
319
+ feat_half1 = feat_half1.squeeze(0)
320
+ feat_half2 = feat_half2.squeeze(0)
321
+
322
+ feat_half1 = feat_half1.unsqueeze(1)
323
+ feat_half2 = feat_half2.unsqueeze(1)
324
+
325
+ # Compute the output
326
+ output_segmentation_half_1, output_spotting_half_1 = self.forward(
327
+ feat_half1
328
+ )
329
+ output_segmentation_half_2, output_spotting_half_2 = self.forward(
330
+ feat_half2
331
+ )
332
+
333
+ timestamp_long_half_1 = timestamps2long(
334
+ output_spotting_half_1.cpu().detach(),
335
+ label_half1.size()[0],
336
+ self.chunk_size,
337
+ self.receptive_field,
338
+ )
339
+ timestamp_long_half_2 = timestamps2long(
340
+ output_spotting_half_2.cpu().detach(),
341
+ label_half2.size()[0],
342
+ self.chunk_size,
343
+ self.receptive_field,
344
+ )
345
+ segmentation_long_half_1 = batch2long(
346
+ output_segmentation_half_1.cpu().detach(),
347
+ label_half1.size()[0],
348
+ self.chunk_size,
349
+ self.receptive_field,
350
+ )
351
+ segmentation_long_half_2 = batch2long(
352
+ output_segmentation_half_2.cpu().detach(),
353
+ label_half2.size()[0],
354
+ self.chunk_size,
355
+ self.receptive_field,
356
+ )
357
+
358
+ self.spotting_grountruth.append(torch.abs(label_half1))
359
+ self.spotting_grountruth.append(torch.abs(label_half2))
360
+ self.spotting_grountruth_visibility.append(label_half1)
361
+ self.spotting_grountruth_visibility.append(label_half2)
362
+ self.spotting_predictions.append(timestamp_long_half_1)
363
+ self.spotting_predictions.append(timestamp_long_half_2)
364
+ self.segmentation_predictions.append(segmentation_long_half_1)
365
+ self.segmentation_predictions.append(segmentation_long_half_2)
366
+ elif self.runner == "runner_JSON":
367
+ features, labels = batch
368
+
369
+ labels = labels.float().squeeze(0)
370
+
371
+ features = features.squeeze(0)
372
+
373
+ features = features.unsqueeze(1)
374
+
375
+ # Compute the output
376
+ output_segmentation, output_spotting = self.forward(features)
377
+
378
+ timestamp_long = timestamps2long(
379
+ output_spotting.cpu().detach(),
380
+ labels.size()[0],
381
+ self.chunk_size,
382
+ self.receptive_field,
383
+ )
384
+ segmentation_long = batch2long(
385
+ output_segmentation.cpu().detach(),
386
+ labels.size()[0],
387
+ self.chunk_size,
388
+ self.receptive_field,
389
+ )
390
+
391
+ self.spotting_grountruth.append(torch.abs(labels))
392
+ self.spotting_grountruth_visibility.append(labels)
393
+ self.spotting_predictions.append(timestamp_long)
394
+ self.segmentation_predictions.append(segmentation_long)