autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,17 @@
|
|
1
|
+
from .lit_distiller import DistillerLitModule
|
2
|
+
from .lit_matcher import MatcherLitModule
|
3
|
+
from .lit_mmdet import MMDetLitModule
|
4
|
+
from .lit_module import LitModule
|
5
|
+
from .lit_ner import NerLitModule
|
6
|
+
from .lit_semantic_seg import SemanticSegmentationLitModule
|
7
|
+
from .losses import get_aug_loss_func, get_loss_func, get_matcher_loss_func, get_matcher_miner_func
|
8
|
+
from .metrics import (
|
9
|
+
CustomHitRate,
|
10
|
+
compute_ranking_score,
|
11
|
+
compute_score,
|
12
|
+
get_minmax_mode,
|
13
|
+
get_stopping_threshold,
|
14
|
+
get_torchmetric,
|
15
|
+
infer_metrics,
|
16
|
+
)
|
17
|
+
from .utils import get_norm_layer_param_names, get_peft_param_names
|
@@ -13,7 +13,8 @@ from torchmetrics.aggregation import BaseAggregator
|
|
13
13
|
|
14
14
|
from ..constants import FEATURES, LOGITS, WEIGHT
|
15
15
|
from ..models.utils import run_model
|
16
|
-
from .
|
16
|
+
from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
|
17
|
+
from .utils import get_optimizer
|
17
18
|
|
18
19
|
logger = logging.getLogger(__name__)
|
19
20
|
|
@@ -13,16 +13,10 @@ from torchmetrics.aggregation import BaseAggregator
|
|
13
13
|
from ..constants import FEATURES, LOGIT_SCALE, PROBABILITY, QUERY, RESPONSE
|
14
14
|
from ..models.utils import run_model
|
15
15
|
from ..utils.matcher import compute_matching_probability
|
16
|
-
from .losses import MultiNegativesSoftmaxLoss
|
17
|
-
from .
|
18
|
-
|
19
|
-
|
20
|
-
apply_single_lr,
|
21
|
-
apply_two_stages_lr,
|
22
|
-
generate_metric_learning_labels,
|
23
|
-
get_lr_scheduler,
|
24
|
-
get_optimizer,
|
25
|
-
)
|
16
|
+
from .losses import MultiNegativesSoftmaxLoss, generate_metric_learning_labels
|
17
|
+
from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
|
18
|
+
from .metrics import CustomHitRate
|
19
|
+
from .utils import get_optimizer
|
26
20
|
|
27
21
|
logger = logging.getLogger(__name__)
|
28
22
|
|
@@ -2,21 +2,13 @@ import logging
|
|
2
2
|
from typing import Callable, Optional, Union
|
3
3
|
|
4
4
|
import lightning.pytorch as pl
|
5
|
-
import torch
|
6
5
|
import torchmetrics
|
7
6
|
from lightning.pytorch.utilities import grad_norm
|
8
|
-
from torch.nn.modules.loss import _Loss
|
9
7
|
from torchmetrics.aggregation import BaseAggregator
|
10
8
|
|
11
9
|
from ..constants import BBOX, IMAGE, LABEL
|
12
|
-
from .
|
13
|
-
|
14
|
-
apply_single_lr,
|
15
|
-
apply_two_stages_lr,
|
16
|
-
get_lr_scheduler,
|
17
|
-
get_optimizer,
|
18
|
-
remove_parameters_without_grad,
|
19
|
-
)
|
10
|
+
from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
|
11
|
+
from .utils import get_optimizer, remove_parameters_without_grad
|
20
12
|
|
21
13
|
try:
|
22
14
|
import mmdet
|
@@ -11,11 +11,25 @@ from torch import nn
|
|
11
11
|
from torch.nn.modules.loss import _Loss
|
12
12
|
from torchmetrics.aggregation import BaseAggregator
|
13
13
|
|
14
|
-
from ..constants import
|
14
|
+
from ..constants import (
|
15
|
+
AUG_LOGITS,
|
16
|
+
LM_TARGET,
|
17
|
+
LOGITS,
|
18
|
+
MULTIMODAL_FEATURES,
|
19
|
+
MULTIMODAL_FEATURES_POST_AUG,
|
20
|
+
MULTIMODAL_FEATURES_PRE_AUG,
|
21
|
+
ORI_LOGITS,
|
22
|
+
T_FEW,
|
23
|
+
TEMPLATE_LOGITS,
|
24
|
+
VAE_MEAN,
|
25
|
+
VAE_VAR,
|
26
|
+
WEIGHT,
|
27
|
+
)
|
15
28
|
from ..data.mixup import MixupModule, multimodel_mixup
|
16
29
|
from ..models.utils import run_model
|
17
|
-
from .
|
18
|
-
from .
|
30
|
+
from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
|
31
|
+
from .metrics import Coverage
|
32
|
+
from .utils import get_optimizer
|
19
33
|
|
20
34
|
logger = logging.getLogger(__name__)
|
21
35
|
|
@@ -44,13 +58,24 @@ class LitModule(pl.LightningModule):
|
|
44
58
|
validation_metric_name: Optional[str] = None,
|
45
59
|
custom_metric_func: Callable = None,
|
46
60
|
test_metric: Optional[torchmetrics.Metric] = None,
|
47
|
-
|
61
|
+
peft: Optional[str] = None,
|
48
62
|
trainable_param_names: Optional[List] = None,
|
49
63
|
mixup_fn: Optional[MixupModule] = None,
|
50
64
|
mixup_off_epoch: Optional[int] = 0,
|
51
65
|
model_postprocess_fn: Callable = None,
|
52
66
|
skip_final_val: Optional[bool] = False,
|
53
67
|
track_grad_norm: Optional[Union[int, str]] = -1,
|
68
|
+
cross_modal_align: Optional[str] = None,
|
69
|
+
cross_modal_align_weight: Optional[float] = 0,
|
70
|
+
automatic_optimization: Optional[bool] = True,
|
71
|
+
accumulate_grad_batches: Optional[int] = None,
|
72
|
+
gradient_clip_val: Optional[float] = None,
|
73
|
+
gradient_clip_algorithm: Optional[str] = None,
|
74
|
+
use_aug_optim: Optional[bool] = False,
|
75
|
+
aug_loss_func: Optional[_Loss] = None,
|
76
|
+
aug_lr: Optional[float] = None,
|
77
|
+
aug_weight_decay: Optional[float] = None,
|
78
|
+
aug_optim_type: Optional[str] = None,
|
54
79
|
):
|
55
80
|
"""
|
56
81
|
Parameters
|
@@ -104,7 +129,7 @@ class LitModule(pl.LightningModule):
|
|
104
129
|
Refer to https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/aggregation.py
|
105
130
|
test_metric
|
106
131
|
A torchmetrics module used in the test stage, e.g., torchmetrics.Accuracy().
|
107
|
-
|
132
|
+
peft
|
108
133
|
Whether to use efficient finetuning strategies. This will be helpful for fast finetuning of large backbones.
|
109
134
|
We support options such as:
|
110
135
|
|
@@ -128,6 +153,8 @@ class LitModule(pl.LightningModule):
|
|
128
153
|
"model_postprocess_fn",
|
129
154
|
"mixup_fn",
|
130
155
|
"trainable_param_names",
|
156
|
+
"custom_metric_func",
|
157
|
+
"aug_loss_func",
|
131
158
|
]
|
132
159
|
)
|
133
160
|
self.model = model
|
@@ -144,7 +171,12 @@ class LitModule(pl.LightningModule):
|
|
144
171
|
self.model_postprocess_fn = model_postprocess_fn
|
145
172
|
self.trainable_param_names = trainable_param_names if trainable_param_names else []
|
146
173
|
self.skip_final_val = skip_final_val
|
147
|
-
self.
|
174
|
+
self.automatic_optimization = automatic_optimization
|
175
|
+
self.aug_loss_func = aug_loss_func
|
176
|
+
if self.hparams.cross_modal_align:
|
177
|
+
assert self.hparams.cross_modal_align_weight > 0
|
178
|
+
logger.debug(f"Cross modal alignment mode: {self.hparams.cross_modal_align}")
|
179
|
+
logger.debug(f"Cross modal alignment loss weight: {self.hparams.cross_modal_align_weight}")
|
148
180
|
|
149
181
|
def _compute_template_loss(
|
150
182
|
self,
|
@@ -179,19 +211,41 @@ class LitModule(pl.LightningModule):
|
|
179
211
|
|
180
212
|
return lm_loss + mc_loss * self.model.mc_loss + unlikely_loss * self.model.unlikely_loss
|
181
213
|
|
214
|
+
def _compute_cross_modal_align_loss(self, multimodal_features):
|
215
|
+
if self.hparams.cross_modal_align == "positive_only":
|
216
|
+
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
217
|
+
loss = 0
|
218
|
+
num = 0
|
219
|
+
for i in range(len(multimodal_features)):
|
220
|
+
# input should be a distribution in the log space
|
221
|
+
a = F.log_softmax(multimodal_features[i], dim=1)
|
222
|
+
# kl divergence is not symmetric, so need to compute both (i, j) and (j, i)
|
223
|
+
for j in range(len(multimodal_features)):
|
224
|
+
if i == j:
|
225
|
+
continue
|
226
|
+
# input should be a distribution in the log space
|
227
|
+
b = F.log_softmax(multimodal_features[j], dim=1)
|
228
|
+
loss += kl_loss(a, b)
|
229
|
+
num += 1
|
230
|
+
return self.hparams.cross_modal_align_weight * loss / num
|
231
|
+
else:
|
232
|
+
raise ValueError(f"Unsupported cross modal alignment loss: {self.hparams.cross_modal_align}.")
|
233
|
+
|
182
234
|
def _compute_loss(
|
183
235
|
self,
|
184
236
|
output: Dict,
|
185
237
|
label: torch.Tensor,
|
186
238
|
):
|
187
239
|
loss = 0
|
188
|
-
for
|
240
|
+
for per_prefix, per_output in output.items():
|
189
241
|
weight = per_output[WEIGHT] if WEIGHT in per_output else 1
|
190
242
|
if (
|
191
243
|
TEMPLATE_LOGITS in per_output and self.model.prefix == T_FEW
|
192
244
|
): # Do only add template loss if T-Few. #TODO Add compatibility to Fusion models.
|
193
245
|
loss += self._compute_template_loss(per_output, label) * weight
|
194
246
|
else:
|
247
|
+
if self.training and self.hparams.use_aug_optim and per_prefix.startswith("fusion"):
|
248
|
+
label = label.tile((2,))
|
195
249
|
loss += (
|
196
250
|
self.loss_func(
|
197
251
|
input=per_output[LOGITS].squeeze(dim=1),
|
@@ -199,6 +253,22 @@ class LitModule(pl.LightningModule):
|
|
199
253
|
)
|
200
254
|
* weight
|
201
255
|
)
|
256
|
+
|
257
|
+
if self.hparams.cross_modal_align:
|
258
|
+
loss += self._compute_cross_modal_align_loss(
|
259
|
+
multimodal_features=output[self.model.prefix][MULTIMODAL_FEATURES]
|
260
|
+
)
|
261
|
+
|
262
|
+
if self.training and self.hparams.use_aug_optim:
|
263
|
+
loss += self.aug_loss_func(
|
264
|
+
pre_aug=output[self.model.prefix][MULTIMODAL_FEATURES_PRE_AUG],
|
265
|
+
post_aug=output[self.model.prefix][MULTIMODAL_FEATURES_POST_AUG],
|
266
|
+
vae_mean=output[self.model.prefix][VAE_MEAN],
|
267
|
+
vae_var=output[self.model.prefix][VAE_VAR],
|
268
|
+
ori_logits=output[self.model.prefix][ORI_LOGITS],
|
269
|
+
aug_logits=output[self.model.prefix][AUG_LOGITS],
|
270
|
+
)
|
271
|
+
|
202
272
|
return loss
|
203
273
|
|
204
274
|
def _compute_metric_score(
|
@@ -214,6 +284,7 @@ class LitModule(pl.LightningModule):
|
|
214
284
|
torchmetrics.classification.BinaryAUROC,
|
215
285
|
torchmetrics.classification.BinaryAveragePrecision,
|
216
286
|
torchmetrics.classification.BinaryF1Score,
|
287
|
+
Coverage,
|
217
288
|
),
|
218
289
|
):
|
219
290
|
prob = F.softmax(logits.float(), dim=1)
|
@@ -255,6 +326,27 @@ class LitModule(pl.LightningModule):
|
|
255
326
|
Average loss of the mini-batch data.
|
256
327
|
"""
|
257
328
|
output, loss = self._shared_step(batch)
|
329
|
+
|
330
|
+
if not self.automatic_optimization:
|
331
|
+
if self.hparams.use_aug_optim:
|
332
|
+
optimizer, aug_optimizer = self.optimizers()
|
333
|
+
else:
|
334
|
+
optimizer = self.optimizers()
|
335
|
+
aug_optimizer = None
|
336
|
+
|
337
|
+
lr_scheduler = self.lr_schedulers()
|
338
|
+
loss = loss / self.hparams.accumulate_grad_batches
|
339
|
+
self.manual_backward(loss)
|
340
|
+
|
341
|
+
if (batch_idx + 1) % self.hparams.accumulate_grad_batches == 0 or self.trainer.is_last_batch:
|
342
|
+
optimizer.step()
|
343
|
+
optimizer.zero_grad()
|
344
|
+
lr_scheduler.step()
|
345
|
+
|
346
|
+
if aug_optimizer is not None:
|
347
|
+
aug_optimizer.step()
|
348
|
+
aug_optimizer.zero_grad()
|
349
|
+
|
258
350
|
self.log("train_loss", loss)
|
259
351
|
return loss
|
260
352
|
|
@@ -345,6 +437,9 @@ class LitModule(pl.LightningModule):
|
|
345
437
|
model=self.model,
|
346
438
|
lr=self.hparams.lr,
|
347
439
|
weight_decay=self.hparams.weight_decay,
|
440
|
+
exclude_keys=[
|
441
|
+
"augmenter"
|
442
|
+
], # exclude augmenter parameters from the optimizer as they would use an independent optimizer
|
348
443
|
)
|
349
444
|
if self.hparams.lr_choice == "two_stages":
|
350
445
|
logger.debug("applying 2-stage learning rate...")
|
@@ -357,14 +452,14 @@ class LitModule(pl.LightningModule):
|
|
357
452
|
logger.debug("applying layerwise learning rate decay...")
|
358
453
|
grouped_parameters = apply_layerwise_lr_decay(
|
359
454
|
lr_decay=self.hparams.lr_decay,
|
360
|
-
|
455
|
+
peft=self.hparams.peft,
|
361
456
|
trainable_param_names=self.trainable_param_names,
|
362
457
|
**kwargs,
|
363
458
|
)
|
364
459
|
else:
|
365
460
|
logger.debug("applying single learning rate...")
|
366
461
|
grouped_parameters = apply_single_lr(
|
367
|
-
|
462
|
+
peft=self.hparams.peft,
|
368
463
|
trainable_param_names=self.trainable_param_names,
|
369
464
|
**kwargs,
|
370
465
|
)
|
@@ -381,16 +476,21 @@ class LitModule(pl.LightningModule):
|
|
381
476
|
if isinstance(self.trainer.strategy, DeepSpeedStrategy):
|
382
477
|
max_steps = 1
|
383
478
|
else:
|
479
|
+
accumulate_grad_batches = (
|
480
|
+
self.trainer.accumulate_grad_batches
|
481
|
+
if self.automatic_optimization
|
482
|
+
else self.hparams.accumulate_grad_batches
|
483
|
+
)
|
384
484
|
max_steps = (
|
385
485
|
len(self.trainer.datamodule.train_dataloader())
|
386
486
|
* self.trainer.max_epochs
|
387
|
-
//
|
487
|
+
// accumulate_grad_batches
|
388
488
|
)
|
389
489
|
logger.debug(
|
390
490
|
f"len(trainer.datamodule.train_dataloader()): {len(self.trainer.datamodule.train_dataloader())}"
|
391
491
|
)
|
392
492
|
logger.debug(f"trainer.max_epochs: {self.trainer.max_epochs}")
|
393
|
-
logger.debug(f"
|
493
|
+
logger.debug(f"accumulate_grad_batches: {accumulate_grad_batches}")
|
394
494
|
else:
|
395
495
|
max_steps = self.trainer.max_steps
|
396
496
|
|
@@ -411,10 +511,35 @@ class LitModule(pl.LightningModule):
|
|
411
511
|
)
|
412
512
|
|
413
513
|
sched = {"scheduler": scheduler, "interval": "step"}
|
514
|
+
ret_optimizers = [optimizer]
|
515
|
+
ret_schedulers = [sched]
|
516
|
+
if self.hparams.use_aug_optim:
|
517
|
+
logger.debug("initializing augment optimizer")
|
518
|
+
# augmenter's optimizer
|
519
|
+
aug_grouped_parameters = apply_single_lr(
|
520
|
+
model=self.model.augmenter,
|
521
|
+
lr=self.hparams.aug_lr,
|
522
|
+
weight_decay=self.hparams.aug_weight_decay,
|
523
|
+
)
|
524
|
+
aug_optimizer = get_optimizer(
|
525
|
+
optim_type=self.hparams.aug_optim_type,
|
526
|
+
optimizer_grouped_parameters=aug_grouped_parameters,
|
527
|
+
lr=self.hparams.aug_lr,
|
528
|
+
weight_decay=self.hparams.aug_weight_decay,
|
529
|
+
)
|
530
|
+
ret_optimizers.append(aug_optimizer)
|
531
|
+
|
414
532
|
logger.debug("done configuring optimizer and scheduler")
|
415
|
-
return
|
533
|
+
return ret_optimizers, ret_schedulers
|
416
534
|
|
417
535
|
def on_before_optimizer_step(self, optimizer):
|
418
536
|
# If using mixed precision, the gradients are already unscaled here
|
419
|
-
|
420
|
-
|
537
|
+
# TODO: apply gradient clip only to the target optimizer
|
538
|
+
if not self.automatic_optimization and self.hparams.gradient_clip_val > 0:
|
539
|
+
self.clip_gradients(
|
540
|
+
optimizer,
|
541
|
+
gradient_clip_val=self.hparams.gradient_clip_val,
|
542
|
+
gradient_clip_algorithm=self.hparams.gradient_clip_algorithm,
|
543
|
+
)
|
544
|
+
if self.hparams.track_grad_norm != -1:
|
545
|
+
self.log_dict(grad_norm(self, norm_type=self.hparams.track_grad_norm))
|
@@ -37,7 +37,7 @@ class NerLitModule(LitModule):
|
|
37
37
|
validation_metric_name: Optional[str] = None,
|
38
38
|
custom_metric_func: Callable = None,
|
39
39
|
test_metric: Optional[torchmetrics.Metric] = None,
|
40
|
-
|
40
|
+
peft: Optional[str] = None,
|
41
41
|
trainable_param_names: Optional[List] = None,
|
42
42
|
mixup_fn: Optional[MixupModule] = None,
|
43
43
|
mixup_off_epoch: Optional[int] = 0,
|
@@ -97,7 +97,7 @@ class NerLitModule(LitModule):
|
|
97
97
|
Refer to https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/aggregation.py
|
98
98
|
test_metric
|
99
99
|
A torchmetrics module used in the test stage, e.g., torchmetrics.Accuracy().
|
100
|
-
|
100
|
+
peft
|
101
101
|
Whether to use efficient finetuning strategies. This will be helpful for fast finetuning of large backbones.
|
102
102
|
We support options such as:
|
103
103
|
|
@@ -127,7 +127,7 @@ class NerLitModule(LitModule):
|
|
127
127
|
validation_metric_name=validation_metric_name,
|
128
128
|
custom_metric_func=custom_metric_func,
|
129
129
|
test_metric=test_metric,
|
130
|
-
|
130
|
+
peft=peft,
|
131
131
|
trainable_param_names=trainable_param_names,
|
132
132
|
mixup_fn=mixup_fn,
|
133
133
|
mixup_off_epoch=mixup_off_epoch,
|
@@ -8,7 +8,7 @@ from transformers.models.mask2former.modeling_mask2former import Mask2FormerLoss
|
|
8
8
|
from ..constants import CLASS_LOGITS, LOGITS, MOE_LOSS, SEMANTIC_MASK, WEIGHT
|
9
9
|
from ..models.utils import run_model
|
10
10
|
from .lit_module import LitModule
|
11
|
-
from .semantic_seg_metrics import Multiclass_IoU
|
11
|
+
from .metrics.semantic_seg_metrics import Multiclass_IoU
|
12
12
|
|
13
13
|
logger = logging.getLogger(__name__)
|
14
14
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .softmax_losses import MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy
|
2
|
+
from .focal_loss import FocalLoss
|
3
|
+
from .lemda_loss import LemdaLoss
|
4
|
+
from .rkd_loss import RKDLoss
|
5
|
+
from .bce_loss import BBCEWithLogitLoss
|
6
|
+
from .structure_loss import StructureLoss
|
7
|
+
from .utils import (
|
8
|
+
generate_metric_learning_labels,
|
9
|
+
get_aug_loss_func,
|
10
|
+
get_loss_func,
|
11
|
+
get_matcher_loss_func,
|
12
|
+
get_matcher_miner_func,
|
13
|
+
get_metric_learning_distance_func,
|
14
|
+
)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class BBCEWithLogitLoss(nn.Module):
|
6
|
+
"""
|
7
|
+
Balanced BCEWithLogitLoss based on https://github.com/NiFangBaAGe/Explicit-Visual-Prompt/blob/latest_branch/models/segformer.py
|
8
|
+
"""
|
9
|
+
|
10
|
+
def __init__(self):
|
11
|
+
super(BBCEWithLogitLoss, self).__init__()
|
12
|
+
|
13
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
14
|
+
if input.dim() == 3:
|
15
|
+
input = input.unsqueeze(1)
|
16
|
+
eps = 1e-10
|
17
|
+
count_pos = torch.sum(target) + eps
|
18
|
+
count_neg = torch.sum(1.0 - target)
|
19
|
+
ratio = count_neg / count_pos
|
20
|
+
w_neg = count_pos / (count_pos + count_neg)
|
21
|
+
|
22
|
+
bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio)
|
23
|
+
loss = w_neg * bce1(input, target)
|
24
|
+
|
25
|
+
return loss
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
|
8
|
+
class FocalLoss(nn.Module):
|
9
|
+
"""
|
10
|
+
Focal loss based on https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
|
11
|
+
|
12
|
+
References:
|
13
|
+
[1] https://arxiv.org/abs/1708.02002
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
alpha: Optional[torch.Tensor] = None,
|
19
|
+
gamma: Optional[float] = 2.0,
|
20
|
+
reduction: Optional[str] = "mean",
|
21
|
+
eps: Optional[float] = 1e-6,
|
22
|
+
):
|
23
|
+
"""
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
alpha
|
28
|
+
weighting factor for each class. Should be of shape (num_classes)
|
29
|
+
gamma
|
30
|
+
the focal parameter for calculating weights on easy/hard samples
|
31
|
+
reduction
|
32
|
+
the reduction to apply to the final loss output. Default: "mean". Options:
|
33
|
+
"mean", "sum"
|
34
|
+
eps
|
35
|
+
epsilon for numerical stability
|
36
|
+
"""
|
37
|
+
super(FocalLoss, self).__init__()
|
38
|
+
|
39
|
+
self.gamma = gamma
|
40
|
+
self.reduction = reduction
|
41
|
+
self.eps = eps
|
42
|
+
if alpha is not None:
|
43
|
+
if isinstance(alpha, str): # handles Ray Tune HPO sampled hyperparameter
|
44
|
+
try:
|
45
|
+
numbers = alpha.strip("()").split(",")
|
46
|
+
alpha = [float(num) for num in numbers]
|
47
|
+
except:
|
48
|
+
raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
|
49
|
+
alpha = torch.tensor(alpha)
|
50
|
+
self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
|
51
|
+
|
52
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
53
|
+
if not torch.is_tensor(input):
|
54
|
+
raise TypeError("input type is not a torch.Tensor. Got {}".format(type(input)))
|
55
|
+
if input.ndim > 2:
|
56
|
+
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
|
57
|
+
num_class = input.shape[1]
|
58
|
+
input = input.permute(0, *range(2, input.ndim), 1).reshape(-1, num_class)
|
59
|
+
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
|
60
|
+
target = target.view(-1)
|
61
|
+
|
62
|
+
pt = F.softmax(input, dim=-1)
|
63
|
+
|
64
|
+
# -alpha_t * log(pt) term
|
65
|
+
log_p = torch.log_softmax(input, dim=-1)
|
66
|
+
ce = self.nll_loss(log_p, target)
|
67
|
+
|
68
|
+
# (1 - pt)^gamma term
|
69
|
+
all_rows = torch.arange(input.shape[0])
|
70
|
+
pt = pt[all_rows, target]
|
71
|
+
focal_term = (1 - pt) ** self.gamma
|
72
|
+
|
73
|
+
loss = focal_term * ce
|
74
|
+
|
75
|
+
if self.reduction == "mean":
|
76
|
+
loss = loss.mean()
|
77
|
+
|
78
|
+
elif self.reduction == "sum":
|
79
|
+
loss = loss.sum()
|
80
|
+
|
81
|
+
return loss
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
from ...constants import BINARY, REGRESSION
|
6
|
+
|
7
|
+
|
8
|
+
class LemdaLoss(nn.Module):
|
9
|
+
def __init__(self, mse_weight, kld_weight, consist_weight, consist_threshold, problem_type):
|
10
|
+
super().__init__()
|
11
|
+
self.mse_loss = nn.MSELoss(reduction="mean")
|
12
|
+
self.mse_weight = mse_weight
|
13
|
+
self.kld_weight = kld_weight
|
14
|
+
self.consist_weight = consist_weight
|
15
|
+
self.consist_threshold = consist_threshold
|
16
|
+
self.problem_type = problem_type
|
17
|
+
|
18
|
+
def consist_loss(self, p_logits, q_logits):
|
19
|
+
p = F.softmax(p_logits, dim=1)
|
20
|
+
logp = F.log_softmax(p_logits, dim=1)
|
21
|
+
logq = F.log_softmax(q_logits, dim=1)
|
22
|
+
loss = torch.sum(p * (logp - logq), dim=-1)
|
23
|
+
q = F.softmax(q_logits, dim=1)
|
24
|
+
q_largest = torch.max(q, dim=1)[0]
|
25
|
+
loss_mask = torch.gt(q_largest, self.consist_threshold).float()
|
26
|
+
loss = loss * loss_mask
|
27
|
+
return torch.mean(loss)
|
28
|
+
|
29
|
+
def forward(self, pre_aug, post_aug, vae_mean, vae_var, ori_logits, aug_logits):
|
30
|
+
mse_loss = self.mse_loss(pre_aug, post_aug) * self.mse_weight
|
31
|
+
# see Appendix B from VAE paper: https://arxiv.org/abs/1312.6114
|
32
|
+
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
33
|
+
kld_loss = -0.5 * torch.mean(1 + vae_var - vae_mean.pow(2) - vae_var.exp()) * self.kld_weight
|
34
|
+
if self.problem_type in [REGRESSION, BINARY]:
|
35
|
+
consist_loss = self.mse_loss(ori_logits, aug_logits) * self.consist_weight
|
36
|
+
else:
|
37
|
+
consist_loss = self.consist_loss(ori_logits, aug_logits) * self.consist_weight
|
38
|
+
|
39
|
+
return mse_loss + kld_loss + consist_loss
|
@@ -0,0 +1,103 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
|
8
|
+
class RKDLoss(nn.Module):
|
9
|
+
"""
|
10
|
+
Compute RKD Distance Loss.
|
11
|
+
Paper Refer to: Relational Knowledge Disitllation, CVPR2019. https://arxiv.org/abs/1904.05068
|
12
|
+
Code Refer to: https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/RKD.py
|
13
|
+
and https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, distance_loss_weight: Optional[float] = 25.0, angle_loss_weight: Optional[float] = 50.0):
|
17
|
+
"""
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
distance_loss_weight
|
21
|
+
Weight of RKD distance loss
|
22
|
+
angle_loss_weight
|
23
|
+
Weight of RKD angle loss
|
24
|
+
Returns
|
25
|
+
-------
|
26
|
+
"""
|
27
|
+
super(RKDLoss, self).__init__()
|
28
|
+
self.distance_loss_weight = distance_loss_weight
|
29
|
+
self.angle_loss_weight = angle_loss_weight
|
30
|
+
|
31
|
+
def forward(self, feature_student: Optional[torch.Tensor], feature_teacher: Optional[torch.Tensor]):
|
32
|
+
"""
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
feature_student
|
36
|
+
Output feature of student model, shape: (N, D)
|
37
|
+
feature_teacher
|
38
|
+
Output feature of teacher model, shape: (N, D)
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
The RKD Loss between teacher and student
|
42
|
+
"""
|
43
|
+
# RKD loss
|
44
|
+
if self.distance_loss_weight > 0:
|
45
|
+
with torch.no_grad():
|
46
|
+
t_dist = self.pdist(feature_teacher, squared=False)
|
47
|
+
mean_td = t_dist[t_dist > 0].mean()
|
48
|
+
t_dist = t_dist / mean_td
|
49
|
+
|
50
|
+
s_dist = self.pdist(feature_student, squared=False)
|
51
|
+
mean_d = s_dist[s_dist > 0].mean()
|
52
|
+
s_dist = s_dist / mean_d
|
53
|
+
|
54
|
+
loss_distance = F.smooth_l1_loss(s_dist, t_dist)
|
55
|
+
|
56
|
+
# RKD Angle loss
|
57
|
+
if self.angle_loss_weight > 0:
|
58
|
+
with torch.no_grad():
|
59
|
+
td = feature_teacher.unsqueeze(0) - feature_teacher.unsqueeze(1)
|
60
|
+
norm_td = F.normalize(td, p=2, dim=2)
|
61
|
+
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
|
62
|
+
|
63
|
+
sd = feature_student.unsqueeze(0) - feature_student.unsqueeze(1)
|
64
|
+
norm_sd = F.normalize(sd, p=2, dim=2)
|
65
|
+
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
|
66
|
+
|
67
|
+
loss_angle = F.smooth_l1_loss(s_angle, t_angle)
|
68
|
+
|
69
|
+
loss = ((self.distance_loss_weight * loss_distance) if self.distance_loss_weight > 0 else 0) + (
|
70
|
+
(self.angle_loss_weight * loss_angle) if self.angle_loss_weight > 0 else 0
|
71
|
+
)
|
72
|
+
|
73
|
+
return loss
|
74
|
+
|
75
|
+
@staticmethod
|
76
|
+
def pdist(embeddings: Optional[torch.Tensor], squared: Optional[bool] = False, eps: Optional[float] = 1e-12):
|
77
|
+
"""
|
78
|
+
Compute pairwise Euclidean distances between embeddings in n-dimensional space.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
embeddings
|
83
|
+
The embeddings to compute pairwise distance between. Shape: (N,D)
|
84
|
+
squared
|
85
|
+
If the result is square of Euclidean distance.
|
86
|
+
eps
|
87
|
+
Min value of each entry.
|
88
|
+
|
89
|
+
Returns
|
90
|
+
-------
|
91
|
+
Pairwise Euclidean distances. Shape: (N,N)
|
92
|
+
"""
|
93
|
+
e_square = embeddings.pow(2).sum(dim=1)
|
94
|
+
prod = embeddings @ embeddings.t()
|
95
|
+
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
|
96
|
+
|
97
|
+
if not squared:
|
98
|
+
res = res.sqrt()
|
99
|
+
|
100
|
+
res = res.clone()
|
101
|
+
res[range(len(embeddings)), range(len(embeddings))] = 0
|
102
|
+
|
103
|
+
return res
|