opensportslib 0.1.2__tar.gz → 0.1.2.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {opensportslib-0.1.2/opensportslib.egg-info → opensportslib-0.1.2.dev2}/PKG-INFO +48 -15
  2. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/README.md +47 -14
  3. opensportslib-0.1.2.dev2/examples/quickstart/basic_classification.py +46 -0
  4. opensportslib-0.1.2.dev2/examples/quickstart/basic_localization.py +46 -0
  5. opensportslib-0.1.2.dev2/opensportslib/apis/__init__.py +15 -0
  6. opensportslib-0.1.2.dev2/opensportslib/apis/base_task_model.py +131 -0
  7. opensportslib-0.1.2.dev2/opensportslib/apis/classification.py +328 -0
  8. opensportslib-0.1.2.dev2/opensportslib/apis/localization.py +354 -0
  9. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/classification_trainer.py +57 -14
  10. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/localization_trainer.py +18 -20
  11. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/config.py +4 -4
  12. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/metrics/localization_metric.py +1 -1
  13. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2/opensportslib.egg-info}/PKG-INFO +48 -15
  14. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/SOURCES.txt +3 -1
  15. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/pyproject.toml +4 -1
  16. opensportslib-0.1.2.dev2/tests/conftest.py +359 -0
  17. opensportslib-0.1.2.dev2/tests/test_public_apis_smoke.py +38 -0
  18. opensportslib-0.1.2.dev2/tests/test_subset_train_infer_integration.py +292 -0
  19. opensportslib-0.1.2.dev2/tests/test_task_model_api_contract.py +375 -0
  20. opensportslib-0.1.2/examples/quickstart/basic_classification.py +0 -30
  21. opensportslib-0.1.2/examples/quickstart/basic_localization.py +0 -30
  22. opensportslib-0.1.2/opensportslib/apis/__init__.py +0 -21
  23. opensportslib-0.1.2/opensportslib/apis/classification.py +0 -364
  24. opensportslib-0.1.2/opensportslib/apis/localization.py +0 -239
  25. opensportslib-0.1.2/tests/conftest.py +0 -59
  26. opensportslib-0.1.2/tests/test_public_apis_smoke.py +0 -29
  27. opensportslib-0.1.2/tests/test_subset_train_infer_integration.py +0 -172
  28. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/LICENSE +0 -0
  29. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/LICENSE-COMMERCIAL +0 -0
  30. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/MANIFEST.in +0 -0
  31. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/__init__.py +0 -0
  32. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/cli.py +0 -0
  33. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/classification.yaml +0 -0
  34. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  35. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  36. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  37. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/localization.yaml +0 -0
  38. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-frames.yaml +0 -0
  39. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/config/sngar-tracking.yaml +0 -0
  40. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/__init__.py +0 -0
  41. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/__init__.py +0 -0
  42. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/builder.py +0 -0
  43. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/calf.py +0 -0
  44. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/ce.py +0 -0
  45. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/combine.py +0 -0
  46. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/loss/nll.py +0 -0
  47. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
  48. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/optimizer/builder.py +0 -0
  49. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  50. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
  51. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/scheduler/builder.py +0 -0
  52. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/trainer/__init__.py +0 -0
  53. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
  54. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/data.py +0 -0
  55. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/ddp.py +0 -0
  56. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/default_args.py +0 -0
  57. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/lightning.py +0 -0
  58. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
  59. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/seed.py +0 -0
  60. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/video_processing.py +0 -0
  61. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/core/utils/wandb.py +0 -0
  62. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/__init__.py +0 -0
  63. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/builder.py +0 -0
  64. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/classification_dataset.py +0 -0
  65. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/localization_dataset.py +0 -0
  66. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
  67. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
  68. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/metrics/classification_metric.py +0 -0
  69. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/__init__.py +0 -0
  70. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/backbones/builder.py +0 -0
  71. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/contextaware.py +0 -0
  72. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/e2e.py +0 -0
  73. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
  74. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/tracking.py +0 -0
  75. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/vars.py +0 -0
  76. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video.py +0 -0
  77. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/base/video_mae.py +0 -0
  78. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/builder.py +0 -0
  79. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/heads/builder.py +0 -0
  80. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/neck/builder.py +0 -0
  81. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/common.py +0 -0
  82. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
  83. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
  84. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
  85. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
  86. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
  87. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
  88. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/litebase.py +0 -0
  89. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/modules.py +0 -0
  90. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/shift.py +0 -0
  91. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/models/utils/utils.py +0 -0
  92. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib/setup/setup.py +0 -0
  93. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
  94. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/entry_points.txt +0 -0
  95. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/requires.txt +0 -0
  96. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/opensportslib.egg-info/top_level.txt +0 -0
  97. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/setup.cfg +0 -0
  98. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/tests/test_config_utils_smoke.py +0 -0
  99. {opensportslib-0.1.2 → opensportslib-0.1.2.dev2}/tests/test_package_smoke.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2
