opensportslib 0.1.2.dev5__tar.gz → 0.1.2.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {opensportslib-0.1.2.dev5/opensportslib.egg-info → opensportslib-0.1.2.dev7}/PKG-INFO +1 -1
  2. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/localization.py +59 -12
  3. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/classification_trainer.py +2 -1
  4. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/localization_trainer.py +13 -44
  5. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/checkpoint.py +26 -4
  6. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/localization_dataset.py +28 -18
  7. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7/opensportslib.egg-info}/PKG-INFO +1 -1
  8. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/pyproject.toml +1 -1
  9. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/LICENSE +0 -0
  10. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/LICENSE-COMMERCIAL +0 -0
  11. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/MANIFEST.in +0 -0
  12. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/README.md +0 -0
  13. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_classification.py +0 -0
  14. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/examples/quickstart/basic_localization.py +0 -0
  15. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/__init__.py +0 -0
  16. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/__init__.py +0 -0
  17. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/base_task_model.py +0 -0
  18. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/apis/classification.py +0 -0
  19. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/cli.py +0 -0
  20. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/classification.yaml +0 -0
  21. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  22. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  23. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  24. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/localization.yaml +0 -0
  25. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-frames.yaml +0 -0
  26. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/config/sngar-tracking.yaml +0 -0
  27. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/__init__.py +0 -0
  28. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/__init__.py +0 -0
  29. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/builder.py +0 -0
  30. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/calf.py +0 -0
  31. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/ce.py +0 -0
  32. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/combine.py +0 -0
  33. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/loss/nll.py +0 -0
  34. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/__init__.py +0 -0
  35. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/optimizer/builder.py +0 -0
  36. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  37. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/__init__.py +0 -0
  38. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/scheduler/builder.py +0 -0
  39. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/trainer/__init__.py +0 -0
  40. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/config.py +0 -0
  41. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/data.py +0 -0
  42. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/ddp.py +0 -0
  43. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/default_args.py +0 -0
  44. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/lightning.py +0 -0
  45. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/load_annotations.py +0 -0
  46. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/seed.py +0 -0
  47. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/video_processing.py +0 -0
  48. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/core/utils/wandb.py +0 -0
  49. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/__init__.py +0 -0
  50. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/builder.py +0 -0
  51. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/classification_dataset.py +0 -0
  52. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/__init__.py +0 -0
  53. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/datasets/utils/tracking.py +0 -0
  54. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/classification_metric.py +0 -0
  55. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/metrics/localization_metric.py +0 -0
  56. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/__init__.py +0 -0
  57. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/backbones/builder.py +0 -0
  58. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/contextaware.py +0 -0
  59. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/e2e.py +0 -0
  60. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/learnablepooling.py +0 -0
  61. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/tracking.py +0 -0
  62. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/vars.py +0 -0
  63. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video.py +0 -0
  64. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/base/video_mae.py +0 -0
  65. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/builder.py +0 -0
  66. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/heads/builder.py +0 -0
  67. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/neck/builder.py +0 -0
  68. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/common.py +0 -0
  69. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/__init__.py +0 -0
  70. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/asformer.py +0 -0
  71. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/calf.py +0 -0
  72. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gsm.py +0 -0
  73. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/gtad.py +0 -0
  74. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/impl/tsm.py +0 -0
  75. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/litebase.py +0 -0
  76. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/modules.py +0 -0
  77. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/shift.py +0 -0
  78. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/models/utils/utils.py +0 -0
  79. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/setup/setup.py +0 -0
  80. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/__init__.py +0 -0
  81. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/_common.py +0 -0
  82. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/hf_transfer.py +0 -0
  83. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  84. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  85. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/SOURCES.txt +0 -0
  86. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/dependency_links.txt +0 -0
  87. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/entry_points.txt +0 -0
  88. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/requires.txt +0 -0
  89. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/opensportslib.egg-info/top_level.txt +0 -0
  90. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/setup.cfg +0 -0
  91. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/conftest.py +0 -0
  92. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_config_utils_smoke.py +0 -0
  93. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_conversion_tools.py +0 -0
  94. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_hf_transfer_tools.py +0 -0
  95. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_package_smoke.py +0 -0
  96. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_public_apis_smoke.py +0 -0
  97. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_subset_train_infer_integration.py +0 -0
  98. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tests/test_task_model_api_contract.py +0 -0
  99. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  100. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  101. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/download_hf_repo.py +0 -0
  102. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/download_osl_hf.py +0 -0
  103. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/download/upload_osl_hf.py +0 -0
  104. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/training/classification.py +0 -0
  105. {opensportslib-0.1.2.dev5 → opensportslib-0.1.2.dev7}/tools/training/localization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev5
3
+ Version: 0.1.2.dev7
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -15,6 +15,8 @@ class LocalizationModel(BaseTaskModel):
15
15
  self.last_loaded_weights = weights
16
16
  self.best_checkpoint = weights
17
17
 
18
+ self.train_flag = False # Flag to indicate whether we're in training mode (affects checkpoint loading behavior)
19
+
18
20
  def _resolve_split_path(self, split: str, override: str | None = None) -> str:
19
21
  if override is not None:
20
22
  return expand(override)
@@ -69,15 +71,18 @@ class LocalizationModel(BaseTaskModel):
69
71
  load_checkpoint,
70
72
  localization_remap,
71
73
  )
72
-
74
+ from opensportslib.core.optimizer.builder import build_optimizer
75
+ from opensportslib.core.scheduler.builder import build_scheduler
76
+ default_args = kwargs.get("default_args", None)
73
77
  del kwargs
74
78
  if weights is None:
75
79
  raise ValueError("`weights` must be provided to load_weights().")
76
80
 
77
81
  model_cfg = getattr(self.config, "MODEL", None)
78
- original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
79
- if model_cfg is not None and original_multi_gpu is not None:
80
- model_cfg.multi_gpu = False
82
+ if not self.train_flag:
83
+ original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
84
+ if model_cfg is not None and original_multi_gpu is not None:
85
+ model_cfg.multi_gpu = False
81
86
 
82
87
  device = select_device(self.config.SYSTEM)
83
88
  if self.model is None:
@@ -90,9 +95,28 @@ class LocalizationModel(BaseTaskModel):
90
95
  if is_local_path(weights):
91
96
  self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
92
97
 
93
- inner_model, _, _, _ = load_checkpoint(
98
+ if default_args is not None:
99
+ logging.info("Building optimizer + scaler for checkpoint restore...")
100
+ optimizer, scaler = build_optimizer(
101
+ inner_model.parameters(), # or _get_params() if required
102
+ self.config.TRAIN.optimizer
103
+ )
104
+
105
+ logging.info("Building scheduler for checkpoint restore...")
106
+ scheduler = build_scheduler(
107
+ optimizer,
108
+ self.config.TRAIN.scheduler,
109
+ default_args
110
+ )
111
+ else:
112
+ optimizer = scheduler = scaler = None
113
+
114
+ inner_model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
94
115
  model=inner_model,
95
116
  path=weights,
117
+ optimizer=optimizer,
118
+ scheduler=scheduler,
119
+ scaler=scaler,
96
120
  device=device,
97
121
  key_remap_fn=localization_remap,
98
122
  )
@@ -107,8 +131,24 @@ class LocalizationModel(BaseTaskModel):
107
131
  self.last_loaded_weights = weights
108
132
  self.best_checkpoint = weights
109
133
 
110
- if model_cfg is not None and original_multi_gpu is not None:
111
- model_cfg.multi_gpu = original_multi_gpu
134
+ best_epoch = checkpoint.get("best_epoch", 0)
135
+
136
+ best_criterion_valid = checkpoint.get(
137
+ "best_criterion_valid",
138
+ 0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
139
+ )
140
+ self._resume_state = {
141
+ "optimizer": optimizer,
142
+ "scheduler": scheduler,
143
+ "scaler": scaler,
144
+ "epoch": epoch if epoch is not None else 0,
145
+ "best_epoch": best_epoch,
146
+ "best_criterion_valid": best_criterion_valid,
147
+ }
148
+
149
+ if not self.train_flag:
150
+ if model_cfg is not None and original_multi_gpu is not None:
151
+ model_cfg.multi_gpu = original_multi_gpu
112
152
 
