autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,17 @@
1
+ from .lit_distiller import DistillerLitModule
2
+ from .lit_matcher import MatcherLitModule
3
+ from .lit_mmdet import MMDetLitModule
4
+ from .lit_module import LitModule
5
+ from .lit_ner import NerLitModule
6
+ from .lit_semantic_seg import SemanticSegmentationLitModule
7
+ from .losses import get_aug_loss_func, get_loss_func, get_matcher_loss_func, get_matcher_miner_func
8
+ from .metrics import (
9
+ CustomHitRate,
10
+ compute_ranking_score,
11
+ compute_score,
12
+ get_minmax_mode,
13
+ get_stopping_threshold,
14
+ get_torchmetric,
15
+ infer_metrics,
16
+ )
17
+ from .utils import get_norm_layer_param_names, get_peft_param_names
@@ -13,7 +13,8 @@ from torchmetrics.aggregation import BaseAggregator
13
13
 
14
14
  from ..constants import FEATURES, LOGITS, WEIGHT
15
15
  from ..models.utils import run_model
16
- from .utils import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler, get_optimizer
16
+ from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
17
+ from .utils import get_optimizer
17
18
 
18
19
  logger = logging.getLogger(__name__)
19
20
 
@@ -13,16 +13,10 @@ from torchmetrics.aggregation import BaseAggregator
13
13
  from ..constants import FEATURES, LOGIT_SCALE, PROBABILITY, QUERY, RESPONSE
14
14
  from ..models.utils import run_model
15
15
  from ..utils.matcher import compute_matching_probability
16
- from .losses import MultiNegativesSoftmaxLoss
17
- from .utils import (
18
- CustomHitRate,
19
- apply_layerwise_lr_decay,
20
- apply_single_lr,
21
- apply_two_stages_lr,
22
- generate_metric_learning_labels,
23
- get_lr_scheduler,
24
- get_optimizer,
25
- )
16
+ from .losses import MultiNegativesSoftmaxLoss, generate_metric_learning_labels
17
+ from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
18
+ from .metrics import CustomHitRate
19
+ from .utils import get_optimizer
26
20
 
27
21
  logger = logging.getLogger(__name__)
28
22
 
@@ -2,21 +2,13 @@ import logging
2
2
  from typing import Callable, Optional, Union
3
3
 
4
4
  import lightning.pytorch as pl
5
- import torch
6
5
  import torchmetrics
7
6
  from lightning.pytorch.utilities import grad_norm
8
- from torch.nn.modules.loss import _Loss
9
7
  from torchmetrics.aggregation import BaseAggregator
10
8
 
11
9
  from ..constants import BBOX, IMAGE, LABEL
12
- from .utils import (
13
- apply_layerwise_lr_decay,
14
- apply_single_lr,
15
- apply_two_stages_lr,
16
- get_lr_scheduler,
17
- get_optimizer,
18
- remove_parameters_without_grad,
19
- )
10
+ from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
11
+ from .utils import get_optimizer, remove_parameters_without_grad
20
12
 
21
13
  try:
22
14
  import mmdet
@@ -11,11 +11,25 @@ from torch import nn
11
11
  from torch.nn.modules.loss import _Loss
12
12
  from torchmetrics.aggregation import BaseAggregator
13
13
 
14
- from ..constants import LM_TARGET, LOGITS, T_FEW, TEMPLATE_LOGITS, WEIGHT
14
+ from ..constants import (
15
+ AUG_LOGITS,
16
+ LM_TARGET,
17
+ LOGITS,
18
+ MULTIMODAL_FEATURES,
19
+ MULTIMODAL_FEATURES_POST_AUG,
20
+ MULTIMODAL_FEATURES_PRE_AUG,
21
+ ORI_LOGITS,
22
+ T_FEW,
23
+ TEMPLATE_LOGITS,
24
+ VAE_MEAN,
25
+ VAE_VAR,
26
+ WEIGHT,
27
+ )
15
28
  from ..data.mixup import MixupModule, multimodel_mixup
16
29
  from ..models.utils import run_model
