opensportslib 0.1.3.dev3__tar.gz → 0.1.3.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. {opensportslib-0.1.3.dev3/opensportslib.egg-info → opensportslib-0.1.3.dev4}/PKG-INFO +4 -3
  2. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/README.md +3 -2
  3. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/base_task_model.py +18 -9
  4. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/classification.py +63 -43
  5. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/localization.py +61 -61
  6. opensportslib-0.1.3.dev4/opensportslib/configs/captioning/encoder_decoder/video_captioning.yaml +136 -0
  7. opensportslib-0.1.3.dev4/opensportslib/configs/captioning/llava/llava_style.yaml +148 -0
  8. opensportslib-0.1.3.dev4/opensportslib/configs/classification/default.yaml +125 -0
  9. opensportslib-0.1.3.dev4/opensportslib/configs/classification/frames_npy/sngar-frames.yaml +171 -0
  10. opensportslib-0.1.3.dev4/opensportslib/configs/classification/tracking/sngar-tracking.yaml +170 -0
  11. opensportslib-0.1.3.dev4/opensportslib/configs/classification/video/classification.yaml +178 -0
  12. opensportslib-0.1.3.dev4/opensportslib/configs/localization/default.yaml +134 -0
  13. opensportslib-0.1.3.dev4/opensportslib/configs/localization/video/localization-dali.yaml +180 -0
  14. opensportslib-0.1.3.dev4/opensportslib/configs/localization/video/localization-ocv.yaml +189 -0
  15. opensportslib-0.1.3.dev4/opensportslib/configs/localization/video_features/localization-calf-resnetpca512.yaml +201 -0
  16. opensportslib-0.1.3.dev4/opensportslib/configs/localization/video_features/localization-netvladpp-resnetpca512.yaml +192 -0
  17. opensportslib-0.1.3.dev4/opensportslib/configs/reasoning/multimodal/video_text_fusion.yaml +161 -0
  18. opensportslib-0.1.3.dev4/opensportslib/configs/retrieval/two_tower/video_text_retrieval.yaml +150 -0
  19. opensportslib-0.1.3.dev4/opensportslib/core/config/__init__.py +22 -0
  20. opensportslib-0.1.3.dev4/opensportslib/core/config/accessors.py +374 -0
  21. opensportslib-0.1.3.dev4/opensportslib/core/config/conflicts.py +130 -0
  22. opensportslib-0.1.3.dev4/opensportslib/core/config/loader.py +77 -0
  23. opensportslib-0.1.3.dev4/opensportslib/core/config/migrate.py +22 -0
  24. opensportslib-0.1.3.dev4/opensportslib/core/config/migrations/__init__.py +5 -0
  25. opensportslib-0.1.3.dev4/opensportslib/core/config/migrations/legacy_to_canonical.py +614 -0
  26. opensportslib-0.1.3.dev4/opensportslib/core/config/runtime_adapter.py +60 -0
  27. opensportslib-0.1.3.dev4/opensportslib/core/config/schema.py +61 -0
  28. opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/__init__.py +6 -0
  29. opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/schema_canonical.py +22 -0
  30. opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/schema_legacy.py +17 -0
  31. opensportslib-0.1.3.dev4/opensportslib/core/config/validate.py +89 -0
  32. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/builder.py +6 -2
  33. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/optimizer/builder.py +4 -0
  34. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/sampler/weighted_sampler.py +8 -5
  35. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/scheduler/builder.py +3 -0
  36. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/classification_trainer.py +138 -82
  37. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/localization_trainer.py +81 -49
  38. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/config.py +51 -42
  39. opensportslib-0.1.3.dev4/opensportslib/core/utils/config_normalize.py +29 -0
  40. opensportslib-0.1.3.dev4/opensportslib/core/utils/default_args.py +143 -0
  41. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/load_annotations.py +52 -30
  42. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/video_processing.py +10 -4
  43. opensportslib-0.1.3.dev4/opensportslib/core/utils/wandb.py +217 -0
  44. opensportslib-0.1.3.dev4/opensportslib/datasets/builder.py +16 -0
  45. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/classification_dataset.py +140 -44
  46. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/localization_dataset.py +243 -103
  47. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/utils/tracking.py +18 -12
  48. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/classification.yaml +4 -4
  49. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-e2e-ocv.yaml +1 -1
  50. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-json_calf_resnetpca512.yaml +1 -1
  51. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-json_netvlad++_resnetpca512.yaml +1 -1
  52. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization.yaml +5 -6
  53. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/sngar-frames.yaml +4 -4
  54. {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/sngar-tracking.yaml +5 -5
  55. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/metrics/localization_metric.py +15 -0
  56. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/backbones/builder.py +1 -1
  57. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/contextaware.py +36 -17
  58. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/learnablepooling.py +44 -22
  59. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/tracking.py +24 -9
  60. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/vars.py +7 -2
  61. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/video.py +41 -11
  62. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/video_mae.py +20 -5
  63. opensportslib-0.1.3.dev4/opensportslib/models/builder.py +145 -0
  64. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4/opensportslib.egg-info}/PKG-INFO +4 -3
  65. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/SOURCES.txt +37 -7
  66. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/pyproject.toml +2 -2
  67. opensportslib-0.1.3.dev4/tests/conftest.py +581 -0
  68. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_classification_dataset_paths.py +32 -12
  69. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_classification_trainer_dataloader.py +47 -11
  70. opensportslib-0.1.3.dev4/tests/test_config_architecture.py +104 -0
  71. opensportslib-0.1.3.dev4/tests/test_config_split_override_sync.py +37 -0
  72. opensportslib-0.1.3.dev4/tests/test_pretrained_config_merge_policy.py +223 -0
  73. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_public_apis_smoke.py +4 -4
  74. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_task_model_api_contract.py +35 -9
  75. opensportslib-0.1.3.dev3/opensportslib/core/utils/default_args.py +0 -110
  76. opensportslib-0.1.3.dev3/opensportslib/core/utils/wandb.py +0 -280
  77. opensportslib-0.1.3.dev3/opensportslib/datasets/builder.py +0 -42
  78. opensportslib-0.1.3.dev3/opensportslib/models/builder.py +0 -66
  79. opensportslib-0.1.3.dev3/tests/conftest.py +0 -377
  80. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/LICENSE +0 -0
  81. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/LICENSE-COMMERCIAL +0 -0
  82. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/MANIFEST.in +0 -0
  83. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/examples/quickstart/basic_classification.py +0 -0
  84. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/examples/quickstart/basic_localization.py +0 -0
  85. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/__init__.py +0 -0
  86. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/__init__.py +0 -0
  87. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/cli.py +0 -0
  88. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/__init__.py +0 -0
  89. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/__init__.py +0 -0
  90. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/calf.py +0 -0
  91. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/ce.py +0 -0
  92. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/combine.py +0 -0
  93. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/nll.py +0 -0
  94. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/optimizer/__init__.py +0 -0
  95. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/scheduler/__init__.py +0 -0
  96. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/__init__.py +0 -0
  97. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/checkpoint.py +0 -0
  98. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/data.py +0 -0
  99. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/ddp.py +0 -0
  100. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/lightning.py +0 -0
  101. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/seed.py +0 -0
  102. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/__init__.py +0 -0
  103. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/utils/__init__.py +0 -0
  104. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/metrics/classification_metric.py +0 -0
  105. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/__init__.py +0 -0
  106. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/e2e.py +0 -0
  107. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/heads/builder.py +0 -0
  108. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/neck/builder.py +0 -0
  109. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/common.py +0 -0
  110. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/__init__.py +0 -0
  111. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/asformer.py +0 -0
  112. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/calf.py +0 -0
  113. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/gsm.py +0 -0
  114. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/gtad.py +0 -0
  115. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/tsm.py +0 -0
  116. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/litebase.py +0 -0
  117. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/modules.py +0 -0
  118. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/shift.py +0 -0
  119. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/utils.py +0 -0
  120. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/setup/setup.py +0 -0
  121. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/__init__.py +0 -0
  122. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/_common.py +0 -0
  123. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/hf_transfer.py +0 -0
  124. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  125. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  126. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/dependency_links.txt +0 -0
  127. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/entry_points.txt +0 -0
  128. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/requires.txt +0 -0
  129. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/top_level.txt +0 -0
  130. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/setup.cfg +0 -0
  131. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_config_utils_smoke.py +0 -0
  132. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_conversion_tools.py +0 -0
  133. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_hf_transfer_tools.py +0 -0
  134. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_localization_dali_filenames.py +0 -0
  135. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_package_smoke.py +0 -0
  136. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_subset_train_infer_integration.py +0 -0
  137. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/build_soccernet_gar.py +0 -0
  138. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/build_soccernet_gar_action_spotting.py +0 -0
  139. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  140. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  141. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/download_hf_repo.py +0 -0
  142. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/download_osl_hf.py +0 -0
  143. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/upload_osl_hf.py +0 -0
  144. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/training/classification.py +0 -0
  145. {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/training/localization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.3.dev3
3
+ Version: 0.1.3.dev4
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -34,6 +34,7 @@ Requires-Dist: pytest-cov; extra == "test"
34
34
  Dynamic: license-file
35
35
 
36
36
  # OpenSportsLib
37
+ <img src="docs/assets/osl.jpg" height="400">
37
38
 
38
39
  OpenSportsLib is a modular Python library for sports video understanding.
39
40
 
@@ -188,7 +189,7 @@ Minimal localization sample:
188
189
  ```
189
190
 
190
191
  Relative paths in `inputs[].path` are resolved from the split media root in the
191
- YAML config, for example `DATA.train.video_path`. See the full
192
+ YAML config, for example `DATA.common.splits.train.source_path`. See the full
192
193
  [OSL JSON format guide](docs/data/osl-json-format.md) for field definitions,
193
194
  multi-modal examples, prediction payloads, and conversion notes.
194
195
 
@@ -345,7 +346,7 @@ Use the README for the fast start, then go deeper through:
345
346
  - Full documentation: https://opensportslab.github.io/opensportslib/
346
347
  - OSL JSON format: [docs/data/osl-json-format.md](docs/data/osl-json-format.md)
347
348
  - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
348
- - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
349
+ - Configuration guide: https://opensportslab.github.io/opensportslib/config/configuration-guide/
349
350
  - Example configs: [examples/configs/](examples/configs/)
350
351
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
351
352
  - Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
@@ -1,4 +1,5 @@
1
1
  # OpenSportsLib
2
+ <img src="docs/assets/osl.jpg" height="400">
2
3
 
3
4
  OpenSportsLib is a modular Python library for sports video understanding.
4
5
 
@@ -153,7 +154,7 @@ Minimal localization sample:
153
154
  ```
154
155
 
155
156
  Relative paths in `inputs[].path` are resolved from the split media root in the
156
- YAML config, for example `DATA.train.video_path`. See the full
157
+ YAML config, for example `DATA.common.splits.train.source_path`. See the full
157
158
  [OSL JSON format guide](docs/data/osl-json-format.md) for field definitions,
158
159
  multi-modal examples, prediction payloads, and conversion notes.
159
160
 
@@ -310,7 +311,7 @@ Use the README for the fast start, then go deeper through:
310
311
  - Full documentation: https://opensportslab.github.io/opensportslib/
311
312
  - OSL JSON format: [docs/data/osl-json-format.md](docs/data/osl-json-format.md)
312
313
  - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
313
- - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
314
+ - Configuration guide: https://opensportslab.github.io/opensportslib/config/configuration-guide/
314
315
  - Example configs: [examples/configs/](examples/configs/)
315
316
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
316
317
  - Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
@@ -9,6 +9,7 @@ import uuid
9
9
  from abc import ABC, abstractmethod
10
10
  from typing import Any
11
11
 
12
+ from opensportslib.core.config.accessors import get_component_name_by_kind
12
13
  from opensportslib.core.utils.config import expand, load_config_omega, fetch_and_merge_config_from_HF
13
14
 
14
15
 
@@ -23,6 +24,8 @@ class BaseTaskModel(ABC):
23
24
 
24
25
  self.config_path = expand(config)
25
26
  self.config = load_config_omega(self.config_path)
27
+ self.last_loaded_weights = None
28
+ self.best_checkpoint = None
26
29
 
27
30
  if weights is not None:
28
31
  self.config = fetch_and_merge_config_from_HF(self.config, weights, merge_policy="compatibility")
@@ -41,15 +44,23 @@ class BaseTaskModel(ABC):
41
44
 
42
45
  system_cfg = getattr(self.config, "SYSTEM", None)
43
46
  if system_cfg is not None:
44
- base_save_dir = expand(getattr(system_cfg, "save_dir", None) or "./checkpoints")
45
- model_cfg = getattr(self.config, "MODEL", None)
46
- backbone_cfg = getattr(model_cfg, "backbone", None)
47
- model_name = getattr(backbone_cfg, "type", None) or "model"
47
+ system_paths = getattr(system_cfg, "paths", None)
48
+ base_save_dir = expand(
49
+ getattr(system_paths, "save_dir", None)
50
+ or getattr(system_cfg, "save_dir", None)
51
+ or "./checkpoints"
52
+ )
53
+ model_name = get_component_name_by_kind(self.config, "encoder") or "model"
48
54
  run_save_dir = os.path.join(base_save_dir, model_name, self.run_id)
49
55
  self.save_dir = run_save_dir
50
- system_cfg.save_dir = run_save_dir
51
- if hasattr(system_cfg, "work_dir"):
52
- system_cfg.work_dir = run_save_dir
56
+ if system_paths is not None:
57
+ system_paths.save_dir = run_save_dir
58
+ if hasattr(system_paths, "work_dir"):
59
+ system_paths.work_dir = run_save_dir
60
+ else:
61
+ system_cfg.save_dir = run_save_dir
62
+ if hasattr(system_cfg, "work_dir"):
63
+ system_cfg.work_dir = run_save_dir
53
64
  os.makedirs(run_save_dir, exist_ok=True)
54
65
  else:
55
66
  self.save_dir = expand("./checkpoints")
@@ -60,8 +71,6 @@ class BaseTaskModel(ABC):
60
71
  self.model = None
61
72
  self.processor = None
62
73
  self.trainer = None
63
- self.best_checkpoint = None
64
- self.last_loaded_weights = None
65
74
 
66
75
  if weights is not None:
67
76
  self.load_weights(weights=weights)
@@ -4,8 +4,17 @@
4
4
 
5
5
  import logging
6
6
  import os
7
+ import json
7
8
 
8
9
  from opensportslib.apis.base_task_model import BaseTaskModel
10
+ from opensportslib.core.config.accessors import (
11
+ get_component_provider_by_kind,
12
+ get_data_modality,
13
+ get_split_annotation_path,
14
+ get_system_gpu_count,
15
+ get_system_seed,
16
+ get_system_use_seed,
17
+ )
9
18
  from opensportslib.core.utils.config import expand
10
19
 
11
20
  class ClassificationModel(BaseTaskModel):
@@ -15,24 +24,13 @@ class ClassificationModel(BaseTaskModel):
15
24
  if override is not None:
16
25
  return expand(override)
17
26
 
18
- data_cfg = getattr(self.config, "DATA", None)
19
- split_cfg = getattr(data_cfg, split, None)
20
- path = getattr(split_cfg, "path", None) if split_cfg is not None else None
21
- if path:
22
- return expand(path)
23
-
24
- annotations_cfg = getattr(data_cfg, "annotations", None)
25
- path = (
26
- getattr(annotations_cfg, split, None)
27
- if annotations_cfg is not None
28
- else None
29
- )
27
+ path = get_split_annotation_path(self.config, split)
30
28
  if path:
31
29
  return expand(path)
32
30
 
33
31
  raise ValueError(
34
32
  f"Could not resolve path for split '{split}'. "
35
- f"Expected DATA.{split}.path or DATA.annotations.{split}."
33
+ f"Expected DATA.common.splits.{split}.annotation_path."
36
34
  )
37
35
 
38
36
  # -----------------------------------------------------------------
@@ -54,6 +52,7 @@ class ClassificationModel(BaseTaskModel):
54
52
  ):
