autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,1054 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
import logging
|
3
|
-
import re
|
4
|
-
import warnings
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
6
|
-
|
7
|
-
import torch
|
8
|
-
import torchmetrics
|
9
|
-
from omegaconf import DictConfig, OmegaConf
|
10
|
-
from packaging import version
|
11
|
-
from pytorch_metric_learning import distances, losses, miners
|
12
|
-
from torch import nn, optim
|
13
|
-
from torch.nn import functional as F
|
14
|
-
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
15
|
-
from transformers import Adafactor
|
16
|
-
from transformers.models.mask2former.modeling_mask2former import Mask2FormerConfig, Mask2FormerLoss
|
17
|
-
from transformers.trainer_pt_utils import get_parameter_names
|
18
|
-
|
19
|
-
from ..constants import (
|
20
|
-
ACC,
|
21
|
-
ACCURACY,
|
22
|
-
AVERAGE_PRECISION,
|
23
|
-
BER,
|
24
|
-
BINARY,
|
25
|
-
BIT_FIT,
|
26
|
-
COLUMN_FEATURES,
|
27
|
-
CONTRASTIVE_LOSS,
|
28
|
-
CONV_LORA,
|
29
|
-
COSINE_EMBEDDING_LOSS,
|
30
|
-
COSINE_SIMILARITY,
|
31
|
-
CROSS_ENTROPY,
|
32
|
-
DETECTION_METRICS,
|
33
|
-
DIRECT_LOSS,
|
34
|
-
EM,
|
35
|
-
F1,
|
36
|
-
F1_MACRO,
|
37
|
-
F1_MICRO,
|
38
|
-
F1_WEIGHTED,
|
39
|
-
FEATURES,
|
40
|
-
FEW_SHOT_CLASSIFICATION,
|
41
|
-
FM,
|
42
|
-
HIT_RATE,
|
43
|
-
IA3,
|
44
|
-
IA3_BIAS,
|
45
|
-
IA3_LORA,
|
46
|
-
IA3_LORA_BIAS,
|
47
|
-
IA3_LORA_NORM,
|
48
|
-
IA3_NORM,
|
49
|
-
IOU,
|
50
|
-
LOG_LOSS,
|
51
|
-
LORA,
|
52
|
-
LORA_BIAS,
|
53
|
-
LORA_NORM,
|
54
|
-
MAE,
|
55
|
-
MULTI_NEGATIVES_SOFTMAX_LOSS,
|
56
|
-
MULTICLASS,
|
57
|
-
NER,
|
58
|
-
NER_TOKEN_F1,
|
59
|
-
NORM_FIT,
|
60
|
-
OBJECT_DETECTION,
|
61
|
-
OVERALL_ACCURACY,
|
62
|
-
PAIR_MARGIN_MINER,
|
63
|
-
PEARSONR,
|
64
|
-
PEFT_STRATEGIES,
|
65
|
-
QUADRATIC_KAPPA,
|
66
|
-
R2,
|
67
|
-
RECALL,
|
68
|
-
REGRESSION,
|
69
|
-
RMSE,
|
70
|
-
ROC_AUC,
|
71
|
-
ROOT_MEAN_SQUARED_ERROR,
|
72
|
-
SEMANTIC_SEGMENTATION,
|
73
|
-
SM,
|
74
|
-
SPEARMANR,
|
75
|
-
)
|
76
|
-
from .losses import BBCEWithLogitLoss, FocalLoss, MultiNegativesSoftmaxLoss, SoftTargetCrossEntropy, StructureLoss
|
77
|
-
from .lr_scheduler import (
|
78
|
-
get_cosine_schedule_with_warmup,
|
79
|
-
get_linear_schedule_with_warmup,
|
80
|
-
get_polynomial_decay_schedule_with_warmup,
|
81
|
-
)
|
82
|
-
from .semantic_seg_metrics import COD_METRICS_NAMES, Balanced_Error_Rate, Binary_IoU, Multiclass_IoU
|
83
|
-
|
84
|
-
logger = logging.getLogger(__name__)
|
85
|
-
|
86
|
-
|
87
|
-
def get_loss_func(
|
88
|
-
problem_type: str,
|
89
|
-
mixup_active: Optional[bool] = None,
|
90
|
-
loss_func_name: Optional[str] = None,
|
91
|
-
config: Optional[DictConfig] = None,
|
92
|
-
**kwargs,
|
93
|
-
):
|
94
|
-
"""
|
95
|
-
Choose a suitable Pytorch loss module based on the provided problem type.
|
96
|
-
|
97
|
-
Parameters
|
98
|
-
----------
|
99
|
-
problem_type
|
100
|
-
Type of problem.
|
101
|
-
mixup_active
|
102
|
-
The activation determining whether to use mixup.
|
103
|
-
loss_func_name
|
104
|
-
The name of the function the user wants to use.
|
105
|
-
config
|
106
|
-
The optimization configs containing values such as i.e. optimization.loss_function
|
107
|
-
An example purpose of this config here is to pass through the parameters for focal loss, i.e.:
|
108
|
-
alpha = optimization.focal_loss.alpha
|
109
|
-
Returns
|
110
|
-
-------
|
111
|
-
A Pytorch loss module.
|
112
|
-
"""
|
113
|
-
if problem_type in [BINARY, MULTICLASS]:
|
114
|
-
if mixup_active:
|
115
|
-
loss_func = SoftTargetCrossEntropy()
|
116
|
-
else:
|
117
|
-
if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
|
118
|
-
loss_func = FocalLoss(
|
119
|
-
alpha=OmegaConf.select(config, "focal_loss.alpha"),
|
120
|
-
gamma=OmegaConf.select(config, "focal_loss.gamma", default=2.0),
|
121
|
-
reduction=OmegaConf.select(config, "focal_loss.reduction", default="mean"),
|
122
|
-
)
|
123
|
-
else:
|
124
|
-
loss_func = nn.CrossEntropyLoss()
|
125
|
-
elif problem_type == REGRESSION:
|
126
|
-
if loss_func_name is not None:
|
127
|
-
if "bcewithlogitsloss" in loss_func_name.lower():
|
128
|
-
loss_func = nn.BCEWithLogitsLoss()
|
129
|
-
else:
|
130
|
-
loss_func = nn.MSELoss()
|
131
|
-
else:
|
132
|
-
loss_func = nn.MSELoss()
|
133
|
-
elif problem_type == NER:
|
134
|
-
loss_func = nn.CrossEntropyLoss(ignore_index=0)
|
135
|
-
elif problem_type in [OBJECT_DETECTION, FEW_SHOT_CLASSIFICATION]:
|
136
|
-
return None
|
137
|
-
elif problem_type == SEMANTIC_SEGMENTATION:
|
138
|
-
if "structure_loss" in loss_func_name.lower():
|
139
|
-
loss_func = StructureLoss()
|
140
|
-
elif "balanced_bce" in loss_func_name.lower():
|
141
|
-
loss_func = BBCEWithLogitLoss()
|
142
|
-
elif "mask2former_loss" in loss_func_name.lower():
|
143
|
-
weight_dict = {
|
144
|
-
"loss_cross_entropy": config.mask2former_loss.loss_cross_entropy_weight,
|
145
|
-
"loss_mask": config.mask2former_loss.loss_mask_weight,
|
146
|
-
"loss_dice": config.mask2former_loss.loss_dice_weight,
|
147
|
-
}
|
148
|
-
loss_func = Mask2FormerLoss(
|
149
|
-
config=Mask2FormerConfig(num_labels=kwargs["num_classes"]), weight_dict=weight_dict
|
150
|
-
)
|
151
|
-
else:
|
152
|
-
loss_func = nn.BCEWithLogitsLoss()
|
153
|
-
elif problem_type is None:
|
154
|
-
return None
|
155
|
-
else:
|
156
|
-
raise NotImplementedError
|
157
|
-
|
158
|
-
return loss_func
|
159
|
-
|
160
|
-
|
161
|
-
class CustomHitRate(torchmetrics.Metric):
|
162
|
-
"""
|
163
|
-
Compute the hit rate when doing semantic search between two group of embeddings.
|
164
|
-
We assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
|
165
|
-
"""
|
166
|
-
|
167
|
-
def __init__(
|
168
|
-
self,
|
169
|
-
):
|
170
|
-
super().__init__()
|
171
|
-
self.add_state("query_embeddings", default=[], dist_reduce_fx=None)
|
172
|
-
self.add_state("response_embeddings", default=[], dist_reduce_fx=None)
|
173
|
-
self.add_state("logit_scale", default=[], dist_reduce_fx=None)
|
174
|
-
|
175
|
-
def update(
|
176
|
-
self,
|
177
|
-
batch_query_embeds: torch.Tensor,
|
178
|
-
batch_response_embeds: torch.Tensor,
|
179
|
-
logit_scale: Optional[torch.Tensor] = None,
|
180
|
-
):
|
181
|
-
self.query_embeddings.append(batch_query_embeds)
|
182
|
-
self.response_embeddings.append(batch_response_embeds)
|
183
|
-
if logit_scale is not None:
|
184
|
-
self.logit_scale.append(logit_scale)
|
185
|
-
|
186
|
-
def compute(self):
|
187
|
-
query_embeddings = torch.cat(self.query_embeddings)
|
188
|
-
response_embeddings = torch.cat(self.response_embeddings)
|
189
|
-
if self.logit_scale:
|
190
|
-
logit_scale = torch.mean(torch.stack(self.logit_scale))
|
191
|
-
else:
|
192
|
-
logit_scale = 1
|
193
|
-
|
194
|
-
return compute_hit_rate(query_embeddings, response_embeddings, logit_scale)
|
195
|
-
|
196
|
-
|
197
|
-
def compute_hit_rate(features_a, features_b, logit_scale, top_ks=[1, 5, 10]):
|
198
|
-
"""
|
199
|
-
Compute symmetric hit rates between two groups of features.
|
200
|
-
|
201
|
-
Parameters
|
202
|
-
----------
|
203
|
-
features_a
|
204
|
-
One group of features.
|
205
|
-
features_b
|
206
|
-
The other group of features.
|
207
|
-
logit_scale
|
208
|
-
The scale of logit (Used in CLIP).
|
209
|
-
top_ks
|
210
|
-
Consider only the top k elements for each query.
|
211
|
-
|
212
|
-
Returns
|
213
|
-
-------
|
214
|
-
The accumulated hit rate.
|
215
|
-
"""
|
216
|
-
assert len(features_a) == len(features_b)
|
217
|
-
hit_rate = 0
|
218
|
-
logits_per_a = (logit_scale * features_a @ features_b.t()).detach().cpu()
|
219
|
-
logits_per_b = logits_per_a.t().detach().cpu()
|
220
|
-
|
221
|
-
logits = {"logits_per_a": logits_per_a, "logits_per_b": logits_per_b}
|
222
|
-
ground_truth = torch.arange(len(features_b)).view(-1, 1)
|
223
|
-
|
224
|
-
for name, logit in logits.items():
|
225
|
-
ranking = torch.argsort(logit, descending=True)
|
226
|
-
preds = torch.where(ranking == ground_truth)[1]
|
227
|
-
|
228
|
-
for k in top_ks:
|
229
|
-
hit_rate += (preds < k).float().mean()
|
230
|
-
|
231
|
-
hit_rate /= len(top_ks) * len(logits)
|
232
|
-
return hit_rate
|
233
|
-
|
234
|
-
|
235
|
-
def get_metric(
|
236
|
-
metric_name: str,
|
237
|
-
num_classes: Optional[int] = None,
|
238
|
-
is_matching: Optional[bool] = False,
|
239
|
-
problem_type: Optional[str] = None,
|
240
|
-
):
|
241
|
-
"""
|
242
|
-
Obtain a torchmerics.Metric from its name.
|
243
|
-
Define a customized metric function in case that torchmetrics doesn't support some metric.
|
244
|
-
|
245
|
-
Parameters
|
246
|
-
----------
|
247
|
-
metric_name
|
248
|
-
Name of metric.
|
249
|
-
num_classes
|
250
|
-
Number of classes.
|
251
|
-
is_matching
|
252
|
-
Whether is matching.
|
253
|
-
problem_type
|
254
|
-
Type of problem, e.g., binary and multiclass.
|
255
|
-
|
256
|
-
Returns
|
257
|
-
-------
|
258
|
-
torchmetrics.Metric
|
259
|
-
A torchmetrics.Metric object.
|
260
|
-
custom_metric_func
|
261
|
-
A customized metric function.
|
262
|
-
"""
|
263
|
-
metric_name = metric_name.lower()
|
264
|
-
if metric_name in [ACC, ACCURACY, OVERALL_ACCURACY]:
|
265
|
-
# use MULTICLASS since the head output dim is 2 for the binary problem type.
|
266
|
-
return torchmetrics.Accuracy(task=MULTICLASS, num_classes=num_classes), None
|
267
|
-
elif metric_name == NER_TOKEN_F1:
|
268
|
-
return torchmetrics.F1Score(task=MULTICLASS, num_classes=num_classes, ignore_index=1), None
|
269
|
-
elif metric_name in [RMSE, ROOT_MEAN_SQUARED_ERROR]:
|
270
|
-
return torchmetrics.MeanSquaredError(squared=False), None
|
271
|
-
elif metric_name == R2:
|
272
|
-
return torchmetrics.R2Score(), None
|
273
|
-
elif metric_name == QUADRATIC_KAPPA:
|
274
|
-
return (
|
275
|
-
torchmetrics.CohenKappa(task=problem_type, num_classes=num_classes, weights="quadratic"),
|
276
|
-
None,
|
277
|
-
)
|
278
|
-
elif metric_name == ROC_AUC:
|
279
|
-
return torchmetrics.AUROC(task=problem_type, num_classes=num_classes), None
|
280
|
-
elif metric_name == AVERAGE_PRECISION:
|
281
|
-
return torchmetrics.AveragePrecision(task=problem_type, num_classes=num_classes)
|
282
|
-
elif metric_name in [LOG_LOSS, CROSS_ENTROPY]:
|
283
|
-
return torchmetrics.MeanMetric(), functools.partial(F.cross_entropy, reduction="none")
|
284
|
-
elif metric_name == COSINE_EMBEDDING_LOSS:
|
285
|
-
return torchmetrics.MeanMetric(), functools.partial(F.cosine_embedding_loss, reduction="none")
|
286
|
-
elif metric_name == PEARSONR:
|
287
|
-
return torchmetrics.PearsonCorrCoef(), None
|
288
|
-
elif metric_name == SPEARMANR:
|
289
|
-
if is_matching: # TODO: add support for matching.
|
290
|
-
raise ValueError("spearman relation is not supported for matching yet.")
|
291
|
-
else:
|
292
|
-
return torchmetrics.SpearmanCorrCoef(), None
|
293
|
-
elif metric_name == F1:
|
294
|
-
return torchmetrics.F1Score(task=problem_type, num_classes=num_classes), None
|
295
|
-
elif metric_name in [F1_MACRO, F1_MICRO, F1_WEIGHTED]:
|
296
|
-
average = metric_name.split("_")[1]
|
297
|
-
return torchmetrics.F1Score(task=problem_type, num_classes=num_classes, average=average), None
|
298
|
-
elif metric_name in DETECTION_METRICS:
|
299
|
-
return (
|
300
|
-
MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=False),
|
301
|
-
None,
|
302
|
-
) # TODO: remove parameter hardcodings here, and add class_metrics
|
303
|
-
elif metric_name == DIRECT_LOSS:
|
304
|
-
return (
|
305
|
-
torchmetrics.MeanMetric(nan_strategy="warn"),
|
306
|
-
None,
|
307
|
-
) # This only works for detection where custom_metric is not required for BaseAggregator
|
308
|
-
elif metric_name in [RECALL, HIT_RATE]:
|
309
|
-
if is_matching:
|
310
|
-
return CustomHitRate(), None
|
311
|
-
else: # TODO: support recall for general classification tasks.
|
312
|
-
raise ValueError("Recall is not supported yet.")
|
313
|
-
elif metric_name == BER:
|
314
|
-
return Balanced_Error_Rate(), None
|
315
|
-
elif metric_name in [SM, EM, FM, MAE]:
|
316
|
-
return COD_METRICS_NAMES[metric_name], None
|
317
|
-
elif metric_name == IOU:
|
318
|
-
if num_classes == 1:
|
319
|
-
return Binary_IoU(), None
|
320
|
-
else:
|
321
|
-
return Multiclass_IoU(num_classes=num_classes), None
|
322
|
-
else:
|
323
|
-
raise ValueError(f"Unknown metric {metric_name}")
|
324
|
-
|
325
|
-
|
326
|
-
def get_optimizer(
|
327
|
-
optim_type: str,
|
328
|
-
optimizer_grouped_parameters,
|
329
|
-
lr: float,
|
330
|
-
weight_decay: float,
|
331
|
-
eps: Optional[float] = 1e-6,
|
332
|
-
betas: Optional[Tuple[float, float]] = (0.9, 0.999),
|
333
|
-
momentum: Optional[float] = 0.9,
|
334
|
-
):
|
335
|
-
"""
|
336
|
-
Choose a Pytorch optimizer based on its name.
|
337
|
-
|
338
|
-
Parameters
|
339
|
-
----------
|
340
|
-
optim_type
|
341
|
-
Name of optimizer.
|
342
|
-
optimizer_grouped_parameters
|
343
|
-
The model parameters to be optimized.
|
344
|
-
lr
|
345
|
-
Learning rate.
|
346
|
-
weight_decay
|
347
|
-
Optimizer weight decay.
|
348
|
-
eps
|
349
|
-
Optimizer eps.
|
350
|
-
betas
|
351
|
-
Optimizer betas.
|
352
|
-
momentum
|
353
|
-
Momentum used in the SGD optimizer.
|
354
|
-
|
355
|
-
Returns
|
356
|
-
-------
|
357
|
-
A Pytorch optimizer.
|
358
|
-
"""
|
359
|
-
if optim_type == "adamw":
|
360
|
-
optimizer = optim.AdamW(
|
361
|
-
optimizer_grouped_parameters,
|
362
|
-
lr=lr,
|
363
|
-
weight_decay=weight_decay,
|
364
|
-
eps=eps,
|
365
|
-
betas=betas,
|
366
|
-
)
|
367
|
-
elif optim_type == "adam":
|
368
|
-
optimizer = optim.Adam(
|
369
|
-
optimizer_grouped_parameters,
|
370
|
-
lr=lr,
|
371
|
-
weight_decay=weight_decay,
|
372
|
-
)
|
373
|
-
elif optim_type == "sgd":
|
374
|
-
optimizer = optim.SGD(
|
375
|
-
optimizer_grouped_parameters,
|
376
|
-
lr=lr,
|
377
|
-
weight_decay=weight_decay,
|
378
|
-
momentum=momentum,
|
379
|
-
)
|
380
|
-
elif optim_type == "adafactor":
|
381
|
-
optimizer = Adafactor(
|
382
|
-
optimizer_grouped_parameters,
|
383
|
-
lr=lr,
|
384
|
-
weight_decay=weight_decay,
|
385
|
-
scale_parameter=True, # Generally recommended to enable scaling
|
386
|
-
relative_step=False,
|
387
|
-
warmup_init=False,
|
388
|
-
)
|
389
|
-
else:
|
390
|
-
raise ValueError(f"unknown optimizer: {optim_type}")
|
391
|
-
|
392
|
-
return optimizer
|
393
|
-
|
394
|
-
|
395
|
-
def get_lr_scheduler(
|
396
|
-
optimizer: optim.Optimizer,
|
397
|
-
num_max_steps: int,
|
398
|
-
num_warmup_steps: int,
|
399
|
-
lr_schedule: str,
|
400
|
-
end_lr: Union[float, int],
|
401
|
-
):
|
402
|
-
"""
|
403
|
-
Get the learning rate scheduler from its name. Here we use our defined learning rate
|
404
|
-
scheduler instead of those imported from "transformers" because we want to support
|
405
|
-
Pytorch lightning's "ddp_spawn" training strategy.
|
406
|
-
|
407
|
-
Parameters
|
408
|
-
----------
|
409
|
-
optimizer
|
410
|
-
A Pytorch optimizer.
|
411
|
-
num_max_steps
|
412
|
-
Number of maximum training steps.
|
413
|
-
num_warmup_steps
|
414
|
-
Number of steps to do learning rate warmup.
|
415
|
-
lr_schedule
|
416
|
-
Name of the learning rate scheduler.
|
417
|
-
end_lr
|
418
|
-
The final learning rate after decay.
|
419
|
-
|
420
|
-
Returns
|
421
|
-
-------
|
422
|
-
A learning rate scheduler.
|
423
|
-
"""
|
424
|
-
if lr_schedule == "cosine_decay":
|
425
|
-
scheduler = get_cosine_schedule_with_warmup(
|
426
|
-
optimizer=optimizer,
|
427
|
-
num_warmup_steps=num_warmup_steps,
|
428
|
-
num_training_steps=num_max_steps,
|
429
|
-
)
|
430
|
-
elif lr_schedule == "polynomial_decay":
|
431
|
-
scheduler = get_polynomial_decay_schedule_with_warmup(
|
432
|
-
optimizer=optimizer,
|
433
|
-
num_warmup_steps=num_warmup_steps,
|
434
|
-
num_training_steps=num_max_steps,
|
435
|
-
lr_end=end_lr,
|
436
|
-
power=1,
|
437
|
-
)
|
438
|
-
elif lr_schedule == "linear_decay":
|
439
|
-
scheduler = get_linear_schedule_with_warmup(
|
440
|
-
optimizer=optimizer,
|
441
|
-
num_warmup_steps=num_warmup_steps,
|
442
|
-
num_training_steps=num_max_steps,
|
443
|
-
)
|
444
|
-
elif lr_schedule == "multi_step":
|
445
|
-
# TODO: add milestones, gamma into hyperparameters
|
446
|
-
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[30, 55], gamma=0.1)
|
447
|
-
else:
|
448
|
-
raise ValueError(f"unknown lr schedule: {lr_schedule}")
|
449
|
-
|
450
|
-
return scheduler
|
451
|
-
|
452
|
-
|
453
|
-
def get_weight_decay_param_names(model: nn.Module):
|
454
|
-
"""
|
455
|
-
Set the layer normalization parameters and other layers' bias parameters not to use weight decay.
|
456
|
-
|
457
|
-
Parameters
|
458
|
-
----------
|
459
|
-
model
|
460
|
-
A Pytorch model.
|
461
|
-
|
462
|
-
Returns
|
463
|
-
-------
|
464
|
-
A list of parameter names not using weight decay.
|
465
|
-
"""
|
466
|
-
# By default, we should not apply weight decay for all the norm layers
|
467
|
-
decay_param_names = get_parameter_names(
|
468
|
-
model,
|
469
|
-
[nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
|
470
|
-
)
|
471
|
-
decay_param_names = [
|
472
|
-
name
|
473
|
-
for name in decay_param_names
|
474
|
-
if (
|
475
|
-
"bias" not in name
|
476
|
-
and "cls_token" not in name
|
477
|
-
and "categorical_feature_tokenizer" not in name
|
478
|
-
and "numerical_feature_tokenizer" not in name
|
479
|
-
)
|
480
|
-
]
|
481
|
-
return decay_param_names
|
482
|
-
|
483
|
-
|
484
|
-
def get_norm_layer_param_names(model: nn.Module):
|
485
|
-
"""
|
486
|
-
Get parameters associated with the normalization layers
|
487
|
-
|
488
|
-
Parameters
|
489
|
-
----------
|
490
|
-
model
|
491
|
-
A Pytorch model
|
492
|
-
|
493
|
-
Returns
|
494
|
-
-------
|
495
|
-
norm_param_names
|
496
|
-
A list of normalization parameter names
|
497
|
-
"""
|
498
|
-
all_param_names = [name for name, _ in model.named_parameters()]
|
499
|
-
all_param_names_except_norm_names = get_parameter_names(
|
500
|
-
model,
|
501
|
-
[nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm],
|
502
|
-
)
|
503
|
-
norm_param_names = [name for name in all_param_names if name not in all_param_names_except_norm_names]
|
504
|
-
return norm_param_names
|
505
|
-
|
506
|
-
|
507
|
-
def apply_single_lr(
|
508
|
-
model: nn.Module,
|
509
|
-
lr: float,
|
510
|
-
weight_decay: float,
|
511
|
-
return_params: Optional[bool] = True,
|
512
|
-
efficient_finetune: Optional[str] = None,
|
513
|
-
trainable_param_names: Optional[List] = None,
|
514
|
-
):
|
515
|
-
"""
|
516
|
-
Set to use a single learning rate for all parameters. Layer normalization parameters and other
|
517
|
-
layers' bias parameters don't use weight decay.
|
518
|
-
|
519
|
-
Parameters
|
520
|
-
----------
|
521
|
-
model
|
522
|
-
A Pytorch model.
|
523
|
-
lr
|
524
|
-
Learning rate.
|
525
|
-
weight_decay
|
526
|
-
Weight decay.
|
527
|
-
return_params
|
528
|
-
Whether to return parameters or their names. If you want to double-check
|
529
|
-
whether the learning rate setup is as expected, you can set "return_params=False",
|
530
|
-
and print the layer names along with their learning rates through
|
531
|
-
"print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
|
532
|
-
|
533
|
-
Returns
|
534
|
-
-------
|
535
|
-
The grouped parameters or their names.
|
536
|
-
"""
|
537
|
-
decay_param_names = get_weight_decay_param_names(model)
|
538
|
-
decay_grad_param_names = []
|
539
|
-
no_decay_grad_param_names = []
|
540
|
-
|
541
|
-
for name, param in model.named_parameters():
|
542
|
-
if (
|
543
|
-
efficient_finetune is not None
|
544
|
-
and efficient_finetune != "None"
|
545
|
-
and trainable_param_names
|
546
|
-
and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
|
547
|
-
):
|
548
|
-
param.requires_grad = False
|
549
|
-
|
550
|
-
if not param.requires_grad:
|
551
|
-
continue # frozen weights
|
552
|
-
|
553
|
-
if name in decay_param_names:
|
554
|
-
if return_params:
|
555
|
-
decay_grad_param_names.append(param)
|
556
|
-
else:
|
557
|
-
decay_grad_param_names.append(name)
|
558
|
-
|
559
|
-
else:
|
560
|
-
if return_params:
|
561
|
-
no_decay_grad_param_names.append(param)
|
562
|
-
else:
|
563
|
-
no_decay_grad_param_names.append(name)
|
564
|
-
|
565
|
-
optimizer_grouped_parameters = [
|
566
|
-
{
|
567
|
-
"params": decay_grad_param_names,
|
568
|
-
"weight_decay": weight_decay,
|
569
|
-
"lr": lr,
|
570
|
-
},
|
571
|
-
{
|
572
|
-
"params": no_decay_grad_param_names,
|
573
|
-
"weight_decay": 0.0,
|
574
|
-
"lr": lr,
|
575
|
-
},
|
576
|
-
]
|
577
|
-
return optimizer_grouped_parameters
|
578
|
-
|
579
|
-
|
580
|
-
def apply_two_stages_lr(
|
581
|
-
model: nn.Module,
|
582
|
-
lr: float,
|
583
|
-
lr_mult: Union[float, int],
|
584
|
-
weight_decay: float,
|
585
|
-
return_params: Optional[bool] = True,
|
586
|
-
):
|
587
|
-
"""
|
588
|
-
Set up the pretrained backbone to use a smaller learning rate (lr * lr_mult).
|
589
|
-
The newly added head layers use the normal learning rate (lr).
|
590
|
-
Layer normalization parameters and other layers' bias parameters don't use weight decay.
|
591
|
-
|
592
|
-
Parameters
|
593
|
-
----------
|
594
|
-
model
|
595
|
-
A Pytorch model.
|
596
|
-
lr
|
597
|
-
The learning rate.
|
598
|
-
lr_mult
|
599
|
-
The multiplier (0, 1) to scale down the learning rate.
|
600
|
-
weight_decay
|
601
|
-
Weight decay.
|
602
|
-
return_params
|
603
|
-
return_params
|
604
|
-
Whether to return parameters or their names. If you want to double-check
|
605
|
-
whether the learning rate setup is as expected, you can set "return_params=False",
|
606
|
-
and print the layer names along with their learning rates through
|
607
|
-
"print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
|
608
|
-
|
609
|
-
Returns
|
610
|
-
-------
|
611
|
-
The grouped parameters or their names.
|
612
|
-
"""
|
613
|
-
decay_param_names = get_weight_decay_param_names(model)
|
614
|
-
|
615
|
-
optimizer_grouped_parameters = [
|
616
|
-
{
|
617
|
-
"params": [
|
618
|
-
p if return_params else n
|
619
|
-
for n, p in model.named_parameters()
|
620
|
-
if n in decay_param_names and not any(bb in n for bb in model.head_layer_names)
|
621
|
-
],
|
622
|
-
"weight_decay": weight_decay,
|
623
|
-
"lr": lr,
|
624
|
-
},
|
625
|
-
{
|
626
|
-
"params": [
|
627
|
-
p if return_params else n
|
628
|
-
for n, p in model.named_parameters()
|
629
|
-
if n not in decay_param_names and not any(bb in n for bb in model.head_layer_names)
|
630
|
-
],
|
631
|
-
"weight_decay": 0.0,
|
632
|
-
"lr": lr,
|
633
|
-
},
|
634
|
-
{
|
635
|
-
"params": [
|
636
|
-
p if return_params else n
|
637
|
-
for n, p in model.named_parameters()
|
638
|
-
if n in decay_param_names and any(bb in n for bb in model.head_layer_names)
|
639
|
-
],
|
640
|
-
"weight_decay": weight_decay,
|
641
|
-
"lr": lr * lr_mult,
|
642
|
-
},
|
643
|
-
{
|
644
|
-
"params": [
|
645
|
-
p if return_params else n
|
646
|
-
for n, p in model.named_parameters()
|
647
|
-
if n not in decay_param_names and any(bb in n for bb in model.head_layer_names)
|
648
|
-
],
|
649
|
-
"weight_decay": 0.0,
|
650
|
-
"lr": lr * lr_mult,
|
651
|
-
},
|
652
|
-
]
|
653
|
-
|
654
|
-
return optimizer_grouped_parameters
|
655
|
-
|
656
|
-
|
657
|
-
def get_trainable_params_efficient_finetune(
|
658
|
-
norm_param_names: List[str], efficient_finetune: Optional[str] = None, extra_params: Optional[List] = None
|
659
|
-
):
|
660
|
-
"""
|
661
|
-
Get the list of trainable parameters according to the provided efficient finetuning method.
|
662
|
-
|
663
|
-
Parameters
|
664
|
-
----------
|
665
|
-
norm_param_names
|
666
|
-
The parameters associated with the normalization layers
|
667
|
-
efficient_finetune
|
668
|
-
Efficient finetuning strategy. Trainable parameters will be adjusted according to the method.
|
669
|
-
|
670
|
-
Returns
|
671
|
-
-------
|
672
|
-
Get list of trainable parameter names according to the provided efficient finetuning method.
|
673
|
-
"""
|
674
|
-
trainable_param_names = []
|
675
|
-
|
676
|
-
if efficient_finetune == BIT_FIT:
|
677
|
-
trainable_param_names.append(".*bias*.")
|
678
|
-
elif efficient_finetune == NORM_FIT:
|
679
|
-
trainable_param_names.append(".*bias*.")
|
680
|
-
trainable_param_names += norm_param_names
|
681
|
-
elif efficient_finetune in [LORA, IA3, IA3_LORA, CONV_LORA]:
|
682
|
-
trainable_param_names.append(".*lora_*.")
|
683
|
-
elif efficient_finetune in [LORA_BIAS, IA3_BIAS, IA3_LORA_BIAS]:
|
684
|
-
trainable_param_names.append(".*lora_*.")
|
685
|
-
trainable_param_names.append(".*bias*.")
|
686
|
-
elif efficient_finetune in [LORA_NORM, IA3_NORM, IA3_LORA_NORM]:
|
687
|
-
trainable_param_names.append(".*lora_*.")
|
688
|
-
trainable_param_names.append(".*bias*.")
|
689
|
-
trainable_param_names += norm_param_names
|
690
|
-
elif efficient_finetune is not None and efficient_finetune != "None":
|
691
|
-
raise NotImplementedError(
|
692
|
-
f"The efficient finetuning strategy '{efficient_finetune}'"
|
693
|
-
f" is not supported. We only support"
|
694
|
-
f" {', '.join(PEFT_STRATEGIES)}."
|
695
|
-
)
|
696
|
-
|
697
|
-
if extra_params:
|
698
|
-
trainable_param_names.extend(extra_params)
|
699
|
-
|
700
|
-
return trainable_param_names
|
701
|
-
|
702
|
-
|
703
|
-
def remove_parameters_without_grad(
|
704
|
-
grouped_parameters: List[Dict],
|
705
|
-
):
|
706
|
-
"""
|
707
|
-
Remove layers
|
708
|
-
|
709
|
-
Parameters
|
710
|
-
----------
|
711
|
-
grouped_parameters
|
712
|
-
The grouped parameters or their names output from lr_choice.
|
713
|
-
|
714
|
-
Returns
|
715
|
-
-------
|
716
|
-
The updated grouped parameters or their names.
|
717
|
-
"""
|
718
|
-
for group_idx, group_param in enumerate(grouped_parameters):
|
719
|
-
updated_params = []
|
720
|
-
for p in group_param["params"]:
|
721
|
-
if p.requires_grad:
|
722
|
-
updated_params.append(p)
|
723
|
-
grouped_parameters[group_idx]["params"] = updated_params
|
724
|
-
|
725
|
-
return grouped_parameters
|
726
|
-
|
727
|
-
|
728
|
-
def apply_layerwise_lr_decay(
|
729
|
-
model: nn.Module,
|
730
|
-
lr: float,
|
731
|
-
lr_decay: float,
|
732
|
-
weight_decay: float,
|
733
|
-
efficient_finetune: Optional[str] = None,
|
734
|
-
trainable_param_names: Optional[List] = None,
|
735
|
-
):
|
736
|
-
"""
|
737
|
-
Assign monotonically decreasing learning rates for layers from the output end to the input end.
|
738
|
-
The intuition behind is that later layers are more task-related compared to the early layers.
|
739
|
-
Layer normalization parameters and other layers' bias parameters don't use weight decay.
|
740
|
-
If you want to double-check whether the learning rate setup is as expected,
|
741
|
-
you can print the layer names along with their learning rates through
|
742
|
-
"print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))".
|
743
|
-
|
744
|
-
Parameters
|
745
|
-
----------
|
746
|
-
model
|
747
|
-
A Pytorch model.
|
748
|
-
lr
|
749
|
-
The learning rate.
|
750
|
-
lr_decay
|
751
|
-
The learning rate decay factor (0, 1).
|
752
|
-
weight_decay
|
753
|
-
Weight decay.
|
754
|
-
efficient_finetune
|
755
|
-
Efficient finetuning strategy. It will only finetune part of the parameters
|
756
|
-
|
757
|
-
Returns
|
758
|
-
-------
|
759
|
-
The grouped parameters based on their layer ids and whether using weight decay.
|
760
|
-
"""
|
761
|
-
parameter_group_names = {}
|
762
|
-
parameter_group_vars = {}
|
763
|
-
decay_param_names = get_weight_decay_param_names(model)
|
764
|
-
|
765
|
-
for name, param in model.named_parameters():
|
766
|
-
if name.startswith("_orig_mod."):
|
767
|
-
name = "".join(name.split("_orig_mod."))
|
768
|
-
layer_id = model.name_to_id[name]
|
769
|
-
if layer_id == 0: # Set top layer (e.g. head, fusion_mlp, adapter) as being trainable.
|
770
|
-
param.requires_grad = True
|
771
|
-
elif (
|
772
|
-
efficient_finetune is not None
|
773
|
-
and efficient_finetune != "None"
|
774
|
-
and trainable_param_names
|
775
|
-
and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
|
776
|
-
):
|
777
|
-
param.requires_grad = False
|
778
|
-
|
779
|
-
if not param.requires_grad:
|
780
|
-
continue # frozen weights
|
781
|
-
|
782
|
-
if name in decay_param_names:
|
783
|
-
group_name = "decay"
|
784
|
-
this_weight_decay = weight_decay
|
785
|
-
else:
|
786
|
-
group_name = "no_decay"
|
787
|
-
this_weight_decay = 0.0
|
788
|
-
|
789
|
-
layer_id = model.name_to_id[name]
|
790
|
-
group_name = "layer_%d_%s" % (layer_id, group_name)
|
791
|
-
|
792
|
-
if group_name not in parameter_group_names:
|
793
|
-
scale = lr_decay**layer_id
|
794
|
-
parameter_group_names[group_name] = {
|
795
|
-
"weight_decay": this_weight_decay,
|
796
|
-
"params": [],
|
797
|
-
"lr": scale * lr,
|
798
|
-
}
|
799
|
-
parameter_group_vars[group_name] = {
|
800
|
-
"weight_decay": this_weight_decay,
|
801
|
-
"params": [],
|
802
|
-
"lr": scale * lr,
|
803
|
-
}
|
804
|
-
|
805
|
-
parameter_group_vars[group_name]["params"].append(param)
|
806
|
-
parameter_group_names[group_name]["params"].append(name)
|
807
|
-
|
808
|
-
return list(parameter_group_vars.values())
|
809
|
-
|
810
|
-
|
811
|
-
def gather_column_features(
|
812
|
-
output: Dict[str, Dict],
|
813
|
-
column_names: Union[str, List[str]],
|
814
|
-
):
|
815
|
-
"""
|
816
|
-
Gather column features from models' outputs.
|
817
|
-
For each feature name in one model's output, we enumerate the provided column names to see
|
818
|
-
whether (partial) the provided columns share one cls feature or they have independent features.
|
819
|
-
|
820
|
-
TODO: return features' masks and use them to filter the losses.
|
821
|
-
|
822
|
-
Parameters
|
823
|
-
----------
|
824
|
-
output
|
825
|
-
The models' outputs.
|
826
|
-
column_names
|
827
|
-
The columns whose features we want to get.
|
828
|
-
|
829
|
-
Returns
|
830
|
-
-------
|
831
|
-
The gathered feature vectors. Each sample should only have one feature vector.
|
832
|
-
"""
|
833
|
-
if isinstance(column_names, str):
|
834
|
-
column_names = [column_names]
|
835
|
-
|
836
|
-
gathered_features = []
|
837
|
-
# logger.debug(f"gather features for columns: {column_names}")
|
838
|
-
for per_model_name, per_model_output in output.items():
|
839
|
-
# logger.debug(f"gather column features from model: {per_model_name}")
|
840
|
-
for feature_name in per_model_output[COLUMN_FEATURES][FEATURES]:
|
841
|
-
# logger.debug(f"processing feature: {feature_name}")
|
842
|
-
columns_share_one_feature = []
|
843
|
-
for col_name in column_names:
|
844
|
-
if col_name in feature_name:
|
845
|
-
# this column feature is part of the cls feature
|
846
|
-
if not (feature_name.startswith(col_name) and feature_name.endswith(col_name)):
|
847
|
-
columns_share_one_feature.append(col_name)
|
848
|
-
# logger.debug(f"column {col_name} is included in feature {feature_name}")
|
849
|
-
else: # this column's feature is independent of other columns'
|
850
|
-
gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][col_name])
|
851
|
-
# logger.debug(f"col_name {col_name} has an independent feature in model: {per_model_name}")
|
852
|
-
|
853
|
-
# two or more columns share one cls feature, and no other columns share it.
|
854
|
-
if len(columns_share_one_feature) > 0:
|
855
|
-
assert (
|
856
|
-
len("_".join(columns_share_one_feature)) == len(feature_name)
|
857
|
-
), f"model `{per_model_name}`'s cls feature name `{feature_name}` doesn't match `{columns_share_one_feature}`"
|
858
|
-
gathered_features.append(per_model_output[COLUMN_FEATURES][FEATURES][feature_name])
|
859
|
-
|
860
|
-
if len(gathered_features) > 1:
|
861
|
-
# currently only support features of the same shape
|
862
|
-
assert all(
|
863
|
-
per_features.shape == gathered_features[0].shape for per_features in gathered_features
|
864
|
-
), "Currently we only support gathering features of the same dimension."
|
865
|
-
|
866
|
-
if len(gathered_features) == 0:
|
867
|
-
raise ValueError(f"No features are found for columns names {column_names}.")
|
868
|
-
|
869
|
-
gathered_features = torch.stack(gathered_features, dim=0).mean(dim=0) # (b, d)
|
870
|
-
|
871
|
-
return gathered_features
|
872
|
-
|
873
|
-
|
874
|
-
def get_metric_learning_distance_func(
|
875
|
-
name: str,
|
876
|
-
):
|
877
|
-
"""
|
878
|
-
Return one pytorch metric learning's distance function based on its name.
|
879
|
-
|
880
|
-
Parameters
|
881
|
-
----------
|
882
|
-
name
|
883
|
-
distance function name
|
884
|
-
|
885
|
-
Returns
|
886
|
-
-------
|
887
|
-
A distance function from the pytorch metric learning package.
|
888
|
-
"""
|
889
|
-
if name.lower() == COSINE_SIMILARITY:
|
890
|
-
return distances.CosineSimilarity()
|
891
|
-
else:
|
892
|
-
raise ValueError(f"Unknown distance measure: {name}")
|
893
|
-
|
894
|
-
|
895
|
-
def infer_matcher_loss(data_format: str, problem_type: str):
|
896
|
-
"""
|
897
|
-
Infer the loss type to train the matcher.
|
898
|
-
|
899
|
-
Parameters
|
900
|
-
----------
|
901
|
-
data_format
|
902
|
-
The training data format, e.g., pair or triplet.
|
903
|
-
problem_type
|
904
|
-
Type of problem.
|
905
|
-
|
906
|
-
Returns
|
907
|
-
-------
|
908
|
-
The loss name.
|
909
|
-
"""
|
910
|
-
if data_format == "pair":
|
911
|
-
if problem_type is None:
|
912
|
-
return [MULTI_NEGATIVES_SOFTMAX_LOSS]
|
913
|
-
elif problem_type == BINARY:
|
914
|
-
return [CONTRASTIVE_LOSS]
|
915
|
-
elif problem_type == REGRESSION:
|
916
|
-
return ["cosine_similarity_loss"]
|
917
|
-
else:
|
918
|
-
raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
|
919
|
-
elif data_format == "triplet":
|
920
|
-
if problem_type is None:
|
921
|
-
return [MULTI_NEGATIVES_SOFTMAX_LOSS]
|
922
|
-
else:
|
923
|
-
raise ValueError(f"Unsupported data format {data_format} with problem type {problem_type}")
|
924
|
-
else:
|
925
|
-
raise ValueError(f"Unsupported data format: {data_format}")
|
926
|
-
|
927
|
-
|
928
|
-
def get_matcher_loss_func(
|
929
|
-
data_format: str,
|
930
|
-
problem_type: str,
|
931
|
-
loss_type: Optional[str] = None,
|
932
|
-
pos_margin: Optional[float] = None,
|
933
|
-
neg_margin: Optional[float] = None,
|
934
|
-
distance_type: Optional[str] = None,
|
935
|
-
):
|
936
|
-
"""
|
937
|
-
Return a list of pytorch metric learning's loss functions based on their names.
|
938
|
-
|
939
|
-
Parameters
|
940
|
-
----------
|
941
|
-
data_format
|
942
|
-
The training data format, e.g., pair or triplet.
|
943
|
-
problem_type
|
944
|
-
Type of problem.
|
945
|
-
loss_type
|
946
|
-
The provided loss type.
|
947
|
-
pos_margin
|
948
|
-
The positive margin in computing the metric learning loss.
|
949
|
-
neg_margin
|
950
|
-
The negative margin in computing the metric learning loss.
|
951
|
-
distance_type
|
952
|
-
The distance function type.
|
953
|
-
|
954
|
-
Returns
|
955
|
-
-------
|
956
|
-
A loss function of metric learning.
|
957
|
-
"""
|
958
|
-
|
959
|
-
allowable_loss_types = infer_matcher_loss(data_format=data_format, problem_type=problem_type)
|
960
|
-
if loss_type is not None:
|
961
|
-
assert loss_type in allowable_loss_types, f"data format {data_format} can't use loss {loss_type}."
|
962
|
-
else:
|
963
|
-
loss_type = allowable_loss_types[0]
|
964
|
-
|
965
|
-
if loss_type.lower() == CONTRASTIVE_LOSS:
|
966
|
-
return losses.ContrastiveLoss(
|
967
|
-
pos_margin=pos_margin,
|
968
|
-
neg_margin=neg_margin,
|
969
|
-
distance=get_metric_learning_distance_func(distance_type),
|
970
|
-
)
|
971
|
-
elif loss_type.lower() == MULTI_NEGATIVES_SOFTMAX_LOSS:
|
972
|
-
return MultiNegativesSoftmaxLoss(
|
973
|
-
local_loss=True,
|
974
|
-
gather_with_grad=True,
|
975
|
-
cache_labels=False,
|
976
|
-
)
|
977
|
-
else:
|
978
|
-
raise ValueError(f"Unknown metric learning loss: {loss_type}")
|
979
|
-
|
980
|
-
|
981
|
-
def get_matcher_miner_func(
|
982
|
-
miner_type: str,
|
983
|
-
pos_margin: float,
|
984
|
-
neg_margin: float,
|
985
|
-
distance_type: str,
|
986
|
-
):
|
987
|
-
"""
|
988
|
-
Return a pytorch metric learning's miner functions based on their names.
|
989
|
-
The miners are used to mine the positive and negative examples.
|
990
|
-
|
991
|
-
Parameters
|
992
|
-
----------
|
993
|
-
miner_type
|
994
|
-
The miner function type.
|
995
|
-
pos_margin
|
996
|
-
The positive margin used by the miner function.
|
997
|
-
neg_margin
|
998
|
-
The negative margin used by the miner function.
|
999
|
-
distance_type
|
1000
|
-
The distance function type.
|
1001
|
-
|
1002
|
-
Returns
|
1003
|
-
-------
|
1004
|
-
A miner function to mine positive and negative samples.
|
1005
|
-
"""
|
1006
|
-
if miner_type.lower() == PAIR_MARGIN_MINER:
|
1007
|
-
return miners.PairMarginMiner(
|
1008
|
-
pos_margin=pos_margin,
|
1009
|
-
neg_margin=neg_margin,
|
1010
|
-
distance=get_metric_learning_distance_func(distance_type),
|
1011
|
-
)
|
1012
|
-
else:
|
1013
|
-
raise ValueError(f"Unknown metric learning miner: {miner_type}")
|
1014
|
-
|
1015
|
-
|
1016
|
-
def generate_metric_learning_labels(
|
1017
|
-
num_samples: int,
|
1018
|
-
match_label: int,
|
1019
|
-
labels: torch.Tensor,
|
1020
|
-
):
|
1021
|
-
"""
|
1022
|
-
Generate labels to compute the metric learning loss of one mini-batch.
|
1023
|
-
For n samples, it generates 2*n labels since each match has two sides, each of which
|
1024
|
-
has one label. If we know the matching label, then it determines the two sides' labels
|
1025
|
-
according to whether their label is the matching label. If the matching label is None,
|
1026
|
-
it assigns a unique label for each side.
|
1027
|
-
|
1028
|
-
Parameters
|
1029
|
-
----------
|
1030
|
-
num_samples
|
1031
|
-
number of samples.
|
1032
|
-
match_label
|
1033
|
-
The matching label, which can be None.
|
1034
|
-
labels
|
1035
|
-
The sample labels used in the supervised setting. It's required only when match_label is not None.
|
1036
|
-
|
1037
|
-
Returns
|
1038
|
-
-------
|
1039
|
-
The labels used in computing the metric learning loss.
|
1040
|
-
"""
|
1041
|
-
device = labels.device
|
1042
|
-
labels_1 = torch.arange(num_samples, device=device)
|
1043
|
-
|
1044
|
-
if match_label is not None:
|
1045
|
-
labels_2 = torch.arange(num_samples, num_samples * 2, device=device)
|
1046
|
-
# users need to specify the match_label based on the raw label's semantic meaning.
|
1047
|
-
mask = labels == match_label
|
1048
|
-
labels_2[mask] = labels_1[mask]
|
1049
|
-
else:
|
1050
|
-
labels_2 = torch.arange(num_samples, device=device)
|
1051
|
-
|
1052
|
-
metric_learning_labels = torch.cat([labels_1, labels_2], dim=0)
|
1053
|
-
|
1054
|
-
return metric_learning_labels
|