17
- from .semantic_seg_metrics import COD, Balanced_Error_Rate
18
- from .utils import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler, get_optimizer
30
+ from .lr import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler
31
+ from .metrics import Coverage
32
+ from .utils import get_optimizer
19
33
 
20
34
  logger = logging.getLogger(__name__)
21
35
 
@@ -44,13 +58,24 @@ class LitModule(pl.LightningModule):
44
58
  validation_metric_name: Optional[str] = None,
45
59
  custom_metric_func: Callable = None,
46
60
  test_metric: Optional[torchmetrics.Metric] = None,
47
- efficient_finetune: Optional[str] = None,
61
+ peft: Optional[str] = None,
48
62
  trainable_param_names: Optional[List] = None,
49
63
  mixup_fn: Optional[MixupModule] = None,
50
64
  mixup_off_epoch: Optional[int] = 0,
51
65
  model_postprocess_fn: Callable = None,
52
66
  skip_final_val: Optional[bool] = False,
53
67
  track_grad_norm: Optional[Union[int, str]] = -1,
68
+ cross_modal_align: Optional[str] = None,
69
+ cross_modal_align_weight: Optional[float] = 0,
70
+ automatic_optimization: Optional[bool] = True,
71
+ accumulate_grad_batches: Optional[int] = None,
72
+ gradient_clip_val: Optional[float] = None,
73
+ gradient_clip_algorithm: Optional[str] = None,
74
+ use_aug_optim: Optional[bool] = False,
75
+ aug_loss_func: Optional[_Loss] = None,
76
+ aug_lr: Optional[float] = None,
77
+ aug_weight_decay: Optional[float] = None,
78
+ aug_optim_type: Optional[str] = None,
54
79
  ):
