opensportslib 0.0.1.dev15__tar.gz → 0.0.1.dev17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {opensportslib-0.0.1.dev15/opensportslib.egg-info → opensportslib-0.0.1.dev17}/PKG-INFO +2 -1
  2. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/apis/localization.py +17 -6
  3. opensportslib-0.0.1.dev17/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +145 -0
  4. opensportslib-0.0.1.dev15/opensportslib/config/sngar_frames.yaml → opensportslib-0.0.1.dev17/opensportslib/config/sngar-frames.yaml +22 -12
  5. opensportslib-0.0.1.dev15/opensportslib/config/classification_tracking.yaml → opensportslib-0.0.1.dev17/opensportslib/config/sngar-tracking.yaml +12 -3
  6. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/trainer/classification_trainer.py +1 -1
  7. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/trainer/localization_trainer.py +37 -20
  8. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/load_annotations.py +72 -2
  9. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/wandb.py +10 -0
  10. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/classification_dataset.py +24 -3
  11. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/localization_dataset.py +33 -29
  12. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/metrics/classification_metric.py +40 -23
  13. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/metrics/localization_metric.py +1 -1
  14. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/contextaware.py +12 -12
  15. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/learnablepooling.py +26 -26
  16. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/builder.py +4 -3
  17. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/gsm.py +10 -3
  18. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/utils.py +27 -9
  19. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17/opensportslib.egg-info}/PKG-INFO +2 -1
  20. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib.egg-info/SOURCES.txt +3 -8
  21. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib.egg-info/requires.txt +1 -0
  22. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/pyproject.toml +2 -2
  23. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -79
  24. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/gin.yaml +0 -79
  25. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -79
  26. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -79
  27. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -79
  28. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -79
  29. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/LICENSE +0 -0
  30. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/LICENSE-COMMERCIAL +0 -0
  31. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/MANIFEST.in +0 -0
  32. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/README.md +0 -0
  33. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/examples/quickstart/basic_classification.py +0 -0
  34. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/examples/quickstart/basic_localization.py +0 -0
  35. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/__init__.py +0 -0
  36. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/apis/__init__.py +0 -0
  37. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/apis/classification.py +0 -0
  38. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/config/classification.yaml +0 -0
  39. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  40. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/config/localization.yaml +0 -0
  41. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/__init__.py +0 -0
  42. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/__init__.py +0 -0
  43. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/builder.py +0 -0
  44. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/calf.py +0 -0
  45. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/ce.py +0 -0
  46. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/combine.py +0 -0
  47. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/loss/nll.py +0 -0
  48. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/optimizer/__init__.py +0 -0
  49. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/optimizer/builder.py +0 -0
  50. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  51. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/scheduler/__init__.py +0 -0
  52. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/scheduler/builder.py +0 -0
  53. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/trainer/__init__.py +0 -0
  54. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/checkpoint.py +0 -0
  55. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/config.py +0 -0
  56. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/data.py +0 -0
  57. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/ddp.py +0 -0
  58. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/default_args.py +0 -0
  59. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/lightning.py +0 -0
  60. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/seed.py +0 -0
  61. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/core/utils/video_processing.py +0 -0
  62. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/__init__.py +0 -0
  63. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/builder.py +0 -0
  64. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/utils/__init__.py +0 -0
  65. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/datasets/utils/tracking.py +0 -0
  66. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/__init__.py +0 -0
  67. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/backbones/builder.py +0 -0
  68. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/e2e.py +0 -0
  69. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/tracking.py +0 -0
  70. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/vars.py +0 -0
  71. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/video.py +0 -0
  72. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/base/video_mae.py +0 -0
  73. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/heads/builder.py +0 -0
  74. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/neck/builder.py +0 -0
  75. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/common.py +0 -0
  76. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/__init__.py +0 -0
  77. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/asformer.py +0 -0
  78. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/calf.py +0 -0
  79. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/gtad.py +0 -0
  80. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/impl/tsm.py +0 -0
  81. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/litebase.py +0 -0
  82. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/modules.py +0 -0
  83. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib/models/utils/shift.py +0 -0
  84. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib.egg-info/dependency_links.txt +0 -0
  85. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/opensportslib.egg-info/top_level.txt +0 -0
  86. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev17}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev15
