opensportslib 0.1.2.dev6__tar.gz → 0.1.2.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {opensportslib-0.1.2.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev8}/PKG-INFO +1 -1
  2. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/base_task_model.py +8 -1
  3. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/classification.py +4 -5
  4. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/localization.py +68 -26
  5. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/classification_trainer.py +9 -1
  6. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/localization_trainer.py +27 -44
  7. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/checkpoint.py +26 -4
  8. opensportslib-0.1.2.dev8/opensportslib/core/utils/config.py +358 -0
  9. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8/opensportslib.egg-info}/PKG-INFO +1 -1
  10. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/pyproject.toml +1 -1
  11. opensportslib-0.1.2.dev6/opensportslib/core/utils/config.py +0 -214
  12. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/LICENSE +0 -0
  13. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/LICENSE-COMMERCIAL +0 -0
  14. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/MANIFEST.in +0 -0
  15. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/README.md +0 -0
  16. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/examples/quickstart/basic_classification.py +0 -0
  17. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/examples/quickstart/basic_localization.py +0 -0
  18. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/__init__.py +0 -0
  19. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/apis/__init__.py +0 -0
  20. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/cli.py +0 -0
  21. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/classification.yaml +0 -0
  22. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  23. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  24. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  25. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/localization.yaml +0 -0
  26. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/sngar-frames.yaml +0 -0
  27. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/config/sngar-tracking.yaml +0 -0
  28. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/__init__.py +0 -0
  29. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/__init__.py +0 -0
  30. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/builder.py +0 -0
  31. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/calf.py +0 -0
  32. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/ce.py +0 -0
  33. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/combine.py +0 -0
  34. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/loss/nll.py +0 -0
  35. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/optimizer/__init__.py +0 -0
  36. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/optimizer/builder.py +0 -0
  37. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  38. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/scheduler/__init__.py +0 -0
  39. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/scheduler/builder.py +0 -0
  40. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/trainer/__init__.py +0 -0
  41. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/data.py +0 -0
  42. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/ddp.py +0 -0
  43. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/default_args.py +0 -0
  44. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/lightning.py +0 -0
  45. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/load_annotations.py +0 -0
  46. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/seed.py +0 -0
  47. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/video_processing.py +0 -0
  48. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/core/utils/wandb.py +0 -0
  49. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/__init__.py +0 -0
  50. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/builder.py +0 -0
  51. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/classification_dataset.py +0 -0
  52. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/localization_dataset.py +0 -0
  53. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/utils/__init__.py +0 -0
  54. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/datasets/utils/tracking.py +0 -0
  55. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/metrics/classification_metric.py +0 -0
  56. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/metrics/localization_metric.py +0 -0
  57. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/__init__.py +0 -0
  58. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/backbones/builder.py +0 -0
  59. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/contextaware.py +0 -0
  60. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/e2e.py +0 -0
  61. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/learnablepooling.py +0 -0
  62. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/tracking.py +0 -0
  63. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/vars.py +0 -0
  64. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/video.py +0 -0
  65. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/base/video_mae.py +0 -0
  66. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/builder.py +0 -0
  67. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/heads/builder.py +0 -0
  68. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/neck/builder.py +0 -0
  69. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/common.py +0 -0
  70. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/__init__.py +0 -0
  71. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/asformer.py +0 -0
  72. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/calf.py +0 -0
  73. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/gsm.py +0 -0
  74. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/gtad.py +0 -0
  75. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/impl/tsm.py +0 -0
  76. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/litebase.py +0 -0
  77. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/modules.py +0 -0
  78. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/shift.py +0 -0
  79. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/models/utils/utils.py +0 -0
  80. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/setup/setup.py +0 -0
  81. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/__init__.py +0 -0
  82. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/_common.py +0 -0
  83. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/hf_transfer.py +0 -0
  84. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  85. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  86. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/SOURCES.txt +0 -0
  87. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/dependency_links.txt +0 -0
  88. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/entry_points.txt +0 -0
  89. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/requires.txt +0 -0
  90. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/opensportslib.egg-info/top_level.txt +0 -0
  91. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/setup.cfg +0 -0
  92. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/conftest.py +0 -0
  93. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_config_utils_smoke.py +0 -0
  94. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_conversion_tools.py +0 -0
  95. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_hf_transfer_tools.py +0 -0
  96. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_package_smoke.py +0 -0
  97. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_public_apis_smoke.py +0 -0
  98. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_subset_train_infer_integration.py +0 -0
  99. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tests/test_task_model_api_contract.py +0 -0
  100. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  101. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  102. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/download_hf_repo.py +0 -0
  103. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/download_osl_hf.py +0 -0
  104. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/download/upload_osl_hf.py +0 -0
  105. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/training/classification.py +0 -0
  106. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev8}/tools/training/localization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev6
