autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,500 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import math
|
3
|
-
import operator
|
4
|
-
import warnings
|
5
|
-
from typing import Dict, List, Optional, Tuple, Union
|
6
|
-
|
7
|
-
import evaluate
|
8
|
-
import numpy as np
|
9
|
-
from sklearn.metrics import f1_score
|
10
|
-
|
11
|
-
from autogluon.core.metrics import Scorer, compute_metric, get_metric
|
12
|
-
|
13
|
-
from ..constants import (
|
14
|
-
ACCURACY,
|
15
|
-
AVERAGE_PRECISION,
|
16
|
-
BINARY,
|
17
|
-
DIRECT_LOSS,
|
18
|
-
F1,
|
19
|
-
FEW_SHOT_CLASSIFICATION,
|
20
|
-
IOU,
|
21
|
-
MAP,
|
22
|
-
MATCHING_METRICS,
|
23
|
-
MATCHING_METRICS_WITHOUT_PROBLEM_TYPE,
|
24
|
-
MAX,
|
25
|
-
METRIC_MODE_MAP,
|
26
|
-
MIN,
|
27
|
-
MULTICLASS,
|
28
|
-
NDCG,
|
29
|
-
NER,
|
30
|
-
NER_TOKEN_F1,
|
31
|
-
NUMERICAL,
|
32
|
-
OBJECT_DETECTION,
|
33
|
-
OVERALL_ACCURACY,
|
34
|
-
OVERALL_F1,
|
35
|
-
PRECISION,
|
36
|
-
RECALL,
|
37
|
-
REGRESSION,
|
38
|
-
RETRIEVAL_METRICS,
|
39
|
-
RMSE,
|
40
|
-
ROC_AUC,
|
41
|
-
SEMANTIC_SEGMENTATION,
|
42
|
-
SPEARMANR,
|
43
|
-
Y_PRED,
|
44
|
-
Y_PRED_PROB,
|
45
|
-
Y_TRUE,
|
46
|
-
)
|
47
|
-
from ..problem_types import PROBLEM_TYPES_REG
|
48
|
-
|
49
|
-
logger = logging.getLogger(__name__)
|
50
|
-
|
51
|
-
|
52
|
-
def infer_metrics(
|
53
|
-
problem_type: Optional[str] = None,
|
54
|
-
eval_metric: Optional[Union[str, Scorer]] = None,
|
55
|
-
validation_metric_name: Optional[str] = None,
|
56
|
-
is_matching: Optional[bool] = False,
|
57
|
-
):
|
58
|
-
"""
|
59
|
-
Infer the validation metric and the evaluation metric if not provided.
|
60
|
-
Validation metric is for early-stopping and selecting the best model checkpoints.
|
61
|
-
Evaluation metric is to report performance to users.
|
62
|
-
|
63
|
-
Parameters
|
64
|
-
----------
|
65
|
-
problem_type
|
66
|
-
Type of problem.
|
67
|
-
eval_metric_name
|
68
|
-
Name of evaluation metric provided by users.
|
69
|
-
validation_metric_name
|
70
|
-
The provided validation metric name
|
71
|
-
is_matching
|
72
|
-
Whether is matching.
|
73
|
-
|
74
|
-
Returns
|
75
|
-
-------
|
76
|
-
validation_metric_name
|
77
|
-
Name of validation metric.
|
78
|
-
eval_metric_name
|
79
|
-
Name of evaluation metric.
|
80
|
-
"""
|
81
|
-
is_customized = False
|
82
|
-
if eval_metric is None:
|
83
|
-
eval_metric_name = None
|
84
|
-
elif isinstance(eval_metric, str):
|
85
|
-
eval_metric_name = eval_metric
|
86
|
-
elif isinstance(eval_metric, Scorer):
|
87
|
-
eval_metric_name = eval_metric.name
|
88
|
-
is_customized = True
|
89
|
-
else:
|
90
|
-
raise TypeError(f"eval_metric can be a str, a Scorer, or None, but is type: {type(eval_metric)}")
|
91
|
-
|
92
|
-
if problem_type is not None:
|
93
|
-
problem_property = PROBLEM_TYPES_REG.get(problem_type)
|
94
|
-
|
95
|
-
if is_matching:
|
96
|
-
if eval_metric_name is not None:
|
97
|
-
# if eval_metric_name is a valid metric
|
98
|
-
if eval_metric_name.lower() in METRIC_MODE_MAP.keys():
|
99
|
-
validation_metric_name = eval_metric_name
|
100
|
-
return validation_metric_name, eval_metric_name
|
101
|
-
elif eval_metric_name.lower() in RETRIEVAL_METRICS:
|
102
|
-
# Currently only support recall as validation metric in retrieval tasks.
|
103
|
-
validation_metric_name = RECALL
|
104
|
-
return validation_metric_name, eval_metric_name
|
105
|
-
|
106
|
-
# When eval_metric_name is either None or not supported:
|
107
|
-
# Fallback based on problem type unless it's a customized metric
|
108
|
-
if problem_type is None:
|
109
|
-
validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS_WITHOUT_PROBLEM_TYPE
|
110
|
-
elif problem_type in MATCHING_METRICS:
|
111
|
-
validation_metric_name, fallback_evaluation_metric = MATCHING_METRICS[problem_type]
|
112
|
-
else:
|
113
|
-
raise NotImplementedError(f"Problem type: {problem_type} is not yet supported for matching!")
|
114
|
-
if not is_customized:
|
115
|
-
if eval_metric_name is not None:
|
116
|
-
warnings.warn(
|
117
|
-
f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type} in matching tasks."
|
118
|
-
f"The evaluation metric is changed to {fallback_evaluation_metric} by default."
|
119
|
-
)
|
120
|
-
eval_metric_name = fallback_evaluation_metric
|
121
|
-
return validation_metric_name, eval_metric_name
|
122
|
-
|
123
|
-
if eval_metric_name is not None:
|
124
|
-
# Infer evaluation metric
|
125
|
-
if eval_metric_name.lower() not in problem_property.supported_evaluation_metrics and not is_customized:
|
126
|
-
warnings.warn(
|
127
|
-
f"Metric {eval_metric_name} is not supported as the evaluation metric for {problem_type}. "
|
128
|
-
f"The evaluation metric is changed to {problem_property.fallback_evaluation_metric} by default."
|
129
|
-
)
|
130
|
-
if problem_property.fallback_evaluation_metric is not None:
|
131
|
-
eval_metric_name = problem_property.fallback_evaluation_metric
|
132
|
-
else:
|
133
|
-
# Problem types like extract_embedding does not need a eval/val metric
|
134
|
-
return None, None
|
135
|
-
|
136
|
-
# Infer validation metric
|
137
|
-
if eval_metric_name.lower() in problem_property.supported_validation_metrics:
|
138
|
-
validation_metric_name = eval_metric_name
|
139
|
-
else:
|
140
|
-
if problem_property.fallback_validation_metric is not None:
|
141
|
-
validation_metric_name = problem_property.fallback_validation_metric
|
142
|
-
else:
|
143
|
-
eval_metric_name = problem_property.fallback_evaluation_metric
|
144
|
-
validation_metric_name = problem_property.fallback_validation_metric
|
145
|
-
|
146
|
-
return validation_metric_name, eval_metric_name
|
147
|
-
|
148
|
-
|
149
|
-
def get_minmax_mode(
|
150
|
-
metric_name: Union[str, Scorer],
|
151
|
-
):
|
152
|
-
"""
|
153
|
-
Get minmax mode based on metric name
|
154
|
-
|
155
|
-
Parameters
|
156
|
-
----------
|
157
|
-
metric_name
|
158
|
-
A string representing metric
|
159
|
-
|
160
|
-
Returns
|
161
|
-
-------
|
162
|
-
mode
|
163
|
-
The min/max mode used in selecting model checkpoints.
|
164
|
-
- min
|
165
|
-
Its means that smaller metric is better.
|
166
|
-
- max
|
167
|
-
It means that larger metric is better.
|
168
|
-
"""
|
169
|
-
if isinstance(metric_name, str):
|
170
|
-
assert (
|
171
|
-
metric_name in METRIC_MODE_MAP
|
172
|
-
), f"{metric_name} is not a supported metric. Options are: {METRIC_MODE_MAP.keys()}"
|
173
|
-
return METRIC_MODE_MAP.get(metric_name)
|
174
|
-
else:
|
175
|
-
return MAX if metric_name.greater_is_better else MIN
|
176
|
-
|
177
|
-
|
178
|
-
def get_stopping_threshold(metric_name: str):
|
179
|
-
"""
|
180
|
-
Get the metric threshold for early stopping.
|
181
|
-
|
182
|
-
Parameters
|
183
|
-
----------
|
184
|
-
metric_name
|
185
|
-
Name of validation metric.
|
186
|
-
|
187
|
-
Returns
|
188
|
-
-------
|
189
|
-
The stopping threshold.
|
190
|
-
"""
|
191
|
-
try:
|
192
|
-
metric = get_metric(metric_name)
|
193
|
-
stopping_threshold = metric.optimum - metric._sign * 1e-7
|
194
|
-
except:
|
195
|
-
stopping_threshold = None
|
196
|
-
|
197
|
-
return stopping_threshold
|
198
|
-
|
199
|
-
|
200
|
-
def compute_score(
|
201
|
-
metric_data: dict,
|
202
|
-
metric: Union[str, Scorer],
|
203
|
-
pos_label: Optional[int] = 1,
|
204
|
-
) -> float:
|
205
|
-
"""
|
206
|
-
Use sklearn to compute the score of one metric.
|
207
|
-
|
208
|
-
Parameters
|
209
|
-
----------
|
210
|
-
metric_data
|
211
|
-
A dictionary with the groundtruth (Y_TRUE) and predicted values (Y_PRED, Y_PRED_PROB).
|
212
|
-
The predicted class probabilities are required to compute the roc_auc score.
|
213
|
-
metric
|
214
|
-
The name of metric or the function of metric to compute.
|
215
|
-
pos_label
|
216
|
-
The encoded label (0 or 1) of binary classification's positive class.
|
217
|
-
|
218
|
-
Returns
|
219
|
-
-------
|
220
|
-
Computed score.
|
221
|
-
"""
|
222
|
-
if isinstance(metric, str) and metric in [OVERALL_ACCURACY, OVERALL_F1]:
|
223
|
-
metric = evaluate.load("seqeval")
|
224
|
-
warnings.filterwarnings("ignore")
|
225
|
-
for p in metric_data[Y_TRUE]:
|
226
|
-
if "_" in p:
|
227
|
-
print(p)
|
228
|
-
for p in metric_data[Y_PRED]:
|
229
|
-
if "_" in p:
|
230
|
-
print(p)
|
231
|
-
return metric.compute(references=metric_data[Y_TRUE], predictions=metric_data[Y_PRED])
|
232
|
-
|
233
|
-
metric = get_metric(metric)
|
234
|
-
|
235
|
-
y = metric_data[Y_TRUE]
|
236
|
-
if metric.needs_proba or metric.needs_threshold:
|
237
|
-
y_pred_proba = metric_data[Y_PRED_PROB]
|
238
|
-
y_pred_proba = (
|
239
|
-
y_pred_proba if y_pred_proba.shape[1] > 2 else y_pred_proba[:, pos_label]
|
240
|
-
) # only use pos_label for binary classification
|
241
|
-
return metric.convert_score_to_original(
|
242
|
-
compute_metric(y=y, y_pred_proba=y_pred_proba, metric=metric, weights=None)
|
243
|
-
)
|
244
|
-
else:
|
245
|
-
y_pred = metric_data[Y_PRED]
|
246
|
-
|
247
|
-
# TODO: This is a hack. Doesn't support `f1_macro`, `f1_micro`, `f1_weighted`, or custom `f1` metrics with different names.
|
248
|
-
# TODO: Longterm the solution should be to have the input data to this function use the internal representation without the original class names. This way `pos_label` would not need to be specified.
|
249
|
-
if metric.name == F1: # only for binary classification
|
250
|
-
y = (y == pos_label).astype(int)
|
251
|
-
y_pred = (y_pred == pos_label).astype(int)
|
252
|
-
|
253
|
-
return metric.convert_score_to_original(compute_metric(y=y, y_pred=y_pred, metric=metric, weights=None))
|
254
|
-
|
255
|
-
|
256
|
-
class RankingMetrics:
|
257
|
-
def __init__(
|
258
|
-
self,
|
259
|
-
pred: Dict[str, Dict],
|
260
|
-
target: Dict[str, Dict],
|
261
|
-
is_higher_better=True,
|
262
|
-
):
|
263
|
-
"""
|
264
|
-
Evaluation Metrics for information retrieval tasks such as document retrieval, image retrieval, etc.
|
265
|
-
Reference: https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
|
266
|
-
|
267
|
-
Parameters
|
268
|
-
----------
|
269
|
-
pred:
|
270
|
-
the prediction of the ranking model. It has the following form.
|
271
|
-
pred = {
|
272
|
-
'q1': {
|
273
|
-
'd1': 1,
|
274
|
-
'd3': 0,
|
275
|
-
},
|
276
|
-
'q2': {
|
277
|
-
'd2': 1,
|
278
|
-
'd3': 1,
|
279
|
-
},
|
280
|
-
}
|
281
|
-
where q refers to queries, and d refers to documents, each query has a few relevant documents.
|
282
|
-
0s and 1s are model predicted scores (does not need to be binary).
|
283
|
-
target:
|
284
|
-
the ground truth query and response relevance which has the same form as pred.
|
285
|
-
is_higher_better:
|
286
|
-
if higher relevance score means higher ranking.
|
287
|
-
if the relevance score is cosine similarity / dot product, it should be set to True;
|
288
|
-
if it is Eulidean distance, it should be False.
|
289
|
-
"""
|
290
|
-
self.pred = pred
|
291
|
-
self.target = target
|
292
|
-
self.is_higher_better = is_higher_better
|
293
|
-
# the supported metrics in this script
|
294
|
-
self.supported_metrics = {
|
295
|
-
"precision": 0,
|
296
|
-
"recall": 1,
|
297
|
-
"mrr": 2,
|
298
|
-
"map": 3,
|
299
|
-
"ndcg": 4,
|
300
|
-
}
|
301
|
-
|
302
|
-
assert len(pred) == len(
|
303
|
-
target
|
304
|
-
), f"The prediction and groudtruth target should have the same number of queries, \
|
305
|
-
while there are {len(pred)} queries in prediction and {len(target)} in the target."
|
306
|
-
|
307
|
-
self.results = {}
|
308
|
-
for key in target.keys():
|
309
|
-
self.results.update({key: [target[key], pred[key]]})
|
310
|
-
|
311
|
-
def compute(self, metrics: Union[str, list] = None, k: Optional[int] = 10):
|
312
|
-
"""
|
313
|
-
compute and return ranking scores.
|
314
|
-
|
315
|
-
Parameters
|
316
|
-
----------
|
317
|
-
metrics:
|
318
|
-
user provided metrics
|
319
|
-
k:
|
320
|
-
the cutoff value for NDCG, MAP, Recall, MRR, and Precision
|
321
|
-
|
322
|
-
Returns
|
323
|
-
-------
|
324
|
-
Computed score.
|
325
|
-
|
326
|
-
"""
|
327
|
-
if isinstance(metrics, str):
|
328
|
-
metrics = [metrics]
|
329
|
-
if not metrics: # no metric is provided
|
330
|
-
metrics = self.supported_metrics.keys()
|
331
|
-
|
332
|
-
return_res = {}
|
333
|
-
|
334
|
-
eval_res = np.mean(
|
335
|
-
[list(self._compute_one(idx, k)) for idx in self.results.keys()],
|
336
|
-
axis=0,
|
337
|
-
)
|
338
|
-
|
339
|
-
for metric in metrics:
|
340
|
-
metric = metric.lower()
|
341
|
-
if metric in self.supported_metrics:
|
342
|
-
return_res.update({f"{metric}@{k}": eval_res[self.supported_metrics[metric]]})
|
343
|
-
|
344
|
-
return return_res
|
345
|
-
|
346
|
-
def _compute_one(self, idx, k):
|
347
|
-
"""
|
348
|
-
compute and return the ranking scores for one query.
|
349
|
-
for definition of these metrics, please refer to
|
350
|
-
https://www.cs.cornell.edu/courses/cs4300/2013fa/lectures/metrics-2-4pp.pdf
|
351
|
-
|
352
|
-
Parameters
|
353
|
-
----------
|
354
|
-
idx:
|
355
|
-
the index of the query
|
356
|
-
k:
|
357
|
-
the cutoff value for NDCG, MAP, Recall, MRR, and Precision
|
358
|
-
|
359
|
-
Returns
|
360
|
-
-------
|
361
|
-
Computed score.
|
362
|
-
"""
|
363
|
-
precision, recall, mrr, mAP, ndcg = 0, 0, 0, 0, 0
|
364
|
-
target, pred = self.results[idx][0], self.results[idx][1]
|
365
|
-
|
366
|
-
# sort the ground truth and predictions in descending order
|
367
|
-
sorted_target = dict(
|
368
|
-
sorted(
|
369
|
-
target.items(),
|
370
|
-
key=operator.itemgetter(1),
|
371
|
-
reverse=self.is_higher_better,
|
372
|
-
)
|
373
|
-
)
|
374
|
-
sorted_pred = dict(
|
375
|
-
sorted(
|
376
|
-
pred.items(),
|
377
|
-
key=operator.itemgetter(1),
|
378
|
-
reverse=self.is_higher_better,
|
379
|
-
)
|
380
|
-
)
|
381
|
-
sorted_target_values = list(sorted_target.values())
|
382
|
-
sorted_pred_values = list(sorted_pred.values())
|
383
|
-
|
384
|
-
# number of positive relevance in target
|
385
|
-
# negative numbers and zero are considered as negative response
|
386
|
-
num_pos_target = len([val for val in sorted_target_values if val > 0])
|
387
|
-
|
388
|
-
at_k = k if num_pos_target > k else num_pos_target
|
389
|
-
|
390
|
-
first_k_items_list = list(sorted_pred.items())[0:k]
|
391
|
-
|
392
|
-
rank = 0
|
393
|
-
hit_rank = [] # correctly retrieved items
|
394
|
-
for key, value in first_k_items_list:
|
395
|
-
if key in sorted_target and sorted_target[key] > 0:
|
396
|
-
hit_rank.append(rank)
|
397
|
-
rank += 1
|
398
|
-
count = len(hit_rank)
|
399
|
-
# compute the precision and recall
|
400
|
-
precision = count / k
|
401
|
-
recall = count / num_pos_target
|
402
|
-
|
403
|
-
dcg = 0
|
404
|
-
if hit_rank: # not empty
|
405
|
-
# compute the mean reciprocal rank
|
406
|
-
mrr = 1 / (hit_rank[0] + 1)
|
407
|
-
# compute the mean average precision
|
408
|
-
mAP = np.sum([sorted_pred_values[rank] * (i + 1) / (rank + 1) for i, rank in enumerate(hit_rank)])
|
409
|
-
# compute the discounted cumulative gain
|
410
|
-
dcg = np.sum([1 / math.log(rank + 2, 2) for rank in hit_rank])
|
411
|
-
|
412
|
-
# compute the ideal discounted cumulative gain
|
413
|
-
idcg = np.sum([1 / math.log(i + 2, 2) for i in range(at_k)])
|
414
|
-
# compute the normalized discounted cumulative gain
|
415
|
-
ndcg = dcg / idcg
|
416
|
-
mAP /= at_k
|
417
|
-
|
418
|
-
return precision, recall, mrr, mAP, ndcg
|
419
|
-
|
420
|
-
|
421
|
-
def compute_ranking_score(
|
422
|
-
results: Dict[str, Dict],
|
423
|
-
qrel_dict: Dict[str, Dict],
|
424
|
-
metrics: List[str],
|
425
|
-
cutoffs: Optional[List[int]] = [5, 10, 20],
|
426
|
-
):
|
427
|
-
"""
|
428
|
-
Compute the ranking metrics, e.g., NDCG, MAP, Recall, and Precision.
|
429
|
-
TODO: Consider MRR.
|
430
|
-
|
431
|
-
Parameters
|
432
|
-
----------
|
433
|
-
results:
|
434
|
-
The query/document ranking list by the model.
|
435
|
-
qrel_dict:
|
436
|
-
The groundtruth query and document relevance.
|
437
|
-
metrics
|
438
|
-
A list of metrics to compute.
|
439
|
-
cutoffs:
|
440
|
-
The cutoff values for NDCG, MAP, Recall, and Precision.
|
441
|
-
|
442
|
-
Returns
|
443
|
-
-------
|
444
|
-
A dict of metric scores.
|
445
|
-
"""
|
446
|
-
scores = {}
|
447
|
-
evaluator = RankingMetrics(pred=results, target=qrel_dict)
|
448
|
-
for k in cutoffs:
|
449
|
-
scores.update(evaluator.compute(k=k))
|
450
|
-
|
451
|
-
metric_results = dict()
|
452
|
-
for k in cutoffs:
|
453
|
-
for per_metric in metrics:
|
454
|
-
if per_metric.lower() == NDCG:
|
455
|
-
metric_results[f"{NDCG}@{k}"] = 0.0
|
456
|
-
elif per_metric.lower() == MAP:
|
457
|
-
metric_results[f"{MAP}@{k}"] = 0.0
|
458
|
-
elif per_metric.lower() == RECALL:
|
459
|
-
metric_results[f"{RECALL}@{k}"] = 0.0
|
460
|
-
elif per_metric.lower() == PRECISION:
|
461
|
-
metric_results[f"{PRECISION}@{k}"] = 0.0
|
462
|
-
|
463
|
-
for k in cutoffs:
|
464
|
-
for per_metric in metrics:
|
465
|
-
if per_metric.lower() == NDCG:
|
466
|
-
metric_results[f"{NDCG}@{k}"] = round(scores[f"{NDCG}@{k}"], 5)
|
467
|
-
elif per_metric.lower() == MAP:
|
468
|
-
metric_results[f"{MAP}@{k}"] = round(scores[f"{MAP}@{k}"], 5)
|
469
|
-
elif per_metric.lower() == RECALL:
|
470
|
-
metric_results[f"{RECALL}@{k}"] = round(scores[f"{RECALL}@{k}"], 5)
|
471
|
-
elif per_metric.lower() == PRECISION:
|
472
|
-
metric_results[f"{PRECISION}@{k}"] = round(scores[f"{PRECISION}@{k}"], 5)
|
473
|
-
|
474
|
-
return metric_results
|
475
|
-
|
476
|
-
|
477
|
-
def infer_problem_type_by_eval_metric(eval_metric_name: str, problem_type: str):
|
478
|
-
if eval_metric_name is not None and eval_metric_name.lower() in [
|
479
|
-
"rmse",
|
480
|
-
"r2",
|
481
|
-
"pearsonr",
|
482
|
-
"spearmanr",
|
483
|
-
]:
|
484
|
-
if problem_type is None:
|
485
|
-
logger.debug(
|
486
|
-
f"Infer problem type to be a regression problem "
|
487
|
-
f"since the evaluation metric is set as {eval_metric_name}."
|
488
|
-
)
|
489
|
-
problem_type = REGRESSION
|
490
|
-
else:
|
491
|
-
problem_prop = PROBLEM_TYPES_REG.get(problem_type)
|
492
|
-
if NUMERICAL not in problem_prop.supported_label_type:
|
493
|
-
raise ValueError(
|
494
|
-
f"The provided evaluation metric will require the problem "
|
495
|
-
f"to support label type = {NUMERICAL}. However, "
|
496
|
-
f"the provided problem type = {problem_type} only "
|
497
|
-
f"supports label type = {problem_prop.supported_label_type}."
|
498
|
-
)
|
499
|
-
|
500
|
-
return problem_type
|