opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,132 @@
1
+ TASK: localization
2
+
3
+ dali: True
4
+
5
+ DATA:
6
+ dataset_name: SoccerNet
7
+ data_dir: /home/vorajv/opensportslib/SoccerNet/annotations/
8
+ classes:
9
+ - PASS
10
+ - DRIVE
11
+ - HEADER
12
+ - HIGH PASS
13
+ - OUT
14
+ - CROSS
15
+ - THROW IN
16
+ - SHOT
17
+ - BALL PLAYER BLOCK
18
+ - PLAYER SUCCESSFUL TACKLE
19
+ - FREE KICK
20
+ - GOAL
21
+
22
+ epoch_num_frames: 500000
23
+ mixup: true
24
+ modality: rgb
25
+ crop_dim: -1
26
+ dilate_len: 0 # Dilate ground truth labels
27
+ clip_len: 100
28
+ input_fps: 25
29
+ extract_fps: 2
30
+ imagenet_mean: [0.485, 0.456, 0.406]
31
+ imagenet_std: [0.229, 0.224, 0.225]
32
+ target_height: 224
33
+ target_width: 398
34
+
35
+ train:
36
+ type: VideoGameWithDali
37
+ classes: ${DATA.classes}
38
+ output_map: [data, label]
39
+ video_path: ${DATA.data_dir}/train/
40
+ path: ${DATA.train.video_path}/annotations-2024-224p-train.json
41
+ dataloader:
42
+ batch_size: 8
43
+ shuffle: true
44
+ num_workers: 4
45
+ pin_memory: true
46
+
47
+ valid:
48
+ type: VideoGameWithDali
49
+ classes: ${DATA.classes}
50
+ output_map: [data, label]
51
+ video_path: ${DATA.data_dir}/valid/
52
+ path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
53
+ dataloader:
54
+ batch_size: 8
55
+ shuffle: true
56
+
57
+ valid_data_frames:
58
+ type: VideoGameWithDaliVideo
59
+ classes: ${DATA.classes}
60
+ output_map: [data, label]
61
+ video_path: ${DATA.valid.video_path}
62
+ path: ${DATA.valid.path}
63
+ overlap_len: 0
64
+ dataloader:
65
+ batch_size: 4
66
+ shuffle: false
67
+
68
+ test:
69
+ type: VideoGameWithDaliVideo
70
+ classes: ${DATA.classes}
71
+ output_map: [data, label]
72
+ video_path: ${DATA.data_dir}/test/
73
+ path: ${DATA.test.video_path}/annotations-2024-224p-test.json
74
+ results: results_spotting_test
75
+ nms_window: 2
76
+ metric: tight
77
+ overlap_len: 50
78
+ dataloader:
79
+ batch_size: 4
80
+ shuffle: false
81
+
82
+ challenge:
83
+ type: VideoGameWithDaliVideo
84
+ overlap_len: 50
85
+ output_map: [data, label]
86
+ path: ${DATA.data_dir}/challenge/annotations.json
87
+ dataloader:
88
+ batch_size: 4
89
+ shuffle: false
90
+
91
+ MODEL:
92
+ type: E2E
93
+ runner:
94
+ type: runner_e2e
95
+ backbone:
96
+ type: rny008_gsm
97
+ head:
98
+ type: gru
99
+ multi_gpu: true
100
+ load_weights: null
101
+
102
+ TRAIN:
103
+ type: trainer_e2e
104
+ num_epochs: 10
105
+ acc_grad_iter: 1
106
+ base_num_valid_epochs: 30
107
+ start_valid_epoch: 4
108
+ valid_map_every: 1
109
+ criterion_valid: map
110
+
111
+ criterion:
112
+ type: CrossEntropyLoss
113
+
114
+ optimizer:
115
+ type: AdamWithScaler
116
+ lr: 0.01
117
+
118
+ scheduler:
119
+ type: ChainedSchedulerE2E
120
+ acc_grad_iter: 1
121
+ num_epochs: ${TRAIN.num_epochs}
122
+ warm_up_epochs: 3
123
+
124
+ SYSTEM:
125
+ log_dir: ./logs
126
+ save_dir: ./checkpoints
127
+ work_dir: ${SYSTEM.save_dir}
128
+ seed: 42
129
+ GPU: 4 # number of gpus to use
130
+ device: cuda # auto | cuda | cpu
131
+ gpu_id: 0 # device id for single gpu training
132
+
@@ -0,0 +1,98 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_dir: /home/spark_user1/opensportslib/sngar-frames
6
+ data_modality: frames_npy
7
+ max_samples: 100
8
+ num_frames: 16
9
+ frame_size: [224, 224]
10
+ train:
11
+ path: ${DATA.data_dir}/annotations_train.json
12
+ dataloader:
13
+ batch_size: 64
14
+ shuffle: true
15
+ num_workers: 8
16
+ pin_memory: true
17
+ valid:
18
+ path: ${DATA.data_dir}/annotations_valid.json
19
+ dataloader:
20
+ batch_size: 64
21
+ num_workers: 8
22
+ shuffle: false
23
+ test:
24
+ path: ${DATA.data_dir}/annotations_test.json
25
+ dataloader:
26
+ batch_size: 64
27
+ num_workers: 8
28
+ shuffle: false
29
+ augmentations:
30
+ random_horizontal_flip: true
31
+ flip_prob: 0.5
32
+ color_jitter: true
33
+ jitter_prob: 0.5
34
+ jitter_params: [0.2, 0.2, 0.2, 0.1]
35
+
36
+ MODEL:
37
+ type: custom
38
+ backbone:
39
+ type: dinov3 # dinov3 | clip | videomae | videomae2
40
+ pretrained_model: facebook/dinov3-vitb16-pretrain-lvd1689m
41
+ # facebook/dinov3-vitb16-pretrain-lvd1689m | openai/clip-vit-base-patch16 | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
42
+ hidden_dim: 768
43
+ freeze: true
44
+ unfreeze_last_n_layers: 0 # 0 = frozen backbone, >0 = unfreeze last N layers
45
+ neck:
46
+ type: TemporalAggregation
47
+ agr_type: maxpool # avgpool | maxpool | bilstm | tcn | attention
48
+ hidden_dim: 768
49
+ dropout: 0.1
50
+ lstm_dropout: 0.3 # only used when agr_type is bilstm
51
+ num_attention_heads: 8 # only used when agr_type is attention (8 for video, 4 for tracking)
52
+ head:
53
+ type: TrackingClassifier
54
+ hidden_dim: 64
55
+ num_classes: 10 # must match the number of action classes in the dataset
56
+ dropout: 0.1
57
+
58
+ TRAIN:
59
+ monitor: balanced_accuracy
60
+ mode: max
61
+ enabled: true
62
+ use_amp: true
63
+ mixup_alpha: 0.2
64
+ use_weighted_sampler: false
65
+ samples_per_class: 4000
66
+ use_weighted_loss: false
67
+ epochs: 100
68
+ patience: 5
69
+ log_interval: 10
70
+ save_every: 5
71
+ detailed_results: true
72
+
73
+ criterion:
74
+ type: CrossEntropyLoss
75
+
76
+ optimizer:
77
+ type: AdamW
78
+ lr: 0.0001
79
+ betas: [0.9, 0.999]
80
+ eps: 0.0000001
81
+ weight_decay: 0.0001
82
+ amsgrad: false
83
+
84
+ scheduler:
85
+ type: ReduceLROnPlateau
86
+ mode: min
87
+ patience: 5
88
+ factor: 0.1
89
+ min_lr: 1e-8
90
+
91
+ SYSTEM:
92
+ log_dir: ./logs
93
+ save_dir: ./checkpoints_video
94
+ use_seed: true
95
+ seed: 42
96
+ GPU: 1
97
+ device: cuda
98
+ gpu_id: 0
File without changes
File without changes
@@ -0,0 +1,40 @@
1
+ from opensportslib.core.loss.ce import CELoss
2
+ from .nll import NLLLoss
3
+ from .calf import ContextAwareLoss, SpottingLoss
4
+ from .combine import Combined2x
5
+
6
+
7
+ def build_criterion(cfg, default_args=None):
8
+ """Build a criterion from config dict.
9
+
10
+ Args:
11
+ cfg (dict): Config dict. It should at least contain the key "type".
12
+ default_args (dict | None, optional): Default initialization arguments.
13
+ Default: None.
14
+
15
+ Returns:
16
+ criterion: The constructed criterion.
17
+ """
18
+ if cfg.type == "NLLLoss":
19
+ criterion = NLLLoss()
20
+ elif cfg.type == "ContextAwareLoss":
21
+ criterion = ContextAwareLoss(
22
+ K=cfg.K,
23
+ framerate=cfg.framerate,
24
+ hit_radius=cfg.hit_radius,
25
+ miss_radius=cfg.miss_radius,
26
+ )
27
+ elif cfg.type == "SpottingLoss":
28
+ criterion = SpottingLoss(
29
+ lambda_coord=cfg.lambda_coord, lambda_noobj=cfg.lambda_noobj
30
+ )
31
+ elif cfg.type == "Combined2x":
32
+ c_1 = build_criterion(cfg.loss_1)
33
+ c_2 = build_criterion(cfg.loss_2)
34
+ criterion = Combined2x(c_1, c_2, cfg.w_1, cfg.w_2)
35
+
36
+ elif cfg.type == "CrossEntropyLoss":
37
+ criterion = CELoss()
38
+ else:
39
+ criterion = None
40
+ return criterion
@@ -0,0 +1,258 @@
1
+ import torch
2
+
3
+
4
+ ####################################################################################################################################################
5
+
6
+ # Context-aware loss function
7
+
8
+ ####################################################################################################################################################
9
+
10
+
11
+ class ContextAwareLoss(torch.nn.Module):
12
+ """Context Aware Loss.
13
+
14
+ Args:
15
+ K (list[list[int]]): Config dict. It should at least contain the key "type".
16
+ framerate (int): Framerate at which the features have been extracted.
17
+ Default: 2.
18
+ hit_radius (float): The hit radius.
19
+ Default: 0.1.
20
+ miss_radius (float): The miss radius.
21
+ Default: 0.9.
22
+
23
+ """
24
+
25
+ def __init__(self, K, framerate=2, hit_radius=0.1, miss_radius=0.9):
26
+
27
+ super(ContextAwareLoss, self).__init__()
28
+
29
+ self.K = torch.FloatTensor(K * framerate).cuda()
30
+ self.hit_radius = float(hit_radius)
31
+ self.miss_radius = float(miss_radius)
32
+
33
+ def forward(self, labels, output):
34
+ """Forward function.
35
+
36
+ Args:
37
+ labels (torch.Tensor): The ground truth labels.
38
+ output (torch.Tensor): The predictions.
39
+
40
+ Returns:
41
+ torch.Tensor: The returned loss.
42
+ """
43
+ K = self.K
44
+ hit_radius = self.hit_radius
45
+ miss_radius = self.miss_radius
46
+
47
+ zeros = torch.zeros(output.size()).to(output.device).type(torch.float)
48
+ output = 1.0 - output
49
+
50
+ case1 = self.DownStep(labels, K[0]) * torch.max(
51
+ zeros, -torch.log(output) + torch.log(zeros + miss_radius)
52
+ )
53
+ case2 = self.Interval(labels, K[0], K[1]) * torch.max(
54
+ zeros,
55
+ -torch.log(
56
+ output
57
+ + (1.0 - output)
58
+ * (self.PartialIdentity(labels, K[0], K[1]) - K[0])
59
+ / (K[1] - K[0])
60
+ )
61
+ + torch.log(zeros + miss_radius),
62
+ )
63
+ case3 = self.Interval(labels, K[1], 0.0) * zeros
64
+ case4 = self.Interval(labels, 0.0, K[2]) * torch.max(
65
+ zeros,
66
+ -torch.log(
67
+ 1.0
68
+ - output
69
+ + output
70
+ * (self.PartialIdentity(labels, 0.0, K[2]) - 0.0)
71
+ / (K[2] - 0.0)
72
+ )
73
+ + torch.log(zeros + 1.0 - hit_radius),
74
+ )
75
+ case5 = self.Interval(labels, K[2], K[3]) * torch.max(
76
+ zeros,
77
+ -torch.log(
78
+ output
79
+ + (1.0 - output)
80
+ * (self.PartialIdentity(labels, K[2], K[3]) - K[3])
81
+ / (K[2] - K[3])
82
+ )
83
+ + torch.log(zeros + miss_radius),
84
+ )
85
+ case6 = self.UpStep(labels, K[3]) * torch.max(
86
+ zeros, -torch.log(output) + torch.log(zeros + miss_radius)
87
+ )
88
+
89
+ L = case1 + case2 + case3 + case4 + case5 + case6
90
+
91
+ return torch.sum(L)
92
+
93
+ def UpStep(self, x, a):
94
+ """
95
+ Args :
96
+ x (torch.Tensor).
97
+ a (torch.Tensor).
98
+
99
+ Returns:
100
+ 0 if x<a, 1 if x >= a
101
+ """
102
+
103
+ return 1.0 - torch.max(0.0 * x, torch.sign(a - x))
104
+
105
+ def DownStep(self, x, a):
106
+ """
107
+ Args :
108
+ x (torch.Tensor).
109
+ a (torch.Tensor).
110
+
111
+ Returns:
112
+ 1 if x < a, 0 if x >=a
113
+ """
114
+
115
+ return torch.max(0.0 * x, torch.sign(a - x))
116
+
117
+ def Interval(self, x, a, b):
118
+ """
119
+ Args :
120
+ x (torch.Tensor).
121
+ a (torch.Tensor).
122
+ b (torch.Tensor).
123
+
124
+ Returns:
125
+ 1 if a<= x < b, 0 otherwise
126
+ """
127
+
128
+ return self.UpStep(x, a) * self.DownStep(x, b)
129
+
130
+ def PartialIdentity(self, x, a, b):
131
+ """
132
+ Args :
133
+ x (torch.Tensor).
134
+ a (torch.Tensor).
135
+ b (torch.Tensor).
136
+
137
+ Returns:
138
+ a if x<a, x if a<= x <b, b if x >= b
139
+ """
140
+
141
+ return torch.min(torch.max(x, 0.0 * x + a), 0.0 * x + b)
142
+
143
+
144
+ ####################################################################################################################################################
145
+
146
+ # Spotting loss
147
+
148
+ ####################################################################################################################################################
149
+
150
+
151
+ class SpottingLoss(torch.nn.Module):
152
+ """Spotting loss.
153
+
154
+ Args:
155
+ lambda_coord (float).
156
+ lambda_noobj (float).
157
+
158
+ """
159
+
160
+ def __init__(self, lambda_coord, lambda_noobj):
161
+ super(SpottingLoss, self).__init__()
162
+
163
+ self.lambda_coord = lambda_coord
164
+ self.lambda_noobj = lambda_noobj
165
+
166
+ def forward(self, labels, output):
167
+ """Forward function.
168
+
169
+ Args:
170
+ labels (torch.Tensor): The ground truth labels.
171
+ output (torch.Tensor): The predictions.
172
+
173
+ Returns:
174
+ torch.Tensor: The returned spotting loss.
175
+ """
176
+ output = self.permute_output_for_matching(labels, output)
177
+ loss = torch.sum(
178
+ labels[:, :, 0]
179
+ * self.lambda_coord
180
+ * torch.square(labels[:, :, 1] - output[:, :, 1])
181
+ + labels[:, :, 0] * torch.square(labels[:, :, 0] - output[:, :, 0])
182
+ + (1 - labels[:, :, 0])
183
+ * self.lambda_noobj
184
+ * torch.square(labels[:, :, 0] - output[:, :, 0])
185
+ + labels[:, :, 0]
186
+ * torch.sum(torch.square(labels[:, :, 2:] - output[:, :, 2:]), axis=-1)
187
+ ) # -labels[:,:,0]*torch.sum(labels[:,:,2:]*torch.log(output[:,:,2:]),axis=-1)
188
+ return loss
189
+
190
+ def permute_output_for_matching(self, labels, output):
191
+ """
192
+ Args:
193
+ labels (torch.Tensor): The ground truth labels.
194
+ output (torch.Tensor): The predictions.
195
+
196
+ Returns:
197
+ torch.Tensor: The permuted pred.
198
+ """
199
+ alpha = labels[:, :, 0]
200
+ x = labels[:, :, 1]
201
+ p = output[:, :, 1]
202
+ nb_pred = x.shape[-1]
203
+
204
+ D = torch.abs(
205
+ x.unsqueeze(-1).repeat(1, 1, nb_pred)
206
+ - p.unsqueeze(-2).repeat(1, nb_pred, 1)
207
+ )
208
+ D1 = 1 - D
209
+ Permut = 0 * D
210
+
211
+ alpha_filter = alpha.unsqueeze(-1).repeat(1, 1, nb_pred)
212
+
213
+ v_filter = alpha_filter
214
+ h_filter = 0 * v_filter + 1
215
+ D2 = v_filter * D1
216
+
217
+ for i in range(nb_pred):
218
+ D2 = v_filter * D2
219
+ D2 = h_filter * D2
220
+ A = torch.nn.functional.one_hot(torch.argmax(D2, axis=-1), nb_pred)
221
+ B = v_filter * A * D2
222
+ C = torch.nn.functional.one_hot(torch.argmax(B, axis=-2), nb_pred).permute(
223
+ 0, 2, 1
224
+ )
225
+ E = v_filter * A * C
226
+ Permut = Permut + E
227
+ v_filter = (1 - torch.sum(Permut, axis=-1)) * alpha
228
+ v_filter = v_filter.unsqueeze(-1).repeat(1, 1, nb_pred)
229
+ h_filter = 1 - torch.sum(Permut, axis=-2)
230
+ h_filter = h_filter.unsqueeze(-2).repeat(1, nb_pred, 1)
231
+
232
+ v_filter = 1 - alpha_filter
233
+ D2 = v_filter * D1
234
+ D2 = h_filter * D2
235
+
236
+ for i in range(nb_pred):
237
+ D2 = v_filter * D2
238
+ D2 = h_filter * D2
239
+ A = torch.nn.functional.one_hot(torch.argmax(D2, axis=-1), nb_pred)
240
+ B = v_filter * A * D2
241
+ C = torch.nn.functional.one_hot(torch.argmax(B, axis=-2), nb_pred).permute(
242
+ 0, 2, 1
243
+ )
244
+ E = v_filter * A * C
245
+ Permut = Permut + E
246
+ v_filter = (1 - torch.sum(Permut, axis=-1)) * (
247
+ 1 - alpha
248
+ ) # here comes the change
249
+ v_filter = v_filter.unsqueeze(-1).repeat(1, 1, nb_pred)
250
+ h_filter = 1 - torch.sum(Permut, axis=-2)
251
+ h_filter = h_filter.unsqueeze(-2).repeat(1, nb_pred, 1)
252
+
253
+ permutation = torch.argmax(Permut, axis=-1)
254
+ permuted = torch.gather(
255
+ output, 1, permutation.unsqueeze(-1).repeat(1, 1, labels.shape[-1])
256
+ )
257
+
258
+ return permuted
@@ -0,0 +1,23 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class CELoss(torch.nn.Module):
6
+ """Cross Entropy Loss."""
7
+
8
+ def __init__(self):
9
+ super(CELoss, self).__init__()
10
+
11
+ def forward(self, output, labels, **ce_kwargs):
12
+ """Forward function.
13
+
14
+ Args:
15
+ labels (torch.Tensor): The ground truth labels.
16
+ output (torch.Tensor): The predictions.
17
+ ce_kwargs: Any keyword argument to be used to calculate
18
+ CrossEntropy loss.
19
+
20
+ Returns:
21
+ torch.Tensor: The returned CrossEntropy loss.
22
+ """
23
+ return F.cross_entropy(output, labels, **ce_kwargs)
@@ -0,0 +1,42 @@
1
+ import torch
2
+
3
+
4
+ ####################################################################################################################################################
5
+
6
+ # Combined loss function
7
+
8
+ ####################################################################################################################################################
9
+
10
+
11
+ class Combined2x(torch.nn.Module):
12
+ """Combination of two losses.
13
+
14
+ Args:
15
+ c_1 : The first criterion.
16
+ c_2 : The second criterion.
17
+ w_1 (float): Weight for the first criterion.
18
+ w_2 (float): Weight for the second criterion.
19
+ """
20
+
21
+ def __init__(self, c_1, c_2, w_1, w_2):
22
+
23
+ super(Combined2x, self).__init__()
24
+
25
+ self.c_1 = c_1
26
+ self.c_2 = c_2
27
+ self.w_1 = w_1
28
+ self.w_2 = w_2
29
+
30
+ def forward(self, labels, output):
31
+ """Forward function.
32
+
33
+ Args:
34
+ labels (torch.Tensor): The ground truth labels.
35
+ output (torch.Tensor): The predictions.
36
+
37
+ Returns:
38
+ torch.Tensor: The returned combined loss.
39
+ """
40
+ return self.w_1 * self.c_1(labels[0], output[0]) + self.w_2 * self.c_2(
41
+ labels[1], output[1]
42
+ )
@@ -0,0 +1,25 @@
1
+
2
+ import torch
3
+
4
+
5
+ class NLLLoss(torch.nn.Module):
6
+ """Negative Log LikeLihood Loss."""
7
+
8
+ def __init__(self):
9
+ super(NLLLoss, self).__init__()
10
+
11
+ def forward(self, labels, output):
12
+ """Forward function.
13
+
14
+ Args:
15
+ labels (torch.Tensor): The ground truth labels.
16
+ output (torch.Tensor): The predictions.
17
+
18
+ Returns:
19
+ torch.Tensor: The returned negative log likelihood loss.
20
+ """
21
+ return torch.mean(
22
+ torch.mean(
23
+ labels * -torch.log(output) + (1 - labels) * -torch.log(1 - output)
24
+ )
25
+ )
File without changes
@@ -0,0 +1,38 @@
1
+ import torch
2
+
3
+
4
+ def build_optimizer(parameters, cfg, default_args=None):
5
+ """Build a optimizer from config dict.
6
+
7
+ Args:
8
+ cfg (dict): Config dict. It should at least contain the key "type".
9
+ default_args (dict | None, optional): Default initialization arguments.
10
+ Default: None.
11
+
12
+ Returns:
13
+ optimizer: The constructed optimizer.
14
+ """
15
+ if cfg.type == "Adam":
16
+ optimizer = torch.optim.Adam(
17
+ parameters,
18
+ lr=cfg.lr,
19
+ betas=tuple(getattr(cfg, 'betas', (0.9, 0.999))),
20
+ eps=getattr(cfg, 'eps', 1e-8),
21
+ weight_decay=getattr(cfg, 'weight_decay', 1e-4),
22
+ amsgrad=getattr(cfg, 'amsgrad', False),
23
+ )
24
+ elif cfg.type == "AdamWithScaler":
25
+ optimizer = (
26
+ torch.optim.AdamW(parameters, lr=cfg.lr),
27
+ torch.cuda.amp.GradScaler(),
28
+ )
29
+ elif cfg.type == "AdamW":
30
+ optimizer = torch.optim.AdamW(
31
+ parameters,
32
+ lr=cfg.lr,
33
+ betas=tuple(cfg.betas),
34
+ eps=cfg.eps,
35
+ weight_decay=cfg.weight_decay,
36
+ amsgrad=cfg.amsgrad,
37
+ )
38
+ return optimizer