opensportslib 0.1.2.dev6__tar.gz → 0.1.2.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {opensportslib-0.1.2.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev7}/PKG-INFO +1 -1
  2. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/localization.py +59 -16
  3. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/classification_trainer.py +2 -1
  4. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/localization_trainer.py +13 -44
  5. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py +26 -4
  6. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7/opensportslib.egg-info}/PKG-INFO +1 -1
  7. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/pyproject.toml +1 -1
  8. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/LICENSE +0 -0
  9. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/LICENSE-COMMERCIAL +0 -0
  10. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/MANIFEST.in +0 -0
  11. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/README.md +0 -0
  12. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py +0 -0
  13. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py +0 -0
  14. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/__init__.py +0 -0
  15. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/__init__.py +0 -0
  16. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/base_task_model.py +0 -0
  17. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/apis/classification.py +0 -0
  18. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/cli.py +0 -0
  19. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml +0 -0
  20. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  21. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  22. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  23. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml +0 -0
  24. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml +0 -0
  25. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml +0 -0
  26. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/__init__.py +0 -0
  27. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/__init__.py +0 -0
  28. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/builder.py +0 -0
  29. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/calf.py +0 -0
  30. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/ce.py +0 -0
  31. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/combine.py +0 -0
  32. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/nll.py +0 -0
  33. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py +0 -0
  34. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py +0 -0
  35. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  36. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py +0 -0
  37. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py +0 -0
  38. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py +0 -0
  39. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/config.py +0 -0
  40. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/data.py +0 -0
  41. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/ddp.py +0 -0
  42. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py +0 -0
  43. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/lightning.py +0 -0
  44. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py +0 -0
  45. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/seed.py +0 -0
  46. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py +0 -0
  47. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/wandb.py +0 -0
  48. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/__init__.py +0 -0
  49. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/builder.py +0 -0
  50. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/classification_dataset.py +0 -0
  51. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py +0 -0
  52. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py +0 -0
  53. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py +0 -0
  54. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py +0 -0
  55. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py +0 -0
  56. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/__init__.py +0 -0
  57. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py +0 -0
  58. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py +0 -0
  59. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/e2e.py +0 -0
  60. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py +0 -0
  61. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/tracking.py +0 -0
  62. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/vars.py +0 -0
  63. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video.py +0 -0
  64. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py +0 -0
  65. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/builder.py +0 -0
  66. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/heads/builder.py +0 -0
  67. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/neck/builder.py +0 -0
  68. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/common.py +0 -0
  69. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py +0 -0
  70. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py +0 -0
  71. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py +0 -0
  72. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py +0 -0
  73. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py +0 -0
  74. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py +0 -0
  75. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py +0 -0
  76. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/modules.py +0 -0
  77. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/shift.py +0 -0
  78. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/utils.py +0 -0
  79. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/setup/setup.py +0 -0
  80. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/__init__.py +0 -0
  81. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/_common.py +0 -0
  82. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/hf_transfer.py +0 -0
  83. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  84. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  85. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/SOURCES.txt +0 -0
  86. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt +0 -0
  87. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt +0 -0
  88. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/requires.txt +0 -0
  89. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/top_level.txt +0 -0
  90. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/setup.cfg +0 -0
  91. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/conftest.py +0 -0
  92. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_config_utils_smoke.py +0 -0
  93. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_conversion_tools.py +0 -0
  94. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_hf_transfer_tools.py +0 -0
  95. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_package_smoke.py +0 -0
  96. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_public_apis_smoke.py +0 -0
  97. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py +0 -0
  98. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tests/test_task_model_api_contract.py +0 -0
  99. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  100. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  101. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/download_hf_repo.py +0 -0
  102. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/download_osl_hf.py +0 -0
  103. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/download/upload_osl_hf.py +0 -0
  104. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/training/classification.py +0 -0
  105. {opensportslib-0.1.2.dev6 → opensportslib-0.1.2.dev7}/tools/training/localization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev6
3
+ Version: 0.1.2.dev7
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -15,6 +15,8 @@ class LocalizationModel(BaseTaskModel):
15
15
  self.last_loaded_weights = weights
