opensportslib 0.1.2.dev6__tar.gz → 0.1.2.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.2.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev7}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/localization.py +59 -16
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/classification_trainer.py +2 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/localization_trainer.py +13 -44
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py +26 -4
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7/opensportslib.egg-info}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/pyproject.toml +1 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/LICENSE +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/MANIFEST.in +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/README.md +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/base_task_model.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/SOURCES.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/setup.cfg +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/conftest.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_public_apis_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_task_model_api_contract.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/upload_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/training/classification.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/training/localization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev7
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -15,6 +15,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
15
15
|
self.last_loaded_weights = weights
|
|
16
16
|
self.best_checkpoint = weights
|
|
17
17
|
|
|
18
|
+
self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
19
|
+
|
|
18
20
|
def _resolve_split_path(self, split: str, override: str | None = None) -> str:
|
|
19
21
|
if override is not None:
|
|
20
22
|
return expand(override)
|
|
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
|
|
|
69
71
|
load_checkpoint,
|
|
70
72
|
localization_remap,
|
|
71
73
|
)
|
|
72
|
-
|
|
74
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
75
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
76
|
+
default_args = kwargs.get("default_args", None)
|
|
73
77
|
del kwargs
|
|
74
78
|
if weights is None:
|
|
75
79
|
raise ValueError("`weights` must be provided to load_weights().")
|
|
76
80
|
|
|
77
81
|
model_cfg = getattr(self.config, "MODEL", None)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
model_cfg
|
|
82
|
+
if not self.train_flag:
|
|
83
|
+
original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
|
|
84
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
85
|
+
model_cfg.multi_gpu = False
|
|
81
86
|
|
|
82
87
|
device = select_device(self.config.SYSTEM)
|
|
83
88
|
if self.model is None:
|
|
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
|
|
|
90
95
|
if is_local_path(weights):
|
|
91
96
|
self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
|
|
92
97
|
|
|
93
|
-
|
|
98
|
+
if default_args is not None:
|
|
99
|
+
logging.info("Building optimizer + scaler for checkpoint restore...")
|
|
100
|
+
optimizer, scaler = build_optimizer(
|
|
101
|
+
inner_model.parameters(), # or _get_params() if required
|
|
102
|
+
self.config.TRAIN.optimizer
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
logging.info("Building scheduler for checkpoint restore...")
|
|
106
|
+
scheduler = build_scheduler(
|
|
107
|
+
optimizer,
|
|
108
|
+
self.config.TRAIN.scheduler,
|
|
109
|
+
default_args
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
optimizer = scheduler = scaler = None
|
|
113
|
+
|
|
114
|
+
inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
94
115
|
model=inner_model,
|
|
95
116
|
path=weights,
|
|
117
|
+
optimizer=optimizer,
|
|
118
|
+
scheduler=scheduler,
|
|
119
|
+
scaler=scaler,
|
|
96
120
|
device=device,
|
|
97
121
|
key_remap_fn=localization_remap,
|
|
98
122
|
)
|
|
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
|
|
|
107
131
|
self.last_loaded_weights = weights
|
|
108
132
|
self.best_checkpoint = weights
|
|
109
133
|
|
|
110
|
-
|
|
111
|
-
|
|
134
|
+
best_epoch = checkpoint.get("best_epoch", 0)
|
|
135
|
+
|
|
136
|
+
best_criterion_valid = checkpoint.get(
|
|
137
|
+
"best_criterion_valid",
|
|
138
|
+
0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
|
|
139
|
+
)
|
|
140
|
+
self._resume_state = {
|
|
141
|
+
"optimizer": optimizer,
|
|
142
|
+
"scheduler": scheduler,
|
|
143
|
+
"scaler": scaler,
|
|
144
|
+
"epoch": epoch if epoch is not None else 0,
|
|
145
|
+
"best_epoch": best_epoch,
|
|
146
|
+
"best_criterion_valid": best_criterion_valid,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if not self.train_flag:
|
|
150
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
151
|
+
model_cfg.multi_gpu = original_multi_gpu
|
|
112
152
|
|
|
113
153
|
def train(
|
|
114
154
|
self,
|
|
@@ -167,13 +207,6 @@ class LocalizationModel(BaseTaskModel):
|
|
|
167
207
|
|
|
168
208
|
start = time.time()
|
|
169
209
|
|
|
170
|
-
if effective_weights is not None:
|
|
171
|
-
if self.model is None or self.last_loaded_weights != effective_weights:
|
|
172
|
-
self.load_weights(weights=effective_weights)
|
|
173
|
-
elif self.model is None:
|
|
174
|
-
device = select_device(self.config.SYSTEM)
|
|
175
|
-
self.model = build_model(self.config, device=device)
|
|
176
|
-
|
|
177
210
|
data_obj_train = build_dataset(self.config, split="train")
|
|
178
211
|
dataset_train = data_obj_train.building_dataset(
|
|
179
212
|
cfg=data_obj_train.cfg,
|
|
@@ -200,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
|
|
|
200
233
|
dali=self.config.dali,
|
|
201
234
|
)
|
|
202
235
|
|
|
236
|
+
default_args = get_default_args_trainer(self.config, len(train_loader))
|
|
237
|
+
|
|
238
|
+
self.train_flag = True # Set flag to indicate training mode for checkpoint loading
|
|
239
|
+
if effective_weights is not None:
|
|
240
|
+
if self.model is None or self.last_loaded_weights != effective_weights:
|
|
241
|
+
self.load_weights(weights=effective_weights, default_args=default_args)
|
|
242
|
+
elif self.model is None:
|
|
243
|
+
device = select_device(self.config.SYSTEM)
|
|
244
|
+
self.model = build_model(self.config, device=device)
|
|
245
|
+
|
|
203
246
|
self.trainer = build_trainer(
|
|
204
247
|
cfg=self.config,
|
|
205
248
|
model=self.model,
|
|
206
|
-
default_args=
|
|
207
|
-
resume_from=
|
|
249
|
+
default_args=default_args,
|
|
250
|
+
resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
|
|
208
251
|
)
|
|
209
252
|
|
|
210
253
|
logging.info("Start training")
|
|
@@ -1167,11 +1167,12 @@ class Trainer_Classification:
|
|
|
1167
1167
|
from opensportslib.models.builder import build_model
|
|
1168
1168
|
if self.model is None:
|
|
1169
1169
|
self.model, _ = build_model(self.config, self.device)
|
|
1170
|
-
self.model, optimizer, scheduler, epoch = load_checkpoint(
|
|
1170
|
+
self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
1171
1171
|
self.model, path, optimizer, scheduler, device=self.device
|
|
1172
1172
|
)
|
|
1173
1173
|
self.optimizer = optimizer
|
|
1174
1174
|
self.scheduler = scheduler
|
|
1175
|
+
self.scaler = scaler
|
|
1175
1176
|
self.epoch = epoch
|
|
1176
1177
|
logging.info(f"Model loaded from {path}, epoch: {epoch}")
|
|
1177
1178
|
return self.model, self.optimizer, self.scheduler, self.epoch
|
|
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
29
29
|
"""
|
|
30
30
|
from opensportslib.metrics.localization_metric import *
|
|
31
31
|
from opensportslib.core.optimizer.builder import build_optimizer
|
|
32
|
-
from opensportslib.core.optimizer.builder import build_optimizer
|
|
33
32
|
from opensportslib.core.scheduler.builder import build_scheduler
|
|
34
33
|
from opensportslib.core.utils.config import store_json
|
|
35
34
|
from opensportslib.datasets.builder import build_dataset
|
|
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
67
66
|
|
|
68
67
|
# Handle checkpoint loading
|
|
69
68
|
if resume_from is not None:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
checkpoint = torch.load(resume_from)
|
|
75
|
-
|
|
76
|
-
# Load model state
|
|
77
|
-
model.load(checkpoint['model_state_dict'])
|
|
78
|
-
logging.info("Model state loaded successfully")
|
|
79
|
-
|
|
80
|
-
# Get current training progress
|
|
81
|
-
start_epoch = checkpoint['epoch'] + 1
|
|
82
|
-
logging.info(f"Resuming from epoch {start_epoch}")
|
|
83
|
-
|
|
69
|
+
optimizer = resume_from["optimizer"]
|
|
70
|
+
scheduler = resume_from["scheduler"]
|
|
71
|
+
scaler = resume_from["scaler"]
|
|
72
|
+
start_epoch = resume_from["epoch"] + 1
|
|
84
73
|
# Check if we've already reached target epochs
|
|
85
74
|
if start_epoch >= cfg.TRAIN.num_epochs:
|
|
86
75
|
logging.error(f"Model already trained for {start_epoch} epochs")
|
|
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
89
78
|
raise ValueError("Need to increase num_epochs to continue training")
|
|
90
79
|
|
|
91
80
|
logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if resume_from is not None and 'optimizer_state_dict' in checkpoint:
|
|
98
|
-
try:
|
|
99
|
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
100
|
-
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
101
|
-
logging.info("Optimizer and scaler states loaded")
|
|
102
|
-
except Exception as e:
|
|
103
|
-
logging.warning(f"Could not load optimizer state: {e}")
|
|
104
|
-
logging.warning("Will start with fresh optimizer state")
|
|
105
|
-
|
|
106
|
-
logging.info("Building scheduler...")
|
|
107
|
-
lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
108
|
-
|
|
109
|
-
# Load scheduler state if available
|
|
110
|
-
if resume_from is not None and 'lr_state_dict' in checkpoint:
|
|
111
|
-
try:
|
|
112
|
-
lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
|
|
113
|
-
logging.info("Scheduler state loaded")
|
|
114
|
-
except Exception as e:
|
|
115
|
-
logging.warning(f"Could not load scheduler state: {e}")
|
|
116
|
-
logging.warning("Will start with fresh scheduler state")
|
|
81
|
+
else:
|
|
82
|
+
logging.info("Building optimizer...")
|
|
83
|
+
optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
|
|
84
|
+
logging.info("Building scheduler...")
|
|
85
|
+
scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
117
86
|
|
|
118
87
|
trainer = Trainer_e2e(
|
|
119
88
|
cfg,
|
|
120
89
|
model,
|
|
121
90
|
optimizer,
|
|
122
91
|
scaler,
|
|
123
|
-
|
|
92
|
+
scheduler,
|
|
124
93
|
default_args["work_dir"],
|
|
125
94
|
default_args["dali"],
|
|
126
95
|
default_args["repartitions"],
|
|
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
132
101
|
|
|
133
102
|
# Load training history if resuming
|
|
134
103
|
if resume_from is not None:
|
|
135
|
-
trainer.best_epoch =
|
|
136
|
-
trainer.best_criterion_valid =
|
|
104
|
+
trainer.best_epoch = resume_from.get('best_epoch', 0)
|
|
105
|
+
trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
|
|
137
106
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
138
107
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
139
108
|
|
|
@@ -441,7 +410,7 @@ class Trainer_e2e(Trainer):
|
|
|
441
410
|
best_checkpoint_path = os.path.join(
|
|
442
411
|
self.save_dir, f"best_checkpoint.pt"
|
|
443
412
|
)
|
|
444
|
-
self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
|
|
413
|
+
self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
|
|
445
414
|
path=best_checkpoint_path,
|
|
446
415
|
key_remap_fn=localization_remap)
|
|
447
416
|
logging.info(f"Loaded best model from epoch {self.best_epoch}")
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py
RENAMED
|
@@ -76,6 +76,7 @@ def load_checkpoint(
|
|
|
76
76
|
path,
|
|
77
77
|
optimizer=None,
|
|
78
78
|
scheduler=None,
|
|
79
|
+
scaler=None,
|
|
79
80
|
device=None,
|
|
80
81
|
key_remap_fn=None,
|
|
81
82
|
hf_filename="model.pth.tar", # required if loading from HF repo
|
|
@@ -164,7 +165,7 @@ def load_checkpoint(
|
|
|
164
165
|
# --------------------------------------------------
|
|
165
166
|
# Load checkpoint
|
|
166
167
|
# --------------------------------------------------
|
|
167
|
-
checkpoint = torch.load(ckpt_path, map_location=
|
|
168
|
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
168
169
|
|
|
169
170
|
# ---------------- MODEL STATE ----------------
|
|
170
171
|
if isinstance(checkpoint, dict):
|
|
@@ -201,8 +202,24 @@ def load_checkpoint(
|
|
|
201
202
|
for k, v in state_dict.items()
|
|
202
203
|
}
|
|
203
204
|
|
|
204
|
-
state_dict = strip_prefix(state_dict, "module.")
|
|
205
|
+
# state_dict = strip_prefix(state_dict, "module.")
|
|
206
|
+
# state_dict = strip_prefix(state_dict, "model.")
|
|
207
|
+
|
|
208
|
+
# First remove known wrappers (safe ones)
|
|
205
209
|
state_dict = strip_prefix(state_dict, "model.")
|
|
210
|
+
state_dict = strip_prefix(state_dict, "_model.")
|
|
211
|
+
|
|
212
|
+
# Now handle module dynamically
|
|
213
|
+
model_keys = list(model.state_dict().keys())
|
|
214
|
+
ckpt_keys = list(state_dict.keys())
|
|
215
|
+
|
|
216
|
+
model_has_module = model_keys[0].startswith("module.")
|
|
217
|
+
ckpt_has_module = ckpt_keys[0].startswith("module.")
|
|
218
|
+
|
|
219
|
+
if model_has_module and not ckpt_has_module:
|
|
220
|
+
state_dict = {f"module.{k}": v for k, v in state_dict.items()}
|
|
221
|
+
elif not model_has_module and ckpt_has_module:
|
|
222
|
+
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
|
206
223
|
|
|
207
224
|
# Optional custom remap
|
|
208
225
|
if key_remap_fn:
|
|
@@ -229,15 +246,20 @@ def load_checkpoint(
|
|
|
229
246
|
|
|
230
247
|
# ---------------- SCHEDULER ----------------
|
|
231
248
|
if scheduler and isinstance(checkpoint, dict):
|
|
232
|
-
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
|
|
249
|
+
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
|
|
233
250
|
if sch_state:
|
|
234
251
|
scheduler.load_state_dict(sch_state)
|
|
235
252
|
|
|
253
|
+
if scaler and isinstance(checkpoint, dict):
|
|
254
|
+
scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
|
|
255
|
+
if scaler_state:
|
|
256
|
+
scaler.load_state_dict(scaler_state)
|
|
257
|
+
|
|
236
258
|
print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
|
|
237
259
|
print(f"Missing keys: {len(missing)}")
|
|
238
260
|
print(f"Unexpected keys: {len(unexpected)}")
|
|
239
261
|
|
|
240
|
-
return model, optimizer, scheduler, epoch
|
|
262
|
+
return model, optimizer, scheduler, scaler, epoch, checkpoint
|
|
241
263
|
|
|
242
264
|
|
|
243
265
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev7
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "opensportslib"
|
|
7
|
-
version = "0.1.2.
|
|
7
|
+
version = "0.1.2.dev7"
|
|
8
8
|
description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.12"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py
RENAMED
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|