55
53
  """Execute one training/inference job on a single process."""
56
54
  import torch
55
+ import wandb
57
56
  from opensportslib.core.trainer.classification_trainer import Trainer_Classification
58
57
  from opensportslib.core.utils.ddp import ddp_cleanup, ddp_setup
59
58
  from opensportslib.core.utils.wandb import init_wandb
@@ -77,8 +76,8 @@ class ClassificationModel(BaseTaskModel):
77
76
  use_wandb=use_wandb,
78
77
  )
79
78
 
80
- if getattr(config.SYSTEM, "use_seed", False):
81
- set_reproducibility(config.SYSTEM.seed)
79
+ if get_system_use_seed(config):
80
+ set_reproducibility(get_system_seed(config))
82
81
 
83
82
  is_ddp = world_size > 1
84
83
  if is_ddp:
@@ -100,32 +99,50 @@ class ClassificationModel(BaseTaskModel):
100
99
 
101
100
  trainer.model = model
102
101
 
103
- if mode == "train":
104
- train_data = build_dataset(config, train_set, processor, split="train")
105
- valid_data = build_dataset(config, valid_set, processor, split="valid")
106
- best_ckpt = trainer.train(
107
- model,
108
- train_data,
109
- valid_data,
110
- rank=rank,
111
- world_size=world_size,
112
- )
113
- if rank == 0 and return_queue is not None:
114
- best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
115
- return_queue.put(best_ckpt)
116
-
117
- elif mode == "infer":
118
- test_data = build_dataset(config, test_set, processor, split="test")
119
- predictions = trainer.infer(
120
- test_data,
121
- rank=rank,
122
- world_size=world_size,
123
- )
124
- if rank == 0 and return_queue is not None:
125
- return_queue.put(predictions)
102
+ modality = get_data_modality(config)
103
+ use_tracking_collate = modality in {"tracking", "tracking_parquet"}
104
+ logging.info(
105
+ "Worker setup | mode=%s | modality=%s | tracking_collate=%s",
106
+ mode,
107
+ modality,
108
+ use_tracking_collate,
109
+ )
126
110
 