55
80
  """
56
81
  Parameters
@@ -104,7 +129,7 @@ class LitModule(pl.LightningModule):
104
129
  Refer to https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/aggregation.py
105
130
  test_metric
106
131
  A torchmetrics module used in the test stage, e.g., torchmetrics.Accuracy().
107
- efficient_finetune
132
+ peft
108
133
  Whether to use efficient finetuning strategies. This will be helpful for fast finetuning of large backbones.
109
134
  We support options such as:
110
135
 
@@ -128,6 +153,8 @@ class LitModule(pl.LightningModule):
128
153
  "model_postprocess_fn",
129
154
  "mixup_fn",
130
155
  "trainable_param_names",
156
+ "custom_metric_func",
157
+ "aug_loss_func",
131
158
  ]
132
159
  )
133
160
  self.model = model
@@ -144,7 +171,12 @@ class LitModule(pl.LightningModule):
144
171
  self.model_postprocess_fn = model_postprocess_fn
145
172
  self.trainable_param_names = trainable_param_names if trainable_param_names else []
146
173
  self.skip_final_val = skip_final_val
147
- self.track_grad_norm = track_grad_norm
174
+ self.automatic_optimization = automatic_optimization
175
+ self.aug_loss_func = aug_loss_func
176
+ if self.hparams.cross_modal_align:
177
+ assert self.hparams.cross_modal_align_weight > 0
178
+ logger.debug(f"Cross modal alignment mode: {self.hparams.cross_modal_align}")
179
+ logger.debug(f"Cross modal alignment loss weight: {self.hparams.cross_modal_align_weight}")
148
180
 
149
181
  def _compute_template_loss(
150
182
  self,
@@ -179,19 +211,41 @@ class LitModule(pl.LightningModule):
179
211
 
180
212
  return lm_loss + mc_loss * self.model.mc_loss + unlikely_loss * self.model.unlikely_loss
181
213
 
214
+ def _compute_cross_modal_align_loss(self, multimodal_features):
215
+ if self.hparams.cross_modal_align == "positive_only":
216
+ kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
217
+ loss = 0
218
+ num = 0
219
+ for i in range(len(multimodal_features)):
220
+ # input should be a distribution in the log space
221
+ a = F.log_softmax(multimodal_features[i], dim=1)
222
+ # kl divergence is not symmetric, so need to compute both (i, j) and (j, i)
223
+ for j in range(len(multimodal_features)):
224
+ if i == j:
225
+ continue
226
+ # input should be a distribution in the log space
227
+ b = F.log_softmax(multimodal_features[j], dim=1)
228
+ loss += kl_loss(a, b)
229
+ num += 1
230
+ return self.hparams.cross_modal_align_weight * loss / num
231
+ else:
232
+ raise ValueError(f"Unsupported cross modal alignment loss: {self.hparams.cross_modal_align}.")
233
+
182
234
  def _compute_loss(
183
235
  self,
184
236
  output: Dict,
185
237
  label: torch.Tensor,
186
238
  ):
187
239
  loss = 0
188
- for _, per_output in output.items():
240
+ for per_prefix, per_output in output.items():
189
241
  weight = per_output[WEIGHT] if WEIGHT in per_output else 1
190
242
  if (
191
243
  TEMPLATE_LOGITS in per_output and self.model.prefix == T_FEW
192
244
  ): # Do only add template loss if T-Few. #TODO Add compatibility to Fusion models.
193
245
  loss += self._compute_template_loss(per_output, label) * weight
194
246
  else:
247
+ if self.training and self.hparams.use_aug_optim and per_prefix.startswith("fusion"):
248
+ label = label.tile((2,))
195
249
  loss += (
196
250
  self.loss_func(
197
251
  input=per_output[LOGITS].squeeze(dim=1),
@@ -199,6 +253,22 @@ class LitModule(pl.LightningModule):
199
253
  )
200
254
  * weight
201
255
  )
256
+
257
+ if self.hparams.cross_modal_align:
258
+ loss += self._compute_cross_modal_align_loss(
259
+ multimodal_features=output[self.model.prefix][MULTIMODAL_FEATURES]
260
+ )
261
+
262
+ if self.training and self.hparams.use_aug_optim:
263
+ loss += self.aug_loss_func(
264
+ pre_aug=output[self.model.prefix][MULTIMODAL_FEATURES_PRE_AUG],
265
+ post_aug=output[self.model.prefix][MULTIMODAL_FEATURES_POST_AUG],
266
+ vae_mean=output[self.model.prefix][VAE_MEAN],
267
+ vae_var=output[self.model.prefix][VAE_VAR],
268
+ ori_logits=output[self.model.prefix][ORI_LOGITS],
269
+ aug_logits=output[self.model.prefix][AUG_LOGITS],
270
+ )
271
+
202
272
  return loss
203
273
 
204
274
  def _compute_metric_score(
@@ -214,6 +284,7 @@ class LitModule(pl.LightningModule):
214
284
  torchmetrics.classification.BinaryAUROC,
215
285
  torchmetrics.classification.BinaryAveragePrecision,
216
286
  torchmetrics.classification.BinaryF1Score,
287
+ Coverage,
217
288
  ),
218
289
  ):
219
290
  prob = F.softmax(logits.float(), dim=1)
@@ -255,6 +326,27 @@ class LitModule(pl.LightningModule):
255
326
  Average loss of the mini-batch data.
256
327
  """
257
328
  output, loss = self._shared_step(batch)
329
+
330
+ if not self.automatic_optimization:
331
+ if self.hparams.use_aug_optim:
332
+ optimizer, aug_optimizer = self.optimizers()
333
+ else:
334
+ optimizer = self.optimizers()
335
+ aug_optimizer = None
336
+
337
+ lr_scheduler = self.lr_schedulers()
338
+ loss = loss / self.hparams.accumulate_grad_batches
339
+ self.manual_backward(loss)
340
+
341
+ if (batch_idx + 1) % self.hparams.accumulate_grad_batches == 0 or self.trainer.is_last_batch:
342
+ optimizer.step()
343
+ optimizer.zero_grad()
344
+ lr_scheduler.step()
345
+
346
+ if aug_optimizer is not None:
347
+ aug_optimizer.step()
348
+ aug_optimizer.zero_grad()
349
+
258
350
  self.log("train_loss", loss)
259
351
  return loss
260
352
 
@@ -345,6 +437,9 @@ class LitModule(pl.LightningModule):
345
437
  model=self.model,
