autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,332 @@
1
+ import logging
2
+ import re
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn, optim
7
+
8
+ from ..utils import get_weight_decay_param_names
9
+ from .lr_schedulers import (
10
+ get_cosine_schedule_with_warmup,
11
+ get_linear_schedule_with_warmup,
12
+ get_polynomial_decay_schedule_with_warmup,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def get_lr_scheduler(
19
+ optimizer: optim.Optimizer,
20
+ num_max_steps: int,
21
+ num_warmup_steps: int,
22
+ lr_schedule: str,
23
+ end_lr: Union[float, int],
24
+ ):
25
+ """
26
+ Get the learning rate scheduler from its name. Here we use our defined learning rate
27
+ scheduler instead of those imported from "transformers" because we want to support
28
+ Pytorch lightning's "ddp_spawn" training strategy.
29
+
30
+ Parameters
31
+ ----------
32
+ optimizer
33
+ A Pytorch optimizer.
34
+ num_max_steps
35
+ Number of maximum training steps.
36
+ num_warmup_steps
37
+ Number of steps to do learning rate warmup.
38
+ lr_schedule
39
+ Name of the learning rate scheduler.
40
+ end_lr
41
+ The final learning rate after decay.
42
+
43
+ Returns
44
+ -------
45
+ A learning rate scheduler.
46
+ """
47
+ if lr_schedule == "cosine_decay":
48
+ scheduler = get_cosine_schedule_with_warmup(
49
+ optimizer=optimizer,
50
+ num_warmup_steps=num_warmup_steps,
51
+ num_training_steps=num_max_steps,
52
+ )
53
+ elif lr_schedule == "polynomial_decay":
54
+ scheduler = get_polynomial_decay_schedule_with_warmup(
55
+ optimizer=optimizer,
56
+ num_warmup_steps=num_warmup_steps,
57
+ num_training_steps=num_max_steps,
58
+ lr_end=end_lr,
59
+ power=1,
60
+ )
61
+ elif lr_schedule == "linear_decay":
62
+ scheduler = get_linear_schedule_with_warmup(
63
+ optimizer=optimizer,
64
+ num_warmup_steps=num_warmup_steps,
65
+ num_training_steps=num_max_steps,
66
+ )
67
+ elif lr_schedule == "multi_step":
68
+ # TODO: add milestones, gamma into hyperparameters
69
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[30, 55], gamma=0.1)
70
+ else:
71
+ raise ValueError(f"unknown lr schedule: {lr_schedule}")
72
+
73
+ return scheduler
74
+
75
+
76
+ def apply_single_lr(
77
+ model: nn.Module,
78
+ lr: float,
79
+ weight_decay: float,
80
+ return_params: Optional[bool] = True,
81
+ peft: Optional[str] = None,
82
+ trainable_param_names: Optional[List] = None,
83
+ exclude_keys: Optional[List] = None,
84
+ ):
85
+ """
86
+ Set to use a single learning rate for all parameters. Layer normalization parameters and other
87
+ layers' bias parameters don't use weight decay.
88
+
89
+ Parameters
90
+ ----------
91
+ model
92
+ A Pytorch model.
93
+ lr
94
+ Learning rate.
95
+ weight_decay
96
+ Weight decay.
97
+ return_params
98
+ Whether to return parameters or their names. If you want to double-check
99
+ whether the learning rate setup is as expected, you can set "return_params=False",
100
+ and print the layer names along with their learning rates through
101
+ "print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
102
+ peft
103
+ Efficient finetuning strategy. It will only finetune part of the parameters
104
+ trainable_param_names
105
+ A list of trainable parameters. (Optional)
106
+ exclude_keys
107
+ A list of keys to be excluded.
108
+
109
+ Returns
110
+ -------
111
+ The grouped parameters or their names.
112
+ """
113
+ decay_param_names = get_weight_decay_param_names(model)
114
+ decay_grad_param_names = []
115
+ no_decay_grad_param_names = []
116
+
117
+ for name, param in model.named_parameters():
118
+ if exclude_keys and any([exc in name for exc in exclude_keys]):
119
+ continue
120
+
121
+ if (
122
+ peft is not None
123
+ and trainable_param_names
124
+ and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
125
+ ):
126
+ param.requires_grad = False
127
+
128
+ if not param.requires_grad:
129
+ continue # frozen weights
130
+
131
+ if name in decay_param_names:
132
+ if return_params:
133
+ decay_grad_param_names.append(param)
134
+ else:
135
+ decay_grad_param_names.append(name)
136
+
137
+ else:
138
+ if return_params:
139
+ no_decay_grad_param_names.append(param)
140
+ else:
141
+ no_decay_grad_param_names.append(name)
142
+
143
+ optimizer_grouped_parameters = [
144
+ {
145
+ "params": decay_grad_param_names,
146
+ "weight_decay": weight_decay,
147
+ "lr": lr,
148
+ },
149
+ {
150
+ "params": no_decay_grad_param_names,
151
+ "weight_decay": 0.0,
152
+ "lr": lr,
153
+ },
154
+ ]
155
+ return optimizer_grouped_parameters
156
+
157
+
158
+ def apply_two_stages_lr(
159
+ model: nn.Module,
160
+ lr: float,
161
+ lr_mult: Union[float, int],
162
+ weight_decay: float,
163
+ return_params: Optional[bool] = True,
164
+ exclude_keys: Optional[List] = None,
165
+ ):
166
+ """
167
+ Set up the pretrained backbone to use a smaller learning rate (lr * lr_mult).
168
+ The newly added head layers use the normal learning rate (lr).
169
+ Layer normalization parameters and other layers' bias parameters don't use weight decay.
170
+
171
+ Parameters
172
+ ----------
173
+ model
174
+ A Pytorch model.
175
+ lr
176
+ The learning rate.
177
+ lr_mult
178
+ The multiplier (0, 1) to scale down the learning rate.
179
+ weight_decay
180
+ Weight decay.
181
+ return_params
182
+ return_params
183
+ Whether to return parameters or their names. If you want to double-check
184
+ whether the learning rate setup is as expected, you can set "return_params=False",
185
+ and print the layer names along with their learning rates through
186
+ "print("Param groups = %s" % json.dumps(optimizer_grouped_parameters, indent=2))".
187
+ exclude_keys
188
+ A list of keys to be excluded.
189
+
190
+ Returns
191
+ -------
192
+ The grouped parameters or their names.
193
+ """
194
+ decay_param_names = get_weight_decay_param_names(model)
195
+
196
+ optimizer_grouped_parameters = [
197
+ {
198
+ "params": [
199
+ p if return_params else n
200
+ for n, p in model.named_parameters()
201
+ if n in decay_param_names
202
+ and not any(bb in n for bb in model.head_layer_names)
203
+ and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
204
+ ],
205
+ "weight_decay": weight_decay,
206
+ "lr": lr,
207
+ },
208
+ {
209
+ "params": [
210
+ p if return_params else n
211
+ for n, p in model.named_parameters()
212
+ if n not in decay_param_names
213
+ and not any(bb in n for bb in model.head_layer_names)
214
+ and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
215
+ ],
216
+ "weight_decay": 0.0,
217
+ "lr": lr,
218
+ },
219
+ {
220
+ "params": [
221
+ p if return_params else n
222
+ for n, p in model.named_parameters()
223
+ if n in decay_param_names
224
+ and any(bb in n for bb in model.head_layer_names)
225
+ and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
226
+ ],
227
+ "weight_decay": weight_decay,
228
+ "lr": lr * lr_mult,
229
+ },
230
+ {
231
+ "params": [
232
+ p if return_params else n
233
+ for n, p in model.named_parameters()
234
+ if n not in decay_param_names
235
+ and any(bb in n for bb in model.head_layer_names)
236
+ and (not exclude_keys or not any([exc in n for exc in exclude_keys]))
237
+ ],
238
+ "weight_decay": 0.0,
239
+ "lr": lr * lr_mult,
240
+ },
241
+ ]
242
+
243
+ return optimizer_grouped_parameters
244
+
245
+
246
+ def apply_layerwise_lr_decay(
247
+ model: nn.Module,
248
+ lr: float,
249
+ lr_decay: float,
250
+ weight_decay: float,
251
+ peft: Optional[str] = None,
252
+ trainable_param_names: Optional[List] = None,
253
+ exclude_keys: Optional[List] = None,
254
+ ):
255
+ """
256
+ Assign monotonically decreasing learning rates for layers from the output end to the input end.
257
+ The intuition behind is that later layers are more task-related compared to the early layers.
258
+ Layer normalization parameters and other layers' bias parameters don't use weight decay.
259
+ If you want to double-check whether the learning rate setup is as expected,
260
+ you can print the layer names along with their learning rates through
261
+ "print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))".
262
+
263
+ Parameters
264
+ ----------
265
+ model
266
+ A Pytorch model.
267
+ lr
268
+ The learning rate.
269
+ lr_decay
270
+ The learning rate decay factor (0, 1).
271
+ weight_decay
272
+ Weight decay.
273
+ peft
274
+ Efficient finetuning strategy. It will only finetune part of the parameters
275
+ trainable_param_names
276
+ A list of trainable parameters. (Optional)
277
+ exclude_keys
278
+ A list of keys to be excluded.
279
+
280
+ Returns
281
+ -------
282
+ The grouped parameters based on their layer ids and whether using weight decay.
283
+ """
284
+ parameter_group_names = {}
285
+ parameter_group_vars = {}
286
+ decay_param_names = get_weight_decay_param_names(model)
287
+
288
+ for name, param in model.named_parameters():
289
+ if name.startswith("_orig_mod."):
290
+ name = "".join(name.split("_orig_mod."))
291
+ if exclude_keys and any([exc in name for exc in exclude_keys]):
292
+ continue
293
+ layer_id = model.name_to_id[name]
294
+ if layer_id == 0: # Set top layer (e.g. head, fusion_mlp, adapter) as being trainable.
295
+ param.requires_grad = True
296
+ elif (
297
+ peft is not None
298
+ and trainable_param_names
299
+ and not any([re.match(trainable_param_name, name) for trainable_param_name in trainable_param_names])
300
+ ):
301
+ param.requires_grad = False
302
+
303
+ if not param.requires_grad:
304
+ continue # frozen weights
305
+
306
+ if name in decay_param_names:
307
+ group_name = "decay"
308
+ this_weight_decay = weight_decay
309
+ else:
310
+ group_name = "no_decay"
311
+ this_weight_decay = 0.0
312
+
313
+ layer_id = model.name_to_id[name]
314
+ group_name = "layer_%d_%s" % (layer_id, group_name)
315
+
316
+ if group_name not in parameter_group_names:
317
+ scale = lr_decay**layer_id
318
+ parameter_group_names[group_name] = {
319
+ "weight_decay": this_weight_decay,
320
+ "params": [],
321
+ "lr": scale * lr,
322
+ }
323
+ parameter_group_vars[group_name] = {
324
+ "weight_decay": this_weight_decay,
325
+ "params": [],
326
+ "lr": scale * lr,
327
+ }
328
+
329
+ parameter_group_vars[group_name]["params"].append(param)
330
+ parameter_group_names[group_name]["params"].append(name)
331
+
332
+ return list(parameter_group_vars.values())
@@ -0,0 +1,4 @@
1
+ from .coverage_metrics import Coverage
2
+ from .hit_rate_metrics import CustomHitRate
3
+ from .ranking_metrics import compute_ranking_score
4
+ from .utils import get_torchmetric, compute_score, get_minmax_mode, get_stopping_threshold, infer_metrics
@@ -0,0 +1,42 @@
1
+ import torch
2
+ from torchmetrics import Metric
3
+ from torchmetrics.utilities import dim_zero_cat
4
+
5
+
6
+ class Coverage(Metric):
7
+ def __init__(self, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.add_state("pos_probs", default=[], dist_reduce_fx="cat")
10
+ self.add_state("targets", default=[], dist_reduce_fx="cat")
11
+ self.tp_threshold = 0.97
12
+ self.tn_threshold = 0.99
13
+ higher_is_better = True
14
+
15
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
16
+ assert preds.dim() == 1
17
+ assert target.dim() == 1
18
+ self.pos_probs.append(preds)
19
+ self.targets.append(target)
20
+
21
+ def compute(self):
22
+ # parse inputs
23
+ pos_probs = dim_zero_cat(self.pos_probs)
24
+ targets = dim_zero_cat(self.targets)
25
+ y_pos = targets[torch.where(pos_probs >= self.tp_threshold)]
26
+ y_neg = targets[torch.where(pos_probs <= 1 - self.tn_threshold)]
27
+ tp = sum(y_pos == 1)
28
+ tn = sum(y_neg == 0)
29
+ if len(y_pos) == 0 or len(targets) == 0:
30
+ tp_precision, tp_coverage = 0, 0
31
+ else:
32
+ tp_precision, tp_coverage = tp / len(y_pos), len(y_pos) / len(targets)
33
+ if tp_precision < self.tp_threshold:
34
+ tp_coverage = 0
35
+ if len(y_neg) == 0 or len(targets) == 0:
36
+ tn_precision, tn_coverage = 0, 0
37
+ else:
38
+ tn_precision, tn_coverage = tn / len(y_neg), len(y_neg) / len(targets)
39
+ if tn_precision < self.tn_threshold:
40
+ tn_coverage = 0
41
+
42
+ return torch.tensor(tp_coverage) + torch.tensor(tn_coverage)
@@ -0,0 +1,78 @@
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torchmetrics
5
+
6
+
7
+ class CustomHitRate(torchmetrics.Metric):
8
+ """
9
+ Compute the hit rate when doing semantic search between two group of embeddings.
10
+ We assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ ):
16
+ super().__init__()
17
+ self.add_state("query_embeddings", default=[], dist_reduce_fx=None)
18
+ self.add_state("response_embeddings", default=[], dist_reduce_fx=None)
19
+ self.add_state("logit_scale", default=[], dist_reduce_fx=None)
20
+
21
+ def update(
22
+ self,
23
+ batch_query_embeds: torch.Tensor,
24
+ batch_response_embeds: torch.Tensor,
25
+ logit_scale: Optional[torch.Tensor] = None,
26
+ ):
27
+ self.query_embeddings.append(batch_query_embeds)
28
+ self.response_embeddings.append(batch_response_embeds)
29
+ if logit_scale is not None:
30
+ self.logit_scale.append(logit_scale)
31
+
32
+ def compute(self):
33
+ query_embeddings = torch.cat(self.query_embeddings)
34
+ response_embeddings = torch.cat(self.response_embeddings)
35
+ if self.logit_scale:
36
+ logit_scale = torch.mean(torch.stack(self.logit_scale))
37
+ else:
38
+ logit_scale = 1
39
+
40
+ return self.compute_hit_rate(query_embeddings, response_embeddings, logit_scale)
41
+
42
+ @staticmethod
43
+ def compute_hit_rate(features_a, features_b, logit_scale, top_ks=[1, 5, 10]):
44
+ """
45
+ Compute symmetric hit rates between two groups of features.
46
+
47
+ Parameters
48
+ ----------
49
+ features_a
50
+ One group of features.
51
+ features_b
52
+ The other group of features.
53
+ logit_scale
54
+ The scale of logit (Used in CLIP).
55
+ top_ks
56
+ Consider only the top k elements for each query.
57
+
58
+ Returns
59
+ -------
60
+ The accumulated hit rate.
61
+ """
62
+ assert len(features_a) == len(features_b)
63
+ hit_rate = 0
64
+ logits_per_a = (logit_scale * features_a @ features_b.t()).detach().cpu()
65
+ logits_per_b = logits_per_a.t().detach().cpu()
66
+
67
+ logits = {"logits_per_a": logits_per_a, "logits_per_b": logits_per_b}
68
+ ground_truth = torch.arange(len(features_b)).view(-1, 1)
69
+
70
+ for name, logit in logits.items():
71
+ ranking = torch.argsort(logit, descending=True)
72
+ preds = torch.where(ranking == ground_truth)[1]
73
+
74
+ for k in top_ks:
75
+ hit_rate += (preds < k).float().mean()
76
+
77
+ hit_rate /= len(top_ks) * len(logits)
78
+ return hit_rate
@@ -0,0 +1,231 @@
1
+ import logging
2
+ import math
3
+ import operator
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+
8
+ from ...constants import MAP, NDCG, PRECISION, RECALL
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class RankingMetrics:
14
+ def __init__(
15
+ self,
16
+ pred: Dict[str, Dict],
17
+ target: Dict[str, Dict],
18
+ is_higher_better=True,
19
+ ):
20
+ """
21
+ Evaluation Metrics for information retrieval tasks such as document retrieval, image retrieval, etc.
22
+ Reference: https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
23
+
24
+ Parameters
25
+ ----------
26
+ pred:
27
+ the prediction of the ranking model. It has the following form.
28
+ pred = {
29
+ 'q1': {
30
+ 'd1': 1,
31
+ 'd3': 0,
32
+ },
33
+ 'q2': {
34
+ 'd2': 1,
35
+ 'd3': 1,
36
+ },
37
+ }
38
+ where q refers to queries, and d refers to documents, each query has a few relevant documents.
39
+ 0s and 1s are model predicted scores (does not need to be binary).
40
+ target:
41
+ the ground truth query and response relevance which has the same form as pred.
42
+ is_higher_better:
43
+ if higher relevance score means higher ranking.
44
+ if the relevance score is cosine similarity / dot product, it should be set to True;
45
+ if it is Eulidean distance, it should be False.
46
+ """
47
+ self.pred = pred
48
+ self.target = target
49
+ self.is_higher_better = is_higher_better
50
+ # the supported metrics in this script
51
+ self.supported_metrics = {
52
+ "precision": 0,
53
+ "recall": 1,
54
+ "mrr": 2,
55
+ "map": 3,
56
+ "ndcg": 4,
57
+ }
58
+
59
+ assert len(pred) == len(
60
+ target
61
+ ), f"The prediction and groudtruth target should have the same number of queries, \
62
+ while there are {len(pred)} queries in prediction and {len(target)} in the target."
63
+
64
+ self.results = {}
65
+ for key in target.keys():
66
+ self.results.update({key: [target[key], pred[key]]})
67
+
68
+ def compute(self, metrics: Union[str, list] = None, k: Optional[int] = 10):
69
+ """
70
+ compute and return ranking scores.
71
+
72
+ Parameters
73
+ ----------
74
+ metrics:
75
+ user provided metrics
76
+ k:
77
+ the cutoff value for NDCG, MAP, Recall, MRR, and Precision
78
+
79
+ Returns
80
+ -------
81
+ Computed score.
82
+
83
+ """
84
+ if isinstance(metrics, str):
85
+ metrics = [metrics]
86
+ if not metrics: # no metric is provided
87
+ metrics = self.supported_metrics.keys()
88
+
89
+ return_res = {}
90
+
91
+ eval_res = np.mean(
92
+ [list(self._compute_one(idx, k)) for idx in self.results.keys()],
93
+ axis=0,
94
+ )
95
+
96
+ for metric in metrics:
97
+ metric = metric.lower()
98
+ if metric in self.supported_metrics:
99
+ return_res.update({f"{metric}@{k}": eval_res[self.supported_metrics[metric]]})
100
+
101
+ return return_res
102
+
103
+ def _compute_one(self, idx, k):
104
+ """
105
+ compute and return the ranking scores for one query.
106
+ for definition of these metrics, please refer to
107
+ https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
108
+
109
+ Parameters
110
+ ----------
111
+ idx:
112
+ the index of the query
113
+ k:
114
+ the cutoff value for NDCG, MAP, Recall, MRR, and Precision
115
+
116
+ Returns
117
+ -------
118
+ Computed score.
119
+ """
120
+ precision, recall, mrr, mAP, ndcg = 0, 0, 0, 0, 0
121
+ target, pred = self.results[idx][0], self.results[idx][1]
122
+
123
+ # sort the ground truth and predictions in descending order
124
+ sorted_target = dict(
125
+ sorted(
126
+ target.items(),
127
+ key=operator.itemgetter(1),
128
+ reverse=self.is_higher_better,
129
+ )
130
+ )
131
+ sorted_pred = dict(
132
+ sorted(
133
+ pred.items(),
134
+ key=operator.itemgetter(1),
135
+ reverse=self.is_higher_better,
136
+ )
137
+ )
138
+ sorted_target_values = list(sorted_target.values())
139
+ sorted_pred_values = list(sorted_pred.values())
140
+
141
+ # number of positive relevance in target
142
+ # negative numbers and zero are considered as negative response
143
+ num_pos_target = len([val for val in sorted_target_values if val > 0])
144
+
145
+ at_k = k if num_pos_target > k else num_pos_target
146
+
147
+ first_k_items_list = list(sorted_pred.items())[0:k]
148
+
149
+ rank = 0
150
+ hit_rank = [] # correctly retrieved items
151
+ for key, value in first_k_items_list:
152
+ if key in sorted_target and sorted_target[key] > 0:
153
+ hit_rank.append(rank)
154
+ rank += 1
155
+ count = len(hit_rank)
156
+ # compute the precision and recall
157
+ precision = count / k
158
+ recall = count / num_pos_target
159
+
160
+ dcg = 0
161
+ if hit_rank: # not empty
162
+ # compute the mean reciprocal rank
163
+ mrr = 1 / (hit_rank[0] + 1)
164
+ # compute the mean average precision
165
+ mAP = np.sum([sorted_pred_values[rank] * (i + 1) / (rank + 1) for i, rank in enumerate(hit_rank)])
166
+ # compute the discounted cumulative gain
167
+ dcg = np.sum([1 / math.log(rank + 2, 2) for rank in hit_rank])
168
+
169
+ # compute the ideal discounted cumulative gain
170
+ idcg = np.sum([1 / math.log(i + 2, 2) for i in range(at_k)])
171
+ # compute the normalized discounted cumulative gain
172
+ ndcg = dcg / idcg
173
+ mAP /= at_k
174
+
175
+ return precision, recall, mrr, mAP, ndcg
176
+
177
+
178
+ def compute_ranking_score(
179
+ results: Dict[str, Dict],
180
+ qrel_dict: Dict[str, Dict],
181
+ metrics: List[str],
182
+ cutoffs: Optional[List[int]] = [5, 10, 20],
183
+ ):
184
+ """
185
+ Compute the ranking metrics, e.g., NDCG, MAP, Recall, and Precision.
186
+ TODO: Consider MRR.
187
+
188
+ Parameters
189
+ ----------
190
+ results:
191
+ The query/document ranking list by the model.
192
+ qrel_dict:
193
+ The groundtruth query and document relevance.
194
+ metrics
195
+ A list of metrics to compute.
196
+ cutoffs:
197
+ The cutoff values for NDCG, MAP, Recall, and Precision.
198
+
199
+ Returns
200
+ -------
201
+ A dict of metric scores.
202
+ """
203
+ scores = {}
204
+ evaluator = RankingMetrics(pred=results, target=qrel_dict)
205
+ for k in cutoffs:
206
+ scores.update(evaluator.compute(k=k))
207
+
208
+ metric_results = dict()
209
+ for k in cutoffs:
210
+ for per_metric in metrics:
211
+ if per_metric.lower() == NDCG:
212
+ metric_results[f"{NDCG}@{k}"] = 0.0
213
+ elif per_metric.lower() == MAP:
214
+ metric_results[f"{MAP}@{k}"] = 0.0
215
+ elif per_metric.lower() == RECALL:
216
+ metric_results[f"{RECALL}@{k}"] = 0.0
217
+ elif per_metric.lower() == PRECISION:
218
+ metric_results[f"{PRECISION}@{k}"] = 0.0
219
+
220
+ for k in cutoffs:
221
+ for per_metric in metrics:
222
+ if per_metric.lower() == NDCG:
223
+ metric_results[f"{NDCG}@{k}"] = round(scores[f"{NDCG}@{k}"], 5)
224
+ elif per_metric.lower() == MAP:
225
+ metric_results[f"{MAP}@{k}"] = round(scores[f"{MAP}@{k}"], 5)
226
+ elif per_metric.lower() == RECALL:
227
+ metric_results[f"{RECALL}@{k}"] = round(scores[f"{RECALL}@{k}"], 5)
228
+ elif per_metric.lower() == PRECISION:
229
+ metric_results[f"{PRECISION}@{k}"] = round(scores[f"{PRECISION}@{k}"], 5)
230
+
231
+ return metric_results