16
16
  self.best_checkpoint = weights
17
17
 
18
+ self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
19
+
18
20
  def _resolve_split_path(self, split: str, override: str | None = None) -> str:
19
21
  if override is not None:
20
22
  return expand(override)
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
69
71
  load_checkpoint,
70
72
  localization_remap,
71
73
  )
72
-
74
+ from opensportslib.core.optimizer.builder import build_optimizer
75
+ from opensportslib.core.scheduler.builder import build_scheduler
76
+ default_args = kwargs.get("default_args", None)
73
77
  del kwargs
74
78
  if weights is None:
75
79
  raise ValueError("`weights` must be provided to load_weights().")
76
80
 
77
81
  model_cfg = getattr(self.config, "MODEL", None)
78
- original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
79
- if model_cfg is not None and original_multi_gpu is not None:
80
- model_cfg.multi_gpu = False
82
+ if not self.train_flag:
83
+ original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
84
+ if model_cfg is not None and original_multi_gpu is not None:
85
+ model_cfg.multi_gpu = False
81
86
 
82
87
  device = select_device(self.config.SYSTEM)
83
88
  if self.model is None:
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
90
95
  if is_local_path(weights):
91
96
  self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
92
97
 
93
- inner_model, _, _, _ = load_checkpoint(
98
+ if default_args is not None:
99
+ logging.info("Building optimizer + scaler for checkpoint restore...")
100
+ optimizer, scaler = build_optimizer(
101
+ inner_model.parameters(), # or _get_params() if required
102
+ self.config.TRAIN.optimizer
103
+ )
104
+
105
+ logging.info("Building scheduler for checkpoint restore...")
106
+ scheduler = build_scheduler(
107
+ optimizer,
108
+ self.config.TRAIN.scheduler,
109
+ default_args
110
+ )
111
+ else:
112
+ optimizer = scheduler = scaler = None
113
+
114
+ inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
94
115
  model=inner_model,
95
116
  path=weights,
117
+ optimizer=optimizer,
118
+ scheduler=scheduler,
119
+ scaler=scaler,
96
120
  device=device,
97
121
  key_remap_fn=localization_remap,
98
122
  )
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
107
131
  self.last_loaded_weights = weights
108
132
  self.best_checkpoint = weights
109
133
 
110
- if model_cfg is not None and original_multi_gpu is not None:
111
- model_cfg.multi_gpu = original_multi_gpu
134
+ best_epoch = checkpoint.get("best_epoch", 0)
135
+
136
+ best_criterion_valid = checkpoint.get(
137
+ "best_criterion_valid",
138
+ 0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
139
+ )
140
+ self._resume_state = {
141
+ "optimizer": optimizer,
142
+ "scheduler": scheduler,
143
+ "scaler": scaler,
144
+ "epoch": epoch if epoch is not None else 0,
145
+ "best_epoch": best_epoch,
146
+ "best_criterion_valid": best_criterion_valid,
147
+ }
148
+
149
+ if not self.train_flag:
150
+ if model_cfg is not None and original_multi_gpu is not None:
151
+ model_cfg.multi_gpu = original_multi_gpu
112
152
 
