nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py
CHANGED
|
@@ -5,33 +5,55 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Author:
|
|
6
6
|
Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import logging
|
|
9
10
|
import numpy as np
|
|
10
11
|
from sklearn.metrics import (
|
|
11
|
-
roc_auc_score,
|
|
12
|
-
|
|
12
|
+
roc_auc_score,
|
|
13
|
+
log_loss,
|
|
14
|
+
mean_squared_error,
|
|
15
|
+
mean_absolute_error,
|
|
16
|
+
accuracy_score,
|
|
17
|
+
precision_score,
|
|
18
|
+
recall_score,
|
|
19
|
+
f1_score,
|
|
20
|
+
r2_score,
|
|
13
21
|
)
|
|
14
22
|
|
|
15
23
|
|
|
16
|
-
CLASSIFICATION_METRICS = {
|
|
17
|
-
|
|
24
|
+
CLASSIFICATION_METRICS = {
|
|
25
|
+
"auc",
|
|
26
|
+
"gauc",
|
|
27
|
+
"ks",
|
|
28
|
+
"logloss",
|
|
29
|
+
"accuracy",
|
|
30
|
+
"acc",
|
|
31
|
+
"precision",
|
|
32
|
+
"recall",
|
|
33
|
+
"f1",
|
|
34
|
+
"micro_f1",
|
|
35
|
+
"macro_f1",
|
|
36
|
+
}
|
|
37
|
+
REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
18
38
|
TASK_DEFAULT_METRICS = {
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
39
|
+
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
40
|
+
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
41
|
+
"multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
|
|
42
|
+
"matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
|
|
43
|
+
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
44
|
+
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
45
|
+
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
23
46
|
}
|
|
24
47
|
|
|
25
48
|
|
|
26
|
-
|
|
27
49
|
def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
28
50
|
"""Compute Kolmogorov-Smirnov statistic."""
|
|
29
51
|
sorted_indices = np.argsort(y_pred)[::-1]
|
|
30
52
|
y_true_sorted = y_true[sorted_indices]
|
|
31
|
-
|
|
53
|
+
|
|
32
54
|
n_pos = np.sum(y_true_sorted == 1)
|
|
33
55
|
n_neg = np.sum(y_true_sorted == 0)
|
|
34
|
-
|
|
56
|
+
|
|
35
57
|
if n_pos > 0 and n_neg > 0:
|
|
36
58
|
cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
|
|
37
59
|
cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
|
|
@@ -44,7 +66,9 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
44
66
|
"""Compute Mean Absolute Percentage Error."""
|
|
45
67
|
mask = y_true != 0
|
|
46
68
|
if np.any(mask):
|
|
47
|
-
return float(
|
|
69
|
+
return float(
|
|
70
|
+
np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
|
71
|
+
)
|
|
48
72
|
return 0.0
|
|
49
73
|
|
|
50
74
|
|
|
@@ -55,9 +79,7 @@ def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
55
79
|
|
|
56
80
|
|
|
57
81
|
def compute_gauc(
|
|
58
|
-
y_true: np.ndarray,
|
|
59
|
-
y_pred: np.ndarray,
|
|
60
|
-
user_ids: np.ndarray | None = None
|
|
82
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None
|
|
61
83
|
) -> float:
|
|
62
84
|
if user_ids is None:
|
|
63
85
|
# If no user_ids provided, fall back to regular AUC
|
|
@@ -65,37 +87,37 @@ def compute_gauc(
|
|
|
65
87
|
return float(roc_auc_score(y_true, y_pred))
|
|
66
88
|
except:
|
|
67
89
|
return 0.0
|
|
68
|
-
|
|
90
|
+
|
|
69
91
|
# Group by user_id and calculate AUC for each user
|
|
70
92
|
user_aucs = []
|
|
71
93
|
user_weights = []
|
|
72
|
-
|
|
94
|
+
|
|
73
95
|
unique_users = np.unique(user_ids)
|
|
74
|
-
|
|
96
|
+
|
|
75
97
|
for user_id in unique_users:
|
|
76
98
|
mask = user_ids == user_id
|
|
77
99
|
user_y_true = y_true[mask]
|
|
78
100
|
user_y_pred = y_pred[mask]
|
|
79
|
-
|
|
101
|
+
|
|
80
102
|
# Skip users with only one class (cannot compute AUC)
|
|
81
103
|
if len(np.unique(user_y_true)) < 2:
|
|
82
104
|
continue
|
|
83
|
-
|
|
105
|
+
|
|
84
106
|
try:
|
|
85
107
|
user_auc = roc_auc_score(user_y_true, user_y_pred)
|
|
86
108
|
user_aucs.append(user_auc)
|
|
87
109
|
user_weights.append(len(user_y_true))
|
|
88
110
|
except:
|
|
89
111
|
continue
|
|
90
|
-
|
|
112
|
+
|
|
91
113
|
if len(user_aucs) == 0:
|
|
92
114
|
return 0.0
|
|
93
|
-
|
|
115
|
+
|
|
94
116
|
# Weighted average
|
|
95
117
|
user_aucs = np.array(user_aucs)
|
|
96
118
|
user_weights = np.array(user_weights)
|
|
97
119
|
gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
|
|
98
|
-
|
|
120
|
+
|
|
99
121
|
return gauc
|
|
100
122
|
|
|
101
123
|
|
|
@@ -103,7 +125,7 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
|
|
|
103
125
|
"""Group sample indices by user_id. If user_ids is None, treat all as one group."""
|
|
104
126
|
if user_ids is None:
|
|
105
127
|
return [np.arange(n_samples)]
|
|
106
|
-
|
|
128
|
+
|
|
107
129
|
user_ids = np.asarray(user_ids)
|
|
108
130
|
if user_ids.shape[0] != n_samples:
|
|
109
131
|
logging.warning(
|
|
@@ -113,17 +135,19 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
|
|
|
113
135
|
n_samples,
|
|
114
136
|
)
|
|
115
137
|
return [np.arange(n_samples)]
|
|
116
|
-
|
|
138
|
+
|
|
117
139
|
unique_users = np.unique(user_ids)
|
|
118
140
|
groups = [np.where(user_ids == u)[0] for u in unique_users]
|
|
119
141
|
return groups
|
|
120
142
|
|
|
121
143
|
|
|
122
|
-
def _compute_precision_at_k(
|
|
144
|
+
def _compute_precision_at_k(
|
|
145
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
146
|
+
) -> float:
|
|
123
147
|
y_true = (y_true > 0).astype(int)
|
|
124
148
|
n = len(y_true)
|
|
125
149
|
groups = _group_indices_by_user(user_ids, n)
|
|
126
|
-
|
|
150
|
+
|
|
127
151
|
precisions = []
|
|
128
152
|
for idx in groups:
|
|
129
153
|
if idx.size == 0:
|
|
@@ -135,16 +159,18 @@ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np
|
|
|
135
159
|
topk = order[:k_user]
|
|
136
160
|
hits = labels[topk].sum()
|
|
137
161
|
precisions.append(hits / float(k_user))
|
|
138
|
-
|
|
162
|
+
|
|
139
163
|
return float(np.mean(precisions)) if precisions else 0.0
|
|
140
164
|
|
|
141
165
|
|
|
142
|
-
def _compute_recall_at_k(
|
|
166
|
+
def _compute_recall_at_k(
|
|
167
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
168
|
+
) -> float:
|
|
143
169
|
"""Compute Recall@K."""
|
|
144
170
|
y_true = (y_true > 0).astype(int)
|
|
145
171
|
n = len(y_true)
|
|
146
172
|
groups = _group_indices_by_user(user_ids, n)
|
|
147
|
-
|
|
173
|
+
|
|
148
174
|
recalls = []
|
|
149
175
|
for idx in groups:
|
|
150
176
|
if idx.size == 0:
|
|
@@ -159,16 +185,18 @@ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
|
|
|
159
185
|
topk = order[:k_user]
|
|
160
186
|
hits = labels[topk].sum()
|
|
161
187
|
recalls.append(hits / float(num_pos))
|
|
162
|
-
|
|
188
|
+
|
|
163
189
|
return float(np.mean(recalls)) if recalls else 0.0
|
|
164
190
|
|
|
165
191
|
|
|
166
|
-
def _compute_hitrate_at_k(
|
|
192
|
+
def _compute_hitrate_at_k(
|
|
193
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
194
|
+
) -> float:
|
|
167
195
|
"""Compute HitRate@K."""
|
|
168
196
|
y_true = (y_true > 0).astype(int)
|
|
169
197
|
n = len(y_true)
|
|
170
198
|
groups = _group_indices_by_user(user_ids, n)
|
|
171
|
-
|
|
199
|
+
|
|
172
200
|
hits_per_user = []
|
|
173
201
|
for idx in groups:
|
|
174
202
|
if idx.size == 0:
|
|
@@ -182,16 +210,18 @@ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.n
|
|
|
182
210
|
topk = order[:k_user]
|
|
183
211
|
hits = labels[topk].sum()
|
|
184
212
|
hits_per_user.append(1.0 if hits > 0 else 0.0)
|
|
185
|
-
|
|
213
|
+
|
|
186
214
|
return float(np.mean(hits_per_user)) if hits_per_user else 0.0
|
|
187
215
|
|
|
188
216
|
|
|
189
|
-
def _compute_mrr_at_k(
|
|
217
|
+
def _compute_mrr_at_k(
|
|
218
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
219
|
+
) -> float:
|
|
190
220
|
"""Compute MRR@K."""
|
|
191
221
|
y_true = (y_true > 0).astype(int)
|
|
192
222
|
n = len(y_true)
|
|
193
223
|
groups = _group_indices_by_user(user_ids, n)
|
|
194
|
-
|
|
224
|
+
|
|
195
225
|
mrrs = []
|
|
196
226
|
for idx in groups:
|
|
197
227
|
if idx.size == 0:
|
|
@@ -204,14 +234,14 @@ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
|
|
|
204
234
|
k_user = min(k, idx.size)
|
|
205
235
|
topk = order[:k_user]
|
|
206
236
|
ranked_labels = labels[order]
|
|
207
|
-
|
|
237
|
+
|
|
208
238
|
rr = 0.0
|
|
209
239
|
for rank, lab in enumerate(ranked_labels[:k_user], start=1):
|
|
210
240
|
if lab > 0:
|
|
211
241
|
rr = 1.0 / rank
|
|
212
242
|
break
|
|
213
243
|
mrrs.append(rr)
|
|
214
|
-
|
|
244
|
+
|
|
215
245
|
return float(np.mean(mrrs)) if mrrs else 0.0
|
|
216
246
|
|
|
217
247
|
|
|
@@ -224,12 +254,14 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
|
224
254
|
return float(np.sum(gains / discounts))
|
|
225
255
|
|
|
226
256
|
|
|
227
|
-
def _compute_ndcg_at_k(
|
|
257
|
+
def _compute_ndcg_at_k(
|
|
258
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
259
|
+
) -> float:
|
|
228
260
|
"""Compute NDCG@K."""
|
|
229
261
|
y_true = (y_true > 0).astype(int)
|
|
230
262
|
n = len(y_true)
|
|
231
263
|
groups = _group_indices_by_user(user_ids, n)
|
|
232
|
-
|
|
264
|
+
|
|
233
265
|
ndcgs = []
|
|
234
266
|
for idx in groups:
|
|
235
267
|
if idx.size == 0:
|
|
@@ -238,27 +270,29 @@ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndar
|
|
|
238
270
|
if labels.sum() == 0:
|
|
239
271
|
continue
|
|
240
272
|
scores = y_pred[idx]
|
|
241
|
-
|
|
273
|
+
|
|
242
274
|
order = np.argsort(scores)[::-1]
|
|
243
275
|
ranked_labels = labels[order]
|
|
244
276
|
dcg = _compute_dcg_at_k(ranked_labels, k)
|
|
245
|
-
|
|
277
|
+
|
|
246
278
|
# ideal DCG
|
|
247
279
|
ideal_labels = np.sort(labels)[::-1]
|
|
248
280
|
idcg = _compute_dcg_at_k(ideal_labels, k)
|
|
249
281
|
if idcg == 0.0:
|
|
250
282
|
continue
|
|
251
283
|
ndcgs.append(dcg / idcg)
|
|
252
|
-
|
|
284
|
+
|
|
253
285
|
return float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
254
286
|
|
|
255
287
|
|
|
256
|
-
def _compute_map_at_k(
|
|
288
|
+
def _compute_map_at_k(
|
|
289
|
+
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
290
|
+
) -> float:
|
|
257
291
|
"""Mean Average Precision@K."""
|
|
258
292
|
y_true = (y_true > 0).astype(int)
|
|
259
293
|
n = len(y_true)
|
|
260
294
|
groups = _group_indices_by_user(user_ids, n)
|
|
261
|
-
|
|
295
|
+
|
|
262
296
|
aps = []
|
|
263
297
|
for idx in groups:
|
|
264
298
|
if idx.size == 0:
|
|
@@ -267,23 +301,23 @@ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
|
|
|
267
301
|
num_pos = labels.sum()
|
|
268
302
|
if num_pos == 0:
|
|
269
303
|
continue
|
|
270
|
-
|
|
304
|
+
|
|
271
305
|
scores = y_pred[idx]
|
|
272
306
|
order = np.argsort(scores)[::-1]
|
|
273
307
|
k_user = min(k, idx.size)
|
|
274
|
-
|
|
308
|
+
|
|
275
309
|
hits = 0
|
|
276
310
|
sum_precisions = 0.0
|
|
277
311
|
for rank, i in enumerate(order[:k_user], start=1):
|
|
278
312
|
if labels[i] > 0:
|
|
279
313
|
hits += 1
|
|
280
314
|
sum_precisions += hits / float(rank)
|
|
281
|
-
|
|
315
|
+
|
|
282
316
|
if hits == 0:
|
|
283
317
|
aps.append(0.0)
|
|
284
318
|
else:
|
|
285
319
|
aps.append(sum_precisions / float(num_pos))
|
|
286
|
-
|
|
320
|
+
|
|
287
321
|
return float(np.mean(aps)) if aps else 0.0
|
|
288
322
|
|
|
289
323
|
|
|
@@ -292,24 +326,26 @@ def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
292
326
|
y_true = (y_true > 0).astype(int)
|
|
293
327
|
pos_mask = y_true == 1
|
|
294
328
|
neg_mask = y_true == 0
|
|
295
|
-
|
|
329
|
+
|
|
296
330
|
if not np.any(pos_mask) or not np.any(neg_mask):
|
|
297
331
|
return 0.0
|
|
298
|
-
|
|
332
|
+
|
|
299
333
|
pos_mean = float(np.mean(y_pred[pos_mask]))
|
|
300
334
|
neg_mean = float(np.mean(y_pred[neg_mask]))
|
|
301
335
|
return pos_mean - neg_mean
|
|
302
336
|
|
|
303
337
|
|
|
304
338
|
def configure_metrics(
|
|
305
|
-
task: str | list[str],
|
|
306
|
-
metrics:
|
|
307
|
-
|
|
339
|
+
task: str | list[str], # 'binary' or ['binary', 'regression']
|
|
340
|
+
metrics: (
|
|
341
|
+
list[str] | dict[str, list[str]] | None
|
|
342
|
+
), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
|
|
343
|
+
target_names: list[str], # ['target1', 'target2']
|
|
308
344
|
) -> tuple[list[str], dict[str, list[str]] | None, str]:
|
|
309
345
|
"""Configure metrics based on task and user input."""
|
|
310
346
|
primary_task = task[0] if isinstance(task, list) else task
|
|
311
347
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
312
|
-
|
|
348
|
+
|
|
313
349
|
metrics_list: list[str] = []
|
|
314
350
|
task_specific_metrics: dict[str, list[str]] | None = None
|
|
315
351
|
|
|
@@ -351,48 +387,62 @@ def configure_metrics(
|
|
|
351
387
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
352
388
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
353
389
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
354
|
-
|
|
390
|
+
|
|
355
391
|
if not metrics_list:
|
|
356
392
|
# Inline get_default_metrics_for_task logic
|
|
357
393
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
358
394
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
359
395
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
360
|
-
|
|
396
|
+
|
|
361
397
|
best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
|
|
362
|
-
|
|
398
|
+
|
|
363
399
|
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
364
400
|
|
|
365
401
|
|
|
366
402
|
def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
367
403
|
"""Determine if metric should be maximized or minimized."""
|
|
368
404
|
first_metric_lower = first_metric.lower()
|
|
369
|
-
|
|
405
|
+
|
|
370
406
|
# Metrics that should be maximized
|
|
371
|
-
if first_metric_lower in {
|
|
372
|
-
|
|
373
|
-
|
|
407
|
+
if first_metric_lower in {
|
|
408
|
+
"auc",
|
|
409
|
+
"gauc",
|
|
410
|
+
"ks",
|
|
411
|
+
"accuracy",
|
|
412
|
+
"acc",
|
|
413
|
+
"precision",
|
|
414
|
+
"recall",
|
|
415
|
+
"f1",
|
|
416
|
+
"r2",
|
|
417
|
+
"micro_f1",
|
|
418
|
+
"macro_f1",
|
|
419
|
+
}:
|
|
420
|
+
return "max"
|
|
421
|
+
|
|
374
422
|
# Ranking metrics that should be maximized (with @K suffix)
|
|
375
|
-
if (
|
|
376
|
-
first_metric_lower.startswith(
|
|
377
|
-
first_metric_lower.startswith(
|
|
378
|
-
first_metric_lower.startswith(
|
|
379
|
-
first_metric_lower.startswith(
|
|
380
|
-
first_metric_lower.startswith(
|
|
381
|
-
first_metric_lower.startswith(
|
|
382
|
-
|
|
383
|
-
|
|
423
|
+
if (
|
|
424
|
+
first_metric_lower.startswith("recall@")
|
|
425
|
+
or first_metric_lower.startswith("precision@")
|
|
426
|
+
or first_metric_lower.startswith("hitrate@")
|
|
427
|
+
or first_metric_lower.startswith("hr@")
|
|
428
|
+
or first_metric_lower.startswith("mrr@")
|
|
429
|
+
or first_metric_lower.startswith("ndcg@")
|
|
430
|
+
or first_metric_lower.startswith("map@")
|
|
431
|
+
):
|
|
432
|
+
return "max"
|
|
433
|
+
|
|
384
434
|
# Cosine separation should be maximized
|
|
385
|
-
if first_metric_lower ==
|
|
386
|
-
return
|
|
387
|
-
|
|
435
|
+
if first_metric_lower == "cosine":
|
|
436
|
+
return "max"
|
|
437
|
+
|
|
388
438
|
# Metrics that should be minimized
|
|
389
|
-
if first_metric_lower in {
|
|
390
|
-
return
|
|
391
|
-
|
|
439
|
+
if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
|
|
440
|
+
return "min"
|
|
441
|
+
|
|
392
442
|
# Default based on task type
|
|
393
|
-
if primary_task ==
|
|
394
|
-
return
|
|
395
|
-
return
|
|
443
|
+
if primary_task == "regression":
|
|
444
|
+
return "min"
|
|
445
|
+
return "max"
|
|
396
446
|
|
|
397
447
|
|
|
398
448
|
def compute_single_metric(
|
|
@@ -400,80 +450,111 @@ def compute_single_metric(
|
|
|
400
450
|
y_true: np.ndarray,
|
|
401
451
|
y_pred: np.ndarray,
|
|
402
452
|
task_type: str,
|
|
403
|
-
user_ids: np.ndarray | None = None
|
|
453
|
+
user_ids: np.ndarray | None = None,
|
|
404
454
|
) -> float:
|
|
405
455
|
"""Compute a single metric given true and predicted values."""
|
|
406
456
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
407
|
-
|
|
457
|
+
|
|
408
458
|
try:
|
|
409
459
|
metric_lower = metric.lower()
|
|
410
|
-
|
|
460
|
+
|
|
411
461
|
# recall@K
|
|
412
|
-
if metric_lower.startswith(
|
|
413
|
-
k = int(metric_lower.split(
|
|
462
|
+
if metric_lower.startswith("recall@"):
|
|
463
|
+
k = int(metric_lower.split("@")[1])
|
|
414
464
|
return _compute_recall_at_k(y_true, y_pred, user_ids, k)
|
|
415
465
|
|
|
416
466
|
# precision@K
|
|
417
|
-
if metric_lower.startswith(
|
|
418
|
-
k = int(metric_lower.split(
|
|
467
|
+
if metric_lower.startswith("precision@"):
|
|
468
|
+
k = int(metric_lower.split("@")[1])
|
|
419
469
|
return _compute_precision_at_k(y_true, y_pred, user_ids, k)
|
|
420
470
|
|
|
421
471
|
# hitrate@K / hr@K
|
|
422
|
-
if metric_lower.startswith(
|
|
423
|
-
k_str = metric_lower.split(
|
|
472
|
+
if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
|
|
473
|
+
k_str = metric_lower.split("@")[1]
|
|
424
474
|
k = int(k_str)
|
|
425
475
|
return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
|
|
426
476
|
|
|
427
477
|
# mrr@K
|
|
428
|
-
if metric_lower.startswith(
|
|
429
|
-
k = int(metric_lower.split(
|
|
478
|
+
if metric_lower.startswith("mrr@"):
|
|
479
|
+
k = int(metric_lower.split("@")[1])
|
|
430
480
|
return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
|
|
431
481
|
|
|
432
482
|
# ndcg@K
|
|
433
|
-
if metric_lower.startswith(
|
|
434
|
-
k = int(metric_lower.split(
|
|
483
|
+
if metric_lower.startswith("ndcg@"):
|
|
484
|
+
k = int(metric_lower.split("@")[1])
|
|
435
485
|
return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
|
|
436
486
|
|
|
437
487
|
# map@K
|
|
438
|
-
if metric_lower.startswith(
|
|
439
|
-
k = int(metric_lower.split(
|
|
488
|
+
if metric_lower.startswith("map@"):
|
|
489
|
+
k = int(metric_lower.split("@")[1])
|
|
440
490
|
return _compute_map_at_k(y_true, y_pred, user_ids, k)
|
|
441
491
|
|
|
442
492
|
# cosine for matching task
|
|
443
|
-
if metric_lower ==
|
|
493
|
+
if metric_lower == "cosine":
|
|
444
494
|
return _compute_cosine_separation(y_true, y_pred)
|
|
445
|
-
|
|
446
|
-
if metric ==
|
|
447
|
-
value = float(
|
|
448
|
-
|
|
495
|
+
|
|
496
|
+
if metric == "auc":
|
|
497
|
+
value = float(
|
|
498
|
+
roc_auc_score(
|
|
499
|
+
y_true,
|
|
500
|
+
y_pred,
|
|
501
|
+
average="macro" if task_type == "multilabel" else None,
|
|
502
|
+
)
|
|
503
|
+
)
|
|
504
|
+
elif metric == "gauc":
|
|
449
505
|
value = float(compute_gauc(y_true, y_pred, user_ids))
|
|
450
|
-
elif metric ==
|
|
506
|
+
elif metric == "ks":
|
|
451
507
|
value = float(compute_ks(y_true, y_pred))
|
|
452
|
-
elif metric ==
|
|
508
|
+
elif metric == "logloss":
|
|
453
509
|
value = float(log_loss(y_true, y_pred))
|
|
454
|
-
elif metric in (
|
|
510
|
+
elif metric in ("accuracy", "acc"):
|
|
455
511
|
value = float(accuracy_score(y_true, y_p_binary))
|
|
456
|
-
elif metric ==
|
|
457
|
-
value = float(
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
512
|
+
elif metric == "precision":
|
|
513
|
+
value = float(
|
|
514
|
+
precision_score(
|
|
515
|
+
y_true,
|
|
516
|
+
y_p_binary,
|
|
517
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
518
|
+
zero_division=0,
|
|
519
|
+
)
|
|
520
|
+
)
|
|
521
|
+
elif metric == "recall":
|
|
522
|
+
value = float(
|
|
523
|
+
recall_score(
|
|
524
|
+
y_true,
|
|
525
|
+
y_p_binary,
|
|
526
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
527
|
+
zero_division=0,
|
|
528
|
+
)
|
|
529
|
+
)
|
|
530
|
+
elif metric == "f1":
|
|
531
|
+
value = float(
|
|
532
|
+
f1_score(
|
|
533
|
+
y_true,
|
|
534
|
+
y_p_binary,
|
|
535
|
+
average="samples" if task_type == "multilabel" else "binary",
|
|
536
|
+
zero_division=0,
|
|
537
|
+
)
|
|
538
|
+
)
|
|
539
|
+
elif metric == "micro_f1":
|
|
540
|
+
value = float(
|
|
541
|
+
f1_score(y_true, y_p_binary, average="micro", zero_division=0)
|
|
542
|
+
)
|
|
543
|
+
elif metric == "macro_f1":
|
|
544
|
+
value = float(
|
|
545
|
+
f1_score(y_true, y_p_binary, average="macro", zero_division=0)
|
|
546
|
+
)
|
|
547
|
+
elif metric == "mse":
|
|
467
548
|
value = float(mean_squared_error(y_true, y_pred))
|
|
468
|
-
elif metric ==
|
|
549
|
+
elif metric == "mae":
|
|
469
550
|
value = float(mean_absolute_error(y_true, y_pred))
|
|
470
|
-
elif metric ==
|
|
551
|
+
elif metric == "rmse":
|
|
471
552
|
value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
|
472
|
-
elif metric ==
|
|
553
|
+
elif metric == "r2":
|
|
473
554
|
value = float(r2_score(y_true, y_pred))
|
|
474
|
-
elif metric ==
|
|
555
|
+
elif metric == "mape":
|
|
475
556
|
value = float(compute_mape(y_true, y_pred))
|
|
476
|
-
elif metric ==
|
|
557
|
+
elif metric == "msle":
|
|
477
558
|
value = float(compute_msle(y_true, y_pred))
|
|
478
559
|
else:
|
|
479
560
|
logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
|
|
@@ -481,35 +562,40 @@ def compute_single_metric(
|
|
|
481
562
|
except Exception as exception:
|
|
482
563
|
logging.warning(f"Failed to compute metric {metric}: {exception}")
|
|
483
564
|
value = 0.0
|
|
484
|
-
|
|
565
|
+
|
|
485
566
|
return value
|
|
486
567
|
|
|
568
|
+
|
|
487
569
|
def evaluate_metrics(
|
|
488
570
|
y_true: np.ndarray | None,
|
|
489
571
|
y_pred: np.ndarray | None,
|
|
490
|
-
metrics: list[str],
|
|
491
|
-
task: str | list[str],
|
|
492
|
-
target_names: list[str],
|
|
493
|
-
task_specific_metrics:
|
|
494
|
-
|
|
495
|
-
)
|
|
572
|
+
metrics: list[str], # example: ['auc', 'logloss']
|
|
573
|
+
task: str | list[str], # example: 'binary' or ['binary', 'regression']
|
|
574
|
+
target_names: list[str], # example: ['target1', 'target2']
|
|
575
|
+
task_specific_metrics: (
|
|
576
|
+
dict[str, list[str]] | None
|
|
577
|
+
) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
578
|
+
user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
|
|
579
|
+
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
496
580
|
"""Evaluate specified metrics for given true and predicted values."""
|
|
497
581
|
result = {}
|
|
498
|
-
|
|
582
|
+
|
|
499
583
|
if y_true is None or y_pred is None:
|
|
500
584
|
return result
|
|
501
|
-
|
|
585
|
+
|
|
502
586
|
# Main evaluation logic
|
|
503
587
|
primary_task = task[0] if isinstance(task, list) else task
|
|
504
588
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
505
|
-
|
|
589
|
+
|
|
506
590
|
# Single task evaluation
|
|
507
591
|
if nums_task == 1:
|
|
508
592
|
for metric in metrics:
|
|
509
593
|
metric_lower = metric.lower()
|
|
510
|
-
value = compute_single_metric(
|
|
594
|
+
value = compute_single_metric(
|
|
595
|
+
metric_lower, y_true, y_pred, primary_task, user_ids
|
|
596
|
+
)
|
|
511
597
|
result[metric_lower] = value
|
|
512
|
-
|
|
598
|
+
|
|
513
599
|
# Multi-task evaluation
|
|
514
600
|
else:
|
|
515
601
|
for metric in metrics:
|
|
@@ -519,7 +605,9 @@ def evaluate_metrics(
|
|
|
519
605
|
should_compute = True
|
|
520
606
|
if task_specific_metrics is not None and task_idx < len(target_names):
|
|
521
607
|
task_name = target_names[task_idx]
|
|
522
|
-
should_compute = metric_lower in task_specific_metrics.get(
|
|
608
|
+
should_compute = metric_lower in task_specific_metrics.get(
|
|
609
|
+
task_name, []
|
|
610
|
+
)
|
|
523
611
|
else:
|
|
524
612
|
# Get task type for specific index
|
|
525
613
|
if isinstance(task, list) and task_idx < len(task):
|
|
@@ -527,31 +615,51 @@ def evaluate_metrics(
|
|
|
527
615
|
elif isinstance(task, str):
|
|
528
616
|
task_type = task
|
|
529
617
|
else:
|
|
530
|
-
task_type =
|
|
531
|
-
|
|
532
|
-
if task_type in [
|
|
533
|
-
should_compute = metric_lower in {
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
618
|
+
task_type = "binary"
|
|
619
|
+
|
|
620
|
+
if task_type in ["binary", "multilabel"]:
|
|
621
|
+
should_compute = metric_lower in {
|
|
622
|
+
"auc",
|
|
623
|
+
"ks",
|
|
624
|
+
"logloss",
|
|
625
|
+
"accuracy",
|
|
626
|
+
"acc",
|
|
627
|
+
"precision",
|
|
628
|
+
"recall",
|
|
629
|
+
"f1",
|
|
630
|
+
"micro_f1",
|
|
631
|
+
"macro_f1",
|
|
632
|
+
}
|
|
633
|
+
elif task_type == "regression":
|
|
634
|
+
should_compute = metric_lower in {
|
|
635
|
+
"mse",
|
|
636
|
+
"mae",
|
|
637
|
+
"rmse",
|
|
638
|
+
"r2",
|
|
639
|
+
"mape",
|
|
640
|
+
"msle",
|
|
641
|
+
}
|
|
642
|
+
|
|
537
643
|
if not should_compute:
|
|
538
644
|
continue
|
|
539
|
-
|
|
645
|
+
|
|
540
646
|
target_name = target_names[task_idx]
|
|
541
|
-
|
|
647
|
+
|
|
542
648
|
# Get task type for specific index
|
|
543
649
|
if isinstance(task, list) and task_idx < len(task):
|
|
544
650
|
task_type = task[task_idx]
|
|
545
651
|
elif isinstance(task, str):
|
|
546
652
|
task_type = task
|
|
547
653
|
else:
|
|
548
|
-
task_type =
|
|
549
|
-
|
|
654
|
+
task_type = "binary"
|
|
655
|
+
|
|
550
656
|
y_true_task = y_true[:, task_idx]
|
|
551
657
|
y_pred_task = y_pred[:, task_idx]
|
|
552
|
-
|
|
658
|
+
|
|
553
659
|
# Compute metric
|
|
554
|
-
value = compute_single_metric(
|
|
555
|
-
|
|
556
|
-
|
|
660
|
+
value = compute_single_metric(
|
|
661
|
+
metric_lower, y_true_task, y_pred_task, task_type, user_ids
|
|
662
|
+
)
|
|
663
|
+
result[f"{metric_lower}_{target_name}"] = value
|
|
664
|
+
|
|
557
665
|
return result
|