3
+ Version: 0.1.2.dev8
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -9,7 +9,7 @@ import uuid
9
9
  from abc import ABC, abstractmethod
10
10
  from typing import Any
11
11
 
12
- from opensportslib.core.utils.config import expand, load_config_omega
12
+ from opensportslib.core.utils.config import expand, load_config_omega, fetch_and_merge_config_from_HF
13
13
 
14
14
 
15
15
  class BaseTaskModel(ABC):
@@ -24,6 +24,13 @@ class BaseTaskModel(ABC):
24
24
  self.config_path = expand(config)
25
25
  self.config = load_config_omega(self.config_path)
26
26
 
27
+ if weights is not None:
28
+ self.config = fetch_and_merge_config_from_HF(self.config, weights, merge_policy="compatibility")
29
+ self.last_loaded_weights = weights
30
+ self.best_checkpoint = weights
31
+
32
+ self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
33
+
27
34
  data_cfg = getattr(self.config, "DATA", None)
28
35
  if data_cfg is not None and hasattr(data_cfg, "data_dir"):
29
36
  data_cfg.data_dir = expand(data_cfg.data_dir)
@@ -8,7 +8,6 @@ import os
8
8
  from opensportslib.apis.base_task_model import BaseTaskModel
9
9
  from opensportslib.core.utils.config import expand
10
10
 
11
-
12
11
  class ClassificationModel(BaseTaskModel):
13
12
  """Top-level task wrapper for classification."""
14
13
 
@@ -172,8 +171,8 @@ class ClassificationModel(BaseTaskModel):
172
171
 
173
172
  train_set = self._resolve_split_path("train", train_set)
174
173
  valid_set = self._resolve_split_path("valid", valid_set)
175
-
176
- self.config = resolve_config_omega(self.config)
174
+
175
+ self.config = resolve_config_omega(self.config, weights=weights)
177
176
  logging.info("Configuration:")
178
177
  logging.info(self.config)
179
178
 
@@ -241,7 +240,7 @@ class ClassificationModel(BaseTaskModel):
241
240
 
242
241
  test_set = self._resolve_split_path("test", test_set)
243
242
 
244
- self.config = resolve_config_omega(self.config)
243
+ self.config = resolve_config_omega(self.config, weights=weights)
245
244
  logging.info("Configuration:")
246
245
  logging.info(self.config)
247
246
 
@@ -304,7 +303,7 @@ class ClassificationModel(BaseTaskModel):
304
303
 
305
304
  test_set = self._resolve_split_path("test", test_set)
306
305
 
307
- self.config = resolve_config_omega(self.config)
306
+ self.config = resolve_config_omega(self.config, weights=weights)
308
307
  logging.info("Configuration:")
309
308
  logging.info(self.config)
310
309
  if predictions is None:
@@ -9,11 +9,13 @@ from opensportslib.core.utils.config import expand
9
9
  class LocalizationModel(BaseTaskModel):
10
10
  """Top-level task wrapper for localization / spotting."""
11
11
 