346
438
  lr=self.hparams.lr,
347
439
  weight_decay=self.hparams.weight_decay,
440
+ exclude_keys=[
441
+ "augmenter"
442
+ ], # exclude augmenter parameters from the optimizer as they would use an independent optimizer
348
443
  )
349
444
  if self.hparams.lr_choice == "two_stages":
350
445
  logger.debug("applying 2-stage learning rate...")
@@ -357,14 +452,14 @@ class LitModule(pl.LightningModule):
357
452
  logger.debug("applying layerwise learning rate decay...")
358
453
  grouped_parameters = apply_layerwise_lr_decay(
359
454
  lr_decay=self.hparams.lr_decay,
360
- efficient_finetune=self.hparams.efficient_finetune,
455
+ peft=self.hparams.peft,
361
456
  trainable_param_names=self.trainable_param_names,
362
457
  **kwargs,
363
458
  )
364
459
  else:
365
460
  logger.debug("applying single learning rate...")
366
461
  grouped_parameters = apply_single_lr(
367
- efficient_finetune=self.hparams.efficient_finetune,
462
+ peft=self.hparams.peft,
368
463
  trainable_param_names=self.trainable_param_names,
369
464
  **kwargs,
370
465
  )
@@ -381,16 +476,21 @@ class LitModule(pl.LightningModule):
381
476
  if isinstance(self.trainer.strategy, DeepSpeedStrategy):
382
477
  max_steps = 1
383
478
  else:
479
+ accumulate_grad_batches = (
480
+ self.trainer.accumulate_grad_batches
481
+ if self.automatic_optimization
482
+ else self.hparams.accumulate_grad_batches
483
+ )
384
484
  max_steps = (
385
485
  len(self.trainer.datamodule.train_dataloader())
386
486
  * self.trainer.max_epochs
387
- // self.trainer.accumulate_grad_batches
487
+ // accumulate_grad_batches
388
488
  )
389
489
  logger.debug(
390
490
  f"len(trainer.datamodule.train_dataloader()): {len(self.trainer.datamodule.train_dataloader())}"
391
491
  )
392
492
  logger.debug(f"trainer.max_epochs: {self.trainer.max_epochs}")
393
- logger.debug(f"trainer.accumulate_grad_batches: {self.trainer.accumulate_grad_batches}")
493
+ logger.debug(f"accumulate_grad_batches: {accumulate_grad_batches}")
394
494
  else:
395
495
  max_steps = self.trainer.max_steps
396
496
 
@@ -411,10 +511,35 @@ class LitModule(pl.LightningModule):
411
511
  )
412
512
 
413
513
  sched = {"scheduler": scheduler, "interval": "step"}
514
+ ret_optimizers = [optimizer]
515
+ ret_schedulers = [sched]
516
+ if self.hparams.use_aug_optim:
517
+ logger.debug("initializing augment optimizer")
518
+ # augmenter's optimizer
519
+ aug_grouped_parameters = apply_single_lr(
520
+ model=self.model.augmenter,
521
+ lr=self.hparams.aug_lr,
522
+ weight_decay=self.hparams.aug_weight_decay,
523
+ )
524
+ aug_optimizer = get_optimizer(
525
+ optim_type=self.hparams.aug_optim_type,
526
+ optimizer_grouped_parameters=aug_grouped_parameters,
527
+ lr=self.hparams.aug_lr,
528
+ weight_decay=self.hparams.aug_weight_decay,
529
+ )
530
+ ret_optimizers.append(aug_optimizer)
531
+
414
532
  logger.debug("done configuring optimizer and scheduler")
415
- return [optimizer], [sched]
533
+ return ret_optimizers, ret_schedulers
416
534
 
417
535
  def on_before_optimizer_step(self, optimizer):
418
536
  # If using mixed precision, the gradients are already unscaled here
419
- if self.track_grad_norm != -1:
420
- self.log_dict(grad_norm(self, norm_type=self.track_grad_norm))
537
+ # TODO: apply gradient clip only to the target optimizer
538
+ if not self.automatic_optimization and self.hparams.gradient_clip_val > 0:
539
+ self.clip_gradients(
540
+ optimizer,
541
+ gradient_clip_val=self.hparams.gradient_clip_val,
542
+ gradient_clip_algorithm=self.hparams.gradient_clip_algorithm,
543
+ )
544
+ if self.hparams.track_grad_norm != -1:
545
+ self.log_dict(grad_norm(self, norm_type=self.hparams.track_grad_norm))
@@ -37,7 +37,7 @@ class NerLitModule(LitModule):
37
37
  validation_metric_name: Optional[str] = None,
38
38
  custom_metric_func: Callable = None,
39
39
  test_metric: Optional[torchmetrics.Metric] = None,
40
- efficient_finetune: Optional[str] = None,
40
+ peft: Optional[str] = None,
41
41
  trainable_param_names: Optional[List] = None,
42
42
  mixup_fn: Optional[MixupModule] = None,
43
43
  mixup_off_epoch: Optional[int] = 0,
@@ -97,7 +97,7 @@ class NerLitModule(LitModule):
97
97
  Refer to https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/aggregation.py
98
98
  test_metric
99
99
  A torchmetrics module used in the test stage, e.g., torchmetrics.Accuracy().
100
- efficient_finetune
100
+ peft
101
101
  Whether to use efficient finetuning strategies. This will be helpful for fast finetuning of large backbones.
102
102
  We support options such as:
103
103
 
@@ -127,7 +127,7 @@ class NerLitModule(LitModule):
127
127
  validation_metric_name=validation_metric_name,
128
128
  custom_metric_func=custom_metric_func,
129
129
  test_metric=test_metric,
130
- efficient_finetune=efficient_finetune,
130
+ peft=peft,
131
131
  trainable_param_names=trainable_param_names,
132
132
  mixup_fn=mixup_fn,
133
133
  mixup_off_epoch=mixup_off_epoch,
@@ -8,7 +8,7 @@ from transformers.models.mask2former.modeling_mask2former import Mask2FormerLoss
8
8
  from ..constants import CLASS_LOGITS, LOGITS, MOE_LOSS, SEMANTIC_MASK, WEIGHT
9
9
  from ..models.utils import run_model
10
10
  from .lit_module import LitModule
