opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,238 @@
1
+ # opensportslib/core/utils/checkpoint.py
2
+
3
+ import torch
4
+ import os
5
+
6
+ def localization_remap(key):
7
+ if key.startswith("_features"):
8
+ return "backbone." + key
9
+ elif key.startswith("_pred_fine"):
10
+ return "head." + key
11
+ return key
12
+
13
+
14
+ def save_checkpoint(model, path, processor=None, tokenizer=None, optimizer=None, epoch=None):
15
+ """
16
+ Save model checkpoint to `path`. Uses HF save_pretrained if available,
17
+ otherwise falls back to saving a PyTorch checkpoint.
18
+
19
+ Args:
20
+ model (torch.nn.Module or HF PreTrainedModel): model to save
21
+ path (str): path to save torch checkpoint (file path, e.g., /.../checkpoint.pt)
22
+ optimizer (torch.optim.Optimizer, optional): optimizer to save
23
+ epoch (int, optional): current epoch number
24
+ processor (optional): HF processor / feature extractor to save with model
25
+ tokenizer (optional): HF tokenizer to save with model
26
+ """
27
+ os.makedirs(os.path.dirname(path), exist_ok=True)
28
+
29
+ hf_saved = False
30
+
31
+ # 1) Try to save HuggingFace model if available
32
+ if hasattr(model, "save_pretrained"):
33
+ try:
34
+ model.save_pretrained(path)
35
+ hf_saved = True
36
+ print(f"[Checkpoint] HuggingFace model saved to {path}")
37
+ except Exception as e:
38
+ print(f"[Checkpoint] Warning: could not save HF model: {e}")
39
+ hf_saved = False
40
+
41
+ # 2) Save processor / tokenizer if provided (only if HF save succeeded)
42
+ if hf_saved:
43
+ if processor is not None:
44
+ try:
45
+ processor.save_pretrained(path)
46
+ print(f"[Checkpoint] Processor saved to {path}")
47
+ except Exception as e:
48
+ print(f"[Checkpoint] Warning: could not save processor: {e}")
49
+
50
+ if tokenizer is not None:
51
+ try:
52
+ tokenizer.save_pretrained(path)
53
+ print(f"[Checkpoint] Tokenizer saved to {path}")
54
+ except Exception as e:
55
+ print(f"[Checkpoint] Warning: could not save tokenizer: {e}")
56
+
57
+ # 3) Fallback: Save a PyTorch checkpoint if HF save is unavailable or failed
58
+ if not hf_saved:
59
+ checkpoint = {}
60
+ if hasattr(model, "state_dict"):
61
+ checkpoint["model_state_dict"] = model.state_dict()
62
+ else:
63
+ print("[Checkpoint] Warning: model has no state_dict(), skipping model_state_dict.")
64
+
65
+ if optimizer is not None:
66
+ checkpoint["optimizer_state_dict"] = optimizer.state_dict()
67
+ if epoch is not None:
68
+ checkpoint["epoch"] = epoch
69
+
70
+ torch.save(checkpoint, path)
71
+ print(f"[Checkpoint] Torch checkpoint saved at: {path}")
72
+
73
+
74
+ def load_checkpoint(
75
+ model,
76
+ path,
77
+ optimizer=None,
78
+ scheduler=None,
79
+ device=None,
80
+ key_remap_fn=None,
81
+ hf_filename="model.pth.tar", # required if loading from HF repo
82
+ hf_token=None, # optional (for private repos / non-interactive envs)
83
+ ):
84
+ """
85
+ Load checkpoint from:
86
+ - local .pt/.pth/.tar
87
+ - HuggingFace repo (repo_id)
88
+
89
+ Auth behavior:
90
+ - If logged in -> no token needed
91
+ - If not logged in -> interactive prompt
92
+ - If non-interactive -> hf_token required
93
+ """
94
+
95
+ import os
96
+ import sys
97
+ import torch
98
+ from opensportslib.core.utils.config import expand
99
+
100
+ if device is None:
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+
103
+ ckpt_path = None
104
+ hf_error = None
105
+ is_local = os.path.exists(path)
106
+ # --------------------------------------------------
107
+ # Try Hugging Face FIRST
108
+ # --------------------------------------------------
109
+ try:
110
+ from huggingface_hub import hf_hub_download, whoami, login, list_repo_files
111
+
112
+ # Ensure auth if needed
113
+ if hf_token is None:
114
+ try:
115
+ whoami()
116
+ except Exception:
117
+ if sys.stdin.isatty():
118
+ login()
119
+
120
+ if not is_local:
121
+ print(f"[HF] Inspecting repo: {path}")
122
+
123
+ files = list_repo_files(path, token=hf_token)
124
+
125
+ # find checkpoint file automatically
126
+ candidates = [
127
+ f for f in files
128
+ if f.endswith((".pt", ".pth", ".pth.tar", ".bin"))
129
+ ]
130
+
131
+ if not candidates:
132
+ raise FileNotFoundError(
133
+ f"No checkpoint file found in HF repo {path}. "
134
+ f"Files: {files}"
135
+ )
136
+
137
+ # pick first candidate
138
+ hf_filename = candidates[0]
139
+ print(f"[HF] Using checkpoint file: {hf_filename}")
140
+
141
+ ckpt_path = hf_hub_download(
142
+ repo_id=path,
143
+ filename=hf_filename,
144
+ token=hf_token,
145
+ )
146
+
147
+ print(f"[HF] Loaded from cache: {ckpt_path}")
148
+
149
+ except Exception as e:
150
+ hf_error = e
151
+
152
+ # --------------------------------------------------
153
+ # 2️⃣ Fallback to local
154
+ # --------------------------------------------------
155
+ path = expand(path)
156
+ if ckpt_path is None:
157
+ if not is_local:
158
+ raise FileNotFoundError(
159
+ f"Checkpoint not found on HuggingFace OR locally: {path}"
160
+ ) from hf_error
161
+
162
+ ckpt_path = path
163
+ print(f"[Local] Using local checkpoint: {ckpt_path}")
164
+ # --------------------------------------------------
165
+ # Load checkpoint
166
+ # --------------------------------------------------
167
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
168
+
169
+ # ---------------- MODEL STATE ----------------
170
+ if isinstance(checkpoint, dict):
171
+ if "model_state_dict" in checkpoint:
172
+ state_dict = checkpoint["model_state_dict"]
173
+ elif "state_dict" in checkpoint:
174
+ state_dict = checkpoint["state_dict"]
175
+ else:
176
+ state_dict = {
177
+ k: v for k, v in checkpoint.items()
178
+ if isinstance(v, torch.Tensor)
179
+ }
180
+ else:
181
+ raise ValueError("Checkpoint format not recognized")
182
+
183
+ # Clean + remap keys
184
+ model_keys = list(model.state_dict().keys())
185
+ ckpt_keys = list(state_dict.keys())
186
+
187
+ ckpt_has_module = ckpt_keys[0].startswith("module.")
188
+ model_has_module = model_keys[0].startswith("module.")
189
+
190
+ # Case 1: checkpoint has module., model doesn't
191
+ if ckpt_has_module and not model_has_module:
192
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
193
+
194
+ # Case 2: checkpoint doesn't have module., model does
195
+ elif not ckpt_has_module and model_has_module:
196
+ state_dict = {f"module.{k}": v for k, v in state_dict.items()}
197
+
198
+ # Optional custom remap
199
+ if key_remap_fn:
200
+ state_dict = {key_remap_fn(k): v for k, v in state_dict.items()}
201
+
202
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
203
+ print("\n--- MISSING KEYS ---")
204
+ for k in missing[:20]:
205
+ print(k)
206
+
207
+ print("\n--- UNEXPECTED KEYS ---")
208
+ for k in unexpected[:20]:
209
+ print(k)
210
+ model.to(device)
211
+
212
+ # ---------------- EPOCH ----------------
213
+ epoch = checkpoint.get("epoch") if isinstance(checkpoint, dict) else None
214
+
215
+ # ---------------- OPTIMIZER ----------------
216
+ if optimizer and isinstance(checkpoint, dict):
217
+ opt_state = checkpoint.get("optimizer") or checkpoint.get("optimizer_state_dict")
218
+ if opt_state:
219
+ optimizer.load_state_dict(opt_state)
220
+
221
+ # ---------------- SCHEDULER ----------------
222
+ if scheduler and isinstance(checkpoint, dict):
223
+ sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
224
+ if sch_state:
225
+ scheduler.load_state_dict(sch_state)
226
+
227
+ print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
228
+ print(f"Missing keys: {len(missing)}")
229
+ print(f"Unexpected keys: {len(unexpected)}")
230
+
231
+ return model, optimizer, scheduler, epoch
232
+
233
+
234
+
235
+ def load_huggingface_checkpoint(config, path, device):
236
+ from opensportslib.models.base.video_mae import load_video_mae_checkpoint
237
+ return load_video_mae_checkpoint(config, device=device, ckpt_path=path)
238
+
@@ -0,0 +1,199 @@
1
+
2
+ import os
3
+ import re
4
+ import json
5
+ import gzip
6
+ import yaml
7
+
8
+ def dict_to_namespace(d, skip_keys=("classes",)):
9
+ """
10
+ Recursively convert dict to namespace for easy access,
11
+ but keep certain keys (like 'classes') as raw dict/list.
12
+ """
13
+ from types import SimpleNamespace
14
+
15
+ if isinstance(d, dict):
16
+ out = {}
17
+ for k, v in d.items():
18
+ if k in skip_keys:
19
+ out[k] = v # leave as-is
20
+ else:
21
+ out[k] = dict_to_namespace(v, skip_keys)
22
+ return SimpleNamespace(**out)
23
+ elif isinstance(d, list):
24
+ return [dict_to_namespace(v, skip_keys) for v in d]
25
+ else:
26
+ return d
27
+
28
+ def namespace_to_dict(ns):
29
+ return {k: vars(v) if hasattr(v, "__dict__") else v for k, v in vars(ns).items()}
30
+
31
+ def namespace_to_omegaconf(ns):
32
+ """
33
+ Recursively convert SimpleNamespace (or dict/list) back to OmegaConf
34
+ """
35
+ from omegaconf import OmegaConf
36
+ from types import SimpleNamespace
37
+
38
+ def to_dict(obj):
39
+ if isinstance(obj, SimpleNamespace):
40
+ return {k: to_dict(v) for k, v in vars(obj).items()}
41
+ elif isinstance(obj, dict):
42
+ return {k: to_dict(v) for k, v in obj.items()}
43
+ elif isinstance(obj, list):
44
+ return [to_dict(v) for v in obj]
45
+ else:
46
+ return obj
47
+
48
+ return OmegaConf.create(to_dict(ns))
49
+
50
+ def load_config(config_path):
51
+ """
52
+ Loading configurations
53
+ """
54
+ print(config_path)
55
+ if config_path.endswith(".yaml") or config_path.endswith(".yml"):
56
+ with open(config_path, "r") as f:
57
+ cfg_dict = yaml.safe_load(f)
58
+ elif config_path.endswith(".json"):
59
+ with open(config_path, "r") as f:
60
+ cfg_dict = json.load(f)
61
+ else:
62
+ raise ValueError("Unsupported config format. Use YAML or JSON.")
63
+ return dict_to_namespace(cfg_dict)
64
+
65
+
66
+
67
+ def load_config_omega(path):
68
+
69
+ from omegaconf import OmegaConf
70
+ cfg = OmegaConf.load(path)
71
+ # OmegaConf.resolve(cfg)
72
+ # cfg = OmegaConf.to_container(cfg, resolve=True)
73
+ return dict_to_namespace(cfg)
74
+
75
+ def resolve_config_omega(cfg):
76
+ from omegaconf import OmegaConf, DictConfig
77
+ #cfg = namespace_to_omegaconf(cfg)
78
+ #cfg = namespace_to_dict(cfg)
79
+ #print(type(cfg))
80
+ #cfg = OmegaConf.create(cfg)
81
+ if not isinstance(cfg, DictConfig):
82
+ return cfg
83
+ OmegaConf.resolve(cfg)
84
+ cfg = dict_to_namespace(OmegaConf.to_container(cfg, resolve=True))
85
+ return cfg
86
+
87
+
88
+ def expand(path):
89
+ return os.path.abspath(os.path.expanduser(path))
90
+
91
+
92
+ def load_json(fpath):
93
+ with open(fpath) as fp:
94
+ return json.load(fp)
95
+
96
+ def load_gz_json(fpath):
97
+ with gzip.open(fpath, "rt", encoding="ascii") as fp:
98
+ return json.load(fp)
99
+
100
+
101
+ def store_json(fpath, obj, pretty=False):
102
+ kwargs = {}
103
+ if pretty:
104
+ kwargs["indent"] = 4
105
+ kwargs["sort_keys"] = False
106
+ with open(fpath, "w") as fp:
107
+ json.dump(obj, fp, **kwargs)
108
+
109
+
110
+ def store_gz_json(fpath, obj):
111
+ with gzip.open(fpath, "wt", encoding="ascii") as fp:
112
+ json.dump(obj, fp)
113
+
114
+
115
+ def load_text(fpath):
116
+ """Load text from a given file.
117
+
118
+ Args:
119
+ fpath (string): The path of the file.
120
+
121
+ Returns:
122
+ lines (List): List in which element is a line of the file.
123
+
124
+ """
125
+ lines = []
126
+ with open(fpath, "r") as fp:
127
+ for l in fp:
128
+ l = l.strip()
129
+ if l:
130
+ lines.append(l)
131
+ return lines
132
+
133
+ def load_classes(input):
134
+ """Load classes from either list or txt file.
135
+
136
+ Args:
137
+ input (string): Path of the file that contains one class per line or list of classes.
138
+
139
+ Returns:
140
+ Dictionnary with classes associated to indexes.
141
+ """
142
+ from omegaconf import ListConfig
143
+ if isinstance(input, (list, ListConfig)):
144
+ return {x: i + 1 for i, x in enumerate(input)}
145
+ return {x: i + 1 for i, x in enumerate(load_text(input))}
146
+
147
+ def clear_files(dir_name, re_str, exclude=[]):
148
+ for file_name in os.listdir(dir_name):
149
+ if re.match(re_str, file_name):
150
+ if file_name not in exclude:
151
+ file_path = os.path.join(dir_name, file_name)
152
+ os.remove(file_path)
153
+
154
+
155
+ def _print_info_helper(src_file, labels):
156
+ """Print informations about videos contained in a json file.
157
+
158
+ Args:
159
+ src_file (string): The source file.
160
+ labels (list(dict)): List containing a dict fro each video.
161
+ """
162
+ num_frames = sum([x["num_frames"] for x in labels])
163
+ num_events = sum([len(x["events"]) for x in labels])
164
+ print(
165
+ "{} : {} videos, {} frames, {:0.5f}% non-bg".format(
166
+ src_file, len(labels), num_frames, num_events / num_frames * 100
167
+ )
168
+ )
169
+
170
+ def select_device(config):
171
+ import torch
172
+ mode = config.device.lower()
173
+
174
+ if mode == "auto":
175
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
176
+
177
+ elif mode == "cuda":
178
+ assert torch.cuda.is_available(), "CUDA requested but not available"
179
+ gpu_id = getattr(config, "gpu_id", 0)
180
+ torch.cuda.set_device(gpu_id)
181
+ device = torch.device(f"cuda:{gpu_id}")
182
+
183
+ elif mode == "cpu":
184
+ device = torch.device("cpu")
185
+
186
+ else:
187
+ raise ValueError(f"Unknown device mode: {mode}")
188
+
189
+ print(f"Using device: {device}")
190
+ if device.type == "cuda" or device.type == "auto":
191
+ print(f"GPU: {torch.cuda.get_device_name(device)}")
192
+
193
+ return device
194
+
195
+ def is_local_path(p):
196
+ return p and (
197
+ os.path.exists(p) or
198
+ p.endswith((".pt", ".pth", ".tar"))
199
+ )
@@ -0,0 +1,85 @@
1
+ from collections import defaultdict
2
+ from torch.utils.data import Subset
3
+ import torch
4
+ import numpy as np
5
+
6
+ def balanced_subset(dataset, samples_per_class=5):
7
+ class_indices = defaultdict(list)
8
+
9
+ for idx in range(len(dataset)):
10
+ label = dataset.samples[idx]["label"]
11
+ class_indices[label].append(idx)
12
+
13
+ print(class_indices.keys())
14
+ selected_indices = []
15
+ for label, indices in class_indices.items():
16
+ selected_indices.extend(indices[:samples_per_class])
17
+
18
+ print(selected_indices)
19
+ return Subset(dataset, selected_indices)
20
+
21
+
22
+ def batch_tensor(tensor, dim=1, squeeze=False):
23
+ """
24
+ A function to reshape PyTorch tensor `tensor` along some dimension `dim` to the batch dimension 0 such that the tensor can be processed in parallel.
25
+ If `sqeeze`=True, the dimension `dim` will be removed completely, otherwise it will be of size=1. Check `unbatch_tensor()` for the reverese function.
26
+ """
27
+ batch_size, dim_size = tensor.shape[0], tensor.shape[dim]
28
+ returned_size = list(tensor.shape)
29
+ returned_size[0] = batch_size * dim_size
30
+ returned_size[dim] = 1
31
+ if squeeze:
32
+ return tensor.transpose(0, dim).reshape(returned_size).squeeze_(dim)
33
+ else:
34
+ return tensor.transpose(0, dim).reshape(returned_size)
35
+
36
+
37
+ def unbatch_tensor(tensor, batch_size, dim=1, unsqueeze=False):
38
+ """
39
+ A function to chunk pytorch tensor `tensor` along the batch dimension 0 and concatenate the chuncks on dimension `dim` to recover from `batch_tensor()` function.
40
+ If `unsqueee`=True, it will add a dimension `dim` before the unbatching.
41
+ """
42
+ fake_batch_size = tensor.shape[0]
43
+ nb_chunks = int(fake_batch_size / batch_size)
44
+ if unsqueeze:
45
+ return torch.cat(torch.chunk(tensor.unsqueeze_(dim), nb_chunks, dim=0), dim=dim).contiguous()
46
+ else:
47
+ return torch.cat(torch.chunk(tensor, nb_chunks, dim=0), dim=dim).contiguous()
48
+
49
+
50
+ def tracking_collate_fn(batch):
51
+ """
52
+ Custom collate function for tracking data.
53
+ Uses PyG Batch.from_data_list for efficient C++ batching.
54
+ """
55
+ from torch_geometric.data import Batch
56
+
57
+ batch_size = len(batch)
58
+ seq_len = batch[0]['seq_len']
59
+
60
+ # flatten all graphs from all samples
61
+ all_graphs = []
62
+ for sample_idx, item in enumerate(batch):
63
+ for time_idx, graph in enumerate(item['graphs']):
64
+ all_graphs.append(graph)
65
+
66
+ # PyG handles node offsets for edge_index automatically
67
+ batched_graphs = Batch.from_data_list(all_graphs)
68
+
69
+ return {
70
+ 'x': batched_graphs.x,
71
+ 'edge_index': batched_graphs.edge_index,
72
+ 'batch': batched_graphs.batch,
73
+ 'batch_size': batch_size,
74
+ 'seq_len': seq_len,
75
+ 'labels': torch.tensor([item['label'] for item in batch], dtype=torch.long),
76
+ 'id': [item['id'] for item in batch],
77
+ }
78
+
79
+ def mixup_data(x, y, alpha=0.2):
80
+ """blend pairs of samples and their labels for mixup augmentation."""
81
+ lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
82
+ index = torch.randperm(x.size(0)).to(x.device)
83
+ mixed_x = lam * x + (1 - lam) * x[index]
84
+ return mixed_x, y, y[index], lam
85
+
@@ -0,0 +1,77 @@
1
+ import torch.distributed as dist
2
+ def ddp_setup(rank, world_size):
3
+ import os
4
+ os.environ["MASTER_ADDR"] = "localhost"
5
+ os.environ["MASTER_PORT"] = "12355" # any free port
6
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
7
+
8
+ def ddp_cleanup():
9
+ dist.destroy_process_group()
10
+
11
+
12
+ import torch
13
+ from torch.utils.data import Sampler
14
+ import math
15
+
16
+ class DistributedWeightedSampler(Sampler):
17
+ def __init__(
18
+ self,
19
+ weights,
20
+ num_replicas=None,
21
+ rank=None,
22
+ replacement=True,
23
+ num_samples=None,
24
+ seed=0
25
+ ):
26
+ if num_replicas is None:
27
+ num_replicas = dist.get_world_size()
28
+ if rank is None:
29
+ rank = dist.get_rank()
30
+
31
+ self.num_replicas = num_replicas
32
+ self.rank = rank
33
+ self.replacement = replacement
34
+ self.seed = seed
35
+
36
+ self.weights = torch.as_tensor(weights, dtype=torch.double)
37
+ self.dataset_size = len(self.weights)
38
+
39
+ # Split dataset across ranks
40
+ self.num_samples_per_rank = math.ceil(self.dataset_size / self.num_replicas)
41
+ self.total_size = self.num_samples_per_rank * self.num_replicas
42
+
43
+ # Pad if needed
44
+ if self.total_size > self.dataset_size:
45
+ padding = self.total_size - self.dataset_size
46
+ self.weights = torch.cat([self.weights, self.weights[:padding]])
47
+
48
+ # Indices for this rank
49
+ self.rank_indices = list(range(self.rank, self.total_size, self.num_replicas))
50
+ self.rank_weights = self.weights[self.rank_indices]
51
+
52
+ # Number of samples to draw
53
+ if num_samples is None:
54
+ self.num_samples = len(self.rank_indices)
55
+ else:
56
+ self.num_samples = num_samples // self.num_replicas
57
+
58
+ self.epoch = 0
59
+
60
+ def set_epoch(self, epoch):
61
+ self.epoch = epoch
62
+
63
+ def __iter__(self):
64
+ g = torch.Generator()
65
+ g.manual_seed(self.seed + self.epoch)
66
+
67
+ sampled = torch.multinomial(
68
+ self.rank_weights,
69
+ self.num_samples,
70
+ self.replacement,
71
+ generator=g
72
+ )
73
+
74
+ return (self.rank_indices[i] % self.dataset_size for i in sampled)
75
+
76
+ def __len__(self):
77
+ return self.num_samples
@@ -0,0 +1,110 @@
1
+ def get_default_args_data_train_e2e_dali(cfg):
2
+ return {
3
+ "classes": cfg.DATA.classes,
4
+ "train": True,
5
+ "acc_grad_iter": cfg.TRAIN.acc_grad_iter,
6
+ "num_epochs": cfg.TRAIN.num_epochs,
7
+ "repartitions": cfg.TRAIN.repartitions,
8
+ }
9
+
10
+
11
+ def get_default_args_data_valid_e2e_dali(cfg):
12
+ return {
13
+ "classes": cfg.DATA.classes,
14
+ "train": False,
15
+ "acc_grad_iter": cfg.TRAIN.acc_grad_iter,
16
+ "num_epochs": cfg.TRAIN.num_epochs,
17
+ "repartitions": cfg.TRAIN.repartitions,
18
+ }
19
+
20
+
21
+ def get_default_args_data_train_e2e_opencv(cfg):
22
+ return {"classes": cfg.DATA.classes, "train": True}
23
+
24
+
25
+ def get_default_args_data_valid_e2e_opencv(cfg):
26
+ return {"classes": cfg.DATA.classes, "train": False}
27
+
28
+
29
+ def get_default_args_data_train():
30
+ return {"train": True}
31
+
32
+
33
+ def get_default_args_data_valid():
34
+ return {"train": False}
35
+
36
+
37
+ def get_default_args_data_valid_data_frames_e2e_dali(cfg):
38
+ return {"classes": cfg.DATA.classes, "repartitions": cfg.TRAIN.repartitions}
39
+
40
+
41
+ def get_default_args_data_valid_data_frames_e2e_opencv(cfg):
42
+ return {"classes": cfg.DATA.classes}
43
+
44
+
45
+ def get_default_args_dataset(split, cfg):
46
+ if split == "train":
47
+ if cfg.MODEL.runner.type == "runner_e2e":
48
+ if getattr(cfg, "dali", False):
49
+ return get_default_args_data_train_e2e_dali(cfg)
50
+ else:
51
+ return get_default_args_data_train_e2e_opencv(cfg)
52
+ else:
53
+ return get_default_args_data_train()
54
+
55
+ elif split == "valid":
56
+ if cfg.MODEL.runner.type == "runner_e2e":
57
+ if getattr(cfg, "dali", False):
58
+ return get_default_args_data_valid_e2e_dali(cfg)
59
+ else:
60
+ return get_default_args_data_valid_e2e_opencv(cfg)
61
+ else:
62
+ return get_default_args_data_valid()
63
+
64
+ elif split == "valid_data_frames" or split == "test" or split == "challenge":
65
+ if cfg.MODEL.runner.type == "runner_e2e":
66
+ if getattr(cfg, "dali", False):
67
+ return get_default_args_data_valid_data_frames_e2e_dali(cfg)
68
+ else:
69
+ return get_default_args_data_valid_data_frames_e2e_opencv(cfg)
70
+ else:
71
+ return
72
+ else:
73
+ return None
74
+
75
+
76
+ def get_default_args_model(cfg):
77
+ if cfg.MODEL.type == "E2E":
78
+ return {"classes": cfg.DATA.classes}
79
+ else:
80
+ return None
81
+
82
+
83
+ def get_default_args_trainer(cfg, len_train_loader):
84
+ if cfg.TRAIN.type == "trainer_e2e":
85
+ return {
86
+ "len_train_loader": len_train_loader,
87
+ "work_dir": cfg.SYSTEM.work_dir,
88
+ "dali": cfg.dali,
89
+ "repartitions": cfg.TRAIN.repartitions if cfg.dali else None,
90
+ "cfg_test": cfg.DATA.test,
91
+ #"cfg_challenge": cfg.DATA.challenge,
92
+ "cfg_valid_data_frames": cfg.DATA.valid_data_frames,
93
+ }
94
+ else:
95
+ return {"work_dir": cfg.SYSTEM.work_dir}
96
+
97
+
98
+ def get_default_args_train(model, train_loader, valid_loader, classes, trainer_type):
99
+ if trainer_type == "trainer_CALF" or trainer_type == "trainer_pooling":
100
+ return {
101
+ "model": model,
102
+ "train_dataloaders": train_loader,
103
+ "val_dataloaders": valid_loader,
104
+ }
105
+ elif trainer_type == "trainer_e2e":
106
+ return {
107
+ "train_loader": train_loader,
108
+ "valid_loader": valid_loader,
109
+ "classes": classes,
110
+ }