opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,252 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+
31
+ import logging
32
+ from opensportslib.core.loss.builder import build_criterion
33
+ from opensportslib.models.backbones.builder import build_backbone
34
+ from opensportslib.models.utils.common import step, BaseRGBModel
35
+ from opensportslib.models.heads.builder import build_head
36
+ from contextlib import nullcontext
37
+ import torch
38
+ from tqdm import tqdm
39
+ import torch.nn as nn
40
+
41
+
42
+ class E2EModel(BaseRGBModel):
43
+ """Class used to handle the model for the E2E method. This class is not the model but contains the model as a variable.
44
+ This class is used to initialize the model, perform an epoch and predict.
45
+ Args:
46
+ cfg (dict): Dict of config.
47
+ num_classes (int): Number of classes.
48
+ backbone (dict): Dict used to build the backbone.
49
+ head (dict): Dict used to build the head.
50
+ clip_len (int): the length of the clip.
51
+ device (string): The device to use.
52
+ Default :'cuda'.
53
+ multi_gpu: Whether to use one gpu or more for the model.
54
+ Default: False.
55
+ """
56
+
57
+ class Impl(nn.Module):
58
+ """Model for the E2E method. Model is build of a backbone and a head.
59
+
60
+ Args:
61
+ num_classes (int): Number of classes.
62
+ backbone (dict): Dict used to build the backbone.
63
+ head (dict): Dict used to build the head.
64
+ clip_len (int): the length of the clip.
65
+ """
66
+
67
+ def __init__(self, num_classes, backbone, head, clip_len, modality):
68
+ super().__init__()
69
+ is_rgb = modality == "rgb"
70
+ in_channels = {"flow": 2, "bw": 1, "rgb": 3}[modality]
71
+
72
+ backbone.clip_len = clip_len
73
+ backbone.is_rgb = is_rgb
74
+ backbone.in_channels = in_channels
75
+
76
+ self.backbone = build_backbone(backbone)
77
+
78
+ head.num_classes = num_classes
79
+ head.feat_dim = self.backbone._feat_dim
80
+
81
+ self.head = build_head(head)
82
+
83
+ def forward(self, x):
84
+ im_feat = self.backbone(x)
85
+ return self.head(im_feat)
86
+
87
+ def print_stats(self):
88
+ print("Model params:", sum(p.numel() for p in self.parameters()))
89
+ print(
90
+ " CNN features:",
91
+ sum(p.numel() for p in self.backbone._features.parameters()),
92
+ )
93
+ print(
94
+ " Temporal:", sum(p.numel() for p in self.head._pred_fine.parameters())
95
+ )
96
+
97
+ def __init__(
98
+ self,
99
+ cfg,
100
+ num_classes,
101
+ backbone,
102
+ head,
103
+ clip_len,
104
+ modality,
105
+ device="cuda",
106
+ multi_gpu=False,
107
+ ):
108
+
109
+ last_gpu_index = torch.cuda.device_count() - 1
110
+
111
+ # self.device = torch.device('cuda:{}'.format(0))
112
+ self.device = device
113
+ self._multi_gpu = multi_gpu
114
+ self._model = E2EModel.Impl(num_classes, backbone, head, clip_len, modality)
115
+ self._model.print_stats()
116
+ logging.info("Build criterion")
117
+ self.criterion = build_criterion(cfg.TRAIN.criterion)
118
+ logging.info(self.criterion)
119
+ if multi_gpu:
120
+ self._model = nn.DataParallel(self._model)
121
+ self._model.to(device)
122
+ # self._model = nn.DataParallel(self._model,device_ids = [1, 0, 2, 3] if torch.cuda.device_count()==4 else [1, 0])
123
+ # self._model.to(f'cuda:{self._model.device_ids[0]}')
124
+ # self.device = torch.device('cuda:{}'.format(1))
125
+ else:
126
+ self._model.to(device)
127
+
128
+ self._multi_gpu = multi_gpu
129
+ self._num_classes = num_classes
130
+
131
+ def epoch(
132
+ self,
133
+ loader,
134
+ dali,
135
+ optimizer=None,
136
+ scaler=None,
137
+ lr_scheduler=None,
138
+ acc_grad_iter=1,
139
+ fg_weight=5,
140
+ ):
141
+ """Performs an epoch for training/validating.
142
+
143
+ Args:
144
+ loader: The dataloader.
145
+ dali (bool): Whether dali has been used or opencv for data processing.
146
+ optimizer (torch.optim.Optimizer): The optimizer to update model parameters. Set to None if validation epoch.
147
+ Default: None.
148
+ scaler (torch.cuda.amp.GradScaler): The gradient scaler for mixed precision training.
149
+ Default: None.
150
+ lr_scheduler : The learning rate scheduler.
151
+ Default: None.
152
+ acc_grad_iter (int): Use gradient accumulation.
153
+ Default: 1.
154
+ fg_weight (int): Used to build tensor of weights for the criterion.
155
+ Default: 5.
156
+ Returns:
157
+ (float): The average loss over the batches.
158
+ """
159
+ if optimizer is None:
160
+ self._model.eval()
161
+ else:
162
+ optimizer.zero_grad()
163
+ self._model.train()
164
+
165
+ ce_kwargs = {}
166
+ if fg_weight != 1:
167
+ ce_kwargs["weight"] = torch.FloatTensor(
168
+ [1] + [fg_weight] * (self._num_classes - 1)
169
+ ).to(self.device)
170
+
171
+ epoch_loss = 0.0
172
+
173
+ times = []
174
+ import timeit
175
+
176
+ with torch.no_grad() if optimizer is None else nullcontext():
177
+ for batch_idx, batch in enumerate(tqdm(loader)):
178
+ if dali:
179
+ frame = batch["frame"].to(self.device)
180
+ label = batch["label"].to(self.device)
181
+ else:
182
+ frame = loader.dataset.load_frame_gpu(batch, self.device)
183
+ label = batch["label"].to(self.device)
184
+
185
+ label = (
186
+ label.flatten()
187
+ if len(label.shape) == 2
188
+ else label.view(-1, label.shape[-1])
189
+ )
190
+
191
+ with torch.cuda.amp.autocast():
192
+ pred = self._model(frame)
193
+
194
+ pred = pred.to(self.device)
195
+
196
+ loss = 0.0
197
+ if len(pred.shape) == 3:
198
+ pred = pred.unsqueeze(0)
199
+
200
+ # label=label.to(self.device)
201
+
202
+ for i in range(pred.shape[0]):
203
+ loss += self.criterion(
204
+ pred[i].reshape(-1, self._num_classes), label, **ce_kwargs
205
+ )
206
+ # loss += F.cross_entropy(
207
+ # pred[i].reshape(-1, self._num_classes), label,
208
+ # **ce_kwargs)
209
+
210
+ if optimizer is not None:
211
+ step(
212
+ optimizer,
213
+ scaler,
214
+ loss / acc_grad_iter,
215
+ lr_scheduler=lr_scheduler,
216
+ backward_only=(batch_idx + 1) % acc_grad_iter != 0,
217
+ )
218
+
219
+ epoch_loss += loss.detach().item()
220
+
221
+ print(epoch_loss, len(loader), epoch_loss / len(loader))
222
+ return epoch_loss / len(loader) # Avg loss
223
+
224
+ def predict(self, seq, use_amp=True):
225
+ """Perform prediction on the input.
226
+
227
+ Args:
228
+ seq (torch.tensor): The input.
229
+ use_amp (bool): Whether to use automatic precision.
230
+ Default: True.
231
+ Returns:
232
+ pred_cls (numpy.ndarray): Predicted class indices.
233
+ pred (numpy.ndarray): Predicted probabilities.
234
+ """
235
+ if not isinstance(seq, torch.Tensor):
236
+ seq = torch.FloatTensor(seq)
237
+ if len(seq.shape) == 4: # (L, C, H, W)
238
+ seq = seq.unsqueeze(0)
239
+ if seq.device != self.device:
240
+ seq = seq.to(self.device)
241
+
242
+ self._model.eval()
243
+ with torch.no_grad():
244
+ with torch.cuda.amp.autocast() if use_amp else nullcontext():
245
+ pred = self._model(seq)
246
+ if isinstance(pred, tuple):
247
+ pred = pred[0]
248
+ if len(pred.shape) > 3:
249
+ pred = pred[-1]
250
+ pred = torch.softmax(pred, axis=2)
251
+ pred_cls = torch.argmax(pred, axis=2)
252
+ return pred_cls.cpu().numpy(), pred.cpu().numpy()
@@ -0,0 +1,73 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import global_mean_pool
4
+
5
+ from opensportslib.models.backbones.builder import build_backbone
6
+ from opensportslib.models.neck.builder import build_neck
7
+ from opensportslib.models.heads.builder import build_head
8
+ from opensportslib.datasets.utils.tracking import FEATURE_DIM
9
+
10
+
11
+ class TrackingModel(nn.Module):
12
+ """
13
+ Tracking-based classification model.
14
+ Combines graph backbone, temporal neck, and classification head.
15
+ """
16
+
17
+ def __init__(self, config, device):
18
+ super().__init__()
19
+ print("Building TrackingModel")
20
+
21
+ self.device = device
22
+ self.num_frames = config.DATA.num_frames
23
+
24
+ # backbone: graph encoder
25
+ self.backbone = build_backbone(
26
+ config.MODEL.backbone,
27
+ default_args={"input_dim": FEATURE_DIM}
28
+ )
29
+
30
+ # neck: temporal aggregation
31
+ self.neck = build_neck(
32
+ config.MODEL.neck,
33
+ default_args={"window_size": self.num_frames}
34
+ )
35
+
36
+ # head: classifier
37
+ self.head = build_head(
38
+ config.MODEL.head,
39
+ default_args={"input_dim": self.neck.feat_dim}
40
+ )
41
+
42
+ def forward(self, batch):
43
+ """
44
+ Args:
45
+ batch: dict with keys:
46
+ - x: (B*T*N, F) all node features batched
47
+ - edge_index: (2, E) all edges with proper offsets
48
+ - batch: (B*T*N,) graph assignment per node
49
+ - batch_size: int
50
+ - seq_len: int
51
+
52
+ Returns:
53
+ logits: (B, num_classes)
54
+ """
55
+ x = batch['x']
56
+ edge_index = batch['edge_index']
57
+ batch_idx = batch['batch']
58
+ batch_size = batch['batch_size']
59
+ seq_len = batch['seq_len']
60
+
61
+ # single forward through backbone for all B*T graphs
62
+ graph_emb = self.backbone(x, edge_index, batch_idx) # (B*T, H)
63
+
64
+ # reshape to (B, T, H)
65
+ x = graph_emb.view(batch_size, seq_len, -1)
66
+
67
+ # temporal aggregation
68
+ x = self.neck(x) # (B, H)
69
+
70
+ # classification
71
+ logits = self.head(x) # (B, num_classes)
72
+
73
+ return logits
@@ -0,0 +1,29 @@
1
+ # opensportslib/models/base/vars.py
2
+
3
+ import __future__
4
+ import torch
5
+ from opensportslib.models.backbones.builder import build_backbone
6
+ from opensportslib.models.neck.builder import build_neck
7
+ from opensportslib.models.heads.builder import build_head
8
+
9
+ class MVNetwork(torch.nn.Module):
10
+
11
+ def __init__(self, config, backbone, neck, head):
12
+ super().__init__()
13
+ print("Building MVNetwork Model")
14
+ self.lifting_net = torch.nn.Sequential()
15
+
16
+ self.backbone = build_backbone(backbone)
17
+ self.mvaggregate = build_neck(neck, default_args=dict(
18
+ model=self.backbone,
19
+ feat_dim=self.backbone.feat_dim,
20
+ lifting_net=self.lifting_net
21
+ )
22
+ )
23
+ head.num_classes = config.DATA.num_classes
24
+ head.feat_dim = self.backbone.feat_dim
25
+ self.head = build_head(head)
26
+
27
+ def forward(self, mvimages):
28
+ features, attention = self.mvaggregate(mvimages)
29
+ return self.head(features), attention
@@ -0,0 +1,130 @@
1
+ # opensportslib/models/base/video.py
2
+
3
+ """video backbone and model for frames_npy modality.
4
+
5
+ this file contains two independent things:
6
+
7
+ 1. the existing VideoMAE HuggingFace full-model builder functions
8
+ (build_video_mae_backbone, load_video_mae_checkpoint).
9
+ these are left exactly as they were and route through MODEL.type == "huggingface".
10
+
11
+ 2. the new VideoBackbone + VideoModel classes for the custom frames_npy path,
12
+ supporting dinov3, clip, videomae, videomae2 as pure feature extractors
13
+ wired to the library's existing TemporalAggregation neck and
14
+ TrackingClassifierHead head.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from opensportslib.models.backbones.builder import build_backbone
21
+ from opensportslib.models.neck.builder import build_neck
22
+ from opensportslib.models.heads.builder import build_head
23
+
24
+
25
+ # -----------------------------------------------------------------------
26
+ # video mae backbone for MVFoul
27
+ # -----------------------------------------------------------------------
28
+
29
+ def build_video_mae_backbone(config, device, ckpt_path=None, infer=False):
30
+ """
31
+ Build HuggingFace VideoMAE model for video classification.
32
+ This includes both backbone and classification head.
33
+ """
34
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
35
+
36
+ num_classes = config.MODEL.num_classes
37
+ pretrained_model_name = ckpt_path if ckpt_path else config.MODEL.pretrained_model
38
+ processor = VideoMAEImageProcessor.from_pretrained(config.MODEL.pretrained_model)
39
+ model = VideoMAEForVideoClassification.from_pretrained(
40
+ pretrained_model_name,
41
+ num_labels=num_classes,
42
+ ignore_mismatched_sizes=True,
43
+ trust_remote_code=True,
44
+ device_map=device
45
+ )
46
+
47
+ for param in model.parameters():
48
+ param.requires_grad = False
49
+
50
+ if not infer:
51
+ if config.MODEL.unfreeze_head:
52
+ for p in model.classifier.parameters():
53
+ p.requires_grad = True
54
+
55
+ n_unfreeze = getattr(config.MODEL, "unfreeze_last_n_layers", 0)
56
+ if n_unfreeze > 0:
57
+ for layer in model.videomae.encoder.layer[-n_unfreeze:]:
58
+ for p in layer.parameters():
59
+ p.requires_grad = True
60
+
61
+ trainable = [name for name, p in model.named_parameters() if p.requires_grad]
62
+ print("Number of trainable params:", len(trainable))
63
+ for n in trainable:
64
+ print(n)
65
+
66
+ return model, processor
67
+
68
+
69
+ def load_video_mae_checkpoint(config, device, ckpt_path, infer=True):
70
+ """
71
+ Load fine-tuned VideoMAE checkpoint from a HuggingFace-style directory.
72
+ """
73
+ return build_video_mae_backbone(config, device, ckpt_path, infer=infer)
74
+
75
+
76
+ # -----------------------------------------------------------------------
77
+ # new custom path: full model
78
+ # -----------------------------------------------------------------------
79
+
80
+ class VideoModel(nn.Module):
81
+ """Video classification model for the frames_npy modality.
82
+
83
+ follows the same backbone -> neck -> head pattern as TrackingModel.
84
+
85
+ the backbone is a VideoBackbone (pure feature extractor).
86
+ the neck is TemporalAggregation.
87
+ the head is TrackingClassifierHead.
88
+
89
+ Args:
90
+ config: full YAML config
91
+ device: torch device string
92
+ """
93
+
94
+ def __init__(self, config, device):
95
+ super().__init__()
96
+ print("Building VideoModel")
97
+
98
+ self.device = device
99
+ self.num_frames = config.DATA.num_frames
100
+
101
+ # backbone: pure feature extractor
102
+ self.backbone = build_backbone(config.MODEL.backbone)
103
+
104
+ # neck: temporal aggregation over the frame sequence
105
+ self.neck = build_neck(
106
+ config.MODEL.neck,
107
+ default_args={"window_size": self.num_frames}
108
+ )
109
+
110
+ # head: linear classifier
111
+ self.head = build_head(
112
+ config.MODEL.head,
113
+ default_args={"input_dim": self.neck.feat_dim}
114
+ )
115
+
116
+ def forward(self, batch):
117
+ """
118
+ Args:
119
+ batch: dict with key "pixel_values" of shape (B, T, H, W, C).
120
+
121
+ Returns:
122
+ logits: (B, num_classes)
123
+ """
124
+ x = batch["pixel_values"] # (B, T, H, W, C)
125
+
126
+ x = self.backbone(x) # (B, T, hidden_dim) or (B, 1, hidden_dim)
127
+ x = self.neck(x) # (B, hidden_dim)
128
+ logits = self.head(x) # (B, num_classes)
129
+
130
+ return logits
@@ -0,0 +1,60 @@
1
+ # models/backbones/video_mae.py
2
+
3
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
4
+ import os
5
+
6
+ def build_video_mae_backbone(config, device, ckpt_path=None, infer=False):
7
+ """
8
+ Build HuggingFace VideoMAE model for video classification.
9
+ This includes both backbone and classification head.
10
+ """
11
+ num_classes = config.MODEL.num_classes
12
+ pretrained_model_name = ckpt_path if ckpt_path else config.MODEL.pretrained_model
13
+ processor = VideoMAEImageProcessor.from_pretrained(config.MODEL.pretrained_model)
14
+ model = VideoMAEForVideoClassification.from_pretrained(
15
+ pretrained_model_name,
16
+ num_labels=num_classes,
17
+ ignore_mismatched_sizes=True,
18
+ trust_remote_code=True,
19
+ device_map=device
20
+ )
21
+
22
+ # freeze everything
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
+ if not infer:
27
+ # Unfreeze classification head
28
+ if config.MODEL.unfreeze_head:
29
+ for p in model.classifier.parameters():
30
+ p.requires_grad = True
31
+
32
+ # -------- Unfreeze last N VideoMAE encoder layers --------
33
+ n_unfreeze = getattr(config.MODEL, "unfreeze_last_n_layers", 0)
34
+ # unfreeze last encoder layer
35
+ if n_unfreeze > 0:
36
+ for layer in model.videomae.encoder.layer[-n_unfreeze:]:
37
+ for p in layer.parameters():
38
+ p.requires_grad = True
39
+
40
+ trainable = []
41
+ for name, p in model.named_parameters():
42
+ if p.requires_grad:
43
+ trainable.append(name)
44
+
45
+ print("Number of trainable params:", len(trainable))
46
+ for n in trainable:
47
+ print(n)
48
+ return model, processor
49
+
50
+
51
+ def load_video_mae_checkpoint(config, device, ckpt_path, infer=True):
52
+ """
53
+ Load fine-tuned VideoMAE checkpoint from a HuggingFace-style directory.
54
+
55
+ Supports:
56
+ - model.safetensors
57
+ - pytorch_model.bin
58
+ - config.json
59
+ """
60
+ return build_video_mae_backbone(config, device, ckpt_path, infer=infer)
@@ -0,0 +1,43 @@
1
+ # opensportslib/models/builder.py
2
+
3
+ def build_model(config, device):
4
+ """
5
+ Dispatch model builder based on cfg.MODEL.task
6
+ """
7
+ task = config.TASK.lower()
8
+
9
+ if task == "classification":
10
+ # return model, processor
11
+ if config.MODEL.backbone.type == "video_mae":
12
+ from opensportslib.models.base.video import build_video_mae_backbone
13
+ return build_video_mae_backbone(config, device)
14
+
15
+ elif config.MODEL.backbone.type in ["r3d_18", "mc3_18", "r2plus1d_18", "s3d", "mvit_v2_s"]:
16
+ from opensportslib.models.base.vars import MVNetwork
17
+ return MVNetwork(config, config.MODEL.backbone, config.MODEL.neck, config.MODEL.head), None
18
+
19
+ elif config.MODEL.backbone.type == "graph_conv":
20
+ from opensportslib.models.base.tracking import TrackingModel
21
+ return TrackingModel(config, device), None
22
+
23
+ elif config.MODEL.backbone.type in ("dinov3", "clip", "videomae", "videomae2"):
24
+ from opensportslib.models.base.video import VideoModel
25
+ return VideoModel(config, device), None
26
+
27
+ else:
28
+ raise ValueError(f"Unsupported backbone type: {config.MODEL.backbone.type}")
29
+
30
+ if task == "localization":
31
+ from opensportslib.models.base.e2e import E2EModel
32
+ if config.MODEL.type == "E2E":
33
+ model = E2EModel(config,
34
+ len(config.DATA.classes)+1,
35
+ config.MODEL.backbone,
36
+ config.MODEL.head,
37
+ clip_len=config.DATA.clip_len,
38
+ modality=config.DATA.modality,
39
+ device=device,
40
+ multi_gpu=config.MODEL.multi_gpu)
41
+ return model
42
+ else:
43
+ raise ValueError(f"Unsupported model type: {config.MODEL.backbone} for task: {task}")