3
+ Version: 0.1.2.dev2
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -123,32 +123,45 @@ print("OpenSportsLib imported successfully")
123
123
  ### Train a classification model
124
124
 
125
125
  ```python
126
- from opensportslib import model
126
+ from opensportslib.apis import ClassificationModel
127
127
 
128
- myModel = model.classification(
129
- config="/path/to/classification.yaml"
128
+ my_model = ClassificationModel(
129
+ config="/path/to/classification.yaml",
130
+ weights="/path/to/weights.pt", # optional
130
131
  )
131
132
 
132
- myModel.train(
133
+ my_model.train(
133
134
  train_set="/path/to/train_annotations.json",
134
135
  valid_set="/path/to/valid_annotations.json",
135
- pretrained="/path/to/pretrained.pt", # optional
136
136
  )
137
137
  ```
138
138
 
139
139
  ### Run inference
140
140
 
141
141
  ```python
142
- from opensportslib import model
142
+ from opensportslib.apis import ClassificationModel
143
143
 
144
- myModel = model.classification(
145
- config="/path/to/classification.yaml"
144
+ my_model = ClassificationModel(
145
+ config="/path/to/classification.yaml",
146
+ weights="/path/to/weights.pt", # optional
146
147
  )
147
148
 
148
- metrics = myModel.infer(
149
+ predictions = my_model.infer(
149
150
  test_set="/path/to/test_annotations.json",
150
- pretrained="/path/to/checkpoints/final_model",
151
- predictions="/path/to/predictions.json"
151
+ )
152
+
153
+ saved_predictions = my_model.save_predictions(
154
+ output_path="/path/to/predictions.json",
155
+ predictions=predictions,
156
+ )
157
+
158
+ metrics = my_model.evaluate(
159
+ test_set="/path/to/test_annotations.json",
160
+ )
161
+
162
+ metrics_from_file = my_model.evaluate(
163
+ test_set="/path/to/test_annotations.json",
164
+ predictions=saved_predictions,
152
165
  )
153
166
 
154
167
  print(metrics)
@@ -157,10 +170,29 @@ print(metrics)
157
170
  ### Localization example
158
171
 
159
172
  ```python
160
- from opensportslib import model
173
+ from opensportslib.apis import LocalizationModel
174
+
175
+ my_model = LocalizationModel(
176
+ config="/path/to/localization.yaml",
177
+ weights="/path/to/weights.pt", # optional
178
+ )
161
179
 