3
+ Version: 0.0.1.dev17
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -21,6 +21,7 @@ Requires-Dist: wandb
21
21
  Requires-Dist: opencv-python
22
22
  Requires-Dist: omegaconf
23
23
  Requires-Dist: timm
24
+ Requires-Dist: seaborn
24
25
  Provides-Extra: localization
25
26
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
27
  Requires-Dist: cupy-cuda12x; extra == "localization"
@@ -101,7 +101,7 @@ class LocalizationAPI:
101
101
 
102
102
  device = select_device(self.config.SYSTEM)
103
103
  self.model = build_model(self.config, device=device)
104
- print(self.model)
104
+ print(f"model: {self.model}")
105
105
 
106
106
 
107
107
  # Datasets
@@ -155,7 +155,7 @@ class LocalizationAPI:
155
155
  from opensportslib.core.trainer.localization_trainer import build_inferer, build_evaluator
156
156
  from opensportslib.core.utils.config import select_device, resolve_config_omega, is_local_path
157
157
  from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
158
- from opensportslib.core.utils.load_annotations import check_config, has_localization_events
158
+ from opensportslib.core.utils.load_annotations import check_config, has_localization_events, whether_infer_split
159
159
  from opensportslib.core.utils.wandb import init_wandb
160
160
  import time
161
161
 
@@ -163,6 +163,7 @@ class LocalizationAPI:
163
163
  self.config.MODEL.multi_gpu = False
164
164
  self.config = resolve_config_omega(self.config)
165
165
  check_config(self.config, split="test")
166
+ self.config.infer_split = whether_infer_split(self.config.DATA.test)
166
167
  init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
167
168
  logging.info("Configuration:")
168
169
  logging.info(self.config)
@@ -179,19 +180,29 @@ class LocalizationAPI:
179
180
  logging.info("No predictions provided, running inference.")
180
181
  device = select_device(self.config.SYSTEM)
181
182
  self.model = build_model(self.config, device=device)
183
+ inner_model = getattr(self.model, "_model", None)
184
+ if inner_model is None:
185
+ inner_model = getattr(self.model, "model", self.model)
182
186
  print("Model type:", type(self.model))
183
- print("Torch model type:", type(self.model._model))
187
+ print("Torch model type:", type(inner_model))
184
188
  # Load model
185
189
  if pretrained:
186
190
  #pretrained = expand(pretrained)
187
191
  if is_local_path(pretrained):
188
192
  self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(pretrained))
189
193
 
190
- self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
194
+ inner_model, _, _, epoch = load_checkpoint(model=inner_model,
191
195
  path=pretrained,
192
196
  device=device,
193
197
  key_remap_fn=localization_remap)
194
198
 
199
+ if hasattr(self.model, "_model"):
200
+ self.model._model = inner_model
201
+ elif hasattr(self.model, "model"):
202
+ self.model.model = inner_model
203
+ else:
204
+ self.model = inner_model
205
+
195
206
  # Datasets
196
207
  # Test
197
208
  data_obj_test = build_dataset(self.config, split="test")
@@ -206,7 +217,7 @@ class LocalizationAPI:
206
217
  # # Inference
207
218
  inferer = build_inferer(cfg=self.config.MODEL,
208
219
  model=self.model)
209
- json_gz_file = inferer.infer(cfg=self.config, data=dataset_Test)
220
+ json_gz_file = inferer.infer(cfg=self.config, data=dataset_Test, dataloader=test_loader)
210
221
 
