opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
TASK: localization
|
|
2
|
+
|
|
3
|
+
dali: True
|
|
4
|
+
|
|
5
|
+
DATA:
|
|
6
|
+
dataset_name: SoccerNet
|
|
7
|
+
data_dir: /home/vorajv/opensportslib/SoccerNet/annotations/
|
|
8
|
+
classes:
|
|
9
|
+
- PASS
|
|
10
|
+
- DRIVE
|
|
11
|
+
- HEADER
|
|
12
|
+
- HIGH PASS
|
|
13
|
+
- OUT
|
|
14
|
+
- CROSS
|
|
15
|
+
- THROW IN
|
|
16
|
+
- SHOT
|
|
17
|
+
- BALL PLAYER BLOCK
|
|
18
|
+
- PLAYER SUCCESSFUL TACKLE
|
|
19
|
+
- FREE KICK
|
|
20
|
+
- GOAL
|
|
21
|
+
|
|
22
|
+
epoch_num_frames: 500000
|
|
23
|
+
mixup: true
|
|
24
|
+
modality: rgb
|
|
25
|
+
crop_dim: -1
|
|
26
|
+
dilate_len: 0 # Dilate ground truth labels
|
|
27
|
+
clip_len: 100
|
|
28
|
+
input_fps: 25
|
|
29
|
+
extract_fps: 2
|
|
30
|
+
imagenet_mean: [0.485, 0.456, 0.406]
|
|
31
|
+
imagenet_std: [0.229, 0.224, 0.225]
|
|
32
|
+
target_height: 224
|
|
33
|
+
target_width: 398
|
|
34
|
+
|
|
35
|
+
train:
|
|
36
|
+
type: VideoGameWithDali
|
|
37
|
+
classes: ${DATA.classes}
|
|
38
|
+
output_map: [data, label]
|
|
39
|
+
video_path: ${DATA.data_dir}/train/
|
|
40
|
+
path: ${DATA.train.video_path}/annotations-2024-224p-train.json
|
|
41
|
+
dataloader:
|
|
42
|
+
batch_size: 8
|
|
43
|
+
shuffle: true
|
|
44
|
+
num_workers: 4
|
|
45
|
+
pin_memory: true
|
|
46
|
+
|
|
47
|
+
valid:
|
|
48
|
+
type: VideoGameWithDali
|
|
49
|
+
classes: ${DATA.classes}
|
|
50
|
+
output_map: [data, label]
|
|
51
|
+
video_path: ${DATA.data_dir}/valid/
|
|
52
|
+
path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
|
|
53
|
+
dataloader:
|
|
54
|
+
batch_size: 8
|
|
55
|
+
shuffle: true
|
|
56
|
+
|
|
57
|
+
valid_data_frames:
|
|
58
|
+
type: VideoGameWithDaliVideo
|
|
59
|
+
classes: ${DATA.classes}
|
|
60
|
+
output_map: [data, label]
|
|
61
|
+
video_path: ${DATA.valid.video_path}
|
|
62
|
+
path: ${DATA.valid.path}
|
|
63
|
+
overlap_len: 0
|
|
64
|
+
dataloader:
|
|
65
|
+
batch_size: 4
|
|
66
|
+
shuffle: false
|
|
67
|
+
|
|
68
|
+
test:
|
|
69
|
+
type: VideoGameWithDaliVideo
|
|
70
|
+
classes: ${DATA.classes}
|
|
71
|
+
output_map: [data, label]
|
|
72
|
+
video_path: ${DATA.data_dir}/test/
|
|
73
|
+
path: ${DATA.test.video_path}/annotations-2024-224p-test.json
|
|
74
|
+
results: results_spotting_test
|
|
75
|
+
nms_window: 2
|
|
76
|
+
metric: tight
|
|
77
|
+
overlap_len: 50
|
|
78
|
+
dataloader:
|
|
79
|
+
batch_size: 4
|
|
80
|
+
shuffle: false
|
|
81
|
+
|
|
82
|
+
challenge:
|
|
83
|
+
type: VideoGameWithDaliVideo
|
|
84
|
+
overlap_len: 50
|
|
85
|
+
output_map: [data, label]
|
|
86
|
+
path: ${DATA.data_dir}/challenge/annotations.json
|
|
87
|
+
dataloader:
|
|
88
|
+
batch_size: 4
|
|
89
|
+
shuffle: false
|
|
90
|
+
|
|
91
|
+
MODEL:
|
|
92
|
+
type: E2E
|
|
93
|
+
runner:
|
|
94
|
+
type: runner_e2e
|
|
95
|
+
backbone:
|
|
96
|
+
type: rny008_gsm
|
|
97
|
+
head:
|
|
98
|
+
type: gru
|
|
99
|
+
multi_gpu: true
|
|
100
|
+
load_weights: null
|
|
101
|
+
|
|
102
|
+
TRAIN:
|
|
103
|
+
type: trainer_e2e
|
|
104
|
+
num_epochs: 10
|
|
105
|
+
acc_grad_iter: 1
|
|
106
|
+
base_num_valid_epochs: 30
|
|
107
|
+
start_valid_epoch: 4
|
|
108
|
+
valid_map_every: 1
|
|
109
|
+
criterion_valid: map
|
|
110
|
+
|
|
111
|
+
criterion:
|
|
112
|
+
type: CrossEntropyLoss
|
|
113
|
+
|
|
114
|
+
optimizer:
|
|
115
|
+
type: AdamWithScaler
|
|
116
|
+
lr: 0.01
|
|
117
|
+
|
|
118
|
+
scheduler:
|
|
119
|
+
type: ChainedSchedulerE2E
|
|
120
|
+
acc_grad_iter: 1
|
|
121
|
+
num_epochs: ${TRAIN.num_epochs}
|
|
122
|
+
warm_up_epochs: 3
|
|
123
|
+
|
|
124
|
+
SYSTEM:
|
|
125
|
+
log_dir: ./logs
|
|
126
|
+
save_dir: ./checkpoints
|
|
127
|
+
work_dir: ${SYSTEM.save_dir}
|
|
128
|
+
seed: 42
|
|
129
|
+
GPU: 4 # number of gpus to use
|
|
130
|
+
device: cuda # auto | cuda | cpu
|
|
131
|
+
gpu_id: 0 # device id for single gpu training
|
|
132
|
+
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
TASK: classification
|
|
2
|
+
|
|
3
|
+
DATA:
|
|
4
|
+
dataset_name: sngar
|
|
5
|
+
data_dir: /home/spark_user1/opensportslib/sngar-frames
|
|
6
|
+
data_modality: frames_npy
|
|
7
|
+
max_samples: 100
|
|
8
|
+
num_frames: 16
|
|
9
|
+
frame_size: [224, 224]
|
|
10
|
+
train:
|
|
11
|
+
path: ${DATA.data_dir}/annotations_train.json
|
|
12
|
+
dataloader:
|
|
13
|
+
batch_size: 64
|
|
14
|
+
shuffle: true
|
|
15
|
+
num_workers: 8
|
|
16
|
+
pin_memory: true
|
|
17
|
+
valid:
|
|
18
|
+
path: ${DATA.data_dir}/annotations_valid.json
|
|
19
|
+
dataloader:
|
|
20
|
+
batch_size: 64
|
|
21
|
+
num_workers: 8
|
|
22
|
+
shuffle: false
|
|
23
|
+
test:
|
|
24
|
+
path: ${DATA.data_dir}/annotations_test.json
|
|
25
|
+
dataloader:
|
|
26
|
+
batch_size: 64
|
|
27
|
+
num_workers: 8
|
|
28
|
+
shuffle: false
|
|
29
|
+
augmentations:
|
|
30
|
+
random_horizontal_flip: true
|
|
31
|
+
flip_prob: 0.5
|
|
32
|
+
color_jitter: true
|
|
33
|
+
jitter_prob: 0.5
|
|
34
|
+
jitter_params: [0.2, 0.2, 0.2, 0.1]
|
|
35
|
+
|
|
36
|
+
MODEL:
|
|
37
|
+
type: custom
|
|
38
|
+
backbone:
|
|
39
|
+
type: dinov3 # dinov3 | clip | videomae | videomae2
|
|
40
|
+
pretrained_model: facebook/dinov3-vitb16-pretrain-lvd1689m
|
|
41
|
+
# facebook/dinov3-vitb16-pretrain-lvd1689m | openai/clip-vit-base-patch16 | MCG-NJU/videomae-base | OpenGVLab/VideoMAEv2-Base
|
|
42
|
+
hidden_dim: 768
|
|
43
|
+
freeze: true
|
|
44
|
+
unfreeze_last_n_layers: 0 # 0 = frozen backbone, >0 = unfreeze last N layers
|
|
45
|
+
neck:
|
|
46
|
+
type: TemporalAggregation
|
|
47
|
+
agr_type: maxpool # avgpool | maxpool | bilstm | tcn | attention
|
|
48
|
+
hidden_dim: 768
|
|
49
|
+
dropout: 0.1
|
|
50
|
+
lstm_dropout: 0.3 # only used when agr_type is bilstm
|
|
51
|
+
num_attention_heads: 8 # only used when agr_type is attention (8 for video, 4 for tracking)
|
|
52
|
+
head:
|
|
53
|
+
type: TrackingClassifier
|
|
54
|
+
hidden_dim: 64
|
|
55
|
+
num_classes: 10 # must match the number of action classes in the dataset
|
|
56
|
+
dropout: 0.1
|
|
57
|
+
|
|
58
|
+
TRAIN:
|
|
59
|
+
monitor: balanced_accuracy
|
|
60
|
+
mode: max
|
|
61
|
+
enabled: true
|
|
62
|
+
use_amp: true
|
|
63
|
+
mixup_alpha: 0.2
|
|
64
|
+
use_weighted_sampler: false
|
|
65
|
+
samples_per_class: 4000
|
|
66
|
+
use_weighted_loss: false
|
|
67
|
+
epochs: 100
|
|
68
|
+
patience: 5
|
|
69
|
+
log_interval: 10
|
|
70
|
+
save_every: 5
|
|
71
|
+
detailed_results: true
|
|
72
|
+
|
|
73
|
+
criterion:
|
|
74
|
+
type: CrossEntropyLoss
|
|
75
|
+
|
|
76
|
+
optimizer:
|
|
77
|
+
type: AdamW
|
|
78
|
+
lr: 0.0001
|
|
79
|
+
betas: [0.9, 0.999]
|
|
80
|
+
eps: 0.0000001
|
|
81
|
+
weight_decay: 0.0001
|
|
82
|
+
amsgrad: false
|
|
83
|
+
|
|
84
|
+
scheduler:
|
|
85
|
+
type: ReduceLROnPlateau
|
|
86
|
+
mode: min
|
|
87
|
+
patience: 5
|
|
88
|
+
factor: 0.1
|
|
89
|
+
min_lr: 1e-8
|
|
90
|
+
|
|
91
|
+
SYSTEM:
|
|
92
|
+
log_dir: ./logs
|
|
93
|
+
save_dir: ./checkpoints_video
|
|
94
|
+
use_seed: true
|
|
95
|
+
seed: 42
|
|
96
|
+
GPU: 1
|
|
97
|
+
device: cuda
|
|
98
|
+
gpu_id: 0
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from opensportslib.core.loss.ce import CELoss
|
|
2
|
+
from .nll import NLLLoss
|
|
3
|
+
from .calf import ContextAwareLoss, SpottingLoss
|
|
4
|
+
from .combine import Combined2x
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def build_criterion(cfg, default_args=None):
|
|
8
|
+
"""Build a criterion from config dict.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
12
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
13
|
+
Default: None.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
criterion: The constructed criterion.
|
|
17
|
+
"""
|
|
18
|
+
if cfg.type == "NLLLoss":
|
|
19
|
+
criterion = NLLLoss()
|
|
20
|
+
elif cfg.type == "ContextAwareLoss":
|
|
21
|
+
criterion = ContextAwareLoss(
|
|
22
|
+
K=cfg.K,
|
|
23
|
+
framerate=cfg.framerate,
|
|
24
|
+
hit_radius=cfg.hit_radius,
|
|
25
|
+
miss_radius=cfg.miss_radius,
|
|
26
|
+
)
|
|
27
|
+
elif cfg.type == "SpottingLoss":
|
|
28
|
+
criterion = SpottingLoss(
|
|
29
|
+
lambda_coord=cfg.lambda_coord, lambda_noobj=cfg.lambda_noobj
|
|
30
|
+
)
|
|
31
|
+
elif cfg.type == "Combined2x":
|
|
32
|
+
c_1 = build_criterion(cfg.loss_1)
|
|
33
|
+
c_2 = build_criterion(cfg.loss_2)
|
|
34
|
+
criterion = Combined2x(c_1, c_2, cfg.w_1, cfg.w_2)
|
|
35
|
+
|
|
36
|
+
elif cfg.type == "CrossEntropyLoss":
|
|
37
|
+
criterion = CELoss()
|
|
38
|
+
else:
|
|
39
|
+
criterion = None
|
|
40
|
+
return criterion
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
####################################################################################################################################################
|
|
5
|
+
|
|
6
|
+
# Context-aware loss function
|
|
7
|
+
|
|
8
|
+
####################################################################################################################################################
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContextAwareLoss(torch.nn.Module):
|
|
12
|
+
"""Context Aware Loss.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
K (list[list[int]]): Config dict. It should at least contain the key "type".
|
|
16
|
+
framerate (int): Framerate at which the features have been extracted.
|
|
17
|
+
Default: 2.
|
|
18
|
+
hit_radius (float): The hit radius.
|
|
19
|
+
Default: 0.1.
|
|
20
|
+
miss_radius (float): The miss radius.
|
|
21
|
+
Default: 0.9.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, K, framerate=2, hit_radius=0.1, miss_radius=0.9):
|
|
26
|
+
|
|
27
|
+
super(ContextAwareLoss, self).__init__()
|
|
28
|
+
|
|
29
|
+
self.K = torch.FloatTensor(K * framerate).cuda()
|
|
30
|
+
self.hit_radius = float(hit_radius)
|
|
31
|
+
self.miss_radius = float(miss_radius)
|
|
32
|
+
|
|
33
|
+
def forward(self, labels, output):
|
|
34
|
+
"""Forward function.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
labels (torch.Tensor): The ground truth labels.
|
|
38
|
+
output (torch.Tensor): The predictions.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
torch.Tensor: The returned loss.
|
|
42
|
+
"""
|
|
43
|
+
K = self.K
|
|
44
|
+
hit_radius = self.hit_radius
|
|
45
|
+
miss_radius = self.miss_radius
|
|
46
|
+
|
|
47
|
+
zeros = torch.zeros(output.size()).to(output.device).type(torch.float)
|
|
48
|
+
output = 1.0 - output
|
|
49
|
+
|
|
50
|
+
case1 = self.DownStep(labels, K[0]) * torch.max(
|
|
51
|
+
zeros, -torch.log(output) + torch.log(zeros + miss_radius)
|
|
52
|
+
)
|
|
53
|
+
case2 = self.Interval(labels, K[0], K[1]) * torch.max(
|
|
54
|
+
zeros,
|
|
55
|
+
-torch.log(
|
|
56
|
+
output
|
|
57
|
+
+ (1.0 - output)
|
|
58
|
+
* (self.PartialIdentity(labels, K[0], K[1]) - K[0])
|
|
59
|
+
/ (K[1] - K[0])
|
|
60
|
+
)
|
|
61
|
+
+ torch.log(zeros + miss_radius),
|
|
62
|
+
)
|
|
63
|
+
case3 = self.Interval(labels, K[1], 0.0) * zeros
|
|
64
|
+
case4 = self.Interval(labels, 0.0, K[2]) * torch.max(
|
|
65
|
+
zeros,
|
|
66
|
+
-torch.log(
|
|
67
|
+
1.0
|
|
68
|
+
- output
|
|
69
|
+
+ output
|
|
70
|
+
* (self.PartialIdentity(labels, 0.0, K[2]) - 0.0)
|
|
71
|
+
/ (K[2] - 0.0)
|
|
72
|
+
)
|
|
73
|
+
+ torch.log(zeros + 1.0 - hit_radius),
|
|
74
|
+
)
|
|
75
|
+
case5 = self.Interval(labels, K[2], K[3]) * torch.max(
|
|
76
|
+
zeros,
|
|
77
|
+
-torch.log(
|
|
78
|
+
output
|
|
79
|
+
+ (1.0 - output)
|
|
80
|
+
* (self.PartialIdentity(labels, K[2], K[3]) - K[3])
|
|
81
|
+
/ (K[2] - K[3])
|
|
82
|
+
)
|
|
83
|
+
+ torch.log(zeros + miss_radius),
|
|
84
|
+
)
|
|
85
|
+
case6 = self.UpStep(labels, K[3]) * torch.max(
|
|
86
|
+
zeros, -torch.log(output) + torch.log(zeros + miss_radius)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
L = case1 + case2 + case3 + case4 + case5 + case6
|
|
90
|
+
|
|
91
|
+
return torch.sum(L)
|
|
92
|
+
|
|
93
|
+
def UpStep(self, x, a):
|
|
94
|
+
"""
|
|
95
|
+
Args :
|
|
96
|
+
x (torch.Tensor).
|
|
97
|
+
a (torch.Tensor).
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
0 if x<a, 1 if x >= a
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
return 1.0 - torch.max(0.0 * x, torch.sign(a - x))
|
|
104
|
+
|
|
105
|
+
def DownStep(self, x, a):
|
|
106
|
+
"""
|
|
107
|
+
Args :
|
|
108
|
+
x (torch.Tensor).
|
|
109
|
+
a (torch.Tensor).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
1 if x < a, 0 if x >=a
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return torch.max(0.0 * x, torch.sign(a - x))
|
|
116
|
+
|
|
117
|
+
def Interval(self, x, a, b):
|
|
118
|
+
"""
|
|
119
|
+
Args :
|
|
120
|
+
x (torch.Tensor).
|
|
121
|
+
a (torch.Tensor).
|
|
122
|
+
b (torch.Tensor).
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
1 if a<= x < b, 0 otherwise
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
return self.UpStep(x, a) * self.DownStep(x, b)
|
|
129
|
+
|
|
130
|
+
def PartialIdentity(self, x, a, b):
|
|
131
|
+
"""
|
|
132
|
+
Args :
|
|
133
|
+
x (torch.Tensor).
|
|
134
|
+
a (torch.Tensor).
|
|
135
|
+
b (torch.Tensor).
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
a if x<a, x if a<= x <b, b if x >= b
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
return torch.min(torch.max(x, 0.0 * x + a), 0.0 * x + b)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
####################################################################################################################################################
|
|
145
|
+
|
|
146
|
+
# Spotting loss
|
|
147
|
+
|
|
148
|
+
####################################################################################################################################################
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class SpottingLoss(torch.nn.Module):
|
|
152
|
+
"""Spotting loss.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
lambda_coord (float).
|
|
156
|
+
lambda_noobj (float).
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, lambda_coord, lambda_noobj):
|
|
161
|
+
super(SpottingLoss, self).__init__()
|
|
162
|
+
|
|
163
|
+
self.lambda_coord = lambda_coord
|
|
164
|
+
self.lambda_noobj = lambda_noobj
|
|
165
|
+
|
|
166
|
+
def forward(self, labels, output):
|
|
167
|
+
"""Forward function.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
labels (torch.Tensor): The ground truth labels.
|
|
171
|
+
output (torch.Tensor): The predictions.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
torch.Tensor: The returned spotting loss.
|
|
175
|
+
"""
|
|
176
|
+
output = self.permute_output_for_matching(labels, output)
|
|
177
|
+
loss = torch.sum(
|
|
178
|
+
labels[:, :, 0]
|
|
179
|
+
* self.lambda_coord
|
|
180
|
+
* torch.square(labels[:, :, 1] - output[:, :, 1])
|
|
181
|
+
+ labels[:, :, 0] * torch.square(labels[:, :, 0] - output[:, :, 0])
|
|
182
|
+
+ (1 - labels[:, :, 0])
|
|
183
|
+
* self.lambda_noobj
|
|
184
|
+
* torch.square(labels[:, :, 0] - output[:, :, 0])
|
|
185
|
+
+ labels[:, :, 0]
|
|
186
|
+
* torch.sum(torch.square(labels[:, :, 2:] - output[:, :, 2:]), axis=-1)
|
|
187
|
+
) # -labels[:,:,0]*torch.sum(labels[:,:,2:]*torch.log(output[:,:,2:]),axis=-1)
|
|
188
|
+
return loss
|
|
189
|
+
|
|
190
|
+
def permute_output_for_matching(self, labels, output):
|
|
191
|
+
"""
|
|
192
|
+
Args:
|
|
193
|
+
labels (torch.Tensor): The ground truth labels.
|
|
194
|
+
output (torch.Tensor): The predictions.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
torch.Tensor: The permuted pred.
|
|
198
|
+
"""
|
|
199
|
+
alpha = labels[:, :, 0]
|
|
200
|
+
x = labels[:, :, 1]
|
|
201
|
+
p = output[:, :, 1]
|
|
202
|
+
nb_pred = x.shape[-1]
|
|
203
|
+
|
|
204
|
+
D = torch.abs(
|
|
205
|
+
x.unsqueeze(-1).repeat(1, 1, nb_pred)
|
|
206
|
+
- p.unsqueeze(-2).repeat(1, nb_pred, 1)
|
|
207
|
+
)
|
|
208
|
+
D1 = 1 - D
|
|
209
|
+
Permut = 0 * D
|
|
210
|
+
|
|
211
|
+
alpha_filter = alpha.unsqueeze(-1).repeat(1, 1, nb_pred)
|
|
212
|
+
|
|
213
|
+
v_filter = alpha_filter
|
|
214
|
+
h_filter = 0 * v_filter + 1
|
|
215
|
+
D2 = v_filter * D1
|
|
216
|
+
|
|
217
|
+
for i in range(nb_pred):
|
|
218
|
+
D2 = v_filter * D2
|
|
219
|
+
D2 = h_filter * D2
|
|
220
|
+
A = torch.nn.functional.one_hot(torch.argmax(D2, axis=-1), nb_pred)
|
|
221
|
+
B = v_filter * A * D2
|
|
222
|
+
C = torch.nn.functional.one_hot(torch.argmax(B, axis=-2), nb_pred).permute(
|
|
223
|
+
0, 2, 1
|
|
224
|
+
)
|
|
225
|
+
E = v_filter * A * C
|
|
226
|
+
Permut = Permut + E
|
|
227
|
+
v_filter = (1 - torch.sum(Permut, axis=-1)) * alpha
|
|
228
|
+
v_filter = v_filter.unsqueeze(-1).repeat(1, 1, nb_pred)
|
|
229
|
+
h_filter = 1 - torch.sum(Permut, axis=-2)
|
|
230
|
+
h_filter = h_filter.unsqueeze(-2).repeat(1, nb_pred, 1)
|
|
231
|
+
|
|
232
|
+
v_filter = 1 - alpha_filter
|
|
233
|
+
D2 = v_filter * D1
|
|
234
|
+
D2 = h_filter * D2
|
|
235
|
+
|
|
236
|
+
for i in range(nb_pred):
|
|
237
|
+
D2 = v_filter * D2
|
|
238
|
+
D2 = h_filter * D2
|
|
239
|
+
A = torch.nn.functional.one_hot(torch.argmax(D2, axis=-1), nb_pred)
|
|
240
|
+
B = v_filter * A * D2
|
|
241
|
+
C = torch.nn.functional.one_hot(torch.argmax(B, axis=-2), nb_pred).permute(
|
|
242
|
+
0, 2, 1
|
|
243
|
+
)
|
|
244
|
+
E = v_filter * A * C
|
|
245
|
+
Permut = Permut + E
|
|
246
|
+
v_filter = (1 - torch.sum(Permut, axis=-1)) * (
|
|
247
|
+
1 - alpha
|
|
248
|
+
) # here comes the change
|
|
249
|
+
v_filter = v_filter.unsqueeze(-1).repeat(1, 1, nb_pred)
|
|
250
|
+
h_filter = 1 - torch.sum(Permut, axis=-2)
|
|
251
|
+
h_filter = h_filter.unsqueeze(-2).repeat(1, nb_pred, 1)
|
|
252
|
+
|
|
253
|
+
permutation = torch.argmax(Permut, axis=-1)
|
|
254
|
+
permuted = torch.gather(
|
|
255
|
+
output, 1, permutation.unsqueeze(-1).repeat(1, 1, labels.shape[-1])
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return permuted
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CELoss(torch.nn.Module):
|
|
6
|
+
"""Cross Entropy Loss."""
|
|
7
|
+
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super(CELoss, self).__init__()
|
|
10
|
+
|
|
11
|
+
def forward(self, output, labels, **ce_kwargs):
|
|
12
|
+
"""Forward function.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
labels (torch.Tensor): The ground truth labels.
|
|
16
|
+
output (torch.Tensor): The predictions.
|
|
17
|
+
ce_kwargs: Any keyword argument to be used to calculate
|
|
18
|
+
CrossEntropy loss.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
torch.Tensor: The returned CrossEntropy loss.
|
|
22
|
+
"""
|
|
23
|
+
return F.cross_entropy(output, labels, **ce_kwargs)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
####################################################################################################################################################
|
|
5
|
+
|
|
6
|
+
# Combined loss function
|
|
7
|
+
|
|
8
|
+
####################################################################################################################################################
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Combined2x(torch.nn.Module):
|
|
12
|
+
"""Combination of two losses.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
c_1 : The first criterion.
|
|
16
|
+
c_2 : The second criterion.
|
|
17
|
+
w_1 (float): Weight for the first criterion.
|
|
18
|
+
w_2 (float): Weight for the second criterion.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, c_1, c_2, w_1, w_2):
|
|
22
|
+
|
|
23
|
+
super(Combined2x, self).__init__()
|
|
24
|
+
|
|
25
|
+
self.c_1 = c_1
|
|
26
|
+
self.c_2 = c_2
|
|
27
|
+
self.w_1 = w_1
|
|
28
|
+
self.w_2 = w_2
|
|
29
|
+
|
|
30
|
+
def forward(self, labels, output):
|
|
31
|
+
"""Forward function.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
labels (torch.Tensor): The ground truth labels.
|
|
35
|
+
output (torch.Tensor): The predictions.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
torch.Tensor: The returned combined loss.
|
|
39
|
+
"""
|
|
40
|
+
return self.w_1 * self.c_1(labels[0], output[0]) + self.w_2 * self.c_2(
|
|
41
|
+
labels[1], output[1]
|
|
42
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class NLLLoss(torch.nn.Module):
|
|
6
|
+
"""Negative Log LikeLihood Loss."""
|
|
7
|
+
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super(NLLLoss, self).__init__()
|
|
10
|
+
|
|
11
|
+
def forward(self, labels, output):
|
|
12
|
+
"""Forward function.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
labels (torch.Tensor): The ground truth labels.
|
|
16
|
+
output (torch.Tensor): The predictions.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
torch.Tensor: The returned negative log likelihood loss.
|
|
20
|
+
"""
|
|
21
|
+
return torch.mean(
|
|
22
|
+
torch.mean(
|
|
23
|
+
labels * -torch.log(output) + (1 - labels) * -torch.log(1 - output)
|
|
24
|
+
)
|
|
25
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def build_optimizer(parameters, cfg, default_args=None):
|
|
5
|
+
"""Build a optimizer from config dict.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
9
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
10
|
+
Default: None.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
optimizer: The constructed optimizer.
|
|
14
|
+
"""
|
|
15
|
+
if cfg.type == "Adam":
|
|
16
|
+
optimizer = torch.optim.Adam(
|
|
17
|
+
parameters,
|
|
18
|
+
lr=cfg.lr,
|
|
19
|
+
betas=tuple(getattr(cfg, 'betas', (0.9, 0.999))),
|
|
20
|
+
eps=getattr(cfg, 'eps', 1e-8),
|
|
21
|
+
weight_decay=getattr(cfg, 'weight_decay', 1e-4),
|
|
22
|
+
amsgrad=getattr(cfg, 'amsgrad', False),
|
|
23
|
+
)
|
|
24
|
+
elif cfg.type == "AdamWithScaler":
|
|
25
|
+
optimizer = (
|
|
26
|
+
torch.optim.AdamW(parameters, lr=cfg.lr),
|
|
27
|
+
torch.cuda.amp.GradScaler(),
|
|
28
|
+
)
|
|
29
|
+
elif cfg.type == "AdamW":
|
|
30
|
+
optimizer = torch.optim.AdamW(
|
|
31
|
+
parameters,
|
|
32
|
+
lr=cfg.lr,
|
|
33
|
+
betas=tuple(cfg.betas),
|
|
34
|
+
eps=cfg.eps,
|
|
35
|
+
weight_decay=cfg.weight_decay,
|
|
36
|
+
amsgrad=cfg.amsgrad,
|
|
37
|
+
)
|
|
38
|
+
return optimizer
|