autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,1054 +0,0 @@
1
- import functools
2
- import logging
3
- import re
4
- import warnings
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
-
7
- import torch
8
- import torchmetrics
9
- from omegaconf import DictConfig, OmegaConf
10
- from packaging import version
11
- from pytorch_metric_learning import distances, losses, miners
12
- from torch import nn, optim
13
- from torch.nn import functional as F
14
- from torchmetrics.detection.mean_ap import MeanAveragePrecision
15
- from transformers import Adafactor
16
- from transformers.models.mask2former.modeling_mask2former import Mask2FormerConfig, Mask2FormerLoss
17
- from transformers.trainer_pt_utils import get_parameter_names
18
-
19
- from ..constants import (
20
- ACC,
21
- ACCURACY,
22
- AVERAGE_PRECISION,
23
- BER,
24
- BINARY,
25
- BIT_FIT,
26
- COLUMN_FEATURES,
27
- CONTRASTIVE_LOSS,
28
- CONV_LORA,
29
- COSINE_EMBEDDING_LOSS,
30
- COSINE_SIMILARITY,
31
- CROSS_ENTROPY,
32
- DETECTION_METRICS,
33
- DIRECT_LOSS,
34
- EM,
35
- F1,
36
- F1_MACRO,
37
- F1_MICRO,
38
- F1_WEIGHTED,
39
- FEATURES,
40
- FEW_SHOT_CLASSIFICATION,
41
- FM,
42
- HIT_RATE,
43
- IA3,
44
- IA3_BIAS,
45
- IA3_LORA,
46
- IA3_LORA_BIAS,
47
- IA3_LORA_NORM,
48
- IA3_NORM,
49
- IOU,
50
- LOG_LOSS,
51
- LORA,
52
- LORA_BIAS,
53
- LORA_NORM,
54
- MAE,
55
- MULTI_NEGATIVES_SOFTMAX_LOSS,
56
- MULTICLASS,
57
- NER,
58
- NER_TOKEN_F1,
59
- NORM_FIT,
60
- OBJECT_DETECTION,
61
- OVERALL_ACCURACY,
62
- PAIR_MARGIN_MINER,
63
- PEARSONR,
64
- PEFT_STRATEGIES,
65
- QUADRATIC_KAPPA,
66
- R2,
67
- RECALL,
68
- REGRESSION,
69
- RMSE,
70
- ROC_AUC,
71
- ROOT_MEAN_SQUARED_ERROR,
72
- SEMANTIC_SEGMENTATION,
73
- SM,
74
- SPEARMANR,
75
- )
76
- from .losses import BBCEWithLogitLoss, FocalLoss, MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy, StructureLoss
77
- from .lr_scheduler import (
78
- get_cosine_schedule_with_warmup,
79
- get_linear_schedule_with_warmup,
80
- get_polynomial_decay_schedule_with_warmup,
81
- )
82
- from .semantic_seg_metrics import COD_METRICS_NAMES, Balanced_Error_Rate, Binary_IoU, Multiclass_IoU
83
-
84
- logger = logging.getLogger(__name__)
85
-
86
-
87
- def get_loss_func(
88
- problem_type: str,
89
- mixup_active: Optional[bool] = None,
90
- loss_func_name: Optional[str] = None,
91
- config: Optional[DictConfig] = None,
92
- **kwargs,
93
- ):
94
- """
95
- Choose a suitable Pytorch loss module based on the provided problem type.
96
-
97
- Parameters
98
- ----------
99
- problem_type
100
- Type of problem.
101
- mixup_active
102
- The activation determining whether to use mixup.
103
- loss_func_name
104
- The name of the function the user wants to use.
105
- config
106
- The optimization configs containing values such as i.e. optimization.loss_function
107
- An example purpose of this config here is to pass through the parameters for focal loss, i.e.:
108
- alpha = optimization.focal_loss.alpha
109
- Returns
110
- -------
111
- A Pytorch loss module.
112
- """
113
- if problem_type in [BINARY, MULTICLASS]:
114
- if mixup_active:
115
- loss_func = SoftTargetCrossEntropy()
116
- else:
117
- if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
118
- loss_func = FocalLoss(
119
- alpha=OmegaConf.select(config, "focal_loss.alpha"),
120
- gamma=OmegaConf.select(config, "focal_loss.gamma", default=2.0),
121
- reduction=OmegaConf.select(config, "focal_loss.reduction", default="mean"),
122
- )
123
- else:
124
- loss_func = nn.CrossEntropyLoss()
125
- elif problem_type == REGRESSION:
126
- if loss_func_name is not None:
127
- if "bcewithlogitsloss" in loss_func_name.lower():
128
- loss_func = nn.BCEWithLogitsLoss()
129
- else:
130
- loss_func = nn.MSELoss()
131
- else:
132
- loss_func = nn.MSELoss()
133
- elif problem_type == NER:
134
- loss_func = nn.CrossEntropyLoss(ignore_index=0)
135
- elif problem_type in [OBJECT_DETECTION, FEW_SHOT_CLASSIFICATION]:
136
- return None
137
- elif problem_type == SEMANTIC_SEGMENTATION:
138
- if "structure_loss" in loss_func_name.lower():
139
- loss_func = StructureLoss()
140
- elif "balanced_bce" in loss_func_name.lower():
141
- loss_func = BBCEWithLogitLoss()
142
- elif "mask2former_loss" in loss_func_name.lower():
143
- weight_dict = {
144
- "loss_cross_entropy": config.mask2former_loss.loss_cross_entropy_weight,
145
- "loss_mask": config.mask2former_loss.loss_mask_weight,
146
- "loss_dice": config.mask2former_loss.loss_dice_weight,
147
- }
148
- loss_func = Mask2FormerLoss(
149
- config=Mask2FormerConfig(num_labels=kwargs["num_classes"]), weight_dict=weight_dict
150
- )
151
- else:
152
- loss_func = nn.BCEWithLogitsLoss()
153
- elif problem_type is None:
154
- return None
155
- else:
156
- raise NotImplementedError
157
-
158
- return loss_func
159
-
160
-
161
- class CustomHitRate(torchmetrics.Metric):
162
- """
163
- Compute the hit rate when doing semantic search between two group of embeddings.
164
- We assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
165
- """
166
-
167
- def __init__(
168
- self,
169
- ):
170
- super().__init__()
171
- self.add_state("query_embeddings", default=[], dist_reduce_fx=None)
172
- self.add_state("response_embeddings", default=[], dist_reduce_fx=None)
173
- self.add_state("logit_scale", default=[], dist_reduce_fx=None)
174
-
175
- def update(
176
- self,
177
- batch_query_embeds: torch.Tensor,
178
- batch_response_embeds: torch.Tensor,
179
- logit_scale: Optional[torch.Tensor] = None,
180
- ):
181
- self.query_embeddings.append(batch_query_embeds)
182
- self.response_embeddings.append(batch_response_embeds)
183
- if logit_scale is not None:
184
- self.logit_scale.append(logit_scale)
185
-
186
- def compute(self):
187
- query_embeddings = torch.cat(self.query_embeddings)
188
- response_embeddings = torch.cat(self.response_embeddings)
189
- if self.logit_scale:
190
- logit_scale = torch.mean(torch.stack(self.logit_scale))
191
- else:
192
- logit_scale = 1
193
-
194
- return compute_hit_rate(query_embeddings, response_embeddings, logit_scale)
195
-
196
-
197
- def compute_hit_rate(features_a, features_b, logit_scale, top_ks=[1, 5, 10]):
198
- """
199
- Compute symmetric hit rates between two groups of features.
200
-
201
- Parameters
202
- ----------
203
- features_a
204
- One group of features.
205
- features_b
206
- The other group of features.
207
- logit_scale
208
- The scale of logit (Used in CLIP).
209
- top_ks
210
- Consider only the top k elements for each query.
211
-
212
- Returns
213
- -------
214
- The accumulated hit rate.
215
- """
216
- assert len(features_a) == len(features_b)
217
- hit_rate = 0
218
- logits_per_a = (logit_scale * features_a @ features_b.t()).detach().cpu()
219
- logits_per_b = logits_per_a.t().detach().cpu()
220
-
221
- logits = {"logits_per_a": logits_per_a, "logits_per_b": logits_per_b}
222
- ground_truth = torch.arange(len(features_b)).view(-1, 1)
223
-
224
- for name, logit in logits.items():
225
- ranking = torch.argsort(logit, descending=True)
226
- preds = torch.where(ranking == ground_truth)[1]
227
-
228
- for k in top_ks:
229
- hit_rate += (preds < k).float().mean()
230
-
231
- hit_rate /= len(top_ks) * len(logits)
232
- return hit_rate
233
-
234
-
235
- def get_metric(
236
- metric_name: str,
237
- num_classes: Optional[int] = None,
238
- is_matching: Optional[bool] = False,
239
- problem_type: Optional[str] = None,
240
- ):
241
- """
242
- Obtain a torchmerics.Metric from its name.
243
- Define a customized metric function in case that torchmetrics doesn't support some metric.
244
-
245
- Parameters
246
- ----------
247
- metric_name
248
- Name of metric.
249
- num_classes
250
- Number of classes.
251
- is_matching
252
- Whether is matching.
253
- problem_type
254
- Type of problem, e.g., binary and multiclass.
255
-
256
- Returns
257
- -------
258
- torchmetrics.Metric
259
- A torchmetrics.Metric object.
260
- custom_metric_func
261
- A customized metric function.
262
- """
263
- metric_name = metric_name.lower()
264
- if metric_name in [ACC, ACCURACY, OVERALL_ACCURACY]:
265
- # use MULTICLASS since the head output dim is 2 for the binary problem type.
266
- return torchmetrics.Accuracy(task=MULTICLASS, num_classes=num_classes), None
267
- elif metric_name == NER_TOKEN_F1:
268
- return torchmetrics.F1Score(task=MULTICLASS, num_classes=num_classes, ignore_index=1), None
269
- elif metric_name in [RMSE, ROOT_MEAN_SQUARED_ERROR]:
270
- return torchmetrics.MeanSquaredError(squared=False), None
271
- elif metric_name == R2:
272
- return torchmetrics.R2Score(), None
273
- elif metric_name == QUADRATIC_KAPPA:
274
- return (
275
- torchmetrics.CohenKappa(task=problem_type, num_classes=num_classes, weights="quadratic"),
276
- None,
277
- )
278
- elif metric_name == ROC_AUC:
279
- return torchmetrics.AUROC(task=problem_type, num_classes=num_classes), None
280
- elif metric_name == AVERAGE_PRECISION:
281
- return torchmetrics.AveragePrecision(task=problem_type, num_classes=num_classes)
282
- elif metric_name in [LOG_LOSS, CROSS_ENTROPY]:
283
- return torchmetrics.MeanMetric(), functools.partial(F.cross_entropy, reduction="none")
284
- elif metric_name == COSINE_EMBEDDING_LOSS:
285
- return torchmetrics.MeanMetric(), functools.partial(F.cosine_embedding_loss, reduction="none")
286
- elif metric_name == PEARSONR:
287
- return torchmetrics.PearsonCorrCoef(), None
288
- elif metric_name == SPEARMANR:
289
- if is_matching: # TODO: add support for matching.
290
- raise ValueError("spearman relation is not supported for matching yet.")
291
- else:
292
- return torchmetrics.SpearmanCorrCoef(), None
293
- elif metric_name == F1:
294
- return torchmetrics.F1Score(task=problem_type, num_classes=num_classes), None
295
- elif metric_name in [F1_MACRO, F1_MICRO, F1_WEIGHTED]:
296
- average = metric_name.split("_")[1]
297
- return torchmetrics.F1Score(task=problem_type, num_classes=num_classes, average=average), None
298
- elif metric_name in DETECTION_METRICS:
299
- return (
300
- MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=False),
301
- None,
302
- ) # TODO: remove parameter hardcodings here, and add class_metrics
303
- elif metric_name == DIRECT_LOSS:
304
- return (
305
- torchmetrics.MeanMetric(nan_strategy="warn"),
306
- None,
307
- ) # This only works for detection where custom_metric is not required for BaseAggregator
308
- elif metric_name in [RECALL, HIT_RATE]:
309
- if is_matching:
310
- return CustomHitRate(), None
311
- else: # TODO: support recall for general classification tasks.
312
- raise ValueError("Recall is not supported yet.")
313
- elif metric_name == BER:
314
- return Balanced_Error_Rate(), None
315
- elif metric_name in [SM, EM, FM, MAE]:
316
- return COD_METRICS_NAMES[metric_name], None
317
- elif metric_name == IOU:
318
- if num_classes == 1:
319
- return Binary_IoU(), None
320
- else:
321
- return Multiclass_IoU(num_classes=num_classes), None
322
- else:
323
- raise ValueError(f"Unknown metric {metric_name}")
324
-
325
-
326
- def get_optimizer(
327
- optim_type: str,
328
- optimizer_grouped_parameters,
329
- lr: float,
330
- weight_decay: float,
331
- eps: Optional[float] = 1e-6,
332
- betas: Optional[Tuple[float, float]] = (0.9, 0.999),
333
- momentum: Optional[float] = 0.9,
334
- ):
335
- """
336
- Choose a Pytorch optimizer based on its name.
337
-
338
- Parameters
339
- ----------
340
- optim_type
341
- Name of optimizer.
342
- optimizer_grouped_parameters
343
- The model parameters to be optimized.
344
- lr
345
- Learning rate.
346
- weight_decay
347
- Optimizer weight decay.
348
- eps
349
- Optimizer eps.
350
- betas
351
- Optimizer betas.
352
- momentum
353
- Momentum used in the SGD optimizer.
354
-
355
- Returns
356
- -------
357
- A Pytorch optimizer.
358
- """
359
- if optim_type == "adamw":
360
- optimizer = optim.AdamW(
361
- optimizer_grouped_parameters,
362
- lr=lr,
363
- weight_decay=weight_decay,
364
- eps=eps,
365
- betas=betas,
366
- )
367
- elif optim_type == "adam":
368
- optimizer = optim.Adam(
369
- optimizer_grouped_parameters,
370
- lr=lr,
371
- weight_decay=weight_decay,
372
- )
373
- elif optim_type == "sgd":
374
- optimizer = optim.SGD(
375
- optimizer_grouped_parameters,
376
- lr=lr,
377
- weight_decay=weight_decay,
378
- momentum=momentum,
379
- )
380
- elif optim_type == "adafactor":
381
- optimizer = Adafactor(
382
- optimizer_grouped_parameters,
383
- lr=lr,
384
- weight_decay=weight_decay,
385
- scale_parameter=True, # Generally recommended to enable scaling
386
- relative_step=False,
387
- warmup_init=False,
388
- )
389
- else:
390
- raise ValueError(f"unknown optimizer: {optim_type}")
391
-
392
- return optimizer
393
-
394
-
395
- def get_lr_scheduler(
396
- optimizer: optim.Optimizer,
397
- num_max_steps: int,
398
- num_warmup_steps: int,
399
- lr_schedule: str,
400
- end_lr: Union[float, int],
401
- ):
402
- """
403
- Get the learning rate scheduler from its name. Here we use our defined learning rate
404
- scheduler instead of those imported from "transformers" because we want to support
405
- Pytorch lightning's "ddp_spawn" training strategy.
406
-
407
- Parameters
408
- ----------
409
- optimizer
410
- A Pytorch optimizer.
411
- num_max_steps
412
- Number of maximum training steps.
413
- num_warmup_steps
414
- Number of steps to do learning rate warmup.
415
- lr_schedule
416
- Name of the learning rate scheduler.
417
- end_lr
418
- The final learning rate after decay.
419
-
420
- Returns
421
- -------
422
- A learning rate scheduler.
423
- """
424
- if lr_schedule == "cosine_decay":
425
- scheduler = get_cosine_schedule_with_warmup(
426
- optimizer=optimizer,
427
- num_warmup_steps=num_warmup_steps,
428
- num_training_steps=num_max_steps,
429
- )
430
- elif lr_schedule == "polynomial_decay":
431
- scheduler = get_polynomial_decay_schedule_with_warmup(
432
- optimizer=optimizer,
433
- num_warmup_steps=num_warmup_steps,
434
- num_training_steps=num_max_steps,
435
- lr_end=end_lr,
436
- power=1,
437
- )
438
- elif lr_schedule == "linear_decay":
439
- scheduler = get_linear_schedule_with_warmup(
440
- optimizer=optimizer,
441
- num_warmup_steps=num_warmup_steps,
442
- num_training_steps=num_max_steps,
443
- )
444
- elif lr_schedule == "multi_step":
445
- # TODO: add milestones, gamma into hyperparameters
446
- scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[30, 55], gamma=0.1)
447
- else:
448
- raise ValueError(f"unknown lr schedule: {lr_schedule}")
449
-
450
- return scheduler
451
-
452
-
453
- def get_weight_decay_param_names(model: nn.Module):
454
- """
455
- Set the layer normalization parameters and other layers' bias parameters not to use weight decay.
456
-
457
- Parameters
458
- ----------
459
- model
460
- A Pytorch model.
461
-
462
- Returns
463
- -------
464
- A list of parameter names not using weight decay.
465
- """
466
- # By default, we should not apply weight decay for all the norm layers
467
- decay_param_names = get_parameter_names(
468
- model,
469
- [nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
470
- )
471
- decay_param_names = [
472
- name
473
- for name in decay_param_names
474
- if (
475
- "bias" not in name
476
- and "cls_token" not in name
477
- and "categorical_feature_tokenizer" not in name
478
- and "numerical_feature_tokenizer" not in name
479
- )
480
- ]
481
- return decay_param_names
482
-
483
-
484
- def get_norm_layer_param_names(model: nn.Module):
485
- """
486
- Get parameters associated with the normalization layers
487
-
488
- Parameters
489
- ----------
490
- model
491
- A Pytorch model
492
-
493
- Returns
494
- -------
495
- norm_param_names
496
- A list of normalization parameter names
497
- """
498
- all_param_names = [name for name, _ in model.named_parameters()]
499
- all_param_names_except_norm_names = get_parameter_names(
500
- model,
501
- [nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
502
- )
503
- norm_param_names = [name for name in all_param_names if name not in all_param_names_except_norm_names]
504
- return norm_param_names
505
-
506
-
507
- def apply_single_lr(
508
- model: nn.Module,
509
- lr: float,
510
- weight_decay: float,
511
- return_params: Optional[bool] = True,
512
- efficient_finetune: Optional[str] = None,
513
- trainable_param_names: Optional[List] = None,
514
- ):
515
- """
516
- Set to use a single learning rate for all parameters. Layer normalization parameters and other
517
- layers' bias parameters don't use weight decay.
518
-
519
- Parameters
520
- ----------
521
- model
522
- A Pytorch model.
523
- lr
524
- Learning rate.
525
- weight_decay
526
- Weight decay.
527
- return_params
528
- Whether to return parameters or their names. If you want to double-check
529
- whether the learning rate setup is as expected, you can set "return_params=False",
530
- and print the layer names along with their learning rates through
531
- "print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
532
-
533
- Returns
534
- -------
535
- The grouped parameters or their names.
536
- """
537
- decay_param_names = get_weight_decay_param_names(model)
538
- decay_grad_param_names = []
539
- no_decay_grad_param_names = []
540
-
541
- for name, param in model.named_parameters():
542
- if (
543
- efficient_finetune is not None
544
- and efficient_finetune != "None"
545
- and trainable_param_names
546
- and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
547
- ):
548
- param.requires_grad = False
549
-
550
- if not param.requires_grad:
551
- continue # frozen weights
552
-
553
- if name in decay_param_names:
554
- if return_params:
555
- decay_grad_param_names.append(param)
556
- else:
557
- decay_grad_param_names.append(name)
558
-
559
- else:
560
- if return_params:
561
- no_decay_grad_param_names.append(param)
562
- else:
563
- no_decay_grad_param_names.append(name)
564
-
565
- optimizer_grouped_parameters = [
566
- {
567
- "params": decay_grad_param_names,
568
- "weight_decay": weight_decay,
569
- "lr": lr,
570
- },
571
- {
572
- "params": no_decay_grad_param_names,
573
- "weight_decay": 0.0,
574
- "lr": lr,
575
- },
576
- ]
577
- return optimizer_grouped_parameters
578
-
579
-
580
- def apply_two_stages_lr(
581
- model: nn.Module,
582
- lr: float,
583
- lr_mult: Union[float, int],
584
- weight_decay: float,
585
- return_params: Optional[bool] = True,
586
- ):
587
- """
588
- Set up the pretrained backbone to use a smaller learning rate (lr * lr_mult).
589
- The newly added head layers use the normal learning rate (lr).
590
- Layer normalization parameters and other layers' bias parameters don't use weight decay.
591
-
592
- Parameters
593
- ----------
594
- model
595
- A Pytorch model.
596
- lr
597
- The learning rate.
598
- lr_mult
599
- The multiplier (0, 1) to scale down the learning rate.
600
- weight_decay
601
- Weight decay.
602
- return_params
603
- return_params
604
- Whether to return parameters or their names. If you want to double-check
605
- whether the learning rate setup is as expected, you can set "return_params=False",
606
- and print the layer names along with their learning rates through
607
- "print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
608
-
609
- Returns
610
- -------
611
- The grouped parameters or their names.
612
- """
613
- decay_param_names = get_weight_decay_param_names(model)
614
-
615
- optimizer_grouped_parameters = [
616
- {
617
- "params": [
618
- p if return_params else n
619
- for n, p in model.named_parameters()
620
- if n in decay_param_names and not any(bb in n for bb in model.head_layer_names)
621
- ],
622
- "weight_decay": weight_decay,
623
- "lr": lr,
624
- },
625
- {
626
- "params": [
627
- p if return_params else n
628
- for n, p in model.named_parameters()
629
- if n not in decay_param_names and not any(bb in n for bb in model.head_layer_names)
630
- ],
631
- "weight_decay": 0.0,
632
- "lr": lr,
633
- },
634
- {
635
- "params": [
636
- p if return_params else n
637
- for n, p in model.named_parameters()
638
- if n in decay_param_names and any(bb in n for bb in model.head_layer_names)
639
- ],
640
- "weight_decay": weight_decay,
641
- "lr": lr * lr_mult,
642
- },
643
- {
644
- "params": [
645
- p if return_params else n
646
- for n, p in model.named_parameters()
647
- if n not in decay_param_names and any(bb in n for bb in model.head_layer_names)
648
- ],
649
- "weight_decay": 0.0,
650
- "lr": lr * lr_mult,
651
- },
652
- ]
653
-
654
- return optimizer_grouped_parameters
655
-
656
-
657
- def get_trainable_params_efficient_finetune(
658
- norm_param_names: List[str], efficient_finetune: Optional[str] = None, extra_params: Optional[List] = None
659
- ):
660
- """
661
- Get the list of trainable parameters according to the provided efficient finetuning method.
662
-
663
- Parameters
664
- ----------
665
- norm_param_names
666
- The parameters associated with the normalization layers
667
- efficient_finetune
668
- Efficient finetuning strategy. Trainable parameters will be adjusted according to the method.
669
-
670
- Returns
671
- -------
672
- Get list of trainable parameter names according to the provided efficient finetuning method.
673
- """
674
- trainable_param_names = []
675
-
676
- if efficient_finetune == BIT_FIT:
677
- trainable_param_names.append(".*bias*.")
678
- elif efficient_finetune == NORM_FIT:
679
- trainable_param_names.append(".*bias*.")
680
- trainable_param_names += norm_param_names
681
- elif efficient_finetune in [LORA, IA3, IA3_LORA, CONV_LORA]:
682
- trainable_param_names.append(".*lora_*.")
683
- elif efficient_finetune in [LORA_BIAS, IA3_BIAS, IA3_LORA_BIAS]:
684
- trainable_param_names.append(".*lora_*.")
685
- trainable_param_names.append(".*bias*.")
686
- elif efficient_finetune in [LORA_NORM, IA3_NORM, IA3_LORA_NORM]:
687
- trainable_param_names.append(".*lora_*.")
688
- trainable_param_names.append(".*bias*.")
689
- trainable_param_names += norm_param_names
690
- elif efficient_finetune is not None and efficient_finetune != "None":
691
- raise NotImplementedError(
692
- f"The efficient finetuning strategy '{efficient_finetune}'"
693
- f" is not supported. We only support"
694
- f" {', '.join(PEFT_STRATEGIES)}."
695
- )
696
-
697
- if extra_params:
698
- trainable_param_names.extend(extra_params)
699
-
700
- return trainable_param_names
701
-
702
-
703
- def remove_parameters_without_grad(
704
- grouped_parameters: List[Dict],
705
- ):
706
- """
707
- Remove layers
708
-
709
- Parameters
710
- ----------
711
- grouped_parameters
712
- The grouped parameters or their names output from lr_choice.
713
-
714
- Returns
715
- -------
716
- The updated grouped parameters or their names.
717
- """
718
- for group_idx, group_param in enumerate(grouped_parameters):
719
- updated_params = []
720
- for p in group_param["params"]:
721
- if p.requires_grad:
722
- updated_params.append(p)
723
- grouped_parameters[group_idx]["params"] = updated_params
724
-
725
- return grouped_parameters
726
-
727
-
728
- def apply_layerwise_lr_decay(
729
- model: nn.Module,
730
- lr: float,
731
- lr_decay: float,
732
- weight_decay: float,
733
- efficient_finetune: Optional[str] = None,
734
- trainable_param_names: Optional[List] = None,
735
- ):
736
- """
737
- Assign monotonically decreasing learning rates for layers from the output end to the input end.
738
- The intuition behind is that later layers are more task-related compared to the early layers.
739
- Layer normalization parameters and other layers' bias parameters don't use weight decay.
740
- If you want to double-check whether the learning rate setup is as expected,
741
- you can print the layer names along with their learning rates through
742
- "print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))".
743
-
744
- Parameters
745
- ----------
746
- model
747
- A Pytorch model.
748
- lr
749
- The learning rate.
750
- lr_decay
751
- The learning rate decay factor (0, 1).
752
- weight_decay
753
- Weight decay.
754
- efficient_finetune
755
- Efficient finetuning strategy. It will only finetune part of the parameters
756
-
757
- Returns
758
- -------
759
- The grouped parameters based on their layer ids and whether using weight decay.
760
- """
761
- parameter_group_names = {}
762
- parameter_group_vars = {}
763
- decay_param_names = get_weight_decay_param_names(model)
764
-
765
- for name, param in model.named_parameters():
766
- if name.startswith("_orig_mod."):
767
- name = "".join(name.split("_orig_mod."))
768
- layer_id = model.name_to_id[name]
769
- if layer_id == 0: # Set top layer (e.g. head, fusion_mlp, adapter) as being trainable.
770
- param.requires_grad = True
771
- elif (
772
- efficient_finetune is not None
773
- and efficient_finetune != "None"
774
- and trainable_param_names
775
- and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
776
- ):
777
- param.requires_grad = False
778
-
779
- if not param.requires_grad:
780
- continue # frozen weights
781
-
782
- if name in decay_param_names:
783
- group_name = "decay"
784
- this_weight_decay = weight_decay
785
- else:
786
- group_name = "no_decay"
787
- this_weight_decay = 0.0
788
-
789
- layer_id = model.name_to_id[name]
790
- group_name = "layer_%d_%s" % (layer_id, group_name)
791
-
792
- if group_name not in parameter_group_names:
793
- scale = lr_decay**layer_id
794
- parameter_group_names[group_name] = {
795
- "weight_decay": this_weight_decay,
796
- "params": [],
797
- "lr": scale * lr,
798
- }
799
- parameter_group_vars[group_name] = {
800
- "weight_decay": this_weight_decay,
801
- "params": [],
802
- "lr": scale * lr,
803
- }
804
-
805
- parameter_group_vars[group_name]["params"].append(param)
806
- parameter_group_names[group_name]["params"].append(name)
807
-
808
- return list(parameter_group_vars.values())
809
-
810
-
811
- def gather_column_features(
812
- output: Dict[str, Dict],
813
- column_names: Union[str, List[str]],
814
- ):
815
- """
816
- Gather column features from models' outputs.
817
- For each feature name in one model's output, we enumerate the provided column names to see
818
- whether (partial) the provided columns share one cls feature or they have independent features.
819
-
820
- TODO: return features' masks and use them to filter the losses.
821
-
822
- Parameters
823
- ----------
824
- output
825
- The models' outputs.
826
- column_names
827
- The columns whose features we want to get.
828
-
829
- Returns
830
- -------
831
- The gathered feature vectors. Each sample should only have one feature vector.
832
- """
833
- if isinstance(column_names, str):
834
- column_names = [column_names]
835
-
836
- gathered_features = []
837
- # logger.debug(f"gather features for columns: {column_names}")
838
- for per_model_name, per_model_output in output.items():
839
- # logger.debug(f"gather column features from model: {per_model_name}")
840
- for feature_name in per_model_output[COLUMN_FEATURES][FEATURES]:
841
- # logger.debug(f"processing feature: {feature_name}")
842
- columns_share_one_feature = []
843
- for col_name in column_names:
844
- if col_name in feature_name:
845
- # this column feature is part of the cls feature
846
- if not (feature_name.startswith(col_name) and feature_name.endswith(col_name)):
847
- columns_share_one_feature.append(col_name)
848
- # logger.debug(f"column {col_name} is included in feature {feature_name}")
849
- else: # this column's feature is independent of other columns'
850
- gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][col_name])
851
- # logger.debug(f"col_name {col_name} has an independent feature in model: {per_model_name}")
852
-
853
- # two or more columns share one cls feature, and no other columns share it.
854
- if len(columns_share_one_feature) > 0:
855
- assert (
856
- len("_".join(columns_share_one_feature)) == len(feature_name)
857
- ), f"model `{per_model_name}`'s cls feature name `{feature_name}` doesn't match `{columns_share_one_feature}`"
858
- gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][feature_name])
859
-
860
- if len(gathered_features) > 1:
861
- # currently only support features of the same shape
862
- assert all(
863
- per_features.shape == gathered_features[0].shape for per_features in gathered_features
864
- ), "Currently we only support gathering features of the same dimension."
865
-
866
- if len(gathered_features) == 0:
867
- raise ValueError(f"No features are found for columns names {column_names}.")
868
-
869
- gathered_features = torch.stack(gathered_features, dim=0).mean(dim=0) # (b, d)
870
-
871
- return gathered_features
872
-
873
-
874
- def get_metric_learning_distance_func(
875
- name: str,
876
- ):
877
- """
878
- Return one pytorch metric learning's distance function based on its name.
879
-
880
- Parameters
881
- ----------
882
- name
883
- distance function name
884
-
885
- Returns
886
- -------
887
- A distance function from the pytorch metric learning package.
888
- """
889
- if name.lower() == COSINE_SIMILARITY:
890
- return distances.CosineSimilarity()
891
- else:
892
- raise ValueError(f"Unknown distance measure: {name}")
893
-
894
-
895
- def infer_matcher_loss(data_format: str, problem_type: str):
896
- """
897
- Infer the loss type to train the matcher.
898
-
899
- Parameters
900
- ----------
901
- data_format
902
- The training data format, e.g., pair or triplet.
903
- problem_type
904
- Type of problem.
905
-
906
- Returns
907
- -------
908
- The loss name.
909
- """
910
- if data_format == "pair":
911
- if problem_type is None:
912
- return [MULTI_NEGATIVES_SOFTMAX_LOSS]
913
- elif problem_type == BINARY:
914
- return [CONTRASTIVE_LOSS]
915
- elif problem_type == REGRESSION:
916
- return ["cosine_similarity_loss"]
917
- else:
918
- raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
919
- elif data_format == "triplet":
920
- if problem_type is None:
921
- return [MULTI_NEGATIVES_SOFTMAX_LOSS]
922
- else:
923
- raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
924
- else:
925
- raise ValueError(f"Unsupported data format: {data_format}")
926
-
927
-
928
- def get_matcher_loss_func(
929
- data_format: str,
930
- problem_type: str,
931
- loss_type: Optional[str] = None,
932
- pos_margin: Optional[float] = None,
933
- neg_margin: Optional[float] = None,
934
- distance_type: Optional[str] = None,
935
- ):
936
- """
937
- Return a list of pytorch metric learning's loss functions based on their names.
938
-
939
- Parameters
940
- ----------
941
- data_format
942
- The training data format, e.g., pair or triplet.
943
- problem_type
944
- Type of problem.
945
- loss_type
946
- The provided loss type.
947
- pos_margin
948
- The positive margin in computing the metric learning loss.
949
- neg_margin
950
- The negative margin in computing the metric learning loss.
951
- distance_type
952
- The distance function type.
953
-
954
- Returns
955
- -------
956
- A loss function of metric learning.
957
- """
958
-
959
- allowable_loss_types = infer_matcher_loss(data_format=data_format, problem_type=problem_type)
960
- if loss_type is not None:
961
- assert loss_type in allowable_loss_types, f"data format {data_format} can't use loss {loss_type}."
962
- else:
963
- loss_type = allowable_loss_types[0]
964
-
965
- if loss_type.lower() == CONTRASTIVE_LOSS:
966
- return losses.ContrastiveLoss(
967
- pos_margin=pos_margin,
968
- neg_margin=neg_margin,
969
- distance=get_metric_learning_distance_func(distance_type),
970
- )
971
- elif loss_type.lower() == MULTI_NEGATIVES_SOFTMAX_LOSS:
972
- return MultiNegativesSoftmaxLoss(
973
- local_loss=True,
974
- gather_with_grad=True,
975
- cache_labels=False,
976
- )
977
- else:
978
- raise ValueError(f"Unknown metric learning loss: {loss_type}")
979
-
980
-
981
- def get_matcher_miner_func(
982
- miner_type: str,
983
- pos_margin: float,
984
- neg_margin: float,
985
- distance_type: str,
986
- ):
987
- """
988
- Return a pytorch metric learning's miner functions based on their names.
989
- The miners are used to mine the positive and negative examples.
990
-
991
- Parameters
992
- ----------
993
- miner_type
994
- The miner function type.
995
- pos_margin
996
- The positive margin used by the miner function.
997
- neg_margin
998
- The negative margin used by the miner function.
999
- distance_type
1000
- The distance function type.
1001
-
1002
- Returns
1003
- -------
1004
- A miner function to mine positive and negative samples.
1005
- """
1006
- if miner_type.lower() == PAIR_MARGIN_MINER:
1007
- return miners.PairMarginMiner(
1008
- pos_margin=pos_margin,
1009
- neg_margin=neg_margin,
1010
- distance=get_metric_learning_distance_func(distance_type),
1011
- )
1012
- else:
1013
- raise ValueError(f"Unknown metric learning miner: {miner_type}")
1014
-
1015
-
1016
- def generate_metric_learning_labels(
1017
- num_samples: int,
1018
- match_label: int,
1019
- labels: torch.Tensor,
1020
- ):
1021
- """
1022
- Generate labels to compute the metric learning loss of one mini-batch.
1023
- For n samples, it generates 2*n labels since each match has two sides, each of which
1024
- has one label. If we know the matching label, then it determines the two sides' labels
1025
- according to whether their label is the matching label. If the matching label is None,
1026
- it assigns a unique label for each side.
1027
-
1028
- Parameters
1029
- ----------
1030
- num_samples
1031
- number of samples.
1032
- match_label
1033
- The matching label, which can be None.
1034
- labels
1035
- The sample labels used in the supervised setting. It's required only when match_label is not None.
1036
-
1037
- Returns
1038
- -------
1039
- The labels used in computing the metric learning loss.
1040
- """
1041
- device = labels.device
1042
- labels_1 = torch.arange(num_samples, device=device)
1043
-
1044
- if match_label is not None:
1045
- labels_2 = torch.arange(num_samples, num_samples * 2, device=device)
1046
- # users need to specify the match_label based on the raw label's semantic meaning.
1047
- mask = labels == match_label
1048
- labels_2[mask] = labels_1[mask]
1049
- else:
1050
- labels_2 = torch.arange(num_samples, device=device)
1051
-
1052
- metric_learning_labels = torch.cat([labels_1, labels_2], dim=0)
1053
-
1054
- return metric_learning_labels