113
153
  def train(
114
154
  self,
@@ -167,13 +207,6 @@ class LocalizationModel(BaseTaskModel):
167
207
 
168
208
  start = time.time()
169
209
 
170
- if effective_weights is not None:
171
- if self.model is None or self.last_loaded_weights != effective_weights:
172
- self.load_weights(weights=effective_weights)
173
- elif self.model is None:
174
- device = select_device(self.config.SYSTEM)
175
- self.model = build_model(self.config, device=device)
176
-
177
210
  data_obj_train = build_dataset(self.config, split="train")
178
211
  dataset_train = data_obj_train.building_dataset(
179
212
  cfg=data_obj_train.cfg,
@@ -200,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
200
233
  dali=self.config.dali,
201
234
  )
202
235
 
236
+ default_args = get_default_args_trainer(self.config, len(train_loader))
237
+
238
+ self.train_flag = True # Set flag to indicate training mode for checkpoint loading
239
+ if effective_weights is not None:
240
+ if self.model is None or self.last_loaded_weights != effective_weights:
241
+ self.load_weights(weights=effective_weights, default_args=default_args)
242
+ elif self.model is None:
243
+ device = select_device(self.config.SYSTEM)
244
+ self.model = build_model(self.config, device=device)
245
+
203
246
  self.trainer = build_trainer(
204
247
  cfg=self.config,
205
248
  model=self.model,
206
- default_args=get_default_args_trainer(self.config, len(train_loader)),
207
- resume_from=effective_weights,
249
+ default_args=default_args,
250
+ resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
208
251
  )
209
252
 
210
253
  logging.info("Start training")
@@ -1167,11 +1167,12 @@ class Trainer_Classification:
1167
1167
  from opensportslib.models.builder import build_model
1168
1168
  if self.model is None:
1169
1169
  self.model, _ = build_model(self.config, self.device)
1170
- self.model, optimizer, scheduler, epoch = load_checkpoint(
1170
+ self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
1171
1171
  self.model, path, optimizer, scheduler, device=self.device
1172
1172
  )
1173
1173
  self.optimizer = optimizer
1174
1174
  self.scheduler = scheduler
1175
+ self.scaler = scaler
1175
1176
  self.epoch = epoch
1176
1177
  logging.info(f"Model loaded from {path}, epoch: {epoch}")
1177
1178
  return self.model, self.optimizer, self.scheduler, self.epoch
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
  """
30
30
  from opensportslib.metrics.localization_metric import *
31
31
  from opensportslib.core.optimizer.builder import build_optimizer
32
- from opensportslib.core.optimizer.builder import build_optimizer
33
32
  from opensportslib.core.scheduler.builder import build_scheduler
34
33
  from opensportslib.core.utils.config import store_json
35
34
  from opensportslib.datasets.builder import build_dataset
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
67
66
 
68
67
  # Handle checkpoint loading
69
68
  if resume_from is not None:
70
- if not os.path.isfile(resume_from):
71
- raise ValueError(f"Checkpoint file not found: {resume_from}")
72
-
73
- logging.info(f"Loading checkpoint from: {resume_from}")
74
- checkpoint = torch.load(resume_from)
75
-
76
- # Load model state
77
- model.load(checkpoint['model_state_dict'])
78
- logging.info("Model state loaded successfully")
79
-
80
- # Get current training progress
81
- start_epoch = checkpoint['epoch'] + 1
82
- logging.info(f"Resuming from epoch {start_epoch}")
83
-
69
+ optimizer = resume_from["optimizer"]
70
+ scheduler = resume_from["scheduler"]
71
+ scaler = resume_from["scaler"]
72
+ start_epoch = resume_from["epoch"] + 1
84
73
  # Check if we've already reached target epochs
85
74
  if start_epoch >= cfg.TRAIN.num_epochs:
86
75
  logging.error(f"Model already trained for {start_epoch} epochs")
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
89
78
  raise ValueError("Need to increase num_epochs to continue training")
90
79
 
91
80
  logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
92
-
93
- logging.info("Building optimizer...")
94
- optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
95
-
96
- # Load optimizer state if available in checkpoint
97
- if resume_from is not None and 'optimizer_state_dict' in checkpoint:
98
- try:
99
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
100
- scaler.load_state_dict(checkpoint['scaler_state_dict'])
101
- logging.info("Optimizer and scaler states loaded")
102
- except Exception as e:
103
- logging.warning(f"Could not load optimizer state: {e}")
104
- logging.warning("Will start with fresh optimizer state")
105
-
106
- logging.info("Building scheduler...")
107
- lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
108
-
109
- # Load scheduler state if available
110
- if resume_from is not None and 'lr_state_dict' in checkpoint:
111
- try:
112
- lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
113
- logging.info("Scheduler state loaded")
114
- except Exception as e:
115
- logging.warning(f"Could not load scheduler state: {e}")
116
- logging.warning("Will start with fresh scheduler state")
81
+ else:
82
+ logging.info("Building optimizer...")
83
+ optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
84
+ logging.info("Building scheduler...")
85
+ scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
117
86
 
118
87
  trainer = Trainer_e2e(
119
88
  cfg,
120
89
  model,
121
90
  optimizer,
122
91
  scaler,
123
- lr_scheduler,
92
+ scheduler,
124
93
  default_args["work_dir"],
125
94
  default_args["dali"],
126
95
  default_args["repartitions"],
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
132
101
 
133
102
  # Load training history if resuming
134
103
  if resume_from is not None:
135
- trainer.best_epoch = checkpoint.get('best_epoch', 0)
136
- trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
104
+ trainer.best_epoch = resume_from.get('best_epoch', 0)
105
+ trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
137
106
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
138
107
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
139
108
 
@@ -441,7 +410,7 @@ class Trainer_e2e(Trainer):
441
410
  best_checkpoint_path = os.path.join(
442
411
  self.save_dir, f"best_checkpoint.pt"
443
412
  )
444
- self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
413
+ self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
445
414
  path=best_checkpoint_path,
446
415
  key_remap_fn=localization_remap)
447
416
  logging.info(f"Loaded best model from epoch {self.best_epoch}")
@@ -76,6 +76,7 @@ def load_checkpoint(
76
76
  path,
77
77
  optimizer=None,
78
78
  scheduler=None,
79
+ scaler=None,
79
80
  device=None,
80
81
  key_remap_fn=None,
81
82
  hf_filename="model.pth.tar", # required if loading from HF repo
@@ -164,7 +165,7 @@ def load_checkpoint(
164
165
  # --------------------------------------------------
165
166
  # Load checkpoint
166
167
  # --------------------------------------------------
167
- checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
168
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
168
169
 
169
170
  # ---------------- MODEL STATE ----------------
170
171
  if isinstance(checkpoint, dict):
@@ -201,8 +202,24 @@ def load_checkpoint(
201
202
  for k, v in state_dict.items()
202
203
  }
203
204
 
204
- state_dict = strip_prefix(state_dict, "module.")
205
+ # state_dict = strip_prefix(state_dict, "module.")
206
+ # state_dict = strip_prefix(state_dict, "model.")
207
+
208
+ # First remove known wrappers (safe ones)
205
209
  state_dict = strip_prefix(state_dict, "model.")
210
+ state_dict = strip_prefix(state_dict, "_model.")
211
+
212
+ # Now handle module dynamically
213
+ model_keys = list(model.state_dict().keys())
214
+ ckpt_keys = list(state_dict.keys())
215
+
216
+ model_has_module = model_keys[0].startswith("module.")
217
+ ckpt_has_module = ckpt_keys[0].startswith("module.")
218
+
219
+ if model_has_module and not ckpt_has_module:
220
+ state_dict = {f"module.{k}": v for k, v in state_dict.items()}
221
+ elif not model_has_module and ckpt_has_module:
222
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
206
223
 
207
224
  # Optional custom remap
208
225
  if key_remap_fn:
@@ -229,15 +246,20 @@ def load_checkpoint(
229
246
 
230
247
  # ---------------- SCHEDULER ----------------
231
248
  if scheduler and isinstance(checkpoint, dict):
232
- sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
249
+ sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
233
250
  if sch_state:
234
251
  scheduler.load_state_dict(sch_state)
235
252
 
253
+ if scaler and isinstance(checkpoint, dict):
254
+ scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
255
+ if scaler_state:
256
+ scaler.load_state_dict(scaler_state)
257
+
236
258
  print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
237
259
  print(f"Missing keys: {len(missing)}")
238
260
  print(f"Unexpected keys: {len(unexpected)}")
239
261
 
240
- return model, optimizer, scheduler, epoch
262
+ return model, optimizer, scheduler, scaler, epoch, checkpoint
241
263
 
242
264
 
243
265
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev6
3
+ Version: 0.1.2.dev7
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "opensportslib"
7
- version = "0.1.2.dev6"
7
+ version = "0.1.2.dev7"
8
8
  description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"