opensportslib 0.1.1.dev6__tar.gz → 0.1.2.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.1.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev1}/PKG-INFO +11 -8
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/README.md +10 -7
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_classification.py +11 -7
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_localization.py +11 -7
- opensportslib-0.1.2.dev1/opensportslib/apis/__init__.py +15 -0
- opensportslib-0.1.2.dev1/opensportslib/apis/base_task_model.py +96 -0
- opensportslib-0.1.2.dev1/opensportslib/apis/classification.py +322 -0
- opensportslib-0.1.2.dev1/opensportslib/apis/localization.py +333 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/classification_trainer.py +55 -12
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/localization_trainer.py +16 -18
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1/opensportslib.egg-info}/PKG-INFO +11 -8
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/SOURCES.txt +3 -1
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/pyproject.toml +4 -1
- opensportslib-0.1.2.dev1/tests/conftest.py +359 -0
- opensportslib-0.1.2.dev1/tests/test_public_apis_smoke.py +38 -0
- opensportslib-0.1.2.dev1/tests/test_subset_train_infer_integration.py +292 -0
- opensportslib-0.1.2.dev1/tests/test_task_model_api_contract.py +73 -0
- opensportslib-0.1.1.dev6/opensportslib/apis/__init__.py +0 -21
- opensportslib-0.1.1.dev6/opensportslib/apis/classification.py +0 -364
- opensportslib-0.1.1.dev6/opensportslib/apis/localization.py +0 -239
- opensportslib-0.1.1.dev6/tests/conftest.py +0 -59
- opensportslib-0.1.1.dev6/tests/test_public_apis_smoke.py +0 -29
- opensportslib-0.1.1.dev6/tests/test_subset_train_infer_integration.py +0 -172
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/LICENSE +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/MANIFEST.in +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/setup.cfg +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/tests/test_package_smoke.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2.dev1
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -125,14 +125,14 @@ print("OpenSportsLib imported successfully")
|
|
|
125
125
|
```python
|
|
126
126
|
from opensportslib import model
|
|
127
127
|
|
|
128
|
-
myModel = model.
|
|
129
|
-
config="/path/to/classification.yaml"
|
|
128
|
+
myModel = model.ClassificationModel(
|
|
129
|
+
config="/path/to/classification.yaml",
|
|
130
|
+
weights="/path/to/weights.pt", # optional
|
|
130
131
|
)
|
|
131
132
|
|
|
132
133
|
myModel.train(
|
|
133
134
|
train_set="/path/to/train_annotations.json",
|
|
134
135
|
valid_set="/path/to/valid_annotations.json",
|
|
135
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
136
136
|
)
|
|
137
137
|
```
|
|
138
138
|
|
|
@@ -142,13 +142,16 @@ myModel.train(
|
|
|
142
142
|
from opensportslib import model
|
|
143
143
|
|
|
144
144
|
myModel = model.classification(
|
|
145
|
-
config="/path/to/classification.yaml"
|
|
145
|
+
config="/path/to/classification.yaml",
|
|
146
|
+
weights="/path/to/weights.pt", # optional
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
predictions = myModel.infer(
|
|
150
|
+
test_set="/path/to/test_annotations.json",
|
|
146
151
|
)
|
|
147
152
|
|
|
148
|
-
metrics = myModel.
|
|
153
|
+
metrics = myModel.evaluate(
|
|
149
154
|
test_set="/path/to/test_annotations.json",
|
|
150
|
-
pretrained="/path/to/checkpoints/final_model",
|
|
151
|
-
predictions="/path/to/predictions.json"
|
|
152
155
|
)
|
|
153
156
|
|
|
154
157
|
print(metrics)
|
|
@@ -94,14 +94,14 @@ print("OpenSportsLib imported successfully")
|
|
|
94
94
|
```python
|
|
95
95
|
from opensportslib import model
|
|
96
96
|
|
|
97
|
-
myModel = model.
|
|
98
|
-
config="/path/to/classification.yaml"
|
|
97
|
+
myModel = model.ClassificationModel(
|
|
98
|
+
config="/path/to/classification.yaml",
|
|
99
|
+
weights="/path/to/weights.pt", # optional
|
|
99
100
|
)
|
|
100
101
|
|
|
101
102
|
myModel.train(
|
|
102
103
|
train_set="/path/to/train_annotations.json",
|
|
103
104
|
valid_set="/path/to/valid_annotations.json",
|
|
104
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
105
105
|
)
|
|
106
106
|
```
|
|
107
107
|
|
|
@@ -111,13 +111,16 @@ myModel.train(
|
|
|
111
111
|
from opensportslib import model
|
|
112
112
|
|
|
113
113
|
myModel = model.classification(
|
|
114
|
-
config="/path/to/classification.yaml"
|
|
114
|
+
config="/path/to/classification.yaml",
|
|
115
|
+
weights="/path/to/weights.pt", # optional
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
predictions = myModel.infer(
|
|
119
|
+
test_set="/path/to/test_annotations.json",
|
|
115
120
|
)
|
|
116
121
|
|
|
117
|
-
metrics = myModel.
|
|
122
|
+
metrics = myModel.evaluate(
|
|
118
123
|
test_set="/path/to/test_annotations.json",
|
|
119
|
-
pretrained="/path/to/checkpoints/final_model",
|
|
120
|
-
predictions="/path/to/predictions.json"
|
|
121
124
|
)
|
|
122
125
|
|
|
123
126
|
print(metrics)
|
{opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_classification.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from opensportslib import
|
|
1
|
+
from opensportslib.apis import ClassificationModel
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def main():
|
|
@@ -7,20 +7,24 @@ def main():
|
|
|
7
7
|
Update config and dataset paths before running.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
my_model =
|
|
11
|
-
config="examples/configs/classification_video.yaml"
|
|
10
|
+
my_model = ClassificationModel(
|
|
11
|
+
config="examples/configs/classification_video.yaml",
|
|
12
|
+
weights="/path/to/weights.pt", # optional
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
my_model.train(
|
|
15
16
|
train_set="/path/to/train_annotations.json",
|
|
16
17
|
valid_set="/path/to/valid_annotations.json",
|
|
17
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
-
|
|
20
|
+
predictions = my_model.infer(
|
|
21
|
+
test_set="/path/to/test_annotations.json",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
print(predictions)
|
|
25
|
+
|
|
26
|
+
metrics = my_model.evaluate(
|
|
21
27
|
test_set="/path/to/test_annotations.json",
|
|
22
|
-
pretrained="/path/to/checkpoints/best.pt",
|
|
23
|
-
predictions="/path/to/predictions.json",
|
|
24
28
|
)
|
|
25
29
|
|
|
26
30
|
print(metrics)
|
{opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_localization.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from opensportslib import
|
|
1
|
+
from opensportslib.apis import LocalizationModel
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def main():
|
|
@@ -7,20 +7,24 @@ def main():
|
|
|
7
7
|
Update config and dataset paths before running.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
my_model =
|
|
11
|
-
config="examples/configs/localization.yaml"
|
|
10
|
+
my_model = LocalizationModel(
|
|
11
|
+
config="examples/configs/localization.yaml",
|
|
12
|
+
weights="/path/to/weights.pt", # optional
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
my_model.train(
|
|
15
16
|
train_set="/path/to/train_annotations.json",
|
|
16
17
|
valid_set="/path/to/valid_annotations.json",
|
|
17
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
-
|
|
20
|
+
predictions = my_model.infer(
|
|
21
|
+
test_set="/path/to/test_annotations.json",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
print(predictions)
|
|
25
|
+
|
|
26
|
+
metrics = my_model.evaluate(
|
|
21
27
|
test_set="/path/to/test_annotations.json",
|
|
22
|
-
pretrained="/path/to/checkpoints/best.pt",
|
|
23
|
-
predictions="/path/to/predictions.json",
|
|
24
28
|
)
|
|
25
29
|
|
|
26
30
|
print(metrics)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# opensportslib/apis/__init__.py
|
|
2
|
+
|
|
3
|
+
# Import task APIs
|
|
4
|
+
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
5
|
+
from opensportslib.apis.classification import ClassificationModel
|
|
6
|
+
from opensportslib.apis.localization import LocalizationModel
|
|
7
|
+
import warnings
|
|
8
|
+
warnings.filterwarnings("ignore")
|
|
9
|
+
|
|
10
|
+
# Expose only these
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BaseTaskModel",
|
|
13
|
+
"ClassificationModel",
|
|
14
|
+
"LocalizationModel",
|
|
15
|
+
]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Shared task-level wrapper base for OpenSportsLib APIs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
|
|
10
|
+
from opensportslib.core.utils.config import expand, load_config_omega
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseTaskModel(ABC):
|
|
14
|
+
"""Thin shared contract for task-level OpenSportsLib wrappers."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config=None, weights=None):
|
|
17
|
+
if config is None:
|
|
18
|
+
raise ValueError("config path is required")
|
|
19
|
+
|
|
20
|
+
self.config_path = expand(config)
|
|
21
|
+
self.config = load_config_omega(self.config_path)
|
|
22
|
+
|
|
23
|
+
data_cfg = getattr(self.config, "DATA", None)
|
|
24
|
+
if data_cfg is not None and hasattr(data_cfg, "data_dir"):
|
|
25
|
+
data_cfg.data_dir = expand(data_cfg.data_dir)
|
|
26
|
+
|
|
27
|
+
self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
|
|
28
|
+
os.environ["RUN_ID"] = self.run_id
|
|
29
|
+
|
|
30
|
+
self.model = None
|
|
31
|
+
self.processor = None
|
|
32
|
+
self.trainer = None
|
|
33
|
+
self.best_checkpoint = None
|
|
34
|
+
self.last_loaded_weights = None
|
|
35
|
+
|
|
36
|
+
if weights is not None:
|
|
37
|
+
self.load_weights(weights=weights)
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def load_weights(
|
|
41
|
+
self,
|
|
42
|
+
weights: str | None = None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> None:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def train(
|
|
49
|
+
self,
|
|
50
|
+
train_set: str | None = None,
|
|
51
|
+
valid_set: str | None = None,
|
|
52
|
+
weights: str | None = None,
|
|
53
|
+
use_wandb: bool = True,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> str | None:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def infer(
|
|
60
|
+
self,
|
|
61
|
+
test_set: str | None = None,
|
|
62
|
+
weights: str | None = None,
|
|
63
|
+
use_wandb: bool = True,
|
|
64
|
+
**kwargs,
|
|
65
|
+
) -> dict:
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def evaluate(
|
|
70
|
+
self,
|
|
71
|
+
test_set: str | None = None,
|
|
72
|
+
weights: str | None = None,
|
|
73
|
+
use_wandb: bool = True,
|
|
74
|
+
**kwargs,
|
|
75
|
+
) -> dict | str | None:
|
|
76
|
+
raise NotImplementedError
|
|
77
|
+
|
|
78
|
+
def save_predictions(
|
|
79
|
+
self,
|
|
80
|
+
output_path: str,
|
|
81
|
+
predictions: dict,
|
|
82
|
+
) -> str:
|
|
83
|
+
"""Persist in-memory prediction JSON payload to a target file path."""
|
|
84
|
+
|
|
85
|
+
dst = expand(output_path)
|
|
86
|
+
os.makedirs(os.path.dirname(dst) or ".", exist_ok=True)
|
|
87
|
+
|
|
88
|
+
if not isinstance(predictions, dict):
|
|
89
|
+
raise TypeError(
|
|
90
|
+
f"Unsupported predictions type: {type(predictions).__name__}. "
|
|
91
|
+
"Expected dict."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
with open(dst, "w", encoding="utf-8") as f:
|
|
95
|
+
json.dump(predictions, f)
|
|
96
|
+
return dst
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
# opensportslib/apis/classification.py
|
|
2
|
+
|
|
3
|
+
"""Public API for classification tasks."""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
9
|
+
from opensportslib.core.utils.config import expand
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ClassificationModel(BaseTaskModel):
|
|
13
|
+
"""Top-level task wrapper for classification."""
|
|
14
|
+
|
|
15
|
+
def _resolve_split_path(self, split: str, override: str | None = None) -> str:
|
|
16
|
+
if override is not None:
|
|
17
|
+
return expand(override)
|
|
18
|
+
|
|
19
|
+
data_cfg = getattr(self.config, "DATA", None)
|
|
20
|
+
split_cfg = getattr(data_cfg, split, None)
|
|
21
|
+
path = getattr(split_cfg, "path", None) if split_cfg is not None else None
|
|
22
|
+
if path:
|
|
23
|
+
return expand(path)
|
|
24
|
+
|
|
25
|
+
annotations_cfg = getattr(data_cfg, "annotations", None)
|
|
26
|
+
path = (
|
|
27
|
+
getattr(annotations_cfg, split, None)
|
|
28
|
+
if annotations_cfg is not None
|
|
29
|
+
else None
|
|
30
|
+
)
|
|
31
|
+
if path:
|
|
32
|
+
return expand(path)
|
|
33
|
+
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Could not resolve path for split '{split}'. "
|
|
36
|
+
f"Expected DATA.{split}.path or DATA.annotations.{split}."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# -----------------------------------------------------------------
|
|
40
|
+
# internal DDP worker
|
|
41
|
+
# -----------------------------------------------------------------
|
|
42
|
+
@staticmethod
|
|
43
|
+
def _worker_ddp(
|
|
44
|
+
rank,
|
|
45
|
+
world_size,
|
|
46
|
+
mode,
|
|
47
|
+
config_path,
|
|
48
|
+
config,
|
|
49
|
+
return_queue=None,
|
|
50
|
+
train_set=None,
|
|
51
|
+
valid_set=None,
|
|
52
|
+
test_set=None,
|
|
53
|
+
weights=None,
|
|
54
|
+
use_wandb=False,
|
|
55
|
+
):
|
|
56
|
+
"""Execute one training/inference job on a single process."""
|
|
57
|
+
import torch
|
|
58
|
+
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
59
|
+
from opensportslib.core.utils.ddp import ddp_cleanup, ddp_setup
|
|
60
|
+
from opensportslib.core.utils.wandb import init_wandb
|
|
61
|
+
from opensportslib.core.utils.seed import set_reproducibility
|
|
62
|
+
from opensportslib.datasets.builder import build_dataset
|
|
63
|
+
from opensportslib.models.builder import build_model
|
|
64
|
+
|
|
65
|
+
logging.basicConfig(
|
|
66
|
+
level=logging.INFO,
|
|
67
|
+
format=f"[RANK {rank}] %(asctime)s | %(levelname)s | %(message)s",
|
|
68
|
+
force=True,
|
|
69
|
+
)
|
|
70
|
+
if rank != 0:
|
|
71
|
+
logging.getLogger().setLevel(logging.ERROR)
|
|
72
|
+
|
|
73
|
+
if rank == 0:
|
|
74
|
+
init_wandb(
|
|
75
|
+
config_path,
|
|
76
|
+
config,
|
|
77
|
+
run_id=os.environ["RUN_ID"],
|
|
78
|
+
use_wandb=use_wandb,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if getattr(config.SYSTEM, "use_seed", False):
|
|
82
|
+
set_reproducibility(config.SYSTEM.seed)
|
|
83
|
+
|
|
84
|
+
is_ddp = world_size > 1
|
|
85
|
+
if is_ddp:
|
|
86
|
+
torch.cuda.set_device(rank)
|
|
87
|
+
ddp_setup(rank, world_size)
|
|
88
|
+
device = torch.device(f"cuda:{rank}")
|
|
89
|
+
else:
|
|
90
|
+
from opensportslib.core.utils.config import select_device
|
|
91
|
+
|
|
92
|
+
device = select_device(config.SYSTEM)
|
|
93
|
+
|
|
94
|
+
trainer = Trainer_Classification(config)
|
|
95
|
+
trainer.device = device
|
|
96
|
+
|
|
97
|
+
if weights:
|
|
98
|
+
model, processor, _, _ = trainer.load(weights)
|
|
99
|
+
else:
|
|
100
|
+
model, processor = build_model(config, device)
|
|
101
|
+
|
|
102
|
+
trainer.model = model
|
|
103
|
+
|
|
104
|
+
if mode == "train":
|
|
105
|
+
train_data = build_dataset(config, train_set, processor, split="train")
|
|
106
|
+
valid_data = build_dataset(config, valid_set, processor, split="valid")
|
|
107
|
+
best_ckpt = trainer.train(
|
|
108
|
+
model,
|
|
109
|
+
train_data,
|
|
110
|
+
valid_data,
|
|
111
|
+
rank=rank,
|
|
112
|
+
world_size=world_size,
|
|
113
|
+
)
|
|
114
|
+
if rank == 0 and return_queue is not None:
|
|
115
|
+
best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
|
|
116
|
+
return_queue.put(best_ckpt)
|
|
117
|
+
|
|
118
|
+
elif mode == "infer":
|
|
119
|
+
test_data = build_dataset(config, test_set, processor, split="test")
|
|
120
|
+
predictions = trainer.infer(
|
|
121
|
+
test_data,
|
|
122
|
+
rank=rank,
|
|
123
|
+
world_size=world_size,
|
|
124
|
+
)
|
|
125
|
+
if rank == 0 and return_queue is not None:
|
|
126
|
+
return_queue.put(predictions)
|
|
127
|
+
|
|
128
|
+
if is_ddp:
|
|
129
|
+
ddp_cleanup()
|
|
130
|
+
|
|
131
|
+
def load_weights(
|
|
132
|
+
self,
|
|
133
|
+
weights: str | None = None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
) -> None:
|
|
136
|
+
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
137
|
+
|
|
138
|
+
del kwargs
|
|
139
|
+
if weights is None:
|
|
140
|
+
raise ValueError("`weights` must be provided to load_weights().")
|
|
141
|
+
|
|
142
|
+
self.trainer = Trainer_Classification(self.config)
|
|
143
|
+
loaded = self.trainer.load(weights)
|
|
144
|
+
self.model = loaded[0]
|
|
145
|
+
|
|
146
|
+
if getattr(self.config.MODEL, "type", "custom") == "huggingface":
|
|
147
|
+
self.processor = loaded[1]
|
|
148
|
+
|
|
149
|
+
self.last_loaded_weights = weights
|
|
150
|
+
self.best_checkpoint = weights
|
|
151
|
+
|
|
152
|
+
# -----------------------------------------------------------------
|
|
153
|
+
# public training interface
|
|
154
|
+
# -----------------------------------------------------------------
|
|
155
|
+
|
|
156
|
+
def train(
|
|
157
|
+
self,
|
|
158
|
+
train_set=None,
|
|
159
|
+
valid_set=None,
|
|
160
|
+
test_set=None,
|
|
161
|
+
weights=None,
|
|
162
|
+
use_ddp=False,
|
|
163
|
+
use_wandb=True,
|
|
164
|
+
**kwargs,
|
|
165
|
+
):
|
|
166
|
+
"""Run full training and return best checkpoint path."""
|
|
167
|
+
import torch
|
|
168
|
+
import torch.multiprocessing as mp
|
|
169
|
+
from opensportslib.core.utils.config import resolve_config_omega
|
|
170
|
+
|
|
171
|
+
del test_set # retained for API compatibility
|
|
172
|
+
|
|
173
|
+
train_set = self._resolve_split_path("train", train_set)
|
|
174
|
+
valid_set = self._resolve_split_path("valid", valid_set)
|
|
175
|
+
|
|
176
|
+
self.config = resolve_config_omega(self.config)
|
|
177
|
+
logging.info("Configuration:")
|
|
178
|
+
logging.info(self.config)
|
|
179
|
+
|
|
180
|
+
del kwargs
|
|
181
|
+
|
|
182
|
+
world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
|
|
183
|
+
use_ddp = use_ddp and world_size > 1
|
|
184
|
+
|
|
185
|
+
ctx = mp.get_context("spawn")
|
|
186
|
+
queue = ctx.Queue()
|
|
187
|
+
|
|
188
|
+
if use_ddp:
|
|
189
|
+
logging.info(f"Launching DDP on {world_size} GPUs")
|
|
190
|
+
mp.spawn(
|
|
191
|
+
ClassificationModel._worker_ddp,
|
|
192
|
+
args=(
|
|
193
|
+
world_size,
|
|
194
|
+
"train",
|
|
195
|
+
self.config_path,
|
|
196
|
+
self.config,
|
|
197
|
+
queue,
|
|
198
|
+
train_set,
|
|
199
|
+
valid_set,
|
|
200
|
+
None,
|
|
201
|
+
weights,
|
|
202
|
+
use_wandb,
|
|
203
|
+
),
|
|
204
|
+
nprocs=world_size,
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
logging.info("Single GPU training")
|
|
208
|
+
ClassificationModel._worker_ddp(
|
|
209
|
+
rank=0,
|
|
210
|
+
world_size=1,
|
|
211
|
+
mode="train",
|
|
212
|
+
config_path=self.config_path,
|
|
213
|
+
config=self.config,
|
|
214
|
+
return_queue=queue,
|
|
215
|
+
train_set=train_set,
|
|
216
|
+
valid_set=valid_set,
|
|
217
|
+
weights=weights,
|
|
218
|
+
use_wandb=use_wandb,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self.best_checkpoint = queue.get()
|
|
222
|
+
self.last_loaded_weights = self.best_checkpoint
|
|
223
|
+
return self.best_checkpoint
|
|
224
|
+
|
|
225
|
+
def infer(
|
|
226
|
+
self,
|
|
227
|
+
test_set=None,
|
|
228
|
+
weights=None,
|
|
229
|
+
use_ddp=False,
|
|
230
|
+
use_wandb=True,
|
|
231
|
+
**kwargs,
|
|
232
|
+
):
|
|
233
|
+
"""Run model inference and return predictions in OSL JSON format."""
|
|
234
|
+
del kwargs
|
|
235
|
+
|
|
236
|
+
import torch
|
|
237
|
+
import torch.multiprocessing as mp
|
|
238
|
+
from opensportslib.core.utils.config import resolve_config_omega
|
|
239
|
+
|
|
240
|
+
test_set = self._resolve_split_path("test", test_set)
|
|
241
|
+
|
|
242
|
+
self.config = resolve_config_omega(self.config)
|
|
243
|
+
logging.info("Configuration:")
|
|
244
|
+
logging.info(self.config)
|
|
245
|
+
|
|
246
|
+
world_size = torch.cuda.device_count()
|
|
247
|
+
use_ddp = use_ddp and world_size > 1
|
|
248
|
+
|
|
249
|
+
ctx = mp.get_context("spawn")
|
|
250
|
+
queue = ctx.Queue()
|
|
251
|
+
|
|
252
|
+
if use_ddp:
|
|
253
|
+
mp.spawn(
|
|
254
|
+
ClassificationModel._worker_ddp,
|
|
255
|
+
args=(
|
|
256
|
+
world_size,
|
|
257
|
+
"infer",
|
|
258
|
+
self.config_path,
|
|
259
|
+
self.config,
|
|
260
|
+
queue,
|
|
261
|
+
None,
|
|
262
|
+
None,
|
|
263
|
+
test_set,
|
|
264
|
+
weights,
|
|
265
|
+
use_wandb,
|
|
266
|
+
),
|
|
267
|
+
nprocs=world_size,
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
ClassificationModel._worker_ddp(
|
|
271
|
+
rank=0,
|
|
272
|
+
world_size=1,
|
|
273
|
+
mode="infer",
|
|
274
|
+
config_path=self.config_path,
|
|
275
|
+
config=self.config,
|
|
276
|
+
return_queue=queue,
|
|
277
|
+
test_set=test_set,
|
|
278
|
+
weights=weights,
|
|
279
|
+
use_wandb=use_wandb,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
predictions = queue.get()
|
|
283
|
+
return predictions
|
|
284
|
+
|
|
285
|
+
def evaluate(
|
|
286
|
+
self,
|
|
287
|
+
test_set=None,
|
|
288
|
+
weights=None,
|
|
289
|
+
use_ddp=False,
|
|
290
|
+
use_wandb=True,
|
|
291
|
+
**kwargs,
|
|
292
|
+
):
|
|
293
|
+
"""Run inference on test set and return evaluation metrics."""
|
|
294
|
+
del kwargs
|
|
295
|
+
|
|
296
|
+
from opensportslib.datasets.builder import build_dataset
|
|
297
|
+
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
298
|
+
from opensportslib.core.utils.config import resolve_config_omega
|
|
299
|
+
|
|
300
|
+
test_set = self._resolve_split_path("test", test_set)
|
|
301
|
+
|
|
302
|
+
self.config = resolve_config_omega(self.config)
|
|
303
|
+
logging.info("Configuration:")
|
|
304
|
+
logging.info(self.config)
|
|
305
|
+
predictions = self.infer(
|
|
306
|
+
test_set=test_set,
|
|
307
|
+
weights=weights,
|
|
308
|
+
use_ddp=use_ddp,
|
|
309
|
+
use_wandb=use_wandb,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
self.trainer = self.trainer or Trainer_Classification(self.config)
|
|
313
|
+
test_data = build_dataset(self.config, test_set, None, split="test")
|
|
314
|
+
metrics = self.trainer.evaluate(
|
|
315
|
+
pred_path=predictions,
|
|
316
|
+
gt_path=test_set,
|
|
317
|
+
class_names=test_data.label_map,
|
|
318
|
+
exclude_labels=test_data.exclude_labels,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
logging.info(f"TEST METRICS : {metrics}")
|
|
322
|
+
return metrics
|