autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -0,0 +1,359 @@
1
+ import functools
2
+ import logging
3
+ import warnings
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import evaluate
7
+ import torchmetrics
8
+ from torch.nn import functional as F
9
+ from torchmetrics.detection.mean_ap import MeanAveragePrecision
10
+
11
+ from autogluon.core.metrics import Scorer, compute_metric, get_metric
12
+
13
+ from ...constants import (
14
+ ACC,
15
+ ACCURACY,
16
+ AVERAGE_PRECISION,
17
+ BER,
18
+ COSINE_EMBEDDING_LOSS,
19
+ COVERAGE,
20
+ CROSS_ENTROPY,
21
+ DETECTION_METRICS,
22
+ DIRECT_LOSS,
23
+ EM,
24
+ F1,
25
+ F1_MACRO,
26
+ F1_MICRO,
27
+ F1_WEIGHTED,
28
+ FM,
29
+ HIT_RATE,
30
+ IOU,
31
+ LOG_LOSS,
32
+ MAE,
33
+ MATCHING_METRICS,
34
+ MATCHING_METRICS_WITHOUT_PROBLEM_TYPE,
35
+ MAX,
36
+ METRIC_MODE_MAP,
37
+ MIN,
38
+ MULTICLASS,
39
+ NER_TOKEN_F1,
40
+ OVERALL_ACCURACY,
41
+ OVERALL_F1,
42
+ PEARSONR,
43
+ QUADRATIC_KAPPA,
44
+ R2,
45
+ RECALL,
46
+ RETRIEVAL_METRICS,
47
+ RMSE,
48
+ ROC_AUC,
49
+ ROOT_MEAN_SQUARED_ERROR,
50
+ SM,
51
+ SPEARMANR,
52
+ Y_PRED,
53
+ Y_PRED_PROB,
54
+ Y_TRUE,
55
+ )
56
+ from .coverage_metrics import Coverage
57
+ from .hit_rate_metrics import CustomHitRate
58
+ from .semantic_seg_metrics import COD_METRICS_NAMES, Balanced_Error_Rate, Binary_IoU, Multiclass_IoU
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ def compute_score(
64
+ metric_data: dict,
65
+ metric: Union[str, Scorer],
66
+ pos_label: Optional[int] = 1,
67
+ ) -> float:
68
+ """
69
+ Use sklearn to compute the score of one metric.
70
+
71
+ Parameters
72
+ ----------
73
+ metric_data
74
+ A dictionary with the groundtruth (Y_TRUE) and predicted values (Y_PRED, Y_PRED_PROB).
75
+ The predicted class probabilities are required to compute the roc_auc score.
76
+ metric
77
+ The name of metric or the function of metric to compute.
78
+ pos_label
79
+ The encoded label (0 or 1) of binary classification's positive class.
80
+
81
+ Returns
82
+ -------
83
+ Computed score.
84
+ """
85
+ if isinstance(metric, str) and metric in [OVERALL_ACCURACY, OVERALL_F1]:
86
+ metric = evaluate.load("seqeval")
87
+ warnings.filterwarnings("ignore")
88
+ for p in metric_data[Y_TRUE]:
89
+ if "_" in p:
90
+ print(p)
91
+ for p in metric_data[Y_PRED]:
92
+ if "_" in p:
93
+ print(p)
94
+ return metric.compute(references=metric_data[Y_TRUE], predictions=metric_data[Y_PRED])
95
+
96
+ metric = get_metric(metric)
97
+
98
+ y = metric_data[Y_TRUE]
99
+ if metric.needs_proba or metric.needs_threshold:
100
+ y_pred_proba = metric_data[Y_PRED_PROB]
101
+ y_pred_proba = (
102
+ y_pred_proba if y_pred_proba.shape[1] > 2 else y_pred_proba[:, pos_label]
103
+ ) # only use pos_label for binary classification
104
+ return metric.convert_score_to_original(
105
+ compute_metric(y=y, y_pred_proba=y_pred_proba, metric=metric, weights=None)
106
+ )
107
+ else:
108
+ y_pred = metric_data[Y_PRED]
109
+
110
+ # TODO: This is a hack. Doesn't support `f1_macro`, `f1_micro`, `f1_weighted`, or custom `f1` metrics with different names.
111
+ # TODO: Longterm the solution should be to have the input data to this function use the internal representation without the original class names. This way `pos_label` would not need to be specified.
112
+ if metric.name == F1: # only for binary classification
113
+ y = (y == pos_label).astype(int)
114
+ y_pred = (y_pred == pos_label).astype(int)
115
+
116
+ return metric.convert_score_to_original(compute_metric(y=y, y_pred=y_pred, metric=metric, weights=None))
117
+
118
+
119
+ def infer_metrics(
120
+ problem_type: Optional[str] = None,
121
+ eval_metric: Optional[Union[str, Scorer]] = None,
122
+ validation_metric_name: Optional[str] = None,
123
+ is_matching: Optional[bool] = False,
124
+ ):
125
+ """
126
+ Infer the validation metric and the evaluation metric if not provided.
127
+ Validation metric is for early-stopping and selecting the best model checkpoints.
128
+ Evaluation metric is to report performance to users.
129
+
130
+ Parameters
131
+ ----------
132
+ problem_type
133
+ Type of problem.
134
+ eval_metric
135
+ Name of evaluation metric provided by users.
136
+ validation_metric_name
137
+ The provided validation metric name
138
+ is_matching
139
+ Whether is matching.
140
+
141
+ Returns
142
+ -------
143
+ validation_metric_name
144
+ Name of validation metric.
145
+ eval_metric_name
146
+ Name of evaluation metric.
147
+ """
148
+ is_customized = False
149
+ if eval_metric is None:
150
+ eval_metric_name = None
151
+ elif isinstance(eval_metric, str):
152
+ eval_metric_name = eval_metric
153
+ elif isinstance(eval_metric, Scorer):
154
+ eval_metric_name = eval_metric.name
155
+ is_customized = True
156
+ else:
157
+ raise TypeError(f"eval_metric can be a str, a Scorer, or None, but is type: {type(eval_metric)}")
158
+
159
+ if problem_type is not None:
160
+ from ...utils.problem_types import PROBLEM_TYPES_REG
161
+
162
+ problem_property = PROBLEM_TYPES_REG.get(problem_type)
163
+
164
+ if is_matching:
165
+ if eval_metric_name is not None:
166
+ # if eval_metric_name is a valid metric
167
+ if eval_metric_name.lower() in METRIC_MODE_MAP.keys():
168
+ validation_metric_name = eval_metric_name
169
+ return validation_metric_name, eval_metric_name
170
+ elif eval_metric_name.lower() in RETRIEVAL_METRICS:
171
+ # Currently only support recall as validation metric in retrieval tasks.
172
+ validation_metric_name = RECALL
173
+ return validation_metric_name, eval_metric_name
174
+
175
+ # When eval_metric_name is either None or not supported:
176
+ # Fallback based on problem type unless it's a customized metric
177
+ if problem_type is None:
178
+ validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS_WITHOUT_PROBLEM_TYPE
179
+ elif problem_type in MATCHING_METRICS:
180
+ validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS[problem_type]
181
+ else:
182
+ raise NotImplementedError(f"Problem type: {problem_type} is not yet supported for matching!")
183
+ if not is_customized:
184
+ if eval_metric_name is not None:
185
+ warnings.warn(
186
+ f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type} in matching tasks."
187
+ f"The evaluation metric is changed to {fallback_evaluation_metric} by default."
188
+ )
189
+ eval_metric_name = fallback_evaluation_metric
190
+ return validation_metric_name, eval_metric_name
191
+
192
+ if eval_metric_name is not None:
193
+ # Infer evaluation metric
194
+ if eval_metric_name.lower() not in problem_property.supported_evaluation_metrics and not is_customized:
195
+ warnings.warn(
196
+ f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type}. "
197
+ f"The evaluation metric is changed to {problem_property.fallback_evaluation_metric} by default."
198
+ )
199
+ if problem_property.fallback_evaluation_metric is not None:
200
+ eval_metric_name = problem_property.fallback_evaluation_metric
201
+ else:
202
+ # Problem types like extract_embedding does not need a eval/val metric
203
+ return None, None
204
+
205
+ # Infer validation metric
206
+ if eval_metric_name.lower() in problem_property.supported_validation_metrics:
207
+ validation_metric_name = eval_metric_name
208
+ else:
209
+ if problem_property.fallback_validation_metric is not None:
210
+ validation_metric_name = problem_property.fallback_validation_metric
211
+ else:
212
+ eval_metric_name = problem_property.fallback_evaluation_metric
213
+ validation_metric_name = problem_property.fallback_validation_metric
214
+
215
+ return validation_metric_name, eval_metric_name
216
+
217
+
218
+ def get_minmax_mode(
219
+ metric_name: Union[str, Scorer],
220
+ ):
221
+ """
222
+ Get minmax mode based on metric name
223
+
224
+ Parameters
225
+ ----------
226
+ metric_name
227
+ A string representing metric
228
+
229
+ Returns
230
+ -------
231
+ mode
232
+ The min/max mode used in selecting model checkpoints.
233
+ - min
234
+ Its means that smaller metric is better.
235
+ - max
236
+ It means that larger metric is better.
237
+ """
238
+ if isinstance(metric_name, str):
239
+ assert (
240
+ metric_name in METRIC_MODE_MAP
241
+ ), f"{metric_name} is not a supported metric. Options are: {METRIC_MODE_MAP.keys()}"
242
+ return METRIC_MODE_MAP.get(metric_name)
243
+ else:
244
+ return MAX if metric_name.greater_is_better else MIN
245
+
246
+
247
+ def get_stopping_threshold(metric_name: str):
248
+ """
249
+ Get the metric threshold for early stopping.
250
+
251
+ Parameters
252
+ ----------
253
+ metric_name
254
+ Name of validation metric.
255
+
256
+ Returns
257
+ -------
258
+ The stopping threshold.
259
+ """
260
+ try:
261
+ metric = get_metric(metric_name)
262
+ stopping_threshold = metric.optimum - metric._sign * 1e-7
263
+ except:
264
+ stopping_threshold = None
265
+
266
+ return stopping_threshold
267
+
268
+
269
+ def get_torchmetric(
270
+ metric_name: str,
271
+ num_classes: Optional[int] = None,
272
+ is_matching: Optional[bool] = False,
273
+ problem_type: Optional[str] = None,
274
+ ):
275
+ """
276
+ Obtain a torchmerics.Metric from its name.
277
+ Define a customized metric function in case that torchmetrics doesn't support some metric.
278
+
279
+ Parameters
280
+ ----------
281
+ metric_name
282
+ Name of metric.
283
+ num_classes
284
+ Number of classes.
285
+ is_matching
286
+ Whether is matching.
287
+ problem_type
288
+ Type of problem, e.g., binary and multiclass.
289
+
290
+ Returns
291
+ -------
292
+ torchmetrics.Metric
293
+ A torchmetrics.Metric object.
294
+ custom_metric_func
295
+ A customized metric function.
296
+ """
297
+ metric_name = metric_name.lower()
298
+ if metric_name in [ACC, ACCURACY, OVERALL_ACCURACY]:
299
+ # use MULTICLASS since the head output dim is 2 for the binary problem type.
300
+ return torchmetrics.Accuracy(task=MULTICLASS, num_classes=num_classes), None
301
+ elif metric_name == NER_TOKEN_F1:
302
+ return torchmetrics.F1Score(task=MULTICLASS, num_classes=num_classes, ignore_index=1), None
303
+ elif metric_name in [RMSE, ROOT_MEAN_SQUARED_ERROR]:
304
+ return torchmetrics.MeanSquaredError(squared=False), None
305
+ elif metric_name == R2:
306
+ return torchmetrics.R2Score(), None
307
+ elif metric_name == QUADRATIC_KAPPA:
308
+ return (
309
+ torchmetrics.CohenKappa(task=problem_type, num_classes=num_classes, weights="quadratic"),
310
+ None,
311
+ )
312
+ elif metric_name == ROC_AUC:
313
+ return torchmetrics.AUROC(task=problem_type, num_classes=num_classes), None
314
+ elif metric_name == AVERAGE_PRECISION:
315
+ return torchmetrics.AveragePrecision(task=problem_type, num_classes=num_classes)
316
+ elif metric_name in [LOG_LOSS, CROSS_ENTROPY]:
317
+ return torchmetrics.MeanMetric(), functools.partial(F.cross_entropy, reduction="none")
318
+ elif metric_name == COSINE_EMBEDDING_LOSS:
319
+ return torchmetrics.MeanMetric(), functools.partial(F.cosine_embedding_loss, reduction="none")
320
+ elif metric_name == PEARSONR:
321
+ return torchmetrics.PearsonCorrCoef(), None
322
+ elif metric_name == SPEARMANR:
323
+ if is_matching: # TODO: add support for matching.
324
+ raise ValueError("spearman relation is not supported for matching yet.")
325
+ else:
326
+ return torchmetrics.SpearmanCorrCoef(), None
327
+ elif metric_name == F1:
328
+ return torchmetrics.F1Score(task=problem_type, num_classes=num_classes), None
329
+ elif metric_name in [F1_MACRO, F1_MICRO, F1_WEIGHTED]:
330
+ average = metric_name.split("_")[1]
331
+ return torchmetrics.F1Score(task=problem_type, num_classes=num_classes, average=average), None
332
+ elif metric_name in DETECTION_METRICS:
333
+ return (
334
+ MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=False),
335
+ None,
336
+ ) # TODO: remove parameter hardcodings here, and add class_metrics
337
+ elif metric_name == DIRECT_LOSS:
338
+ return (
339
+ torchmetrics.MeanMetric(nan_strategy="warn"),
340
+ None,
341
+ ) # This only works for detection where custom_metric is not required for BaseAggregator
342
+ elif metric_name in [RECALL, HIT_RATE]:
343
+ if is_matching:
344
+ return CustomHitRate(), None
345
+ else: # TODO: support recall for general classification tasks.
346
+ raise ValueError("Recall is not supported yet.")
347
+ elif metric_name == BER:
348
+ return Balanced_Error_Rate(), None
349
+ elif metric_name in [SM, EM, FM, MAE]:
350
+ return COD_METRICS_NAMES[metric_name], None
351
+ elif metric_name == IOU:
352
+ if num_classes == 1:
353
+ return Binary_IoU(), None
354
+ else:
355
+ return Multiclass_IoU(num_classes=num_classes), None
356
+ elif metric_name == COVERAGE:
357
+ return Coverage(), None
358
+ else:
359
+ raise ValueError(f"Unknown metric {metric_name}")
@@ -0,0 +1,284 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn, optim
6
+ from transformers import Adafactor
7
+ from transformers.trainer_pt_utils import get_parameter_names
8
+
9
+ from ..constants import (
10
+ BIT_FIT,
11
+ COLUMN_FEATURES,
12
+ CONV_LORA,
13
+ FEATURES,
14
+ IA3,
15
+ IA3_BIAS,
16
+ IA3_LORA,
17
+ IA3_LORA_BIAS,
18
+ IA3_LORA_NORM,
19
+ IA3_NORM,
20
+ LORA,
21
+ LORA_BIAS,
22
+ LORA_NORM,
23
+ NORM_FIT,
24
+ PEFT_STRATEGIES,
25
+ )
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def get_optimizer(
31
+ optim_type: str,
32
+ optimizer_grouped_parameters,
33
+ lr: float,
34
+ weight_decay: float,
35
+ eps: Optional[float] = 1e-6,
36
+ betas: Optional[Tuple[float, float]] = (0.9, 0.999),
37
+ momentum: Optional[float] = 0.9,
38
+ ):
39
+ """
40
+ Choose a Pytorch optimizer based on its name.
41
+
42
+ Parameters
43
+ ----------
44
+ optim_type
45
+ Name of optimizer.
46
+ optimizer_grouped_parameters
47
+ The model parameters to be optimized.
48
+ lr
49
+ Learning rate.
50
+ weight_decay
51
+ Optimizer weight decay.
52
+ eps
53
+ Optimizer eps.
54
+ betas
55
+ Optimizer betas.
56
+ momentum
57
+ Momentum used in the SGD optimizer.
58
+
59
+ Returns
60
+ -------
61
+ A Pytorch optimizer.
62
+ """
63
+ if optim_type == "adamw":
64
+ optimizer = optim.AdamW(
65
+ optimizer_grouped_parameters,
66
+ lr=lr,
67
+ weight_decay=weight_decay,
68
+ eps=eps,
69
+ betas=betas,
70
+ )
71
+ elif optim_type == "adam":
72
+ optimizer = optim.Adam(
73
+ optimizer_grouped_parameters,
74
+ lr=lr,
75
+ weight_decay=weight_decay,
76
+ )
77
+ elif optim_type == "sgd":
78
+ optimizer = optim.SGD(
79
+ optimizer_grouped_parameters,
80
+ lr=lr,
81
+ weight_decay=weight_decay,
82
+ momentum=momentum,
83
+ )
84
+ elif optim_type == "adafactor":
85
+ optimizer = Adafactor(
86
+ optimizer_grouped_parameters,
87
+ lr=lr,
88
+ weight_decay=weight_decay,
89
+ scale_parameter=True, # Generally recommended to enable scaling
90
+ relative_step=False,
91
+ warmup_init=False,
92
+ )
93
+ else:
94
+ raise ValueError(f"unknown optimizer: {optim_type}")
95
+
96
+ return optimizer
97
+
98
+
99
+ def get_weight_decay_param_names(model: nn.Module):
100
+ """
101
+ Set the layer normalization parameters and other layers' bias parameters not to use weight decay.
102
+
103
+ Parameters
104
+ ----------
105
+ model
106
+ A Pytorch model.
107
+
108
+ Returns
109
+ -------
110
+ A list of parameter names not using weight decay.
111
+ """
112
+ # By default, we should not apply weight decay for all the norm layers
113
+ decay_param_names = get_parameter_names(
114
+ model,
115
+ [nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
116
+ )
117
+ decay_param_names = [
118
+ name
119
+ for name in decay_param_names
120
+ if (
121
+ "bias" not in name
122
+ and "cls_token" not in name
123
+ and "categorical_feature_tokenizer" not in name
124
+ and "numerical_feature_tokenizer" not in name
125
+ )
126
+ ]
127
+ return decay_param_names
128
+
129
+
130
+ def get_norm_layer_param_names(model: nn.Module):
131
+ """
132
+ Get parameters associated with the normalization layers
133
+
134
+ Parameters
135
+ ----------
136
+ model
137
+ A Pytorch model
138
+
139
+ Returns
140
+ -------
141
+ norm_param_names
142
+ A list of normalization parameter names
143
+ """
144
+ all_param_names = [name for name, _ in model.named_parameters()]
145
+ all_param_names_except_norm_names = get_parameter_names(
146
+ model,
147
+ [nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
148
+ )
149
+ norm_param_names = [name for name in all_param_names if name not in all_param_names_except_norm_names]
150
+ return norm_param_names
151
+
152
+
153
+ def get_peft_param_names(norm_param_names: List[str], peft: Optional[str] = None, extra_params: Optional[List] = None):
154
+ """
155
+ Get the list of trainable parameters according to the provided efficient finetuning method.
156
+
157
+ Parameters
158
+ ----------
159
+ norm_param_names
160
+ The parameters associated with the normalization layers
161
+ peft
162
+ Efficient finetuning strategy. Trainable parameters will be adjusted according to the method.
163
+ extra_params
164
+ Extra parameters to train.
165
+
166
+ Returns
167
+ -------
168
+ Get list of trainable parameter names according to the provided efficient finetuning method.
169
+ """
170
+ peft_param_names = []
171
+
172
+ if peft == BIT_FIT:
173
+ peft_param_names.append(".*bias*.")
174
+ elif peft == NORM_FIT:
175
+ peft_param_names.append(".*bias*.")
176
+ peft_param_names += norm_param_names
177
+ elif peft in [LORA, IA3, IA3_LORA, CONV_LORA]:
178
+ peft_param_names.append(".*lora_*.")
179
+ elif peft in [LORA_BIAS, IA3_BIAS, IA3_LORA_BIAS]:
180
+ peft_param_names.append(".*lora_*.")
181
+ peft_param_names.append(".*bias*.")
182
+ elif peft in [LORA_NORM, IA3_NORM, IA3_LORA_NORM]:
183
+ peft_param_names.append(".*lora_*.")
184
+ peft_param_names.append(".*bias*.")
185
+ peft_param_names += norm_param_names
186
+ elif peft is not None:
187
+ raise NotImplementedError(
188
+ f"The efficient finetuning strategy '{peft}'"
189
+ f" is not supported. We only support"
190
+ f" {', '.join(PEFT_STRATEGIES)}."
191
+ )
192
+
193
+ if extra_params:
194
+ peft_param_names.extend(extra_params)
195
+
196
+ return peft_param_names
197
+
198
+
199
+ def remove_parameters_without_grad(
200
+ grouped_parameters: List[Dict],
201
+ ):
202
+ """
203
+ Remove layers
204
+
205
+ Parameters
206
+ ----------
207
+ grouped_parameters
208
+ The grouped parameters or their names output from lr_choice.
209
+
210
+ Returns
211
+ -------
212
+ The updated grouped parameters or their names.
213
+ """
214
+ for group_idx, group_param in enumerate(grouped_parameters):
215
+ updated_params = []
216
+ for p in group_param["params"]:
217
+ if p.requires_grad:
218
+ updated_params.append(p)
219
+ grouped_parameters[group_idx]["params"] = updated_params
220
+
221
+ return grouped_parameters
222
+
223
+
224
+ def gather_column_features(
225
+ output: Dict[str, Dict],
226
+ column_names: Union[str, List[str]],
227
+ ):
228
+ """
229
+ Gather column features from models' outputs.
230
+ For each feature name in one model's output, we enumerate the provided column names to see
231
+ whether (partial) the provided columns share one cls feature or they have independent features.
232
+
233
+ TODO: return features' masks and use them to filter the losses.
234
+
235
+ Parameters
236
+ ----------
237
+ output
238
+ The models' outputs.
239
+ column_names
240
+ The columns whose features we want to get.
241
+
242
+ Returns
243
+ -------
244
+ The gathered feature vectors. Each sample should only have one feature vector.
245
+ """
246
+ if isinstance(column_names, str):
247
+ column_names = [column_names]
248
+
249
+ gathered_features = []
250
+ # logger.debug(f"gather features for columns: {column_names}")
251
+ for per_model_name, per_model_output in output.items():
252
+ # logger.debug(f"gather column features from model: {per_model_name}")
253
+ for feature_name in per_model_output[COLUMN_FEATURES][FEATURES]:
254
+ # logger.debug(f"processing feature: {feature_name}")
255
+ columns_share_one_feature = []
256
+ for col_name in column_names:
257
+ if col_name in feature_name:
258
+ # this column feature is part of the cls feature
259
+ if not (feature_name.startswith(col_name) and feature_name.endswith(col_name)):
260
+ columns_share_one_feature.append(col_name)
261
+ # logger.debug(f"column {col_name} is included in feature {feature_name}")
262
+ else: # this column's feature is independent of other columns'
263
+ gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][col_name])
264
+ # logger.debug(f"col_name {col_name} has an independent feature in model: {per_model_name}")
265
+
266
+ # two or more columns share one cls feature, and no other columns share it.
267
+ if len(columns_share_one_feature) > 0:
268
+ assert (
269
+ len("_".join(columns_share_one_feature)) == len(feature_name)
270
+ ), f"model `{per_model_name}`'s cls feature name `{feature_name}` doesn't match `{columns_share_one_feature}`"
271
+ gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][feature_name])
272
+
273
+ if len(gathered_features) > 1:
274
+ # currently only support features of the same shape
275
+ assert all(
276
+ per_features.shape == gathered_features[0].shape for per_features in gathered_features
277
+ ), "Currently we only support gathering features of the same dimension."
278
+
279
+ if len(gathered_features) == 0:
280
+ raise ValueError(f"No features are found for columns names {column_names}.")
281
+
282
+ gathered_features = torch.stack(gathered_features, dim=0).mean(dim=0) # (b, d)
283
+
284
+ return gathered_features