opensportslib 0.0.1.dev15__tar.gz → 0.0.1.dev16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {opensportslib-0.0.1.dev15/opensportslib.egg-info → opensportslib-0.0.1.dev16}/PKG-INFO +2 -1
  2. opensportslib-0.0.1.dev15/opensportslib/config/sngar_frames.yaml → opensportslib-0.0.1.dev16/opensportslib/config/sngar-frames.yaml +22 -12
  3. opensportslib-0.0.1.dev15/opensportslib/config/classification_tracking.yaml → opensportslib-0.0.1.dev16/opensportslib/config/sngar-tracking.yaml +12 -3
  4. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/trainer/classification_trainer.py +1 -1
  5. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/load_annotations.py +41 -1
  6. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/classification_dataset.py +24 -3
  7. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/metrics/classification_metric.py +40 -23
  8. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16/opensportslib.egg-info}/PKG-INFO +2 -1
  9. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib.egg-info/SOURCES.txt +2 -8
  10. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib.egg-info/requires.txt +1 -0
  11. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/pyproject.toml +2 -2
  12. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/avgpool.yaml +0 -79
  13. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/gin.yaml +0 -79
  14. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/graphconv.yaml +0 -79
  15. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/graphsage.yaml +0 -79
  16. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/maxpool.yaml +0 -79
  17. opensportslib-0.0.1.dev15/opensportslib/config/graph_tracking_classification/noedges.yaml +0 -79
  18. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/LICENSE +0 -0
  19. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/LICENSE-COMMERCIAL +0 -0
  20. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/MANIFEST.in +0 -0
  21. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/README.md +0 -0
  22. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/examples/quickstart/basic_classification.py +0 -0
  23. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/examples/quickstart/basic_localization.py +0 -0
  24. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/__init__.py +0 -0
  25. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/apis/__init__.py +0 -0
  26. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/apis/classification.py +0 -0
  27. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/apis/localization.py +0 -0
  28. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/config/classification.yaml +0 -0
  29. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  30. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/config/localization.yaml +0 -0
  31. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/__init__.py +0 -0
  32. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/__init__.py +0 -0
  33. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/builder.py +0 -0
  34. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/calf.py +0 -0
  35. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/ce.py +0 -0
  36. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/combine.py +0 -0
  37. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/loss/nll.py +0 -0
  38. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/optimizer/__init__.py +0 -0
  39. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/optimizer/builder.py +0 -0
  40. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  41. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/scheduler/__init__.py +0 -0
  42. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/scheduler/builder.py +0 -0
  43. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/trainer/__init__.py +0 -0
  44. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/trainer/localization_trainer.py +0 -0
  45. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/checkpoint.py +0 -0
  46. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/config.py +0 -0
  47. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/data.py +0 -0
  48. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/ddp.py +0 -0
  49. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/default_args.py +0 -0
  50. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/lightning.py +0 -0
  51. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/seed.py +0 -0
  52. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/video_processing.py +0 -0
  53. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/core/utils/wandb.py +0 -0
  54. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/__init__.py +0 -0
  55. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/builder.py +0 -0
  56. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/localization_dataset.py +0 -0
  57. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/utils/__init__.py +0 -0
  58. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/datasets/utils/tracking.py +0 -0
  59. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/metrics/localization_metric.py +0 -0
  60. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/__init__.py +0 -0
  61. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/backbones/builder.py +0 -0
  62. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/contextaware.py +0 -0
  63. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/e2e.py +0 -0
  64. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/learnablepooling.py +0 -0
  65. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/tracking.py +0 -0
  66. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/vars.py +0 -0
  67. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/video.py +0 -0
  68. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/base/video_mae.py +0 -0
  69. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/builder.py +0 -0
  70. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/heads/builder.py +0 -0
  71. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/neck/builder.py +0 -0
  72. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/common.py +0 -0
  73. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/__init__.py +0 -0
  74. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/asformer.py +0 -0
  75. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/calf.py +0 -0
  76. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/gsm.py +0 -0
  77. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/gtad.py +0 -0
  78. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/impl/tsm.py +0 -0
  79. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/litebase.py +0 -0
  80. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/modules.py +0 -0
  81. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/shift.py +0 -0
  82. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib/models/utils/utils.py +0 -0
  83. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib.egg-info/dependency_links.txt +0 -0
  84. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/opensportslib.egg-info/top_level.txt +0 -0
  85. {opensportslib-0.0.1.dev15 → opensportslib-0.0.1.dev16}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev15