12
- def __init__(self, config=None, weights=None):
13
- super().__init__(config=config, weights=None)
14
- if weights is not None:
15
- self.last_loaded_weights = weights
16
- self.best_checkpoint = weights
12
+ # def __init__(self, config=None, weights=None):
13
+ # super().__init__(config=config, weights=None)
14
+ # if weights is not None:
15
+ # self.last_loaded_weights = weights
16
+ # self.best_checkpoint = weights
17
+
18
+ # self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
17
19
 
18
20
  def _resolve_split_path(self, split: str, override: str | None = None) -> str:
19
21
  if override is not None:
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
69
71
  load_checkpoint,
70
72
  localization_remap,
71
73
  )
72
-
74
+ from opensportslib.core.optimizer.builder import build_optimizer
75
+ from opensportslib.core.scheduler.builder import build_scheduler
76
+ default_args = kwargs.get("default_args", None)
73
77
  del kwargs
74
78
  if weights is None:
75
79
  raise ValueError("`weights` must be provided to load_weights().")
76
80
 
77
81
  model_cfg = getattr(self.config, "MODEL", None)
78
- original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
79
- if model_cfg is not None and original_multi_gpu is not None:
80
- model_cfg.multi_gpu = False
82
+ if not self.train_flag:
83
+ original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
84
+ if model_cfg is not None and original_multi_gpu is not None:
85
+ model_cfg.multi_gpu = False
81
86
 
82
87
  device = select_device(self.config.SYSTEM)
83
88
  if self.model is None:
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
90
95
  if is_local_path(weights):
91
96
  self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
92
97
 
93
- inner_model, _, _, _ = load_checkpoint(
98
+ if default_args is not None:
99
+ logging.info("Building optimizer + scaler for checkpoint restore...")
100
+ optimizer, scaler = build_optimizer(
101
+ inner_model.parameters(), # or _get_params() if required
102
+ self.config.TRAIN.optimizer
103
+ )
104
+
105
+ logging.info("Building scheduler for checkpoint restore...")
106
+ scheduler = build_scheduler(
107
+ optimizer,
108
+ self.config.TRAIN.scheduler,
109
+ default_args
110
+ )
111
+ else:
112
+ optimizer = scheduler = scaler = None
113
+
114
+ inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
94
115
  model=inner_model,
95
116
  path=weights,
117
+ optimizer=optimizer,
118
+ scheduler=scheduler,
119
+ scaler=scaler,
96
120
  device=device,
97
121
  key_remap_fn=localization_remap,
98
122
  )
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
107
131
  self.last_loaded_weights = weights
108
132
  self.best_checkpoint = weights
109
133
 
110
- if model_cfg is not None and original_multi_gpu is not None:
111
- model_cfg.multi_gpu = original_multi_gpu
134
+ best_epoch = checkpoint.get("best_epoch", 0)
135
+
136
+ best_criterion_valid = checkpoint.get(
137
+ "best_criterion_valid",
138
+ 0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
139
+ )
140
+ self._resume_state = {
141
+ "optimizer": optimizer,
142
+ "scheduler": scheduler,
143
+ "scaler": scaler,
144
+ "epoch": epoch if epoch is not None else 0,
145
+ "best_epoch": best_epoch,
146
+ "best_criterion_valid": best_criterion_valid,
147
+ }
148
+
149
+ if not self.train_flag:
150
+ if model_cfg is not None and original_multi_gpu is not None:
151
+ model_cfg.multi_gpu = original_multi_gpu
112
152
 