127
- if is_ddp:
128
- ddp_cleanup()
111
+ try:
112
+ if mode == "train":
113
+ train_data = build_dataset(config, train_set, processor, split="train")
114
+ valid_data = build_dataset(config, valid_set, processor, split="valid")
115
+ best_ckpt = trainer.train(
116
+ model,
117
+ train_data,
118
+ valid_data,
119
+ rank=rank,
120
+ world_size=world_size,
121
+ )
122
+ if rank == 0 and return_queue is not None:
123
+ best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
124
+ return_queue.put(best_ckpt)
125
+
126
+ elif mode == "infer":
127
+ test_data = build_dataset(config, test_set, processor, split="test")
128
+ _ = trainer.infer(
129
+ test_data,
130
+ rank=rank,
131
+ world_size=world_size,
132
+ )
133
+ if rank == 0 and return_queue is not None:
134
+ if world_size > 1:
135
+ return_queue.put(getattr(trainer, "predictions_path", None))
136
+ else:
137
+ return_queue.put(getattr(trainer, "predictions_payload", None))
138
+ finally:
139
+ if rank == 0 and use_wandb and getattr(wandb, "run", None) is not None:
140
+ try:
141
+ wandb.finish(quiet=True)
142
+ except Exception:
143
+ pass
144
+ if is_ddp:
145
+ ddp_cleanup()
129
146
 