211
222
  #json_gz_file = self.config.DATA.test.results + ".recall.json.gz"
212
223
  json_gz_file = predictions if predictions else json_gz_file
@@ -219,7 +230,7 @@ class LocalizationAPI:
219
230
  evaluator = build_evaluator(cfg=self.config)
220
231
  metrics = evaluator.evaluate(
221
232
  cfg_testset=self.config.DATA.test,
222
- json_gz_file=json_gz_file
233
+ json_gz_file=self.config.DATA.test.results if isinstance(json_gz_file, dict) else json_gz_file
223
234
  )
224
235
  else:
225
236
  logging.info("No labels found in annotation file → skipping evaluation")
@@ -0,0 +1,145 @@
1
+ TASK: localization
2
+
3
+ dali: false
4
+
5
+ DATA:
6
+ dataset_name: SoccerNet
7
+ data_dir: /home/vorajv/opensportslib/SoccerNet/
8
+ classes:
9
+ - Penalty
10
+ - Kick-off
11
+ - Goal
12
+ - Substitution
13
+ - Offside
14
+ - Shots on target
15
+ - Shots off target
16
+ - Clearance
17
+ - Ball out of play
18
+ - Throw-in
19
+ - Foul
20
+ - Indirect free-kick
21
+ - Direct free-kick
22
+ - Corner
23
+ - Yellow card
24
+ - Red card
25
+ - Yellow->red card
26
+
27
+ epoch_num_frames: 500000
28
+ mixup: true
29
+ modality: rgb
30
+ crop_dim: -1
31
+ dilate_len: 0 # Dilate ground truth labels
32
+ clip_len: 100
33
+ input_fps: 25
34
+ extract_fps: 2
35
+ imagenet_mean: [0.485, 0.456, 0.406]
36
+ imagenet_std: [0.229, 0.224, 0.225]
37
+ target_height: 224
38
+ target_width: 398
39
+
40
+ train:
41
+ type: FeatureClipsfromJSON
42
+ classes: ${DATA.classes}
43
+ output_map: [data, label]
44
+ video_path: ${DATA.data_dir}
45
+ path: ${DATA.train.video_path}/annotations-2024-224p-train.json
46
+ framerate: 2
47
+ window_size: 20
48
+ dataloader:
49
+ batch_size: 256
50
+ shuffle: true
51
+ num_workers: 4
52
+ pin_memory: true
53
+
54
+ valid:
55
+ type: FeatureClipsfromJSON
56
+ classes: ${DATA.classes}
57
+ output_map: [data, label]
58
+ video_path: ${DATA.data_dir}
59
+ path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
60
+ framerate: 2
61
+ window_size: 20
62
+ dataloader:
63
+ batch_size: 256
64
+ shuffle: true
65
+ num_workers: 4
66
+ pin_memory: true
67
+
68
+ test:
69
+ type: FeatureVideosfromJSON
70
+ classes: ${DATA.classes}
71
+ output_map: [data, label]
72
+ video_path: ${DATA.data_dir}
73
+ path: ${DATA.test.video_path}/annotations-2024-224p-test.json
74
+ results: results_spotting_test_netvlad++_resnetpca512
75
+ framerate: 2
76
+ window_size: 20
77
+ metric: tight
78
+ dataloader:
79
+ batch_size: 1
80
+ shuffle: false
81
+ num_workers: 1
82
+ pin_memory: true
83
+
84
+ MODEL:
85
+ type: LearnablePooling
86
+ runner:
87
+ type: runner_JSON
88
+ backbone:
89
+ type: PreExtactedFeatures
90
+ encoder: ResNET_TF2_PCA512
91
+ feature_dim: 512
92
+ output_dim: 512
93
+ framerate: 2
94
+ window_size: 20
95
+ neck:
96
+ type: NetVLAD++
97
+ input_dim: 512
98
+ output_dim: 32768 # 512 clusters * 64 vocab size
99
+ vocab_size: 64
100
+ head:
101
+ type: LinearLayer
102
+ input_dim: 32768
103
+ num_classes: 17
104
+ post_proc:
105
+ type: NMS
106
+ NMS_window: 30
107
+ NMS_threshold: 0.0
108
+ load_weights: null
109
+
110
+ TRAIN:
111
+ type: trainer_pooling
112
+ max_epochs: 1000
113
+ evaluation_frequency: 1000
114
+ framerate: 2
115
+ batch_size: 256
116
+
117
+ criterion:
118
+ type: NLLLoss
119
+
120
+ optimizer:
121
+ type: Adam
122
+ lr: 0.001
123
+ betas: [0.9, 0.999]
124
+ eps: 1e-08
125
+ weight_decay: 0
126
+ amsgrad: false
127
+
128
+ scheduler:
129
+ type: ReduceLROnPlateau
130
+ mode: min
131
+ factor: 1e-03
132
+ min_lr: 1e-06
133
+ patience: 10
134
+ verbose: true
135
+
136
+
137
+ SYSTEM:
138
+ log_dir: ./logs
139
+ save_dir: ./checkpoints
140
+ work_dir: ${SYSTEM.save_dir}
141
+ seed: 42
142
+ GPU: 4 # number of gpus to use
143
+ device: cuda # auto | cuda | cpu
144
+ gpu_id: 0 # device id for single gpu training
145
+
@@ -1,29 +1,36 @@
1
1
  TASK: classification
