autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
try:
|
6
|
+
import torch.distributed.nn
|
7
|
+
from torch import distributed as dist
|
8
|
+
|
9
|
+
has_distributed = True
|
10
|
+
except ImportError:
|
11
|
+
has_distributed = False
|
12
|
+
|
13
|
+
try:
|
14
|
+
import horovod.torch as hvd
|
15
|
+
except ImportError:
|
16
|
+
hvd = None
|
17
|
+
|
18
|
+
|
19
|
+
class SoftTargetCrossEntropy(nn.Module):
|
20
|
+
"""
|
21
|
+
The soft target CrossEntropy from timm.
|
22
|
+
https://github.com/rwightman/pytorch-image-models/blob/e4360e6125bb0bb4279785810c8eb33b40af3ebd/timm/loss/cross_entropy.py
|
23
|
+
It works under the mixup.
|
24
|
+
It can calculate the crossentropy of input and label with one-hot.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self):
|
28
|
+
super(SoftTargetCrossEntropy, self).__init__()
|
29
|
+
|
30
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
31
|
+
loss = torch.sum(-target * F.log_softmax(input, dim=-1), dim=-1)
|
32
|
+
return loss.mean()
|
33
|
+
|
34
|
+
|
35
|
+
class MultiNegativesSoftmaxLoss(nn.Module):
|
36
|
+
"""
|
37
|
+
This loss expects as input a batch consisting of pairs (a_1, p_1), (a_2, p_2)…, (a_n, p_n) where
|
38
|
+
we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
|
39
|
+
For each a_i, it uses all other p_j as negative samples, i.e., for a_i,
|
40
|
+
we have 1 positive example (p_i) and n-1 negative examples (p_j).
|
41
|
+
It then minimizes the negative log-likehood for softmax normalized scores.
|
42
|
+
It can also support gather negatives across processes.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
local_loss=False,
|
48
|
+
gather_with_grad=False,
|
49
|
+
cache_labels=False,
|
50
|
+
use_horovod=False,
|
51
|
+
):
|
52
|
+
"""
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
local_loss
|
56
|
+
Whether to compute the loss only for the current process's samples.
|
57
|
+
gather_with_grad
|
58
|
+
Whether to gather all features with gradients enabled.
|
59
|
+
cache_labels
|
60
|
+
Whether to cache labels for loss in next iterations.
|
61
|
+
use_horovod
|
62
|
+
Whether to use horovod.
|
63
|
+
"""
|
64
|
+
super().__init__()
|
65
|
+
self.local_loss = local_loss
|
66
|
+
self.gather_with_grad = gather_with_grad
|
67
|
+
self.cache_labels = cache_labels
|
68
|
+
self.use_horovod = use_horovod
|
69
|
+
|
70
|
+
# cache state
|
71
|
+
self.prev_num_logits = 0
|
72
|
+
self.labels = {}
|
73
|
+
|
74
|
+
def forward(self, features_a, features_b, logit_scale, rank=0, world_size=1):
|
75
|
+
device = features_a.device
|
76
|
+
if world_size > 1:
|
77
|
+
all_features_a, all_features_b = self.gather_features(
|
78
|
+
features_a, features_b, self.local_loss, self.gather_with_grad, rank, world_size, self.use_horovod
|
79
|
+
)
|
80
|
+
|
81
|
+
if self.local_loss:
|
82
|
+
logits_per_a = logit_scale * features_a @ all_features_b.T
|
83
|
+
logits_per_b = logit_scale * features_b @ all_features_a.T
|
84
|
+
else:
|
85
|
+
logits_per_a = logit_scale * all_features_a @ all_features_b.T
|
86
|
+
logits_per_b = logits_per_a.T
|
87
|
+
else:
|
88
|
+
logits_per_a = logit_scale * features_a @ features_b.T
|
89
|
+
logits_per_b = logit_scale * features_b @ features_a.T
|
90
|
+
|
91
|
+
# calculated ground-truth and cache if enabled
|
92
|
+
num_logits = logits_per_a.shape[0]
|
93
|
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
94
|
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
95
|
+
if world_size > 1 and self.local_loss:
|
96
|
+
labels = labels + num_logits * rank
|
97
|
+
if self.cache_labels:
|
98
|
+
self.labels[device] = labels
|
99
|
+
self.prev_num_logits = num_logits
|
100
|
+
else:
|
101
|
+
labels = self.labels[device]
|
102
|
+
|
103
|
+
total_loss = (F.cross_entropy(logits_per_a, labels) + F.cross_entropy(logits_per_b, labels)) / 2
|
104
|
+
return total_loss
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def gather_features(
|
108
|
+
image_features,
|
109
|
+
text_features,
|
110
|
+
local_loss=False,
|
111
|
+
gather_with_grad=False,
|
112
|
+
rank=0,
|
113
|
+
world_size=1,
|
114
|
+
use_horovod=False,
|
115
|
+
):
|
116
|
+
"""
|
117
|
+
Gather features across GPUs.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
image_features
|
122
|
+
image features of the current process.
|
123
|
+
text_features
|
124
|
+
text features of the current process.
|
125
|
+
local_loss
|
126
|
+
If False, make sure the features on the current GPU have gradients.
|
127
|
+
gather_with_grad
|
128
|
+
Whether to gather all features with gradients enabled.
|
129
|
+
rank
|
130
|
+
Rank of the current process (it should be a number between 0 and world_size-1).
|
131
|
+
world_size
|
132
|
+
Number of processes participating in the job.
|
133
|
+
use_horovod
|
134
|
+
Whether to use horovod.
|
135
|
+
|
136
|
+
Returns
|
137
|
+
-------
|
138
|
+
Gathered image and text features from all processes.
|
139
|
+
"""
|
140
|
+
assert (
|
141
|
+
has_distributed
|
142
|
+
), "torch.distributed did not import correctly, please use a PyTorch version with support."
|
143
|
+
if use_horovod:
|
144
|
+
assert hvd is not None, "Please install horovod"
|
145
|
+
if gather_with_grad:
|
146
|
+
all_image_features = hvd.allgather(image_features)
|
147
|
+
all_text_features = hvd.allgather(text_features)
|
148
|
+
else:
|
149
|
+
with torch.no_grad():
|
150
|
+
all_image_features = hvd.allgather(image_features)
|
151
|
+
all_text_features = hvd.allgather(text_features)
|
152
|
+
if not local_loss:
|
153
|
+
# ensure grads for local rank when all_* features don't have a gradient
|
154
|
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
155
|
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
156
|
+
gathered_image_features[rank] = image_features
|
157
|
+
gathered_text_features[rank] = text_features
|
158
|
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
159
|
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
160
|
+
else:
|
161
|
+
# We gather tensors from all gpus
|
162
|
+
if gather_with_grad:
|
163
|
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
164
|
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
165
|
+
else:
|
166
|
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
167
|
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
168
|
+
dist.all_gather(gathered_image_features, image_features)
|
169
|
+
dist.all_gather(gathered_text_features, text_features)
|
170
|
+
if not local_loss:
|
171
|
+
# ensure grads for local rank when all_* features don't have a gradient
|
172
|
+
gathered_image_features[rank] = image_features
|
173
|
+
gathered_text_features[rank] = text_features
|
174
|
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
175
|
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
176
|
+
|
177
|
+
return all_image_features, all_text_features
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
|
6
|
+
class StructureLoss(nn.Module):
|
7
|
+
"""
|
8
|
+
Structure Loss based on https://github.com/DengPingFan/PraNet/blob/master/MyTrain.py
|
9
|
+
The loss represent the weighted IoU loss and binary cross entropy (BCE) loss for the global restriction and local (pixel-level) restriction.
|
10
|
+
|
11
|
+
References:
|
12
|
+
[1] https://arxiv.org/abs/2006.11392
|
13
|
+
"""
|
14
|
+
|
15
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
16
|
+
if input.dim() == 3:
|
17
|
+
input = input.unsqueeze(1)
|
18
|
+
weit = 1 + 5 * torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15) - target)
|
19
|
+
wbce = F.binary_cross_entropy_with_logits(input, target, reduce="none")
|
20
|
+
wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
|
21
|
+
|
22
|
+
input = torch.sigmoid(input)
|
23
|
+
inter = ((input * target) * weit).sum(dim=(2, 3))
|
24
|
+
union = ((input + target) * weit).sum(dim=(2, 3))
|
25
|
+
wiou = 1 - (inter + 1) / (union - inter + 1)
|
26
|
+
return (wbce + wiou).mean()
|
@@ -0,0 +1,313 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from omegaconf import DictConfig, OmegaConf
|
6
|
+
from pytorch_metric_learning import distances, losses, miners
|
7
|
+
from torch import nn
|
8
|
+
from transformers.models.mask2former.modeling_mask2former import Mask2FormerConfig, Mask2FormerLoss
|
9
|
+
|
10
|
+
from ...constants import (
|
11
|
+
BINARY,
|
12
|
+
CONTRASTIVE_LOSS,
|
13
|
+
COSINE_SIMILARITY,
|
14
|
+
FEW_SHOT_CLASSIFICATION,
|
15
|
+
MULTI_NEGATIVES_SOFTMAX_LOSS,
|
16
|
+
MULTICLASS,
|
17
|
+
NER,
|
18
|
+
OBJECT_DETECTION,
|
19
|
+
PAIR_MARGIN_MINER,
|
20
|
+
REGRESSION,
|
21
|
+
SEMANTIC_SEGMENTATION,
|
22
|
+
)
|
23
|
+
from .bce_loss import BBCEWithLogitLoss
|
24
|
+
from .focal_loss import FocalLoss
|
25
|
+
from .lemda_loss import LemdaLoss
|
26
|
+
from .softmax_losses import MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy
|
27
|
+
from .structure_loss import StructureLoss
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def get_loss_func(
|
33
|
+
problem_type: str,
|
34
|
+
mixup_active: Optional[bool] = None,
|
35
|
+
loss_func_name: Optional[str] = None,
|
36
|
+
config: Optional[DictConfig] = None,
|
37
|
+
**kwargs,
|
38
|
+
):
|
39
|
+
"""
|
40
|
+
Choose a suitable Pytorch loss module based on the provided problem type.
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
problem_type
|
45
|
+
Type of problem.
|
46
|
+
mixup_active
|
47
|
+
The activation determining whether to use mixup.
|
48
|
+
loss_func_name
|
49
|
+
The name of the function the user wants to use.
|
50
|
+
config
|
51
|
+
The optimization configs containing values such as i.e. optim.loss_func
|
52
|
+
An example purpose of this config here is to pass through the parameters for focal loss, i.e.:
|
53
|
+
alpha = optim.focal_loss.alpha
|
54
|
+
Returns
|
55
|
+
-------
|
56
|
+
A Pytorch loss module.
|
57
|
+
"""
|
58
|
+
if problem_type in [BINARY, MULTICLASS]:
|
59
|
+
if mixup_active:
|
60
|
+
loss_func = SoftTargetCrossEntropy()
|
61
|
+
else:
|
62
|
+
if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
|
63
|
+
loss_func = FocalLoss(
|
64
|
+
alpha=config.focal_loss.alpha,
|
65
|
+
gamma=config.focal_loss.gamma,
|
66
|
+
reduction=config.focal_loss.reduction,
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
|
70
|
+
logger.debug(f"loss_func.label_smoothing: {loss_func.label_smoothing}")
|
71
|
+
elif problem_type == REGRESSION:
|
72
|
+
if loss_func_name is not None:
|
73
|
+
if "bcewithlogitsloss" in loss_func_name.lower():
|
74
|
+
loss_func = nn.BCEWithLogitsLoss()
|
75
|
+
else:
|
76
|
+
loss_func = nn.MSELoss()
|
77
|
+
else:
|
78
|
+
loss_func = nn.MSELoss()
|
79
|
+
elif problem_type == NER:
|
80
|
+
loss_func = nn.CrossEntropyLoss(ignore_index=0)
|
81
|
+
elif problem_type in [None, OBJECT_DETECTION, FEW_SHOT_CLASSIFICATION]:
|
82
|
+
return None
|
83
|
+
elif problem_type == SEMANTIC_SEGMENTATION:
|
84
|
+
if "structure_loss" in loss_func_name.lower():
|
85
|
+
loss_func = StructureLoss()
|
86
|
+
elif "balanced_bce" in loss_func_name.lower():
|
87
|
+
loss_func = BBCEWithLogitLoss()
|
88
|
+
elif "mask2former_loss" in loss_func_name.lower():
|
89
|
+
weight_dict = {
|
90
|
+
"loss_cross_entropy": config.mask2former_loss.loss_cross_entropy_weight,
|
91
|
+
"loss_mask": config.mask2former_loss.loss_mask_weight,
|
92
|
+
"loss_dice": config.mask2former_loss.loss_dice_weight,
|
93
|
+
}
|
94
|
+
loss_func = Mask2FormerLoss(
|
95
|
+
config=Mask2FormerConfig(num_labels=kwargs["num_classes"]), weight_dict=weight_dict
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
loss_func = nn.BCEWithLogitsLoss()
|
99
|
+
else:
|
100
|
+
raise NotImplementedError
|
101
|
+
|
102
|
+
return loss_func
|
103
|
+
|
104
|
+
|
105
|
+
def get_metric_learning_distance_func(
|
106
|
+
name: str,
|
107
|
+
):
|
108
|
+
"""
|
109
|
+
Return one pytorch metric learning's distance function based on its name.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
name
|
114
|
+
distance function name
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
A distance function from the pytorch metric learning package.
|
119
|
+
"""
|
120
|
+
if name.lower() == COSINE_SIMILARITY:
|
121
|
+
return distances.CosineSimilarity()
|
122
|
+
else:
|
123
|
+
raise ValueError(f"Unknown distance measure: {name}")
|
124
|
+
|
125
|
+
|
126
|
+
def infer_matcher_loss(data_format: str, problem_type: str):
|
127
|
+
"""
|
128
|
+
Infer the loss type to train the matcher.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
data_format
|
133
|
+
The training data format, e.g., pair or triplet.
|
134
|
+
problem_type
|
135
|
+
Type of problem.
|
136
|
+
|
137
|
+
Returns
|
138
|
+
-------
|
139
|
+
The loss name.
|
140
|
+
"""
|
141
|
+
if data_format == "pair":
|
142
|
+
if problem_type is None:
|
143
|
+
return [MULTI_NEGATIVES_SOFTMAX_LOSS]
|
144
|
+
elif problem_type == BINARY:
|
145
|
+
return [CONTRASTIVE_LOSS]
|
146
|
+
elif problem_type == REGRESSION:
|
147
|
+
return ["cosine_similarity_loss"]
|
148
|
+
else:
|
149
|
+
raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
|
150
|
+
elif data_format == "triplet":
|
151
|
+
if problem_type is None:
|
152
|
+
return [MULTI_NEGATIVES_SOFTMAX_LOSS]
|
153
|
+
else:
|
154
|
+
raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
|
155
|
+
else:
|
156
|
+
raise ValueError(f"Unsupported data format: {data_format}")
|
157
|
+
|
158
|
+
|
159
|
+
def get_matcher_loss_func(
|
160
|
+
data_format: str,
|
161
|
+
problem_type: str,
|
162
|
+
loss_type: Optional[str] = None,
|
163
|
+
pos_margin: Optional[float] = None,
|
164
|
+
neg_margin: Optional[float] = None,
|
165
|
+
distance_type: Optional[str] = None,
|
166
|
+
):
|
167
|
+
"""
|
168
|
+
Return a list of pytorch metric learning's loss functions based on their names.
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
data_format
|
173
|
+
The training data format, e.g., pair or triplet.
|
174
|
+
problem_type
|
175
|
+
Type of problem.
|
176
|
+
loss_type
|
177
|
+
The provided loss type.
|
178
|
+
pos_margin
|
179
|
+
The positive margin in computing the metric learning loss.
|
180
|
+
neg_margin
|
181
|
+
The negative margin in computing the metric learning loss.
|
182
|
+
distance_type
|
183
|
+
The distance function type.
|
184
|
+
|
185
|
+
Returns
|
186
|
+
-------
|
187
|
+
A loss function of metric learning.
|
188
|
+
"""
|
189
|
+
|
190
|
+
allowable_loss_types = infer_matcher_loss(data_format=data_format, problem_type=problem_type)
|
191
|
+
if loss_type is not None:
|
192
|
+
assert loss_type in allowable_loss_types, f"data format {data_format} can't use loss {loss_type}."
|
193
|
+
else:
|
194
|
+
loss_type = allowable_loss_types[0]
|
195
|
+
|
196
|
+
if loss_type.lower() == CONTRASTIVE_LOSS:
|
197
|
+
return losses.ContrastiveLoss(
|
198
|
+
pos_margin=pos_margin,
|
199
|
+
neg_margin=neg_margin,
|
200
|
+
distance=get_metric_learning_distance_func(distance_type),
|
201
|
+
)
|
202
|
+
elif loss_type.lower() == MULTI_NEGATIVES_SOFTMAX_LOSS:
|
203
|
+
return MultiNegativesSoftmaxLoss(
|
204
|
+
local_loss=True,
|
205
|
+
gather_with_grad=True,
|
206
|
+
cache_labels=False,
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
raise ValueError(f"Unknown metric learning loss: {loss_type}")
|
210
|
+
|
211
|
+
|
212
|
+
def get_matcher_miner_func(
|
213
|
+
miner_type: str,
|
214
|
+
pos_margin: float,
|
215
|
+
neg_margin: float,
|
216
|
+
distance_type: str,
|
217
|
+
):
|
218
|
+
"""
|
219
|
+
Return a pytorch metric learning's miner functions based on their names.
|
220
|
+
The miners are used to mine the positive and negative examples.
|
221
|
+
|
222
|
+
Parameters
|
223
|
+
----------
|
224
|
+
miner_type
|
225
|
+
The miner function type.
|
226
|
+
pos_margin
|
227
|
+
The positive margin used by the miner function.
|
228
|
+
neg_margin
|
229
|
+
The negative margin used by the miner function.
|
230
|
+
distance_type
|
231
|
+
The distance function type.
|
232
|
+
|
233
|
+
Returns
|
234
|
+
-------
|
235
|
+
A miner function to mine positive and negative samples.
|
236
|
+
"""
|
237
|
+
if miner_type.lower() == PAIR_MARGIN_MINER:
|
238
|
+
return miners.PairMarginMiner(
|
239
|
+
pos_margin=pos_margin,
|
240
|
+
neg_margin=neg_margin,
|
241
|
+
distance=get_metric_learning_distance_func(distance_type),
|
242
|
+
)
|
243
|
+
else:
|
244
|
+
raise ValueError(f"Unknown metric learning miner: {miner_type}")
|
245
|
+
|
246
|
+
|
247
|
+
def generate_metric_learning_labels(
|
248
|
+
num_samples: int,
|
249
|
+
match_label: int,
|
250
|
+
labels: torch.Tensor,
|
251
|
+
):
|
252
|
+
"""
|
253
|
+
Generate labels to compute the metric learning loss of one mini-batch.
|
254
|
+
For n samples, it generates 2*n labels since each match has two sides, each of which
|
255
|
+
has one label. If we know the matching label, then it determines the two sides' labels
|
256
|
+
according to whether their label is the matching label. If the matching label is None,
|
257
|
+
it assigns a unique label for each side.
|
258
|
+
|
259
|
+
Parameters
|
260
|
+
----------
|
261
|
+
num_samples
|
262
|
+
number of samples.
|
263
|
+
match_label
|
264
|
+
The matching label, which can be None.
|
265
|
+
labels
|
266
|
+
The sample labels used in the supervised setting. It's required only when match_label is not None.
|
267
|
+
|
268
|
+
Returns
|
269
|
+
-------
|
270
|
+
The labels used in computing the metric learning loss.
|
271
|
+
"""
|
272
|
+
device = labels.device
|
273
|
+
labels_1 = torch.arange(num_samples, device=device)
|
274
|
+
|
275
|
+
if match_label is not None:
|
276
|
+
labels_2 = torch.arange(num_samples, num_samples * 2, device=device)
|
277
|
+
# users need to specify the match_label based on the raw label's semantic meaning.
|
278
|
+
mask = labels == match_label
|
279
|
+
labels_2[mask] = labels_1[mask]
|
280
|
+
else:
|
281
|
+
labels_2 = torch.arange(num_samples, device=device)
|
282
|
+
|
283
|
+
metric_learning_labels = torch.cat([labels_1, labels_2], dim=0)
|
284
|
+
|
285
|
+
return metric_learning_labels
|
286
|
+
|
287
|
+
|
288
|
+
def get_aug_loss_func(config: Optional[DictConfig] = None, problem_type: Optional[str] = None):
|
289
|
+
"""
|
290
|
+
Return the loss function for lemda augmentation
|
291
|
+
|
292
|
+
Parameters
|
293
|
+
----------
|
294
|
+
config
|
295
|
+
The optimization configuration.
|
296
|
+
problem_type
|
297
|
+
Problem type (binary, multimclass, or regression)
|
298
|
+
|
299
|
+
Returns
|
300
|
+
-------
|
301
|
+
Augmentation loss function.
|
302
|
+
"""
|
303
|
+
loss_func = None
|
304
|
+
if config.lemda.turn_on:
|
305
|
+
loss_func = LemdaLoss(
|
306
|
+
mse_weight=config.lemda.mse_weight,
|
307
|
+
kld_weight=config.lemda.kld_weight,
|
308
|
+
consist_weight=config.lemda.consist_weight,
|
309
|
+
consist_threshold=config.lemda.consist_threshold,
|
310
|
+
problem_type=problem_type,
|
311
|
+
)
|
312
|
+
|
313
|
+
return loss_func
|
@@ -0,0 +1 @@
|
|
1
|
+
from .utils import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
|