nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py
CHANGED
|
@@ -5,22 +5,45 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Checkpoint: edit on 02/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import logging
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
from sklearn.metrics import (
|
|
13
|
-
roc_auc_score,
|
|
14
|
-
|
|
14
|
+
roc_auc_score,
|
|
15
|
+
log_loss,
|
|
16
|
+
mean_squared_error,
|
|
17
|
+
mean_absolute_error,
|
|
18
|
+
accuracy_score,
|
|
19
|
+
precision_score,
|
|
20
|
+
recall_score,
|
|
21
|
+
f1_score,
|
|
22
|
+
r2_score,
|
|
15
23
|
)
|
|
16
24
|
|
|
17
|
-
CLASSIFICATION_METRICS = {
|
|
18
|
-
|
|
25
|
+
CLASSIFICATION_METRICS = {
|
|
26
|
+
"auc",
|
|
27
|
+
"gauc",
|
|
28
|
+
"ks",
|
|
29
|
+
"logloss",
|
|
30
|
+
"accuracy",
|
|
31
|
+
"acc",
|
|
32
|
+
"precision",
|
|
33
|
+
"recall",
|
|
34
|
+
"f1",
|
|
35
|
+
"micro_f1",
|
|
36
|
+
"macro_f1",
|
|
37
|
+
}
|
|
38
|
+
REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
19
39
|
TASK_DEFAULT_METRICS = {
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
40
|
+
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
41
|
+
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
42
|
+
"multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
|
|
43
|
+
"matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
|
|
44
|
+
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
|
+
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
46
|
+
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
24
47
|
}
|
|
25
48
|
|
|
26
49
|
|
|
@@ -45,18 +68,21 @@ def check_user_id(*metric_sources: Any) -> bool:
|
|
|
45
68
|
for name in metric_names:
|
|
46
69
|
if name == "gauc":
|
|
47
70
|
return True
|
|
48
|
-
if name.startswith(
|
|
71
|
+
if name.startswith(
|
|
72
|
+
("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")
|
|
73
|
+
):
|
|
49
74
|
return True
|
|
50
75
|
return False
|
|
51
76
|
|
|
77
|
+
|
|
52
78
|
def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
53
79
|
"""Compute Kolmogorov-Smirnov statistic."""
|
|
54
80
|
sorted_indices = np.argsort(y_pred)[::-1]
|
|
55
81
|
y_true_sorted = y_true[sorted_indices]
|
|
56
|
-
|
|
82
|
+
|
|
57
83
|
n_pos = np.sum(y_true_sorted == 1)
|
|
58
84
|
n_neg = np.sum(y_true_sorted == 0)
|
|
59
|
-
|
|
85
|
+
|
|
60
86
|
if n_pos > 0 and n_neg > 0:
|
|
61
87
|
cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
|
|
62
88
|
cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
|
|
@@ -64,24 +90,34 @@ def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
64
90
|
return float(ks_value)
|
|
65
91
|
return 0.0
|
|
66
92
|
|
|
93
|
+
|
|
67
94
|
def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
68
95
|
"""Compute Mean Absolute Percentage Error."""
|
|
69
96
|
mask = y_true != 0
|
|
70
97
|
if np.any(mask):
|
|
71
|
-
return float(
|
|
98
|
+
return float(
|
|
99
|
+
np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
|
100
|
+
)
|
|
72
101
|
return 0.0
|
|
73
102
|
|
|
103
|
+
|
|
74
104
|
def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
75
105
|
"""Compute Mean Squared Log Error."""
|
|
76
106
|
y_pred_pos = np.maximum(y_pred, 0)
|
|
77
107
|
return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
|
|
78
108
|
|
|
79
|
-
|
|
109
|
+
|
|
110
|
+
def compute_gauc(
|
|
111
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None
|
|
112
|
+
) -> float:
|
|
80
113
|
if user_ids is None:
|
|
81
114
|
# If no user_ids provided, fall back to regular AUC
|
|
82
115
|
try:
|
|
83
116
|
return float(roc_auc_score(y_true, y_pred))
|
|
84
|
-
except:
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logging.warning(
|
|
119
|
+
f"[Metrics Warning: GAUC] Failed to compute AUC without user_ids: {e}"
|
|
120
|
+
)
|
|
85
121
|
return 0.0
|
|
86
122
|
# Group by user_id and calculate AUC for each user
|
|
87
123
|
user_aucs = []
|
|
@@ -94,12 +130,10 @@ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray |
|
|
|
94
130
|
# Skip users with only one class (cannot compute AUC)
|
|
95
131
|
if len(np.unique(user_y_true)) < 2:
|
|
96
132
|
continue
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
except:
|
|
102
|
-
continue
|
|
133
|
+
user_auc = roc_auc_score(user_y_true, user_y_pred)
|
|
134
|
+
user_aucs.append(user_auc)
|
|
135
|
+
user_weights.append(len(user_y_true))
|
|
136
|
+
|
|
103
137
|
if len(user_aucs) == 0:
|
|
104
138
|
return 0.0
|
|
105
139
|
# Weighted average
|
|
@@ -108,22 +142,30 @@ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray |
|
|
|
108
142
|
gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
|
|
109
143
|
return gauc
|
|
110
144
|
|
|
145
|
+
|
|
111
146
|
def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
|
|
112
147
|
"""Group sample indices by user_id. If user_ids is None, treat all as one group."""
|
|
113
148
|
if user_ids is None:
|
|
114
149
|
return [np.arange(n_samples)]
|
|
115
150
|
user_ids = np.asarray(user_ids)
|
|
116
151
|
if user_ids.shape[0] != n_samples:
|
|
117
|
-
logging.warning(
|
|
152
|
+
logging.warning(
|
|
153
|
+
f"[Metrics Warning: GAUC] user_ids length {user_ids.shape[0]} != number of samples {n_samples}, treating all samples as a single group for ranking metrics."
|
|
154
|
+
)
|
|
118
155
|
return [np.arange(n_samples)]
|
|
119
156
|
unique_users = np.unique(user_ids)
|
|
120
157
|
groups = [np.where(user_ids == u)[0] for u in unique_users]
|
|
121
158
|
return groups
|
|
122
159
|
|
|
123
|
-
|
|
160
|
+
|
|
161
|
+
def compute_precision_at_k(
|
|
162
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
163
|
+
) -> float:
|
|
124
164
|
"""Compute Precision@K."""
|
|
125
165
|
if user_ids is None:
|
|
126
|
-
raise ValueError(
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation."
|
|
168
|
+
)
|
|
127
169
|
y_true = (y_true > 0).astype(int)
|
|
128
170
|
n = len(y_true)
|
|
129
171
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -140,10 +182,15 @@ def compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.
|
|
|
140
182
|
precisions.append(hits / float(k_user))
|
|
141
183
|
return float(np.mean(precisions)) if precisions else 0.0
|
|
142
184
|
|
|
143
|
-
|
|
185
|
+
|
|
186
|
+
def compute_recall_at_k(
|
|
187
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
188
|
+
) -> float:
|
|
144
189
|
"""Compute Recall@K."""
|
|
145
190
|
if user_ids is None:
|
|
146
|
-
raise ValueError(
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation."
|
|
193
|
+
)
|
|
147
194
|
y_true = (y_true > 0).astype(int)
|
|
148
195
|
n = len(y_true)
|
|
149
196
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -163,10 +210,15 @@ def compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nda
|
|
|
163
210
|
recalls.append(hits / float(num_pos))
|
|
164
211
|
return float(np.mean(recalls)) if recalls else 0.0
|
|
165
212
|
|
|
166
|
-
|
|
213
|
+
|
|
214
|
+
def compute_hitrate_at_k(
|
|
215
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
216
|
+
) -> float:
|
|
167
217
|
"""Compute HitRate@K."""
|
|
168
218
|
if user_ids is None:
|
|
169
|
-
raise ValueError(
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation."
|
|
221
|
+
)
|
|
170
222
|
y_true = (y_true > 0).astype(int)
|
|
171
223
|
n = len(y_true)
|
|
172
224
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -185,10 +237,15 @@ def compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
|
|
|
185
237
|
hits_per_user.append(1.0 if hits > 0 else 0.0)
|
|
186
238
|
return float(np.mean(hits_per_user)) if hits_per_user else 0.0
|
|
187
239
|
|
|
188
|
-
|
|
240
|
+
|
|
241
|
+
def compute_mrr_at_k(
|
|
242
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
243
|
+
) -> float:
|
|
189
244
|
"""Compute MRR@K."""
|
|
190
245
|
if user_ids is None:
|
|
191
|
-
raise ValueError(
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation."
|
|
248
|
+
)
|
|
192
249
|
y_true = (y_true > 0).astype(int)
|
|
193
250
|
n = len(y_true)
|
|
194
251
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -212,6 +269,7 @@ def compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarra
|
|
|
212
269
|
mrrs.append(rr)
|
|
213
270
|
return float(np.mean(mrrs)) if mrrs else 0.0
|
|
214
271
|
|
|
272
|
+
|
|
215
273
|
def compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
216
274
|
k_user = min(k, labels.size)
|
|
217
275
|
if k_user == 0:
|
|
@@ -220,10 +278,15 @@ def compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
|
220
278
|
discounts = np.log2(np.arange(2, k_user + 2))
|
|
221
279
|
return float(np.sum(gains / discounts))
|
|
222
280
|
|
|
223
|
-
|
|
281
|
+
|
|
282
|
+
def compute_ndcg_at_k(
|
|
283
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
284
|
+
) -> float:
|
|
224
285
|
"""Compute NDCG@K."""
|
|
225
286
|
if user_ids is None:
|
|
226
|
-
raise ValueError(
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation."
|
|
289
|
+
)
|
|
227
290
|
y_true = (y_true > 0).astype(int)
|
|
228
291
|
n = len(y_true)
|
|
229
292
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -247,10 +310,14 @@ def compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
|
|
|
247
310
|
return float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
248
311
|
|
|
249
312
|
|
|
250
|
-
def compute_map_at_k(
|
|
313
|
+
def compute_map_at_k(
|
|
314
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
315
|
+
) -> float:
|
|
251
316
|
"""Mean Average Precision@K."""
|
|
252
317
|
if user_ids is None:
|
|
253
|
-
raise ValueError(
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation."
|
|
320
|
+
)
|
|
254
321
|
y_true = (y_true > 0).astype(int)
|
|
255
322
|
n = len(y_true)
|
|
256
323
|
groups = group_indices_by_user(user_ids, n)
|
|
@@ -283,19 +350,21 @@ def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
283
350
|
y_true = (y_true > 0).astype(int)
|
|
284
351
|
pos_mask = y_true == 1
|
|
285
352
|
neg_mask = y_true == 0
|
|
286
|
-
|
|
353
|
+
|
|
287
354
|
if not np.any(pos_mask) or not np.any(neg_mask):
|
|
288
355
|
return 0.0
|
|
289
|
-
|
|
356
|
+
|
|
290
357
|
pos_mean = float(np.mean(y_pred[pos_mask]))
|
|
291
358
|
neg_mean = float(np.mean(y_pred[neg_mask]))
|
|
292
359
|
return pos_mean - neg_mean
|
|
293
360
|
|
|
294
361
|
|
|
295
362
|
def configure_metrics(
|
|
296
|
-
task: str | list[str],
|
|
297
|
-
metrics:
|
|
298
|
-
|
|
363
|
+
task: str | list[str], # 'binary' or ['binary', 'regression']
|
|
364
|
+
metrics: (
|
|
365
|
+
list[str] | dict[str, list[str]] | None
|
|
366
|
+
), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
|
|
367
|
+
target_names: list[str], # ['target1', 'target2']
|
|
299
368
|
) -> tuple[list[str], dict[str, list[str]] | None, str]:
|
|
300
369
|
"""Configure metrics based on task and user input."""
|
|
301
370
|
primary_task = task[0] if isinstance(task, list) else task
|
|
@@ -307,7 +376,9 @@ def configure_metrics(
|
|
|
307
376
|
task_specific_metrics = {}
|
|
308
377
|
for task_name, task_metrics in metrics.items():
|
|
309
378
|
if task_name not in target_names:
|
|
310
|
-
logging.warning(
|
|
379
|
+
logging.warning(
|
|
380
|
+
f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics"
|
|
381
|
+
)
|
|
311
382
|
continue
|
|
312
383
|
lowered = [m.lower() for m in task_metrics]
|
|
313
384
|
task_specific_metrics[task_name] = lowered
|
|
@@ -341,114 +412,168 @@ def configure_metrics(
|
|
|
341
412
|
best_metrics_mode = getbest_metric_mode(metrics_list[0], primary_task)
|
|
342
413
|
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
343
414
|
|
|
415
|
+
|
|
344
416
|
def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
345
417
|
"""Determine if metric should be maximized or minimized."""
|
|
346
418
|
first_metric_lower = first_metric.lower()
|
|
347
419
|
# Metrics that should be maximized
|
|
348
|
-
if first_metric_lower in {
|
|
349
|
-
|
|
420
|
+
if first_metric_lower in {
|
|
421
|
+
"auc",
|
|
422
|
+
"gauc",
|
|
423
|
+
"ks",
|
|
424
|
+
"accuracy",
|
|
425
|
+
"acc",
|
|
426
|
+
"precision",
|
|
427
|
+
"recall",
|
|
428
|
+
"f1",
|
|
429
|
+
"r2",
|
|
430
|
+
"micro_f1",
|
|
431
|
+
"macro_f1",
|
|
432
|
+
}:
|
|
433
|
+
return "max"
|
|
350
434
|
# Ranking metrics that should be maximized (with @K suffix)
|
|
351
|
-
if (
|
|
352
|
-
first_metric_lower.startswith(
|
|
353
|
-
first_metric_lower.startswith(
|
|
354
|
-
first_metric_lower.startswith(
|
|
355
|
-
first_metric_lower.startswith(
|
|
356
|
-
first_metric_lower.startswith(
|
|
357
|
-
first_metric_lower.startswith(
|
|
358
|
-
|
|
435
|
+
if (
|
|
436
|
+
first_metric_lower.startswith("recall@")
|
|
437
|
+
or first_metric_lower.startswith("precision@")
|
|
438
|
+
or first_metric_lower.startswith("hitrate@")
|
|
439
|
+
or first_metric_lower.startswith("hr@")
|
|
440
|
+
or first_metric_lower.startswith("mrr@")
|
|
441
|
+
or first_metric_lower.startswith("ndcg@")
|
|
442
|
+
or first_metric_lower.startswith("map@")
|
|
443
|
+
):
|
|
444
|
+
return "max"
|
|
359
445
|
# Cosine separation should be maximized
|
|
360
|
-
if first_metric_lower ==
|
|
361
|
-
return
|
|
446
|
+
if first_metric_lower == "cosine":
|
|
447
|
+
return "max"
|
|
362
448
|
# Metrics that should be minimized
|
|
363
|
-
if first_metric_lower in {
|
|
364
|
-
return
|
|
449
|
+
if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
|
|
450
|
+
return "min"
|
|
365
451
|
# Default based on task type
|
|
366
|
-
if primary_task ==
|
|
367
|
-
return
|
|
368
|
-
return
|
|
452
|
+
if primary_task == "regression":
|
|
453
|
+
return "min"
|
|
454
|
+
return "max"
|
|
455
|
+
|
|
369
456
|
|
|
370
457
|
def compute_single_metric(
|
|
371
458
|
metric: str,
|
|
372
459
|
y_true: np.ndarray,
|
|
373
460
|
y_pred: np.ndarray,
|
|
374
461
|
task_type: str,
|
|
375
|
-
user_ids: np.ndarray | None = None
|
|
462
|
+
user_ids: np.ndarray | None = None,
|
|
376
463
|
) -> float:
|
|
377
464
|
"""Compute a single metric given true and predicted values."""
|
|
378
465
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
379
466
|
try:
|
|
380
467
|
metric_lower = metric.lower()
|
|
381
|
-
if metric_lower.startswith(
|
|
382
|
-
k = int(metric_lower.split(
|
|
383
|
-
return compute_recall_at_k(y_true, y_pred, user_ids, k)
|
|
384
|
-
if metric_lower.startswith(
|
|
385
|
-
k = int(metric_lower.split(
|
|
386
|
-
return compute_precision_at_k(y_true, y_pred, user_ids, k)
|
|
387
|
-
if metric_lower.startswith(
|
|
388
|
-
k_str = metric_lower.split(
|
|
468
|
+
if metric_lower.startswith("recall@"):
|
|
469
|
+
k = int(metric_lower.split("@")[1])
|
|
470
|
+
return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
471
|
+
if metric_lower.startswith("precision@"):
|
|
472
|
+
k = int(metric_lower.split("@")[1])
|
|
473
|
+
return compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
474
|
+
if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
|
|
475
|
+
k_str = metric_lower.split("@")[1]
|
|
389
476
|
k = int(k_str)
|
|
390
|
-
return compute_hitrate_at_k(y_true, y_pred, user_ids, k)
|
|
391
|
-
if metric_lower.startswith(
|
|
392
|
-
k = int(metric_lower.split(
|
|
393
|
-
return compute_mrr_at_k(y_true, y_pred, user_ids, k)
|
|
394
|
-
if metric_lower.startswith(
|
|
395
|
-
k = int(metric_lower.split(
|
|
396
|
-
return compute_ndcg_at_k(y_true, y_pred, user_ids, k)
|
|
397
|
-
if metric_lower.startswith(
|
|
398
|
-
k = int(metric_lower.split(
|
|
399
|
-
return compute_map_at_k(y_true, y_pred, user_ids, k)
|
|
477
|
+
return compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
478
|
+
if metric_lower.startswith("mrr@"):
|
|
479
|
+
k = int(metric_lower.split("@")[1])
|
|
480
|
+
return compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
481
|
+
if metric_lower.startswith("ndcg@"):
|
|
482
|
+
k = int(metric_lower.split("@")[1])
|
|
483
|
+
return compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
484
|
+
if metric_lower.startswith("map@"):
|
|
485
|
+
k = int(metric_lower.split("@")[1])
|
|
486
|
+
return compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
400
487
|
# cosine for matching task
|
|
401
|
-
if metric_lower ==
|
|
488
|
+
if metric_lower == "cosine":
|
|
402
489
|
return compute_cosine_separation(y_true, y_pred)
|
|
403
|
-
if metric ==
|
|
404
|
-
value = float(
|
|
405
|
-
|
|
490
|
+
if metric == "auc":
|
|
491
|
+
value = float(
|
|
492
|
+
roc_auc_score(
|
|
493
|
+
y_true,
|
|
494
|
+
y_pred,
|
|
495
|
+
average="macro" if task_type == "multilabel" else None,
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
elif metric == "gauc":
|
|
406
499
|
value = float(compute_gauc(y_true, y_pred, user_ids))
|
|
407
|
-
elif metric ==
|
|
500
|
+
elif metric == "ks":
|
|
408
501
|
value = float(compute_ks(y_true, y_pred))
|
|
409
|
-
elif metric ==
|
|
502
|
+
elif metric == "logloss":
|
|
410
503
|
value = float(log_loss(y_true, y_pred))
|
|
411
|
-
elif metric in (
|
|
504
|
+
elif metric in ("accuracy", "acc"):
|
|
412
505
|
value = float(accuracy_score(y_true, y_p_binary))
|
|
413
|
-
elif metric ==
|
|
414
|
-
value = float(
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
506
|
+
elif metric == "precision":
|
|
507
|
+
value = float(
|
|
508
|
+
precision_score(
|
|
509
|
+
y_true,
|
|
510
|
+
y_p_binary,
|
|
511
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
512
|
+
zero_division=0,
|
|
513
|
+
)
|
|
514
|
+
)
|
|
515
|
+
elif metric == "recall":
|
|
516
|
+
value = float(
|
|
517
|
+
recall_score(
|
|
518
|
+
y_true,
|
|
519
|
+
y_p_binary,
|
|
520
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
521
|
+
zero_division=0,
|
|
522
|
+
)
|
|
523
|
+
)
|
|
524
|
+
elif metric == "f1":
|
|
525
|
+
value = float(
|
|
526
|
+
f1_score(
|
|
527
|
+
y_true,
|
|
528
|
+
y_p_binary,
|
|
529
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
530
|
+
zero_division=0,
|
|
531
|
+
)
|
|
532
|
+
)
|
|
533
|
+
elif metric == "micro_f1":
|
|
534
|
+
value = float(
|
|
535
|
+
f1_score(y_true, y_p_binary, average="micro", zero_division=0)
|
|
536
|
+
)
|
|
537
|
+
elif metric == "macro_f1":
|
|
538
|
+
value = float(
|
|
539
|
+
f1_score(y_true, y_p_binary, average="macro", zero_division=0)
|
|
540
|
+
)
|
|
541
|
+
elif metric == "mse":
|
|
424
542
|
value = float(mean_squared_error(y_true, y_pred))
|
|
425
|
-
elif metric ==
|
|
543
|
+
elif metric == "mae":
|
|
426
544
|
value = float(mean_absolute_error(y_true, y_pred))
|
|
427
|
-
elif metric ==
|
|
545
|
+
elif metric == "rmse":
|
|
428
546
|
value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
|
429
|
-
elif metric ==
|
|
547
|
+
elif metric == "r2":
|
|
430
548
|
value = float(r2_score(y_true, y_pred))
|
|
431
|
-
elif metric ==
|
|
549
|
+
elif metric == "mape":
|
|
432
550
|
value = float(compute_mape(y_true, y_pred))
|
|
433
|
-
elif metric ==
|
|
551
|
+
elif metric == "msle":
|
|
434
552
|
value = float(compute_msle(y_true, y_pred))
|
|
435
553
|
else:
|
|
436
|
-
logging.warning(
|
|
554
|
+
logging.warning(
|
|
555
|
+
f"[Metric Warning] Metric '{metric}' is not supported, returning 0.0"
|
|
556
|
+
)
|
|
437
557
|
value = 0.0
|
|
438
558
|
except Exception as exception:
|
|
439
|
-
logging.warning(
|
|
559
|
+
logging.warning(
|
|
560
|
+
f"[Metric Warning] Failed to compute metric {metric}: {exception}"
|
|
561
|
+
)
|
|
440
562
|
value = 0.0
|
|
441
563
|
return value
|
|
442
564
|
|
|
565
|
+
|
|
443
566
|
def evaluate_metrics(
|
|
444
567
|
y_true: np.ndarray | None,
|
|
445
568
|
y_pred: np.ndarray | None,
|
|
446
|
-
metrics: list[str],
|
|
447
|
-
task: str | list[str],
|
|
448
|
-
target_names: list[str],
|
|
449
|
-
task_specific_metrics:
|
|
450
|
-
|
|
451
|
-
)
|
|
569
|
+
metrics: list[str], # example: ['auc', 'logloss']
|
|
570
|
+
task: str | list[str], # example: 'binary' or ['binary', 'regression']
|
|
571
|
+
target_names: list[str], # example: ['target1', 'target2']
|
|
572
|
+
task_specific_metrics: (
|
|
573
|
+
dict[str, list[str]] | None
|
|
574
|
+
) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
575
|
+
user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
|
|
576
|
+
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
452
577
|
"""Evaluate specified metrics for given true and predicted values."""
|
|
453
578
|
result = {}
|
|
454
579
|
if y_true is None or y_pred is None:
|
|
@@ -460,7 +585,9 @@ def evaluate_metrics(
|
|
|
460
585
|
if nums_task == 1:
|
|
461
586
|
for metric in metrics:
|
|
462
587
|
metric_lower = metric.lower()
|
|
463
|
-
value = compute_single_metric(
|
|
588
|
+
value = compute_single_metric(
|
|
589
|
+
metric_lower, y_true, y_pred, primary_task, user_ids
|
|
590
|
+
)
|
|
464
591
|
result[metric_lower] = value
|
|
465
592
|
# Multi-task evaluation
|
|
466
593
|
else:
|
|
@@ -471,7 +598,9 @@ def evaluate_metrics(
|
|
|
471
598
|
should_compute = True
|
|
472
599
|
if task_specific_metrics is not None and task_idx < len(target_names):
|
|
473
600
|
task_name = target_names[task_idx]
|
|
474
|
-
should_compute = metric_lower in task_specific_metrics.get(
|
|
601
|
+
should_compute = metric_lower in task_specific_metrics.get(
|
|
602
|
+
task_name, []
|
|
603
|
+
)
|
|
475
604
|
else:
|
|
476
605
|
# Get task type for specific index
|
|
477
606
|
if isinstance(task, list) and task_idx < len(task):
|
|
@@ -479,24 +608,44 @@ def evaluate_metrics(
|
|
|
479
608
|
elif isinstance(task, str):
|
|
480
609
|
task_type = task
|
|
481
610
|
else:
|
|
482
|
-
task_type =
|
|
483
|
-
if task_type in [
|
|
484
|
-
should_compute = metric_lower in {
|
|
485
|
-
|
|
486
|
-
|
|
611
|
+
task_type = "binary"
|
|
612
|
+
if task_type in ["binary", "multilabel"]:
|
|
613
|
+
should_compute = metric_lower in {
|
|
614
|
+
"auc",
|
|
615
|
+
"ks",
|
|
616
|
+
"logloss",
|
|
617
|
+
"accuracy",
|
|
618
|
+
"acc",
|
|
619
|
+
"precision",
|
|
620
|
+
"recall",
|
|
621
|
+
"f1",
|
|
622
|
+
"micro_f1",
|
|
623
|
+
"macro_f1",
|
|
624
|
+
}
|
|
625
|
+
elif task_type == "regression":
|
|
626
|
+
should_compute = metric_lower in {
|
|
627
|
+
"mse",
|
|
628
|
+
"mae",
|
|
629
|
+
"rmse",
|
|
630
|
+
"r2",
|
|
631
|
+
"mape",
|
|
632
|
+
"msle",
|
|
633
|
+
}
|
|
487
634
|
if not should_compute:
|
|
488
|
-
continue
|
|
489
|
-
target_name = target_names[task_idx]
|
|
635
|
+
continue
|
|
636
|
+
target_name = target_names[task_idx]
|
|
490
637
|
# Get task type for specific index
|
|
491
638
|
if isinstance(task, list) and task_idx < len(task):
|
|
492
639
|
task_type = task[task_idx]
|
|
493
640
|
elif isinstance(task, str):
|
|
494
641
|
task_type = task
|
|
495
642
|
else:
|
|
496
|
-
task_type =
|
|
643
|
+
task_type = "binary"
|
|
497
644
|
y_true_task = y_true[:, task_idx]
|
|
498
|
-
y_pred_task = y_pred[:, task_idx]
|
|
645
|
+
y_pred_task = y_pred[:, task_idx]
|
|
499
646
|
# Compute metric
|
|
500
|
-
value = compute_single_metric(
|
|
501
|
-
|
|
647
|
+
value = compute_single_metric(
|
|
648
|
+
metric_lower, y_true_task, y_pred_task, task_type, user_ids
|
|
649
|
+
)
|
|
650
|
+
result[f"{metric_lower}_{target_name}"] = value
|
|
502
651
|
return result
|