opensportslib 0.1.1__tar.gz → 0.1.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.1/opensportslib.egg-info → opensportslib-0.1.1.dev2}/PKG-INFO +5 -1
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/README.md +1 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/classification.py +8 -5
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/localization.py +5 -5
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/classification_trainer.py +19 -16
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/localization_trainer.py +14 -13
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/config.py +17 -2
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/data.py +11 -2
- opensportslib-0.1.1.dev2/opensportslib/core/utils/wandb.py +280 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/classification_dataset.py +21 -3
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/backbones/builder.py +54 -16
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/tracking.py +1 -2
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/utils.py +8 -1
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2/opensportslib.egg-info}/PKG-INFO +5 -1
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/SOURCES.txt +6 -1
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/requires.txt +4 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/top_level.txt +2 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/pyproject.toml +2 -1
- opensportslib-0.1.1.dev2/tests/conftest.py +59 -0
- opensportslib-0.1.1.dev2/tests/test_config_utils_smoke.py +46 -0
- opensportslib-0.1.1.dev2/tests/test_package_smoke.py +23 -0
- opensportslib-0.1.1.dev2/tests/test_public_apis_smoke.py +29 -0
- opensportslib-0.1.1.dev2/tests/test_subset_train_infer_integration.py +172 -0
- opensportslib-0.1.1/opensportslib/core/utils/wandb.py +0 -120
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/LICENSE +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/MANIFEST.in +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.1
|
|
3
|
+
Version: 0.1.1.dev2
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -33,6 +33,9 @@ Requires-Dist: torch-scatter; extra == "py-geometric"
|
|
|
33
33
|
Requires-Dist: torch-sparse; extra == "py-geometric"
|
|
34
34
|
Requires-Dist: torch-cluster; extra == "py-geometric"
|
|
35
35
|
Requires-Dist: torch-spline-conv; extra == "py-geometric"
|
|
36
|
+
Provides-Extra: test
|
|
37
|
+
Requires-Dist: pytest; extra == "test"
|
|
38
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
36
39
|
Dynamic: license-file
|
|
37
40
|
|
|
38
41
|
# OpenSportsLib
|
|
@@ -192,6 +195,7 @@ Generate text descriptions for sports events and temporal segments.
|
|
|
192
195
|
Use the README for the fast start, then go deeper through:
|
|
193
196
|
|
|
194
197
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
198
|
+
- Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
|
|
195
199
|
- Example configs: [examples/configs/](examples/configs/)
|
|
196
200
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
197
201
|
- Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
|
|
@@ -155,6 +155,7 @@ Generate text descriptions for sports events and temporal segments.
|
|
|
155
155
|
Use the README for the fast start, then go deeper through:
|
|
156
156
|
|
|
157
157
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
158
|
+
- Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
|
|
158
159
|
- Example configs: [examples/configs/](examples/configs/)
|
|
159
160
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
160
161
|
- Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
|
|
@@ -40,8 +40,8 @@ class ClassificationAPI:
|
|
|
40
40
|
if config is None:
|
|
41
41
|
raise ValueError("config path is required")
|
|
42
42
|
|
|
43
|
-
config_path = expand(config)
|
|
44
|
-
self.config = load_config_omega(config_path)
|
|
43
|
+
self.config_path = expand(config)
|
|
44
|
+
self.config = load_config_omega(self.config_path)
|
|
45
45
|
|
|
46
46
|
# let the caller override the dataset root directory.
|
|
47
47
|
self.config.DATA.data_dir = expand(
|
|
@@ -88,6 +88,7 @@ class ClassificationAPI:
|
|
|
88
88
|
rank,
|
|
89
89
|
world_size,
|
|
90
90
|
mode,
|
|
91
|
+
config_path,
|
|
91
92
|
config,
|
|
92
93
|
return_queue=None,
|
|
93
94
|
train_set=None,
|
|
@@ -134,7 +135,7 @@ class ClassificationAPI:
|
|
|
134
135
|
logging.getLogger().setLevel(logging.ERROR)
|
|
135
136
|
|
|
136
137
|
if rank == 0:
|
|
137
|
-
init_wandb(config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
138
|
+
init_wandb(config_path, config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
138
139
|
|
|
139
140
|
# reproducibility:
|
|
140
141
|
# we default to reproducible training, but allow the user to
|
|
@@ -242,7 +243,7 @@ class ClassificationAPI:
|
|
|
242
243
|
mp.spawn(
|
|
243
244
|
ClassificationAPI._worker_ddp,
|
|
244
245
|
args=(
|
|
245
|
-
world_size, "train", self.config, queue,
|
|
246
|
+
world_size, "train", self.config_path, self.config, queue,
|
|
246
247
|
train_set, valid_set, None, pretrained, use_wandb
|
|
247
248
|
),
|
|
248
249
|
nprocs=world_size,
|
|
@@ -253,6 +254,7 @@ class ClassificationAPI:
|
|
|
253
254
|
rank=0,
|
|
254
255
|
world_size=1,
|
|
255
256
|
mode="train",
|
|
257
|
+
config_path=self.config_path,
|
|
256
258
|
config=self.config,
|
|
257
259
|
return_queue=queue,
|
|
258
260
|
train_set=train_set,
|
|
@@ -321,7 +323,7 @@ class ClassificationAPI:
|
|
|
321
323
|
mp.spawn(
|
|
322
324
|
ClassificationAPI._worker_ddp,
|
|
323
325
|
args=(
|
|
324
|
-
world_size, "infer", self.config, queue,
|
|
326
|
+
world_size, "infer",self.config_path, self.config, queue,
|
|
325
327
|
None, None, test_set, pretrained, use_wandb
|
|
326
328
|
),
|
|
327
329
|
nprocs=world_size,
|
|
@@ -331,6 +333,7 @@ class ClassificationAPI:
|
|
|
331
333
|
rank=0,
|
|
332
334
|
world_size=1,
|
|
333
335
|
mode="infer",
|
|
336
|
+
config_path=self.config_path,
|
|
334
337
|
config=self.config,
|
|
335
338
|
return_queue=queue,
|
|
336
339
|
test_set=test_set,
|
|
@@ -14,8 +14,8 @@ class LocalizationAPI:
|
|
|
14
14
|
|
|
15
15
|
# Load config
|
|
16
16
|
### load data_dor first then do load config with omega to resolve $paths
|
|
17
|
-
config_path = expand(config)
|
|
18
|
-
self.config = load_config_omega(config_path)
|
|
17
|
+
self.config_path = expand(config)
|
|
18
|
+
self.config = load_config_omega(self.config_path)
|
|
19
19
|
# User must control dataset folder
|
|
20
20
|
self.config.DATA.data_dir = expand(data_dir or self.config.DATA.data_dir)
|
|
21
21
|
print(self.config.DATA.classes)
|
|
@@ -46,7 +46,7 @@ class LocalizationAPI:
|
|
|
46
46
|
|
|
47
47
|
logger = logging.getLogger(__name__)
|
|
48
48
|
|
|
49
|
-
print("CONFIG PATH :", config_path)
|
|
49
|
+
print("CONFIG PATH :", self.config_path)
|
|
50
50
|
print("DATA DIR :", self.config.DATA.data_dir)
|
|
51
51
|
print("SAVEDIR:", self.config.SYSTEM.save_dir)
|
|
52
52
|
print("Classes :", self.config.DATA.classes)
|
|
@@ -76,7 +76,7 @@ class LocalizationAPI:
|
|
|
76
76
|
|
|
77
77
|
self.config = resolve_config_omega(self.config)
|
|
78
78
|
check_config(self.config, split="train")
|
|
79
|
-
init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
79
|
+
init_wandb(self.config_path, self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
80
80
|
logging.info("Configuration:")
|
|
81
81
|
logging.info(self.config)
|
|
82
82
|
#print(self.config)
|
|
@@ -164,7 +164,7 @@ class LocalizationAPI:
|
|
|
164
164
|
self.config = resolve_config_omega(self.config)
|
|
165
165
|
check_config(self.config, split="test")
|
|
166
166
|
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
167
|
-
init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
167
|
+
init_wandb(self.config_path, self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
168
168
|
logging.info("Configuration:")
|
|
169
169
|
logging.info(self.config)
|
|
170
170
|
# Start Timing
|
|
@@ -286,14 +286,15 @@ class BaseTrainerClassification:
|
|
|
286
286
|
|
|
287
287
|
if self.rank == 0:
|
|
288
288
|
# ---------------- W&B LOG ----------------
|
|
289
|
-
wandb.
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
289
|
+
if wandb.run is not None:
|
|
290
|
+
wandb.log({
|
|
291
|
+
"epoch": epoch + 1,
|
|
292
|
+
"lr": current_lr,
|
|
293
|
+
"train/loss": train_loss,
|
|
294
|
+
"valid/loss": val_loss,
|
|
295
|
+
**{f"train/{k}": v for k, v in train_metrics.items()},
|
|
296
|
+
**{f"valid/{k}": v for k, v in val_metrics.items()},
|
|
297
|
+
})
|
|
297
298
|
|
|
298
299
|
logging.info(f"Train Loss: {train_loss:.4f} | Train Bal Acc: {train_metric:.4f}")
|
|
299
300
|
logging.info(f"Val Loss: {val_loss:.4f} | Val Bal Acc: {val_metric:.4f}")
|
|
@@ -316,9 +317,10 @@ class BaseTrainerClassification:
|
|
|
316
317
|
best_path = self._save_checkpoint("best", epoch + 1, tag="best")
|
|
317
318
|
self.best_checkpoint_path = best_path
|
|
318
319
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
320
|
+
if wandb.run is not None:
|
|
321
|
+
artifact = wandb.Artifact("model-checkpoint", type="model")
|
|
322
|
+
artifact.add_file(best_path)
|
|
323
|
+
wandb.log_artifact(artifact)
|
|
322
324
|
|
|
323
325
|
if self.rank == 0:
|
|
324
326
|
logging.info(f"Best checkpoint : {self.best_checkpoint_path}")
|
|
@@ -350,10 +352,11 @@ class BaseTrainerClassification:
|
|
|
350
352
|
pbar.close()
|
|
351
353
|
|
|
352
354
|
if self.rank==0:
|
|
353
|
-
wandb.
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
355
|
+
if wandb.run is not None:
|
|
356
|
+
wandb.log({
|
|
357
|
+
"test/loss": test_loss,
|
|
358
|
+
**{f"test/{k}": v for k, v in test_metrics.items()},
|
|
359
|
+
})
|
|
357
360
|
|
|
358
361
|
if detailed_results:
|
|
359
362
|
from opensportslib.metrics.classification_metric import (
|
|
@@ -1128,4 +1131,4 @@ class Trainer_Classification:
|
|
|
1128
1131
|
self.scheduler = scheduler
|
|
1129
1132
|
self.epoch = epoch
|
|
1130
1133
|
logging.info(f"Model loaded from {path}, epoch: {epoch}")
|
|
1131
|
-
return self.model, self.optimizer, self.scheduler, self.epoch
|
|
1134
|
+
return self.model, self.optimizer, self.scheduler, self.epoch
|
{opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/core/trainer/localization_trainer.py
RENAMED
|
@@ -403,15 +403,16 @@ class Trainer_e2e(Trainer):
|
|
|
403
403
|
)
|
|
404
404
|
|
|
405
405
|
# ---------------- W&B LOG ----------------
|
|
406
|
-
wandb.
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
406
|
+
if wandb.run is not None:
|
|
407
|
+
wandb.log({
|
|
408
|
+
"epoch": epoch + 1,
|
|
409
|
+
"train/loss": train_loss,
|
|
410
|
+
"valid/loss": valid_loss,
|
|
411
|
+
"valid/mAP": valid_mAP,
|
|
412
|
+
"lr": self.optimizer.param_groups[0]["lr"],
|
|
413
|
+
"best/mAP": self.best_criterion_valid if self.criterion_valid == "map" else None,
|
|
414
|
+
"best/loss": self.best_criterion_valid if self.criterion_valid == "loss" else None,
|
|
415
|
+
})
|
|
415
416
|
|
|
416
417
|
if self.save_dir is not None:
|
|
417
418
|
os.makedirs(self.save_dir, exist_ok=True)
|
|
@@ -630,9 +631,10 @@ class Inferer:
|
|
|
630
631
|
cfg.DATA.test.dataloader,
|
|
631
632
|
return_pred=False,
|
|
632
633
|
)
|
|
633
|
-
wandb.
|
|
634
|
-
|
|
635
|
-
|
|
634
|
+
if wandb.run is not None:
|
|
635
|
+
wandb.log({
|
|
636
|
+
"test/Avg_mAP": mAP,
|
|
637
|
+
})
|
|
636
638
|
pred_json_file = os.path.join(pred_file + ".json")
|
|
637
639
|
pred_recall_file = os.path.join(pred_file + ".recall.json.gz")
|
|
638
640
|
logging.info("Predictions saved")
|
|
@@ -1092,4 +1094,3 @@ class Evaluator:
|
|
|
1092
1094
|
|
|
1093
1095
|
return results
|
|
1094
1096
|
|
|
1095
|
-
|
|
@@ -26,7 +26,22 @@ def dict_to_namespace(d, skip_keys=("classes",)):
|
|
|
26
26
|
return d
|
|
27
27
|
|
|
28
28
|
def namespace_to_dict(ns):
|
|
29
|
-
|
|
29
|
+
"""
|
|
30
|
+
Recursively convert namespace/dict/list containers into plain Python types.
|
|
31
|
+
"""
|
|
32
|
+
if ns is None or isinstance(ns, (str, int, float, bool)):
|
|
33
|
+
return ns
|
|
34
|
+
|
|
35
|
+
if isinstance(ns, dict):
|
|
36
|
+
return {str(k): namespace_to_dict(v) for k, v in ns.items()}
|
|
37
|
+
|
|
38
|
+
if isinstance(ns, (list, tuple, set)):
|
|
39
|
+
return [namespace_to_dict(v) for v in ns]
|
|
40
|
+
|
|
41
|
+
if hasattr(ns, "__dict__"):
|
|
42
|
+
return {str(k): namespace_to_dict(v) for k, v in vars(ns).items()}
|
|
43
|
+
|
|
44
|
+
return ns
|
|
30
45
|
|
|
31
46
|
def namespace_to_omegaconf(ns):
|
|
32
47
|
"""
|
|
@@ -196,4 +211,4 @@ def is_local_path(p):
|
|
|
196
211
|
return p and (
|
|
197
212
|
os.path.exists(p) or
|
|
198
213
|
p.endswith((".pt", ".pth", ".tar"))
|
|
199
|
-
)
|
|
214
|
+
)
|
|
@@ -52,7 +52,16 @@ def tracking_collate_fn(batch):
|
|
|
52
52
|
Custom collate function for tracking data.
|
|
53
53
|
Uses PyG Batch.from_data_list for efficient C++ batching.
|
|
54
54
|
"""
|
|
55
|
-
|
|
55
|
+
try:
|
|
56
|
+
from torch_geometric.data import Batch
|
|
57
|
+
except ImportError as exc:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"torch-geometric is required for tracking_collate_fn. "
|
|
60
|
+
"Install with: pip install \"opensportslib[py-geometric]\" "
|
|
61
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
|
|
62
|
+
"or (editable): pip install -e \".[py-geometric]\" "
|
|
63
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
|
|
64
|
+
) from exc
|
|
56
65
|
|
|
57
66
|
batch_size = len(batch)
|
|
58
67
|
seq_len = batch[0]['seq_len']
|
|
@@ -82,4 +91,4 @@ def mixup_data(x, y, alpha=0.2):
|
|
|
82
91
|
index = torch.randperm(x.size(0)).to(x.device)
|
|
83
92
|
mixed_x = lam * x + (1 - lam) * x[index]
|
|
84
93
|
return mixed_x, y, y[index], lam
|
|
85
|
-
|
|
94
|
+
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
import wandb
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from opensportslib.core.utils.config import namespace_to_dict
|
|
7
|
+
|
|
8
|
+
def build_wandb_config(cfg):
|
|
9
|
+
"""
|
|
10
|
+
Extract minimal + useful config for W&B dashboard.
|
|
11
|
+
Returns a FLAT dict (ready for wandb.init(config=...))
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from opensportslib.core.utils.config import namespace_to_dict
|
|
15
|
+
|
|
16
|
+
cfg_dict = namespace_to_dict(cfg)
|
|
17
|
+
|
|
18
|
+
def get(d, path, default=None):
|
|
19
|
+
"""Safe nested get using dot notation"""
|
|
20
|
+
keys = path.split(".")
|
|
21
|
+
for k in keys:
|
|
22
|
+
if not isinstance(d, dict) or k not in d:
|
|
23
|
+
return default
|
|
24
|
+
d = d[k]
|
|
25
|
+
return d
|
|
26
|
+
|
|
27
|
+
def pick(paths):
|
|
28
|
+
"""Pick only selected keys"""
|
|
29
|
+
out = {}
|
|
30
|
+
for p in paths:
|
|
31
|
+
v = get(cfg_dict, p)
|
|
32
|
+
if v is not None:
|
|
33
|
+
out[p] = v
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
# -------------------------
|
|
37
|
+
# REQUIRED (core columns)
|
|
38
|
+
# -------------------------
|
|
39
|
+
REQUIRED_KEYS = [
|
|
40
|
+
"TASK",
|
|
41
|
+
|
|
42
|
+
"DATA.dataset_name",
|
|
43
|
+
"DATA.data_modality",
|
|
44
|
+
|
|
45
|
+
"MODEL.type",
|
|
46
|
+
"MODEL.backbone.type",
|
|
47
|
+
"MODEL.neck.type",
|
|
48
|
+
"MODEL.head.type",
|
|
49
|
+
|
|
50
|
+
"TRAIN.optimizer.type",
|
|
51
|
+
"TRAIN.optimizer.lr",
|
|
52
|
+
"TRAIN.scheduler.type",
|
|
53
|
+
|
|
54
|
+
"TRAIN.monitor",
|
|
55
|
+
"TRAIN.mode",
|
|
56
|
+
|
|
57
|
+
"SYSTEM.device",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
# -------------------------
|
|
61
|
+
# OPTIONAL (useful knobs)
|
|
62
|
+
# -------------------------
|
|
63
|
+
OPTIONAL_KEYS = [
|
|
64
|
+
# DATA
|
|
65
|
+
"DATA.num_frames",
|
|
66
|
+
"DATA.clip_len",
|
|
67
|
+
"DATA.input_fps",
|
|
68
|
+
"DATA.extract_fps",
|
|
69
|
+
"DATA.frame_size",
|
|
70
|
+
"DATA.view_type",
|
|
71
|
+
"DATA.num_classes",
|
|
72
|
+
|
|
73
|
+
# MODEL
|
|
74
|
+
"MODEL.backbone.encoder",
|
|
75
|
+
"MODEL.backbone.hidden_dim",
|
|
76
|
+
"MODEL.backbone.freeze",
|
|
77
|
+
"MODEL.unfreeze_last_n_layers",
|
|
78
|
+
"MODEL.neck.agr_type",
|
|
79
|
+
"MODEL.edge",
|
|
80
|
+
|
|
81
|
+
# TRAIN
|
|
82
|
+
"TRAIN.epochs",
|
|
83
|
+
"TRAIN.num_epochs",
|
|
84
|
+
"TRAIN.max_epochs",
|
|
85
|
+
"TRAIN.use_amp",
|
|
86
|
+
"TRAIN.use_weighted_loss",
|
|
87
|
+
"TRAIN.use_weighted_sampler",
|
|
88
|
+
"TRAIN.mixup",
|
|
89
|
+
"TRAIN.mixup_alpha",
|
|
90
|
+
|
|
91
|
+
# SYSTEM
|
|
92
|
+
"SYSTEM.GPU",
|
|
93
|
+
"SYSTEM.seed",
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
config = {}
|
|
97
|
+
|
|
98
|
+
# pick required
|
|
99
|
+
config.update(pick(REQUIRED_KEYS))
|
|
100
|
+
|
|
101
|
+
# pick optional
|
|
102
|
+
config.update(pick(OPTIONAL_KEYS))
|
|
103
|
+
|
|
104
|
+
# -------------------------
|
|
105
|
+
# SPECIAL HANDLING
|
|
106
|
+
# -------------------------
|
|
107
|
+
|
|
108
|
+
# Normalize batch_size (from nested dataloader)
|
|
109
|
+
batch_size = get(cfg_dict, "DATA.train.dataloader.batch_size")
|
|
110
|
+
if batch_size is not None:
|
|
111
|
+
config["TRAIN.batch_size"] = batch_size
|
|
112
|
+
|
|
113
|
+
# Normalize epochs (different configs use different names)
|
|
114
|
+
epochs = (
|
|
115
|
+
get(cfg_dict, "TRAIN.epochs")
|
|
116
|
+
or get(cfg_dict, "TRAIN.num_epochs")
|
|
117
|
+
or get(cfg_dict, "TRAIN.max_epochs")
|
|
118
|
+
)
|
|
119
|
+
if epochs is not None:
|
|
120
|
+
config["TRAIN.total_epochs"] = epochs
|
|
121
|
+
|
|
122
|
+
return config
|
|
123
|
+
|
|
124
|
+
def _flatten_config(data, parent_key="", sep="."):
|
|
125
|
+
"""Flatten nested dict/list config for W&B table-friendly columns."""
|
|
126
|
+
items = {}
|
|
127
|
+
|
|
128
|
+
if isinstance(data, dict):
|
|
129
|
+
for k, v in data.items():
|
|
130
|
+
key = f"{parent_key}{sep}{k}" if parent_key else str(k)
|
|
131
|
+
items.update(_flatten_config(v, key, sep=sep))
|
|
132
|
+
return items
|
|
133
|
+
|
|
134
|
+
if isinstance(data, list):
|
|
135
|
+
for i, v in enumerate(data):
|
|
136
|
+
key = f"{parent_key}{sep}{i}" if parent_key else str(i)
|
|
137
|
+
items.update(_flatten_config(v, key, sep=sep))
|
|
138
|
+
return items
|
|
139
|
+
|
|
140
|
+
if parent_key:
|
|
141
|
+
items[parent_key] = data
|
|
142
|
+
|
|
143
|
+
return items
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _wandb_ready():
|
|
147
|
+
return getattr(wandb, "run", None) is not None
|
|
148
|
+
|
|
149
|
+
def init_wandb(cfg_path, cfg, run_id, use_wandb=False):
|
|
150
|
+
"""
|
|
151
|
+
Initialize Weights & Biases if enabled.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
cfg_path: Path to the configuration file.
|
|
155
|
+
cfg: config object with attributes:
|
|
156
|
+
- use_wandb (bool)
|
|
157
|
+
- project_name (str)
|
|
158
|
+
- run_name (str)
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
if not use_wandb:
|
|
162
|
+
logging.info("W&B disabled.")
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
import wandb
|
|
167
|
+
except ImportError:
|
|
168
|
+
logging.warning("wandb not installed. Install with `pip install wandb`.")
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
# Prevent multiple processes from initializing wandb
|
|
172
|
+
rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
|
|
173
|
+
if rank != 0:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
# Prevent re-initialization
|
|
177
|
+
if wandb.run is not None:
|
|
178
|
+
return wandb
|
|
179
|
+
|
|
180
|
+
if getattr(cfg.DATA, "data_modality", None):
|
|
181
|
+
run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
|
|
182
|
+
else:
|
|
183
|
+
run_name = f"{cfg.MODEL.backbone.type}"
|
|
184
|
+
|
|
185
|
+
config_flat = build_wandb_config(cfg)
|
|
186
|
+
|
|
187
|
+
wandb.init(
|
|
188
|
+
project=cfg.TASK,
|
|
189
|
+
name=run_name,
|
|
190
|
+
id=run_id,
|
|
191
|
+
resume="allow",
|
|
192
|
+
config=config_flat,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
artifact = wandb.Artifact(
|
|
196
|
+
name=f"{cfg.TASK}-config",
|
|
197
|
+
type="config",
|
|
198
|
+
description="configuration (YAML)"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
artifact.add_file(cfg_path)
|
|
202
|
+
wandb.log_artifact(artifact)
|
|
203
|
+
|
|
204
|
+
logging.info(f"Wandb initialised")
|
|
205
|
+
return wandb
|
|
206
|
+
|
|
207
|
+
def log_table_wandb(name, rows, headers):
|
|
208
|
+
"""
|
|
209
|
+
Log a table to Weights & Biases.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
name (str): Name of the table in wandb.
|
|
213
|
+
rows (list[list]): Table rows.
|
|
214
|
+
headers (list[str]): Column headers.
|
|
215
|
+
"""
|
|
216
|
+
if not _wandb_ready():
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
table = wandb.Table(columns=headers)
|
|
220
|
+
|
|
221
|
+
for row in rows:
|
|
222
|
+
table.add_data(*row)
|
|
223
|
+
|
|
224
|
+
wandb.log({name: table})
|
|
225
|
+
|
|
226
|
+
def log_attention_wandb(attention, split_name):
|
|
227
|
+
if not _wandb_ready():
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
attn = attention.detach().cpu().numpy()
|
|
231
|
+
|
|
232
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
233
|
+
ax.imshow(attn, aspect="auto", cmap="viridis")
|
|
234
|
+
ax.set_title(f"{split_name} Attention Map")
|
|
235
|
+
ax.set_xlabel("Views / Time")
|
|
236
|
+
ax.set_ylabel("Batch")
|
|
237
|
+
|
|
238
|
+
wandb.log({
|
|
239
|
+
f"{split_name}/attention_map": wandb.Image(fig)
|
|
240
|
+
})
|
|
241
|
+
|
|
242
|
+
plt.close(fig)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def log_sample_videos_wandb(mvclips, preds, labels, split_name, max_samples=2, fps=5):
|
|
246
|
+
if not _wandb_ready():
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# mvclips: (B, V, C, T, H, W)
|
|
251
|
+
mvclips = mvclips.detach().cpu().numpy()
|
|
252
|
+
|
|
253
|
+
for i in range(min(len(mvclips), max_samples)):
|
|
254
|
+
views = mvclips[i] # (V, C, T, H, W)
|
|
255
|
+
|
|
256
|
+
# Log each view separately
|
|
257
|
+
for v in range(views.shape[0]):
|
|
258
|
+
video = views[v].transpose(1, 2, 3, 0) # (T, H, W, C)
|
|
259
|
+
video = (video * 255).astype(np.uint8) if video.max() <= 1.0 else video
|
|
260
|
+
|
|
261
|
+
wandb.log({
|
|
262
|
+
f"{split_name}/sample_{i}_view_{v}": wandb.Video(
|
|
263
|
+
video,
|
|
264
|
+
fps=fps,
|
|
265
|
+
caption=f"Pred: {preds[i]}, GT: {labels[i]}"
|
|
266
|
+
)
|
|
267
|
+
})
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def log_confusion_matrix_wandb(y_true, y_pred, class_names, split_name):
|
|
271
|
+
if not _wandb_ready():
|
|
272
|
+
return
|
|
273
|
+
wandb.log({
|
|
274
|
+
f"{split_name}/confusion_matrix": wandb.plot.confusion_matrix(
|
|
275
|
+
probs=None,
|
|
276
|
+
y_true=y_true,
|
|
277
|
+
preds=y_pred,
|
|
278
|
+
class_names=class_names
|
|
279
|
+
)
|
|
280
|
+
})
|
{opensportslib-0.1.1 → opensportslib-0.1.1.dev2}/opensportslib/datasets/classification_dataset.py
RENAMED
|
@@ -496,7 +496,16 @@ class TrackingDataset(ClassificationDataset):
|
|
|
496
496
|
a copy of the feature array is made before augmentation and
|
|
497
497
|
normalization so the cached data is never mutated.
|
|
498
498
|
"""
|
|
499
|
-
|
|
499
|
+
try:
|
|
500
|
+
from torch_geometric.data import Data
|
|
501
|
+
except ImportError as exc:
|
|
502
|
+
raise ImportError(
|
|
503
|
+
"torch-geometric is required for tracking_parquet datasets. "
|
|
504
|
+
"Install with: pip install \"opensportslib[py-geometric]\" "
|
|
505
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
|
|
506
|
+
"or (editable): pip install -e \".[py-geometric]\" "
|
|
507
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
|
|
508
|
+
) from exc
|
|
500
509
|
|
|
501
510
|
from opensportslib.datasets.utils.tracking import normalize_features
|
|
502
511
|
|
|
@@ -533,7 +542,16 @@ class TrackingDataset(ClassificationDataset):
|
|
|
533
542
|
|
|
534
543
|
def _getitem_on_the_fly(self, idx):
|
|
535
544
|
"""load, parse, and process a single sample from disk."""
|
|
536
|
-
|
|
545
|
+
try:
|
|
546
|
+
from torch_geometric.data import Data
|
|
547
|
+
except ImportError as exc:
|
|
548
|
+
raise ImportError(
|
|
549
|
+
"torch-geometric is required for tracking_parquet datasets. "
|
|
550
|
+
"Install with: pip install \"opensportslib[py-geometric]\" "
|
|
551
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html "
|
|
552
|
+
"or (editable): pip install -e \".[py-geometric]\" "
|
|
553
|
+
"-f https://pytorch-geometric.com/whl/torch-2.10.0+cu128.html"
|
|
554
|
+
) from exc
|
|
537
555
|
|
|
538
556
|
from opensportslib.datasets.utils.tracking import (
|
|
539
557
|
compute_deltas,
|
|
@@ -600,4 +618,4 @@ class TrackingDataset(ClassificationDataset):
|
|
|
600
618
|
if label is not None:
|
|
601
619
|
out["label"] = label
|
|
602
620
|
return out
|
|
603
|
-
|
|
621
|
+
|