2
+ # this is the config for classification task on sngar-frames dataset.
3
+ # this config is used for the main experiments reported in the paper:
4
+ # "Pixels or Positions? Benchmarking Modalities in Group Activity Recognition"
5
+ # https://arxiv.org/abs/2511.12606
6
+ # videomaev2 - fully finetuned on the sngar-frames dataset.
7
+ # it has all the hyperparameters value used to reproduce the results reported in the paper.
2
8
 
3
9
  DATA:
4
10
  dataset_name: sngar
5
11
  data_dir: /home/spark_user1/opensportslib/sngar-frames
6
12
  data_modality: frames_npy
7
- max_samples: 100
13
+ # max_samples: 100 # only used for quick testing
8
14
  num_frames: 16
9
15
  frame_size: [224, 224]
10
16
  train:
11
17
  path: ${DATA.data_dir}/annotations_train.json
12
18
  dataloader:
13
- batch_size: 64
19
+ batch_size: 8 # for frozen backbone, use 64
20
+ # for unfrozen backbone, use 32-16-8 depending on the memory available
14
21
  shuffle: true
15
22
  num_workers: 8
16
23
  pin_memory: true
17
24
  valid:
18
25
  path: ${DATA.data_dir}/annotations_valid.json
19
26
  dataloader:
20
- batch_size: 64
27
+ batch_size: 8
21
28
  num_workers: 8
22
29
  shuffle: false
23
30
  test:
24
31
  path: ${DATA.data_dir}/annotations_test.json
25
32
  dataloader:
26
- batch_size: 64
33
+ batch_size: 8
27
34
  num_workers: 8
28
35
  shuffle: false
29
36
  augmentations:
@@ -32,15 +39,18 @@ DATA:
32
39
  color_jitter: true
33
40
  jitter_prob: 0.5
34
41
  jitter_params: [0.2, 0.2, 0.2, 0.1]
42
+ data_slicing: # only used for data scaling experiments
43
+ enabled: false
44
+ training_matches: 45 # default: all 45 training matches
35
45
 
36
46
  MODEL:
37
47
  type: custom
38
48
  backbone:
39
- type: dinov3 # dinov3 | clip | videomae | videomae2
40
- pretrained_model: facebook/dinov3-vitb16-pretrain-lvd1689m
41
- # facebook/dinov3-vitb16-pretrain-lvd1689m | openai/clip-vit-base-patch16 | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
49
+ type: videomae2 # dinov3 | videomae | videomae2
50
+ pretrained_model: OpenGVLab/VideoMAEv2-Base
51
+ # facebook/dinov3-vitb16-pretrain-lvd1689m | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
42
52
  hidden_dim: 768
