opensportslib 0.1.2.dev1__tar.gz → 0.1.2.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {opensportslib-0.1.2.dev1/opensportslib.egg-info → opensportslib-0.1.2.dev2}/PKG-INFO +41 -11
  2. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/README.md +40 -10
  3. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/examples/quickstart/basic_classification.py +12 -0
  4. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/examples/quickstart/basic_localization.py +12 -0
  5. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/apis/base_task_model.py +35 -0
  6. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/apis/classification.py +16 -10
  7. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/apis/localization.py +29 -8
  8. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/classification_trainer.py +4 -4
  9. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/localization_trainer.py +2 -2
  10. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/config.py +4 -4
  11. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/metrics/localization_metric.py +1 -1
  12. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2/opensportslib.egg-info}/PKG-INFO +41 -11
  13. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/pyproject.toml +1 -1
  14. opensportslib-0.1.2.dev2/tests/test_task_model_api_contract.py +375 -0
  15. opensportslib-0.1.2.dev1/tests/test_task_model_api_contract.py +0 -73
  16. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/LICENSE +0 -0
  17. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/LICENSE-COMMERCIAL +0 -0
  18. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/MANIFEST.in +0 -0
  19. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/__init__.py +0 -0
  20. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/apis/__init__.py +0 -0
  21. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/cli.py +0 -0
  22. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/classification.yaml +0 -0
  23. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  24. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  25. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  26. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/localization.yaml +0 -0
  27. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-frames.yaml +0 -0
  28. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-tracking.yaml +0 -0
  29. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/__init__.py +0 -0
  30. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/__init__.py +0 -0
  31. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/builder.py +0 -0
  32. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/calf.py +0 -0
  33. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/ce.py +0 -0
  34. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/combine.py +0 -0
  35. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/nll.py +0 -0
  36. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
  37. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/builder.py +0 -0
  38. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  39. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
  40. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/builder.py +0 -0
  41. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/__init__.py +0 -0
  42. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
  43. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/data.py +0 -0
  44. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/ddp.py +0 -0
  45. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/default_args.py +0 -0
  46. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/lightning.py +0 -0
  47. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
  48. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/seed.py +0 -0
  49. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/video_processing.py +0 -0
  50. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/wandb.py +0 -0
  51. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/__init__.py +0 -0
  52. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/builder.py +0 -0
  53. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/classification_dataset.py +0 -0
  54. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/localization_dataset.py +0 -0
  55. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
  56. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
  57. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/metrics/classification_metric.py +0 -0
  58. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/__init__.py +0 -0
  59. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/backbones/builder.py +0 -0
  60. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/contextaware.py +0 -0
  61. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/e2e.py +0 -0
  62. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
  63. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/tracking.py +0 -0
  64. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/vars.py +0 -0
  65. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video.py +0 -0
  66. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video_mae.py +0 -0
  67. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/builder.py +0 -0
  68. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/heads/builder.py +0 -0
  69. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/neck/builder.py +0 -0
  70. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/common.py +0 -0
  71. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
  72. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
  73. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
  74. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
  75. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
  76. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
  77. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/litebase.py +0 -0
  78. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/modules.py +0 -0
  79. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/shift.py +0 -0
  80. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/utils.py +0 -0
  81. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib/setup/setup.py +0 -0
  82. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/SOURCES.txt +0 -0
  83. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
  84. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/entry_points.txt +0 -0
  85. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/requires.txt +0 -0
  86. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/top_level.txt +0 -0
  87. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/setup.cfg +0 -0
  88. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/tests/conftest.py +0 -0
  89. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/tests/test_config_utils_smoke.py +0 -0
  90. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/tests/test_package_smoke.py +0 -0
  91. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/tests/test_public_apis_smoke.py +0 -0
  92. {opensportslib-0.1.2.dev1 → opensportslib-0.1.2.dev2}/tests/test_subset_train_infer_integration.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev1
3
+ Version: 0.1.2.dev2
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -123,14 +123,14 @@ print("OpenSportsLib imported successfully")
123
123
  ### Train a classification model
124
124
 
125
125
  ```python
126
- from opensportslib import model
126
+ from opensportslib.apis import ClassificationModel
127
127
 
128
- myModel = model.ClassificationModel(
128
+ my_model = ClassificationModel(
129
129
  config="/path/to/classification.yaml",
130
130
  weights="/path/to/weights.pt", # optional
131
131
  )
132
132
 
133
- myModel.train(
133
+ my_model.train(
134
134
  train_set="/path/to/train_annotations.json",
135
135
  valid_set="/path/to/valid_annotations.json",
136
136
  )
@@ -139,19 +139,29 @@ myModel.train(
139
139
  ### Run inference
140
140
 
141
141
  ```python
142
- from opensportslib import model
142
+ from opensportslib.apis import ClassificationModel
143
143
 
144
- myModel = model.classification(
144
+ my_model = ClassificationModel(
145
145
  config="/path/to/classification.yaml",
146
146
  weights="/path/to/weights.pt", # optional
147
147
  )
148
148
 
149
- predictions = myModel.infer(
149
+ predictions = my_model.infer(
150
150
  test_set="/path/to/test_annotations.json",
151
151
  )
152
152
 
153
- metrics = myModel.evaluate(
153
+ saved_predictions = my_model.save_predictions(
154
+ output_path="/path/to/predictions.json",
155
+ predictions=predictions,
156
+ )
157
+
158
+ metrics = my_model.evaluate(
159
+ test_set="/path/to/test_annotations.json",
160
+ )
161
+
162
+ metrics_from_file = my_model.evaluate(
154
163
  test_set="/path/to/test_annotations.json",
164
+ predictions=saved_predictions,
155
165
  )
156
166
 
157
167
  print(metrics)
@@ -160,10 +170,29 @@ print(metrics)
160
170
  ### Localization example
161
171
 
162
172
  ```python
163
- from opensportslib import model
173
+ from opensportslib.apis import LocalizationModel
164
174
 
165
- myModel = model.localization(
166
- config="/path/to/localization.yaml"
175
+ my_model = LocalizationModel(
176
+ config="/path/to/localization.yaml",
177
+ weights="/path/to/weights.pt", # optional
178
+ )
179
+
180
+ predictions = my_model.infer(
181
+ test_set="/path/to/test_annotations.json",
182
+ )
183
+
184
+ saved_predictions = my_model.save_predictions(
185
+ output_path="/path/to/predictions.json",
186
+ predictions=predictions,
187
+ )
188
+
189
+ metrics = my_model.evaluate(
190
+ test_set="/path/to/test_annotations.json",
191
+ )
192
+
193
+ metrics_from_file = my_model.evaluate(
194
+ test_set="/path/to/test_annotations.json",
195
+ predictions=saved_predictions,
167
196
  )
168
197
  ```
169
198
 
@@ -201,6 +230,7 @@ Generate text descriptions for sports events and temporal segments.
201
230
  Use the README for the fast start, then go deeper through:
202
231
 
203
232
  - Full documentation: https://opensportslab.github.io/opensportslib/
233
+ - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
204
234
  - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
205
235
  - Example configs: [examples/configs/](examples/configs/)
206
236
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
@@ -92,14 +92,14 @@ print("OpenSportsLib imported successfully")
92
92
  ### Train a classification model
93
93
 
94
94
  ```python
95
- from opensportslib import model
95
+ from opensportslib.apis import ClassificationModel
96
96
 
97
- myModel = model.ClassificationModel(
97
+ my_model = ClassificationModel(
98
98
  config="/path/to/classification.yaml",
99
99
  weights="/path/to/weights.pt", # optional
100
100
  )
101
101
 
102
- myModel.train(
102
+ my_model.train(
103
103
  train_set="/path/to/train_annotations.json",
104
104
  valid_set="/path/to/valid_annotations.json",
105
105
  )
@@ -108,19 +108,29 @@ myModel.train(
108
108
  ### Run inference
109
109
 
110
110
  ```python
111
- from opensportslib import model
111
+ from opensportslib.apis import ClassificationModel
112
112
 
113
- myModel = model.classification(
113
+ my_model = ClassificationModel(
114
114
  config="/path/to/classification.yaml",
115
115
  weights="/path/to/weights.pt", # optional
116
116
  )
117
117
 
118
- predictions = myModel.infer(
118
+ predictions = my_model.infer(
119
119
  test_set="/path/to/test_annotations.json",
120
120
  )
121
121
 
122
- metrics = myModel.evaluate(
122
+ saved_predictions = my_model.save_predictions(
123
+ output_path="/path/to/predictions.json",
124
+ predictions=predictions,
125
+ )
126
+
127
+ metrics = my_model.evaluate(
128
+ test_set="/path/to/test_annotations.json",
129
+ )
130
+
131
+ metrics_from_file = my_model.evaluate(
123
132
  test_set="/path/to/test_annotations.json",
133
+ predictions=saved_predictions,
124
134
  )
125
135
 
126
136
  print(metrics)
@@ -129,10 +139,29 @@ print(metrics)
129
139
  ### Localization example
130
140
 
131
141
  ```python
132
- from opensportslib import model
142
+ from opensportslib.apis import LocalizationModel
133
143
 
134
- myModel = model.localization(
135
- config="/path/to/localization.yaml"
144
+ my_model = LocalizationModel(
145
+ config="/path/to/localization.yaml",
146
+ weights="/path/to/weights.pt", # optional
147
+ )
148
+
149
+ predictions = my_model.infer(
150
+ test_set="/path/to/test_annotations.json",
151
+ )
152
+
153
+ saved_predictions = my_model.save_predictions(
154
+ output_path="/path/to/predictions.json",
155
+ predictions=predictions,
156
+ )
157
+
158
+ metrics = my_model.evaluate(
159
+ test_set="/path/to/test_annotations.json",
160
+ )
161
+
162
+ metrics_from_file = my_model.evaluate(
163
+ test_set="/path/to/test_annotations.json",
164
+ predictions=saved_predictions,
136
165
  )
137
166
  ```
138
167
 
@@ -170,6 +199,7 @@ Generate text descriptions for sports events and temporal segments.
170
199
  Use the README for the fast start, then go deeper through:
171
200
 
172
201
  - Full documentation: https://opensportslab.github.io/opensportslib/
202
+ - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
173
203
  - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
174
204
  - Example configs: [examples/configs/](examples/configs/)
175
205
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
@@ -29,6 +29,18 @@ def main():
29
29
 
30
30
  print(metrics)
31
31
 
32
+ saved_predictions = my_model.save_predictions(
33
+ output_path="/path/to/predictions.json",
34
+ predictions=predictions,
35
+ )
36
+
37
+ metrics_from_file = my_model.evaluate(
38
+ test_set="/path/to/test_annotations.json",
39
+ predictions=saved_predictions,
40
+ )
41
+
42
+ print(metrics_from_file)
43
+
32
44
 
33
45
  if __name__ == "__main__":
34
46
  main()
@@ -29,6 +29,18 @@ def main():
29
29
 
30
30
  print(metrics)
31
31
 
32
+ saved_predictions = my_model.save_predictions(
33
+ output_path="/path/to/predictions.json",
34
+ predictions=predictions,
35
+ )
36
+
37
+ metrics_from_file = my_model.evaluate(
38
+ test_set="/path/to/test_annotations.json",
39
+ predictions=saved_predictions,
40
+ )
41
+
42
+ print(metrics_from_file)
43
+
32
44
 
33
45
  if __name__ == "__main__":
34
46
  main()
@@ -3,9 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import json
6
+ import logging
6
7
  import os
7
8
  import uuid
8
9
  from abc import ABC, abstractmethod
10
+ from typing import Any
9
11
 
10
12
  from opensportslib.core.utils.config import expand, load_config_omega
11
13
 
@@ -14,6 +16,8 @@ class BaseTaskModel(ABC):
14
16
  """Thin shared contract for task-level OpenSportsLib wrappers."""
15
17
 
16
18
  def __init__(self, config=None, weights=None):
19
+ self._configure_logging()
20
+
17
21
  if config is None:
18
22
  raise ValueError("config path is required")
19
23
 
@@ -23,10 +27,29 @@ class BaseTaskModel(ABC):
23
27
  data_cfg = getattr(self.config, "DATA", None)
24
28
  if data_cfg is not None and hasattr(data_cfg, "data_dir"):
25
29
  data_cfg.data_dir = expand(data_cfg.data_dir)
30
+ logging.info(f"Data directory: {data_cfg.data_dir}")
26
31
 
27
32
  self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
28
33
  os.environ["RUN_ID"] = self.run_id
29
34
 
35
+ system_cfg = getattr(self.config, "SYSTEM", None)
36
+ if system_cfg is not None:
37
+ base_save_dir = expand(getattr(system_cfg, "save_dir", None) or "./checkpoints")
38
+ model_cfg = getattr(self.config, "MODEL", None)
39
+ backbone_cfg = getattr(model_cfg, "backbone", None)
40
+ model_name = getattr(backbone_cfg, "type", None) or "model"
41
+ run_save_dir = os.path.join(base_save_dir, model_name, self.run_id)
42
+ self.save_dir = run_save_dir
43
+ system_cfg.save_dir = run_save_dir
44
+ if hasattr(system_cfg, "work_dir"):
45
+ system_cfg.work_dir = run_save_dir
46
+ os.makedirs(run_save_dir, exist_ok=True)
47
+ else:
48
+ self.save_dir = expand("./checkpoints")
49
+ os.makedirs(self.save_dir, exist_ok=True)
50
+
51
+ logging.info(f"Save directory: {self.save_dir}")
52
+
30
53
  self.model = None
31
54
  self.processor = None
32
55
  self.trainer = None
@@ -36,6 +59,17 @@ class BaseTaskModel(ABC):
36
59
  if weights is not None:
37
60
  self.load_weights(weights=weights)
38
61
 
62
+ @staticmethod
63
+ def _configure_logging() -> None:
64
+ root_logger = logging.getLogger()
65
+ if not root_logger.handlers:
66
+ logging.basicConfig(
67
+ level=logging.INFO,
68
+ format="%(asctime)s | %(levelname)s | %(message)s",
69
+ )
70
+ elif root_logger.level > logging.INFO:
71
+ root_logger.setLevel(logging.INFO)
72
+
39
73
  @abstractmethod
40
74
  def load_weights(
41
75
  self,
@@ -70,6 +104,7 @@ class BaseTaskModel(ABC):
70
104
  self,
71
105
  test_set: str | None = None,
72
106
  weights: str | None = None,
107
+ predictions: str | dict[str, Any] | None = None,
73
108
  use_wandb: bool = True,
74
109
  **kwargs,
75
110
  ) -> dict | str | None:
@@ -179,6 +179,8 @@ class ClassificationModel(BaseTaskModel):
179
179
 
180
180
  del kwargs
181
181
 
182
+ effective_weights = weights if weights is not None else self.last_loaded_weights
183
+
182
184
  world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
183
185
  use_ddp = use_ddp and world_size > 1
184
186
 
@@ -198,7 +200,7 @@ class ClassificationModel(BaseTaskModel):
198
200
  train_set,
199
201
  valid_set,
200
202
  None,
201
- weights,
203
+ effective_weights,
202
204
  use_wandb,
203
205
  ),
204
206
  nprocs=world_size,
@@ -214,7 +216,7 @@ class ClassificationModel(BaseTaskModel):
214
216
  return_queue=queue,
215
217
  train_set=train_set,
216
218
  valid_set=valid_set,
217
- weights=weights,
219
+ weights=effective_weights,
218
220
  use_wandb=use_wandb,
219
221
  )
220
222
 
@@ -243,6 +245,8 @@ class ClassificationModel(BaseTaskModel):
243
245
  logging.info("Configuration:")
244
246
  logging.info(self.config)
245
247
 
248
+ effective_weights = weights if weights is not None else self.last_loaded_weights
249
+
246
250
  world_size = torch.cuda.device_count()
247
251
  use_ddp = use_ddp and world_size > 1
248
252
 
@@ -261,7 +265,7 @@ class ClassificationModel(BaseTaskModel):
261
265
  None,
262
266
  None,
263
267
  test_set,
264
- weights,
268
+ effective_weights,
265
269
  use_wandb,
266
270
  ),
267
271
  nprocs=world_size,
@@ -275,7 +279,7 @@ class ClassificationModel(BaseTaskModel):
275
279
  config=self.config,
276
280
  return_queue=queue,
277
281
  test_set=test_set,
278
- weights=weights,
282
+ weights=effective_weights,
279
283
  use_wandb=use_wandb,
280
284
  )
281
285
 
@@ -286,6 +290,7 @@ class ClassificationModel(BaseTaskModel):
286
290
  self,
287
291
  test_set=None,
288
292
  weights=None,
293
+ predictions=None,
289
294
  use_ddp=False,
290
295
  use_wandb=True,
291
296
  **kwargs,
@@ -302,12 +307,13 @@ class ClassificationModel(BaseTaskModel):
302
307
  self.config = resolve_config_omega(self.config)
303
308
  logging.info("Configuration:")
304
309
  logging.info(self.config)
305
- predictions = self.infer(
306
- test_set=test_set,
307
- weights=weights,
308
- use_ddp=use_ddp,
309
- use_wandb=use_wandb,
310
- )
310
+ if predictions is None:
311
+ predictions = self.infer(
312
+ test_set=test_set,
313
+ weights=weights,
314
+ use_ddp=use_ddp,
315
+ use_wandb=use_wandb,
316
+ )
311
317
 
312
318
  self.trainer = self.trainer or Trainer_Classification(self.config)
313
319
  test_data = build_dataset(self.config, test_set, None, split="test")
@@ -9,6 +9,12 @@ from opensportslib.core.utils.config import expand
9
9
  class LocalizationModel(BaseTaskModel):
10
10
  """Top-level task wrapper for localization / spotting."""
11
11
 
12
+ def __init__(self, config=None, weights=None):
13
+ super().__init__(config=config, weights=None)
14
+ if weights is not None:
15
+ self.last_loaded_weights = weights
16
+ self.best_checkpoint = weights
17
+
12
18
  def _resolve_split_path(self, split: str, override: str | None = None) -> str:
13
19
  if override is not None:
14
20
  return expand(override)
@@ -68,6 +74,11 @@ class LocalizationModel(BaseTaskModel):
68
74
  if weights is None:
69
75
  raise ValueError("`weights` must be provided to load_weights().")
70
76
 
77
+ model_cfg = getattr(self.config, "MODEL", None)
78
+ original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
79
+ if model_cfg is not None and original_multi_gpu is not None:
80
+ model_cfg.multi_gpu = False
81
+
71
82
  device = select_device(self.config.SYSTEM)
72
83
  if self.model is None:
73
84
  self.model = build_model(self.config, device=device)
@@ -96,6 +107,9 @@ class LocalizationModel(BaseTaskModel):
96
107
  self.last_loaded_weights = weights
97
108
  self.best_checkpoint = weights
98
109
 
110
+ if model_cfg is not None and original_multi_gpu is not None:
111
+ model_cfg.multi_gpu = original_multi_gpu
112
+
99
113
  def train(
100
114
  self,
101
115
  train_set=None,
@@ -137,6 +151,8 @@ class LocalizationModel(BaseTaskModel):
137
151
  logging.info("Configuration:")
138
152
  logging.info(self.config)
139
153
 
154
+ effective_weights = weights if weights is not None else self.last_loaded_weights
155
+
140
156
  def set_seed(seed):
141
157
  random.seed(seed)
142
158
  np.random.seed(seed)
@@ -184,7 +200,7 @@ class LocalizationModel(BaseTaskModel):
184
200
  cfg=self.config,
185
201
  model=self.model,
186
202
  default_args=get_default_args_trainer(self.config, len(train_loader)),
187
- resume_from=weights,
203
+ resume_from=effective_weights,
188
204
  )
189
205
 
190
206
  logging.info("Start training")
@@ -245,8 +261,11 @@ class LocalizationModel(BaseTaskModel):
245
261
 
246
262
  start = time.time()
247
263
 
248
- if weights is not None:
249
- self.load_weights(weights=weights)
264
+ effective_weights = weights if weights is not None else self.last_loaded_weights
265
+
266
+ if effective_weights is not None:
267
+ if self.model is None or self.last_loaded_weights != effective_weights:
268
+ self.load_weights(weights=effective_weights)
250
269
  elif self.model is None:
251
270
  device = select_device(self.config.SYSTEM)
252
271
  self.model = build_model(self.config, device=device)
@@ -278,6 +297,7 @@ class LocalizationModel(BaseTaskModel):
278
297
  self,
279
298
  test_set=None,
280
299
  weights=None,
300
+ predictions=None,
281
301
  use_wandb=True,
282
302
  **kwargs,
283
303
  ):
@@ -307,11 +327,12 @@ class LocalizationModel(BaseTaskModel):
307
327
  use_wandb=use_wandb,
308
328
  )
309
329
 
310
- predictions = self.infer(
311
- test_set=test_set,
312
- weights=weights,
313
- use_wandb=use_wandb,
314
- )
330
+ if predictions is None:
331
+ predictions = self.infer(
332
+ test_set=test_set,
333
+ weights=weights,
334
+ use_wandb=use_wandb,
335
+ )
315
336
 
316
337
  metrics = None
317
338
 
@@ -517,7 +517,7 @@ class BaseTrainerClassification:
517
517
 
518
518
  logging.info(f"RESULTS Length: {len(results)}")
519
519
  logging.info(f"Predictions are stored at : {save_path}")
520
- with open(save_path, "w") as f:
520
+ with open(save_path, "w", encoding="utf-8") as f:
521
521
  json.dump(submission, f, indent=2)
522
522
  self.predictions_payload = submission
523
523
 
@@ -1018,7 +1018,7 @@ class Trainer_Classification:
1018
1018
  out_dir = os.path.join(self.config.SYSTEM.save_dir, "final")
1019
1019
  os.makedirs(out_dir, exist_ok=True)
1020
1020
  out_path = os.path.join(out_dir, "predictions_test_epoch_final.json")
1021
- with open(out_path, "w") as f:
1021
+ with open(out_path, "w", encoding="utf-8") as f:
1022
1022
  json.dump(submission, f, indent=2)
1023
1023
  self.predictions_payload = submission
1024
1024
  return submission
@@ -1107,14 +1107,14 @@ class Trainer_Classification:
1107
1107
  if isinstance(pred_path, dict):
1108
1108
  pred_data = pred_path
1109
1109
  elif isinstance(pred_path, str):
1110
- with open(pred_path) as f:
1110
+ with open(pred_path, encoding="utf-8") as f:
1111
1111
  pred_data = json.load(f)
1112
1112
  else:
1113
1113
  raise TypeError(
1114
1114
  f"Unsupported predictions type: {type(pred_path).__name__}. Expected dict or str."
1115
1115
  )
1116
1116
 
1117
- with open(gt_path) as f:
1117
+ with open(gt_path, encoding="utf-8") as f:
1118
1118
  gt_data = json.load(f)
1119
1119
 
1120
1120
  gt_dict = {}
@@ -791,7 +791,7 @@ class Evaluator:
791
791
  # --------------------------------------------------
792
792
  # LOAD GT
793
793
  # --------------------------------------------------
794
- with open(cfg.path) as f:
794
+ with open(cfg.path, encoding="utf-8") as f:
795
795
  GT_data = json.load(f)
796
796
 
797
797
  # --------------------------------------------------
@@ -895,7 +895,7 @@ class Evaluator:
895
895
  if not os.path.exists(pred_file):
896
896
  continue
897
897
 
898
- with open(pred_file) as f:
898
+ with open(pred_file, encoding="utf-8") as f:
899
899
  pred_data_local = json.load(f)
900
900
 
901
901
  if "data" in pred_data_local:
@@ -105,11 +105,11 @@ def expand(path):
105
105
 
106
106
 
107
107
  def load_json(fpath):
108
- with open(fpath) as fp:
108
+ with open(fpath, encoding="utf-8") as fp:
109
109
  return json.load(fp)
110
110
 
111
111
  def load_gz_json(fpath):
112
- with gzip.open(fpath, "rt", encoding="ascii") as fp:
112
+ with gzip.open(fpath, "rt", encoding="utf-8") as fp:
113
113
  return json.load(fp)
114
114
 
115
115
 
@@ -118,12 +118,12 @@ def store_json(fpath, obj, pretty=False):
118
118
  if pretty:
119
119
  kwargs["indent"] = 4
120
120
  kwargs["sort_keys"] = False
121
- with open(fpath, "w") as fp:
121
+ with open(fpath, "w", encoding="utf-8") as fp:
122
122
  json.dump(obj, fp, **kwargs)
123
123
 
124
124
 
125
125
  def store_gz_json(fpath, obj):
126
- with gzip.open(fpath, "wt", encoding="ascii") as fp:
126
+ with gzip.open(fpath, "wt", encoding="utf-8") as fp:
127
127
  json.dump(obj, fp)
128
128
 
129
129
 
@@ -796,7 +796,7 @@ def store_eval_files_json(raw_pred, eval_dir, save_v2=True):
796
796
  }
797
797
 
798
798
  out_path = os.path.join(video_out_dir, "results_spotting.json")
799
- with open(out_path, "w") as f:
799
+ with open(out_path, "w", encoding="utf-8") as f:
800
800
  json.dump(out, f, indent=2)
801
801
 
802
802
  logging.info(f"Stored V2 predictions → {eval_dir}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev1
3
+ Version: 0.1.2.dev2
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -123,14 +123,14 @@ print("OpenSportsLib imported successfully")
123
123
  ### Train a classification model
124
124
 
125
125
  ```python
126
- from opensportslib import model
126
+ from opensportslib.apis import ClassificationModel
127
127
 
128
- myModel = model.ClassificationModel(
128
+ my_model = ClassificationModel(
129
129
  config="/path/to/classification.yaml",
130
130
  weights="/path/to/weights.pt", # optional
131
131
  )
132
132
 
133
- myModel.train(
133
+ my_model.train(
134
134
  train_set="/path/to/train_annotations.json",
135
135
  valid_set="/path/to/valid_annotations.json",
136
136
  )
@@ -139,19 +139,29 @@ myModel.train(
139
139
  ### Run inference
140
140
 
141
141
  ```python
142
- from opensportslib import model
142
+ from opensportslib.apis import ClassificationModel
143
143
 
144
- myModel = model.classification(
144
+ my_model = ClassificationModel(
145
145
  config="/path/to/classification.yaml",
146
146
  weights="/path/to/weights.pt", # optional
147
147
  )
148
148
 
149
- predictions = myModel.infer(
149
+ predictions = my_model.infer(
150
150
  test_set="/path/to/test_annotations.json",
151
151
  )
152
152
 
153
- metrics = myModel.evaluate(
153
+ saved_predictions = my_model.save_predictions(
154
+ output_path="/path/to/predictions.json",
155
+ predictions=predictions,
156
+ )
157
+
158
+ metrics = my_model.evaluate(
159
+ test_set="/path/to/test_annotations.json",
160
+ )
161
+
162
+ metrics_from_file = my_model.evaluate(
154
163
  test_set="/path/to/test_annotations.json",
164
+ predictions=saved_predictions,
155
165
  )
156
166
 
157
167
  print(metrics)
@@ -160,10 +170,29 @@ print(metrics)
160
170
  ### Localization example
161
171
 
162
172
  ```python
163
- from opensportslib import model
173
+ from opensportslib.apis import LocalizationModel
164
174
 
165
- myModel = model.localization(
166
- config="/path/to/localization.yaml"
175
+ my_model = LocalizationModel(
176
+ config="/path/to/localization.yaml",
177
+ weights="/path/to/weights.pt", # optional
178
+ )
179
+
180
+ predictions = my_model.infer(
181
+ test_set="/path/to/test_annotations.json",
182
+ )
183
+
184
+ saved_predictions = my_model.save_predictions(
185
+ output_path="/path/to/predictions.json",
186
+ predictions=predictions,
187
+ )
188
+
189
+ metrics = my_model.evaluate(
190
+ test_set="/path/to/test_annotations.json",
191
+ )
192
+
193
+ metrics_from_file = my_model.evaluate(
194
+ test_set="/path/to/test_annotations.json",
195
+ predictions=saved_predictions,
167
196
  )
168
197
  ```
169
198
 
@@ -201,6 +230,7 @@ Generate text descriptions for sports events and temporal segments.
201
230
  Use the README for the fast start, then go deeper through:
202
231
 
203
232
  - Full documentation: https://opensportslab.github.io/opensportslib/
233
+ - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
204
234
  - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
205
235
  - Example configs: [examples/configs/](examples/configs/)
206
236
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "opensportslib"
7
- version = "0.1.2.dev1"
7
+ version = "0.1.2.dev2"
8
8
  description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"