113
153
  def train(
114
154
  self,
@@ -138,8 +178,8 @@ class LocalizationModel(BaseTaskModel):
138
178
  valid_set = self._resolve_split_path("valid", valid_set)
139
179
  self._set_split_path("train", train_set)
140
180
  self._set_split_path("valid", valid_set)
141
-
142
- self.config = resolve_config_omega(self.config)
181
+
182
+ self.config = resolve_config_omega(self.config, weights=weights)
143
183
  check_config(self.config, split="train")
144
184
  init_wandb(
145
185
  self.config_path,
@@ -167,13 +207,6 @@ class LocalizationModel(BaseTaskModel):
167
207
 
168
208
  start = time.time()
169
209
 
170
- if effective_weights is not None:
171
- if self.model is None or self.last_loaded_weights != effective_weights:
172
- self.load_weights(weights=effective_weights)
173
- elif self.model is None:
174
- device = select_device(self.config.SYSTEM)
175
- self.model = build_model(self.config, device=device)
176
-
177
210
  data_obj_train = build_dataset(self.config, split="train")
178
211
  dataset_train = data_obj_train.building_dataset(
179
212
  cfg=data_obj_train.cfg,
@@ -200,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
200
233
  dali=self.config.dali,
201
234
  )
202
235
 
236
+ default_args = get_default_args_trainer(self.config, len(train_loader))
237
+
238
+ self.train_flag = True # Set flag to indicate training mode for checkpoint loading
239
+ if effective_weights is not None:
240
+ if self.model is None or self.last_loaded_weights != effective_weights:
241
+ self.load_weights(weights=effective_weights, default_args=default_args)
242
+ elif self.model is None:
243
+ device = select_device(self.config.SYSTEM)
244
+ self.model = build_model(self.config, device=device)
245
+
203
246
  self.trainer = build_trainer(
204
247
  cfg=self.config,
205
248
  model=self.model,
206
- default_args=get_default_args_trainer(self.config, len(train_loader)),
207
- resume_from=effective_weights,
249
+ default_args=default_args,
250
+ resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
208
251
  )
209
252
 
210
253
  logging.info("Start training")
@@ -249,7 +292,7 @@ class LocalizationModel(BaseTaskModel):
249
292
  self._set_split_path("test", test_set)
250
293
 
251
294
  self.config.MODEL.multi_gpu = False
252
- self.config = resolve_config_omega(self.config)
295
+ self.config = resolve_config_omega(self.config, weights=weights)
253
296
  check_config(self.config, split="test")
254
297
  self.config.infer_split = whether_infer_split(self.config.DATA.test)
255
298
 
@@ -318,9 +361,8 @@ class LocalizationModel(BaseTaskModel):
318
361
 
319
362
  test_set = self._resolve_split_path("test", test_set)
320
363
  self._set_split_path("test", test_set)
321
-
322
364
  self.config.MODEL.multi_gpu = False
323
- self.config = resolve_config_omega(self.config)
365
+ self.config = resolve_config_omega(self.config, weights=weights)
324
366
  check_config(self.config, split="test")
325
367
  self.config.infer_split = whether_infer_split(self.config.DATA.test)
326
368
 
@@ -550,6 +550,13 @@ class BaseTrainerClassification:
550
550
  path_aux = os.path.join(epoch_dir, name)
551
551
  torch.save(state, path_aux)
552
552
  logging.info(f"Saved checkpoint: {path_aux}")
553
+
554
+ if self.config is not None:
555
+ from opensportslib.core.utils.config import save_config
556
+ config_path = os.path.join(epoch_dir, "config.yaml")
557
+ save_config(self.config, config_path)
558
+ logging.info(f"Saved config: {config_path}")
559
+
553
560
  return path_aux
554
561
 
555
562
 
@@ -1167,11 +1174,12 @@ class Trainer_Classification:
1167
1174
  from opensportslib.models.builder import build_model
1168
1175
  if self.model is None:
1169
1176
  self.model, _ = build_model(self.config, self.device)
1170
- self.model, optimizer, scheduler, epoch = load_checkpoint(
1177
+ self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
1171
1178
  self.model, path, optimizer, scheduler, device=self.device
1172
1179
  )
1173
1180
  self.optimizer = optimizer
1174
1181
  self.scheduler = scheduler
1182
+ self.scaler = scaler
1175
1183
  self.epoch = epoch
1176
1184
  logging.info(f"Model loaded from {path}, epoch: {epoch}")
1177
1185
  return self.model, self.optimizer, self.scheduler, self.epoch
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
  """
30
30
  from opensportslib.metrics.localization_metric import *
31
31
  from opensportslib.core.optimizer.builder import build_optimizer
32
- from opensportslib.core.optimizer.builder import build_optimizer
33
32
  from opensportslib.core.scheduler.builder import build_scheduler
34
33
  from opensportslib.core.utils.config import store_json
35
34
  from opensportslib.datasets.builder import build_dataset
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
67
66
 
68
67
  # Handle checkpoint loading
69
68
  if resume_from is not None:
70
- if not os.path.isfile(resume_from):
71
- raise ValueError(f"Checkpoint file not found: {resume_from}")
72
-
73
- logging.info(f"Loading checkpoint from: {resume_from}")
74
- checkpoint = torch.load(resume_from)
75
-
76
- # Load model state
77
- model.load(checkpoint['model_state_dict'])
78
- logging.info("Model state loaded successfully")
79
-
80
- # Get current training progress
81
- start_epoch = checkpoint['epoch'] + 1
82
- logging.info(f"Resuming from epoch {start_epoch}")
83
-
69
+ optimizer = resume_from["optimizer"]
70
+ scheduler = resume_from["scheduler"]
71
+ scaler = resume_from["scaler"]
72
+ start_epoch = resume_from["epoch"] + 1
84
73
  # Check if we've already reached target epochs
85
74
  if start_epoch >= cfg.TRAIN.num_epochs:
86
75
  logging.error(f"Model already trained for {start_epoch} epochs")
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
89
78
  raise ValueError("Need to increase num_epochs to continue training")
90
79
 
91
80
  logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
92
-
93
- logging.info("Building optimizer...")
94
- optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
95
-
96
- # Load optimizer state if available in checkpoint
97
- if resume_from is not None and 'optimizer_state_dict' in checkpoint:
98
- try:
99
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
100
- scaler.load_state_dict(checkpoint['scaler_state_dict'])
101
- logging.info("Optimizer and scaler states loaded")
102
- except Exception as e:
103
- logging.warning(f"Could not load optimizer state: {e}")
104
- logging.warning("Will start with fresh optimizer state")
105
-
106
- logging.info("Building scheduler...")
107
- lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
108
-
109
- # Load scheduler state if available
110
- if resume_from is not None and 'lr_state_dict' in checkpoint:
111
- try:
112
- lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
113
- logging.info("Scheduler state loaded")
114
- except Exception as e:
115
- logging.warning(f"Could not load scheduler state: {e}")
116
- logging.warning("Will start with fresh scheduler state")
81
+ else:
82
+ logging.info("Building optimizer...")
83
+ optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
84
+ logging.info("Building scheduler...")
85
+ scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
117
86
 
118
87
  trainer = Trainer_e2e(
119
88
  cfg,
120
89
  model,
121
90
  optimizer,
122
91
  scaler,
123
- lr_scheduler,
92
+ scheduler,
124
93
  default_args["work_dir"],
125
94
  default_args["dali"],
126
95
  default_args["repartitions"],
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
132
101
 
133
102
  # Load training history if resuming
134
103
  if resume_from is not None:
135
- trainer.best_epoch = checkpoint.get('best_epoch', 0)
136
- trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
104
+ trainer.best_epoch = resume_from.get('best_epoch', 0)
105
+ trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
137
106
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
138
107
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
139
108
 
@@ -186,6 +155,7 @@ class Trainer_pl(Trainer):
186
155
  num_sanity_val_steps=0,
187
156
  )
188
157
  self.best_checkpoint_path = None
158
+ self.config = cfg
189
159
 
190
160
  def train(self, **kwargs):
191
161
  self.trainer.fit(**kwargs)
@@ -210,6 +180,13 @@ class Trainer_pl(Trainer):
210
180
  logging.info("Done training")
211
181
  logging.info(f"Best model saved at: {self.best_checkpoint_path}")
212
182
 
183
+ # Save the config file uniformly inside the work_dir
184
+ if hasattr(self, 'config') and self.config is not None:
185
+ from opensportslib.core.utils.config import save_config
186
+ config_path = os.path.join(self.work_dir, "config.yaml")
187
+ save_config(self.config, config_path)
188
+ logging.info(f"Saved config: {config_path}")
189
+
213
190
  log()
214
191
 
215
192
 
@@ -328,6 +305,12 @@ class Trainer_e2e(Trainer):
328
305
  self.best_checkpoint_path = best_path
329
306
  torch.save(checkpoint, best_path)
330
307
  logging.info(f"Best checkpoint saved: {best_path}")
308
+
309
+ if self.config is not None:
310
+ from opensportslib.core.utils.config import save_config
311
+ config_path = os.path.join(self.save_dir, "config.yaml")
312
+ save_config(self.config, config_path)
313
+ logging.info(f"Saved config: {config_path}")
331
314
 
332
315
  def train(self, train_loader, valid_loader, classes):
333
316
  """Training loop with checkpoint management."""
@@ -441,7 +424,7 @@ class Trainer_e2e(Trainer):
441
424
  best_checkpoint_path = os.path.join(
442
425
  self.save_dir, f"best_checkpoint.pt"
443
426
  )
444
- self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
427
+ self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
445
428
  path=best_checkpoint_path,
446
429
  key_remap_fn=localization_remap)
447
430
  logging.info(f"Loaded best model from epoch {self.best_epoch}")
@@ -76,6 +76,7 @@ def load_checkpoint(
76
76
  path,
77
77
  optimizer=None,
78
78
  scheduler=None,
79
+ scaler=None,
79
80
  device=None,
80
81
  key_remap_fn=None,
81
82
  hf_filename="model.pth.tar", # required if loading from HF repo
@@ -164,7 +165,7 @@ def load_checkpoint(
164
165
  # --------------------------------------------------
165
166
  # Load checkpoint
166
167
  # --------------------------------------------------
167
- checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
168
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
168
169
 
169
170
  # ---------------- MODEL STATE ----------------
170
171
  if isinstance(checkpoint, dict):
@@ -201,8 +202,24 @@ def load_checkpoint(
201
202
  for k, v in state_dict.items()
202
203
  }
203
204
 
204
- state_dict = strip_prefix(state_dict, "module.")
205
+ # state_dict = strip_prefix(state_dict, "module.")
206
+ # state_dict = strip_prefix(state_dict, "model.")
207
+
208
+ # First remove known wrappers (safe ones)
205
209
  state_dict = strip_prefix(state_dict, "model.")
210
+ state_dict = strip_prefix(state_dict, "_model.")
211
+
212
+ # Now handle module dynamically
213
+ model_keys = list(model.state_dict().keys())
214
+ ckpt_keys = list(state_dict.keys())
215
+
216
+ model_has_module = model_keys[0].startswith("module.")
217
+ ckpt_has_module = ckpt_keys[0].startswith("module.")
218
+
219
+ if model_has_module and not ckpt_has_module:
220
+ state_dict = {f"module.{k}": v for k, v in state_dict.items()}
221
+ elif not model_has_module and ckpt_has_module:
222
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
206
223
 
207
224
  # Optional custom remap
208
225
  if key_remap_fn:
@@ -229,15 +246,20 @@ def load_checkpoint(
229
246
 
230
247
  # ---------------- SCHEDULER ----------------
231
248
  if scheduler and isinstance(checkpoint, dict):
232
- sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
249
+ sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
233
250
  if sch_state:
234
251
  scheduler.load_state_dict(sch_state)
235
252
 
253
+ if scaler and isinstance(checkpoint, dict):
254
+ scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
255
+ if scaler_state:
256
+ scaler.load_state_dict(scaler_state)
257
+
236
258
  print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
237
259
  print(f"Missing keys: {len(missing)}")
238
260
  print(f"Unexpected keys: {len(unexpected)}")
239
261
 
240
- return model, optimizer, scheduler, epoch
262
+ return model, optimizer, scheduler, scaler, epoch, checkpoint
241
263
 
242
264
 
243
265