162
- myModel = model.localization(
163
- config="/path/to/localization.yaml"
180
+ predictions = my_model.infer(
181
+ test_set="/path/to/test_annotations.json",
182
+ )
183
+
184
+ saved_predictions = my_model.save_predictions(
185
+ output_path="/path/to/predictions.json",
186
+ predictions=predictions,
187
+ )
188
+
189
+ metrics = my_model.evaluate(
190
+ test_set="/path/to/test_annotations.json",
191
+ )
192
+
193
+ metrics_from_file = my_model.evaluate(
194
+ test_set="/path/to/test_annotations.json",
195
+ predictions=saved_predictions,
164
196
  )
165
197
  ```
166
198
 
@@ -198,6 +230,7 @@ Generate text descriptions for sports events and temporal segments.
198
230
  Use the README for the fast start, then go deeper through:
199
231
 
200
232
  - Full documentation: https://opensportslab.github.io/opensportslib/
233
+ - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
201
234
  - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
202
235
  - Example configs: [examples/configs/](examples/configs/)
203
236
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
@@ -92,32 +92,45 @@ print("OpenSportsLib imported successfully")
92
92
  ### Train a classification model
93
93
 
94
94
  ```python
95
- from opensportslib import model
95
+ from opensportslib.apis import ClassificationModel
96
96
 
97
- myModel = model.classification(
98
- config="/path/to/classification.yaml"
97
+ my_model = ClassificationModel(
98
+ config="/path/to/classification.yaml",
99
+ weights="/path/to/weights.pt", # optional
99
100
  )
100
101
 
101
- myModel.train(
102
+ my_model.train(
102
103
  train_set="/path/to/train_annotations.json",
103
104
  valid_set="/path/to/valid_annotations.json",
104
- pretrained="/path/to/pretrained.pt", # optional
105
105
  )
106
106
  ```
107
107
 
108
108
  ### Run inference
109
109
 
110
110
  ```python
111
- from opensportslib import model
111
+ from opensportslib.apis import ClassificationModel
112
112
 
113
- myModel = model.classification(
114
- config="/path/to/classification.yaml"
113
+ my_model = ClassificationModel(
114
+ config="/path/to/classification.yaml",
115
+ weights="/path/to/weights.pt", # optional
115
116
  )
116
117
 
117
- metrics = myModel.infer(
118
+ predictions = my_model.infer(
118
119
  test_set="/path/to/test_annotations.json",
119
- pretrained="/path/to/checkpoints/final_model",
120
- predictions="/path/to/predictions.json"
120
+ )
121
+
122
+ saved_predictions = my_model.save_predictions(
123
+ output_path="/path/to/predictions.json",
124
+ predictions=predictions,
125
+ )
126
+
127
+ metrics = my_model.evaluate(
128
+ test_set="/path/to/test_annotations.json",
129
+ )
130
+
131
+ metrics_from_file = my_model.evaluate(
132
+ test_set="/path/to/test_annotations.json",
133
+ predictions=saved_predictions,
121
134
  )
122
135
 
123
136
  print(metrics)
@@ -126,10 +139,29 @@ print(metrics)
126
139
  ### Localization example
127
140
 
128
141
  ```python
129
- from opensportslib import model
142
+ from opensportslib.apis import LocalizationModel
143
+
144
+ my_model = LocalizationModel(
145
+ config="/path/to/localization.yaml",
146
+ weights="/path/to/weights.pt", # optional
147
+ )
130
148
 
131
- myModel = model.localization(
132
- config="/path/to/localization.yaml"
149
+ predictions = my_model.infer(
150
+ test_set="/path/to/test_annotations.json",
151
+ )
152
+
153
+ saved_predictions = my_model.save_predictions(
154
+ output_path="/path/to/predictions.json",
155
+ predictions=predictions,
156
+ )
157
+
158
+ metrics = my_model.evaluate(
159
+ test_set="/path/to/test_annotations.json",
160
+ )
161
+
162
+ metrics_from_file = my_model.evaluate(
163
+ test_set="/path/to/test_annotations.json",
164
+ predictions=saved_predictions,
133
165
  )
134
166
  ```
135
167
 
@@ -167,6 +199,7 @@ Generate text descriptions for sports events and temporal segments.
167
199
  Use the README for the fast start, then go deeper through:
168
200
 
169
201
  - Full documentation: https://opensportslab.github.io/opensportslib/
202
+ - High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
170
203
  - Configuration guide: https://opensportslab.github.io/opensportslib/tni/config-guide/
171
204
  - Example configs: [examples/configs/](examples/configs/)
172
205
  - Quickstart scripts: [examples/quickstart/](examples/quickstart/)
@@ -0,0 +1,46 @@
1
+ from opensportslib.apis import ClassificationModel
2
+
3
+
4
+ def main():
5
+ """
6
+ Minimal classification example.
7
+ Update config and dataset paths before running.
8
+ """
9
+
10
+ my_model = ClassificationModel(
11
+ config="examples/configs/classification_video.yaml",
12
+ weights="/path/to/weights.pt", # optional
13
+ )
14
+
15
+ my_model.train(
16
+ train_set="/path/to/train_annotations.json",
17
+ valid_set="/path/to/valid_annotations.json",
18
+ )
19
+
20
+ predictions = my_model.infer(
21
+ test_set="/path/to/test_annotations.json",
22
+ )
23
+
24
+ print(predictions)
25
+
26
+ metrics = my_model.evaluate(
27
+ test_set="/path/to/test_annotations.json",
28
+ )
29
+
30
+ print(metrics)
31
+
32
+ saved_predictions = my_model.save_predictions(
33
+ output_path="/path/to/predictions.json",
34
+ predictions=predictions,
35
+ )
36
+
37
+ metrics_from_file = my_model.evaluate(
38
+ test_set="/path/to/test_annotations.json",
39
+ predictions=saved_predictions,
40
+ )
41
+
42
+ print(metrics_from_file)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
@@ -0,0 +1,46 @@
1
+ from opensportslib.apis import LocalizationModel
2
+
3
+
4
+ def main():
5
+ """
6
+ Minimal localization example.
7
+ Update config and dataset paths before running.
8
+ """
9
+
10
+ my_model = LocalizationModel(
11
+ config="examples/configs/localization.yaml",
12
+ weights="/path/to/weights.pt", # optional
13
+ )
14
+
15
+ my_model.train(
16
+ train_set="/path/to/train_annotations.json",
17
+ valid_set="/path/to/valid_annotations.json",
18
+ )
19
+
20
+ predictions = my_model.infer(
21
+ test_set="/path/to/test_annotations.json",
22
+ )
23
+
24
+ print(predictions)
25
+
26
+ metrics = my_model.evaluate(
27
+ test_set="/path/to/test_annotations.json",
28
+ )
29
+
30
+ print(metrics)
31
+
32
+ saved_predictions = my_model.save_predictions(
33
+ output_path="/path/to/predictions.json",
34
+ predictions=predictions,
35
+ )
36
+
37
+ metrics_from_file = my_model.evaluate(
38
+ test_set="/path/to/test_annotations.json",
39
+ predictions=saved_predictions,
40
+ )
41
+
42
+ print(metrics_from_file)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
@@ -0,0 +1,15 @@
1
+ # opensportslib/apis/__init__.py
2
+
3
+ # Import task APIs
4
+ from opensportslib.apis.base_task_model import BaseTaskModel
5
+ from opensportslib.apis.classification import ClassificationModel
6
+ from opensportslib.apis.localization import LocalizationModel
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ # Expose only these
11
+ __all__ = [
12
+ "BaseTaskModel",
13
+ "ClassificationModel",
14
+ "LocalizationModel",
15
+ ]
@@ -0,0 +1,131 @@
1
+ """Shared task-level wrapper base for OpenSportsLib APIs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import uuid
9
+ from abc import ABC, abstractmethod
10
+ from typing import Any
11
+
12
+ from opensportslib.core.utils.config import expand, load_config_omega
13
+
14
+
15
+ class BaseTaskModel(ABC):
16
+ """Thin shared contract for task-level OpenSportsLib wrappers."""
17
+
18
+ def __init__(self, config=None, weights=None):
19
+ self._configure_logging()
20
+
21
+ if config is None:
22
+ raise ValueError("config path is required")
23
+
24
+ self.config_path = expand(config)
25
+ self.config = load_config_omega(self.config_path)
26
+
27
+ data_cfg = getattr(self.config, "DATA", None)
28
+ if data_cfg is not None and hasattr(data_cfg, "data_dir"):
29
+ data_cfg.data_dir = expand(data_cfg.data_dir)
30
+ logging.info(f"Data directory: {data_cfg.data_dir}")
31
+
32
+ self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
33
+ os.environ["RUN_ID"] = self.run_id
34
+
35
+ system_cfg = getattr(self.config, "SYSTEM", None)
36
+ if system_cfg is not None:
37
+ base_save_dir = expand(getattr(system_cfg, "save_dir", None) or "./checkpoints")
38
+ model_cfg = getattr(self.config, "MODEL", None)
39
+ backbone_cfg = getattr(model_cfg, "backbone", None)
40
+ model_name = getattr(backbone_cfg, "type", None) or "model"
41
+ run_save_dir = os.path.join(base_save_dir, model_name, self.run_id)
42
+ self.save_dir = run_save_dir
43
+ system_cfg.save_dir = run_save_dir
44
+ if hasattr(system_cfg, "work_dir"):
45
+ system_cfg.work_dir = run_save_dir
46
+ os.makedirs(run_save_dir, exist_ok=True)
47
+ else:
48
+ self.save_dir = expand("./checkpoints")
49
+ os.makedirs(self.save_dir, exist_ok=True)
50
+
51
+ logging.info(f"Save directory: {self.save_dir}")
52
+
53
+ self.model = None
54
+ self.processor = None
55
+ self.trainer = None
56
+ self.best_checkpoint = None
57
+ self.last_loaded_weights = None
58
+
59
+ if weights is not None:
60
+ self.load_weights(weights=weights)
61
+
62
+ @staticmethod
63
+ def _configure_logging() -> None:
64
+ root_logger = logging.getLogger()
65
+ if not root_logger.handlers:
66
+ logging.basicConfig(
67
+ level=logging.INFO,
68
+ format="%(asctime)s | %(levelname)s | %(message)s",
69
+ )
70
+ elif root_logger.level > logging.INFO:
71
+ root_logger.setLevel(logging.INFO)
72
+
73
+ @abstractmethod
74
+ def load_weights(
75
+ self,
76
+ weights: str | None = None,
77
+ **kwargs,
78
+ ) -> None:
79
+ raise NotImplementedError
80
+
81
+ @abstractmethod
82
+ def train(
83
+ self,
84
+ train_set: str | None = None,
85
+ valid_set: str | None = None,
86
+ weights: str | None = None,
87
+ use_wandb: bool = True,
88
+ **kwargs,
89
+ ) -> str | None:
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def infer(
94
+ self,
95
+ test_set: str | None = None,
96
+ weights: str | None = None,
97
+ use_wandb: bool = True,
98
+ **kwargs,
99
+ ) -> dict:
100
+ raise NotImplementedError
101
+
102
+ @abstractmethod
103
+ def evaluate(
104
+ self,
105
+ test_set: str | None = None,
106
+ weights: str | None = None,
107
+ predictions: str | dict[str, Any] | None = None,
108
+ use_wandb: bool = True,
109
+ **kwargs,
110
+ ) -> dict | str | None:
111
+ raise NotImplementedError
112
+
113
+ def save_predictions(
114
+ self,
115
+ output_path: str,
116
+ predictions: dict,
117
+ ) -> str:
118
+ """Persist in-memory prediction JSON payload to a target file path."""
119
+
120
+ dst = expand(output_path)
121
+ os.makedirs(os.path.dirname(dst) or ".", exist_ok=True)
122
+
123
+ if not isinstance(predictions, dict):
124
+ raise TypeError(
125
+ f"Unsupported predictions type: {type(predictions).__name__}. "
126
+ "Expected dict."
127
+ )
128
+
129
+ with open(dst, "w", encoding="utf-8") as f:
130
+ json.dump(predictions, f)
131
+ return dst