opensportslib 0.1.1.dev6__tar.gz → 0.1.2.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {opensportslib-0.1.1.dev6/opensportslib.egg-info → opensportslib-0.1.2.dev1}/PKG-INFO +11 -8
  2. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/README.md +10 -7
  3. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_classification.py +11 -7
  4. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/examples/quickstart/basic_localization.py +11 -7
  5. opensportslib-0.1.2.dev1/opensportslib/apis/__init__.py +15 -0
  6. opensportslib-0.1.2.dev1/opensportslib/apis/base_task_model.py +96 -0
  7. opensportslib-0.1.2.dev1/opensportslib/apis/classification.py +322 -0
  8. opensportslib-0.1.2.dev1/opensportslib/apis/localization.py +333 -0
  9. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/classification_trainer.py +55 -12
  10. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/localization_trainer.py +16 -18
  11. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1/opensportslib.egg-info}/PKG-INFO +11 -8
  12. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/SOURCES.txt +3 -1
  13. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/pyproject.toml +4 -1
  14. opensportslib-0.1.2.dev1/tests/conftest.py +359 -0
  15. opensportslib-0.1.2.dev1/tests/test_public_apis_smoke.py +38 -0
  16. opensportslib-0.1.2.dev1/tests/test_subset_train_infer_integration.py +292 -0
  17. opensportslib-0.1.2.dev1/tests/test_task_model_api_contract.py +73 -0
  18. opensportslib-0.1.1.dev6/opensportslib/apis/__init__.py +0 -21
  19. opensportslib-0.1.1.dev6/opensportslib/apis/classification.py +0 -364
  20. opensportslib-0.1.1.dev6/opensportslib/apis/localization.py +0 -239
  21. opensportslib-0.1.1.dev6/tests/conftest.py +0 -59
  22. opensportslib-0.1.1.dev6/tests/test_public_apis_smoke.py +0 -29
  23. opensportslib-0.1.1.dev6/tests/test_subset_train_infer_integration.py +0 -172
  24. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/LICENSE +0 -0
  25. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/LICENSE-COMMERCIAL +0 -0
  26. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/MANIFEST.in +0 -0
  27. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/__init__.py +0 -0
  28. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/cli.py +0 -0
  29. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/classification.yaml +0 -0
  30. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  31. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  32. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  33. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/localization.yaml +0 -0
  34. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/sngar-frames.yaml +0 -0
  35. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/config/sngar-tracking.yaml +0 -0
  36. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/__init__.py +0 -0
  37. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/__init__.py +0 -0
  38. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/builder.py +0 -0
  39. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/calf.py +0 -0
  40. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/ce.py +0 -0
  41. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/combine.py +0 -0
  42. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/loss/nll.py +0 -0
  43. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/optimizer/__init__.py +0 -0
  44. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/optimizer/builder.py +0 -0
  45. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  46. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/scheduler/__init__.py +0 -0
  47. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/scheduler/builder.py +0 -0
  48. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/trainer/__init__.py +0 -0
  49. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/checkpoint.py +0 -0
  50. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/config.py +0 -0
  51. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/data.py +0 -0
  52. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/ddp.py +0 -0
  53. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/default_args.py +0 -0
  54. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/lightning.py +0 -0
  55. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/load_annotations.py +0 -0
  56. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/seed.py +0 -0
  57. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/video_processing.py +0 -0
  58. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/core/utils/wandb.py +0 -0
  59. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/__init__.py +0 -0
  60. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/builder.py +0 -0
  61. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/classification_dataset.py +0 -0
  62. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/localization_dataset.py +0 -0
  63. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/utils/__init__.py +0 -0
  64. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/datasets/utils/tracking.py +0 -0
  65. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/metrics/classification_metric.py +0 -0
  66. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/metrics/localization_metric.py +0 -0
  67. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/__init__.py +0 -0
  68. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/backbones/builder.py +0 -0
  69. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/contextaware.py +0 -0
  70. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/e2e.py +0 -0
  71. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/learnablepooling.py +0 -0
  72. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/tracking.py +0 -0
  73. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/vars.py +0 -0
  74. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/video.py +0 -0
  75. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/base/video_mae.py +0 -0
  76. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/builder.py +0 -0
  77. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/heads/builder.py +0 -0
  78. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/neck/builder.py +0 -0
  79. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/common.py +0 -0
  80. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/__init__.py +0 -0
  81. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/asformer.py +0 -0
  82. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/calf.py +0 -0
  83. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/gsm.py +0 -0
  84. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/gtad.py +0 -0
  85. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/impl/tsm.py +0 -0
  86. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/litebase.py +0 -0
  87. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/modules.py +0 -0
  88. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/shift.py +0 -0
  89. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/models/utils/utils.py +0 -0
  90. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib/setup/setup.py +0 -0
  91. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/dependency_links.txt +0 -0
  92. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/entry_points.txt +0 -0
  93. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/requires.txt +0 -0
  94. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/opensportslib.egg-info/top_level.txt +0 -0
  95. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/setup.cfg +0 -0
  96. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/tests/test_config_utils_smoke.py +0 -0
  97. {opensportslib-0.1.1.dev6 → opensportslib-0.1.2.dev1}/tests/test_package_smoke.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.1.dev6