11
- from .semantic_seg_metrics import Multiclass_IoU
11
+ from .metrics.semantic_seg_metrics import Multiclass_IoU
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
@@ -0,0 +1,14 @@
1
+ from .softmax_losses import MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy
2
+ from .focal_loss import FocalLoss
3
+ from .lemda_loss import LemdaLoss
4
+ from .rkd_loss import RKDLoss
5
+ from .bce_loss import BBCEWithLogitLoss
6
+ from .structure_loss import StructureLoss
7
+ from .utils import (
8
+ generate_metric_learning_labels,
9
+ get_aug_loss_func,
10
+ get_loss_func,
11
+ get_matcher_loss_func,
12
+ get_matcher_miner_func,
13
+ get_metric_learning_distance_func,
14
+ )
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class BBCEWithLogitLoss(nn.Module):
6
+ """
7
+ Balanced BCEWithLogitLoss based on https://github.com/NiFangBaAGe/Explicit-Visual-Prompt/blob/latest_branch/models/segformer.py
8
+ """
9
+
10
+ def __init__(self):
11
+ super(BBCEWithLogitLoss, self).__init__()
12
+
13
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
14
+ if input.dim() == 3:
15
+ input = input.unsqueeze(1)
16
+ eps = 1e-10
17
+ count_pos = torch.sum(target) + eps
18
+ count_neg = torch.sum(1.0 - target)
19
+ ratio = count_neg / count_pos
20
+ w_neg = count_pos / (count_pos + count_neg)
21
+
22
+ bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio)
23
+ loss = w_neg * bce1(input, target)
24
+
25
+ return loss
@@ -0,0 +1,81 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class FocalLoss(nn.Module):
9
+ """
10
+ Focal loss based on https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
11
+
12
+ References:
13
+ [1] https://arxiv.org/abs/1708.02002
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ alpha: Optional[torch.Tensor] = None,
19
+ gamma: Optional[float] = 2.0,
20
+ reduction: Optional[str] = "mean",
21
+ eps: Optional[float] = 1e-6,
22
+ ):
23
+ """
24
+
25
+ Parameters
26
+ ----------
27
+ alpha
28
+ weighting factor for each class. Should be of shape (num_classes)
29
+ gamma
30
+ the focal parameter for calculating weights on easy/hard samples
31
+ reduction
32
+ the reduction to apply to the final loss output. Default: "mean". Options:
33
+ "mean", "sum"
34
+ eps
35
+ epsilon for numerical stability
36
+ """
37
+ super(FocalLoss, self).__init__()
38
+
39
+ self.gamma = gamma
40
+ self.reduction = reduction
41
+ self.eps = eps
42
+ if alpha is not None:
43
+ if isinstance(alpha, str): # handles Ray Tune HPO sampled hyperparameter
44
+ try:
45
+ numbers = alpha.strip("()").split(",")
46
+ alpha = [float(num) for num in numbers]
47
+ except:
48
+ raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
49
+ alpha = torch.tensor(alpha)
50
+ self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
51
+
52
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
53
+ if not torch.is_tensor(input):
54
+ raise TypeError("input type is not a torch.Tensor. Got {}".format(type(input)))
55
+ if input.ndim > 2:
56
+ # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
57
+ num_class = input.shape[1]
58
+ input = input.permute(0, *range(2, input.ndim), 1).reshape(-1, num_class)
59
+ # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
60
+ target = target.view(-1)
61
+
62
+ pt = F.softmax(input, dim=-1)
63
+
64
+ # -alpha_t * log(pt) term
65
+ log_p = torch.log_softmax(input, dim=-1)
66
+ ce = self.nll_loss(log_p, target)
67
+
68
+ # (1 - pt)^gamma term
69
+ all_rows = torch.arange(input.shape[0])
70
+ pt = pt[all_rows, target]
71
+ focal_term = (1 - pt) ** self.gamma
72
+
73
+ loss = focal_term * ce
74
+
75
+ if self.reduction == "mean":
76
+ loss = loss.mean()
77
+
78
+ elif self.reduction == "sum":
79
+ loss = loss.sum()
80
+
81
+ return loss
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from ...constants import BINARY, REGRESSION
6
+
7
+
8
+ class LemdaLoss(nn.Module):
9
+ def __init__(self, mse_weight, kld_weight, consist_weight, consist_threshold, problem_type):
10
+ super().__init__()
11
+ self.mse_loss = nn.MSELoss(reduction="mean")
12
+ self.mse_weight = mse_weight
13
+ self.kld_weight = kld_weight
14
+ self.consist_weight = consist_weight
15
+ self.consist_threshold = consist_threshold
16
+ self.problem_type = problem_type
17
+
18
+ def consist_loss(self, p_logits, q_logits):
19
+ p = F.softmax(p_logits, dim=1)
20
+ logp = F.log_softmax(p_logits, dim=1)
21
+ logq = F.log_softmax(q_logits, dim=1)
22
+ loss = torch.sum(p * (logp - logq), dim=-1)
23
+ q = F.softmax(q_logits, dim=1)
24
+ q_largest = torch.max(q, dim=1)[0]
25
+ loss_mask = torch.gt(q_largest, self.consist_threshold).float()
26
+ loss = loss * loss_mask
27
+ return torch.mean(loss)
28
+
29
+ def forward(self, pre_aug, post_aug, vae_mean, vae_var, ori_logits, aug_logits):
30
+ mse_loss = self.mse_loss(pre_aug, post_aug) * self.mse_weight
31
+ # see Appendix B from VAE paper: https://arxiv.org/abs/1312.6114
32
+ # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
33
+ kld_loss = -0.5 * torch.mean(1 + vae_var - vae_mean.pow(2) - vae_var.exp()) * self.kld_weight
34
+ if self.problem_type in [REGRESSION, BINARY]:
35
+ consist_loss = self.mse_loss(ori_logits, aug_logits) * self.consist_weight
36
+ else:
37
+ consist_loss = self.consist_loss(ori_logits, aug_logits) * self.consist_weight
38
+
39
+ return mse_loss + kld_loss + consist_loss
@@ -0,0 +1,103 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class RKDLoss(nn.Module):
9
+ """
10
+ Compute RKD Distance Loss.
11
+ Paper Refer to: Relational Knowledge Disitllation, CVPR2019. https://arxiv.org/abs/1904.05068
12
+ Code Refer to: https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/RKD.py
13
+ and https://github.com/lenscloth/RKD/blob/master/metric/loss.py
14
+ """
15
+
16
+ def __init__(self, distance_loss_weight: Optional[float] = 25.0, angle_loss_weight: Optional[float] = 50.0):
17
+ """
18
+ Parameters
19
+ ----------
20
+ distance_loss_weight
21
+ Weight of RKD distance loss
22
+ angle_loss_weight
23
+ Weight of RKD angle loss
24
+ Returns
25
+ -------
26
+ """
27
+ super(RKDLoss, self).__init__()
28
+ self.distance_loss_weight = distance_loss_weight
29
+ self.angle_loss_weight = angle_loss_weight
30
+
31
+ def forward(self, feature_student: Optional[torch.Tensor], feature_teacher: Optional[torch.Tensor]):
32
+ """
33
+ Parameters
34
+ ----------
35
+ feature_student
36
+ Output feature of student model, shape: (N, D)
37
+ feature_teacher
38
+ Output feature of teacher model, shape: (N, D)
39
+ Returns
40
+ -------
41
+ The RKD Loss between teacher and student
42
+ """
43
+ # RKD loss
44
+ if self.distance_loss_weight > 0:
45
+ with torch.no_grad():
46
+ t_dist = self.pdist(feature_teacher, squared=False)
47
+ mean_td = t_dist[t_dist > 0].mean()
48
+ t_dist = t_dist / mean_td
49
+
50
+ s_dist = self.pdist(feature_student, squared=False)
51
+ mean_d = s_dist[s_dist > 0].mean()
52
+ s_dist = s_dist / mean_d
53
+
54
+ loss_distance = F.smooth_l1_loss(s_dist, t_dist)
55
+
56
+ # RKD Angle loss
57
+ if self.angle_loss_weight > 0:
58
+ with torch.no_grad():
59
+ td = feature_teacher.unsqueeze(0) - feature_teacher.unsqueeze(1)
60
+ norm_td = F.normalize(td, p=2, dim=2)
61
+ t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
62
+
63
+ sd = feature_student.unsqueeze(0) - feature_student.unsqueeze(1)
64
+ norm_sd = F.normalize(sd, p=2, dim=2)
65
+ s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
66
+
67
+ loss_angle = F.smooth_l1_loss(s_angle, t_angle)
68
+
69
+ loss = ((self.distance_loss_weight * loss_distance) if self.distance_loss_weight > 0 else 0) + (
70
+ (self.angle_loss_weight * loss_angle) if self.angle_loss_weight > 0 else 0
71
+ )
72
+
73
+ return loss
74
+
75
+ @staticmethod
76
+ def pdist(embeddings: Optional[torch.Tensor], squared: Optional[bool] = False, eps: Optional[float] = 1e-12):
77
+ """
78
+ Compute pairwise Euclidean distances between embeddings in n-dimensional space.
79
+
80
+ Parameters
81
+ ----------
82
+ embeddings
83
+ The embeddings to compute pairwise distance between. Shape: (N,D)
84
+ squared
85
+ If the result is square of Euclidean distance.
86
+ eps
87
+ Min value of each entry.
88
+
89
+ Returns
90
+ -------
91
+ Pairwise Euclidean distances. Shape: (N,N)
92
+ """
93
+ e_square = embeddings.pow(2).sum(dim=1)
94
+ prod = embeddings @ embeddings.t()
95
+ res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
96
+
97
+ if not squared:
98
+ res = res.sqrt()
99
+
100
+ res = res.clone()
101
+ res[range(len(embeddings)), range(len(embeddings))] = 0
102
+
103
+ return res