opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,18 @@
1
+ import importlib
2
+ # from . import apis as model
3
+ # from . import metrics
4
+ # from . import datasets
5
+ # from . import core
6
+
7
+ def __getattr__(name):
8
+ if name == "model":
9
+ return importlib.import_module("opensportslib.apis")
10
+ if name == "metrics":
11
+ return importlib.import_module("opensportslib.metrics")
12
+ if name == "datasets":
13
+ return importlib.import_module("opensportslib.datasets")
14
+ if name == "core":
15
+ return importlib.import_module("opensportslib.core")
16
+ raise AttributeError(f"module 'opensportslib' has no attribute '{name}'")
17
+
18
+ __all__ = ["model", "metrics", "datasets", "core"]
@@ -0,0 +1,21 @@
1
+ # opensportslib/apis/__init__.py
2
+
3
+ # Import task APIs
4
+ from opensportslib.apis.classification import ClassificationAPI
5
+ from opensportslib.apis.localization import LocalizationAPI
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ # Factory functions for user-facing calls
10
+ def classification(config=None, data_dir=None, save_dir=None):#, pretrained_model=None):
11
+ return ClassificationAPI(config=config, data_dir=data_dir, save_dir=save_dir)#,pretrained_model=pretrained_model)
12
+
13
+ def localization(config=None, data_dir=None, save_dir=None):#, pretrained_model=None):
14
+ return LocalizationAPI(config=config, data_dir=data_dir, save_dir=save_dir)#,pretrained_model=pretrained_model)
15
+
16
+
17
+ # Expose only these
18
+ __all__ = [
19
+ "classification",
20
+ "localization",
21
+ ]
@@ -0,0 +1,361 @@
1
+ # opensportslib/apis/classification.py
2
+
3
+ """public API for classification tasks.
4
+
5
+ supports three dataset/task combinations:
6
+ - MVFoul (video-based foul classification)
7
+ - SN-GAR with video modality
8
+ - SN-GAR with tracking modality
9
+
10
+ handles single-GPU and multi-GPU (DDP) training and inference,
11
+ delegating heavy lifting to Trainer_Classification.
12
+ """
13
+
14
+ import os
15
+ import logging
16
+ from opensportslib.core.utils.config import expand
17
+
18
+
19
+ class ClassificationAPI:
20
+ """top-level entry point for classification training and inference.
21
+
22
+ loads a YAML config, resolves paths, and exposes train() /
23
+ infer() methods that transparently handle single-GPU and
24
+ DDP execution.
25
+
26
+ Args:
27
+ config: path to the YAML configuration file.
28
+ data_dir: override for DATA.data_dir in the config.
29
+ if None, the value from the config is used.
30
+ save_dir: override for the checkpoint output directory.
31
+ falls back to SYSTEM.save_dir, then "./checkpoints".
32
+ """
33
+
34
+ def __init__(self, config=None, data_dir=None, save_dir=None):
35
+ from opensportslib.core.utils.config import (
36
+ load_config_omega
37
+ )
38
+ import uuid
39
+
40
+ if config is None:
41
+ raise ValueError("config path is required")
42
+
43
+ config_path = expand(config)
44
+ self.config = load_config_omega(config_path)
45
+
46
+ # let the caller override the dataset root directory.
47
+ self.config.DATA.data_dir = expand(
48
+ data_dir or self.config.DATA.data_dir
49
+ )
50
+
51
+ # checkpoint output directory; never derived from BASE_DIR so the
52
+ # user always has explicit control over where weights are written.
53
+ self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
54
+ os.environ["RUN_ID"] = self.run_id
55
+
56
+ self.save_dir = expand(
57
+ save_dir or self.config.SYSTEM.save_dir or "./checkpoints"
58
+ )
59
+ save_filename = os.path.join(self.config.MODEL.backbone.type, self.run_id)
60
+ self.config.SYSTEM.save_dir = os.path.join(self.save_dir, save_filename)
61
+ os.makedirs(self.config.SYSTEM.save_dir, exist_ok=True)
62
+
63
+ # DDP rank; used for logging and checkpointing.
64
+ rank = int(os.environ.get("RANK", 0))
65
+ self.trainer=None
66
+ self.best_checkpoint=None
67
+
68
+ log_dir = expand(self.config.SYSTEM.log_dir or "./log_dir")
69
+ os.makedirs(os.path.join(self.config.SYSTEM.save_dir, log_dir), exist_ok=True)
70
+ logging.basicConfig(
71
+ level=logging.INFO,
72
+ format="%(asctime)s | %(levelname)s | %(message)s",
73
+ handlers=[
74
+ logging.FileHandler(os.path.join(log_dir, "train.log")),
75
+ logging.StreamHandler(),
76
+ ],
77
+ force=True,
78
+ )
79
+ if rank == 0:
80
+ logging.info(f"DATA DIR : {self.config.DATA.data_dir}")
81
+ logging.info(f"MODEL SAVEDIR: {self.config.SYSTEM.save_dir}")
82
+
83
+ # -----------------------------------------------------------------
84
+ # internal DDP worker
85
+ # -----------------------------------------------------------------
86
+ @staticmethod
87
+ def _worker_ddp(
88
+ rank,
89
+ world_size,
90
+ mode,
91
+ config,
92
+ return_queue=None,
93
+ train_set=None,
94
+ valid_set=None,
95
+ test_set=None,
96
+ pretrained=None,
97
+ use_wandb=False
98
+ ):
99
+ """execute a single training or inference job on one GPU.
100
+
101
+ spawned once per GPU by train() / infer(). Each process gets
102
+ its own Trainer_Classification instance so that no mutable
103
+ state is shared across ranks.
104
+
105
+ Args:
106
+ rank: GPU rank (0-indexed).
107
+ world_size: total number of participating GPUs.
108
+ mode: "train" or "infer".
109
+ return_queue: multiprocessing.Queue used by rank-0 to
110
+ return metrics to the calling process (inference only).
111
+ train_set: path to the training set annotations file.
112
+ valid_set: path to the validation set annotations file.
113
+ test_set: path to the test set annotations file.
114
+ pretrained: path to a saved checkpoint for warm-starting
115
+ or inference.
116
+ """
117
+ import torch
118
+ from opensportslib.core.trainer.classification_trainer import Trainer_Classification
119
+ from opensportslib.core.utils.ddp import ddp_setup, ddp_cleanup
120
+ from opensportslib.core.utils.wandb import init_wandb
121
+ from opensportslib.core.utils.seed import set_reproducibility
122
+ from opensportslib.datasets.builder import build_dataset
123
+ from opensportslib.models.builder import build_model
124
+ import logging
125
+
126
+ # configure logging for each spawned process
127
+ logging.basicConfig(
128
+ level=logging.INFO,
129
+ format=f"[RANK {rank}] %(asctime)s | %(levelname)s | %(message)s",
130
+ force=True,
131
+ )
132
+ # silence non-rank0 processes
133
+ if rank != 0:
134
+ logging.getLogger().setLevel(logging.ERROR)
135
+
136
+ if rank == 0:
137
+ init_wandb(config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
138
+
139
+ # reproducibility:
140
+ # we default to reproducible training, but allow the user to
141
+ # disable this via SYSTEM.use_seed=False in the config.
142
+ if getattr(config.SYSTEM, "use_seed", False):
143
+ set_reproducibility(config.SYSTEM.seed)
144
+
145
+ is_ddp = world_size > 1
146
+ if is_ddp:
147
+ torch.cuda.set_device(rank)
148
+ ddp_setup(rank, world_size)
149
+ device = torch.device(f"cuda:{rank}")
150
+ else:
151
+ from opensportslib.core.utils.config import select_device
152
+ device = select_device(config.SYSTEM)
153
+
154
+ # each process creates a fresh trainer to avoid shared mutable state.
155
+ trainer = Trainer_Classification(config)
156
+ trainer.device = device
157
+
158
+ # build or restore the model.
159
+ if pretrained:
160
+ model, processor, scheduler, epoch = trainer.load(pretrained)
161
+ else:
162
+ model, processor = build_model(config, device)
163
+
164
+ trainer.model = model
165
+
166
+ if mode == "train":
167
+ train_data = build_dataset(
168
+ config, train_set, processor, split="train"
169
+ )
170
+ valid_data = build_dataset(
171
+ config, valid_set, processor, split="valid"
172
+ )
173
+ best_ckpt = trainer.train(
174
+ model, train_data, valid_data,
175
+ rank=rank, world_size=world_size
176
+ )
177
+ # SEND BACK CHECKPOINT FROM RANK 0
178
+ if rank == 0 and return_queue is not None:
179
+ best_ckpt = getattr(trainer.trainer, "best_checkpoint_path", None)
180
+ return_queue.put(best_ckpt)
181
+
182
+ elif mode == "infer":
183
+ test_data = build_dataset(
184
+ config, test_set, processor, split="test"
185
+ )
186
+
187
+ metrics = trainer.infer(
188
+ test_data, rank=rank, world_size=world_size
189
+ )
190
+
191
+ if rank == 0 and return_queue is not None:
192
+ return_queue.put(metrics)
193
+
194
+ if is_ddp:
195
+ ddp_cleanup()
196
+
197
+ # -----------------------------------------------------------------
198
+ # public training interface
199
+ # -----------------------------------------------------------------
200
+
201
+ def train(
202
+ self,
203
+ train_set=None,
204
+ valid_set=None,
205
+ test_set=None,
206
+ pretrained=None,
207
+ use_ddp=False,
208
+ use_wandb=True
209
+ ):
210
+ """run a full training loop.
211
+
212
+ Args:
213
+ train_set: path to training annotationns. defaults to the
214
+ value in the loaded config.
215
+ valid_set: path to validation annotations.
216
+ test_set: currently unused.
217
+ pretrained: optional checkpoint path for warm-starting.
218
+ use_ddp: if True and more than one GPU is visible,
219
+ spawn one process per GPU via torch.multiprocessing.spawn.
220
+ """
221
+ import torch
222
+ import torch.multiprocessing as mp
223
+ from opensportslib.core.utils.config import (
224
+ resolve_config_omega
225
+ )
226
+
227
+ train_set = expand(train_set or self.config.DATA.annotations.train)
228
+ valid_set = expand(valid_set or self.config.DATA.annotations.valid)
229
+
230
+ self.config = resolve_config_omega(self.config)
231
+ logging.info("Configuration:")
232
+ logging.info(self.config)
233
+
234
+ world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
235
+ use_ddp = use_ddp and world_size > 1
236
+
237
+ ctx = mp.get_context("spawn")
238
+ queue = ctx.Queue()
239
+
240
+ if use_ddp:
241
+ logging.info(f"Launching DDP on {world_size} GPUs")
242
+ mp.spawn(
243
+ ClassificationAPI._worker_ddp,
244
+ args=(
245
+ world_size, "train", self.config, queue,
246
+ train_set, valid_set, None, pretrained, use_wandb
247
+ ),
248
+ nprocs=world_size,
249
+ )
250
+ else:
251
+ logging.info("Single GPU training")
252
+ ClassificationAPI._worker_ddp(
253
+ rank=0,
254
+ world_size=1,
255
+ mode="train",
256
+ config=self.config,
257
+ return_queue=queue,
258
+ train_set=train_set,
259
+ valid_set=valid_set,
260
+ pretrained=pretrained,
261
+ use_wandb=use_wandb
262
+ )
263
+
264
+ self.best_checkpoint = queue.get()
265
+ return self.best_checkpoint
266
+
267
+
268
+ def infer(
269
+ self,
270
+ test_set=None,
271
+ pretrained=None,
272
+ predictions=None,
273
+ use_ddp=False,
274
+ use_wandb=True
275
+ ):
276
+ """run inference or evaluate saved predictions.
277
+
278
+ when "predictions" is None, the model runs a forward pass
279
+ over the test set and returns live metrics. when "predictions"
280
+ points to a saved prediction file, only the evaluation step runs
281
+ (no GPU needed).
282
+
283
+ Args:
284
+ test_set: path to test annotations.
285
+ pretrained: checkpoint path (required when running live inference).
286
+ predictions: path to a previously saved prediction file.
287
+ if provided, evaluation is run offline without a model.
288
+ use_ddp: if True, distribute inference across all visible GPUs.
289
+
290
+ Returns:
291
+ a metrics dictionary produced by the trainer.
292
+ """
293
+ import torch
294
+ import torch.multiprocessing as mp
295
+ from opensportslib.core.utils.config import (
296
+ resolve_config_omega
297
+ )
298
+
299
+ test_set = expand(test_set or self.config.DATA.annotations.test)
300
+
301
+ self.config = resolve_config_omega(self.config)
302
+ logging.info("Configuration:")
303
+ logging.info(self.config)
304
+
305
+ if pretrained is None and predictions is None:
306
+ if hasattr(self, "best_checkpoint"):
307
+ pretrained = self.best_checkpoint
308
+ logging.info(f"Using last trained checkpoint: {pretrained}")
309
+ else:
310
+ raise ValueError("No pretrained checkpoint provided and no training run found.")
311
+
312
+ if not predictions:
313
+ # live inference: run the model on test data.
314
+ world_size = torch.cuda.device_count()
315
+ use_ddp = use_ddp and world_size > 1
316
+
317
+ ctx = mp.get_context("spawn")
318
+ queue = ctx.Queue()
319
+
320
+ if use_ddp:
321
+ mp.spawn(
322
+ ClassificationAPI._worker_ddp,
323
+ args=(
324
+ world_size, "infer", self.config, queue,
325
+ None, None, test_set, pretrained, use_wandb
326
+ ),
327
+ nprocs=world_size,
328
+ )
329
+ else:
330
+ ClassificationAPI._worker_ddp(
331
+ rank=0,
332
+ world_size=1,
333
+ mode="infer",
334
+ config=self.config,
335
+ return_queue=queue,
336
+ test_set=test_set,
337
+ pretrained=pretrained,
338
+ use_wandb=use_wandb
339
+ )
340
+
341
+ # rank-0 pushes metrics into the queue; retrieve them here.
342
+ metrics = queue.get()
343
+ else:
344
+ # offline evaluation from a saved prediction file.
345
+ from opensportslib.datasets.builder import build_dataset
346
+ from opensportslib.core.trainer.classification_trainer import Trainer_Classification
347
+
348
+ self.trainer = Trainer_Classification(self.config)
349
+ test_data = build_dataset(
350
+ self.config, test_set, None, split="test"
351
+ )
352
+ metrics = self.trainer.evaluate(
353
+ pred_path=predictions,
354
+ gt_path=test_set,
355
+ class_names=test_data.label_map,
356
+ exclude_labels=test_data.exclude_labels
357
+ )
358
+
359
+ logging.info(f"TEST METRICS : {metrics}")
360
+
361
+ return metrics
@@ -0,0 +1,228 @@
1
+ from opensportslib.core.utils.config import expand
2
+ import os
3
+ import logging
4
+ import time
5
+
6
+ class LocalizationAPI:
7
+ def __init__(self, config=None, data_dir=None, save_dir=None):
8
+ from opensportslib.core.utils.config import load_config_omega
9
+ #from ..core.trainer import Trainer
10
+ import uuid
11
+
12
+ if config is None:
13
+ raise ValueError("config path is required")
14
+
15
+ # Load config
16
+ ### load data_dor first then do load config with omega to resolve $paths
17
+ config_path = expand(config)
18
+ self.config = load_config_omega(config_path)
19
+ # User must control dataset folder
20
+ self.config.DATA.data_dir = expand(data_dir or self.config.DATA.data_dir)
21
+ print(self.config.DATA.classes)
22
+ # User controls model saving location (never use BASE_DIR)
23
+
24
+ self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
25
+ os.environ["RUN_ID"] = self.run_id
26
+
27
+ self.save_dir = expand(
28
+ save_dir or self.config.SYSTEM.save_dir or "./checkpoints"
29
+ )
30
+ save_filename = os.path.join(self.config.MODEL.backbone.type, self.run_id)
31
+ self.config.SYSTEM.save_dir = os.path.join(self.save_dir, save_filename)
32
+ os.makedirs(self.config.SYSTEM.save_dir, exist_ok=True)
33
+
34
+ log_dir = expand(self.config.SYSTEM.log_dir or "./log_dir")
35
+ os.makedirs(os.path.join(self.config.SYSTEM.save_dir, log_dir), exist_ok=True)
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format="%(asctime)s | %(levelname)s | %(message)s",
39
+ handlers=[
40
+ logging.FileHandler(os.path.join(log_dir, "train.log")),
41
+ logging.StreamHandler(), # still prints to console
42
+ ],
43
+ )
44
+
45
+ self.best_checkpoint=None
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ print("CONFIG PATH :", config_path)
50
+ print("DATA DIR :", self.config.DATA.data_dir)
51
+ print("SAVEDIR:", self.config.SYSTEM.save_dir)
52
+ print("Classes :", self.config.DATA.classes)
53
+
54
+ #self.trainer = Trainer(self.config)
55
+
56
+
57
+ def train(self, train_set=None, valid_set=None, pretrained=None, use_ddp=False, use_wandb=True):
58
+ from opensportslib.datasets.builder import build_dataset
59
+ from opensportslib.models.builder import build_model
60
+ from opensportslib.core.trainer.localization_trainer import build_trainer
61
+ from opensportslib.core.utils.default_args import get_default_args_trainer, get_default_args_train
62
+ from opensportslib.core.utils.config import select_device, resolve_config_omega
63
+ from opensportslib.core.utils.load_annotations import check_config
64
+ from opensportslib.core.utils.wandb import init_wandb
65
+ import random
66
+ import numpy as np
67
+ import torch
68
+ # # Load model
69
+ # if pretrained:
70
+ # self.model, self.processor, _ = self.trainer.load(expand(pretrained))
71
+ # else:
72
+ # self.model, self.processor = build_model(self.config, self.trainer.device)
73
+ # Expand annotation paths (user or config)
74
+ self.config.DATA.train.path = expand(train_set or self.config.DATA.train.path)
75
+ self.config.DATA.valid.path = expand(valid_set or self.config.DATA.valid.path)
76
+
77
+ self.config = resolve_config_omega(self.config)
78
+ check_config(self.config, split="train")
79
+ init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
80
+ logging.info("Configuration:")
81
+ logging.info(self.config)
82
+ #print(self.config)
83
+
84
+ def set_seed(seed):
85
+ random.seed(seed) # Python random module
86
+ np.random.seed(seed) # NumPy
87
+ torch.manual_seed(seed) # PyTorch
88
+ torch.cuda.manual_seed(seed) # PyTorch CUDA
89
+ torch.cuda.manual_seed_all(seed) # Multi-GPU training
90
+
91
+ # Ensures deterministic behavior
92
+ torch.backends.cudnn.deterministic = True
93
+ torch.backends.cudnn.benchmark = False
94
+
95
+ # Ensures deterministic behavior for CUDA operations
96
+ torch.use_deterministic_algorithms(True, warn_only=True)
97
+
98
+ set_seed(self.config.SYSTEM.seed)
99
+ # Start Timing
100
+ start = time.time()
101
+
102
+ device = select_device(self.config.SYSTEM)
103
+ self.model = build_model(self.config, device=device)
104
+ print(self.model)
105
+
106
+
107
+ # Datasets
108
+ # Train
109
+ data_obj_train = build_dataset(self.config, split="train")
110
+ dataset_Train = data_obj_train.building_dataset(
111
+ cfg=data_obj_train.cfg,
112
+ gpu=self.config.SYSTEM.GPU,
113
+ default_args=data_obj_train.default_args,
114
+ )
115
+ train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
116
+ print(len(train_loader))
117
+ # Valid
118
+ data_obj_valid = build_dataset(self.config,split="valid")
119
+ dataset_Valid = data_obj_valid.building_dataset(
120
+ cfg=data_obj_valid.cfg,
121
+ gpu= self.config.SYSTEM.GPU,
122
+ default_args=data_obj_valid.default_args,
123
+ )
124
+ valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
125
+ print(len(valid_loader))
126
+
127
+ # Trainer
128
+ trainer = build_trainer(
129
+ cfg=self.config,
130
+ model=self.model,
131
+ default_args=get_default_args_trainer(self.config, len(train_loader)),
132
+ resume_from = pretrained
133
+ )
134
+ # Start training`
135
+ logging.info("Start training")
136
+
137
+ trainer.train(
138
+ **get_default_args_train(
139
+ self.model,
140
+ train_loader,
141
+ valid_loader,
142
+ self.config.DATA.classes,
143
+ self.config.TRAIN.type,
144
+ )
145
+ )
146
+ self.best_checkpoint = trainer.best_checkpoint_path
147
+
148
+ logging.info(f"Total Execution Time is {time.time()-start} seconds")
149
+ return self.best_checkpoint
150
+
151
+
152
+ def infer(self, test_set=None, pretrained=None, predictions=None, use_ddp=False, use_wandb=True):
153
+ from opensportslib.datasets.builder import build_dataset
154
+ from opensportslib.models.builder import build_model
155
+ from opensportslib.core.trainer.localization_trainer import build_inferer, build_evaluator
156
+ from opensportslib.core.utils.config import select_device, resolve_config_omega, is_local_path
157
+ from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
158
+ from opensportslib.core.utils.load_annotations import check_config, has_localization_events
159
+ from opensportslib.core.utils.wandb import init_wandb
160
+ import time
161
+
162
+ self.config.DATA.test.path = expand(test_set or self.config.DATA.test.path)
163
+ self.config.MODEL.multi_gpu = False
164
+ self.config = resolve_config_omega(self.config)
165
+ check_config(self.config, split="test")
166
+ init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
167
+ logging.info("Configuration:")
168
+ logging.info(self.config)
169
+ # Start Timing
170
+ start = time.time()
171
+ if pretrained is None and predictions is None:
172
+ if hasattr(self, "best_checkpoint"):
173
+ pretrained = self.best_checkpoint
174
+ print(f"Using last trained checkpoint: {pretrained}")
175
+ else:
176
+ raise ValueError("No pretrained checkpoint provided and no training run found.")
177
+
178
+ if not predictions:
179
+ logging.info("No predictions provided, running inference.")
180
+ device = select_device(self.config.SYSTEM)
181
+ self.model = build_model(self.config, device=device)
182
+ print("Model type:", type(self.model))
183
+ print("Torch model type:", type(self.model._model))
184
+ # Load model
185
+ if pretrained:
186
+ #pretrained = expand(pretrained)
187
+ if is_local_path(pretrained):
188
+ self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(pretrained))
189
+
190
+ self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
191
+ path=pretrained,
192
+ device=device,
193
+ key_remap_fn=localization_remap)
194
+
195
+ # Datasets
196
+ # Test
197
+ data_obj_test = build_dataset(self.config, split="test")
198
+ dataset_Test = data_obj_test.building_dataset(
199
+ cfg=data_obj_test.cfg,
200
+ gpu=self.config.SYSTEM.GPU,
201
+ default_args=data_obj_test.default_args,
202
+ )
203
+ test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
204
+ print(len(test_loader))
205
+
206
+ # # Inference
207
+ inferer = build_inferer(cfg=self.config.MODEL,
208
+ model=self.model)
209
+ json_gz_file = inferer.infer(cfg=self.config, data=test_loader)
210
+
211
+ #json_gz_file = self.config.DATA.test.results + ".recall.json.gz"
212
+ json_gz_file = predictions if predictions else json_gz_file
213
+
214
+ metrics = None
215
+
216
+ if has_localization_events(self.config.DATA.test.path):
217
+ logging.info("Ground truth labels detected → running evaluation")
218
+
219
+ evaluator = build_evaluator(cfg=self.config)
220
+ metrics = evaluator.evaluate(
221
+ cfg_testset=self.config.DATA.test,
222
+ json_gz_file=json_gz_file
223
+ )
224
+ else:
225
+ logging.info("No labels found in annotation file → skipping evaluation")
226
+
227
+ logging.info(f"Total Execution Time is {time.time()-start} seconds")
228
+ return metrics