3
+ Version: 0.0.1.dev16
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -21,6 +21,7 @@ Requires-Dist: wandb
21
21
  Requires-Dist: opencv-python
22
22
  Requires-Dist: omegaconf
23
23
  Requires-Dist: timm
24
+ Requires-Dist: seaborn
24
25
  Provides-Extra: localization
25
26
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
27
  Requires-Dist: cupy-cuda12x; extra == "localization"
@@ -1,29 +1,36 @@
1
1
  TASK: classification
2
+ # this is the config for classification task on sngar-frames dataset.
3
+ # this config is used for the main experiments reported in the paper:
4
+ # "Pixels or Positions? Benchmarking Modalities in Group Activity Recognition"
5
+ # https://arxiv.org/abs/2511.12606
6
+ # videomaev2 - fully finetuned on the sngar-frames dataset.
7
+ # it has all the hyperparameters value used to reproduce the results reported in the paper.
2
8
 
3
9
  DATA:
4
10
  dataset_name: sngar
5
11
  data_dir: /home/spark_user1/opensportslib/sngar-frames
6
12
  data_modality: frames_npy
7
- max_samples: 100
13
+ # max_samples: 100 # only used for quick testing
8
14
  num_frames: 16
9
15
  frame_size: [224, 224]
10
16
  train:
11
17
  path: ${DATA.data_dir}/annotations_train.json
12
18
  dataloader:
13
- batch_size: 64
19
+ batch_size: 8 # for frozen backbone, use 64
20
+ # for unfrozen backbone, use 32-16-8 depending on the memory available
14
21
  shuffle: true
15
22
  num_workers: 8
16
23
  pin_memory: true
17
24
  valid:
18
25
  path: ${DATA.data_dir}/annotations_valid.json
19
26
  dataloader:
20
- batch_size: 64
27
+ batch_size: 8
21
28
  num_workers: 8
22
29
  shuffle: false
23
30
  test:
24
31
  path: ${DATA.data_dir}/annotations_test.json
25
32
  dataloader:
26
- batch_size: 64
33
+ batch_size: 8
27
34
  num_workers: 8
28
35
  shuffle: false
29
36
  augmentations:
@@ -32,15 +39,18 @@ DATA:
32
39
  color_jitter: true
33
40
  jitter_prob: 0.5
34
41
  jitter_params: [0.2, 0.2, 0.2, 0.1]
42
+ data_slicing: # only used for data scaling experiments
43
+ enabled: false
44
+ training_matches: 45 # default: all 45 training matches
35
45
 
36
46
  MODEL:
37
47
  type: custom
38
48
  backbone:
39
- type: dinov3 # dinov3 | clip | videomae | videomae2
40
- pretrained_model: facebook/dinov3-vitb16-pretrain-lvd1689m
41
- # facebook/dinov3-vitb16-pretrain-lvd1689m | openai/clip-vit-base-patch16 | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
49
+ type: videomae2 # dinov3 | videomae | videomae2
50
+ pretrained_model: OpenGVLab/VideoMAEv2-Base
51
+ # facebook/dinov3-vitb16-pretrain-lvd1689m | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
42
52
  hidden_dim: 768
43
- freeze: true
53
+ freeze: false # true for frozen backbone, false for unfrozen backbone i.e. full-finetuning
44
54
  unfreeze_last_n_layers: 0 # 0 = frozen backbone, >0 = unfreeze last N layers
45
55
  neck:
46
56
  type: TemporalAggregation
@@ -56,12 +66,12 @@ MODEL:
56
66
  dropout: 0.1
57
67
 
58
68
  TRAIN:
59
- monitor: balanced_accuracy
60
- mode: max
69
+ monitor: loss # balanced_accuracy, loss
70
+ mode: min # max or min
61
71
  enabled: true
62
72
  use_amp: true
63
73
  mixup_alpha: 0.2
64
- use_weighted_sampler: false
74
+ use_weighted_sampler: true
65
75
  samples_per_class: 4000
66
76
  use_weighted_loss: false
67
77
  epochs: 100
@@ -75,7 +85,7 @@ TRAIN:
75
85
 
76
86
  optimizer:
77
87
  type: AdamW
78
- lr: 0.0001
88
+ lr: 0.00005 # tune lr based on the backbone
79
89
  betas: [0.9, 0.999]
80
90
  eps: 0.0000001
81
91
  weight_decay: 0.0001
@@ -1,9 +1,15 @@
1
1
  TASK: classification
2
+ # this is the config for classification task on sngar-tracking dataset.
3
+ # this config is used for the main experiments reported in the paper:
4
+ # "Pixels or Positions? Benchmarking Modalities in Group Activity Recognition"
5
+ # https://arxiv.org/abs/2511.12606
6
+ # this is used to train our baseline model that is GIN backbone + Maxpool temporal aggregation + positional edges.
7
+ # it has all the hyperparameters value used to reproduce the results reported in the paper.
2
8
 
3
9
  DATA:
4
10
  dataset_name: sngar
5
11
  data_modality: tracking_parquet
6
- data_dir: /home/karkid/opensportslib/sngar-tracking
12
+ data_dir: /home/karkid/opensportslib/tracking-dataset
7
13
  preload_data: false
8
14
  train:
9
15
  type: annotations_train.json
@@ -43,12 +49,15 @@ DATA:
43
49
  pitch_half_width: 50.0
44
50
  max_displacement: 110.0
45
51
  max_ball_height: 30.0
52
+ data_slicing: # only used for data scaling experiments
53
+ enabled: false
54
+ training_matches: 45 # default: all 45 training matches
46
55
 
47
56
  MODEL:
48
57
  type: custom
49
58
  backbone:
50
59
  type: graph_conv
51
- encoder: graphconv
60
+ encoder: gin
52
61
  hidden_dim: 64
53
62
  num_layers: 20
54
63
  dropout: 0.1
@@ -74,7 +83,7 @@ TRAIN:
74
83
  use_weighted_sampler: true
75
84
  use_weighted_loss: false
76
85
  samples_per_class: 4000
77
- epochs: 10
86
+ epochs: 100
78
87
  patience: 10
79
88
  save_every: 20
80
89
  detailed_results: true
@@ -685,7 +685,7 @@ class FramesTrainerClassification(BaseTrainerClassification):
685
685
  self.scaler.step(self.optimizer)
686
686
  self.scaler.update()
687
687
 
688
- return logits, labels, loss
688
+ return logits, labels, loss, True
689
689
 
690
690
  # --------------------------------------------------------------
691
691
  # unified trainer dispatcher
@@ -9,11 +9,51 @@ from opensportslib.core.utils.video_processing import get_stride, read_fps, get_
9
9
  from opensportslib.core.utils.config import load_json
10
10
  from collections import defaultdict
11
11
 
12
- def load_annotations(annotations_path, task_key="action", exclude_labels=[""], multiview=False, input_type="video", allow_missing_labels=False):
12
+ def load_annotations(
13
+ annotations_path,
14
+ task_key="action",
15
+ exclude_labels=None,
16
+ multiview=False,
17
+ input_type="video",
18
+ allow_missing_labels=False,
19
+ max_games=None
20
+ ):
13
21
 
14
22
  with open(annotations_path, "r") as f:
15
23
  data = json.load(f)
16
24
 
25
+ # this is used for data slicing experiments.
26
+ # and doesn't affect the validation and test sets.
27
+ # if you haven't added "data_scaling" to the config, this will be ignored.
28
+ if max_games is not None:
29
+ all_game_ids = sorted(set(
30
+ item.get("metadata", {}).get("game_id", "")
31
+ for item in data["data"]
32
+ ))
33
+
34
+ # remove empty string if any items lack game_id
35
+ all_game_ids = [g for g in all_game_ids if g]
36
+
37
+ # warn if any items lack game_id
38
+ items_without_game_id = sum(
39
+ 1 for item in data["data"]
40
+ if not item.get("metadata", {}).get("game_id", "")
41
+ )
42
+ if items_without_game_id > 0:
43
+ print(f"warning: {items_without_game_id}/{len(data['data'])} items have "
44
+ f"no game_id, these will not be affected by data slicing")
45
+
46
+ if max_games < len(all_game_ids):
47
+ keep_ids = set(all_game_ids[:max_games])
48
+ original_count = len(data["data"])
49
+ data["data"] = [
50
+ item for item in data["data"]
51
+ if item.get("metadata", {}).get("game_id", "") in keep_ids
52
+ ]
53
+ print(f"data slicing: {max_games}/{len(all_game_ids)} games, "
54
+ f"{len(data['data'])}/{original_count} samples retained")
55
+ # ----- data slicing ends here -----
56
+
17
57
  exclude_labels = set(exclude_labels or [""])
18
58
 
19
59
  # Label list for the selected task
@@ -81,14 +81,35 @@ class ClassificationDataset(Dataset):
81
81
 
82
82
  allow_missing_labels = split in ["test", "infer"]
83
83
 
84
+ # these lines of code are used for data scaling experiments.
85
+ # if you want to check how the model performance changes with different
86
+ # number of games in the training set, you can use this code.
87
+ # to use this, you need to add the following to the config:
88
+ # DATA:
89
+ # data_slicing:
90
+ # enabled: true
91
+ # training_matches: <number of games to include in the training set>
92
+ # we will refer to this as "data slicing" in the rest of the code.
93
+ max_games = None
94
+ slicing_cfg = getattr(config.DATA, "data_slicing", None)
95
+ if slicing_cfg and getattr(slicing_cfg, "enabled", False) and split == "train":
96
+ max_games = getattr(slicing_cfg, "training_matches", None)
97
+
84
98
  self.samples, self.label_map = load_annotations(
85
99
  annotations_path,
86
100
  exclude_labels=self.exclude_labels,
87
101
  multiview=is_multiview,
88
102
  input_type=config.DATA.data_modality,
89
- allow_missing_labels=allow_missing_labels
103
+ allow_missing_labels=allow_missing_labels,
104
+ max_games=max_games
90
105
  )
91
106
 
107
+ # this is used for quick testing of the model.
108
+ # we can only use a small subset of data (ideally 100 samples) to
109
+ # test the overall integration quickly.
110
+ # to use this, you need to add the following to the config:
111
+ # DATA:
112
+ # max_samples: <number of samples to include in the training set>
92
113
  max_samples = getattr(config.DATA, 'max_samples', None)
93
114
  if max_samples:
94
115
  self.samples = self.samples[:max_samples]
@@ -320,7 +341,7 @@ class TrackingDataset(ClassificationDataset):
320
341
  """
321
342
 
322
343
  def __init__(self, config, annotations_path, split="train"):
323
- super().__init__(config, annotations_path, split)
344
+ super().__init__(config, annotations_path, processor=None, split=split)
324
345
 
325
346
  from opensportslib.datasets.utils.tracking import (
326
347
  FEATURE_DIM,
@@ -576,7 +597,7 @@ class TrackingDataset(ClassificationDataset):
576
597
  "seq_len": len(graphs),
577
598
  "id": item["id"]
578
599
  }
579
- if "label" in label:
600
+ if label is not None:
580
601
  out["label"] = label
581
602
  return out
582
603
 
@@ -82,14 +82,26 @@ def compute_detailed_classification_metrics(all_logits, all_labels, class_names,
82
82
 
83
83
  preds = np.argmax(all_logits, axis=-1)
84
84
 
85
- sorted_class_names = sorted(class_names.values())
86
- name_to_sorted_idx = {name: i for i, name in enumerate(sorted_class_names)}
87
85
  idx_to_name = class_names
88
86
 
89
- sorted_labels = np.array([name_to_sorted_idx[idx_to_name[l]] for l in all_labels])
90
- sorted_preds = np.array([name_to_sorted_idx[idx_to_name[p]] for p in preds])
87
+ # count samples per class from true labels to order by frequency
88
+ from collections import Counter
89
+ label_counts = Counter(all_labels)
90
+ # map from class index to name, then sort by count descending
91
+ all_class_names = list(class_names.values())
92
+ name_to_original_idx = {v: k for k, v in class_names.items()}
93
+ ordered_class_names = sorted(
94
+ all_class_names,
95
+ key=lambda name: label_counts.get(name_to_original_idx[name], 0),
96
+ reverse=True
97
+ )
91
98
 
92
- all_class_labels = list(range(len(sorted_class_names)))
99
+ name_to_ordered_idx = {name: i for i, name in enumerate(ordered_class_names)}
100
+
101
+ sorted_labels = np.array([name_to_ordered_idx[idx_to_name[l]] for l in all_labels])
102
+ sorted_preds = np.array([name_to_ordered_idx[idx_to_name[p]] for p in preds])
103
+
104
+ all_class_labels = list(range(len(ordered_class_names)))
93
105
 
94
106
  cm = confusion_matrix(sorted_labels, sorted_preds, labels=all_class_labels)
95
107
  per_class_accuracy = np.diag(cm) / np.maximum(cm.sum(axis=1), 1) * 100
@@ -104,14 +116,14 @@ def compute_detailed_classification_metrics(all_logits, all_labels, class_names,
104
116
  import os
105
117
 
106
118
  plt.figure(figsize=(12, 10))
107
- sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
108
- xticklabels=sorted_class_names, yticklabels=sorted_class_names)
119
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
120
+ xticklabels=ordered_class_names, yticklabels=ordered_class_names)
109
121
  plt.title(f'Confusion Matrix ({set_name})')
110
122
  plt.ylabel('True Label')
111
123
  plt.xlabel('Predicted Label')
112
124
  plt.xticks(rotation=45, ha='right')
113
125
  plt.tight_layout()
114
-
126
+
115
127
  plots_dir = os.path.join(save_dir, 'plots')
116
128
  os.makedirs(plots_dir, exist_ok=True)
117
129
  plt.savefig(os.path.join(plots_dir, f'confusion_matrix_{set_name}.png'), dpi=300, bbox_inches='tight')
@@ -119,14 +131,14 @@ def compute_detailed_classification_metrics(all_logits, all_labels, class_names,
119
131
 
120
132
  results_dir = os.path.join(save_dir, 'results')
121
133
  os.makedirs(results_dir, exist_ok=True)
122
-
134
+
123
135
  report_path = os.path.join(results_dir, f'{set_name}_detailed_metrics.txt')
124
136
  with open(report_path, 'w') as f:
125
137
  f.write(f"Balanced Accuracy: {balanced_acc:.2f}%\n")
126
138
  f.write(f"Macro F1: {macro_f1:.2f}%\n\n")
127
139
  f.write(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}\n")
128
140
  f.write("-" * 65 + "\n")
129
- for i, class_name in enumerate(sorted_class_names):
141
+ for i, class_name in enumerate(ordered_class_names):
130
142
  num_samples = int(cm[i].sum())
131
143
  f.write(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}\n")
132
144
  f.write("-" * 65 + "\n\n")
@@ -134,43 +146,48 @@ def compute_detailed_classification_metrics(all_logits, all_labels, class_names,
134
146
  f.write(classification_report(
135
147
  sorted_labels, sorted_preds,
136
148
  labels=all_class_labels,
137
- target_names=sorted_class_names,
149
+ target_names=ordered_class_names,
138
150
  zero_division=0
139
151
  ))
140
152
  f.write("\n" + "-" * 65 + "\n\n")
141
153
  f.write("Confusion Matrix:\n\n")
142
154
  f.write(f"{cm}\n")
143
-
155
+
144
156
  tsv_path = os.path.join(results_dir, f'{set_name}_results.tsv')
145
157
  with open(tsv_path, 'w') as f:
146
- header = "metric\t" + "\t".join(sorted_class_names) + "\toverall"
158
+ header = "metric\toverall\t" + "\t".join(ordered_class_names)
147
159
  f.write(header + "\n")
148
160
 
149
- acc_row = "accuracy\t" + "\t".join(f"{per_class_accuracy[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{balanced_acc:.2f}"
161
+ acc_row = "accuracy\t" + f"{balanced_acc:.2f}\t" + "\t".join(f"{per_class_accuracy[i]:.2f}" for i in range(len(ordered_class_names)))
150
162
  f.write(acc_row + "\n")
151
163
 
152
- f1_row = "f1\t" + "\t".join(f"{per_class_f1[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{macro_f1:.2f}"
164
+ f1_row = "f1\t" + f"{macro_f1:.2f}\t" + "\t".join(f"{per_class_f1[i]:.2f}" for i in range(len(ordered_class_names)))
153
165
  f.write(f1_row + "\n")
154
166
 
155
- samples_row = "samples\t" + "\t".join(str(int(cm[i].sum())) for i in range(len(sorted_class_names))) + f"\t{int(cm.sum())}"
167
+ samples_row = "samples\t" + f"{int(cm.sum())}\t" + "\t".join(str(int(cm[i].sum())) for i in range(len(ordered_class_names)))
156
168
  f.write(samples_row + "\n")
157
169
 
170
+ f.write("\n\nConfusion Matrix\n")
171
+ f.write("True \\ Predicted\t" + "\t".join(ordered_class_names) + "\n")
172
+ for i, name in enumerate(ordered_class_names):
173
+ row = name + "\t" + "\t".join(str(cm[i][j]) for j in range(len(ordered_class_names)))
174
+ f.write(row + "\n")
175
+
158
176
  print(f"Saved TSV to {tsv_path}")
159
-
177
+
160
178
  print(f"\nSaved detailed metrics to {report_path}")
161
179
  print(f"\nBalanced Accuracy: {balanced_acc:.2f}%")
162
180
  print(f"Macro F1: {macro_f1:.2f}%\n")
163
181
  print(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}")
164
182
  print("-" * 65)
165
- for i, class_name in enumerate(sorted_class_names):
183
+ for i, class_name in enumerate(ordered_class_names):
166
184
  num_samples = int(cm[i].sum())
167
185
  print(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}")
168
186
  print("-" * 65)
169
-
187
+
170
188
  return {
171
189
  "balanced_accuracy": balanced_acc,
172
190
  "macro_f1": macro_f1,
173
- "per_class_accuracy": {name: per_class_accuracy[i] for i, name in enumerate(sorted_class_names)},
174
- "per_class_f1": {name: per_class_f1[i] for i, name in enumerate(sorted_class_names)},
175
- }
176
-
191
+ "per_class_accuracy": {name: per_class_accuracy[i] for i, name in enumerate(ordered_class_names)},
192
+ "per_class_f1": {name: per_class_f1[i] for i, name in enumerate(ordered_class_names)},
193
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.0.1.dev15
3
+ Version: 0.0.1.dev16
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -21,6 +21,7 @@ Requires-Dist: wandb
21
21
  Requires-Dist: opencv-python
22
22
  Requires-Dist: omegaconf
23
23
  Requires-Dist: timm
24
+ Requires-Dist: seaborn
24
25
  Provides-Extra: localization
25
26
  Requires-Dist: nvidia-dali-cuda120; extra == "localization"
26
27
  Requires-Dist: cupy-cuda12x; extra == "localization"
@@ -15,16 +15,10 @@ opensportslib/apis/__init__.py
15
15
  opensportslib/apis/classification.py
16
16
  opensportslib/apis/localization.py
17
17
  opensportslib/config/classification.yaml
18
- opensportslib/config/classification_tracking.yaml
19
18
  opensportslib/config/localization-e2e-ocv.yaml
20
19
  opensportslib/config/localization.yaml
21
- opensportslib/config/sngar_frames.yaml
22
- opensportslib/config/graph_tracking_classification/avgpool.yaml
23
- opensportslib/config/graph_tracking_classification/gin.yaml
24
- opensportslib/config/graph_tracking_classification/graphconv.yaml
25
- opensportslib/config/graph_tracking_classification/graphsage.yaml
26
- opensportslib/config/graph_tracking_classification/maxpool.yaml
27
- opensportslib/config/graph_tracking_classification/noedges.yaml
20
+ opensportslib/config/sngar-frames.yaml
21
+ opensportslib/config/sngar-tracking.yaml
28
22
  opensportslib/core/__init__.py
29
23
  opensportslib/core/loss/__init__.py
30
24
  opensportslib/core/loss/builder.py
@@ -11,6 +11,7 @@ wandb
11
11
  opencv-python
12
12
  omegaconf
13
13
  timm
14
+ seaborn
14
15
 
15
16
  [:platform_system != "Darwin" and platform_machine != "arm64"]
16
17
  decord
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "opensportslib"
7
- version = "0.0.1.dev15"
7
+ version = "0.0.1.dev16"
8
8
  description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"
11
- dependencies = [ "SoccerNet", "av", "decord; platform_system != 'Darwin' and platform_machine != 'arm64'", "evaluate", "scikit-learn", "torch", "torchvision", "transformers==4.57.3", "tokenizers==0.22.1", "accelerate", "wandb", "opencv-python", "omegaconf", "timm",]
11
+ dependencies = [ "SoccerNet", "av", "decord; platform_system != 'Darwin' and platform_machine != 'arm64'", "evaluate", "scikit-learn", "torch", "torchvision", "transformers==4.57.3", "tokenizers==0.22.1", "accelerate", "wandb", "opencv-python", "omegaconf", "timm", "seaborn",]
12
12
  [[project.authors]]
13
13
  name = "Jeet Vora"
14
14
 
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: gin
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: avgpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: positional
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: gin
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: maxpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: positional
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: graphconv
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: maxpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: positional
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: sageconv
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: maxpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: positional
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: gin
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: maxpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: positional
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0
@@ -1,79 +0,0 @@
1
- TASK: classification
2
-
3
- DATA:
4
- dataset_name: sngar
5
- data_modality: tracking_parquet
6
- data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
- preload_data: false
8
- annotations:
9
- train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
- valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
- test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
- num_frames: 16
13
- frame_interval: 9
14
- augmentations:
15
- vertical_flip: true
16
- horizontal_flip: true
17
- team_flip: true
18
- normalize: true
19
- num_workers: 20
20
- train_batch_size: 32
21
- valid_batch_size: 32
22
- num_objects: 23
23
- feature_dim: 8
24
- pitch_half_length: 85.0
25
- pitch_half_width: 50.0
26
- max_displacement: 110.0
27
- max_ball_height: 30.0
28
-
29
- MODEL:
30
- type: custom
31
- backbone:
32
- type: graph_conv
33
- encoder: gin
34
- hidden_dim: 64
35
- num_layers: 20
36
- dropout: 0.1
37
- neck:
38
- type: TemporalAggregation
39
- agr_type: maxpool
40
- hidden_dim: 64
41
- dropout: 0.1
42
- head:
43
- type: TrackingClassifier
44
- hidden_dim: 64
45
- dropout: 0.1
46
- num_classes: 10
47
- edge: none
48
- k: 8
49
-
50
- TRAIN:
51
- enabled: true
52
- use_weighted_sampler: true
53
- use_weighted_loss: false
54
- samples_per_class: 4000
55
- epochs: 100
56
- patience: 10
57
- save_every: 20
58
-
59
- optimizer:
60
- type: Adam
61
- lr: 0.001
62
-
63
- scheduler:
64
- type: ReduceLROnPlateau
65
- mode: min
66
- patience: 10
67
- factor: 0.1
68
- min_lr: 1e-8
69
-
70
- criterion:
71
- type: CrossEntropyLoss
72
-
73
- save_dir: ./checkpoints_tracking
74
-
75
- SYSTEM:
76
- log_dir: ./logs
77
- seed: 42
78
- device: cuda
79
- gpu_id: 0