autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,394 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from typing import Optional
|
3
|
-
|
4
|
-
import torch
|
5
|
-
import torch.nn as nn
|
6
|
-
import torch.nn.functional as F
|
7
|
-
|
8
|
-
try:
|
9
|
-
import torch.distributed.nn
|
10
|
-
from torch import distributed as dist
|
11
|
-
|
12
|
-
has_distributed = True
|
13
|
-
except ImportError:
|
14
|
-
has_distributed = False
|
15
|
-
|
16
|
-
try:
|
17
|
-
import horovod.torch as hvd
|
18
|
-
except ImportError:
|
19
|
-
hvd = None
|
20
|
-
|
21
|
-
logger = logging.getLogger(__name__)
|
22
|
-
|
23
|
-
|
24
|
-
class RKDLoss(nn.Module):
|
25
|
-
"""
|
26
|
-
Compute RKD Distance Loss.
|
27
|
-
Paper Refer to: Relational Knowledge Disitllation, CVPR2019. https://arxiv.org/abs/1904.05068
|
28
|
-
Code Refer to: https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/RKD.py
|
29
|
-
and https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(self, distance_loss_weight: Optional[float] = 25.0, angle_loss_weight: Optional[float] = 50.0):
|
33
|
-
"""
|
34
|
-
Parameters
|
35
|
-
----------
|
36
|
-
distance_loss_weight
|
37
|
-
Weight of RKD distance loss
|
38
|
-
angle_loss_weight
|
39
|
-
Weight of RKD angle loss
|
40
|
-
Returns
|
41
|
-
-------
|
42
|
-
"""
|
43
|
-
super(RKDLoss, self).__init__()
|
44
|
-
self.distance_loss_weight = distance_loss_weight
|
45
|
-
self.angle_loss_weight = angle_loss_weight
|
46
|
-
|
47
|
-
def forward(self, feature_student: Optional[torch.Tensor], feature_teacher: Optional[torch.Tensor]):
|
48
|
-
"""
|
49
|
-
Parameters
|
50
|
-
----------
|
51
|
-
feature_student
|
52
|
-
Output feature of student model, shape: (N, D)
|
53
|
-
feature_teacher
|
54
|
-
Output feature of teacher model, shape: (N, D)
|
55
|
-
Returns
|
56
|
-
-------
|
57
|
-
The RKD Loss between teacher and student
|
58
|
-
"""
|
59
|
-
# RKD loss
|
60
|
-
if self.distance_loss_weight > 0:
|
61
|
-
with torch.no_grad():
|
62
|
-
t_dist = self.pdist(feature_teacher, squared=False)
|
63
|
-
mean_td = t_dist[t_dist > 0].mean()
|
64
|
-
t_dist = t_dist / mean_td
|
65
|
-
|
66
|
-
s_dist = self.pdist(feature_student, squared=False)
|
67
|
-
mean_d = s_dist[s_dist > 0].mean()
|
68
|
-
s_dist = s_dist / mean_d
|
69
|
-
|
70
|
-
loss_distance = F.smooth_l1_loss(s_dist, t_dist)
|
71
|
-
|
72
|
-
# RKD Angle loss
|
73
|
-
if self.angle_loss_weight > 0:
|
74
|
-
with torch.no_grad():
|
75
|
-
td = feature_teacher.unsqueeze(0) - feature_teacher.unsqueeze(1)
|
76
|
-
norm_td = F.normalize(td, p=2, dim=2)
|
77
|
-
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
|
78
|
-
|
79
|
-
sd = feature_student.unsqueeze(0) - feature_student.unsqueeze(1)
|
80
|
-
norm_sd = F.normalize(sd, p=2, dim=2)
|
81
|
-
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
|
82
|
-
|
83
|
-
loss_angle = F.smooth_l1_loss(s_angle, t_angle)
|
84
|
-
|
85
|
-
loss = ((self.distance_loss_weight * loss_distance) if self.distance_loss_weight > 0 else 0) + (
|
86
|
-
(self.angle_loss_weight * loss_angle) if self.angle_loss_weight > 0 else 0
|
87
|
-
)
|
88
|
-
|
89
|
-
return loss
|
90
|
-
|
91
|
-
@staticmethod
|
92
|
-
def pdist(embeddings: Optional[torch.Tensor], squared: Optional[bool] = False, eps: Optional[float] = 1e-12):
|
93
|
-
"""
|
94
|
-
Compute pairwise Euclidean distances between embeddings in n-dimensional space.
|
95
|
-
|
96
|
-
Parameters
|
97
|
-
----------
|
98
|
-
embeddings
|
99
|
-
The embeddings to compute pairwise distance between. Shape: (N,D)
|
100
|
-
squared
|
101
|
-
If the result is square of Euclidean distance.
|
102
|
-
eps
|
103
|
-
Min value of each entry.
|
104
|
-
|
105
|
-
Returns
|
106
|
-
-------
|
107
|
-
Pairwise Euclidean distances. Shape: (N,N)
|
108
|
-
"""
|
109
|
-
e_square = embeddings.pow(2).sum(dim=1)
|
110
|
-
prod = embeddings @ embeddings.t()
|
111
|
-
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
|
112
|
-
|
113
|
-
if not squared:
|
114
|
-
res = res.sqrt()
|
115
|
-
|
116
|
-
res = res.clone()
|
117
|
-
res[range(len(embeddings)), range(len(embeddings))] = 0
|
118
|
-
|
119
|
-
return res
|
120
|
-
|
121
|
-
|
122
|
-
class SoftTargetCrossEntropy(nn.Module):
|
123
|
-
"""
|
124
|
-
The soft target CrossEntropy from timm.
|
125
|
-
https://github.com/rwightman/pytorch-image-models/blob/e4360e6125bb0bb4279785810c8eb33b40af3ebd/timm/loss/cross_entropy.py
|
126
|
-
It works under the mixup.
|
127
|
-
It can calculate the crossentropy of input and label with one-hot.
|
128
|
-
"""
|
129
|
-
|
130
|
-
def __init__(self):
|
131
|
-
super(SoftTargetCrossEntropy, self).__init__()
|
132
|
-
|
133
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
134
|
-
loss = torch.sum(-target * F.log_softmax(input, dim=-1), dim=-1)
|
135
|
-
return loss.mean()
|
136
|
-
|
137
|
-
|
138
|
-
def gather_features(
|
139
|
-
image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False
|
140
|
-
):
|
141
|
-
"""
|
142
|
-
Gather features across GPUs.
|
143
|
-
|
144
|
-
Parameters
|
145
|
-
----------
|
146
|
-
image_features
|
147
|
-
image features of the current process.
|
148
|
-
text_features
|
149
|
-
text features of the current process.
|
150
|
-
local_loss
|
151
|
-
If False, make sure the features on the current GPU have gradients.
|
152
|
-
gather_with_grad
|
153
|
-
Whether to gather all features with gradients enabled.
|
154
|
-
rank
|
155
|
-
Rank of the current process (it should be a number between 0 and world_size-1).
|
156
|
-
world_size
|
157
|
-
Number of processes participating in the job.
|
158
|
-
use_horovod
|
159
|
-
Whether to use horovod.
|
160
|
-
|
161
|
-
Returns
|
162
|
-
-------
|
163
|
-
Gathered image and text features from all processes.
|
164
|
-
"""
|
165
|
-
assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
|
166
|
-
if use_horovod:
|
167
|
-
assert hvd is not None, "Please install horovod"
|
168
|
-
if gather_with_grad:
|
169
|
-
all_image_features = hvd.allgather(image_features)
|
170
|
-
all_text_features = hvd.allgather(text_features)
|
171
|
-
else:
|
172
|
-
with torch.no_grad():
|
173
|
-
all_image_features = hvd.allgather(image_features)
|
174
|
-
all_text_features = hvd.allgather(text_features)
|
175
|
-
if not local_loss:
|
176
|
-
# ensure grads for local rank when all_* features don't have a gradient
|
177
|
-
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
178
|
-
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
179
|
-
gathered_image_features[rank] = image_features
|
180
|
-
gathered_text_features[rank] = text_features
|
181
|
-
all_image_features = torch.cat(gathered_image_features, dim=0)
|
182
|
-
all_text_features = torch.cat(gathered_text_features, dim=0)
|
183
|
-
else:
|
184
|
-
# We gather tensors from all gpus
|
185
|
-
if gather_with_grad:
|
186
|
-
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
187
|
-
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
188
|
-
else:
|
189
|
-
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
190
|
-
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
191
|
-
dist.all_gather(gathered_image_features, image_features)
|
192
|
-
dist.all_gather(gathered_text_features, text_features)
|
193
|
-
if not local_loss:
|
194
|
-
# ensure grads for local rank when all_* features don't have a gradient
|
195
|
-
gathered_image_features[rank] = image_features
|
196
|
-
gathered_text_features[rank] = text_features
|
197
|
-
all_image_features = torch.cat(gathered_image_features, dim=0)
|
198
|
-
all_text_features = torch.cat(gathered_text_features, dim=0)
|
199
|
-
|
200
|
-
return all_image_features, all_text_features
|
201
|
-
|
202
|
-
|
203
|
-
class MultiNegativesSoftmaxLoss(nn.Module):
|
204
|
-
"""
|
205
|
-
This loss expects as input a batch consisting of pairs (a_1, p_1), (a_2, p_2)…, (a_n, p_n) where
|
206
|
-
we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
|
207
|
-
For each a_i, it uses all other p_j as negative samples, i.e., for a_i,
|
208
|
-
we have 1 positive example (p_i) and n-1 negative examples (p_j).
|
209
|
-
It then minimizes the negative log-likehood for softmax normalized scores.
|
210
|
-
It can also support gather negatives across processes.
|
211
|
-
"""
|
212
|
-
|
213
|
-
def __init__(
|
214
|
-
self,
|
215
|
-
local_loss=False,
|
216
|
-
gather_with_grad=False,
|
217
|
-
cache_labels=False,
|
218
|
-
use_horovod=False,
|
219
|
-
):
|
220
|
-
"""
|
221
|
-
Parameters
|
222
|
-
----------
|
223
|
-
local_loss
|
224
|
-
Whether to compute the loss only for the current process's samples.
|
225
|
-
gather_with_grad
|
226
|
-
Whether to gather all features with gradients enabled.
|
227
|
-
cache_labels
|
228
|
-
Whether to cache labels for loss in next iterations.
|
229
|
-
use_horovod
|
230
|
-
Whether to use horovod.
|
231
|
-
"""
|
232
|
-
super().__init__()
|
233
|
-
self.local_loss = local_loss
|
234
|
-
self.gather_with_grad = gather_with_grad
|
235
|
-
self.cache_labels = cache_labels
|
236
|
-
self.use_horovod = use_horovod
|
237
|
-
|
238
|
-
# cache state
|
239
|
-
self.prev_num_logits = 0
|
240
|
-
self.labels = {}
|
241
|
-
|
242
|
-
def forward(self, features_a, features_b, logit_scale, rank=0, world_size=1):
|
243
|
-
device = features_a.device
|
244
|
-
if world_size > 1:
|
245
|
-
all_features_a, all_features_b = gather_features(
|
246
|
-
features_a, features_b, self.local_loss, self.gather_with_grad, rank, world_size, self.use_horovod
|
247
|
-
)
|
248
|
-
|
249
|
-
if self.local_loss:
|
250
|
-
logits_per_a = logit_scale * features_a @ all_features_b.T
|
251
|
-
logits_per_b = logit_scale * features_b @ all_features_a.T
|
252
|
-
else:
|
253
|
-
logits_per_a = logit_scale * all_features_a @ all_features_b.T
|
254
|
-
logits_per_b = logits_per_a.T
|
255
|
-
else:
|
256
|
-
logits_per_a = logit_scale * features_a @ features_b.T
|
257
|
-
logits_per_b = logit_scale * features_b @ features_a.T
|
258
|
-
|
259
|
-
# calculated ground-truth and cache if enabled
|
260
|
-
num_logits = logits_per_a.shape[0]
|
261
|
-
if self.prev_num_logits != num_logits or device not in self.labels:
|
262
|
-
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
263
|
-
if world_size > 1 and self.local_loss:
|
264
|
-
labels = labels + num_logits * rank
|
265
|
-
if self.cache_labels:
|
266
|
-
self.labels[device] = labels
|
267
|
-
self.prev_num_logits = num_logits
|
268
|
-
else:
|
269
|
-
labels = self.labels[device]
|
270
|
-
|
271
|
-
total_loss = (F.cross_entropy(logits_per_a, labels) + F.cross_entropy(logits_per_b, labels)) / 2
|
272
|
-
return total_loss
|
273
|
-
|
274
|
-
|
275
|
-
class FocalLoss(nn.Module):
|
276
|
-
"""
|
277
|
-
Focal loss based on https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
|
278
|
-
|
279
|
-
References:
|
280
|
-
[1] https://arxiv.org/abs/1708.02002
|
281
|
-
"""
|
282
|
-
|
283
|
-
def __init__(
|
284
|
-
self,
|
285
|
-
alpha: Optional[torch.Tensor] = None,
|
286
|
-
gamma: Optional[float] = 2.0,
|
287
|
-
reduction: Optional[str] = "mean",
|
288
|
-
eps: Optional[float] = 1e-6,
|
289
|
-
):
|
290
|
-
"""
|
291
|
-
|
292
|
-
Parameters
|
293
|
-
----------
|
294
|
-
alpha
|
295
|
-
weighting factor for each class. Should be of shape (num_classes)
|
296
|
-
gamma
|
297
|
-
the focal parameter for calculating weights on easy/hard samples
|
298
|
-
reduction
|
299
|
-
the reduction to apply to the final loss output. Default: "mean". Options:
|
300
|
-
"mean", "sum"
|
301
|
-
eps
|
302
|
-
epsilon for numerical stability
|
303
|
-
"""
|
304
|
-
super(FocalLoss, self).__init__()
|
305
|
-
|
306
|
-
self.gamma = gamma
|
307
|
-
self.reduction = reduction
|
308
|
-
self.eps = eps
|
309
|
-
if alpha is not None:
|
310
|
-
if isinstance(alpha, str): # handles Ray Tune HPO sampled hyperparameter
|
311
|
-
try:
|
312
|
-
numbers = alpha.strip("()").split(",")
|
313
|
-
alpha = [float(num) for num in numbers]
|
314
|
-
except:
|
315
|
-
raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
|
316
|
-
alpha = torch.tensor(alpha)
|
317
|
-
self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
|
318
|
-
|
319
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
320
|
-
if not torch.is_tensor(input):
|
321
|
-
raise TypeError("input type is not a torch.Tensor. Got {}".format(type(input)))
|
322
|
-
if input.ndim > 2:
|
323
|
-
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
|
324
|
-
num_class = input.shape[1]
|
325
|
-
input = input.permute(0, *range(2, input.ndim), 1).reshape(-1, num_class)
|
326
|
-
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
|
327
|
-
target = target.view(-1)
|
328
|
-
|
329
|
-
pt = F.softmax(input, dim=-1)
|
330
|
-
|
331
|
-
# -alpha_t * log(pt) term
|
332
|
-
log_p = torch.log_softmax(input, dim=-1)
|
333
|
-
ce = self.nll_loss(log_p, target)
|
334
|
-
|
335
|
-
# (1 - pt)^gamma term
|
336
|
-
all_rows = torch.arange(input.shape[0])
|
337
|
-
pt = pt[all_rows, target]
|
338
|
-
focal_term = (1 - pt) ** self.gamma
|
339
|
-
|
340
|
-
loss = focal_term * ce
|
341
|
-
|
342
|
-
if self.reduction == "mean":
|
343
|
-
loss = loss.mean()
|
344
|
-
|
345
|
-
elif self.reduction == "sum":
|
346
|
-
loss = loss.sum()
|
347
|
-
|
348
|
-
return loss
|
349
|
-
|
350
|
-
|
351
|
-
class StructureLoss(nn.Module):
|
352
|
-
"""
|
353
|
-
Structure Loss based on https://github.com/DengPingFan/PraNet/blob/master/MyTrain.py
|
354
|
-
The loss represent the weighted IoU loss and binary cross entropy (BCE) loss for the global restriction and local (pixel-level) restriction.
|
355
|
-
|
356
|
-
References:
|
357
|
-
[1] https://arxiv.org/abs/2006.11392
|
358
|
-
"""
|
359
|
-
|
360
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
361
|
-
if input.dim() == 3:
|
362
|
-
input = input.unsqueeze(1)
|
363
|
-
weit = 1 + 5 * torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15) - target)
|
364
|
-
wbce = F.binary_cross_entropy_with_logits(input, target, reduce="none")
|
365
|
-
wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
|
366
|
-
|
367
|
-
input = torch.sigmoid(input)
|
368
|
-
inter = ((input * target) * weit).sum(dim=(2, 3))
|
369
|
-
union = ((input + target) * weit).sum(dim=(2, 3))
|
370
|
-
wiou = 1 - (inter + 1) / (union - inter + 1)
|
371
|
-
return (wbce + wiou).mean()
|
372
|
-
|
373
|
-
|
374
|
-
class BBCEWithLogitLoss(nn.Module):
|
375
|
-
"""
|
376
|
-
Balanced BCEWithLogitLoss based on https://github.com/NiFangBaAGe/Explicit-Visual-Prompt/blob/latest_branch/models/segformer.py
|
377
|
-
"""
|
378
|
-
|
379
|
-
def __init__(self):
|
380
|
-
super(BBCEWithLogitLoss, self).__init__()
|
381
|
-
|
382
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
383
|
-
if input.dim() == 3:
|
384
|
-
input = input.unsqueeze(1)
|
385
|
-
eps = 1e-10
|
386
|
-
count_pos = torch.sum(target) + eps
|
387
|
-
count_neg = torch.sum(1.0 - target)
|
388
|
-
ratio = count_neg / count_pos
|
389
|
-
w_neg = count_pos / (count_pos + count_neg)
|
390
|
-
|
391
|
-
bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio)
|
392
|
-
loss = w_neg * bce1(input, target)
|
393
|
-
|
394
|
-
return loss
|