opensportslib 0.1.2.dev5__tar.gz → 0.1.2.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.2.dev5/opensportslib.egg-info → opensportslib-0.1.2.dev7}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/localization.py +59 -12
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/classification_trainer.py +2 -1
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/localization_trainer.py +13 -44
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py +26 -4
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py +28 -18
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7/opensportslib.egg-info}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/pyproject.toml +1 -1
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/LICENSE +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/MANIFEST.in +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/README.md +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/base_task_model.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/SOURCES.txt +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/setup.cfg +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/conftest.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_public_apis_smoke.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_task_model_api_contract.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/upload_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/training/classification.py +0 -0
- {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/training/localization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev7
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -15,6 +15,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
15
15
|
self.last_loaded_weights = weights
|
|
16
16
|
self.best_checkpoint = weights
|
|
17
17
|
|
|
18
|
+
self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
19
|
+
|
|
18
20
|
def _resolve_split_path(self, split: str, override: str | None = None) -> str:
|
|
19
21
|
if override is not None:
|
|
20
22
|
return expand(override)
|
|
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
|
|
|
69
71
|
load_checkpoint,
|
|
70
72
|
localization_remap,
|
|
71
73
|
)
|
|
72
|
-
|
|
74
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
75
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
76
|
+
default_args = kwargs.get("default_args", None)
|
|
73
77
|
del kwargs
|
|
74
78
|
if weights is None:
|
|
75
79
|
raise ValueError("`weights` must be provided to load_weights().")
|
|
76
80
|
|
|
77
81
|
model_cfg = getattr(self.config, "MODEL", None)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
model_cfg
|
|
82
|
+
if not self.train_flag:
|
|
83
|
+
original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
|
|
84
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
85
|
+
model_cfg.multi_gpu = False
|
|
81
86
|
|
|
82
87
|
device = select_device(self.config.SYSTEM)
|
|
83
88
|
if self.model is None:
|
|
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
|
|
|
90
95
|
if is_local_path(weights):
|
|
91
96
|
self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
|
|
92
97
|
|
|
93
|
-
|
|
98
|
+
if default_args is not None:
|
|
99
|
+
logging.info("Building optimizer + scaler for checkpoint restore...")
|
|
100
|
+
optimizer, scaler = build_optimizer(
|
|
101
|
+
inner_model.parameters(), # or _get_params() if required
|
|
102
|
+
self.config.TRAIN.optimizer
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
logging.info("Building scheduler for checkpoint restore...")
|
|
106
|
+
scheduler = build_scheduler(
|
|
107
|
+
optimizer,
|
|
108
|
+
self.config.TRAIN.scheduler,
|
|
109
|
+
default_args
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
optimizer = scheduler = scaler = None
|
|
113
|
+
|
|
114
|
+
inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
94
115
|
model=inner_model,
|
|
95
116
|
path=weights,
|
|
117
|
+
optimizer=optimizer,
|
|
118
|
+
scheduler=scheduler,
|
|
119
|
+
scaler=scaler,
|
|
96
120
|
device=device,
|
|
97
121
|
key_remap_fn=localization_remap,
|
|
98
122
|
)
|
|
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
|
|
|
107
131
|
self.last_loaded_weights = weights
|
|
108
132
|
self.best_checkpoint = weights
|
|
109
133
|
|
|
110
|
-
|
|
111
|
-
|
|
134
|
+
best_epoch = checkpoint.get("best_epoch", 0)
|
|
135
|
+
|
|
136
|
+
best_criterion_valid = checkpoint.get(
|
|
137
|
+
"best_criterion_valid",
|
|
138
|
+
0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
|
|
139
|
+
)
|
|
140
|
+
self._resume_state = {
|
|
141
|
+
"optimizer": optimizer,
|
|
142
|
+
"scheduler": scheduler,
|
|
143
|
+
"scaler": scaler,
|
|
144
|
+
"epoch": epoch if epoch is not None else 0,
|
|
145
|
+
"best_epoch": best_epoch,
|
|
146
|
+
"best_criterion_valid": best_criterion_valid,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if not self.train_flag:
|
|
150
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
151
|
+
model_cfg.multi_gpu = original_multi_gpu
|
|
112
152
|
|
|
113
153
|
def train(
|
|
114
154
|
self,
|
|
@@ -167,9 +207,6 @@ class LocalizationModel(BaseTaskModel):
|
|
|
167
207
|
|
|
168
208
|
start = time.time()
|
|
169
209
|
|
|
170
|
-
device = select_device(self.config.SYSTEM)
|
|
171
|
-
self.model = build_model(self.config, device=device)
|
|
172
|
-
|
|
173
210
|
data_obj_train = build_dataset(self.config, split="train")
|
|
174
211
|
dataset_train = data_obj_train.building_dataset(
|
|
175
212
|
cfg=data_obj_train.cfg,
|
|
@@ -196,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
|
|
|
196
233
|
dali=self.config.dali,
|
|
197
234
|
)
|
|
198
235
|
|
|
236
|
+
default_args = get_default_args_trainer(self.config, len(train_loader))
|
|
237
|
+
|
|
238
|
+
self.train_flag = True # Set flag to indicate training mode for checkpoint loading
|
|
239
|
+
if effective_weights is not None:
|
|
240
|
+
if self.model is None or self.last_loaded_weights != effective_weights:
|
|
241
|
+
self.load_weights(weights=effective_weights, default_args=default_args)
|
|
242
|
+
elif self.model is None:
|
|
243
|
+
device = select_device(self.config.SYSTEM)
|
|
244
|
+
self.model = build_model(self.config, device=device)
|
|
245
|
+
|
|
199
246
|
self.trainer = build_trainer(
|
|
200
247
|
cfg=self.config,
|
|
201
248
|
model=self.model,
|
|
202
|
-
default_args=
|
|
203
|
-
resume_from=
|
|
249
|
+
default_args=default_args,
|
|
250
|
+
resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
|
|
204
251
|
)
|
|
205
252
|
|
|
206
253
|
logging.info("Start training")
|
|
@@ -1167,11 +1167,12 @@ class Trainer_Classification:
|
|
|
1167
1167
|
from opensportslib.models.builder import build_model
|
|
1168
1168
|
if self.model is None:
|
|
1169
1169
|
self.model, _ = build_model(self.config, self.device)
|
|
1170
|
-
self.model, optimizer, scheduler, epoch = load_checkpoint(
|
|
1170
|
+
self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
1171
1171
|
self.model, path, optimizer, scheduler, device=self.device
|
|
1172
1172
|
)
|
|
1173
1173
|
self.optimizer = optimizer
|
|
1174
1174
|
self.scheduler = scheduler
|
|
1175
|
+
self.scaler = scaler
|
|
1175
1176
|
self.epoch = epoch
|
|
1176
1177
|
logging.info(f"Model loaded from {path}, epoch: {epoch}")
|
|
1177
1178
|
return self.model, self.optimizer, self.scheduler, self.epoch
|
|
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
29
29
|
"""
|
|
30
30
|
from opensportslib.metrics.localization_metric import *
|
|
31
31
|
from opensportslib.core.optimizer.builder import build_optimizer
|
|
32
|
-
from opensportslib.core.optimizer.builder import build_optimizer
|
|
33
32
|
from opensportslib.core.scheduler.builder import build_scheduler
|
|
34
33
|
from opensportslib.core.utils.config import store_json
|
|
35
34
|
from opensportslib.datasets.builder import build_dataset
|
|
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
67
66
|
|
|
68
67
|
# Handle checkpoint loading
|
|
69
68
|
if resume_from is not None:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
checkpoint = torch.load(resume_from)
|
|
75
|
-
|
|
76
|
-
# Load model state
|
|
77
|
-
model.load(checkpoint['model_state_dict'])
|
|
78
|
-
logging.info("Model state loaded successfully")
|
|
79
|
-
|
|
80
|
-
# Get current training progress
|
|
81
|
-
start_epoch = checkpoint['epoch'] + 1
|
|
82
|
-
logging.info(f"Resuming from epoch {start_epoch}")
|
|
83
|
-
|
|
69
|
+
optimizer = resume_from["optimizer"]
|
|
70
|
+
scheduler = resume_from["scheduler"]
|
|
71
|
+
scaler = resume_from["scaler"]
|
|
72
|
+
start_epoch = resume_from["epoch"] + 1
|
|
84
73
|
# Check if we've already reached target epochs
|
|
85
74
|
if start_epoch >= cfg.TRAIN.num_epochs:
|
|
86
75
|
logging.error(f"Model already trained for {start_epoch} epochs")
|
|
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
89
78
|
raise ValueError("Need to increase num_epochs to continue training")
|
|
90
79
|
|
|
91
80
|
logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if resume_from is not None and 'optimizer_state_dict' in checkpoint:
|
|
98
|
-
try:
|
|
99
|
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
100
|
-
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
101
|
-
logging.info("Optimizer and scaler states loaded")
|
|
102
|
-
except Exception as e:
|
|
103
|
-
logging.warning(f"Could not load optimizer state: {e}")
|
|
104
|
-
logging.warning("Will start with fresh optimizer state")
|
|
105
|
-
|
|
106
|
-
logging.info("Building scheduler...")
|
|
107
|
-
lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
108
|
-
|
|
109
|
-
# Load scheduler state if available
|
|
110
|
-
if resume_from is not None and 'lr_state_dict' in checkpoint:
|
|
111
|
-
try:
|
|
112
|
-
lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
|
|
113
|
-
logging.info("Scheduler state loaded")
|
|
114
|
-
except Exception as e:
|
|
115
|
-
logging.warning(f"Could not load scheduler state: {e}")
|
|
116
|
-
logging.warning("Will start with fresh scheduler state")
|
|
81
|
+
else:
|
|
82
|
+
logging.info("Building optimizer...")
|
|
83
|
+
optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
|
|
84
|
+
logging.info("Building scheduler...")
|
|
85
|
+
scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
117
86
|
|
|
118
87
|
trainer = Trainer_e2e(
|
|
119
88
|
cfg,
|
|
120
89
|
model,
|
|
121
90
|
optimizer,
|
|
122
91
|
scaler,
|
|
123
|
-
|
|
92
|
+
scheduler,
|
|
124
93
|
default_args["work_dir"],
|
|
125
94
|
default_args["dali"],
|
|
126
95
|
default_args["repartitions"],
|
|
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
132
101
|
|
|
133
102
|
# Load training history if resuming
|
|
134
103
|
if resume_from is not None:
|
|
135
|
-
trainer.best_epoch =
|
|
136
|
-
trainer.best_criterion_valid =
|
|
104
|
+
trainer.best_epoch = resume_from.get('best_epoch', 0)
|
|
105
|
+
trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
|
|
137
106
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
138
107
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
139
108
|
|
|
@@ -441,7 +410,7 @@ class Trainer_e2e(Trainer):
|
|
|
441
410
|
best_checkpoint_path = os.path.join(
|
|
442
411
|
self.save_dir, f"best_checkpoint.pt"
|
|
443
412
|
)
|
|
444
|
-
self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
|
|
413
|
+
self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
|
|
445
414
|
path=best_checkpoint_path,
|
|
446
415
|
key_remap_fn=localization_remap)
|
|
447
416
|
logging.info(f"Loaded best model from epoch {self.best_epoch}")
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py
RENAMED
|
@@ -76,6 +76,7 @@ def load_checkpoint(
|
|
|
76
76
|
path,
|
|
77
77
|
optimizer=None,
|
|
78
78
|
scheduler=None,
|
|
79
|
+
scaler=None,
|
|
79
80
|
device=None,
|
|
80
81
|
key_remap_fn=None,
|
|
81
82
|
hf_filename="model.pth.tar", # required if loading from HF repo
|
|
@@ -164,7 +165,7 @@ def load_checkpoint(
|
|
|
164
165
|
# --------------------------------------------------
|
|
165
166
|
# Load checkpoint
|
|
166
167
|
# --------------------------------------------------
|
|
167
|
-
checkpoint = torch.load(ckpt_path, map_location=
|
|
168
|
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
168
169
|
|
|
169
170
|
# ---------------- MODEL STATE ----------------
|
|
170
171
|
if isinstance(checkpoint, dict):
|
|
@@ -201,8 +202,24 @@ def load_checkpoint(
|
|
|
201
202
|
for k, v in state_dict.items()
|
|
202
203
|
}
|
|
203
204
|
|
|
204
|
-
state_dict = strip_prefix(state_dict, "module.")
|
|
205
|
+
# state_dict = strip_prefix(state_dict, "module.")
|
|
206
|
+
# state_dict = strip_prefix(state_dict, "model.")
|
|
207
|
+
|
|
208
|
+
# First remove known wrappers (safe ones)
|
|
205
209
|
state_dict = strip_prefix(state_dict, "model.")
|
|
210
|
+
state_dict = strip_prefix(state_dict, "_model.")
|
|
211
|
+
|
|
212
|
+
# Now handle module dynamically
|
|
213
|
+
model_keys = list(model.state_dict().keys())
|
|
214
|
+
ckpt_keys = list(state_dict.keys())
|
|
215
|
+
|
|
216
|
+
model_has_module = model_keys[0].startswith("module.")
|
|
217
|
+
ckpt_has_module = ckpt_keys[0].startswith("module.")
|
|
218
|
+
|
|
219
|
+
if model_has_module and not ckpt_has_module:
|
|
220
|
+
state_dict = {f"module.{k}": v for k, v in state_dict.items()}
|
|
221
|
+
elif not model_has_module and ckpt_has_module:
|
|
222
|
+
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
|
206
223
|
|
|
207
224
|
# Optional custom remap
|
|
208
225
|
if key_remap_fn:
|
|
@@ -229,15 +246,20 @@ def load_checkpoint(
|
|
|
229
246
|
|
|
230
247
|
# ---------------- SCHEDULER ----------------
|
|
231
248
|
if scheduler and isinstance(checkpoint, dict):
|
|
232
|
-
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
|
|
249
|
+
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
|
|
233
250
|
if sch_state:
|
|
234
251
|
scheduler.load_state_dict(sch_state)
|
|
235
252
|
|
|
253
|
+
if scaler and isinstance(checkpoint, dict):
|
|
254
|
+
scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
|
|
255
|
+
if scaler_state:
|
|
256
|
+
scaler.load_state_dict(scaler_state)
|
|
257
|
+
|
|
236
258
|
print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
|
|
237
259
|
print(f"Missing keys: {len(missing)}")
|
|
238
260
|
print(f"Unexpected keys: {len(unexpected)}")
|
|
239
261
|
|
|
240
|
-
return model, optimizer, scheduler, epoch
|
|
262
|
+
return model, optimizer, scheduler, scaler, epoch, checkpoint
|
|
241
263
|
|
|
242
264
|
|
|
243
265
|
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py
RENAMED
|
@@ -1016,7 +1016,10 @@ if DALI_AVAILABLE:
|
|
|
1016
1016
|
for pipe in self.pipes:
|
|
1017
1017
|
pipe.build()
|
|
1018
1018
|
|
|
1019
|
-
|
|
1019
|
+
# Pipeline returns (video, label_idx, frame_num) - label processing
|
|
1020
|
+
# is done post-hoc in get_attr to avoid DALI 2.0 fn.python_function issues
|
|
1021
|
+
internal_output_map = ['data', 'label_idx', 'frame_num']
|
|
1022
|
+
super().__init__(self.pipes, internal_output_map, size=self.nb_videos)
|
|
1020
1023
|
|
|
1021
1024
|
self.device = torch.device(
|
|
1022
1025
|
"cuda:{}".format(self.devices[1 if len(self.devices) > 1 else 0])
|
|
@@ -1052,8 +1055,19 @@ if DALI_AVAILABLE:
|
|
|
1052
1055
|
Returns:
|
|
1053
1056
|
dict :{"frames","contains_event","labels"}.
|
|
1054
1057
|
"""
|
|
1055
|
-
|
|
1058
|
+
batch_label_idx = batch["label_idx"]
|
|
1059
|
+
batch_frame_num = batch["frame_num"]
|
|
1056
1060
|
batch_images = batch["data"]
|
|
1061
|
+
|
|
1062
|
+
batch_size = batch_label_idx.shape[0]
|
|
1063
|
+
batch_labels = torch.zeros(batch_size, self.clip_len, dtype=torch.int64)
|
|
1064
|
+
for b in range(batch_size):
|
|
1065
|
+
video_idx = int(batch_label_idx[b].item())
|
|
1066
|
+
frame_num = int(batch_frame_num[b].item())
|
|
1067
|
+
batch_labels[b] = torch.from_numpy(
|
|
1068
|
+
self._compute_labels(video_idx, frame_num)
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1057
1071
|
sum_labels = torch.sum(
|
|
1058
1072
|
batch_labels, dim=1 if len(batch_labels.shape) == 2 else 0
|
|
1059
1073
|
)
|
|
@@ -1229,26 +1243,22 @@ if DALI_AVAILABLE:
|
|
|
1229
1243
|
std=[255, 255, 255],
|
|
1230
1244
|
mirror=fn.random.coin_flip(),
|
|
1231
1245
|
)
|
|
1232
|
-
label
|
|
1233
|
-
label, frame_num, function=self.edit_labels, device="gpu"
|
|
1234
|
-
)
|
|
1235
|
-
return video, label
|
|
1246
|
+
return video, label, frame_num
|
|
1236
1247
|
|
|
1237
|
-
def
|
|
1238
|
-
"""Construct a
|
|
1248
|
+
def _compute_labels(self, video_idx, frame_num):
|
|
1249
|
+
"""Construct a label array for a clip. Each element is the class index
|
|
1250
|
+
(starting at 1) where an event occurs, 0 otherwise.
|
|
1239
1251
|
|
|
1240
1252
|
Args:
|
|
1241
|
-
|
|
1242
|
-
frame_num :
|
|
1253
|
+
video_idx (int): Index of the video in self._labels.
|
|
1254
|
+
frame_num (int): Raw start frame number from the reader.
|
|
1243
1255
|
|
|
1244
1256
|
Returns:
|
|
1245
|
-
labels (
|
|
1257
|
+
labels (np.ndarray): Label array of shape (clip_len,).
|
|
1246
1258
|
"""
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
base_idx = frame_num.item() // self._stride
|
|
1251
|
-
labels = cupy.zeros(self.clip_len, np.int64)
|
|
1259
|
+
video_meta = self._labels[video_idx]
|
|
1260
|
+
base_idx = frame_num // self._stride
|
|
1261
|
+
labels = np.zeros(self.clip_len, np.int64)
|
|
1252
1262
|
|
|
1253
1263
|
for event in video_meta["events"]:
|
|
1254
1264
|
event_frame = event["frame"]
|
|
@@ -1258,12 +1268,12 @@ if DALI_AVAILABLE:
|
|
|
1258
1268
|
label_idx >= self.dilate_len
|
|
1259
1269
|
and label_idx < self.clip_len + self.dilate_len
|
|
1260
1270
|
):
|
|
1261
|
-
|
|
1271
|
+
label_val = self._class_dict[event["label"]]
|
|
1262
1272
|
for i in range(
|
|
1263
1273
|
max(0, label_idx - self.dilate_len),
|
|
1264
1274
|
min(self.clip_len, label_idx + self.dilate_len + 1),
|
|
1265
1275
|
):
|
|
1266
|
-
labels[i] =
|
|
1276
|
+
labels[i] = label_val
|
|
1267
1277
|
return labels
|
|
1268
1278
|
|
|
1269
1279
|
def print_info(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev7
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "opensportslib"
|
|
7
|
-
version = "0.1.2.
|
|
7
|
+
version = "0.1.2.dev7"
|
|
8
8
|
description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.12"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|