opensportslib 0.1.2.dev6__tar.gz → 0.1.2.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.2.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev8}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/base_task_model.py +8 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/classification.py +4 -5
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/localization.py +68 -26
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/classification_trainer.py +9 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/localization_trainer.py +27 -44
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/checkpoint.py +26 -4
- opensportslib-0.1.2.dev8/opensportslib/core/utils/config.py +358 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8/opensportslib.egg-info}/PKG-INFO +1 -1
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/pyproject.toml +1 -1
- opensportslib-0.1.2.dev6/opensportslib/core/utils/config.py +0 -214
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/LICENSE +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/MANIFEST.in +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/README.md +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/localization_dataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/SOURCES.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/setup.cfg +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/conftest.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_public_apis_smoke.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_task_model_api_contract.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/upload_osl_hf.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/training/classification.py +0 -0
- {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/training/localization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev8
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -9,7 +9,7 @@ import uuid
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
-
from opensportslib.core.utils.config import expand, load_config_omega
|
|
12
|
+
from opensportslib.core.utils.config import expand, load_config_omega, fetch_and_merge_config_from_HF
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class BaseTaskModel(ABC):
|
|
@@ -24,6 +24,13 @@ class BaseTaskModel(ABC):
|
|
|
24
24
|
self.config_path = expand(config)
|
|
25
25
|
self.config = load_config_omega(self.config_path)
|
|
26
26
|
|
|
27
|
+
if weights is not None:
|
|
28
|
+
self.config = fetch_and_merge_config_from_HF(self.config, weights, merge_policy="compatibility")
|
|
29
|
+
self.last_loaded_weights = weights
|
|
30
|
+
self.best_checkpoint = weights
|
|
31
|
+
|
|
32
|
+
self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
33
|
+
|
|
27
34
|
data_cfg = getattr(self.config, "DATA", None)
|
|
28
35
|
if data_cfg is not None and hasattr(data_cfg, "data_dir"):
|
|
29
36
|
data_cfg.data_dir = expand(data_cfg.data_dir)
|
|
@@ -8,7 +8,6 @@ import os
|
|
|
8
8
|
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
9
9
|
from opensportslib.core.utils.config import expand
|
|
10
10
|
|
|
11
|
-
|
|
12
11
|
class ClassificationModel(BaseTaskModel):
|
|
13
12
|
"""Top-level task wrapper for classification."""
|
|
14
13
|
|
|
@@ -172,8 +171,8 @@ class ClassificationModel(BaseTaskModel):
|
|
|
172
171
|
|
|
173
172
|
train_set = self._resolve_split_path("train", train_set)
|
|
174
173
|
valid_set = self._resolve_split_path("valid", valid_set)
|
|
175
|
-
|
|
176
|
-
self.config = resolve_config_omega(self.config)
|
|
174
|
+
|
|
175
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
177
176
|
logging.info("Configuration:")
|
|
178
177
|
logging.info(self.config)
|
|
179
178
|
|
|
@@ -241,7 +240,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
241
240
|
|
|
242
241
|
test_set = self._resolve_split_path("test", test_set)
|
|
243
242
|
|
|
244
|
-
self.config = resolve_config_omega(self.config)
|
|
243
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
245
244
|
logging.info("Configuration:")
|
|
246
245
|
logging.info(self.config)
|
|
247
246
|
|
|
@@ -304,7 +303,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
304
303
|
|
|
305
304
|
test_set = self._resolve_split_path("test", test_set)
|
|
306
305
|
|
|
307
|
-
self.config = resolve_config_omega(self.config)
|
|
306
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
308
307
|
logging.info("Configuration:")
|
|
309
308
|
logging.info(self.config)
|
|
310
309
|
if predictions is None:
|
|
@@ -9,11 +9,13 @@ from opensportslib.core.utils.config import expand
|
|
|
9
9
|
class LocalizationModel(BaseTaskModel):
|
|
10
10
|
"""Top-level task wrapper for localization / spotting."""
|
|
11
11
|
|
|
12
|
-
def __init__(self, config=None, weights=None):
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
# def __init__(self, config=None, weights=None):
|
|
13
|
+
# super().__init__(config=config, weights=None)
|
|
14
|
+
# if weights is not None:
|
|
15
|
+
# self.last_loaded_weights = weights
|
|
16
|
+
# self.best_checkpoint = weights
|
|
17
|
+
|
|
18
|
+
# self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
|
|
17
19
|
|
|
18
20
|
def _resolve_split_path(self, split: str, override: str | None = None) -> str:
|
|
19
21
|
if override is not None:
|
|
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
|
|
|
69
71
|
load_checkpoint,
|
|
70
72
|
localization_remap,
|
|
71
73
|
)
|
|
72
|
-
|
|
74
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
75
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
76
|
+
default_args = kwargs.get("default_args", None)
|
|
73
77
|
del kwargs
|
|
74
78
|
if weights is None:
|
|
75
79
|
raise ValueError("`weights` must be provided to load_weights().")
|
|
76
80
|
|
|
77
81
|
model_cfg = getattr(self.config, "MODEL", None)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
model_cfg
|
|
82
|
+
if not self.train_flag:
|
|
83
|
+
original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
|
|
84
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
85
|
+
model_cfg.multi_gpu = False
|
|
81
86
|
|
|
82
87
|
device = select_device(self.config.SYSTEM)
|
|
83
88
|
if self.model is None:
|
|
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
|
|
|
90
95
|
if is_local_path(weights):
|
|
91
96
|
self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
|
|
92
97
|
|
|
93
|
-
|
|
98
|
+
if default_args is not None:
|
|
99
|
+
logging.info("Building optimizer + scaler for checkpoint restore...")
|
|
100
|
+
optimizer, scaler = build_optimizer(
|
|
101
|
+
inner_model.parameters(), # or _get_params() if required
|
|
102
|
+
self.config.TRAIN.optimizer
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
logging.info("Building scheduler for checkpoint restore...")
|
|
106
|
+
scheduler = build_scheduler(
|
|
107
|
+
optimizer,
|
|
108
|
+
self.config.TRAIN.scheduler,
|
|
109
|
+
default_args
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
optimizer = scheduler = scaler = None
|
|
113
|
+
|
|
114
|
+
inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
94
115
|
model=inner_model,
|
|
95
116
|
path=weights,
|
|
117
|
+
optimizer=optimizer,
|
|
118
|
+
scheduler=scheduler,
|
|
119
|
+
scaler=scaler,
|
|
96
120
|
device=device,
|
|
97
121
|
key_remap_fn=localization_remap,
|
|
98
122
|
)
|
|
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
|
|
|
107
131
|
self.last_loaded_weights = weights
|
|
108
132
|
self.best_checkpoint = weights
|
|
109
133
|
|
|
110
|
-
|
|
111
|
-
|
|
134
|
+
best_epoch = checkpoint.get("best_epoch", 0)
|
|
135
|
+
|
|
136
|
+
best_criterion_valid = checkpoint.get(
|
|
137
|
+
"best_criterion_valid",
|
|
138
|
+
0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
|
|
139
|
+
)
|
|
140
|
+
self._resume_state = {
|
|
141
|
+
"optimizer": optimizer,
|
|
142
|
+
"scheduler": scheduler,
|
|
143
|
+
"scaler": scaler,
|
|
144
|
+
"epoch": epoch if epoch is not None else 0,
|
|
145
|
+
"best_epoch": best_epoch,
|
|
146
|
+
"best_criterion_valid": best_criterion_valid,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if not self.train_flag:
|
|
150
|
+
if model_cfg is not None and original_multi_gpu is not None:
|
|
151
|
+
model_cfg.multi_gpu = original_multi_gpu
|
|
112
152
|
|
|
113
153
|
def train(
|
|
114
154
|
self,
|
|
@@ -138,8 +178,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
138
178
|
valid_set = self._resolve_split_path("valid", valid_set)
|
|
139
179
|
self._set_split_path("train", train_set)
|
|
140
180
|
self._set_split_path("valid", valid_set)
|
|
141
|
-
|
|
142
|
-
self.config = resolve_config_omega(self.config)
|
|
181
|
+
|
|
182
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
143
183
|
check_config(self.config, split="train")
|
|
144
184
|
init_wandb(
|
|
145
185
|
self.config_path,
|
|
@@ -167,13 +207,6 @@ class LocalizationModel(BaseTaskModel):
|
|
|
167
207
|
|
|
168
208
|
start = time.time()
|
|
169
209
|
|
|
170
|
-
if effective_weights is not None:
|
|
171
|
-
if self.model is None or self.last_loaded_weights != effective_weights:
|
|
172
|
-
self.load_weights(weights=effective_weights)
|
|
173
|
-
elif self.model is None:
|
|
174
|
-
device = select_device(self.config.SYSTEM)
|
|
175
|
-
self.model = build_model(self.config, device=device)
|
|
176
|
-
|
|
177
210
|
data_obj_train = build_dataset(self.config, split="train")
|
|
178
211
|
dataset_train = data_obj_train.building_dataset(
|
|
179
212
|
cfg=data_obj_train.cfg,
|
|
@@ -200,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
|
|
|
200
233
|
dali=self.config.dali,
|
|
201
234
|
)
|
|
202
235
|
|
|
236
|
+
default_args = get_default_args_trainer(self.config, len(train_loader))
|
|
237
|
+
|
|
238
|
+
self.train_flag = True # Set flag to indicate training mode for checkpoint loading
|
|
239
|
+
if effective_weights is not None:
|
|
240
|
+
if self.model is None or self.last_loaded_weights != effective_weights:
|
|
241
|
+
self.load_weights(weights=effective_weights, default_args=default_args)
|
|
242
|
+
elif self.model is None:
|
|
243
|
+
device = select_device(self.config.SYSTEM)
|
|
244
|
+
self.model = build_model(self.config, device=device)
|
|
245
|
+
|
|
203
246
|
self.trainer = build_trainer(
|
|
204
247
|
cfg=self.config,
|
|
205
248
|
model=self.model,
|
|
206
|
-
default_args=
|
|
207
|
-
resume_from=
|
|
249
|
+
default_args=default_args,
|
|
250
|
+
resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
|
|
208
251
|
)
|
|
209
252
|
|
|
210
253
|
logging.info("Start training")
|
|
@@ -249,7 +292,7 @@ class LocalizationModel(BaseTaskModel):
|
|
|
249
292
|
self._set_split_path("test", test_set)
|
|
250
293
|
|
|
251
294
|
self.config.MODEL.multi_gpu = False
|
|
252
|
-
self.config = resolve_config_omega(self.config)
|
|
295
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
253
296
|
check_config(self.config, split="test")
|
|
254
297
|
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
255
298
|
|
|
@@ -318,9 +361,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
318
361
|
|
|
319
362
|
test_set = self._resolve_split_path("test", test_set)
|
|
320
363
|
self._set_split_path("test", test_set)
|
|
321
|
-
|
|
322
364
|
self.config.MODEL.multi_gpu = False
|
|
323
|
-
self.config = resolve_config_omega(self.config)
|
|
365
|
+
self.config = resolve_config_omega(self.config, weights=weights)
|
|
324
366
|
check_config(self.config, split="test")
|
|
325
367
|
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
326
368
|
|
|
@@ -550,6 +550,13 @@ class BaseTrainerClassification:
|
|
|
550
550
|
path_aux = os.path.join(epoch_dir, name)
|
|
551
551
|
torch.save(state, path_aux)
|
|
552
552
|
logging.info(f"Saved checkpoint: {path_aux}")
|
|
553
|
+
|
|
554
|
+
if self.config is not None:
|
|
555
|
+
from opensportslib.core.utils.config import save_config
|
|
556
|
+
config_path = os.path.join(epoch_dir, "config.yaml")
|
|
557
|
+
save_config(self.config, config_path)
|
|
558
|
+
logging.info(f"Saved config: {config_path}")
|
|
559
|
+
|
|
553
560
|
return path_aux
|
|
554
561
|
|
|
555
562
|
|
|
@@ -1167,11 +1174,12 @@ class Trainer_Classification:
|
|
|
1167
1174
|
from opensportslib.models.builder import build_model
|
|
1168
1175
|
if self.model is None:
|
|
1169
1176
|
self.model, _ = build_model(self.config, self.device)
|
|
1170
|
-
self.model, optimizer, scheduler, epoch = load_checkpoint(
|
|
1177
|
+
self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
|
|
1171
1178
|
self.model, path, optimizer, scheduler, device=self.device
|
|
1172
1179
|
)
|
|
1173
1180
|
self.optimizer = optimizer
|
|
1174
1181
|
self.scheduler = scheduler
|
|
1182
|
+
self.scaler = scaler
|
|
1175
1183
|
self.epoch = epoch
|
|
1176
1184
|
logging.info(f"Model loaded from {path}, epoch: {epoch}")
|
|
1177
1185
|
return self.model, self.optimizer, self.scheduler, self.epoch
|
|
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
29
29
|
"""
|
|
30
30
|
from opensportslib.metrics.localization_metric import *
|
|
31
31
|
from opensportslib.core.optimizer.builder import build_optimizer
|
|
32
|
-
from opensportslib.core.optimizer.builder import build_optimizer
|
|
33
32
|
from opensportslib.core.scheduler.builder import build_scheduler
|
|
34
33
|
from opensportslib.core.utils.config import store_json
|
|
35
34
|
from opensportslib.datasets.builder import build_dataset
|
|
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
67
66
|
|
|
68
67
|
# Handle checkpoint loading
|
|
69
68
|
if resume_from is not None:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
checkpoint = torch.load(resume_from)
|
|
75
|
-
|
|
76
|
-
# Load model state
|
|
77
|
-
model.load(checkpoint['model_state_dict'])
|
|
78
|
-
logging.info("Model state loaded successfully")
|
|
79
|
-
|
|
80
|
-
# Get current training progress
|
|
81
|
-
start_epoch = checkpoint['epoch'] + 1
|
|
82
|
-
logging.info(f"Resuming from epoch {start_epoch}")
|
|
83
|
-
|
|
69
|
+
optimizer = resume_from["optimizer"]
|
|
70
|
+
scheduler = resume_from["scheduler"]
|
|
71
|
+
scaler = resume_from["scaler"]
|
|
72
|
+
start_epoch = resume_from["epoch"] + 1
|
|
84
73
|
# Check if we've already reached target epochs
|
|
85
74
|
if start_epoch >= cfg.TRAIN.num_epochs:
|
|
86
75
|
logging.error(f"Model already trained for {start_epoch} epochs")
|
|
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
89
78
|
raise ValueError("Need to increase num_epochs to continue training")
|
|
90
79
|
|
|
91
80
|
logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if resume_from is not None and 'optimizer_state_dict' in checkpoint:
|
|
98
|
-
try:
|
|
99
|
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
100
|
-
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
101
|
-
logging.info("Optimizer and scaler states loaded")
|
|
102
|
-
except Exception as e:
|
|
103
|
-
logging.warning(f"Could not load optimizer state: {e}")
|
|
104
|
-
logging.warning("Will start with fresh optimizer state")
|
|
105
|
-
|
|
106
|
-
logging.info("Building scheduler...")
|
|
107
|
-
lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
108
|
-
|
|
109
|
-
# Load scheduler state if available
|
|
110
|
-
if resume_from is not None and 'lr_state_dict' in checkpoint:
|
|
111
|
-
try:
|
|
112
|
-
lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
|
|
113
|
-
logging.info("Scheduler state loaded")
|
|
114
|
-
except Exception as e:
|
|
115
|
-
logging.warning(f"Could not load scheduler state: {e}")
|
|
116
|
-
logging.warning("Will start with fresh scheduler state")
|
|
81
|
+
else:
|
|
82
|
+
logging.info("Building optimizer...")
|
|
83
|
+
optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
|
|
84
|
+
logging.info("Building scheduler...")
|
|
85
|
+
scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
117
86
|
|
|
118
87
|
trainer = Trainer_e2e(
|
|
119
88
|
cfg,
|
|
120
89
|
model,
|
|
121
90
|
optimizer,
|
|
122
91
|
scaler,
|
|
123
|
-
|
|
92
|
+
scheduler,
|
|
124
93
|
default_args["work_dir"],
|
|
125
94
|
default_args["dali"],
|
|
126
95
|
default_args["repartitions"],
|
|
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
|
132
101
|
|
|
133
102
|
# Load training history if resuming
|
|
134
103
|
if resume_from is not None:
|
|
135
|
-
trainer.best_epoch =
|
|
136
|
-
trainer.best_criterion_valid =
|
|
104
|
+
trainer.best_epoch = resume_from.get('best_epoch', 0)
|
|
105
|
+
trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
|
|
137
106
|
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
138
107
|
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
139
108
|
|
|
@@ -186,6 +155,7 @@ class Trainer_pl(Trainer):
|
|
|
186
155
|
num_sanity_val_steps=0,
|
|
187
156
|
)
|
|
188
157
|
self.best_checkpoint_path = None
|
|
158
|
+
self.config = cfg
|
|
189
159
|
|
|
190
160
|
def train(self, **kwargs):
|
|
191
161
|
self.trainer.fit(**kwargs)
|
|
@@ -210,6 +180,13 @@ class Trainer_pl(Trainer):
|
|
|
210
180
|
logging.info("Done training")
|
|
211
181
|
logging.info(f"Best model saved at: {self.best_checkpoint_path}")
|
|
212
182
|
|
|
183
|
+
# Save the config file uniformly inside the work_dir
|
|
184
|
+
if hasattr(self, 'config') and self.config is not None:
|
|
185
|
+
from opensportslib.core.utils.config import save_config
|
|
186
|
+
config_path = os.path.join(self.work_dir, "config.yaml")
|
|
187
|
+
save_config(self.config, config_path)
|
|
188
|
+
logging.info(f"Saved config: {config_path}")
|
|
189
|
+
|
|
213
190
|
log()
|
|
214
191
|
|
|
215
192
|
|
|
@@ -328,6 +305,12 @@ class Trainer_e2e(Trainer):
|
|
|
328
305
|
self.best_checkpoint_path = best_path
|
|
329
306
|
torch.save(checkpoint, best_path)
|
|
330
307
|
logging.info(f"Best checkpoint saved: {best_path}")
|
|
308
|
+
|
|
309
|
+
if self.config is not None:
|
|
310
|
+
from opensportslib.core.utils.config import save_config
|
|
311
|
+
config_path = os.path.join(self.save_dir, "config.yaml")
|
|
312
|
+
save_config(self.config, config_path)
|
|
313
|
+
logging.info(f"Saved config: {config_path}")
|
|
331
314
|
|
|
332
315
|
def train(self, train_loader, valid_loader, classes):
|
|
333
316
|
"""Training loop with checkpoint management."""
|
|
@@ -441,7 +424,7 @@ class Trainer_e2e(Trainer):
|
|
|
441
424
|
best_checkpoint_path = os.path.join(
|
|
442
425
|
self.save_dir, f"best_checkpoint.pt"
|
|
443
426
|
)
|
|
444
|
-
self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
|
|
427
|
+
self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
|
|
445
428
|
path=best_checkpoint_path,
|
|
446
429
|
key_remap_fn=localization_remap)
|
|
447
430
|
logging.info(f"Loaded best model from epoch {self.best_epoch}")
|
{opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/checkpoint.py
RENAMED
|
@@ -76,6 +76,7 @@ def load_checkpoint(
|
|
|
76
76
|
path,
|
|
77
77
|
optimizer=None,
|
|
78
78
|
scheduler=None,
|
|
79
|
+
scaler=None,
|
|
79
80
|
device=None,
|
|
80
81
|
key_remap_fn=None,
|
|
81
82
|
hf_filename="model.pth.tar", # required if loading from HF repo
|
|
@@ -164,7 +165,7 @@ def load_checkpoint(
|
|
|
164
165
|
# --------------------------------------------------
|
|
165
166
|
# Load checkpoint
|
|
166
167
|
# --------------------------------------------------
|
|
167
|
-
checkpoint = torch.load(ckpt_path, map_location=
|
|
168
|
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
168
169
|
|
|
169
170
|
# ---------------- MODEL STATE ----------------
|
|
170
171
|
if isinstance(checkpoint, dict):
|
|
@@ -201,8 +202,24 @@ def load_checkpoint(
|
|
|
201
202
|
for k, v in state_dict.items()
|
|
202
203
|
}
|
|
203
204
|
|
|
204
|
-
state_dict = strip_prefix(state_dict, "module.")
|
|
205
|
+
# state_dict = strip_prefix(state_dict, "module.")
|
|
206
|
+
# state_dict = strip_prefix(state_dict, "model.")
|
|
207
|
+
|
|
208
|
+
# First remove known wrappers (safe ones)
|
|
205
209
|
state_dict = strip_prefix(state_dict, "model.")
|
|
210
|
+
state_dict = strip_prefix(state_dict, "_model.")
|
|
211
|
+
|
|
212
|
+
# Now handle module dynamically
|
|
213
|
+
model_keys = list(model.state_dict().keys())
|
|
214
|
+
ckpt_keys = list(state_dict.keys())
|
|
215
|
+
|
|
216
|
+
model_has_module = model_keys[0].startswith("module.")
|
|
217
|
+
ckpt_has_module = ckpt_keys[0].startswith("module.")
|
|
218
|
+
|
|
219
|
+
if model_has_module and not ckpt_has_module:
|
|
220
|
+
state_dict = {f"module.{k}": v for k, v in state_dict.items()}
|
|
221
|
+
elif not model_has_module and ckpt_has_module:
|
|
222
|
+
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
|
206
223
|
|
|
207
224
|
# Optional custom remap
|
|
208
225
|
if key_remap_fn:
|
|
@@ -229,15 +246,20 @@ def load_checkpoint(
|
|
|
229
246
|
|
|
230
247
|
# ---------------- SCHEDULER ----------------
|
|
231
248
|
if scheduler and isinstance(checkpoint, dict):
|
|
232
|
-
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
|
|
249
|
+
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
|
|
233
250
|
if sch_state:
|
|
234
251
|
scheduler.load_state_dict(sch_state)
|
|
235
252
|
|
|
253
|
+
if scaler and isinstance(checkpoint, dict):
|
|
254
|
+
scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
|
|
255
|
+
if scaler_state:
|
|
256
|
+
scaler.load_state_dict(scaler_state)
|
|
257
|
+
|
|
236
258
|
print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
|
|
237
259
|
print(f"Missing keys: {len(missing)}")
|
|
238
260
|
print(f"Unexpected keys: {len(unexpected)}")
|
|
239
261
|
|
|
240
|
-
return model, optimizer, scheduler, epoch
|
|
262
|
+
return model, optimizer, scheduler, scaler, epoch, checkpoint
|
|
241
263
|
|
|
242
264
|
|
|
243
265
|
|