autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -0,0 +1,332 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, optim
|
7
|
+
|
8
|
+
from ..utils import get_weight_decay_param_names
|
9
|
+
from .lr_schedulers import (
|
10
|
+
get_cosine_schedule_with_warmup,
|
11
|
+
get_linear_schedule_with_warmup,
|
12
|
+
get_polynomial_decay_schedule_with_warmup,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def get_lr_scheduler(
|
19
|
+
optimizer: optim.Optimizer,
|
20
|
+
num_max_steps: int,
|
21
|
+
num_warmup_steps: int,
|
22
|
+
lr_schedule: str,
|
23
|
+
end_lr: Union[float, int],
|
24
|
+
):
|
25
|
+
"""
|
26
|
+
Get the learning rate scheduler from its name. Here we use our defined learning rate
|
27
|
+
scheduler instead of those imported from "transformers" because we want to support
|
28
|
+
Pytorch lightning's "ddp_spawn" training strategy.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
optimizer
|
33
|
+
A Pytorch optimizer.
|
34
|
+
num_max_steps
|
35
|
+
Number of maximum training steps.
|
36
|
+
num_warmup_steps
|
37
|
+
Number of steps to do learning rate warmup.
|
38
|
+
lr_schedule
|
39
|
+
Name of the learning rate scheduler.
|
40
|
+
end_lr
|
41
|
+
The final learning rate after decay.
|
42
|
+
|
43
|
+
Returns
|
44
|
+
-------
|
45
|
+
A learning rate scheduler.
|
46
|
+
"""
|
47
|
+
if lr_schedule == "cosine_decay":
|
48
|
+
scheduler = get_cosine_schedule_with_warmup(
|
49
|
+
optimizer=optimizer,
|
50
|
+
num_warmup_steps=num_warmup_steps,
|
51
|
+
num_training_steps=num_max_steps,
|
52
|
+
)
|
53
|
+
elif lr_schedule == "polynomial_decay":
|
54
|
+
scheduler = get_polynomial_decay_schedule_with_warmup(
|
55
|
+
optimizer=optimizer,
|
56
|
+
num_warmup_steps=num_warmup_steps,
|
57
|
+
num_training_steps=num_max_steps,
|
58
|
+
lr_end=end_lr,
|
59
|
+
power=1,
|
60
|
+
)
|
61
|
+
elif lr_schedule == "linear_decay":
|
62
|
+
scheduler = get_linear_schedule_with_warmup(
|
63
|
+
optimizer=optimizer,
|
64
|
+
num_warmup_steps=num_warmup_steps,
|
65
|
+
num_training_steps=num_max_steps,
|
66
|
+
)
|
67
|
+
elif lr_schedule == "multi_step":
|
68
|
+
# TODO: add milestones, gamma into hyperparameters
|
69
|
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[30, 55], gamma=0.1)
|
70
|
+
else:
|
71
|
+
raise ValueError(f"unknown lr schedule: {lr_schedule}")
|
72
|
+
|
73
|
+
return scheduler
|
74
|
+
|
75
|
+
|
76
|
+
def apply_single_lr(
|
77
|
+
model: nn.Module,
|
78
|
+
lr: float,
|
79
|
+
weight_decay: float,
|
80
|
+
return_params: Optional[bool] = True,
|
81
|
+
peft: Optional[str] = None,
|
82
|
+
trainable_param_names: Optional[List] = None,
|
83
|
+
exclude_keys: Optional[List] = None,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Set to use a single learning rate for all parameters. Layer normalization parameters and other
|
87
|
+
layers' bias parameters don't use weight decay.
|
88
|
+
|
89
|
+
Parameters
|
90
|
+
----------
|
91
|
+
model
|
92
|
+
A Pytorch model.
|
93
|
+
lr
|
94
|
+
Learning rate.
|
95
|
+
weight_decay
|
96
|
+
Weight decay.
|
97
|
+
return_params
|
98
|
+
Whether to return parameters or their names. If you want to double-check
|
99
|
+
whether the learning rate setup is as expected, you can set "return_params=False",
|
100
|
+
and print the layer names along with their learning rates through
|
101
|
+
"print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
|
102
|
+
peft
|
103
|
+
Efficient finetuning strategy. It will only finetune part of the parameters
|
104
|
+
trainable_param_names
|
105
|
+
A list of trainable parameters. (Optional)
|
106
|
+
exclude_keys
|
107
|
+
A list of keys to be excluded.
|
108
|
+
|
109
|
+
Returns
|
110
|
+
-------
|
111
|
+
The grouped parameters or their names.
|
112
|
+
"""
|
113
|
+
decay_param_names = get_weight_decay_param_names(model)
|
114
|
+
decay_grad_param_names = []
|
115
|
+
no_decay_grad_param_names = []
|
116
|
+
|
117
|
+
for name, param in model.named_parameters():
|
118
|
+
if exclude_keys and any([exc in name for exc in exclude_keys]):
|
119
|
+
continue
|
120
|
+
|
121
|
+
if (
|
122
|
+
peft is not None
|
123
|
+
and trainable_param_names
|
124
|
+
and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
|
125
|
+
):
|
126
|
+
param.requires_grad = False
|
127
|
+
|
128
|
+
if not param.requires_grad:
|
129
|
+
continue # frozen weights
|
130
|
+
|
131
|
+
if name in decay_param_names:
|
132
|
+
if return_params:
|
133
|
+
decay_grad_param_names.append(param)
|
134
|
+
else:
|
135
|
+
decay_grad_param_names.append(name)
|
136
|
+
|
137
|
+
else:
|
138
|
+
if return_params:
|
139
|
+
no_decay_grad_param_names.append(param)
|
140
|
+
else:
|
141
|
+
no_decay_grad_param_names.append(name)
|
142
|
+
|
143
|
+
optimizer_grouped_parameters = [
|
144
|
+
{
|
145
|
+
"params": decay_grad_param_names,
|
146
|
+
"weight_decay": weight_decay,
|
147
|
+
"lr": lr,
|
148
|
+
},
|
149
|
+
{
|
150
|
+
"params": no_decay_grad_param_names,
|
151
|
+
"weight_decay": 0.0,
|
152
|
+
"lr": lr,
|
153
|
+
},
|
154
|
+
]
|
155
|
+
return optimizer_grouped_parameters
|
156
|
+
|
157
|
+
|
158
|
+
def apply_two_stages_lr(
|
159
|
+
model: nn.Module,
|
160
|
+
lr: float,
|
161
|
+
lr_mult: Union[float, int],
|
162
|
+
weight_decay: float,
|
163
|
+
return_params: Optional[bool] = True,
|
164
|
+
exclude_keys: Optional[List] = None,
|
165
|
+
):
|
166
|
+
"""
|
167
|
+
Set up the pretrained backbone to use a smaller learning rate (lr * lr_mult).
|
168
|
+
The newly added head layers use the normal learning rate (lr).
|
169
|
+
Layer normalization parameters and other layers' bias parameters don't use weight decay.
|
170
|
+
|
171
|
+
Parameters
|
172
|
+
----------
|
173
|
+
model
|
174
|
+
A Pytorch model.
|
175
|
+
lr
|
176
|
+
The learning rate.
|
177
|
+
lr_mult
|
178
|
+
The multiplier (0, 1) to scale down the learning rate.
|
179
|
+
weight_decay
|
180
|
+
Weight decay.
|
181
|
+
return_params
|
182
|
+
return_params
|
183
|
+
Whether to return parameters or their names. If you want to double-check
|
184
|
+
whether the learning rate setup is as expected, you can set "return_params=False",
|
185
|
+
and print the layer names along with their learning rates through
|
186
|
+
"print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
|
187
|
+
exclude_keys
|
188
|
+
A list of keys to be excluded.
|
189
|
+
|
190
|
+
Returns
|
191
|
+
-------
|
192
|
+
The grouped parameters or their names.
|
193
|
+
"""
|
194
|
+
decay_param_names = get_weight_decay_param_names(model)
|
195
|
+
|
196
|
+
optimizer_grouped_parameters = [
|
197
|
+
{
|
198
|
+
"params": [
|
199
|
+
p if return_params else n
|
200
|
+
for n, p in model.named_parameters()
|
201
|
+
if n in decay_param_names
|
202
|
+
and not any(bb in n for bb in model.head_layer_names)
|
203
|
+
and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
|
204
|
+
],
|
205
|
+
"weight_decay": weight_decay,
|
206
|
+
"lr": lr,
|
207
|
+
},
|
208
|
+
{
|
209
|
+
"params": [
|
210
|
+
p if return_params else n
|
211
|
+
for n, p in model.named_parameters()
|
212
|
+
if n not in decay_param_names
|
213
|
+
and not any(bb in n for bb in model.head_layer_names)
|
214
|
+
and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
|
215
|
+
],
|
216
|
+
"weight_decay": 0.0,
|
217
|
+
"lr": lr,
|
218
|
+
},
|
219
|
+
{
|
220
|
+
"params": [
|
221
|
+
p if return_params else n
|
222
|
+
for n, p in model.named_parameters()
|
223
|
+
if n in decay_param_names
|
224
|
+
and any(bb in n for bb in model.head_layer_names)
|
225
|
+
and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
|
226
|
+
],
|
227
|
+
"weight_decay": weight_decay,
|
228
|
+
"lr": lr * lr_mult,
|
229
|
+
},
|
230
|
+
{
|
231
|
+
"params": [
|
232
|
+
p if return_params else n
|
233
|
+
for n, p in model.named_parameters()
|
234
|
+
if n not in decay_param_names
|
235
|
+
and any(bb in n for bb in model.head_layer_names)
|
236
|
+
and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
|
237
|
+
],
|
238
|
+
"weight_decay": 0.0,
|
239
|
+
"lr": lr * lr_mult,
|
240
|
+
},
|
241
|
+
]
|
242
|
+
|
243
|
+
return optimizer_grouped_parameters
|
244
|
+
|
245
|
+
|
246
|
+
def apply_layerwise_lr_decay(
|
247
|
+
model: nn.Module,
|
248
|
+
lr: float,
|
249
|
+
lr_decay: float,
|
250
|
+
weight_decay: float,
|
251
|
+
peft: Optional[str] = None,
|
252
|
+
trainable_param_names: Optional[List] = None,
|
253
|
+
exclude_keys: Optional[List] = None,
|
254
|
+
):
|
255
|
+
"""
|
256
|
+
Assign monotonically decreasing learning rates for layers from the output end to the input end.
|
257
|
+
The intuition behind is that later layers are more task-related compared to the early layers.
|
258
|
+
Layer normalization parameters and other layers' bias parameters don't use weight decay.
|
259
|
+
If you want to double-check whether the learning rate setup is as expected,
|
260
|
+
you can print the layer names along with their learning rates through
|
261
|
+
"print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))".
|
262
|
+
|
263
|
+
Parameters
|
264
|
+
----------
|
265
|
+
model
|
266
|
+
A Pytorch model.
|
267
|
+
lr
|
268
|
+
The learning rate.
|
269
|
+
lr_decay
|
270
|
+
The learning rate decay factor (0, 1).
|
271
|
+
weight_decay
|
272
|
+
Weight decay.
|
273
|
+
peft
|
274
|
+
Efficient finetuning strategy. It will only finetune part of the parameters
|
275
|
+
trainable_param_names
|
276
|
+
A list of trainable parameters. (Optional)
|
277
|
+
exclude_keys
|
278
|
+
A list of keys to be excluded.
|
279
|
+
|
280
|
+
Returns
|
281
|
+
-------
|
282
|
+
The grouped parameters based on their layer ids and whether using weight decay.
|
283
|
+
"""
|
284
|
+
parameter_group_names = {}
|
285
|
+
parameter_group_vars = {}
|
286
|
+
decay_param_names = get_weight_decay_param_names(model)
|
287
|
+
|
288
|
+
for name, param in model.named_parameters():
|
289
|
+
if name.startswith("_orig_mod."):
|
290
|
+
name = "".join(name.split("_orig_mod."))
|
291
|
+
if exclude_keys and any([exc in name for exc in exclude_keys]):
|
292
|
+
continue
|
293
|
+
layer_id = model.name_to_id[name]
|
294
|
+
if layer_id == 0: # Set top layer (e.g. head, fusion_mlp, adapter) as being trainable.
|
295
|
+
param.requires_grad = True
|
296
|
+
elif (
|
297
|
+
peft is not None
|
298
|
+
and trainable_param_names
|
299
|
+
and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
|
300
|
+
):
|
301
|
+
param.requires_grad = False
|
302
|
+
|
303
|
+
if not param.requires_grad:
|
304
|
+
continue # frozen weights
|
305
|
+
|
306
|
+
if name in decay_param_names:
|
307
|
+
group_name = "decay"
|
308
|
+
this_weight_decay = weight_decay
|
309
|
+
else:
|
310
|
+
group_name = "no_decay"
|
311
|
+
this_weight_decay = 0.0
|
312
|
+
|
313
|
+
layer_id = model.name_to_id[name]
|
314
|
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
315
|
+
|
316
|
+
if group_name not in parameter_group_names:
|
317
|
+
scale = lr_decay**layer_id
|
318
|
+
parameter_group_names[group_name] = {
|
319
|
+
"weight_decay": this_weight_decay,
|
320
|
+
"params": [],
|
321
|
+
"lr": scale * lr,
|
322
|
+
}
|
323
|
+
parameter_group_vars[group_name] = {
|
324
|
+
"weight_decay": this_weight_decay,
|
325
|
+
"params": [],
|
326
|
+
"lr": scale * lr,
|
327
|
+
}
|
328
|
+
|
329
|
+
parameter_group_vars[group_name]["params"].append(param)
|
330
|
+
parameter_group_names[group_name]["params"].append(name)
|
331
|
+
|
332
|
+
return list(parameter_group_vars.values())
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import torch
|
2
|
+
from torchmetrics import Metric
|
3
|
+
from torchmetrics.utilities import dim_zero_cat
|
4
|
+
|
5
|
+
|
6
|
+
class Coverage(Metric):
|
7
|
+
def __init__(self, **kwargs):
|
8
|
+
super().__init__(**kwargs)
|
9
|
+
self.add_state("pos_probs", default=[], dist_reduce_fx="cat")
|
10
|
+
self.add_state("targets", default=[], dist_reduce_fx="cat")
|
11
|
+
self.tp_threshold = 0.97
|
12
|
+
self.tn_threshold = 0.99
|
13
|
+
higher_is_better = True
|
14
|
+
|
15
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
16
|
+
assert preds.dim() == 1
|
17
|
+
assert target.dim() == 1
|
18
|
+
self.pos_probs.append(preds)
|
19
|
+
self.targets.append(target)
|
20
|
+
|
21
|
+
def compute(self):
|
22
|
+
# parse inputs
|
23
|
+
pos_probs = dim_zero_cat(self.pos_probs)
|
24
|
+
targets = dim_zero_cat(self.targets)
|
25
|
+
y_pos = targets[torch.where(pos_probs >= self.tp_threshold)]
|
26
|
+
y_neg = targets[torch.where(pos_probs <= 1 - self.tn_threshold)]
|
27
|
+
tp = sum(y_pos == 1)
|
28
|
+
tn = sum(y_neg == 0)
|
29
|
+
if len(y_pos) == 0 or len(targets) == 0:
|
30
|
+
tp_precision, tp_coverage = 0, 0
|
31
|
+
else:
|
32
|
+
tp_precision, tp_coverage = tp / len(y_pos), len(y_pos) / len(targets)
|
33
|
+
if tp_precision < self.tp_threshold:
|
34
|
+
tp_coverage = 0
|
35
|
+
if len(y_neg) == 0 or len(targets) == 0:
|
36
|
+
tn_precision, tn_coverage = 0, 0
|
37
|
+
else:
|
38
|
+
tn_precision, tn_coverage = tn / len(y_neg), len(y_neg) / len(targets)
|
39
|
+
if tn_precision < self.tn_threshold:
|
40
|
+
tn_coverage = 0
|
41
|
+
|
42
|
+
return torch.tensor(tp_coverage) + torch.tensor(tn_coverage)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torchmetrics
|
5
|
+
|
6
|
+
|
7
|
+
class CustomHitRate(torchmetrics.Metric):
|
8
|
+
"""
|
9
|
+
Compute the hit rate when doing semantic search between two group of embeddings.
|
10
|
+
We assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
):
|
16
|
+
super().__init__()
|
17
|
+
self.add_state("query_embeddings", default=[], dist_reduce_fx=None)
|
18
|
+
self.add_state("response_embeddings", default=[], dist_reduce_fx=None)
|
19
|
+
self.add_state("logit_scale", default=[], dist_reduce_fx=None)
|
20
|
+
|
21
|
+
def update(
|
22
|
+
self,
|
23
|
+
batch_query_embeds: torch.Tensor,
|
24
|
+
batch_response_embeds: torch.Tensor,
|
25
|
+
logit_scale: Optional[torch.Tensor] = None,
|
26
|
+
):
|
27
|
+
self.query_embeddings.append(batch_query_embeds)
|
28
|
+
self.response_embeddings.append(batch_response_embeds)
|
29
|
+
if logit_scale is not None:
|
30
|
+
self.logit_scale.append(logit_scale)
|
31
|
+
|
32
|
+
def compute(self):
|
33
|
+
query_embeddings = torch.cat(self.query_embeddings)
|
34
|
+
response_embeddings = torch.cat(self.response_embeddings)
|
35
|
+
if self.logit_scale:
|
36
|
+
logit_scale = torch.mean(torch.stack(self.logit_scale))
|
37
|
+
else:
|
38
|
+
logit_scale = 1
|
39
|
+
|
40
|
+
return self.compute_hit_rate(query_embeddings, response_embeddings, logit_scale)
|
41
|
+
|
42
|
+
@staticmethod
|
43
|
+
def compute_hit_rate(features_a, features_b, logit_scale, top_ks=[1, 5, 10]):
|
44
|
+
"""
|
45
|
+
Compute symmetric hit rates between two groups of features.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
features_a
|
50
|
+
One group of features.
|
51
|
+
features_b
|
52
|
+
The other group of features.
|
53
|
+
logit_scale
|
54
|
+
The scale of logit (Used in CLIP).
|
55
|
+
top_ks
|
56
|
+
Consider only the top k elements for each query.
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
The accumulated hit rate.
|
61
|
+
"""
|
62
|
+
assert len(features_a) == len(features_b)
|
63
|
+
hit_rate = 0
|
64
|
+
logits_per_a = (logit_scale * features_a @ features_b.t()).detach().cpu()
|
65
|
+
logits_per_b = logits_per_a.t().detach().cpu()
|
66
|
+
|
67
|
+
logits = {"logits_per_a": logits_per_a, "logits_per_b": logits_per_b}
|
68
|
+
ground_truth = torch.arange(len(features_b)).view(-1, 1)
|
69
|
+
|
70
|
+
for name, logit in logits.items():
|
71
|
+
ranking = torch.argsort(logit, descending=True)
|
72
|
+
preds = torch.where(ranking == ground_truth)[1]
|
73
|
+
|
74
|
+
for k in top_ks:
|
75
|
+
hit_rate += (preds < k).float().mean()
|
76
|
+
|
77
|
+
hit_rate /= len(top_ks) * len(logits)
|
78
|
+
return hit_rate
|
@@ -0,0 +1,231 @@
|
|
1
|
+
import logging
|
2
|
+
import math
|
3
|
+
import operator
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from ...constants import MAP, NDCG, PRECISION, RECALL
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class RankingMetrics:
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
pred: Dict[str, Dict],
|
17
|
+
target: Dict[str, Dict],
|
18
|
+
is_higher_better=True,
|
19
|
+
):
|
20
|
+
"""
|
21
|
+
Evaluation Metrics for information retrieval tasks such as document retrieval, image retrieval, etc.
|
22
|
+
Reference: https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
pred:
|
27
|
+
the prediction of the ranking model. It has the following form.
|
28
|
+
pred = {
|
29
|
+
'q1': {
|
30
|
+
'd1': 1,
|
31
|
+
'd3': 0,
|
32
|
+
},
|
33
|
+
'q2': {
|
34
|
+
'd2': 1,
|
35
|
+
'd3': 1,
|
36
|
+
},
|
37
|
+
}
|
38
|
+
where q refers to queries, and d refers to documents, each query has a few relevant documents.
|
39
|
+
0s and 1s are model predicted scores (does not need to be binary).
|
40
|
+
target:
|
41
|
+
the ground truth query and response relevance which has the same form as pred.
|
42
|
+
is_higher_better:
|
43
|
+
if higher relevance score means higher ranking.
|
44
|
+
if the relevance score is cosine similarity / dot product, it should be set to True;
|
45
|
+
if it is Eulidean distance, it should be False.
|
46
|
+
"""
|
47
|
+
self.pred = pred
|
48
|
+
self.target = target
|
49
|
+
self.is_higher_better = is_higher_better
|
50
|
+
# the supported metrics in this script
|
51
|
+
self.supported_metrics = {
|
52
|
+
"precision": 0,
|
53
|
+
"recall": 1,
|
54
|
+
"mrr": 2,
|
55
|
+
"map": 3,
|
56
|
+
"ndcg": 4,
|
57
|
+
}
|
58
|
+
|
59
|
+
assert len(pred) == len(
|
60
|
+
target
|
61
|
+
), f"The prediction and groudtruth target should have the same number of queries, \
|
62
|
+
while there are {len(pred)} queries in prediction and {len(target)} in the target."
|
63
|
+
|
64
|
+
self.results = {}
|
65
|
+
for key in target.keys():
|
66
|
+
self.results.update({key: [target[key], pred[key]]})
|
67
|
+
|
68
|
+
def compute(self, metrics: Union[str, list] = None, k: Optional[int] = 10):
|
69
|
+
"""
|
70
|
+
compute and return ranking scores.
|
71
|
+
|
72
|
+
Parameters
|
73
|
+
----------
|
74
|
+
metrics:
|
75
|
+
user provided metrics
|
76
|
+
k:
|
77
|
+
the cutoff value for NDCG, MAP, Recall, MRR, and Precision
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
Computed score.
|
82
|
+
|
83
|
+
"""
|
84
|
+
if isinstance(metrics, str):
|
85
|
+
metrics = [metrics]
|
86
|
+
if not metrics: # no metric is provided
|
87
|
+
metrics = self.supported_metrics.keys()
|
88
|
+
|
89
|
+
return_res = {}
|
90
|
+
|
91
|
+
eval_res = np.mean(
|
92
|
+
[list(self._compute_one(idx, k)) for idx in self.results.keys()],
|
93
|
+
axis=0,
|
94
|
+
)
|
95
|
+
|
96
|
+
for metric in metrics:
|
97
|
+
metric = metric.lower()
|
98
|
+
if metric in self.supported_metrics:
|
99
|
+
return_res.update({f"{metric}@{k}": eval_res[self.supported_metrics[metric]]})
|
100
|
+
|
101
|
+
return return_res
|
102
|
+
|
103
|
+
def _compute_one(self, idx, k):
|
104
|
+
"""
|
105
|
+
compute and return the ranking scores for one query.
|
106
|
+
for definition of these metrics, please refer to
|
107
|
+
https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
idx:
|
112
|
+
the index of the query
|
113
|
+
k:
|
114
|
+
the cutoff value for NDCG, MAP, Recall, MRR, and Precision
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
Computed score.
|
119
|
+
"""
|
120
|
+
precision, recall, mrr, mAP, ndcg = 0, 0, 0, 0, 0
|
121
|
+
target, pred = self.results[idx][0], self.results[idx][1]
|
122
|
+
|
123
|
+
# sort the ground truth and predictions in descending order
|
124
|
+
sorted_target = dict(
|
125
|
+
sorted(
|
126
|
+
target.items(),
|
127
|
+
key=operator.itemgetter(1),
|
128
|
+
reverse=self.is_higher_better,
|
129
|
+
)
|
130
|
+
)
|
131
|
+
sorted_pred = dict(
|
132
|
+
sorted(
|
133
|
+
pred.items(),
|
134
|
+
key=operator.itemgetter(1),
|
135
|
+
reverse=self.is_higher_better,
|
136
|
+
)
|
137
|
+
)
|
138
|
+
sorted_target_values = list(sorted_target.values())
|
139
|
+
sorted_pred_values = list(sorted_pred.values())
|
140
|
+
|
141
|
+
# number of positive relevance in target
|
142
|
+
# negative numbers and zero are considered as negative response
|
143
|
+
num_pos_target = len([val for val in sorted_target_values if val > 0])
|
144
|
+
|
145
|
+
at_k = k if num_pos_target > k else num_pos_target
|
146
|
+
|
147
|
+
first_k_items_list = list(sorted_pred.items())[0:k]
|
148
|
+
|
149
|
+
rank = 0
|
150
|
+
hit_rank = [] # correctly retrieved items
|
151
|
+
for key, value in first_k_items_list:
|
152
|
+
if key in sorted_target and sorted_target[key] > 0:
|
153
|
+
hit_rank.append(rank)
|
154
|
+
rank += 1
|
155
|
+
count = len(hit_rank)
|
156
|
+
# compute the precision and recall
|
157
|
+
precision = count / k
|
158
|
+
recall = count / num_pos_target
|
159
|
+
|
160
|
+
dcg = 0
|
161
|
+
if hit_rank: # not empty
|
162
|
+
# compute the mean reciprocal rank
|
163
|
+
mrr = 1 / (hit_rank[0] + 1)
|
164
|
+
# compute the mean average precision
|
165
|
+
mAP = np.sum([sorted_pred_values[rank] * (i + 1) / (rank + 1) for i, rank in enumerate(hit_rank)])
|
166
|
+
# compute the discounted cumulative gain
|
167
|
+
dcg = np.sum([1 / math.log(rank + 2, 2) for rank in hit_rank])
|
168
|
+
|
169
|
+
# compute the ideal discounted cumulative gain
|
170
|
+
idcg = np.sum([1 / math.log(i + 2, 2) for i in range(at_k)])
|
171
|
+
# compute the normalized discounted cumulative gain
|
172
|
+
ndcg = dcg / idcg
|
173
|
+
mAP /= at_k
|
174
|
+
|
175
|
+
return precision, recall, mrr, mAP, ndcg
|
176
|
+
|
177
|
+
|
178
|
+
def compute_ranking_score(
|
179
|
+
results: Dict[str, Dict],
|
180
|
+
qrel_dict: Dict[str, Dict],
|
181
|
+
metrics: List[str],
|
182
|
+
cutoffs: Optional[List[int]] = [5, 10, 20],
|
183
|
+
):
|
184
|
+
"""
|
185
|
+
Compute the ranking metrics, e.g., NDCG, MAP, Recall, and Precision.
|
186
|
+
TODO: Consider MRR.
|
187
|
+
|
188
|
+
Parameters
|
189
|
+
----------
|
190
|
+
results:
|
191
|
+
The query/document ranking list by the model.
|
192
|
+
qrel_dict:
|
193
|
+
The groundtruth query and document relevance.
|
194
|
+
metrics
|
195
|
+
A list of metrics to compute.
|
196
|
+
cutoffs:
|
197
|
+
The cutoff values for NDCG, MAP, Recall, and Precision.
|
198
|
+
|
199
|
+
Returns
|
200
|
+
-------
|
201
|
+
A dict of metric scores.
|
202
|
+
"""
|
203
|
+
scores = {}
|
204
|
+
evaluator = RankingMetrics(pred=results, target=qrel_dict)
|
205
|
+
for k in cutoffs:
|
206
|
+
scores.update(evaluator.compute(k=k))
|
207
|
+
|
208
|
+
metric_results = dict()
|
209
|
+
for k in cutoffs:
|
210
|
+
for per_metric in metrics:
|
211
|
+
if per_metric.lower() == NDCG:
|
212
|
+
metric_results[f"{NDCG}@{k}"] = 0.0
|
213
|
+
elif per_metric.lower() == MAP:
|
214
|
+
metric_results[f"{MAP}@{k}"] = 0.0
|
215
|
+
elif per_metric.lower() == RECALL:
|
216
|
+
metric_results[f"{RECALL}@{k}"] = 0.0
|
217
|
+
elif per_metric.lower() == PRECISION:
|
218
|
+
metric_results[f"{PRECISION}@{k}"] = 0.0
|
219
|
+
|
220
|
+
for k in cutoffs:
|
221
|
+
for per_metric in metrics:
|
222
|
+
if per_metric.lower() == NDCG:
|
223
|
+
metric_results[f"{NDCG}@{k}"] = round(scores[f"{NDCG}@{k}"], 5)
|
224
|
+
elif per_metric.lower() == MAP:
|
225
|
+
metric_results[f"{MAP}@{k}"] = round(scores[f"{MAP}@{k}"], 5)
|
226
|
+
elif per_metric.lower() == RECALL:
|
227
|
+
metric_results[f"{RECALL}@{k}"] = round(scores[f"{RECALL}@{k}"], 5)
|
228
|
+
elif per_metric.lower() == PRECISION:
|
229
|
+
metric_results[f"{PRECISION}@{k}"] = round(scores[f"{PRECISION}@{k}"], 5)
|
230
|
+
|
231
|
+
return metric_results
|