autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
import functools
|
2
|
+
import logging
|
3
|
+
import warnings
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import evaluate
|
7
|
+
import torchmetrics
|
8
|
+
from torch.nn import functional as F
|
9
|
+
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
10
|
+
|
11
|
+
from autogluon.core.metrics import Scorer, compute_metric, get_metric
|
12
|
+
|
13
|
+
from ...constants import (
|
14
|
+
ACC,
|
15
|
+
ACCURACY,
|
16
|
+
AVERAGE_PRECISION,
|
17
|
+
BER,
|
18
|
+
COSINE_EMBEDDING_LOSS,
|
19
|
+
COVERAGE,
|
20
|
+
CROSS_ENTROPY,
|
21
|
+
DETECTION_METRICS,
|
22
|
+
DIRECT_LOSS,
|
23
|
+
EM,
|
24
|
+
F1,
|
25
|
+
F1_MACRO,
|
26
|
+
F1_MICRO,
|
27
|
+
F1_WEIGHTED,
|
28
|
+
FM,
|
29
|
+
HIT_RATE,
|
30
|
+
IOU,
|
31
|
+
LOG_LOSS,
|
32
|
+
MAE,
|
33
|
+
MATCHING_METRICS,
|
34
|
+
MATCHING_METRICS_WITHOUT_PROBLEM_TYPE,
|
35
|
+
MAX,
|
36
|
+
METRIC_MODE_MAP,
|
37
|
+
MIN,
|
38
|
+
MULTICLASS,
|
39
|
+
NER_TOKEN_F1,
|
40
|
+
OVERALL_ACCURACY,
|
41
|
+
OVERALL_F1,
|
42
|
+
PEARSONR,
|
43
|
+
QUADRATIC_KAPPA,
|
44
|
+
R2,
|
45
|
+
RECALL,
|
46
|
+
RETRIEVAL_METRICS,
|
47
|
+
RMSE,
|
48
|
+
ROC_AUC,
|
49
|
+
ROOT_MEAN_SQUARED_ERROR,
|
50
|
+
SM,
|
51
|
+
SPEARMANR,
|
52
|
+
Y_PRED,
|
53
|
+
Y_PRED_PROB,
|
54
|
+
Y_TRUE,
|
55
|
+
)
|
56
|
+
from .coverage_metrics import Coverage
|
57
|
+
from .hit_rate_metrics import CustomHitRate
|
58
|
+
from .semantic_seg_metrics import COD_METRICS_NAMES, Balanced_Error_Rate, Binary_IoU, Multiclass_IoU
|
59
|
+
|
60
|
+
logger = logging.getLogger(__name__)
|
61
|
+
|
62
|
+
|
63
|
+
def compute_score(
|
64
|
+
metric_data: dict,
|
65
|
+
metric: Union[str, Scorer],
|
66
|
+
pos_label: Optional[int] = 1,
|
67
|
+
) -> float:
|
68
|
+
"""
|
69
|
+
Use sklearn to compute the score of one metric.
|
70
|
+
|
71
|
+
Parameters
|
72
|
+
----------
|
73
|
+
metric_data
|
74
|
+
A dictionary with the groundtruth (Y_TRUE) and predicted values (Y_PRED, Y_PRED_PROB).
|
75
|
+
The predicted class probabilities are required to compute the roc_auc score.
|
76
|
+
metric
|
77
|
+
The name of metric or the function of metric to compute.
|
78
|
+
pos_label
|
79
|
+
The encoded label (0 or 1) of binary classification's positive class.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
Computed score.
|
84
|
+
"""
|
85
|
+
if isinstance(metric, str) and metric in [OVERALL_ACCURACY, OVERALL_F1]:
|
86
|
+
metric = evaluate.load("seqeval")
|
87
|
+
warnings.filterwarnings("ignore")
|
88
|
+
for p in metric_data[Y_TRUE]:
|
89
|
+
if "_" in p:
|
90
|
+
print(p)
|
91
|
+
for p in metric_data[Y_PRED]:
|
92
|
+
if "_" in p:
|
93
|
+
print(p)
|
94
|
+
return metric.compute(references=metric_data[Y_TRUE], predictions=metric_data[Y_PRED])
|
95
|
+
|
96
|
+
metric = get_metric(metric)
|
97
|
+
|
98
|
+
y = metric_data[Y_TRUE]
|
99
|
+
if metric.needs_proba or metric.needs_threshold:
|
100
|
+
y_pred_proba = metric_data[Y_PRED_PROB]
|
101
|
+
y_pred_proba = (
|
102
|
+
y_pred_proba if y_pred_proba.shape[1] > 2 else y_pred_proba[:, pos_label]
|
103
|
+
) # only use pos_label for binary classification
|
104
|
+
return metric.convert_score_to_original(
|
105
|
+
compute_metric(y=y, y_pred_proba=y_pred_proba, metric=metric, weights=None)
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
y_pred = metric_data[Y_PRED]
|
109
|
+
|
110
|
+
# TODO: This is a hack. Doesn't support `f1_macro`, `f1_micro`, `f1_weighted`, or custom `f1` metrics with different names.
|
111
|
+
# TODO: Longterm the solution should be to have the input data to this function use the internal representation without the original class names. This way `pos_label` would not need to be specified.
|
112
|
+
if metric.name == F1: # only for binary classification
|
113
|
+
y = (y == pos_label).astype(int)
|
114
|
+
y_pred = (y_pred == pos_label).astype(int)
|
115
|
+
|
116
|
+
return metric.convert_score_to_original(compute_metric(y=y, y_pred=y_pred, metric=metric, weights=None))
|
117
|
+
|
118
|
+
|
119
|
+
def infer_metrics(
|
120
|
+
problem_type: Optional[str] = None,
|
121
|
+
eval_metric: Optional[Union[str, Scorer]] = None,
|
122
|
+
validation_metric_name: Optional[str] = None,
|
123
|
+
is_matching: Optional[bool] = False,
|
124
|
+
):
|
125
|
+
"""
|
126
|
+
Infer the validation metric and the evaluation metric if not provided.
|
127
|
+
Validation metric is for early-stopping and selecting the best model checkpoints.
|
128
|
+
Evaluation metric is to report performance to users.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
problem_type
|
133
|
+
Type of problem.
|
134
|
+
eval_metric
|
135
|
+
Name of evaluation metric provided by users.
|
136
|
+
validation_metric_name
|
137
|
+
The provided validation metric name
|
138
|
+
is_matching
|
139
|
+
Whether is matching.
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
validation_metric_name
|
144
|
+
Name of validation metric.
|
145
|
+
eval_metric_name
|
146
|
+
Name of evaluation metric.
|
147
|
+
"""
|
148
|
+
is_customized = False
|
149
|
+
if eval_metric is None:
|
150
|
+
eval_metric_name = None
|
151
|
+
elif isinstance(eval_metric, str):
|
152
|
+
eval_metric_name = eval_metric
|
153
|
+
elif isinstance(eval_metric, Scorer):
|
154
|
+
eval_metric_name = eval_metric.name
|
155
|
+
is_customized = True
|
156
|
+
else:
|
157
|
+
raise TypeError(f"eval_metric can be a str, a Scorer, or None, but is type: {type(eval_metric)}")
|
158
|
+
|
159
|
+
if problem_type is not None:
|
160
|
+
from ...utils.problem_types import PROBLEM_TYPES_REG
|
161
|
+
|
162
|
+
problem_property = PROBLEM_TYPES_REG.get(problem_type)
|
163
|
+
|
164
|
+
if is_matching:
|
165
|
+
if eval_metric_name is not None:
|
166
|
+
# if eval_metric_name is a valid metric
|
167
|
+
if eval_metric_name.lower() in METRIC_MODE_MAP.keys():
|
168
|
+
validation_metric_name = eval_metric_name
|
169
|
+
return validation_metric_name, eval_metric_name
|
170
|
+
elif eval_metric_name.lower() in RETRIEVAL_METRICS:
|
171
|
+
# Currently only support recall as validation metric in retrieval tasks.
|
172
|
+
validation_metric_name = RECALL
|
173
|
+
return validation_metric_name, eval_metric_name
|
174
|
+
|
175
|
+
# When eval_metric_name is either None or not supported:
|
176
|
+
# Fallback based on problem type unless it's a customized metric
|
177
|
+
if problem_type is None:
|
178
|
+
validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS_WITHOUT_PROBLEM_TYPE
|
179
|
+
elif problem_type in MATCHING_METRICS:
|
180
|
+
validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS[problem_type]
|
181
|
+
else:
|
182
|
+
raise NotImplementedError(f"Problem type: {problem_type} is not yet supported for matching!")
|
183
|
+
if not is_customized:
|
184
|
+
if eval_metric_name is not None:
|
185
|
+
warnings.warn(
|
186
|
+
f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type} in matching tasks."
|
187
|
+
f"The evaluation metric is changed to {fallback_evaluation_metric} by default."
|
188
|
+
)
|
189
|
+
eval_metric_name = fallback_evaluation_metric
|
190
|
+
return validation_metric_name, eval_metric_name
|
191
|
+
|
192
|
+
if eval_metric_name is not None:
|
193
|
+
# Infer evaluation metric
|
194
|
+
if eval_metric_name.lower() not in problem_property.supported_evaluation_metrics and not is_customized:
|
195
|
+
warnings.warn(
|
196
|
+
f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type}. "
|
197
|
+
f"The evaluation metric is changed to {problem_property.fallback_evaluation_metric} by default."
|
198
|
+
)
|
199
|
+
if problem_property.fallback_evaluation_metric is not None:
|
200
|
+
eval_metric_name = problem_property.fallback_evaluation_metric
|
201
|
+
else:
|
202
|
+
# Problem types like extract_embedding does not need a eval/val metric
|
203
|
+
return None, None
|
204
|
+
|
205
|
+
# Infer validation metric
|
206
|
+
if eval_metric_name.lower() in problem_property.supported_validation_metrics:
|
207
|
+
validation_metric_name = eval_metric_name
|
208
|
+
else:
|
209
|
+
if problem_property.fallback_validation_metric is not None:
|
210
|
+
validation_metric_name = problem_property.fallback_validation_metric
|
211
|
+
else:
|
212
|
+
eval_metric_name = problem_property.fallback_evaluation_metric
|
213
|
+
validation_metric_name = problem_property.fallback_validation_metric
|
214
|
+
|
215
|
+
return validation_metric_name, eval_metric_name
|
216
|
+
|
217
|
+
|
218
|
+
def get_minmax_mode(
|
219
|
+
metric_name: Union[str, Scorer],
|
220
|
+
):
|
221
|
+
"""
|
222
|
+
Get minmax mode based on metric name
|
223
|
+
|
224
|
+
Parameters
|
225
|
+
----------
|
226
|
+
metric_name
|
227
|
+
A string representing metric
|
228
|
+
|
229
|
+
Returns
|
230
|
+
-------
|
231
|
+
mode
|
232
|
+
The min/max mode used in selecting model checkpoints.
|
233
|
+
- min
|
234
|
+
Its means that smaller metric is better.
|
235
|
+
- max
|
236
|
+
It means that larger metric is better.
|
237
|
+
"""
|
238
|
+
if isinstance(metric_name, str):
|
239
|
+
assert (
|
240
|
+
metric_name in METRIC_MODE_MAP
|
241
|
+
), f"{metric_name} is not a supported metric. Options are: {METRIC_MODE_MAP.keys()}"
|
242
|
+
return METRIC_MODE_MAP.get(metric_name)
|
243
|
+
else:
|
244
|
+
return MAX if metric_name.greater_is_better else MIN
|
245
|
+
|
246
|
+
|
247
|
+
def get_stopping_threshold(metric_name: str):
|
248
|
+
"""
|
249
|
+
Get the metric threshold for early stopping.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
metric_name
|
254
|
+
Name of validation metric.
|
255
|
+
|
256
|
+
Returns
|
257
|
+
-------
|
258
|
+
The stopping threshold.
|
259
|
+
"""
|
260
|
+
try:
|
261
|
+
metric = get_metric(metric_name)
|
262
|
+
stopping_threshold = metric.optimum - metric._sign * 1e-7
|
263
|
+
except:
|
264
|
+
stopping_threshold = None
|
265
|
+
|
266
|
+
return stopping_threshold
|
267
|
+
|
268
|
+
|
269
|
+
def get_torchmetric(
|
270
|
+
metric_name: str,
|
271
|
+
num_classes: Optional[int] = None,
|
272
|
+
is_matching: Optional[bool] = False,
|
273
|
+
problem_type: Optional[str] = None,
|
274
|
+
):
|
275
|
+
"""
|
276
|
+
Obtain a torchmerics.Metric from its name.
|
277
|
+
Define a customized metric function in case that torchmetrics doesn't support some metric.
|
278
|
+
|
279
|
+
Parameters
|
280
|
+
----------
|
281
|
+
metric_name
|
282
|
+
Name of metric.
|
283
|
+
num_classes
|
284
|
+
Number of classes.
|
285
|
+
is_matching
|
286
|
+
Whether is matching.
|
287
|
+
problem_type
|
288
|
+
Type of problem, e.g., binary and multiclass.
|
289
|
+
|
290
|
+
Returns
|
291
|
+
-------
|
292
|
+
torchmetrics.Metric
|
293
|
+
A torchmetrics.Metric object.
|
294
|
+
custom_metric_func
|
295
|
+
A customized metric function.
|
296
|
+
"""
|
297
|
+
metric_name = metric_name.lower()
|
298
|
+
if metric_name in [ACC, ACCURACY, OVERALL_ACCURACY]:
|
299
|
+
# use MULTICLASS since the head output dim is 2 for the binary problem type.
|
300
|
+
return torchmetrics.Accuracy(task=MULTICLASS, num_classes=num_classes), None
|
301
|
+
elif metric_name == NER_TOKEN_F1:
|
302
|
+
return torchmetrics.F1Score(task=MULTICLASS, num_classes=num_classes, ignore_index=1), None
|
303
|
+
elif metric_name in [RMSE, ROOT_MEAN_SQUARED_ERROR]:
|
304
|
+
return torchmetrics.MeanSquaredError(squared=False), None
|
305
|
+
elif metric_name == R2:
|
306
|
+
return torchmetrics.R2Score(), None
|
307
|
+
elif metric_name == QUADRATIC_KAPPA:
|
308
|
+
return (
|
309
|
+
torchmetrics.CohenKappa(task=problem_type, num_classes=num_classes, weights="quadratic"),
|
310
|
+
None,
|
311
|
+
)
|
312
|
+
elif metric_name == ROC_AUC:
|
313
|
+
return torchmetrics.AUROC(task=problem_type, num_classes=num_classes), None
|
314
|
+
elif metric_name == AVERAGE_PRECISION:
|
315
|
+
return torchmetrics.AveragePrecision(task=problem_type, num_classes=num_classes)
|
316
|
+
elif metric_name in [LOG_LOSS, CROSS_ENTROPY]:
|
317
|
+
return torchmetrics.MeanMetric(), functools.partial(F.cross_entropy, reduction="none")
|
318
|
+
elif metric_name == COSINE_EMBEDDING_LOSS:
|
319
|
+
return torchmetrics.MeanMetric(), functools.partial(F.cosine_embedding_loss, reduction="none")
|
320
|
+
elif metric_name == PEARSONR:
|
321
|
+
return torchmetrics.PearsonCorrCoef(), None
|
322
|
+
elif metric_name == SPEARMANR:
|
323
|
+
if is_matching: # TODO: add support for matching.
|
324
|
+
raise ValueError("spearman relation is not supported for matching yet.")
|
325
|
+
else:
|
326
|
+
return torchmetrics.SpearmanCorrCoef(), None
|
327
|
+
elif metric_name == F1:
|
328
|
+
return torchmetrics.F1Score(task=problem_type, num_classes=num_classes), None
|
329
|
+
elif metric_name in [F1_MACRO, F1_MICRO, F1_WEIGHTED]:
|
330
|
+
average = metric_name.split("_")[1]
|
331
|
+
return torchmetrics.F1Score(task=problem_type, num_classes=num_classes, average=average), None
|
332
|
+
elif metric_name in DETECTION_METRICS:
|
333
|
+
return (
|
334
|
+
MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=False),
|
335
|
+
None,
|
336
|
+
) # TODO: remove parameter hardcodings here, and add class_metrics
|
337
|
+
elif metric_name == DIRECT_LOSS:
|
338
|
+
return (
|
339
|
+
torchmetrics.MeanMetric(nan_strategy="warn"),
|
340
|
+
None,
|
341
|
+
) # This only works for detection where custom_metric is not required for BaseAggregator
|
342
|
+
elif metric_name in [RECALL, HIT_RATE]:
|
343
|
+
if is_matching:
|
344
|
+
return CustomHitRate(), None
|
345
|
+
else: # TODO: support recall for general classification tasks.
|
346
|
+
raise ValueError("Recall is not supported yet.")
|
347
|
+
elif metric_name == BER:
|
348
|
+
return Balanced_Error_Rate(), None
|
349
|
+
elif metric_name in [SM, EM, FM, MAE]:
|
350
|
+
return COD_METRICS_NAMES[metric_name], None
|
351
|
+
elif metric_name == IOU:
|
352
|
+
if num_classes == 1:
|
353
|
+
return Binary_IoU(), None
|
354
|
+
else:
|
355
|
+
return Multiclass_IoU(num_classes=num_classes), None
|
356
|
+
elif metric_name == COVERAGE:
|
357
|
+
return Coverage(), None
|
358
|
+
else:
|
359
|
+
raise ValueError(f"Unknown metric {metric_name}")
|
@@ -0,0 +1,284 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn, optim
|
6
|
+
from transformers import Adafactor
|
7
|
+
from transformers.trainer_pt_utils import get_parameter_names
|
8
|
+
|
9
|
+
from ..constants import (
|
10
|
+
BIT_FIT,
|
11
|
+
COLUMN_FEATURES,
|
12
|
+
CONV_LORA,
|
13
|
+
FEATURES,
|
14
|
+
IA3,
|
15
|
+
IA3_BIAS,
|
16
|
+
IA3_LORA,
|
17
|
+
IA3_LORA_BIAS,
|
18
|
+
IA3_LORA_NORM,
|
19
|
+
IA3_NORM,
|
20
|
+
LORA,
|
21
|
+
LORA_BIAS,
|
22
|
+
LORA_NORM,
|
23
|
+
NORM_FIT,
|
24
|
+
PEFT_STRATEGIES,
|
25
|
+
)
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
def get_optimizer(
|
31
|
+
optim_type: str,
|
32
|
+
optimizer_grouped_parameters,
|
33
|
+
lr: float,
|
34
|
+
weight_decay: float,
|
35
|
+
eps: Optional[float] = 1e-6,
|
36
|
+
betas: Optional[Tuple[float, float]] = (0.9, 0.999),
|
37
|
+
momentum: Optional[float] = 0.9,
|
38
|
+
):
|
39
|
+
"""
|
40
|
+
Choose a Pytorch optimizer based on its name.
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
optim_type
|
45
|
+
Name of optimizer.
|
46
|
+
optimizer_grouped_parameters
|
47
|
+
The model parameters to be optimized.
|
48
|
+
lr
|
49
|
+
Learning rate.
|
50
|
+
weight_decay
|
51
|
+
Optimizer weight decay.
|
52
|
+
eps
|
53
|
+
Optimizer eps.
|
54
|
+
betas
|
55
|
+
Optimizer betas.
|
56
|
+
momentum
|
57
|
+
Momentum used in the SGD optimizer.
|
58
|
+
|
59
|
+
Returns
|
60
|
+
-------
|
61
|
+
A Pytorch optimizer.
|
62
|
+
"""
|
63
|
+
if optim_type == "adamw":
|
64
|
+
optimizer = optim.AdamW(
|
65
|
+
optimizer_grouped_parameters,
|
66
|
+
lr=lr,
|
67
|
+
weight_decay=weight_decay,
|
68
|
+
eps=eps,
|
69
|
+
betas=betas,
|
70
|
+
)
|
71
|
+
elif optim_type == "adam":
|
72
|
+
optimizer = optim.Adam(
|
73
|
+
optimizer_grouped_parameters,
|
74
|
+
lr=lr,
|
75
|
+
weight_decay=weight_decay,
|
76
|
+
)
|
77
|
+
elif optim_type == "sgd":
|
78
|
+
optimizer = optim.SGD(
|
79
|
+
optimizer_grouped_parameters,
|
80
|
+
lr=lr,
|
81
|
+
weight_decay=weight_decay,
|
82
|
+
momentum=momentum,
|
83
|
+
)
|
84
|
+
elif optim_type == "adafactor":
|
85
|
+
optimizer = Adafactor(
|
86
|
+
optimizer_grouped_parameters,
|
87
|
+
lr=lr,
|
88
|
+
weight_decay=weight_decay,
|
89
|
+
scale_parameter=True, # Generally recommended to enable scaling
|
90
|
+
relative_step=False,
|
91
|
+
warmup_init=False,
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
raise ValueError(f"unknown optimizer: {optim_type}")
|
95
|
+
|
96
|
+
return optimizer
|
97
|
+
|
98
|
+
|
99
|
+
def get_weight_decay_param_names(model: nn.Module):
|
100
|
+
"""
|
101
|
+
Set the layer normalization parameters and other layers' bias parameters not to use weight decay.
|
102
|
+
|
103
|
+
Parameters
|
104
|
+
----------
|
105
|
+
model
|
106
|
+
A Pytorch model.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
A list of parameter names not using weight decay.
|
111
|
+
"""
|
112
|
+
# By default, we should not apply weight decay for all the norm layers
|
113
|
+
decay_param_names = get_parameter_names(
|
114
|
+
model,
|
115
|
+
[nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
|
116
|
+
)
|
117
|
+
decay_param_names = [
|
118
|
+
name
|
119
|
+
for name in decay_param_names
|
120
|
+
if (
|
121
|
+
"bias" not in name
|
122
|
+
and "cls_token" not in name
|
123
|
+
and "categorical_feature_tokenizer" not in name
|
124
|
+
and "numerical_feature_tokenizer" not in name
|
125
|
+
)
|
126
|
+
]
|
127
|
+
return decay_param_names
|
128
|
+
|
129
|
+
|
130
|
+
def get_norm_layer_param_names(model: nn.Module):
|
131
|
+
"""
|
132
|
+
Get parameters associated with the normalization layers
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
model
|
137
|
+
A Pytorch model
|
138
|
+
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
norm_param_names
|
142
|
+
A list of normalization parameter names
|
143
|
+
"""
|
144
|
+
all_param_names = [name for name, _ in model.named_parameters()]
|
145
|
+
all_param_names_except_norm_names = get_parameter_names(
|
146
|
+
model,
|
147
|
+
[nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
|
148
|
+
)
|
149
|
+
norm_param_names = [name for name in all_param_names if name not in all_param_names_except_norm_names]
|
150
|
+
return norm_param_names
|
151
|
+
|
152
|
+
|
153
|
+
def get_peft_param_names(norm_param_names: List[str], peft: Optional[str] = None, extra_params: Optional[List] = None):
|
154
|
+
"""
|
155
|
+
Get the list of trainable parameters according to the provided efficient finetuning method.
|
156
|
+
|
157
|
+
Parameters
|
158
|
+
----------
|
159
|
+
norm_param_names
|
160
|
+
The parameters associated with the normalization layers
|
161
|
+
peft
|
162
|
+
Efficient finetuning strategy. Trainable parameters will be adjusted according to the method.
|
163
|
+
extra_params
|
164
|
+
Extra parameters to train.
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
Get list of trainable parameter names according to the provided efficient finetuning method.
|
169
|
+
"""
|
170
|
+
peft_param_names = []
|
171
|
+
|
172
|
+
if peft == BIT_FIT:
|
173
|
+
peft_param_names.append(".*bias*.")
|
174
|
+
elif peft == NORM_FIT:
|
175
|
+
peft_param_names.append(".*bias*.")
|
176
|
+
peft_param_names += norm_param_names
|
177
|
+
elif peft in [LORA, IA3, IA3_LORA, CONV_LORA]:
|
178
|
+
peft_param_names.append(".*lora_*.")
|
179
|
+
elif peft in [LORA_BIAS, IA3_BIAS, IA3_LORA_BIAS]:
|
180
|
+
peft_param_names.append(".*lora_*.")
|
181
|
+
peft_param_names.append(".*bias*.")
|
182
|
+
elif peft in [LORA_NORM, IA3_NORM, IA3_LORA_NORM]:
|
183
|
+
peft_param_names.append(".*lora_*.")
|
184
|
+
peft_param_names.append(".*bias*.")
|
185
|
+
peft_param_names += norm_param_names
|
186
|
+
elif peft is not None:
|
187
|
+
raise NotImplementedError(
|
188
|
+
f"The efficient finetuning strategy '{peft}'"
|
189
|
+
f" is not supported. We only support"
|
190
|
+
f" {', '.join(PEFT_STRATEGIES)}."
|
191
|
+
)
|
192
|
+
|
193
|
+
if extra_params:
|
194
|
+
peft_param_names.extend(extra_params)
|
195
|
+
|
196
|
+
return peft_param_names
|
197
|
+
|
198
|
+
|
199
|
+
def remove_parameters_without_grad(
|
200
|
+
grouped_parameters: List[Dict],
|
201
|
+
):
|
202
|
+
"""
|
203
|
+
Remove layers
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
grouped_parameters
|
208
|
+
The grouped parameters or their names output from lr_choice.
|
209
|
+
|
210
|
+
Returns
|
211
|
+
-------
|
212
|
+
The updated grouped parameters or their names.
|
213
|
+
"""
|
214
|
+
for group_idx, group_param in enumerate(grouped_parameters):
|
215
|
+
updated_params = []
|
216
|
+
for p in group_param["params"]:
|
217
|
+
if p.requires_grad:
|
218
|
+
updated_params.append(p)
|
219
|
+
grouped_parameters[group_idx]["params"] = updated_params
|
220
|
+
|
221
|
+
return grouped_parameters
|
222
|
+
|
223
|
+
|
224
|
+
def gather_column_features(
|
225
|
+
output: Dict[str, Dict],
|
226
|
+
column_names: Union[str, List[str]],
|
227
|
+
):
|
228
|
+
"""
|
229
|
+
Gather column features from models' outputs.
|
230
|
+
For each feature name in one model's output, we enumerate the provided column names to see
|
231
|
+
whether (partial) the provided columns share one cls feature or they have independent features.
|
232
|
+
|
233
|
+
TODO: return features' masks and use them to filter the losses.
|
234
|
+
|
235
|
+
Parameters
|
236
|
+
----------
|
237
|
+
output
|
238
|
+
The models' outputs.
|
239
|
+
column_names
|
240
|
+
The columns whose features we want to get.
|
241
|
+
|
242
|
+
Returns
|
243
|
+
-------
|
244
|
+
The gathered feature vectors. Each sample should only have one feature vector.
|
245
|
+
"""
|
246
|
+
if isinstance(column_names, str):
|
247
|
+
column_names = [column_names]
|
248
|
+
|
249
|
+
gathered_features = []
|
250
|
+
# logger.debug(f"gather features for columns: {column_names}")
|
251
|
+
for per_model_name, per_model_output in output.items():
|
252
|
+
# logger.debug(f"gather column features from model: {per_model_name}")
|
253
|
+
for feature_name in per_model_output[COLUMN_FEATURES][FEATURES]:
|
254
|
+
# logger.debug(f"processing feature: {feature_name}")
|
255
|
+
columns_share_one_feature = []
|
256
|
+
for col_name in column_names:
|
257
|
+
if col_name in feature_name:
|
258
|
+
# this column feature is part of the cls feature
|
259
|
+
if not (feature_name.startswith(col_name) and feature_name.endswith(col_name)):
|
260
|
+
columns_share_one_feature.append(col_name)
|
261
|
+
# logger.debug(f"column {col_name} is included in feature {feature_name}")
|
262
|
+
else: # this column's feature is independent of other columns'
|
263
|
+
gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][col_name])
|
264
|
+
# logger.debug(f"col_name {col_name} has an independent feature in model: {per_model_name}")
|
265
|
+
|
266
|
+
# two or more columns share one cls feature, and no other columns share it.
|
267
|
+
if len(columns_share_one_feature) > 0:
|
268
|
+
assert (
|
269
|
+
len("_".join(columns_share_one_feature)) == len(feature_name)
|
270
|
+
), f"model `{per_model_name}`'s cls feature name `{feature_name}` doesn't match `{columns_share_one_feature}`"
|
271
|
+
gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][feature_name])
|
272
|
+
|
273
|
+
if len(gathered_features) > 1:
|
274
|
+
# currently only support features of the same shape
|
275
|
+
assert all(
|
276
|
+
per_features.shape == gathered_features[0].shape for per_features in gathered_features
|
277
|
+
), "Currently we only support gathering features of the same dimension."
|
278
|
+
|
279
|
+
if len(gathered_features) == 0:
|
280
|
+
raise ValueError(f"No features are found for columns names {column_names}.")
|
281
|
+
|
282
|
+
gathered_features = torch.stack(gathered_features, dim=0).mean(dim=0) # (b, d)
|
283
|
+
|
284
|
+
return gathered_features
|