autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,177 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+
19
+ class SoftTargetCrossEntropy(nn.Module):
20
+ """
21
+ The soft target CrossEntropy from timm.
22
+ https://github.com/rwightman/pytorch-image-models/blob/e4360e6125bb0bb4279785810c8eb33b40af3ebd/timm/loss/cross_entropy.py
23
+ It works under the mixup.
24
+ It can calculate the crossentropy of input and label with one-hot.
25
+ """
26
+
27
+ def __init__(self):
28
+ super(SoftTargetCrossEntropy, self).__init__()
29
+
30
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
31
+ loss = torch.sum(-target * F.log_softmax(input, dim=-1), dim=-1)
32
+ return loss.mean()
33
+
34
+
35
+ class MultiNegativesSoftmaxLoss(nn.Module):
36
+ """
37
+ This loss expects as input a batch consisting of pairs (a_1, p_1), (a_2, p_2)…, (a_n, p_n) where
38
+ we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
39
+ For each a_i, it uses all other p_j as negative samples, i.e., for a_i,
40
+ we have 1 positive example (p_i) and n-1 negative examples (p_j).
41
+ It then minimizes the negative log-likehood for softmax normalized scores.
42
+ It can also support gather negatives across processes.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ local_loss=False,
48
+ gather_with_grad=False,
49
+ cache_labels=False,
50
+ use_horovod=False,
51
+ ):
52
+ """
53
+ Parameters
54
+ ----------
55
+ local_loss
56
+ Whether to compute the loss only for the current process's samples.
57
+ gather_with_grad
58
+ Whether to gather all features with gradients enabled.
59
+ cache_labels
60
+ Whether to cache labels for loss in next iterations.
61
+ use_horovod
62
+ Whether to use horovod.
63
+ """
64
+ super().__init__()
65
+ self.local_loss = local_loss
66
+ self.gather_with_grad = gather_with_grad
67
+ self.cache_labels = cache_labels
68
+ self.use_horovod = use_horovod
69
+
70
+ # cache state
71
+ self.prev_num_logits = 0
72
+ self.labels = {}
73
+
74
+ def forward(self, features_a, features_b, logit_scale, rank=0, world_size=1):
75
+ device = features_a.device
76
+ if world_size > 1:
77
+ all_features_a, all_features_b = self.gather_features(
78
+ features_a, features_b, self.local_loss, self.gather_with_grad, rank, world_size, self.use_horovod
79
+ )
80
+
81
+ if self.local_loss:
82
+ logits_per_a = logit_scale * features_a @ all_features_b.T
83
+ logits_per_b = logit_scale * features_b @ all_features_a.T
84
+ else:
85
+ logits_per_a = logit_scale * all_features_a @ all_features_b.T
86
+ logits_per_b = logits_per_a.T
87
+ else:
88
+ logits_per_a = logit_scale * features_a @ features_b.T
89
+ logits_per_b = logit_scale * features_b @ features_a.T
90
+
91
+ # calculated ground-truth and cache if enabled
92
+ num_logits = logits_per_a.shape[0]
93
+ if self.prev_num_logits != num_logits or device not in self.labels:
94
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
95
+ if world_size > 1 and self.local_loss:
96
+ labels = labels + num_logits * rank
97
+ if self.cache_labels:
98
+ self.labels[device] = labels
99
+ self.prev_num_logits = num_logits
100
+ else:
101
+ labels = self.labels[device]
102
+
103
+ total_loss = (F.cross_entropy(logits_per_a, labels) + F.cross_entropy(logits_per_b, labels)) / 2
104
+ return total_loss
105
+
106
+ @staticmethod
107
+ def gather_features(
108
+ image_features,
109
+ text_features,
110
+ local_loss=False,
111
+ gather_with_grad=False,
112
+ rank=0,
113
+ world_size=1,
114
+ use_horovod=False,
115
+ ):
116
+ """
117
+ Gather features across GPUs.
118
+
119
+ Parameters
120
+ ----------
121
+ image_features
122
+ image features of the current process.
123
+ text_features
124
+ text features of the current process.
125
+ local_loss
126
+ If False, make sure the features on the current GPU have gradients.
127
+ gather_with_grad
128
+ Whether to gather all features with gradients enabled.
129
+ rank
130
+ Rank of the current process (it should be a number between 0 and world_size-1).
131
+ world_size
132
+ Number of processes participating in the job.
133
+ use_horovod
134
+ Whether to use horovod.
135
+
136
+ Returns
137
+ -------
138
+ Gathered image and text features from all processes.
139
+ """
140
+ assert (
141
+ has_distributed
142
+ ), "torch.distributed did not import correctly, please use a PyTorch version with support."
143
+ if use_horovod:
144
+ assert hvd is not None, "Please install horovod"
145
+ if gather_with_grad:
146
+ all_image_features = hvd.allgather(image_features)
147
+ all_text_features = hvd.allgather(text_features)
148
+ else:
149
+ with torch.no_grad():
150
+ all_image_features = hvd.allgather(image_features)
151
+ all_text_features = hvd.allgather(text_features)
152
+ if not local_loss:
153
+ # ensure grads for local rank when all_* features don't have a gradient
154
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
155
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
156
+ gathered_image_features[rank] = image_features
157
+ gathered_text_features[rank] = text_features
158
+ all_image_features = torch.cat(gathered_image_features, dim=0)
159
+ all_text_features = torch.cat(gathered_text_features, dim=0)
160
+ else:
161
+ # We gather tensors from all gpus
162
+ if gather_with_grad:
163
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
164
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
165
+ else:
166
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
167
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
168
+ dist.all_gather(gathered_image_features, image_features)
169
+ dist.all_gather(gathered_text_features, text_features)
170
+ if not local_loss:
171
+ # ensure grads for local rank when all_* features don't have a gradient
172
+ gathered_image_features[rank] = image_features
173
+ gathered_text_features[rank] = text_features
174
+ all_image_features = torch.cat(gathered_image_features, dim=0)
175
+ all_text_features = torch.cat(gathered_text_features, dim=0)
176
+
177
+ return all_image_features, all_text_features
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class StructureLoss(nn.Module):
7
+ """
8
+ Structure Loss based on https://github.com/DengPingFan/PraNet/blob/master/MyTrain.py
9
+ The loss represent the weighted IoU loss and binary cross entropy (BCE) loss for the global restriction and local (pixel-level) restriction.
10
+
11
+ References:
12
+ [1] https://arxiv.org/abs/2006.11392
13
+ """
14
+
15
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
16
+ if input.dim() == 3:
17
+ input = input.unsqueeze(1)
18
+ weit = 1 + 5 * torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15) - target)
19
+ wbce = F.binary_cross_entropy_with_logits(input, target, reduce="none")
20
+ wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
21
+
22
+ input = torch.sigmoid(input)
23
+ inter = ((input * target) * weit).sum(dim=(2, 3))
24
+ union = ((input + target) * weit).sum(dim=(2, 3))
25
+ wiou = 1 - (inter + 1) / (union - inter + 1)
26
+ return (wbce + wiou).mean()
@@ -0,0 +1,313 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from omegaconf import DictConfig, OmegaConf
6
+ from pytorch_metric_learning import distances, losses, miners
7
+ from torch import nn
8
+ from transformers.models.mask2former.modeling_mask2former import Mask2FormerConfig, Mask2FormerLoss
9
+
10
+ from ...constants import (
11
+ BINARY,
12
+ CONTRASTIVE_LOSS,
13
+ COSINE_SIMILARITY,
14
+ FEW_SHOT_CLASSIFICATION,
15
+ MULTI_NEGATIVES_SOFTMAX_LOSS,
16
+ MULTICLASS,
17
+ NER,
18
+ OBJECT_DETECTION,
19
+ PAIR_MARGIN_MINER,
20
+ REGRESSION,
21
+ SEMANTIC_SEGMENTATION,
22
+ )
23
+ from .bce_loss import BBCEWithLogitLoss
24
+ from .focal_loss import FocalLoss
25
+ from .lemda_loss import LemdaLoss
26
+ from .softmax_losses import MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy
27
+ from .structure_loss import StructureLoss
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_loss_func(
33
+ problem_type: str,
34
+ mixup_active: Optional[bool] = None,
35
+ loss_func_name: Optional[str] = None,
36
+ config: Optional[DictConfig] = None,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ Choose a suitable Pytorch loss module based on the provided problem type.
41
+
42
+ Parameters
43
+ ----------
44
+ problem_type
45
+ Type of problem.
46
+ mixup_active
47
+ The activation determining whether to use mixup.
48
+ loss_func_name
49
+ The name of the function the user wants to use.
50
+ config
51
+ The optimization configs containing values such as i.e. optim.loss_func
52
+ An example purpose of this config here is to pass through the parameters for focal loss, i.e.:
53
+ alpha = optim.focal_loss.alpha
54
+ Returns
55
+ -------
56
+ A Pytorch loss module.
57
+ """
58
+ if problem_type in [BINARY, MULTICLASS]:
59
+ if mixup_active:
60
+ loss_func = SoftTargetCrossEntropy()
61
+ else:
62
+ if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
63
+ loss_func = FocalLoss(
64
+ alpha=config.focal_loss.alpha,
65
+ gamma=config.focal_loss.gamma,
66
+ reduction=config.focal_loss.reduction,
67
+ )
68
+ else:
69
+ loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
70
+ logger.debug(f"loss_func.label_smoothing: {loss_func.label_smoothing}")
71
+ elif problem_type == REGRESSION:
72
+ if loss_func_name is not None:
73
+ if "bcewithlogitsloss" in loss_func_name.lower():
74
+ loss_func = nn.BCEWithLogitsLoss()
75
+ else:
76
+ loss_func = nn.MSELoss()
77
+ else:
78
+ loss_func = nn.MSELoss()
79
+ elif problem_type == NER:
80
+ loss_func = nn.CrossEntropyLoss(ignore_index=0)
81
+ elif problem_type in [None, OBJECT_DETECTION, FEW_SHOT_CLASSIFICATION]:
82
+ return None
83
+ elif problem_type == SEMANTIC_SEGMENTATION:
84
+ if "structure_loss" in loss_func_name.lower():
85
+ loss_func = StructureLoss()
86
+ elif "balanced_bce" in loss_func_name.lower():
87
+ loss_func = BBCEWithLogitLoss()
88
+ elif "mask2former_loss" in loss_func_name.lower():
89
+ weight_dict = {
90
+ "loss_cross_entropy": config.mask2former_loss.loss_cross_entropy_weight,
91
+ "loss_mask": config.mask2former_loss.loss_mask_weight,
92
+ "loss_dice": config.mask2former_loss.loss_dice_weight,
93
+ }
94
+ loss_func = Mask2FormerLoss(
95
+ config=Mask2FormerConfig(num_labels=kwargs["num_classes"]), weight_dict=weight_dict
96
+ )
97
+ else:
98
+ loss_func = nn.BCEWithLogitsLoss()
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ return loss_func
103
+
104
+
105
+ def get_metric_learning_distance_func(
106
+ name: str,
107
+ ):
108
+ """
109
+ Return one pytorch metric learning's distance function based on its name.
110
+
111
+ Parameters
112
+ ----------
113
+ name
114
+ distance function name
115
+
116
+ Returns
117
+ -------
118
+ A distance function from the pytorch metric learning package.
119
+ """
120
+ if name.lower() == COSINE_SIMILARITY:
121
+ return distances.CosineSimilarity()
122
+ else:
123
+ raise ValueError(f"Unknown distance measure: {name}")
124
+
125
+
126
+ def infer_matcher_loss(data_format: str, problem_type: str):
127
+ """
128
+ Infer the loss type to train the matcher.
129
+
130
+ Parameters
131
+ ----------
132
+ data_format
133
+ The training data format, e.g., pair or triplet.
134
+ problem_type
135
+ Type of problem.
136
+
137
+ Returns
138
+ -------
139
+ The loss name.
140
+ """
141
+ if data_format == "pair":
142
+ if problem_type is None:
143
+ return [MULTI_NEGATIVES_SOFTMAX_LOSS]
144
+ elif problem_type == BINARY:
145
+ return [CONTRASTIVE_LOSS]
146
+ elif problem_type == REGRESSION:
147
+ return ["cosine_similarity_loss"]
148
+ else:
149
+ raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
150
+ elif data_format == "triplet":
151
+ if problem_type is None:
152
+ return [MULTI_NEGATIVES_SOFTMAX_LOSS]
153
+ else:
154
+ raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
155
+ else:
156
+ raise ValueError(f"Unsupported data format: {data_format}")
157
+
158
+
159
+ def get_matcher_loss_func(
160
+ data_format: str,
161
+ problem_type: str,
162
+ loss_type: Optional[str] = None,
163
+ pos_margin: Optional[float] = None,
164
+ neg_margin: Optional[float] = None,
165
+ distance_type: Optional[str] = None,
166
+ ):
167
+ """
168
+ Return a list of pytorch metric learning's loss functions based on their names.
169
+
170
+ Parameters
171
+ ----------
172
+ data_format
173
+ The training data format, e.g., pair or triplet.
174
+ problem_type
175
+ Type of problem.
176
+ loss_type
177
+ The provided loss type.
178
+ pos_margin
179
+ The positive margin in computing the metric learning loss.
180
+ neg_margin
181
+ The negative margin in computing the metric learning loss.
182
+ distance_type
183
+ The distance function type.
184
+
185
+ Returns
186
+ -------
187
+ A loss function of metric learning.
188
+ """
189
+
190
+ allowable_loss_types = infer_matcher_loss(data_format=data_format, problem_type=problem_type)
191
+ if loss_type is not None:
192
+ assert loss_type in allowable_loss_types, f"data format {data_format} can't use loss {loss_type}."
193
+ else:
194
+ loss_type = allowable_loss_types[0]
195
+
196
+ if loss_type.lower() == CONTRASTIVE_LOSS:
197
+ return losses.ContrastiveLoss(
198
+ pos_margin=pos_margin,
199
+ neg_margin=neg_margin,
200
+ distance=get_metric_learning_distance_func(distance_type),
201
+ )
202
+ elif loss_type.lower() == MULTI_NEGATIVES_SOFTMAX_LOSS:
203
+ return MultiNegativesSoftmaxLoss(
204
+ local_loss=True,
205
+ gather_with_grad=True,
206
+ cache_labels=False,
207
+ )
208
+ else:
209
+ raise ValueError(f"Unknown metric learning loss: {loss_type}")
210
+
211
+
212
+ def get_matcher_miner_func(
213
+ miner_type: str,
214
+ pos_margin: float,
215
+ neg_margin: float,
216
+ distance_type: str,
217
+ ):
218
+ """
219
+ Return a pytorch metric learning's miner functions based on their names.
220
+ The miners are used to mine the positive and negative examples.
221
+
222
+ Parameters
223
+ ----------
224
+ miner_type
225
+ The miner function type.
226
+ pos_margin
227
+ The positive margin used by the miner function.
228
+ neg_margin
229
+ The negative margin used by the miner function.
230
+ distance_type
231
+ The distance function type.
232
+
233
+ Returns
234
+ -------
235
+ A miner function to mine positive and negative samples.
236
+ """
237
+ if miner_type.lower() == PAIR_MARGIN_MINER:
238
+ return miners.PairMarginMiner(
239
+ pos_margin=pos_margin,
240
+ neg_margin=neg_margin,
241
+ distance=get_metric_learning_distance_func(distance_type),
242
+ )
243
+ else:
244
+ raise ValueError(f"Unknown metric learning miner: {miner_type}")
245
+
246
+
247
+ def generate_metric_learning_labels(
248
+ num_samples: int,
249
+ match_label: int,
250
+ labels: torch.Tensor,
251
+ ):
252
+ """
253
+ Generate labels to compute the metric learning loss of one mini-batch.
254
+ For n samples, it generates 2*n labels since each match has two sides, each of which
255
+ has one label. If we know the matching label, then it determines the two sides' labels
256
+ according to whether their label is the matching label. If the matching label is None,
257
+ it assigns a unique label for each side.
258
+
259
+ Parameters
260
+ ----------
261
+ num_samples
262
+ number of samples.
263
+ match_label
264
+ The matching label, which can be None.
265
+ labels
266
+ The sample labels used in the supervised setting. It's required only when match_label is not None.
267
+
268
+ Returns
269
+ -------
270
+ The labels used in computing the metric learning loss.
271
+ """
272
+ device = labels.device
273
+ labels_1 = torch.arange(num_samples, device=device)
274
+
275
+ if match_label is not None:
276
+ labels_2 = torch.arange(num_samples, num_samples * 2, device=device)
277
+ # users need to specify the match_label based on the raw label's semantic meaning.
278
+ mask = labels == match_label
279
+ labels_2[mask] = labels_1[mask]
280
+ else:
281
+ labels_2 = torch.arange(num_samples, device=device)
282
+
283
+ metric_learning_labels = torch.cat([labels_1, labels_2], dim=0)
284
+
285
+ return metric_learning_labels
286
+
287
+
288
+ def get_aug_loss_func(config: Optional[DictConfig] = None, problem_type: Optional[str] = None):
289
+ """
290
+ Return the loss function for lemda augmentation
291
+
292
+ Parameters
293
+ ----------
294
+ config
295
+ The optimization configuration.
296
+ problem_type
297
+ Problem type (binary, multimclass, or regression)
298
+
299
+ Returns
300
+ -------
301
+ Augmentation loss function.
302
+ """
303
+ loss_func = None
304
+ if config.lemda.turn_on:
305
+ loss_func = LemdaLoss(
306
+ mse_weight=config.lemda.mse_weight,
307
+ kld_weight=config.lemda.kld_weight,
308
+ consist_weight=config.lemda.consist_weight,
309
+ consist_threshold=config.lemda.consist_threshold,
310
+ problem_type=problem_type,
311
+ )
312
+
313
+ return loss_func
@@ -0,0 +1 @@
1
+ from .utils import apply_layerwise_lr_decay, apply_single_lr, apply_two_stages_lr, get_lr_scheduler