130
147
  def load_weights(
131
148
  self,
@@ -142,7 +159,7 @@ class ClassificationModel(BaseTaskModel):
142
159
  loaded = self.trainer.load(weights)
143
160
  self.model = loaded[0]
144
161
 
145
- if getattr(self.config.MODEL, "type", "custom") == "huggingface":
162
+ if get_component_provider_by_kind(self.config, "encoder") == "huggingface":
146
163
  self.processor = loaded[1]
147
164
 
148
165
  self.last_loaded_weights = weights
@@ -180,11 +197,11 @@ class ClassificationModel(BaseTaskModel):
180
197
 
181
198
  effective_weights = weights if weights is not None else self.last_loaded_weights
182
199
 
183
- world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
200
+ world_size = torch.cuda.device_count() or get_system_gpu_count(self.config)
184
201
  use_ddp = use_ddp and world_size > 1
185
202
 
186
203
  ctx = mp.get_context("spawn")
187
- queue = ctx.Queue()
204
+ queue = ctx.SimpleQueue()
188
205
 
189
206
  if use_ddp:
190
207
  logging.info(f"Launching DDP on {world_size} GPUs")
@@ -283,6 +300,9 @@ class ClassificationModel(BaseTaskModel):
283
300
  )
284
301
 
285
302
  predictions = queue.get()