113
153
  def train(
114
154
  self,
@@ -167,9 +207,6 @@ class LocalizationModel(BaseTaskModel):
167
207
 
168
208
  start = time.time()
169
209
 
170
- device = select_device(self.config.SYSTEM)
171
- self.model = build_model(self.config, device=device)
172
-
173
210
  data_obj_train = build_dataset(self.config, split="train")
174
211
  dataset_train = data_obj_train.building_dataset(
175
212
  cfg=data_obj_train.cfg,
@@ -196,11 +233,21 @@ class LocalizationModel(BaseTaskModel):
196
233
  dali=self.config.dali,
197
234
  )
198
235
 
236
+ default_args = get_default_args_trainer(self.config, len(train_loader))
237
+
238
+ self.train_flag = True # Set flag to indicate training mode for checkpoint loading
239
+ if effective_weights is not None:
240
+ if self.model is None or self.last_loaded_weights != effective_weights:
241
+ self.load_weights(weights=effective_weights, default_args=default_args)
242
+ elif self.model is None:
243
+ device = select_device(self.config.SYSTEM)
244
+ self.model = build_model(self.config, device=device)
245
+
199
246
  self.trainer = build_trainer(
200
247
  cfg=self.config,
201
248
  model=self.model,
202
- default_args=get_default_args_trainer(self.config, len(train_loader)),
203
- resume_from=effective_weights,
249
+ default_args=default_args,
250
+ resume_from=self._resume_state if hasattr(self, "_resume_state") else None,
204
251
  )
205
252
 
206
253
  logging.info("Start training")
@@ -1167,11 +1167,12 @@ class Trainer_Classification:
1167
1167
  from opensportslib.models.builder import build_model
1168
1168
  if self.model is None:
1169
1169
  self.model, _ = build_model(self.config, self.device)
1170
- self.model, optimizer, scheduler, epoch = load_checkpoint(
1170
+ self.model, optimizer, scheduler, scaler, epoch, checkpoint = load_checkpoint(
1171
1171
  self.model, path, optimizer, scheduler, device=self.device
1172
1172
  )
1173
1173
  self.optimizer = optimizer
1174
1174
  self.scheduler = scheduler
1175
+ self.scaler = scaler
1175
1176
  self.epoch = epoch
1176
1177
  logging.info(f"Model loaded from {path}, epoch: {epoch}")
1177
1178
  return self.model, self.optimizer, self.scheduler, self.epoch
@@ -29,7 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
  """
30
30
  from opensportslib.metrics.localization_metric import *
31
31
  from opensportslib.core.optimizer.builder import build_optimizer
32
- from opensportslib.core.optimizer.builder import build_optimizer
33
32
  from opensportslib.core.scheduler.builder import build_scheduler
34
33
  from opensportslib.core.utils.config import store_json
35
34
  from opensportslib.datasets.builder import build_dataset
@@ -67,20 +66,10 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
67
66
 
68
67
  # Handle checkpoint loading
69
68
  if resume_from is not None:
70
- if not os.path.isfile(resume_from):
71
- raise ValueError(f"Checkpoint file not found: {resume_from}")
72
-
73
- logging.info(f"Loading checkpoint from: {resume_from}")
74
- checkpoint = torch.load(resume_from)
75
-
76
- # Load model state
77
- model.load(checkpoint['model_state_dict'])
78
- logging.info("Model state loaded successfully")
79
-
80
- # Get current training progress
81
- start_epoch = checkpoint['epoch'] + 1
82
- logging.info(f"Resuming from epoch {start_epoch}")
83
-
69
+ optimizer = resume_from["optimizer"]
70
+ scheduler = resume_from["scheduler"]
71
+ scaler = resume_from["scaler"]
72
+ start_epoch = resume_from["epoch"] + 1
84
73
  # Check if we've already reached target epochs
85
74
  if start_epoch >= cfg.TRAIN.num_epochs:
86
75
  logging.error(f"Model already trained for {start_epoch} epochs")
@@ -89,38 +78,18 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
89
78
  raise ValueError("Need to increase num_epochs to continue training")
90
79
 
91
80
  logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
92
-
93
- logging.info("Building optimizer...")
94
- optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
95
-
96
- # Load optimizer state if available in checkpoint
97
- if resume_from is not None and 'optimizer_state_dict' in checkpoint:
98
- try:
99
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
100
- scaler.load_state_dict(checkpoint['scaler_state_dict'])
101
- logging.info("Optimizer and scaler states loaded")
102
- except Exception as e:
103
- logging.warning(f"Could not load optimizer state: {e}")
104
- logging.warning("Will start with fresh optimizer state")
105
-
106
- logging.info("Building scheduler...")
107
- lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
108
-
109
- # Load scheduler state if available
110
- if resume_from is not None and 'lr_state_dict' in checkpoint:
111
- try:
112
- lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
113
- logging.info("Scheduler state loaded")
114
- except Exception as e:
115
- logging.warning(f"Could not load scheduler state: {e}")
116
- logging.warning("Will start with fresh scheduler state")
81
+ else:
82
+ logging.info("Building optimizer...")
83
+ optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
84
+ logging.info("Building scheduler...")
85
+ scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
117
86
 
118
87
  trainer = Trainer_e2e(
119
88
  cfg,
120
89
  model,
121
90
  optimizer,
122
91
  scaler,
123
- lr_scheduler,
92
+ scheduler,
124
93
  default_args["work_dir"],
125
94
  default_args["dali"],
126
95
  default_args["repartitions"],
@@ -132,8 +101,8 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
132
101
 
133
102
  # Load training history if resuming
134
103
  if resume_from is not None:
135
- trainer.best_epoch = checkpoint.get('best_epoch', 0)
136
- trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
104
+ trainer.best_epoch = resume_from.get('best_epoch', 0)
105
+ trainer.best_criterion_valid = resume_from.get('best_criterion_valid',
137
106
  0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
138
107
  logging.info(f"Restored best epoch: {trainer.best_epoch}")
139
108
 
@@ -441,7 +410,7 @@ class Trainer_e2e(Trainer):
441
410
  best_checkpoint_path = os.path.join(
442
411
  self.save_dir, f"best_checkpoint.pt"
443
412
  )
444
- self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
413
+ self.model._model, _, _, _, epoch, _ = load_checkpoint(model=self.model._model,
445
414
  path=best_checkpoint_path,
446
415
  key_remap_fn=localization_remap)
447
416
  logging.info(f"Loaded best model from epoch {self.best_epoch}")
@@ -76,6 +76,7 @@ def load_checkpoint(
76
76
  path,
77
77
  optimizer=None,
78
78
  scheduler=None,
79
+ scaler=None,
79
80
  device=None,
80
81
  key_remap_fn=None,
81
82
  hf_filename="model.pth.tar", # required if loading from HF repo
@@ -164,7 +165,7 @@ def load_checkpoint(
164
165
  # --------------------------------------------------
165
166
  # Load checkpoint
166
167
  # --------------------------------------------------
167
- checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
168
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
168
169
 
169
170
  # ---------------- MODEL STATE ----------------
170
171
  if isinstance(checkpoint, dict):
@@ -201,8 +202,24 @@ def load_checkpoint(
201
202
  for k, v in state_dict.items()
202
203
  }
203
204
 
204
- state_dict = strip_prefix(state_dict, "module.")
205
+ # state_dict = strip_prefix(state_dict, "module.")
206
+ # state_dict = strip_prefix(state_dict, "model.")
207
+
208
+ # First remove known wrappers (safe ones)
205
209
  state_dict = strip_prefix(state_dict, "model.")
210
+ state_dict = strip_prefix(state_dict, "_model.")
211
+
212
+ # Now handle module dynamically
213
+ model_keys = list(model.state_dict().keys())
214
+ ckpt_keys = list(state_dict.keys())
215
+
216
+ model_has_module = model_keys[0].startswith("module.")
217
+ ckpt_has_module = ckpt_keys[0].startswith("module.")
218
+
219
+ if model_has_module and not ckpt_has_module:
220
+ state_dict = {f"module.{k}": v for k, v in state_dict.items()}
221
+ elif not model_has_module and ckpt_has_module:
222
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
206
223
 
207
224
  # Optional custom remap
208
225
  if key_remap_fn:
@@ -229,15 +246,20 @@ def load_checkpoint(
229
246
 
230
247
  # ---------------- SCHEDULER ----------------
231
248
  if scheduler and isinstance(checkpoint, dict):
232
- sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
249
+ sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict") or checkpoint.get("lr_scheduler") # some use "lr_scheduler"
233
250
  if sch_state:
234
251
  scheduler.load_state_dict(sch_state)
235
252
 
253
+ if scaler and isinstance(checkpoint, dict):
254
+ scaler_state = checkpoint.get("scaler") or checkpoint.get("scaler_state_dict")
255
+ if scaler_state:
256
+ scaler.load_state_dict(scaler_state)
257
+
236
258
  print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
237
259
  print(f"Missing keys: {len(missing)}")
238
260
  print(f"Unexpected keys: {len(unexpected)}")
239
261
 
240
- return model, optimizer, scheduler, epoch
262
+ return model, optimizer, scheduler, scaler, epoch, checkpoint
241
263
 
242
264
 
243
265
 
@@ -1016,7 +1016,10 @@ if DALI_AVAILABLE:
1016
1016
  for pipe in self.pipes:
1017
1017
  pipe.build()
1018
1018
 
1019
- super().__init__(self.pipes, output_map, size=self.nb_videos)
1019
+ # Pipeline returns (video, label_idx, frame_num) - label processing
1020
+ # is done post-hoc in get_attr to avoid DALI 2.0 fn.python_function issues
1021
+ internal_output_map = ['data', 'label_idx', 'frame_num']
1022
+ super().__init__(self.pipes, internal_output_map, size=self.nb_videos)
1020
1023
 
1021
1024
  self.device = torch.device(
1022
1025
  "cuda:{}".format(self.devices[1 if len(self.devices) > 1 else 0])
@@ -1052,8 +1055,19 @@ if DALI_AVAILABLE:
1052
1055
  Returns:
1053
1056
  dict :{"frames","contains_event","labels"}.
1054
1057
  """
1055
- batch_labels = batch["label"]
1058
+ batch_label_idx = batch["label_idx"]
1059
+ batch_frame_num = batch["frame_num"]
1056
1060
  batch_images = batch["data"]
1061
+
1062
+ batch_size = batch_label_idx.shape[0]
1063
+ batch_labels = torch.zeros(batch_size, self.clip_len, dtype=torch.int64)
1064
+ for b in range(batch_size):
1065
+ video_idx = int(batch_label_idx[b].item())
1066
+ frame_num = int(batch_frame_num[b].item())
1067
+ batch_labels[b] = torch.from_numpy(
1068
+ self._compute_labels(video_idx, frame_num)
1069
+ )
1070
+
1057
1071
  sum_labels = torch.sum(
1058
1072
  batch_labels, dim=1 if len(batch_labels.shape) == 2 else 0
1059
1073
  )
@@ -1229,26 +1243,22 @@ if DALI_AVAILABLE:
1229
1243
  std=[255, 255, 255],
1230
1244
  mirror=fn.random.coin_flip(),
1231
1245
  )
1232
- label = fn.python_function(
1233
- label, frame_num, function=self.edit_labels, device="gpu"
1234
- )
1235
- return video, label
1246
+ return video, label, frame_num
1236
1247
 
1237
- def edit_labels(self, label, frame_num):
1238
- """Construct a list having the same length as the number of frames. The elements of the list are the indexes (starting at 1) of the class where an event occurs, 0 otherwise.
1248
+ def _compute_labels(self, video_idx, frame_num):
1249
+ """Construct a label array for a clip. Each element is the class index
1250
+ (starting at 1) where an event occurs, 0 otherwise.
1239
1251
 
1240
1252
  Args:
1241
- label :index of the video to get the metadata.
1242
- frame_num :index of start frame.
1253
+ video_idx (int): Index of the video in self._labels.
1254
+ frame_num (int): Raw start frame number from the reader.
1243
1255
 
1244
1256
  Returns:
1245
- labels (cupy.array): the list of labels (corresponding to events) corresponding with the extracted frames.
1257
+ labels (np.ndarray): Label array of shape (clip_len,).
1246
1258
  """
1247
- import cupy
1248
-
1249
- video_meta = self._labels[label.item()]
1250
- base_idx = frame_num.item() // self._stride
1251
- labels = cupy.zeros(self.clip_len, np.int64)
1259
+ video_meta = self._labels[video_idx]
1260
+ base_idx = frame_num // self._stride
1261
+ labels = np.zeros(self.clip_len, np.int64)
1252
1262
 
1253
1263
  for event in video_meta["events"]:
1254
1264
  event_frame = event["frame"]
@@ -1258,12 +1268,12 @@ if DALI_AVAILABLE:
1258
1268
  label_idx >= self.dilate_len
1259
1269
  and label_idx < self.clip_len + self.dilate_len
1260
1270
  ):
1261
- label = self._class_dict[event["label"]]
1271
+ label_val = self._class_dict[event["label"]]
1262
1272
  for i in range(
1263
1273
  max(0, label_idx - self.dilate_len),
1264
1274
  min(self.clip_len, label_idx + self.dilate_len + 1),
1265
1275
  ):
1266
- labels[i] = label
1276
+ labels[i] = label_val
1267
1277
  return labels
1268
1278
 
1269
1279
  def print_info(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev5
3
+ Version: 0.1.2.dev7
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "opensportslib"
7
- version = "0.1.2.dev5"
7
+ version = "0.1.2.dev7"
8
8
  description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"