43
- freeze: true
53
+ freeze: false # true for frozen backbone, false for unfrozen backbone i.e. full-finetuning
44
54
  unfreeze_last_n_layers: 0 # 0 = frozen backbone, >0 = unfreeze last N layers
45
55
  neck:
46
56
  type: TemporalAggregation
@@ -56,12 +66,12 @@ MODEL:
56
66
  dropout: 0.1
57
67
 
58
68
  TRAIN:
59
- monitor: balanced_accuracy
60
- mode: max
69
+ monitor: loss # balanced_accuracy, loss
70
+ mode: min # max or min
61
71
  enabled: true
62
72
  use_amp: true
63
73
  mixup_alpha: 0.2
64
- use_weighted_sampler: false
74
+ use_weighted_sampler: true
65
75
  samples_per_class: 4000
66
76
  use_weighted_loss: false
67
77
  epochs: 100
@@ -75,7 +85,7 @@ TRAIN:
75
85
 
76
86
  optimizer:
77
87
  type: AdamW
78
- lr: 0.0001
88
+ lr: 0.00005 # tune lr based on the backbone
79
89
  betas: [0.9, 0.999]
80
90
  eps: 0.0000001
81
91
  weight_decay: 0.0001
@@ -1,9 +1,15 @@
1
1
  TASK: classification
2
+ # this is the config for classification task on sngar-tracking dataset.
3
+ # this config is used for the main experiments reported in the paper:
4
+ # "Pixels or Positions? Benchmarking Modalities in Group Activity Recognition"
5
+ # https://arxiv.org/abs/2511.12606
6
+ # this is used to train our baseline model that is GIN backbone + Maxpool temporal aggregation + positional edges.
7
+ # it has all the hyperparameters value used to reproduce the results reported in the paper.
2
8
 
3
9
  DATA:
4
10
  dataset_name: sngar
5
11
  data_modality: tracking_parquet
6
- data_dir: /home/karkid/opensportslib/sngar-tracking
12
+ data_dir: /home/karkid/opensportslib/tracking-dataset
7
13
  preload_data: false
8
14
  train:
9
15
  type: annotations_train.json
@@ -43,12 +49,15 @@ DATA:
43
49
  pitch_half_width: 50.0
44
50
  max_displacement: 110.0
45
51
  max_ball_height: 30.0
52
+ data_slicing: # only used for data scaling experiments
53
+ enabled: false
54
+ training_matches: 45 # default: all 45 training matches
46
55
 
47
56
  MODEL:
48
57
  type: custom
49
58
  backbone:
50
59
  type: graph_conv
51
- encoder: graphconv
60
+ encoder: gin
52
61
  hidden_dim: 64
53
62
  num_layers: 20
54
63
  dropout: 0.1
@@ -74,7 +83,7 @@ TRAIN:
74
83
  use_weighted_sampler: true
75
84
  use_weighted_loss: false
76
85
  samples_per_class: 4000
77
- epochs: 10
86
+ epochs: 100
78
87
  patience: 10
79
88
  save_every: 20
80
89
  detailed_results: true
@@ -685,7 +685,7 @@ class FramesTrainerClassification(BaseTrainerClassification):
685
685
  self.scaler.step(self.optimizer)
686
686
  self.scaler.update()
687
687
 
688
- return logits, labels, loss
688
+ return logits, labels, loss, True
689
689
 
690
690
  # --------------------------------------------------------------
691
691
  # unified trainer dispatcher
@@ -164,11 +164,12 @@ class Trainer_pl(Trainer):
164
164
  self.work_dir = work_dir
165
165
  call = MyCallback()
166
166
  self.trainer = pl.Trainer(
167
- max_epochs=cfg.max_epochs,
168
- devices=[cfg.GPU],
167
+ max_epochs=cfg.TRAIN.max_epochs,
168
+ devices=cfg.SYSTEM.GPU,
169
169
  callbacks=[call, CustomProgressBar(refresh_rate=1)],
170
170
  num_sanity_val_steps=0,
171
171
  )
172
+ self.best_checkpoint_path = None
172
173
 
173
174
  def train(self, **kwargs):
174
175
  self.trainer.fit(**kwargs)
@@ -177,10 +178,12 @@ class Trainer_pl(Trainer):
177
178
 
178
179
  logging.info("Done training")
179
180
  logging.info("Best epoch: {}".format(best_model.get("epoch")))
180
- torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
181
+ best_path = os.path.join(self.work_dir, "model.pth.tar")
182
+ self.best_checkpoint_path = best_path
183
+ torch.save(best_model, best_path)
181
184
 
182
185
  logging.info("Model saved")
183
- logging.info(os.path.join(self.work_dir, "model.pth.tar"))
186
+ logging.info(best_path)
184
187
 
185
188
 
186
189
  class Trainer_e2e(Trainer):
@@ -496,24 +499,25 @@ class Inferer:
496
499
  self.model = model
497
500
  self.infer_Spotting=infer_Spotting
498
501
 
499
- def infer(self, cfg, data):
502
+ def infer(self, cfg, data, dataloader=None):
500
503
  """Infer actions from data.
501
504
 
502
505
  Args:
503
506
  data : The data from which we will infer.
507
+ dataloader : The dataloader for the test data.
504
508
 
505
509
  Returns:
506
510
  Dict containing predictions
507
511
  """
508
512
  if self.infer_Spotting=="infer_JSON":
509
- return self.infer_JSON(cfg, self.model, data)
513
+ return self.infer_JSON(cfg, self.model, data, dataloader)
510
514
  elif self.infer_Spotting=="infer_SN":
511
- return self.infer_SN(cfg, self.model, data)
515
+ return self.infer_SN(cfg, self.model, data, dataloader)
512
516
  elif self.infer_Spotting=="infer_E2E":
513
- return self.infer_E2E(cfg, self.model, data)
517
+ return self.infer_E2E(cfg, self.model, data, dataloader)
514
518
 
515
519
 
516
- def infer_common(self, cfg, model, data):
520
+ def infer_common(self, cfg, model, data, dataloader=None):
517
521
  """Infer actions from data using a given model.
518
522
 
519
523
  Args:
@@ -525,10 +529,21 @@ class Inferer:
525
529
  Dict containing predictions
526
530
  """
527
531
  # Run Inference on Dataset
528
- pass
532
+ from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
533
+ import pytorch_lightning as pl
534
+
535
+ if cfg.SYSTEM.work_dir is not None and dataloader is not None:
536
+
537
+ evaluator = pl.Trainer(
538
+ callbacks=[CustomProgressBar()],
539
+ devices=cfg.SYSTEM.GPU,
540
+ num_sanity_val_steps=0,
541
+ )
542
+ evaluator.predict(model, dataloader)
543
+ return model.json_data
529
544
 
530
545
 
531
- def infer_JSON(self, cfg, model, data):
546
+ def infer_JSON(self, cfg, model, data, dataloader=None):
532
547
  """Infer actions from data using a given model for NetVlad/CALF methods
533
548
 
534
549
  Args:
@@ -539,10 +554,10 @@ class Inferer:
539
554
  Returns:
540
555
  Dict containing predictions
541
556
  """
542
- return self.infer_common(cfg, model, data)
557
+ return self.infer_common(cfg, model, data, dataloader)
543
558
 
544
559
 
545
- def infer_SN(self, cfg, model, data):
560
+ def infer_SN(self, cfg, model, data, dataloader=None):
546
561
  """Infer actions from data using a given model for the SNV2 data
547
562
 
548
563
  Args:
@@ -553,10 +568,10 @@ class Inferer:
553
568
  Returns:
554
569
  Dict containing predictions
555
570
  """
556
- return self.infer_common(cfg, model, data)
571
+ return self.infer_common(cfg, model, data, dataloader)
557
572
 
558
573
 
559
- def infer_E2E(self, cfg, model, data):
574
+ def infer_E2E(self, cfg, model, data, dataloader=None):
560
575
  """Infer actions from data using a given model for the e2espot method.
561
576
 
562
577
  Args:
@@ -735,7 +750,6 @@ class Evaluator:
735
750
 
736
751
 
737
752
  def evaluate_common_JSON(self, cfg, results, metric):
738
-
739
753
  if cfg.path is None:
740
754
  return
741
755
 
@@ -756,6 +770,7 @@ class Evaluator:
756
770
 
757
771
  # detect v2 prediction
758
772
  pred_is_v2 = isinstance(pred_data, dict) and pred_data is not None and "data" in pred_data
773
+ print("PRED V2 :", pred_is_v2)
759
774
  # --------------------------------------------------
760
775
  # CLASSES
761
776
  # --------------------------------------------------
@@ -800,10 +815,11 @@ class Evaluator:
800
815
 
801
816
  # ---------------- GT ----------------
802
817
  if gt_is_v2:
818
+ print("Game: ", game)
803
819
  video_path = game["inputs"][0]["path"]
804
820
  labels = [{"label": e.get("label"),
805
821
  "gameTime": e.get("gameTime"),
806
- "position": int(e.get("position_ms")),
822
+ "position": int(e.get("position_ms", e.get("position"))),
807
823
  } for e in game.get("events", [])]
808
824
  else:
809
825
  video_path = game["path"]
@@ -825,7 +841,7 @@ class Evaluator:
825
841
  "label": e.get("label"),
826
842
  "gameTime": e.get("gameTime"),
827
843
  "confidence": e.get("confidence"),
828
- "position": int(e.get("position_ms")),
844
+ "position": int(e.get("position_ms", e.get("position"))),
829
845
  "frame": e.get("frame")
830
846
  }
831
847
  for e in item.get("events", [])
@@ -859,7 +875,7 @@ class Evaluator:
859
875
  "label": e.get("label"),
860
876
  "gameTime": e.get("gameTime"),
861
877
  "confidence": e.get("confidence"),
862
- "position": int(e.get("position_ms")),
878
+ "position": int(e.get("position_ms", e.get("position"))),
863
879
  "frame": e.get("frame")
864
880
  }
865
881
  for e in item.get("events", [])
@@ -997,7 +1013,8 @@ class Evaluator:
997
1013
  Returns
998
1014
  The different mAPs computed.
999
1015
  """
1000
-
1016
+ from SoccerNet.Evaluation.utils import INVERSE_EVENT_DICTIONARY_V2
1017
+ from SoccerNet.Evaluation.ActionSpotting import evaluate
1001
1018
  # challenge sets to be tested on EvalAI
1002
1019
  if "challenge" in cfg.split:
1003
1020
  print("Visit eval.ai to evaluate performances on Challenge set")
@@ -9,11 +9,51 @@ from opensportslib.core.utils.video_processing import get_stride, read_fps, get_
9
9
  from opensportslib.core.utils.config import load_json
10
10
  from collections import defaultdict
11
11
 
12
- def load_annotations(annotations_path, task_key="action", exclude_labels=[""], multiview=False, input_type="video", allow_missing_labels=False):
12
+ def load_annotations(
13
+ annotations_path,
14
+ task_key="action",
15
+ exclude_labels=None,
16
+ multiview=False,
17
+ input_type="video",
18
+ allow_missing_labels=False,
19
+ max_games=None
20
+ ):
13
21
 
14
22
  with open(annotations_path, "r") as f:
15
23
  data = json.load(f)
16
24
 
25
+ # this is used for data slicing experiments.
26
+ # and doesn't affect the validation and test sets.
27
+ # if you haven't added "data_scaling" to the config, this will be ignored.
28
+ if max_games is not None:
29
+ all_game_ids = sorted(set(
30
+ item.get("metadata", {}).get("game_id", "")
31
+ for item in data["data"]
32
+ ))
33
+
34
+ # remove empty string if any items lack game_id
35
+ all_game_ids = [g for g in all_game_ids if g]
36
+
37
+ # warn if any items lack game_id
38
+ items_without_game_id = sum(
39
+ 1 for item in data["data"]
40
+ if not item.get("metadata", {}).get("game_id", "")
41
+ )
42
+ if items_without_game_id > 0:
43
+ print(f"warning: {items_without_game_id}/{len(data['data'])} items have "
44
+ f"no game_id, these will not be affected by data slicing")
45
+
46
+ if max_games < len(all_game_ids):
47
+ keep_ids = set(all_game_ids[:max_games])
48
+ original_count = len(data["data"])
49
+ data["data"] = [
50
+ item for item in data["data"]
51
+ if item.get("metadata", {}).get("game_id", "") in keep_ids
52
+ ]
53
+ print(f"data slicing: {max_games}/{len(all_game_ids)} games, "
54
+ f"{len(data['data'])}/{original_count} samples retained")
55
+ # ----- data slicing ends here -----
56
+
17
57
  exclude_labels = set(exclude_labels or [""])
18
58
 
19
59
  # Label list for the selected task
@@ -496,4 +536,34 @@ def check_config(cfg, split="train"):
496
536
  classes = cfg.DATA.classes
497
537
 
498
538
  #print(classes)
499
- cfg.DATA.classes = load_classes(classes)
539
+ cfg.DATA.classes = load_classes(classes)
540
+
541
+
542
+ def whether_infer_split(cfg):
543
+ """Given a config dict, check whether we want to infer a split or a single element (can be a game, video or feature file)/
544
+
545
+ Args:
546
+ cfg (dict): Config dict.
547
+
548
+ Returns:
549
+ bool : True if we infer split, false otherwise. Raises an error if the input is not expected.
550
+ """
551
+ if cfg.type == "SoccerNetGames" or cfg.type == "SoccerNetClipsTestingCALF":
552
+ if cfg.split == None:
553
+ return False
554
+ else:
555
+ return True
556
+ elif (
557
+ cfg.type == "FeatureVideosfromJSON" or cfg.type == "FeatureVideosChunksfromJson"
558
+ ):
559
+ if cfg.path.endswith(".json"):
560
+ return True
561
+ else:
562
+ return False
563
+ elif cfg.type == "VideoGameWithOpencvVideo" or cfg.type == "VideoGameWithDaliVideo":
564
+ if cfg.path.endswith(".json"):
565
+ return True
566
+ else:
567
+ return False
568
+ else:
569
+ raise ValueError(f"Unknown dataset type {cfg.type}")
@@ -2,6 +2,7 @@ import wandb
2
2
  import matplotlib.pyplot as plt
3
3
  import numpy as np
4
4
  import logging
5
+ import os
5
6
 
6
7
  def init_wandb(cfg, run_id, use_wandb=False):
7
8
  """
@@ -24,6 +25,15 @@ def init_wandb(cfg, run_id, use_wandb=False):
24
25
  logging.warning("wandb not installed. Install with `pip install wandb`.")
25
26
  return None
26
27
 
28
+ # Prevent multiple processes from initializing wandb
29
+ rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
30
+ if rank != 0:
31
+ return None
32
+
33
+ # Prevent re-initialization
34
+ if wandb.run is not None:
35
+ return wandb
36
+
27
37
  if getattr(cfg.DATA, "data_modality", None):
28
38
  run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
29
39
  else: