opensportslib 0.1.2__tar.gz → 0.1.2.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.2/opensportslib.egg-info → opensportslib-0.1.2.dev2}/PKG-INFO +48 -15
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/README.md +47 -14
- opensportslib-0.1.2.dev2/examples/quickstart/basic_classification.py +46 -0
- opensportslib-0.1.2.dev2/examples/quickstart/basic_localization.py +46 -0
- opensportslib-0.1.2.dev2/opensportslib/apis/__init__.py +15 -0
- opensportslib-0.1.2.dev2/opensportslib/apis/base_task_model.py +131 -0
- opensportslib-0.1.2.dev2/opensportslib/apis/classification.py +328 -0
- opensportslib-0.1.2.dev2/opensportslib/apis/localization.py +354 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/classification_trainer.py +57 -14
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/localization_trainer.py +18 -20
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/config.py +4 -4
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/metrics/localization_metric.py +1 -1
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2/opensportslib.egg-info}/PKG-INFO +48 -15
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/SOURCES.txt +3 -1
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/pyproject.toml +4 -1
- opensportslib-0.1.2.dev2/tests/conftest.py +359 -0
- opensportslib-0.1.2.dev2/tests/test_public_apis_smoke.py +38 -0
- opensportslib-0.1.2.dev2/tests/test_subset_train_infer_integration.py +292 -0
- opensportslib-0.1.2.dev2/tests/test_task_model_api_contract.py +375 -0
- opensportslib-0.1.2/examples/quickstart/basic_classification.py +0 -30
- opensportslib-0.1.2/examples/quickstart/basic_localization.py +0 -30
- opensportslib-0.1.2/opensportslib/apis/__init__.py +0 -21
- opensportslib-0.1.2/opensportslib/apis/classification.py +0 -364
- opensportslib-0.1.2/opensportslib/apis/localization.py +0 -239
- opensportslib-0.1.2/tests/conftest.py +0 -59
- opensportslib-0.1.2/tests/test_public_apis_smoke.py +0 -29
- opensportslib-0.1.2/tests/test_subset_train_infer_integration.py +0 -172
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/LICENSE +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/MANIFEST.in +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/setup.cfg +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/tests/test_package_smoke.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2
|
|
3
|
+
Version: 0.1.2.dev2
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -123,32 +123,45 @@ print("OpenSportsLib imported successfully")
|
|
|
123
123
|
### Train a classification model
|
|
124
124
|
|
|
125
125
|
```python
|
|
126
|
-
from opensportslib import
|
|
126
|
+
from opensportslib.apis import ClassificationModel
|
|
127
127
|
|
|
128
|
-
|
|
129
|
-
config="/path/to/classification.yaml"
|
|
128
|
+
my_model = ClassificationModel(
|
|
129
|
+
config="/path/to/classification.yaml",
|
|
130
|
+
weights="/path/to/weights.pt", # optional
|
|
130
131
|
)
|
|
131
132
|
|
|
132
|
-
|
|
133
|
+
my_model.train(
|
|
133
134
|
train_set="/path/to/train_annotations.json",
|
|
134
135
|
valid_set="/path/to/valid_annotations.json",
|
|
135
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
136
136
|
)
|
|
137
137
|
```
|
|
138
138
|
|
|
139
139
|
### Run inference
|
|
140
140
|
|
|
141
141
|
```python
|
|
142
|
-
from opensportslib import
|
|
142
|
+
from opensportslib.apis import ClassificationModel
|
|
143
143
|
|
|
144
|
-
|
|
145
|
-
config="/path/to/classification.yaml"
|
|
144
|
+
my_model = ClassificationModel(
|
|
145
|
+
config="/path/to/classification.yaml",
|
|
146
|
+
weights="/path/to/weights.pt", # optional
|
|
146
147
|
)
|
|
147
148
|
|
|
148
|
-
|
|
149
|
+
predictions = my_model.infer(
|
|
149
150
|
test_set="/path/to/test_annotations.json",
|
|
150
|
-
|
|
151
|
-
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
saved_predictions = my_model.save_predictions(
|
|
154
|
+
output_path="/path/to/predictions.json",
|
|
155
|
+
predictions=predictions,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
metrics = my_model.evaluate(
|
|
159
|
+
test_set="/path/to/test_annotations.json",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
metrics_from_file = my_model.evaluate(
|
|
163
|
+
test_set="/path/to/test_annotations.json",
|
|
164
|
+
predictions=saved_predictions,
|
|
152
165
|
)
|
|
153
166
|
|
|
154
167
|
print(metrics)
|
|
@@ -157,10 +170,29 @@ print(metrics)
|
|
|
157
170
|
### Localization example
|
|
158
171
|
|
|
159
172
|
```python
|
|
160
|
-
from opensportslib import
|
|
173
|
+
from opensportslib.apis import LocalizationModel
|
|
174
|
+
|
|
175
|
+
my_model = LocalizationModel(
|
|
176
|
+
config="/path/to/localization.yaml",
|
|
177
|
+
weights="/path/to/weights.pt", # optional
|
|
178
|
+
)
|
|
161
179
|
|
|
162
|
-
|
|
163
|
-
|
|
180
|
+
predictions = my_model.infer(
|
|
181
|
+
test_set="/path/to/test_annotations.json",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
saved_predictions = my_model.save_predictions(
|
|
185
|
+
output_path="/path/to/predictions.json",
|
|
186
|
+
predictions=predictions,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
metrics = my_model.evaluate(
|
|
190
|
+
test_set="/path/to/test_annotations.json",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
metrics_from_file = my_model.evaluate(
|
|
194
|
+
test_set="/path/to/test_annotations.json",
|
|
195
|
+
predictions=saved_predictions,
|
|
164
196
|
)
|
|
165
197
|
```
|
|
166
198
|
|
|
@@ -198,6 +230,7 @@ Generate text descriptions for sports events and temporal segments.
|
|
|
198
230
|
Use the README for the fast start, then go deeper through:
|
|
199
231
|
|
|
200
232
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
233
|
+
- High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
|
|
201
234
|
- Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
|
|
202
235
|
- Example configs: [examples/configs/](examples/configs/)
|
|
203
236
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
@@ -92,32 +92,45 @@ print("OpenSportsLib imported successfully")
|
|
|
92
92
|
### Train a classification model
|
|
93
93
|
|
|
94
94
|
```python
|
|
95
|
-
from opensportslib import
|
|
95
|
+
from opensportslib.apis import ClassificationModel
|
|
96
96
|
|
|
97
|
-
|
|
98
|
-
config="/path/to/classification.yaml"
|
|
97
|
+
my_model = ClassificationModel(
|
|
98
|
+
config="/path/to/classification.yaml",
|
|
99
|
+
weights="/path/to/weights.pt", # optional
|
|
99
100
|
)
|
|
100
101
|
|
|
101
|
-
|
|
102
|
+
my_model.train(
|
|
102
103
|
train_set="/path/to/train_annotations.json",
|
|
103
104
|
valid_set="/path/to/valid_annotations.json",
|
|
104
|
-
pretrained="/path/to/pretrained.pt", # optional
|
|
105
105
|
)
|
|
106
106
|
```
|
|
107
107
|
|
|
108
108
|
### Run inference
|
|
109
109
|
|
|
110
110
|
```python
|
|
111
|
-
from opensportslib import
|
|
111
|
+
from opensportslib.apis import ClassificationModel
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
config="/path/to/classification.yaml"
|
|
113
|
+
my_model = ClassificationModel(
|
|
114
|
+
config="/path/to/classification.yaml",
|
|
115
|
+
weights="/path/to/weights.pt", # optional
|
|
115
116
|
)
|
|
116
117
|
|
|
117
|
-
|
|
118
|
+
predictions = my_model.infer(
|
|
118
119
|
test_set="/path/to/test_annotations.json",
|
|
119
|
-
|
|
120
|
-
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
saved_predictions = my_model.save_predictions(
|
|
123
|
+
output_path="/path/to/predictions.json",
|
|
124
|
+
predictions=predictions,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
metrics = my_model.evaluate(
|
|
128
|
+
test_set="/path/to/test_annotations.json",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
metrics_from_file = my_model.evaluate(
|
|
132
|
+
test_set="/path/to/test_annotations.json",
|
|
133
|
+
predictions=saved_predictions,
|
|
121
134
|
)
|
|
122
135
|
|
|
123
136
|
print(metrics)
|
|
@@ -126,10 +139,29 @@ print(metrics)
|
|
|
126
139
|
### Localization example
|
|
127
140
|
|
|
128
141
|
```python
|
|
129
|
-
from opensportslib import
|
|
142
|
+
from opensportslib.apis import LocalizationModel
|
|
143
|
+
|
|
144
|
+
my_model = LocalizationModel(
|
|
145
|
+
config="/path/to/localization.yaml",
|
|
146
|
+
weights="/path/to/weights.pt", # optional
|
|
147
|
+
)
|
|
130
148
|
|
|
131
|
-
|
|
132
|
-
|
|
149
|
+
predictions = my_model.infer(
|
|
150
|
+
test_set="/path/to/test_annotations.json",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
saved_predictions = my_model.save_predictions(
|
|
154
|
+
output_path="/path/to/predictions.json",
|
|
155
|
+
predictions=predictions,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
metrics = my_model.evaluate(
|
|
159
|
+
test_set="/path/to/test_annotations.json",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
metrics_from_file = my_model.evaluate(
|
|
163
|
+
test_set="/path/to/test_annotations.json",
|
|
164
|
+
predictions=saved_predictions,
|
|
133
165
|
)
|
|
134
166
|
```
|
|
135
167
|
|
|
@@ -167,6 +199,7 @@ Generate text descriptions for sports events and temporal segments.
|
|
|
167
199
|
Use the README for the fast start, then go deeper through:
|
|
168
200
|
|
|
169
201
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
202
|
+
- High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
|
|
170
203
|
- Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
|
|
171
204
|
- Example configs: [examples/configs/](examples/configs/)
|
|
172
205
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from opensportslib.apis import ClassificationModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def main():
|
|
5
|
+
"""
|
|
6
|
+
Minimal classification example.
|
|
7
|
+
Update config and dataset paths before running.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
my_model = ClassificationModel(
|
|
11
|
+
config="examples/configs/classification_video.yaml",
|
|
12
|
+
weights="/path/to/weights.pt", # optional
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
my_model.train(
|
|
16
|
+
train_set="/path/to/train_annotations.json",
|
|
17
|
+
valid_set="/path/to/valid_annotations.json",
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
predictions = my_model.infer(
|
|
21
|
+
test_set="/path/to/test_annotations.json",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
print(predictions)
|
|
25
|
+
|
|
26
|
+
metrics = my_model.evaluate(
|
|
27
|
+
test_set="/path/to/test_annotations.json",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
print(metrics)
|
|
31
|
+
|
|
32
|
+
saved_predictions = my_model.save_predictions(
|
|
33
|
+
output_path="/path/to/predictions.json",
|
|
34
|
+
predictions=predictions,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
metrics_from_file = my_model.evaluate(
|
|
38
|
+
test_set="/path/to/test_annotations.json",
|
|
39
|
+
predictions=saved_predictions,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
print(metrics_from_file)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
main()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from opensportslib.apis import LocalizationModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def main():
|
|
5
|
+
"""
|
|
6
|
+
Minimal localization example.
|
|
7
|
+
Update config and dataset paths before running.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
my_model = LocalizationModel(
|
|
11
|
+
config="examples/configs/localization.yaml",
|
|
12
|
+
weights="/path/to/weights.pt", # optional
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
my_model.train(
|
|
16
|
+
train_set="/path/to/train_annotations.json",
|
|
17
|
+
valid_set="/path/to/valid_annotations.json",
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
predictions = my_model.infer(
|
|
21
|
+
test_set="/path/to/test_annotations.json",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
print(predictions)
|
|
25
|
+
|
|
26
|
+
metrics = my_model.evaluate(
|
|
27
|
+
test_set="/path/to/test_annotations.json",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
print(metrics)
|
|
31
|
+
|
|
32
|
+
saved_predictions = my_model.save_predictions(
|
|
33
|
+
output_path="/path/to/predictions.json",
|
|
34
|
+
predictions=predictions,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
metrics_from_file = my_model.evaluate(
|
|
38
|
+
test_set="/path/to/test_annotations.json",
|
|
39
|
+
predictions=saved_predictions,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
print(metrics_from_file)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
main()
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# opensportslib/apis/__init__.py
|
|
2
|
+
|
|
3
|
+
# Import task APIs
|
|
4
|
+
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
5
|
+
from opensportslib.apis.classification import ClassificationModel
|
|
6
|
+
from opensportslib.apis.localization import LocalizationModel
|
|
7
|
+
import warnings
|
|
8
|
+
warnings.filterwarnings("ignore")
|
|
9
|
+
|
|
10
|
+
# Expose only these
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BaseTaskModel",
|
|
13
|
+
"ClassificationModel",
|
|
14
|
+
"LocalizationModel",
|
|
15
|
+
]
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Shared task-level wrapper base for OpenSportsLib APIs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import uuid
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from opensportslib.core.utils.config import expand, load_config_omega
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseTaskModel(ABC):
|
|
16
|
+
"""Thin shared contract for task-level OpenSportsLib wrappers."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config=None, weights=None):
|
|
19
|
+
self._configure_logging()
|
|
20
|
+
|
|
21
|
+
if config is None:
|
|
22
|
+
raise ValueError("config path is required")
|
|
23
|
+
|
|
24
|
+
self.config_path = expand(config)
|
|
25
|
+
self.config = load_config_omega(self.config_path)
|
|
26
|
+
|
|
27
|
+
data_cfg = getattr(self.config, "DATA", None)
|
|
28
|
+
if data_cfg is not None and hasattr(data_cfg, "data_dir"):
|
|
29
|
+
data_cfg.data_dir = expand(data_cfg.data_dir)
|
|
30
|
+
logging.info(f"Data directory: {data_cfg.data_dir}")
|
|
31
|
+
|
|
32
|
+
self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
|
|
33
|
+
os.environ["RUN_ID"] = self.run_id
|
|
34
|
+
|
|
35
|
+
system_cfg = getattr(self.config, "SYSTEM", None)
|
|
36
|
+
if system_cfg is not None:
|
|
37
|
+
base_save_dir = expand(getattr(system_cfg, "save_dir", None) or "./checkpoints")
|
|
38
|
+
model_cfg = getattr(self.config, "MODEL", None)
|
|
39
|
+
backbone_cfg = getattr(model_cfg, "backbone", None)
|
|
40
|
+
model_name = getattr(backbone_cfg, "type", None) or "model"
|
|
41
|
+
run_save_dir = os.path.join(base_save_dir, model_name, self.run_id)
|
|
42
|
+
self.save_dir = run_save_dir
|
|
43
|
+
system_cfg.save_dir = run_save_dir
|
|
44
|
+
if hasattr(system_cfg, "work_dir"):
|
|
45
|
+
system_cfg.work_dir = run_save_dir
|
|
46
|
+
os.makedirs(run_save_dir, exist_ok=True)
|
|
47
|
+
else:
|
|
48
|
+
self.save_dir = expand("./checkpoints")
|
|
49
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
50
|
+
|
|
51
|
+
logging.info(f"Save directory: {self.save_dir}")
|
|
52
|
+
|
|
53
|
+
self.model = None
|
|
54
|
+
self.processor = None
|
|
55
|
+
self.trainer = None
|
|
56
|
+
self.best_checkpoint = None
|
|
57
|
+
self.last_loaded_weights = None
|
|
58
|
+
|
|
59
|
+
if weights is not None:
|
|
60
|
+
self.load_weights(weights=weights)
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _configure_logging() -> None:
|
|
64
|
+
root_logger = logging.getLogger()
|
|
65
|
+
if not root_logger.handlers:
|
|
66
|
+
logging.basicConfig(
|
|
67
|
+
level=logging.INFO,
|
|
68
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
69
|
+
)
|
|
70
|
+
elif root_logger.level > logging.INFO:
|
|
71
|
+
root_logger.setLevel(logging.INFO)
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def load_weights(
|
|
75
|
+
self,
|
|
76
|
+
weights: str | None = None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> None:
|
|
79
|
+
raise NotImplementedError
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def train(
|
|
83
|
+
self,
|
|
84
|
+
train_set: str | None = None,
|
|
85
|
+
valid_set: str | None = None,
|
|
86
|
+
weights: str | None = None,
|
|
87
|
+
use_wandb: bool = True,
|
|
88
|
+
**kwargs,
|
|
89
|
+
) -> str | None:
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def infer(
|
|
94
|
+
self,
|
|
95
|
+
test_set: str | None = None,
|
|
96
|
+
weights: str | None = None,
|
|
97
|
+
use_wandb: bool = True,
|
|
98
|
+
**kwargs,
|
|
99
|
+
) -> dict:
|
|
100
|
+
raise NotImplementedError
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def evaluate(
|
|
104
|
+
self,
|
|
105
|
+
test_set: str | None = None,
|
|
106
|
+
weights: str | None = None,
|
|
107
|
+
predictions: str | dict[str, Any] | None = None,
|
|
108
|
+
use_wandb: bool = True,
|
|
109
|
+
**kwargs,
|
|
110
|
+
) -> dict | str | None:
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
def save_predictions(
|
|
114
|
+
self,
|
|
115
|
+
output_path: str,
|
|
116
|
+
predictions: dict,
|
|
117
|
+
) -> str:
|
|
118
|
+
"""Persist in-memory prediction JSON payload to a target file path."""
|
|
119
|
+
|
|
120
|
+
dst = expand(output_path)
|
|
121
|
+
os.makedirs(os.path.dirname(dst) or ".", exist_ok=True)
|
|
122
|
+
|
|
123
|
+
if not isinstance(predictions, dict):
|
|
124
|
+
raise TypeError(
|
|
125
|
+
f"Unsupported predictions type: {type(predictions).__name__}. "
|
|
126
|
+
"Expected dict."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
with open(dst, "w", encoding="utf-8") as f:
|
|
130
|
+
json.dump(predictions, f)
|
|
131
|
+
return dst
|