3
+ Version: 0.1.2.dev1
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -125,14 +125,14 @@ print("OpenSportsLib imported successfully")
125
125
  ```python
126
126
  from opensportslib import model
127
127
 
128
- myModel = model.classification(
129
- config="/path/to/classification.yaml"
128
+ myModel = model.ClassificationModel(
129
+ config="/path/to/classification.yaml",
130
+ weights="/path/to/weights.pt", # optional
130
131
  )
131
132
 
132
133
  myModel.train(
133
134
  train_set="/path/to/train_annotations.json",
134
135
  valid_set="/path/to/valid_annotations.json",
135
- pretrained="/path/to/pretrained.pt", # optional
136
136
  )
137
137
  ```
138
138
 
@@ -142,13 +142,16 @@ myModel.train(
142
142
  from opensportslib import model
143
143
 
144
144
  myModel = model.classification(
145
- config="/path/to/classification.yaml"
145
+ config="/path/to/classification.yaml",
146
+ weights="/path/to/weights.pt", # optional
147
+ )
148
+
149
+ predictions = myModel.infer(
150
+ test_set="/path/to/test_annotations.json",
146
151
  )
147
152
 
148
- metrics = myModel.infer(
153
+ metrics = myModel.evaluate(
149
154
  test_set="/path/to/test_annotations.json",
150
- pretrained="/path/to/checkpoints/final_model",
151
- predictions="/path/to/predictions.json"
152
155
  )
153
156
 
154
157
  print(metrics)
@@ -94,14 +94,14 @@ print("OpenSportsLib imported successfully")
94
94
  ```python
95
95
  from opensportslib import model
96
96
 
97
- myModel = model.classification(
98
- config="/path/to/classification.yaml"
97
+ myModel = model.ClassificationModel(
98
+ config="/path/to/classification.yaml",
99
+ weights="/path/to/weights.pt", # optional
99
100
  )
100
101
 
101
102
  myModel.train(
102
103
  train_set="/path/to/train_annotations.json",
103
104
  valid_set="/path/to/valid_annotations.json",
104
- pretrained="/path/to/pretrained.pt", # optional
105
105
  )
106
106
  ```
107
107
 
@@ -111,13 +111,16 @@ myModel.train(
111
111
  from opensportslib import model
112
112
 
113
113
  myModel = model.classification(
114
- config="/path/to/classification.yaml"
114
+ config="/path/to/classification.yaml",
115
+ weights="/path/to/weights.pt", # optional
116
+ )
117
+
118
+ predictions = myModel.infer(
119
+ test_set="/path/to/test_annotations.json",
115
120
  )
116
121
 
117
- metrics = myModel.infer(
122
+ metrics = myModel.evaluate(
118
123
  test_set="/path/to/test_annotations.json",
119
- pretrained="/path/to/checkpoints/final_model",
120
- predictions="/path/to/predictions.json"
121
124
  )
122
125
 
123
126
  print(metrics)
@@ -1,4 +1,4 @@
1
- from opensportslib import model
1
+ from opensportslib.apis import ClassificationModel
2
2
 
3
3
 
4
4
  def main():
@@ -7,20 +7,24 @@ def main():
7
7
  Update config and dataset paths before running.
8
8
  """
9
9
 
10
- my_model = model.classification(
11
- config="examples/configs/classification_video.yaml"
10
+ my_model = ClassificationModel(
11
+ config="examples/configs/classification_video.yaml",
12
+ weights="/path/to/weights.pt", # optional
12
13
  )
13
14
 
14
15
  my_model.train(
15
16
  train_set="/path/to/train_annotations.json",
16
17
  valid_set="/path/to/valid_annotations.json",
17
- pretrained="/path/to/pretrained.pt", # optional
18
18
  )
19
19
 
20
- metrics = my_model.infer(
20
+ predictions = my_model.infer(
21
+ test_set="/path/to/test_annotations.json",
22
+ )
23
+
24
+ print(predictions)
25
+
26
+ metrics = my_model.evaluate(
21
27
  test_set="/path/to/test_annotations.json",
22
- pretrained="/path/to/checkpoints/best.pt",
23
- predictions="/path/to/predictions.json",
24
28
  )
25
29
 
26
30
  print(metrics)
@@ -1,4 +1,4 @@
1
- from opensportslib import model
1
+ from opensportslib.apis import LocalizationModel
2
2
 
3
3
 
4
4
  def main():
@@ -7,20 +7,24 @@ def main():
7
7
  Update config and dataset paths before running.
8
8
  """
9
9
 
10
- my_model = model.localization(
11
- config="examples/configs/localization.yaml"
10
+ my_model = LocalizationModel(
11
+ config="examples/configs/localization.yaml",
12
+ weights="/path/to/weights.pt", # optional
12
13
  )
13
14
 
14
15
  my_model.train(
15
16
  train_set="/path/to/train_annotations.json",
16
17
  valid_set="/path/to/valid_annotations.json",
17
- pretrained="/path/to/pretrained.pt", # optional
18
18
  )
19
19
 
20
- metrics = my_model.infer(
20
+ predictions = my_model.infer(
21
+ test_set="/path/to/test_annotations.json",
22
+ )
23
+
24
+ print(predictions)
25
+
26
+ metrics = my_model.evaluate(
21
27
  test_set="/path/to/test_annotations.json",
22
- pretrained="/path/to/checkpoints/best.pt",
23
- predictions="/path/to/predictions.json",
24
28
  )
25
29
 
26
30
  print(metrics)
@@ -0,0 +1,15 @@
1
+ # opensportslib/apis/__init__.py
2
+
3
+ # Import task APIs
4
+ from opensportslib.apis.base_task_model import BaseTaskModel
5
+ from opensportslib.apis.classification import ClassificationModel
6
+ from opensportslib.apis.localization import LocalizationModel
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ # Expose only these
11
+ __all__ = [
12
+ "BaseTaskModel",
13
+ "ClassificationModel",
14
+ "LocalizationModel",
15
+ ]
@@ -0,0 +1,96 @@
1
+ """Shared task-level wrapper base for OpenSportsLib APIs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import uuid
8
+ from abc import ABC, abstractmethod
9
+
10
+ from opensportslib.core.utils.config import expand, load_config_omega
11
+
12
+
13
+ class BaseTaskModel(ABC):
14
+ """Thin shared contract for task-level OpenSportsLib wrappers."""
15
+
16
+ def __init__(self, config=None, weights=None):
17
+ if config is None:
18
+ raise ValueError("config path is required")
19
+
20
+ self.config_path = expand(config)
21
+ self.config = load_config_omega(self.config_path)
22
+
23
+ data_cfg = getattr(self.config, "DATA", None)
24
+ if data_cfg is not None and hasattr(data_cfg, "data_dir"):
25
+ data_cfg.data_dir = expand(data_cfg.data_dir)
26
+
27
+ self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
28
+ os.environ["RUN_ID"] = self.run_id
29
+
30
+ self.model = None
31
+ self.processor = None
32
+ self.trainer = None
33
+ self.best_checkpoint = None
34
+ self.last_loaded_weights = None
35
+
36
+ if weights is not None:
37
+ self.load_weights(weights=weights)
38
+
39
+ @abstractmethod
40
+ def load_weights(
41
+ self,
42
+ weights: str | None = None,
43
+ **kwargs,
44
+ ) -> None:
45
+ raise NotImplementedError
46
+
47
+ @abstractmethod
48
+ def train(
49
+ self,
50
+ train_set: str | None = None,
51
+ valid_set: str | None = None,
52
+ weights: str | None = None,
53
+ use_wandb: bool = True,
54
+ **kwargs,
55
+ ) -> str | None:
56
+ raise NotImplementedError
57
+
58
+ @abstractmethod
59
+ def infer(
60
+ self,
61
+ test_set: str | None = None,
62
+ weights: str | None = None,
63
+ use_wandb: bool = True,
64
+ **kwargs,
65
+ ) -> dict:
66
+ raise NotImplementedError
67
+
68
+ @abstractmethod
69
+ def evaluate(
70
+ self,
71
+ test_set: str | None = None,
72
+ weights: str | None = None,
73
+ use_wandb: bool = True,
74
+ **kwargs,
75
+ ) -> dict | str | None:
76
+ raise NotImplementedError
77
+
78
+ def save_predictions(
79
+ self,
80
+ output_path: str,
81
+ predictions: dict,
82
+ ) -> str:
83
+ """Persist in-memory prediction JSON payload to a target file path."""
84
+
85
+ dst = expand(output_path)
86
+ os.makedirs(os.path.dirname(dst) or ".", exist_ok=True)
87
+
88
+ if not isinstance(predictions, dict):
89
+ raise TypeError(
90
+ f"Unsupported predictions type: {type(predictions).__name__}. "
91
+ "Expected dict."
92
+ )
93
+
94
+ with open(dst, "w", encoding="utf-8") as f:
95
+ json.dump(predictions, f)
96
+ return dst
@@ -0,0 +1,322 @@
1
+ # opensportslib/apis/classification.py
2
+
3
+ """Public API for classification tasks."""
4
+
5
+ import logging
6
+ import os
7
+
8
+ from opensportslib.apis.base_task_model import BaseTaskModel
9
+ from opensportslib.core.utils.config import expand
10
+
11
+
12
+ class ClassificationModel(BaseTaskModel):
13
+ """Top-level task wrapper for classification."""
14
+
15
+ def _resolve_split_path(self, split: str, override: str | None = None) -> str:
16
+ if override is not None:
17
+ return expand(override)
18
+
19
+ data_cfg = getattr(self.config, "DATA", None)
20
+ split_cfg = getattr(data_cfg, split, None)
21
+ path = getattr(split_cfg, "path", None) if split_cfg is not None else None
22
+ if path:
23
+ return expand(path)
24
+
25
+ annotations_cfg = getattr(data_cfg, "annotations", None)
26
+ path = (
27
+ getattr(annotations_cfg, split, None)
28
+ if annotations_cfg is not None
29
+ else None
30
+ )
31
+ if path:
32
+ return expand(path)
33
+
34
+ raise ValueError(
35
+ f"Could not resolve path for split '{split}'. "
36
+ f"Expected DATA.{split}.path or DATA.annotations.{split}."
37
+ )
38
+
39
+ # -----------------------------------------------------------------
40
+ # internal DDP worker
41
+ # -----------------------------------------------------------------
42
+ @staticmethod
43
+ def _worker_ddp(
44
+ rank,
45
+ world_size,
46
+ mode,
47
+ config_path,
48
+ config,
49
+ return_queue=None,
50
+ train_set=None,
51
+ valid_set=None,
52
+ test_set=None,
53
+ weights=None,
54
+ use_wandb=False,
55
+ ):
56
+ """Execute one training/inference job on a single process."""
57
+ import torch
58
+ from opensportslib.core.trainer.classification_trainer import Trainer_Classification
59
+ from opensportslib.core.utils.ddp import ddp_cleanup, ddp_setup
60
+ from opensportslib.core.utils.wandb import init_wandb
61
+ from opensportslib.core.utils.seed import set_reproducibility
62
+ from opensportslib.datasets.builder import build_dataset
63
+ from opensportslib.models.builder import build_model
64
+
65
+ logging.basicConfig(
66
+ level=logging.INFO,
67
+ format=f"[RANK {rank}] %(asctime)s | %(levelname)s | %(message)s",
68
+ force=True,
69
+ )
70
+ if rank != 0:
71
+ logging.getLogger().setLevel(logging.ERROR)
72
+
73
+ if rank == 0:
74
+ init_wandb(
75
+ config_path,
76
+ config,
77
+ run_id=os.environ["RUN_ID"],
78
+ use_wandb=use_wandb,
79
+ )
80
+
81
+ if getattr(config.SYSTEM, "use_seed", False):
82
+ set_reproducibility(config.SYSTEM.seed)
83
+
84
+ is_ddp = world_size > 1
85
+ if is_ddp:
86
+ torch.cuda.set_device(rank)
87
+ ddp_setup(rank, world_size)
88
+ device = torch.device(f"cuda:{rank}")
89
+ else:
90
+ from opensportslib.core.utils.config import select_device
91
+
92
+ device = select_device(config.SYSTEM)
93
+
94
+ trainer = Trainer_Classification(config)
95
+ trainer.device = device
96
+
97
+ if weights:
98
+ model, processor, _, _ = trainer.load(weights)
99
+ else:
100
+ model, processor = build_model(config, device)
101
+
102
+ trainer.model = model
103
+
104
+ if mode == "train":
105
+ train_data = build_dataset(config, train_set, processor, split="train")
106
+ valid_data = build_dataset(config, valid_set, processor, split="valid")
107
+ best_ckpt = trainer.train(
108
+ model,
109
+ train_data,
110
+ valid_data,
111
+ rank=rank,
112
+ world_size=world_size,
113
+ )
114
+ if rank == 0 and return_queue is not None:
115
+ best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
116
+ return_queue.put(best_ckpt)
117
+
118
+ elif mode == "infer":
119
+ test_data = build_dataset(config, test_set, processor, split="test")
120
+ predictions = trainer.infer(
121
+ test_data,
122
+ rank=rank,
123
+ world_size=world_size,
124
+ )
125
+ if rank == 0 and return_queue is not None:
126
+ return_queue.put(predictions)
127
+
128
+ if is_ddp:
129
+ ddp_cleanup()
130
+
131
+ def load_weights(
132
+ self,
133
+ weights: str | None = None,
134
+ **kwargs,
135
+ ) -> None:
136
+ from opensportslib.core.trainer.classification_trainer import Trainer_Classification
137
+
138
+ del kwargs
139
+ if weights is None:
140
+ raise ValueError("`weights` must be provided to load_weights().")
141
+
142
+ self.trainer = Trainer_Classification(self.config)
143
+ loaded = self.trainer.load(weights)
144
+ self.model = loaded[0]
145
+
146
+ if getattr(self.config.MODEL, "type", "custom") == "huggingface":
147
+ self.processor = loaded[1]
148
+
149
+ self.last_loaded_weights = weights
150
+ self.best_checkpoint = weights
151
+
152
+ # -----------------------------------------------------------------
153
+ # public training interface
154
+ # -----------------------------------------------------------------
155
+
156
+ def train(
157
+ self,
158
+ train_set=None,
159
+ valid_set=None,
160
+ test_set=None,
161
+ weights=None,
162
+ use_ddp=False,
163
+ use_wandb=True,
164
+ **kwargs,
165
+ ):
166
+ """Run full training and return best checkpoint path."""
167
+ import torch
168
+ import torch.multiprocessing as mp
169
+ from opensportslib.core.utils.config import resolve_config_omega
170
+
171
+ del test_set # retained for API compatibility
172
+
173
+ train_set = self._resolve_split_path("train", train_set)
174
+ valid_set = self._resolve_split_path("valid", valid_set)
175
+
176
+ self.config = resolve_config_omega(self.config)
177
+ logging.info("Configuration:")
178
+ logging.info(self.config)
179
+
180
+ del kwargs
181
+
182
+ world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
183
+ use_ddp = use_ddp and world_size > 1
184
+
185
+ ctx = mp.get_context("spawn")
186
+ queue = ctx.Queue()
187
+
188
+ if use_ddp:
189
+ logging.info(f"Launching DDP on {world_size} GPUs")
190
+ mp.spawn(
191
+ ClassificationModel._worker_ddp,
192
+ args=(
193
+ world_size,
194
+ "train",
195
+ self.config_path,
196
+ self.config,
197
+ queue,
198
+ train_set,
199
+ valid_set,
200
+ None,
201
+ weights,
202
+ use_wandb,
203
+ ),
204
+ nprocs=world_size,
205
+ )
206
+ else:
207
+ logging.info("Single GPU training")
208
+ ClassificationModel._worker_ddp(
209
+ rank=0,
210
+ world_size=1,
211
+ mode="train",
212
+ config_path=self.config_path,
213
+ config=self.config,
214
+ return_queue=queue,
215
+ train_set=train_set,
216
+ valid_set=valid_set,
217
+ weights=weights,
218
+ use_wandb=use_wandb,
219
+ )
220
+
221
+ self.best_checkpoint = queue.get()
222
+ self.last_loaded_weights = self.best_checkpoint
223
+ return self.best_checkpoint
224
+
225
+ def infer(
226
+ self,
227
+ test_set=None,
228
+ weights=None,
229
+ use_ddp=False,
230
+ use_wandb=True,
231
+ **kwargs,
232
+ ):
233
+ """Run model inference and return predictions in OSL JSON format."""
234
+ del kwargs
235
+
236
+ import torch
237
+ import torch.multiprocessing as mp
238
+ from opensportslib.core.utils.config import resolve_config_omega
239
+
240
+ test_set = self._resolve_split_path("test", test_set)
241
+
242
+ self.config = resolve_config_omega(self.config)
243
+ logging.info("Configuration:")
244
+ logging.info(self.config)
245
+
246
+ world_size = torch.cuda.device_count()
247
+ use_ddp = use_ddp and world_size > 1
248
+
249
+ ctx = mp.get_context("spawn")
250
+ queue = ctx.Queue()
251
+
252
+ if use_ddp:
253
+ mp.spawn(
254
+ ClassificationModel._worker_ddp,
255
+ args=(
256
+ world_size,
257
+ "infer",
258
+ self.config_path,
259
+ self.config,
260
+ queue,
261
+ None,
262
+ None,
263
+ test_set,
264
+ weights,
265
+ use_wandb,
266
+ ),
267
+ nprocs=world_size,
268
+ )
269
+ else:
270
+ ClassificationModel._worker_ddp(
271
+ rank=0,
272
+ world_size=1,
273
+ mode="infer",
274
+ config_path=self.config_path,
275
+ config=self.config,
276
+ return_queue=queue,
277
+ test_set=test_set,
278
+ weights=weights,
279
+ use_wandb=use_wandb,
280
+ )
281
+
282
+ predictions = queue.get()
283
+ return predictions
284
+
285
+ def evaluate(
286
+ self,
287
+ test_set=None,
288
+ weights=None,
289
+ use_ddp=False,
290
+ use_wandb=True,
291
+ **kwargs,
292
+ ):
293
+ """Run inference on test set and return evaluation metrics."""
294
+ del kwargs
295
+
296
+ from opensportslib.datasets.builder import build_dataset
297
+ from opensportslib.core.trainer.classification_trainer import Trainer_Classification
298
+ from opensportslib.core.utils.config import resolve_config_omega
299
+
300
+ test_set = self._resolve_split_path("test", test_set)
301
+
302
+ self.config = resolve_config_omega(self.config)
303
+ logging.info("Configuration:")
304
+ logging.info(self.config)
305
+ predictions = self.infer(
306
+ test_set=test_set,
307
+ weights=weights,
308
+ use_ddp=use_ddp,
309
+ use_wandb=use_wandb,
310
+ )
311
+
312
+ self.trainer = self.trainer or Trainer_Classification(self.config)
313
+ test_data = build_dataset(self.config, test_set, None, split="test")
314
+ metrics = self.trainer.evaluate(
315
+ pred_path=predictions,
316
+ gt_path=test_set,
317
+ class_names=test_data.label_map,
318
+ exclude_labels=test_data.exclude_labels,
319
+ )
320
+
321
+ logging.info(f"TEST METRICS : {metrics}")
322
+ return metrics