autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,500 +0,0 @@
1
- import logging
2
- import math
3
- import operator
4
- import warnings
5
- from typing import Dict, List, Optional, Tuple, Union
6
-
7
- import evaluate
8
- import numpy as np
9
- from sklearn.metrics import f1_score
10
-
11
- from autogluon.core.metrics import Scorer, compute_metric, get_metric
12
-
13
- from ..constants import (
14
- ACCURACY,
15
- AVERAGE_PRECISION,
16
- BINARY,
17
- DIRECT_LOSS,
18
- F1,
19
- FEW_SHOT_CLASSIFICATION,
20
- IOU,
21
- MAP,
22
- MATCHING_METRICS,
23
- MATCHING_METRICS_WITHOUT_PROBLEM_TYPE,
24
- MAX,
25
- METRIC_MODE_MAP,
26
- MIN,
27
- MULTICLASS,
28
- NDCG,
29
- NER,
30
- NER_TOKEN_F1,
31
- NUMERICAL,
32
- OBJECT_DETECTION,
33
- OVERALL_ACCURACY,
34
- OVERALL_F1,
35
- PRECISION,
36
- RECALL,
37
- REGRESSION,
38
- RETRIEVAL_METRICS,
39
- RMSE,
40
- ROC_AUC,
41
- SEMANTIC_SEGMENTATION,
42
- SPEARMANR,
43
- Y_PRED,
44
- Y_PRED_PROB,
45
- Y_TRUE,
46
- )
47
- from ..problem_types import PROBLEM_TYPES_REG
48
-
49
- logger = logging.getLogger(__name__)
50
-
51
-
52
- def infer_metrics(
53
- problem_type: Optional[str] = None,
54
- eval_metric: Optional[Union[str, Scorer]] = None,
55
- validation_metric_name: Optional[str] = None,
56
- is_matching: Optional[bool] = False,
57
- ):
58
- """
59
- Infer the validation metric and the evaluation metric if not provided.
60
- Validation metric is for early-stopping and selecting the best model checkpoints.
61
- Evaluation metric is to report performance to users.
62
-
63
- Parameters
64
- ----------
65
- problem_type
66
- Type of problem.
67
- eval_metric_name
68
- Name of evaluation metric provided by users.
69
- validation_metric_name
70
- The provided validation metric name
71
- is_matching
72
- Whether is matching.
73
-
74
- Returns
75
- -------
76
- validation_metric_name
77
- Name of validation metric.
78
- eval_metric_name
79
- Name of evaluation metric.
80
- """
81
- is_customized = False
82
- if eval_metric is None:
83
- eval_metric_name = None
84
- elif isinstance(eval_metric, str):
85
- eval_metric_name = eval_metric
86
- elif isinstance(eval_metric, Scorer):
87
- eval_metric_name = eval_metric.name
88
- is_customized = True
89
- else:
90
- raise TypeError(f"eval_metric can be a str, a Scorer, or None, but is type: {type(eval_metric)}")
91
-
92
- if problem_type is not None:
93
- problem_property = PROBLEM_TYPES_REG.get(problem_type)
94
-
95
- if is_matching:
96
- if eval_metric_name is not None:
97
- # if eval_metric_name is a valid metric
98
- if eval_metric_name.lower() in METRIC_MODE_MAP.keys():
99
- validation_metric_name = eval_metric_name
100
- return validation_metric_name, eval_metric_name
101
- elif eval_metric_name.lower() in RETRIEVAL_METRICS:
102
- # Currently only support recall as validation metric in retrieval tasks.
103
- validation_metric_name = RECALL
104
- return validation_metric_name, eval_metric_name
105
-
106
- # When eval_metric_name is either None or not supported:
107
- # Fallback based on problem type unless it's a customized metric
108
- if problem_type is None:
109
- validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS_WITHOUT_PROBLEM_TYPE
110
- elif problem_type in MATCHING_METRICS:
111
- validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS[problem_type]
112
- else:
113
- raise NotImplementedError(f"Problem type: {problem_type} is not yet supported for matching!")
114
- if not is_customized:
115
- if eval_metric_name is not None:
116
- warnings.warn(
117
- f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type} in matching tasks."
118
- f"The evaluation metric is changed to {fallback_evaluation_metric} by default."
119
- )
120
- eval_metric_name = fallback_evaluation_metric
121
- return validation_metric_name, eval_metric_name
122
-
123
- if eval_metric_name is not None:
124
- # Infer evaluation metric
125
- if eval_metric_name.lower() not in problem_property.supported_evaluation_metrics and not is_customized:
126
- warnings.warn(
127
- f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type}. "
128
- f"The evaluation metric is changed to {problem_property.fallback_evaluation_metric} by default."
129
- )
130
- if problem_property.fallback_evaluation_metric is not None:
131
- eval_metric_name = problem_property.fallback_evaluation_metric
132
- else:
133
- # Problem types like extract_embedding does not need a eval/val metric
134
- return None, None
135
-
136
- # Infer validation metric
137
- if eval_metric_name.lower() in problem_property.supported_validation_metrics:
138
- validation_metric_name = eval_metric_name
139
- else:
140
- if problem_property.fallback_validation_metric is not None:
141
- validation_metric_name = problem_property.fallback_validation_metric
142
- else:
143
- eval_metric_name = problem_property.fallback_evaluation_metric
144
- validation_metric_name = problem_property.fallback_validation_metric
145
-
146
- return validation_metric_name, eval_metric_name
147
-
148
-
149
- def get_minmax_mode(
150
- metric_name: Union[str, Scorer],
151
- ):
152
- """
153
- Get minmax mode based on metric name
154
-
155
- Parameters
156
- ----------
157
- metric_name
158
- A string representing metric
159
-
160
- Returns
161
- -------
162
- mode
163
- The min/max mode used in selecting model checkpoints.
164
- - min
165
- Its means that smaller metric is better.
166
- - max
167
- It means that larger metric is better.
168
- """
169
- if isinstance(metric_name, str):
170
- assert (
171
- metric_name in METRIC_MODE_MAP
172
- ), f"{metric_name} is not a supported metric. Options are: {METRIC_MODE_MAP.keys()}"
173
- return METRIC_MODE_MAP.get(metric_name)
174
- else:
175
- return MAX if metric_name.greater_is_better else MIN
176
-
177
-
178
- def get_stopping_threshold(metric_name: str):
179
- """
180
- Get the metric threshold for early stopping.
181
-
182
- Parameters
183
- ----------
184
- metric_name
185
- Name of validation metric.
186
-
187
- Returns
188
- -------
189
- The stopping threshold.
190
- """
191
- try:
192
- metric = get_metric(metric_name)
193
- stopping_threshold = metric.optimum - metric._sign * 1e-7
194
- except:
195
- stopping_threshold = None
196
-
197
- return stopping_threshold
198
-
199
-
200
- def compute_score(
201
- metric_data: dict,
202
- metric: Union[str, Scorer],
203
- pos_label: Optional[int] = 1,
204
- ) -> float:
205
- """
206
- Use sklearn to compute the score of one metric.
207
-
208
- Parameters
209
- ----------
210
- metric_data
211
- A dictionary with the groundtruth (Y_TRUE) and predicted values (Y_PRED, Y_PRED_PROB).
212
- The predicted class probabilities are required to compute the roc_auc score.
213
- metric
214
- The name of metric or the function of metric to compute.
215
- pos_label
216
- The encoded label (0 or 1) of binary classification's positive class.
217
-
218
- Returns
219
- -------
220
- Computed score.
221
- """
222
- if isinstance(metric, str) and metric in [OVERALL_ACCURACY, OVERALL_F1]:
223
- metric = evaluate.load("seqeval")
224
- warnings.filterwarnings("ignore")
225
- for p in metric_data[Y_TRUE]:
226
- if "_" in p:
227
- print(p)
228
- for p in metric_data[Y_PRED]:
229
- if "_" in p:
230
- print(p)
231
- return metric.compute(references=metric_data[Y_TRUE], predictions=metric_data[Y_PRED])
232
-
233
- metric = get_metric(metric)
234
-
235
- y = metric_data[Y_TRUE]
236
- if metric.needs_proba or metric.needs_threshold:
237
- y_pred_proba = metric_data[Y_PRED_PROB]
238
- y_pred_proba = (
239
- y_pred_proba if y_pred_proba.shape[1] > 2 else y_pred_proba[:, pos_label]
240
- ) # only use pos_label for binary classification
241
- return metric.convert_score_to_original(
242
- compute_metric(y=y, y_pred_proba=y_pred_proba, metric=metric, weights=None)
243
- )
244
- else:
245
- y_pred = metric_data[Y_PRED]
246
-
247
- # TODO: This is a hack. Doesn't support `f1_macro`, `f1_micro`, `f1_weighted`, or custom `f1` metrics with different names.
248
- # TODO: Longterm the solution should be to have the input data to this function use the internal representation without the original class names. This way `pos_label` would not need to be specified.
249
- if metric.name == F1: # only for binary classification
250
- y = (y == pos_label).astype(int)
251
- y_pred = (y_pred == pos_label).astype(int)
252
-
253
- return metric.convert_score_to_original(compute_metric(y=y, y_pred=y_pred, metric=metric, weights=None))
254
-
255
-
256
- class RankingMetrics:
257
- def __init__(
258
- self,
259
- pred: Dict[str, Dict],
260
- target: Dict[str, Dict],
261
- is_higher_better=True,
262
- ):
263
- """
264
- Evaluation Metrics for information retrieval tasks such as document retrieval, image retrieval, etc.
265
- Reference: https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
266
-
267
- Parameters
268
- ----------
269
- pred:
270
- the prediction of the ranking model. It has the following form.
271
- pred = {
272
- 'q1': {
273
- 'd1': 1,
274
- 'd3': 0,
275
- },
276
- 'q2': {
277
- 'd2': 1,
278
- 'd3': 1,
279
- },
280
- }
281
- where q refers to queries, and d refers to documents, each query has a few relevant documents.
282
- 0s and 1s are model predicted scores (does not need to be binary).
283
- target:
284
- the ground truth query and response relevance which has the same form as pred.
285
- is_higher_better:
286
- if higher relevance score means higher ranking.
287
- if the relevance score is cosine similarity / dot product, it should be set to True;
288
- if it is Eulidean distance, it should be False.
289
- """
290
- self.pred = pred
291
- self.target = target
292
- self.is_higher_better = is_higher_better
293
- # the supported metrics in this script
294
- self.supported_metrics = {
295
- "precision": 0,
296
- "recall": 1,
297
- "mrr": 2,
298
- "map": 3,
299
- "ndcg": 4,
300
- }
301
-
302
- assert len(pred) == len(
303
- target
304
- ), f"The prediction and groudtruth target should have the same number of queries, \
305
- while there are {len(pred)} queries in prediction and {len(target)} in the target."
306
-
307
- self.results = {}
308
- for key in target.keys():
309
- self.results.update({key: [target[key], pred[key]]})
310
-
311
- def compute(self, metrics: Union[str, list] = None, k: Optional[int] = 10):
312
- """
313
- compute and return ranking scores.
314
-
315
- Parameters
316
- ----------
317
- metrics:
318
- user provided metrics
319
- k:
320
- the cutoff value for NDCG, MAP, Recall, MRR, and Precision
321
-
322
- Returns
323
- -------
324
- Computed score.
325
-
326
- """
327
- if isinstance(metrics, str):
328
- metrics = [metrics]
329
- if not metrics: # no metric is provided
330
- metrics = self.supported_metrics.keys()
331
-
332
- return_res = {}
333
-
334
- eval_res = np.mean(
335
- [list(self._compute_one(idx, k)) for idx in self.results.keys()],
336
- axis=0,
337
- )
338
-
339
- for metric in metrics:
340
- metric = metric.lower()
341
- if metric in self.supported_metrics:
342
- return_res.update({f"{metric}@{k}": eval_res[self.supported_metrics[metric]]})
343
-
344
- return return_res
345
-
346
- def _compute_one(self, idx, k):
347
- """
348
- compute and return the ranking scores for one query.
349
- for definition of these metrics, please refer to
350
- https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
351
-
352
- Parameters
353
- ----------
354
- idx:
355
- the index of the query
356
- k:
357
- the cutoff value for NDCG, MAP, Recall, MRR, and Precision
358
-
359
- Returns
360
- -------
361
- Computed score.
362
- """
363
- precision, recall, mrr, mAP, ndcg = 0, 0, 0, 0, 0
364
- target, pred = self.results[idx][0], self.results[idx][1]
365
-
366
- # sort the ground truth and predictions in descending order
367
- sorted_target = dict(
368
- sorted(
369
- target.items(),
370
- key=operator.itemgetter(1),
371
- reverse=self.is_higher_better,
372
- )
373
- )
374
- sorted_pred = dict(
375
- sorted(
376
- pred.items(),
377
- key=operator.itemgetter(1),
378
- reverse=self.is_higher_better,
379
- )
380
- )
381
- sorted_target_values = list(sorted_target.values())
382
- sorted_pred_values = list(sorted_pred.values())
383
-
384
- # number of positive relevance in target
385
- # negative numbers and zero are considered as negative response
386
- num_pos_target = len([val for val in sorted_target_values if val > 0])
387
-
388
- at_k = k if num_pos_target > k else num_pos_target
389
-
390
- first_k_items_list = list(sorted_pred.items())[0:k]
391
-
392
- rank = 0
393
- hit_rank = [] # correctly retrieved items
394
- for key, value in first_k_items_list:
395
- if key in sorted_target and sorted_target[key] > 0:
396
- hit_rank.append(rank)
397
- rank += 1
398
- count = len(hit_rank)
399
- # compute the precision and recall
400
- precision = count / k
401
- recall = count / num_pos_target
402
-
403
- dcg = 0
404
- if hit_rank: # not empty
405
- # compute the mean reciprocal rank
406
- mrr = 1 / (hit_rank[0] + 1)
407
- # compute the mean average precision
408
- mAP = np.sum([sorted_pred_values[rank] * (i + 1) / (rank + 1) for i, rank in enumerate(hit_rank)])
409
- # compute the discounted cumulative gain
410
- dcg = np.sum([1 / math.log(rank + 2, 2) for rank in hit_rank])
411
-
412
- # compute the ideal discounted cumulative gain
413
- idcg = np.sum([1 / math.log(i + 2, 2) for i in range(at_k)])
414
- # compute the normalized discounted cumulative gain
415
- ndcg = dcg / idcg
416
- mAP /= at_k
417
-
418
- return precision, recall, mrr, mAP, ndcg
419
-
420
-
421
- def compute_ranking_score(
422
- results: Dict[str, Dict],
423
- qrel_dict: Dict[str, Dict],
424
- metrics: List[str],
425
- cutoffs: Optional[List[int]] = [5, 10, 20],
426
- ):
427
- """
428
- Compute the ranking metrics, e.g., NDCG, MAP, Recall, and Precision.
429
- TODO: Consider MRR.
430
-
431
- Parameters
432
- ----------
433
- results:
434
- The query/document ranking list by the model.
435
- qrel_dict:
436
- The groundtruth query and document relevance.
437
- metrics
438
- A list of metrics to compute.
439
- cutoffs:
440
- The cutoff values for NDCG, MAP, Recall, and Precision.
441
-
442
- Returns
443
- -------
444
- A dict of metric scores.
445
- """
446
- scores = {}
447
- evaluator = RankingMetrics(pred=results, target=qrel_dict)
448
- for k in cutoffs:
449
- scores.update(evaluator.compute(k=k))
450
-
451
- metric_results = dict()
452
- for k in cutoffs:
453
- for per_metric in metrics:
454
- if per_metric.lower() == NDCG:
455
- metric_results[f"{NDCG}@{k}"] = 0.0
456
- elif per_metric.lower() == MAP:
457
- metric_results[f"{MAP}@{k}"] = 0.0
458
- elif per_metric.lower() == RECALL:
459
- metric_results[f"{RECALL}@{k}"] = 0.0
460
- elif per_metric.lower() == PRECISION:
461
- metric_results[f"{PRECISION}@{k}"] = 0.0
462
-
463
- for k in cutoffs:
464
- for per_metric in metrics:
465
- if per_metric.lower() == NDCG:
466
- metric_results[f"{NDCG}@{k}"] = round(scores[f"{NDCG}@{k}"], 5)
467
- elif per_metric.lower() == MAP:
468
- metric_results[f"{MAP}@{k}"] = round(scores[f"{MAP}@{k}"], 5)
469
- elif per_metric.lower() == RECALL:
470
- metric_results[f"{RECALL}@{k}"] = round(scores[f"{RECALL}@{k}"], 5)
471
- elif per_metric.lower() == PRECISION:
472
- metric_results[f"{PRECISION}@{k}"] = round(scores[f"{PRECISION}@{k}"], 5)
473
-
474
- return metric_results
475
-
476
-
477
- def infer_problem_type_by_eval_metric(eval_metric_name: str, problem_type: str):
478
- if eval_metric_name is not None and eval_metric_name.lower() in [
479
- "rmse",
480
- "r2",
481
- "pearsonr",
482
- "spearmanr",
483
- ]:
484
- if problem_type is None:
485
- logger.debug(
486
- f"Infer problem type to be a regression problem "
487
- f"since the evaluation metric is set as {eval_metric_name}."
488
- )
489
- problem_type = REGRESSION
490
- else:
491
- problem_prop = PROBLEM_TYPES_REG.get(problem_type)
492
- if NUMERICAL not in problem_prop.supported_label_type:
493
- raise ValueError(
494
- f"The provided evaluation metric will require the problem "
495
- f"to support label type = {NUMERICAL}. However, "
496
- f"the provided problem type = {problem_type} only "
497
- f"supports label type = {problem_prop.supported_label_type}."
498
- )
499
-
500
- return problem_type