303
+ if use_ddp and isinstance(predictions, str):
304
+ with open(predictions, encoding="utf-8") as f:
305
+ predictions = json.load(f)
286
306
  return predictions
287
307
 
288
308
  def evaluate(
@@ -3,6 +3,18 @@ import os
3
3
  import time
4
4
 
5
5
  from opensportslib.apis.base_task_model import BaseTaskModel
6
+ from opensportslib.core.config.accessors import (
7
+ get_data_classes,
8
+ get_loader_backend,
9
+ get_system_gpu_count,
10
+ get_system_seed,
11
+ set_system_path,
12
+ get_train_trainer_type,
13
+ get_train_execution,
14
+ get_split_annotation_path,
15
+ get_split_cfg,
16
+ set_split_annotation_path,
17
+ )
6
18
  from opensportslib.core.utils.config import expand
7
19
 
8
20
 
@@ -21,44 +33,33 @@ class LocalizationModel(BaseTaskModel):
21
33
  if override is not None:
22
34
  return expand(override)
23
35
 
24
- data_cfg = getattr(self.config, "DATA", None)
25
- split_cfg = getattr(data_cfg, split, None)
26
- path = getattr(split_cfg, "path", None) if split_cfg is not None else None
27
- if path:
28
- return expand(path)
29
-
30
- annotations_cfg = getattr(data_cfg, "annotations", None)
31
- path = (
32
- getattr(annotations_cfg, split, None)
33
- if annotations_cfg is not None
34
- else None
35
- )
36
+ path = get_split_annotation_path(self.config, split)
36
37
  if path:
37
38
  return expand(path)
38
39
 
39
40
  raise ValueError(
40
41
  f"Could not resolve path for split '{split}'. "
41
- f"Expected DATA.{split}.path or DATA.annotations.{split}."
42
+ f"Expected DATA.common.splits.{split}.annotation_path."
42
43
  )
43
44
 
44
45
  def _set_split_path(self, split: str, value: str) -> str:
45
46
  resolved = expand(value)
46
- data_cfg = getattr(self.config, "DATA", None)
47
- split_cfg = getattr(data_cfg, split, None)
48
-
49
- if split_cfg is not None and hasattr(split_cfg, "path"):
50
- split_cfg.path = resolved
51
- return resolved
52
-
53
- annotations_cfg = getattr(data_cfg, "annotations", None)
54
- if annotations_cfg is not None and hasattr(annotations_cfg, split):
55
- setattr(annotations_cfg, split, resolved)
56
- return resolved
57
-
58
- raise ValueError(
59
- f"Could not set path for split '{split}'. "
60
- f"Expected DATA.{split}.path or DATA.annotations.{split}."
61
- )
47
+ set_split_annotation_path(self.config, split, resolved)
48
+ return resolved
49
+
50
+ def _gate_multi_gpu_by_device(self, device) -> None:
51
+ """Disable TRAIN.execution.multi_gpu when effective device is CPU."""
52
+ execution = getattr(getattr(self.config, "TRAIN", None), "execution", None)
53
+ if execution is None:
54
+ return
55
+
56
+ multi_gpu = bool(getattr(execution, "multi_gpu", False))
57
+ if device.type == "cpu" and multi_gpu:
58
+ execution.multi_gpu = False
59
+ logging.warning(
60
+ "Detected SYSTEM.device=%s; forcing TRAIN.execution.multi_gpu=false for localization runtime.",
61
+ device,
62
+ )
62
63
 
63
64
  def load_weights(
64
65
  self,
@@ -78,13 +79,8 @@ class LocalizationModel(BaseTaskModel):
78
79
  if weights is None:
79
80
  raise ValueError("`weights` must be provided to load_weights().")
80
81
 
81
- model_cfg = getattr(self.config, "MODEL", None)
82
- if not self.train_flag:
83
- original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
84
- if model_cfg is not None and original_multi_gpu is not None:
85
- model_cfg.multi_gpu = False
86
-
87
82
  device = select_device(self.config.SYSTEM)
83
+ self._gate_multi_gpu_by_device(device)
88
84
  if self.model is None:
89
85
  self.model = build_model(self.config, device=device)
90
86
 
@@ -93,7 +89,11 @@ class LocalizationModel(BaseTaskModel):
93
89
  inner_model = getattr(self.model, "model", self.model)
94
90
 
95
91
  if is_local_path(weights):
96
- self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(weights))
92
+ set_system_path(
93
+ self.config,
94
+ "work_dir",
95
+ os.path.dirname(os.path.abspath(weights)),
96
+ )
97
97
 
98
98
  if default_args is not None:
99
99
  logging.info("Building optimizer + scaler for checkpoint restore...")
@@ -135,7 +135,7 @@ class LocalizationModel(BaseTaskModel):
135
135
 
136
136
  best_criterion_valid = checkpoint.get(
137
137
  "best_criterion_valid",
138
- 0 if self.config.TRAIN.criterion_valid == "map" else float("inf")
138
+ 0 if get_train_execution(self.config).get("criterion_valid") == "map" else float("inf")
139
139
  )
140
140
  self._resume_state = {
141
141
  "optimizer": optimizer,
@@ -146,10 +146,6 @@ class LocalizationModel(BaseTaskModel):
146
146
  "best_criterion_valid": best_criterion_valid,
147
147
  }
148
148
 
149
- if not self.train_flag:
150
- if model_cfg is not None and original_multi_gpu is not None:
151
- model_cfg.multi_gpu = original_multi_gpu
152
-
153
149
  def train(
154
150
  self,
155
151
  train_set=None,
@@ -178,6 +174,9 @@ class LocalizationModel(BaseTaskModel):
178
174
  valid_set = self._resolve_split_path("valid", valid_set)
179
175
  self._set_split_path("train", train_set)
180
176
  self._set_split_path("valid", valid_set)
177
+ # E2E validation mAP uses the `valid_data_frames` split; keep it in sync
178
+ # with explicit valid annotation overrides.
179
+ self._set_split_path("valid_data_frames", valid_set)
181
180
 
182
181
  self.config = resolve_config_omega(self.config, weights=weights)
183
182
  check_config(self.config, split="train")
@@ -203,34 +202,34 @@ class LocalizationModel(BaseTaskModel):
203
202
  torch.backends.cudnn.benchmark = False
204
203
  torch.use_deterministic_algorithms(True, warn_only=True)
205
204
 
206
- set_seed(self.config.SYSTEM.seed)
205
+ set_seed(get_system_seed(self.config))
207
206
 
208
207
  start = time.time()
209
208
 
210
209
  data_obj_train = build_dataset(self.config, split="train")
211
210
  dataset_train = data_obj_train.building_dataset(
212
211
  cfg=data_obj_train.cfg,
213
- gpu=self.config.SYSTEM.GPU,
212
+ gpu=get_system_gpu_count(self.config),
214
213
  default_args=data_obj_train.default_args,
215
214
  )
216
215
  train_loader = data_obj_train.building_dataloader(
217
216
  dataset_train,
218
217
  cfg=data_obj_train.cfg.dataloader,
219
- gpu=self.config.SYSTEM.GPU,
220
- dali=self.config.dali,
218
+ gpu=get_system_gpu_count(self.config),
219
+ dali=(get_loader_backend(self.config) == "dali"),
221
220
  )
222
221
 
223
222
  data_obj_valid = build_dataset(self.config, split="valid")
224
223
  dataset_valid = data_obj_valid.building_dataset(
225
224
  cfg=data_obj_valid.cfg,
226
- gpu=self.config.SYSTEM.GPU,
225
+ gpu=get_system_gpu_count(self.config),
227
226
  default_args=data_obj_valid.default_args,
228
227
  )
229
228
  valid_loader = data_obj_valid.building_dataloader(
230
229
  dataset_valid,
231
230
  cfg=data_obj_valid.cfg.dataloader,
232
- gpu=self.config.SYSTEM.GPU,
233
- dali=self.config.dali,
231
+ gpu=get_system_gpu_count(self.config),
232
+ dali=(get_loader_backend(self.config) == "dali"),
234
233
  )
235
234
 
236
235
  default_args = get_default_args_trainer(self.config, len(train_loader))
@@ -241,6 +240,7 @@ class LocalizationModel(BaseTaskModel):
241
240
  self.load_weights(weights=effective_weights, default_args=default_args)
242
241
  elif self.model is None:
243
242
  device = select_device(self.config.SYSTEM)
243
+ self._gate_multi_gpu_by_device(device)
244
244
  self.model = build_model(self.config, device=device)
245
245
 
246
246
  self.trainer = build_trainer(
@@ -257,8 +257,8 @@ class LocalizationModel(BaseTaskModel):
257
257
  self.model,
258
258
  train_loader,
259
259
  valid_loader,
260
- self.config.DATA.classes,
261
- self.config.TRAIN.type,
260
+ get_data_classes(self.config),
261
+ get_train_trainer_type(self.config),
262
262
  )
263
263
  )
264
264
 
@@ -291,10 +291,9 @@ class LocalizationModel(BaseTaskModel):
291
291
  test_set = self._resolve_split_path("test", test_set)
292
292
  self._set_split_path("test", test_set)
293
293
 
294
- self.config.MODEL.multi_gpu = False
295
294
  self.config = resolve_config_omega(self.config, weights=weights)
296
295
  check_config(self.config, split="test")
297
- self.config.infer_split = whether_infer_split(self.config.DATA.test)
296
+ self.config.infer_split = whether_infer_split(get_split_cfg(self.config, "test"))
298
297
 
299
298
  init_wandb(
300
299
  self.config_path,
@@ -315,22 +314,23 @@ class LocalizationModel(BaseTaskModel):
315
314
  self.load_weights(weights=effective_weights)
316
315
  elif self.model is None:
317
316
  device = select_device(self.config.SYSTEM)
317
+ self._gate_multi_gpu_by_device(device)
318
318
  self.model = build_model(self.config, device=device)
319
319
 
320
320
  data_obj_test = build_dataset(self.config, split="test")
321
321
  dataset_test = data_obj_test.building_dataset(
322
322
  cfg=data_obj_test.cfg,
323
- gpu=self.config.SYSTEM.GPU,
323
+ gpu=get_system_gpu_count(self.config),
324
324
  default_args=data_obj_test.default_args,
325
325
  )
326
326
  test_loader = data_obj_test.building_dataloader(
327
327
  dataset_test,
328
328
  cfg=data_obj_test.cfg.dataloader,
329
- gpu=self.config.SYSTEM.GPU,
330
- dali=self.config.dali,
329
+ gpu=get_system_gpu_count(self.config),
330
+ dali=(get_loader_backend(self.config) == "dali"),
331
331
  )
332
332
 
333
- inferer = build_inferer(cfg=self.config.MODEL, model=self.model)
333
+ inferer = build_inferer(cfg=self.config, model=self.model)
334
334
  predictions = inferer.infer(
335
335
  cfg=self.config,
336
336
  data=dataset_test,
@@ -361,10 +361,9 @@ class LocalizationModel(BaseTaskModel):
361
361
 
362
362
  test_set = self._resolve_split_path("test", test_set)
363
363
  self._set_split_path("test", test_set)
364
- self.config.MODEL.multi_gpu = False
365
364
  self.config = resolve_config_omega(self.config, weights=weights)
366
365
  check_config(self.config, split="test")
367
- self.config.infer_split = whether_infer_split(self.config.DATA.test)
366
+ self.config.infer_split = whether_infer_split(get_split_cfg(self.config, "test"))
368
367
 
369
368
  init_wandb(
370
369
  self.config_path,
@@ -382,16 +381,17 @@ class LocalizationModel(BaseTaskModel):
382
381
 
383
382
  metrics = None
384
383
 
385
- if has_localization_events(self.config.DATA.test.path):
384
+ test_path = get_split_annotation_path(self.config, "test")
385
+ if has_localization_events(test_path):
386
386
  logging.info("Ground truth labels detected -> running evaluation")
387
387
  evaluator = build_evaluator(cfg=self.config)
388
388
  eval_input = (
389
- self.config.DATA.test.results
389
+ getattr(get_split_cfg(self.config, "test"), "results")
390
390
  if isinstance(predictions, dict)
391
391
  else predictions
392
392
  )
393
393
  metrics = evaluator.evaluate(
394
- cfg_testset=self.config.DATA.test,
394
+ cfg_testset=get_split_cfg(self.config, "test"),
395
395
  json_gz_file=eval_input,
396
396
  )
397
397
  else: