opensportslib 0.1.2.dev7__tar.gz → 0.1.2.dev9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.2.dev7/opensportslib.egg-info → opensportslib-0.1.2.dev9}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/apis/base_task_model.py +8 -1
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/apis/classification.py +4 -5
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/apis/localization.py +10 -11
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/trainer/classification_trainer.py +7 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/trainer/localization_trainer.py +14 -0
- opensportslib-0.1.2.dev9/opensportslib/core/utils/config.py +358 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9/opensportslib.egg-info}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib.egg-info/SOURCES.txt +2 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/pyproject.toml +1 -1
- opensportslib-0.1.2.dev9/tools/convert/build_soccernet_gar.py +1089 -0
- opensportslib-0.1.2.dev9/tools/convert/build_soccernet_gar_action_spotting.py +289 -0
- opensportslib-0.1.2.dev7/opensportslib/core/utils/config.py +0 -214
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/LICENSE +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/MANIFEST.in +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/README.md +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/setup.cfg +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/conftest.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_public_apis_smoke.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tests/test_task_model_api_contract.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/download/upload_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/training/classification.py +0 -0
- {opensportslib-0.1.2.dev7 → opensportslib-0.1.2.dev9}/tools/training/localization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev9
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -9,7 +9,7 @@ import uuid
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
-
from opensportslib.core.utils.config import expand, load_config_omega
|
|
12
|
+
from opensportslib.core.utils.config import expand, load_config_omega, fetch_and_merge_config_from_HF
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class BaseTaskModel(ABC):
|
|
@@ -24,6 +24,13 @@ class BaseTaskModel(ABC):
|
|
|
24
24
|
self.config_path = expand(config)
|
|
25
25
|
self.config = load_config_omega(self.config_path)
|
|
26
26
|
|
|
27
|
+
if weights is not None:
|
|
28
|
+
self.config = fetch_and_merge_config_from_HF(self.config, weights, merge_policy="compatibility")
|
|
29
|
+
self.last_loaded_weights = weights
|
|
30
|
+
self.best_checkpoint = weights
|
|
31
|
+
|
|
32
|
+
self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
33
|
+
|
|
27
34
|
data_cfg = getattr(self.config, "DATA", None)
|
|
28
35
|
if data_cfg is not None and hasattr(data_cfg, "data_dir"):
|
|
29
36
|
data_cfg.data_dir = expand(data_cfg.data_dir)
|
|
@@ -8,7 +8,6 @@ import os
|
|
|
8
8
|
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
9
9
|
from opensportslib.core.utils.config import expand
|
|
10
10
|
|
|
11
|
-
|
|
12
11
|
class ClassificationModel(BaseTaskModel):
|
|
13
12
|
"""Top-level task wrapper for classification."""
|
|
14
13
|
|
|
@@ -172,8 +171,8 @@ class ClassificationModel(BaseTaskModel):
|
|
|
172
171
|
|
|
173
172
|
train_set = self._resolve_split_path("train", train_set)
|
|
174
173
|
valid_set = self._resolve_split_path("valid", valid_set)
|
|
175
|
-
|
|
176
|
-
self.config = resolve_config_omega(self.config)
|
|
174
|
+
|
|
175
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
177
176
|
logging.info("Configuration:")
|
|
178
177
|
logging.info(self.config)
|
|
179
178
|
|
|
@@ -241,7 +240,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
241
240
|
|
|
242
241
|
test_set = self._resolve_split_path("test", test_set)
|
|
243
242
|
|
|
244
|
-
self.config = resolve_config_omega(self.config)
|
|
243
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
245
244
|
logging.info("Configuration:")
|
|
246
245
|
logging.info(self.config)
|
|
247
246
|
|
|
@@ -304,7 +303,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
304
303
|
|
|
305
304
|
test_set = self._resolve_split_path("test", test_set)
|
|
306
305
|
|
|
307
|
-
self.config = resolve_config_omega(self.config)
|
|
306
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
308
307
|
logging.info("Configuration:")
|
|
309
308
|
logging.info(self.config)
|
|
310
309
|
if predictions is None:
|
|
@@ -9,13 +9,13 @@ from opensportslib.core.utils.config import expand
|
|
|
9
9
|
class LocalizationModel(BaseTaskModel):
|
|
10
10
|
"""Top-level task wrapper for localization / spotting."""
|
|
11
11
|
|
|
12
|
-
def __init__(self, config=None, weights=None):
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
# def __init__(self, config=None, weights=None):
|
|
13
|
+
# super().__init__(config=config, weights=None)
|
|
14
|
+
# if weights is not None:
|
|
15
|
+
# self.last_loaded_weights = weights
|
|
16
|
+
# self.best_checkpoint = weights
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
# self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
19
19
|
|
|
20
20
|
def _resolve_split_path(self, split: str, override: str | None = None) -> str:
|
|
21
21
|
if override is not None:
|
|
@@ -178,8 +178,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
178
178
|
valid_set = self._resolve_split_path("valid", valid_set)
|
|
179
179
|
self._set_split_path("train", train_set)
|
|
180
180
|
self._set_split_path("valid", valid_set)
|
|
181
|
-
|
|
182
|
-
self.config = resolve_config_omega(self.config)
|
|
181
|
+
|
|
182
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
183
183
|
check_config(self.config, split="train")
|
|
184
184
|
init_wandb(
|
|
185
185
|
self.config_path,
|
|
@@ -292,7 +292,7 @@ class LocalizationModel(BaseTaskModel):
|
|
|
292
292
|
self._set_split_path("test", test_set)
|
|
293
293
|
|
|
294
294
|
self.config.MODEL.multi_gpu = False
|
|
295
|
-
self.config = resolve_config_omega(self.config)
|
|
295
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
296
296
|
check_config(self.config, split="test")
|
|
297
297
|
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
298
298
|
|
|
@@ -361,9 +361,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
361
361
|
|
|
362
362
|
test_set = self._resolve_split_path("test", test_set)
|
|
363
363
|
self._set_split_path("test", test_set)
|
|
364
|
-
|
|
365
364
|
self.config.MODEL.multi_gpu = False
|
|
366
|
-
self.config = resolve_config_omega(self.config)
|
|
365
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
367
366
|
check_config(self.config, split="test")
|
|
368
367
|
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
369
368
|
|
|
@@ -550,6 +550,13 @@ class BaseTrainerClassification:
|
|
|
550
550
|
path_aux = os.path.join(epoch_dir, name)
|
|
551
551
|
torch.save(state, path_aux)
|
|
552
552
|
logging.info(f"Saved checkpoint: {path_aux}")
|
|
553
|
+
|
|
554
|
+
if self.config is not None:
|
|
555
|
+
from opensportslib.core.utils.config import save_config
|
|
556
|
+
config_path = os.path.join(epoch_dir, "config.yaml")
|
|
557
|
+
save_config(self.config, config_path)
|
|
558
|
+
logging.info(f"Saved config: {config_path}")
|
|
559
|
+
|
|
553
560
|
return path_aux
|
|
554
561
|
|
|
555
562
|
|
|
@@ -155,6 +155,7 @@ class Trainer_pl(Trainer):
|
|
|
155
155
|
num_sanity_val_steps=0,
|
|
156
156
|
)
|
|
157
157
|
self.best_checkpoint_path = None
|
|
158
|
+
self.config = cfg
|
|
158
159
|
|
|
159
160
|
def train(self, **kwargs):
|
|
160
161
|
self.trainer.fit(**kwargs)
|
|
@@ -179,6 +180,13 @@ class Trainer_pl(Trainer):
|
|
|
179
180
|
logging.info("Done training")
|
|
180
181
|
logging.info(f"Best model saved at: {self.best_checkpoint_path}")
|
|
181
182
|
|
|
183
|
+
# Save the config file uniformly inside the work_dir
|
|
184
|
+
if hasattr(self, 'config') and self.config is not None:
|
|
185
|
+
from opensportslib.core.utils.config import save_config
|
|
186
|
+
config_path = os.path.join(self.work_dir, "config.yaml")
|
|
187
|
+
save_config(self.config, config_path)
|
|
188
|
+
logging.info(f"Saved config: {config_path}")
|
|
189
|
+
|
|
182
190
|
log()
|
|
183
191
|
|
|
184
192
|
|
|
@@ -297,6 +305,12 @@ class Trainer_e2e(Trainer):
|
|
|
297
305
|
self.best_checkpoint_path = best_path
|
|
298
306
|
torch.save(checkpoint, best_path)
|
|
299
307
|
logging.info(f"Best checkpoint saved: {best_path}")
|
|
308
|
+
|
|
309
|
+
if self.config is not None:
|
|
310
|
+
from opensportslib.core.utils.config import save_config
|
|
311
|
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
|
312
|
+
save_config(self.config, config_path)
|
|
313
|
+
logging.info(f"Saved config: {config_path}")
|
|
300
314
|
|
|
301
315
|
def train(self, train_loader, valid_loader, classes):
|
|
302
316
|
"""Training loop with checkpoint management."""
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import json
|
|
5
|
+
import gzip
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
def dict_to_namespace(d, skip_keys=("classes",)):
|
|
9
|
+
"""
|
|
10
|
+
Recursively convert dict to namespace for easy access,
|
|
11
|
+
but keep certain keys (like 'classes') as raw dict/list.
|
|
12
|
+
"""
|
|
13
|
+
from types import SimpleNamespace
|
|
14
|
+
|
|
15
|
+
if isinstance(d, dict):
|
|
16
|
+
out = {}
|
|
17
|
+
for k, v in d.items():
|
|
18
|
+
if k in skip_keys:
|
|
19
|
+
out[k] = v # leave as-is
|
|
20
|
+
else:
|
|
21
|
+
out[k] = dict_to_namespace(v, skip_keys)
|
|
22
|
+
return SimpleNamespace(**out)
|
|
23
|
+
elif isinstance(d, list):
|
|
24
|
+
return [dict_to_namespace(v, skip_keys) for v in d]
|
|
25
|
+
else:
|
|
26
|
+
return d
|
|
27
|
+
|
|
28
|
+
def namespace_to_dict(ns):
|
|
29
|
+
"""
|
|
30
|
+
Recursively convert namespace/dict/list containers into plain Python types.
|
|
31
|
+
"""
|
|
32
|
+
if ns is None or isinstance(ns, (str, int, float, bool)):
|
|
33
|
+
return ns
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
37
|
+
if isinstance(ns, (DictConfig, ListConfig)):
|
|
38
|
+
ns = OmegaConf.to_container(ns, resolve=True)
|
|
39
|
+
except ImportError:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
if isinstance(ns, dict):
|
|
43
|
+
return {str(k): namespace_to_dict(v) for k, v in ns.items()}
|
|
44
|
+
|
|
45
|
+
if isinstance(ns, (list, tuple, set)):
|
|
46
|
+
return [namespace_to_dict(v) for v in ns]
|
|
47
|
+
|
|
48
|
+
if hasattr(ns, "__dict__"):
|
|
49
|
+
return {str(k): namespace_to_dict(v) for k, v in vars(ns).items()}
|
|
50
|
+
|
|
51
|
+
return ns
|
|
52
|
+
|
|
53
|
+
def namespace_to_omegaconf(ns):
|
|
54
|
+
"""
|
|
55
|
+
Recursively convert SimpleNamespace (or dict/list) back to OmegaConf
|
|
56
|
+
"""
|
|
57
|
+
from omegaconf import OmegaConf
|
|
58
|
+
from types import SimpleNamespace
|
|
59
|
+
|
|
60
|
+
def to_dict(obj):
|
|
61
|
+
if isinstance(obj, SimpleNamespace):
|
|
62
|
+
return {k: to_dict(v) for k, v in vars(obj).items()}
|
|
63
|
+
elif isinstance(obj, dict):
|
|
64
|
+
return {k: to_dict(v) for k, v in obj.items()}
|
|
65
|
+
elif isinstance(obj, list):
|
|
66
|
+
return [to_dict(v) for v in obj]
|
|
67
|
+
else:
|
|
68
|
+
return obj
|
|
69
|
+
|
|
70
|
+
return OmegaConf.create(to_dict(ns))
|
|
71
|
+
|
|
72
|
+
def load_config(config_path):
|
|
73
|
+
"""
|
|
74
|
+
Loading configurations
|
|
75
|
+
"""
|
|
76
|
+
print(config_path)
|
|
77
|
+
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
|
|
78
|
+
with open(config_path, "r") as f:
|
|
79
|
+
cfg_dict = yaml.safe_load(f)
|
|
80
|
+
elif config_path.endswith(".json"):
|
|
81
|
+
with open(config_path, "r") as f:
|
|
82
|
+
cfg_dict = json.load(f)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError("Unsupported config format. Use YAML or JSON.")
|
|
85
|
+
return dict_to_namespace(cfg_dict)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_config_omega(path):
|
|
90
|
+
|
|
91
|
+
from omegaconf import OmegaConf
|
|
92
|
+
cfg = OmegaConf.load(path)
|
|
93
|
+
# OmegaConf.resolve(cfg)
|
|
94
|
+
# cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
95
|
+
return dict_to_namespace(cfg)
|
|
96
|
+
|
|
97
|
+
def resolve_config_omega(cfg, weights=None):
|
|
98
|
+
from omegaconf import OmegaConf, DictConfig
|
|
99
|
+
#cfg = namespace_to_omegaconf(cfg)
|
|
100
|
+
#cfg = namespace_to_dict(cfg)
|
|
101
|
+
#print(type(cfg))
|
|
102
|
+
#cfg = OmegaConf.create(cfg)
|
|
103
|
+
if weights is not None:
|
|
104
|
+
cfg = fetch_and_merge_config_from_HF(cfg, weights, merge_policy="compatibility")
|
|
105
|
+
|
|
106
|
+
if not isinstance(cfg, DictConfig):
|
|
107
|
+
return cfg
|
|
108
|
+
OmegaConf.resolve(cfg)
|
|
109
|
+
cfg = dict_to_namespace(OmegaConf.to_container(cfg, resolve=True))
|
|
110
|
+
return cfg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def expand(path):
|
|
114
|
+
return os.path.abspath(os.path.expanduser(path))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def load_json(fpath):
|
|
118
|
+
with open(fpath, encoding="utf-8") as fp:
|
|
119
|
+
return json.load(fp)
|
|
120
|
+
|
|
121
|
+
def load_gz_json(fpath):
|
|
122
|
+
with gzip.open(fpath, "rt", encoding="utf-8") as fp:
|
|
123
|
+
return json.load(fp)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def store_json(fpath, obj, pretty=False):
|
|
127
|
+
kwargs = {}
|
|
128
|
+
if pretty:
|
|
129
|
+
kwargs["indent"] = 4
|
|
130
|
+
kwargs["sort_keys"] = False
|
|
131
|
+
with open(fpath, "w", encoding="utf-8") as fp:
|
|
132
|
+
json.dump(obj, fp, **kwargs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def store_gz_json(fpath, obj):
|
|
136
|
+
with gzip.open(fpath, "wt", encoding="utf-8") as fp:
|
|
137
|
+
json.dump(obj, fp)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def load_text(fpath):
|
|
141
|
+
"""Load text from a given file.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
fpath (string): The path of the file.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
lines (List): List in which element is a line of the file.
|
|
148
|
+
|
|
149
|
+
"""
|
|
150
|
+
lines = []
|
|
151
|
+
with open(fpath, "r") as fp:
|
|
152
|
+
for l in fp:
|
|
153
|
+
l = l.strip()
|
|
154
|
+
if l:
|
|
155
|
+
lines.append(l)
|
|
156
|
+
return lines
|
|
157
|
+
|
|
158
|
+
def load_classes(input):
|
|
159
|
+
"""Load classes from either list or txt file.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
input (string): Path of the file that contains one class per line or list of classes.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Dictionnary with classes associated to indexes.
|
|
166
|
+
"""
|
|
167
|
+
from omegaconf import ListConfig
|
|
168
|
+
if isinstance(input, (list, ListConfig)):
|
|
169
|
+
return {x: i + 1 for i, x in enumerate(input)}
|
|
170
|
+
return {x: i + 1 for i, x in enumerate(load_text(input))}
|
|
171
|
+
|
|
172
|
+
def clear_files(dir_name, re_str, exclude=[]):
|
|
173
|
+
for file_name in os.listdir(dir_name):
|
|
174
|
+
if re.match(re_str, file_name):
|
|
175
|
+
if file_name not in exclude:
|
|
176
|
+
file_path = os.path.join(dir_name, file_name)
|
|
177
|
+
os.remove(file_path)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _print_info_helper(src_file, labels):
|
|
181
|
+
"""Print informations about videos contained in a json file.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
src_file (string): The source file.
|
|
185
|
+
labels (list(dict)): List containing a dict fro each video.
|
|
186
|
+
"""
|
|
187
|
+
num_frames = sum([x["num_frames"] for x in labels])
|
|
188
|
+
num_events = sum([len(x["events"]) for x in labels])
|
|
189
|
+
print(
|
|
190
|
+
"{} : {} videos, {} frames, {:0.5f}% non-bg".format(
|
|
191
|
+
src_file, len(labels), num_frames, num_events / num_frames * 100
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def select_device(config):
|
|
196
|
+
import torch
|
|
197
|
+
mode = config.device.lower()
|
|
198
|
+
|
|
199
|
+
if mode == "auto":
|
|
200
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
201
|
+
|
|
202
|
+
elif mode == "cuda":
|
|
203
|
+
assert torch.cuda.is_available(), "CUDA requested but not available"
|
|
204
|
+
gpu_id = getattr(config, "gpu_id", 0)
|
|
205
|
+
torch.cuda.set_device(gpu_id)
|
|
206
|
+
device = torch.device(f"cuda:{gpu_id}")
|
|
207
|
+
|
|
208
|
+
elif mode == "cpu":
|
|
209
|
+
device = torch.device("cpu")
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Unknown device mode: {mode}")
|
|
213
|
+
|
|
214
|
+
print(f"Using device: {device}")
|
|
215
|
+
if device.type == "cuda" or device.type == "auto":
|
|
216
|
+
print(f"GPU: {torch.cuda.get_device_name(device)}")
|
|
217
|
+
|
|
218
|
+
return device
|
|
219
|
+
|
|
220
|
+
def is_local_path(p):
|
|
221
|
+
return p and (
|
|
222
|
+
os.path.exists(p) or
|
|
223
|
+
p.endswith((".pt", ".pth", ".tar"))
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def fetch_and_merge_config_from_HF(
|
|
228
|
+
target_config, weights, hf_token=None, merge_policy="full"
|
|
229
|
+
):
|
|
230
|
+
"""
|
|
231
|
+
Fetch config from a local path or HF repo and merge it with the local config.
|
|
232
|
+
|
|
233
|
+
merge_policy:
|
|
234
|
+
- "full": legacy behavior; local config overrides loaded config for
|
|
235
|
+
TASK/MODEL/SYSTEM/TRAIN/DATA.
|
|
236
|
+
- "compatibility": used for inference; only TASK/MODEL are updated from
|
|
237
|
+
pretrained config while runtime/system/data settings remain local.
|
|
238
|
+
"""
|
|
239
|
+
import os
|
|
240
|
+
import logging
|
|
241
|
+
from omegaconf import OmegaConf
|
|
242
|
+
|
|
243
|
+
loaded_cfg = None
|
|
244
|
+
|
|
245
|
+
if is_local_path(weights):
|
|
246
|
+
dir_name = os.path.dirname(os.path.abspath(weights))
|
|
247
|
+
yaml_path = os.path.join(dir_name, "config.yaml")
|
|
248
|
+
json_path = os.path.join(dir_name, "config.json")
|
|
249
|
+
if os.path.exists(yaml_path):
|
|
250
|
+
loaded_cfg = load_config_omega(yaml_path)
|
|
251
|
+
logging.info(f"Loaded config from {yaml_path}")
|
|
252
|
+
elif os.path.exists(json_path):
|
|
253
|
+
loaded_cfg = load_config_omega(json_path)
|
|
254
|
+
logging.info(f"Loaded config from {json_path}")
|
|
255
|
+
else:
|
|
256
|
+
try:
|
|
257
|
+
from huggingface_hub import hf_hub_download
|
|
258
|
+
try:
|
|
259
|
+
config_path = hf_hub_download(repo_id=weights, filename="config.yaml", token=hf_token)
|
|
260
|
+
loaded_cfg = load_config_omega(config_path)
|
|
261
|
+
logging.info(f"Loaded config.yaml from HF repo {weights}")
|
|
262
|
+
except Exception:
|
|
263
|
+
config_path = hf_hub_download(repo_id=weights, filename="config.json", token=hf_token)
|
|
264
|
+
loaded_cfg = load_config_omega(config_path)
|
|
265
|
+
logging.info(f"Loaded config.json from HF repo {weights}")
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logging.warning(f"Could not load config from HF repo {weights}: {e}")
|
|
268
|
+
|
|
269
|
+
if loaded_cfg is not None:
|
|
270
|
+
logging.info(f"Merging pretrained config from {weights}")
|
|
271
|
+
|
|
272
|
+
target_dict = namespace_to_dict(target_config)
|
|
273
|
+
loaded_dict = namespace_to_dict(loaded_cfg)
|
|
274
|
+
|
|
275
|
+
_warn_critical_config_conflicts(target_dict, loaded_dict)
|
|
276
|
+
|
|
277
|
+
if merge_policy == "compatibility":
|
|
278
|
+
# Keep local runtime config as source of truth. Pull only compatibility-
|
|
279
|
+
# critical sections from the pretrained config.
|
|
280
|
+
for section in ["TASK", "MODEL"]:
|
|
281
|
+
if section in loaded_dict:
|
|
282
|
+
if isinstance(loaded_dict[section], dict):
|
|
283
|
+
target_oc = OmegaConf.create(target_dict.get(section, {}))
|
|
284
|
+
loaded_oc = OmegaConf.create(loaded_dict[section])
|
|
285
|
+
merged_oc = OmegaConf.merge(target_oc, loaded_oc)
|
|
286
|
+
target_dict[section] = OmegaConf.to_container(merged_oc, resolve=False)
|
|
287
|
+
else:
|
|
288
|
+
target_dict[section] = loaded_dict[section]
|
|
289
|
+
elif merge_policy == "full":
|
|
290
|
+
for section in ["TASK", "MODEL", "SYSTEM", "TRAIN", "DATA"]:
|
|
291
|
+
if section in loaded_dict:
|
|
292
|
+
# Sanitize the DATA block to strip out remote machine-specific paths
|
|
293
|
+
if section == "DATA" and isinstance(loaded_dict[section], dict):
|
|
294
|
+
keys_to_remove = ["data_dir", "train", "valid", "test"]
|
|
295
|
+
for k in keys_to_remove:
|
|
296
|
+
loaded_dict[section].pop(k, None)
|
|
297
|
+
|
|
298
|
+
# Legacy merge logic: pretrained config as base, local config overrides it.
|
|
299
|
+
if isinstance(loaded_dict[section], dict):
|
|
300
|
+
loaded_oc = OmegaConf.create(loaded_dict[section])
|
|
301
|
+
target_oc = OmegaConf.create(target_dict.get(section, {}))
|
|
302
|
+
merged_oc = OmegaConf.merge(loaded_oc, target_oc)
|
|
303
|
+
target_dict[section] = OmegaConf.to_container(merged_oc, resolve=False)
|
|
304
|
+
else:
|
|
305
|
+
target_dict[section] = target_dict.get(section, loaded_dict[section])
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError(f"Unknown merge_policy: {merge_policy}")
|
|
308
|
+
|
|
309
|
+
# Convert back using DictConfig or SimpleNamespace
|
|
310
|
+
return dict_to_namespace(target_dict)
|
|
311
|
+
|
|
312
|
+
return target_config
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _warn_critical_config_conflicts(target_dict, loaded_dict):
|
|
316
|
+
import logging
|
|
317
|
+
|
|
318
|
+
local_data = target_dict.get("DATA", {}) if isinstance(target_dict, dict) else {}
|
|
319
|
+
hf_data = loaded_dict.get("DATA", {}) if isinstance(loaded_dict, dict) else {}
|
|
320
|
+
|
|
321
|
+
local_num_classes = local_data.get("num_classes")
|
|
322
|
+
hf_num_classes = hf_data.get("num_classes")
|
|
323
|
+
if (
|
|
324
|
+
local_num_classes is not None
|
|
325
|
+
and hf_num_classes is not None
|
|
326
|
+
and local_num_classes != hf_num_classes
|
|
327
|
+
):
|
|
328
|
+
logging.warning(
|
|
329
|
+
"Config mismatch: DATA.num_classes local=%s hf=%s. "
|
|
330
|
+
"Keeping local runtime config values.",
|
|
331
|
+
local_num_classes,
|
|
332
|
+
hf_num_classes,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
local_classes = local_data.get("classes")
|
|
336
|
+
hf_classes = hf_data.get("classes")
|
|
337
|
+
if (
|
|
338
|
+
local_classes is not None
|
|
339
|
+
and hf_classes is not None
|
|
340
|
+
and local_classes != hf_classes
|
|
341
|
+
):
|
|
342
|
+
logging.warning(
|
|
343
|
+
"Config mismatch: DATA.classes differs between local and HF config. "
|
|
344
|
+
"Keeping local runtime config values.",
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def save_config(config_obj, path):
|
|
348
|
+
"""Save the configuration object to a YAML file."""
|
|
349
|
+
from omegaconf import OmegaConf, DictConfig
|
|
350
|
+
import yaml
|
|
351
|
+
|
|
352
|
+
if isinstance(config_obj, DictConfig):
|
|
353
|
+
cfg_dict = OmegaConf.to_container(config_obj, resolve=True)
|
|
354
|
+
else:
|
|
355
|
+
cfg_dict = namespace_to_dict(config_obj)
|
|
356
|
+
|
|
357
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
358
|
+
yaml.dump(cfg_dict, f, default_flow_style=False)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev9
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -94,6 +94,8 @@ tests/test_package_smoke.py
|
|
|
94
94
|
tests/test_public_apis_smoke.py
|
|
95
95
|
tests/test_subset_train_infer_integration.py
|
|
96
96
|
tests/test_task_model_api_contract.py
|
|
97
|
+
tools/convert/build_soccernet_gar.py
|
|
98
|
+
tools/convert/build_soccernet_gar_action_spotting.py
|
|
97
99
|
tools/convert/osl_json_to_parquet_webdataset.py
|
|
98
100
|
tools/convert/parquet_webdataset_to_osl_json.py
|
|
99
101
|
tools/download/download_hf_repo.py
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "opensportslib"
|
|
7
|
-
version = "0.1.2.
|
|
7
|
+
version = "0.1.2.dev9"
|
|
8
8
|
description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.12"
|