nextrec 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -10
- nextrec/basic/callback.py +0 -1
- nextrec/basic/dataloader.py +127 -168
- nextrec/basic/features.py +27 -24
- nextrec/basic/layers.py +159 -328
- nextrec/basic/loggers.py +37 -50
- nextrec/basic/metrics.py +147 -255
- nextrec/basic/model.py +462 -817
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +12 -16
- nextrec/data/preprocessor.py +252 -276
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +22 -30
- nextrec/loss/match_losses.py +83 -116
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +61 -70
- nextrec/models/match/dssm_v2.py +51 -61
- nextrec/models/match/mind.py +71 -89
- nextrec/models/match/sdm.py +81 -93
- nextrec/models/match/youtube_dnn.py +53 -62
- nextrec/models/multi_task/esmm.py +43 -49
- nextrec/models/multi_task/mmoe.py +56 -65
- nextrec/models/multi_task/ple.py +65 -92
- nextrec/models/multi_task/share_bottom.py +42 -48
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +30 -39
- nextrec/models/ranking/autoint.py +57 -70
- nextrec/models/ranking/dcn.py +35 -43
- nextrec/models/ranking/deepfm.py +28 -34
- nextrec/models/ranking/dien.py +79 -115
- nextrec/models/ranking/din.py +60 -84
- nextrec/models/ranking/fibinet.py +35 -51
- nextrec/models/ranking/fm.py +26 -28
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +31 -30
- nextrec/models/ranking/widedeep.py +31 -36
- nextrec/models/ranking/xdeepfm.py +39 -46
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +15 -23
- nextrec/utils/optimizer.py +10 -14
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/METADATA +16 -7
- nextrec-0.1.8.dist-info/RECORD +51 -0
- nextrec-0.1.4.dist-info/RECORD +0 -51
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/WHEEL +0 -0
- {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py
CHANGED
|
@@ -5,55 +5,33 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Author:
|
|
6
6
|
Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
8
|
import logging
|
|
10
9
|
import numpy as np
|
|
11
10
|
from sklearn.metrics import (
|
|
12
|
-
roc_auc_score,
|
|
13
|
-
|
|
14
|
-
mean_squared_error,
|
|
15
|
-
mean_absolute_error,
|
|
16
|
-
accuracy_score,
|
|
17
|
-
precision_score,
|
|
18
|
-
recall_score,
|
|
19
|
-
f1_score,
|
|
20
|
-
r2_score,
|
|
11
|
+
roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
|
|
12
|
+
accuracy_score, precision_score, recall_score, f1_score, r2_score,
|
|
21
13
|
)
|
|
22
14
|
|
|
23
15
|
|
|
24
|
-
CLASSIFICATION_METRICS = {
|
|
25
|
-
|
|
26
|
-
"gauc",
|
|
27
|
-
"ks",
|
|
28
|
-
"logloss",
|
|
29
|
-
"accuracy",
|
|
30
|
-
"acc",
|
|
31
|
-
"precision",
|
|
32
|
-
"recall",
|
|
33
|
-
"f1",
|
|
34
|
-
"micro_f1",
|
|
35
|
-
"macro_f1",
|
|
36
|
-
}
|
|
37
|
-
REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
16
|
+
CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
17
|
+
REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
38
18
|
TASK_DEFAULT_METRICS = {
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
44
|
-
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
45
|
-
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
19
|
+
'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
|
|
20
|
+
'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
|
|
21
|
+
'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
|
|
22
|
+
'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
|
|
46
23
|
}
|
|
47
24
|
|
|
48
25
|
|
|
26
|
+
|
|
49
27
|
def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
50
28
|
"""Compute Kolmogorov-Smirnov statistic."""
|
|
51
29
|
sorted_indices = np.argsort(y_pred)[::-1]
|
|
52
30
|
y_true_sorted = y_true[sorted_indices]
|
|
53
|
-
|
|
31
|
+
|
|
54
32
|
n_pos = np.sum(y_true_sorted == 1)
|
|
55
33
|
n_neg = np.sum(y_true_sorted == 0)
|
|
56
|
-
|
|
34
|
+
|
|
57
35
|
if n_pos > 0 and n_neg > 0:
|
|
58
36
|
cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
|
|
59
37
|
cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
|
|
@@ -66,9 +44,7 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
66
44
|
"""Compute Mean Absolute Percentage Error."""
|
|
67
45
|
mask = y_true != 0
|
|
68
46
|
if np.any(mask):
|
|
69
|
-
return float(
|
|
70
|
-
np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
|
71
|
-
)
|
|
47
|
+
return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
|
|
72
48
|
return 0.0
|
|
73
49
|
|
|
74
50
|
|
|
@@ -79,7 +55,9 @@ def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
79
55
|
|
|
80
56
|
|
|
81
57
|
def compute_gauc(
|
|
82
|
-
y_true: np.ndarray,
|
|
58
|
+
y_true: np.ndarray,
|
|
59
|
+
y_pred: np.ndarray,
|
|
60
|
+
user_ids: np.ndarray | None = None
|
|
83
61
|
) -> float:
|
|
84
62
|
if user_ids is None:
|
|
85
63
|
# If no user_ids provided, fall back to regular AUC
|
|
@@ -87,37 +65,37 @@ def compute_gauc(
|
|
|
87
65
|
return float(roc_auc_score(y_true, y_pred))
|
|
88
66
|
except:
|
|
89
67
|
return 0.0
|
|
90
|
-
|
|
68
|
+
|
|
91
69
|
# Group by user_id and calculate AUC for each user
|
|
92
70
|
user_aucs = []
|
|
93
71
|
user_weights = []
|
|
94
|
-
|
|
72
|
+
|
|
95
73
|
unique_users = np.unique(user_ids)
|
|
96
|
-
|
|
74
|
+
|
|
97
75
|
for user_id in unique_users:
|
|
98
76
|
mask = user_ids == user_id
|
|
99
77
|
user_y_true = y_true[mask]
|
|
100
78
|
user_y_pred = y_pred[mask]
|
|
101
|
-
|
|
79
|
+
|
|
102
80
|
# Skip users with only one class (cannot compute AUC)
|
|
103
81
|
if len(np.unique(user_y_true)) < 2:
|
|
104
82
|
continue
|
|
105
|
-
|
|
83
|
+
|
|
106
84
|
try:
|
|
107
85
|
user_auc = roc_auc_score(user_y_true, user_y_pred)
|
|
108
86
|
user_aucs.append(user_auc)
|
|
109
87
|
user_weights.append(len(user_y_true))
|
|
110
88
|
except:
|
|
111
89
|
continue
|
|
112
|
-
|
|
90
|
+
|
|
113
91
|
if len(user_aucs) == 0:
|
|
114
92
|
return 0.0
|
|
115
|
-
|
|
93
|
+
|
|
116
94
|
# Weighted average
|
|
117
95
|
user_aucs = np.array(user_aucs)
|
|
118
96
|
user_weights = np.array(user_weights)
|
|
119
97
|
gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
|
|
120
|
-
|
|
98
|
+
|
|
121
99
|
return gauc
|
|
122
100
|
|
|
123
101
|
|
|
@@ -125,7 +103,7 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
|
|
|
125
103
|
"""Group sample indices by user_id. If user_ids is None, treat all as one group."""
|
|
126
104
|
if user_ids is None:
|
|
127
105
|
return [np.arange(n_samples)]
|
|
128
|
-
|
|
106
|
+
|
|
129
107
|
user_ids = np.asarray(user_ids)
|
|
130
108
|
if user_ids.shape[0] != n_samples:
|
|
131
109
|
logging.warning(
|
|
@@ -135,19 +113,17 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
|
|
|
135
113
|
n_samples,
|
|
136
114
|
)
|
|
137
115
|
return [np.arange(n_samples)]
|
|
138
|
-
|
|
116
|
+
|
|
139
117
|
unique_users = np.unique(user_ids)
|
|
140
118
|
groups = [np.where(user_ids == u)[0] for u in unique_users]
|
|
141
119
|
return groups
|
|
142
120
|
|
|
143
121
|
|
|
144
|
-
def _compute_precision_at_k(
|
|
145
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
146
|
-
) -> float:
|
|
122
|
+
def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
147
123
|
y_true = (y_true > 0).astype(int)
|
|
148
124
|
n = len(y_true)
|
|
149
125
|
groups = _group_indices_by_user(user_ids, n)
|
|
150
|
-
|
|
126
|
+
|
|
151
127
|
precisions = []
|
|
152
128
|
for idx in groups:
|
|
153
129
|
if idx.size == 0:
|
|
@@ -159,18 +135,16 @@ def _compute_precision_at_k(
|
|
|
159
135
|
topk = order[:k_user]
|
|
160
136
|
hits = labels[topk].sum()
|
|
161
137
|
precisions.append(hits / float(k_user))
|
|
162
|
-
|
|
138
|
+
|
|
163
139
|
return float(np.mean(precisions)) if precisions else 0.0
|
|
164
140
|
|
|
165
141
|
|
|
166
|
-
def _compute_recall_at_k(
|
|
167
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
168
|
-
) -> float:
|
|
142
|
+
def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
169
143
|
"""Compute Recall@K."""
|
|
170
144
|
y_true = (y_true > 0).astype(int)
|
|
171
145
|
n = len(y_true)
|
|
172
146
|
groups = _group_indices_by_user(user_ids, n)
|
|
173
|
-
|
|
147
|
+
|
|
174
148
|
recalls = []
|
|
175
149
|
for idx in groups:
|
|
176
150
|
if idx.size == 0:
|
|
@@ -185,18 +159,16 @@ def _compute_recall_at_k(
|
|
|
185
159
|
topk = order[:k_user]
|
|
186
160
|
hits = labels[topk].sum()
|
|
187
161
|
recalls.append(hits / float(num_pos))
|
|
188
|
-
|
|
162
|
+
|
|
189
163
|
return float(np.mean(recalls)) if recalls else 0.0
|
|
190
164
|
|
|
191
165
|
|
|
192
|
-
def _compute_hitrate_at_k(
|
|
193
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
194
|
-
) -> float:
|
|
166
|
+
def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
195
167
|
"""Compute HitRate@K."""
|
|
196
168
|
y_true = (y_true > 0).astype(int)
|
|
197
169
|
n = len(y_true)
|
|
198
170
|
groups = _group_indices_by_user(user_ids, n)
|
|
199
|
-
|
|
171
|
+
|
|
200
172
|
hits_per_user = []
|
|
201
173
|
for idx in groups:
|
|
202
174
|
if idx.size == 0:
|
|
@@ -210,18 +182,16 @@ def _compute_hitrate_at_k(
|
|
|
210
182
|
topk = order[:k_user]
|
|
211
183
|
hits = labels[topk].sum()
|
|
212
184
|
hits_per_user.append(1.0 if hits > 0 else 0.0)
|
|
213
|
-
|
|
185
|
+
|
|
214
186
|
return float(np.mean(hits_per_user)) if hits_per_user else 0.0
|
|
215
187
|
|
|
216
188
|
|
|
217
|
-
def _compute_mrr_at_k(
|
|
218
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
219
|
-
) -> float:
|
|
189
|
+
def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
220
190
|
"""Compute MRR@K."""
|
|
221
191
|
y_true = (y_true > 0).astype(int)
|
|
222
192
|
n = len(y_true)
|
|
223
193
|
groups = _group_indices_by_user(user_ids, n)
|
|
224
|
-
|
|
194
|
+
|
|
225
195
|
mrrs = []
|
|
226
196
|
for idx in groups:
|
|
227
197
|
if idx.size == 0:
|
|
@@ -234,14 +204,14 @@ def _compute_mrr_at_k(
|
|
|
234
204
|
k_user = min(k, idx.size)
|
|
235
205
|
topk = order[:k_user]
|
|
236
206
|
ranked_labels = labels[order]
|
|
237
|
-
|
|
207
|
+
|
|
238
208
|
rr = 0.0
|
|
239
209
|
for rank, lab in enumerate(ranked_labels[:k_user], start=1):
|
|
240
210
|
if lab > 0:
|
|
241
211
|
rr = 1.0 / rank
|
|
242
212
|
break
|
|
243
213
|
mrrs.append(rr)
|
|
244
|
-
|
|
214
|
+
|
|
245
215
|
return float(np.mean(mrrs)) if mrrs else 0.0
|
|
246
216
|
|
|
247
217
|
|
|
@@ -254,14 +224,12 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
|
254
224
|
return float(np.sum(gains / discounts))
|
|
255
225
|
|
|
256
226
|
|
|
257
|
-
def _compute_ndcg_at_k(
|
|
258
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
259
|
-
) -> float:
|
|
227
|
+
def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
260
228
|
"""Compute NDCG@K."""
|
|
261
229
|
y_true = (y_true > 0).astype(int)
|
|
262
230
|
n = len(y_true)
|
|
263
231
|
groups = _group_indices_by_user(user_ids, n)
|
|
264
|
-
|
|
232
|
+
|
|
265
233
|
ndcgs = []
|
|
266
234
|
for idx in groups:
|
|
267
235
|
if idx.size == 0:
|
|
@@ -270,29 +238,27 @@ def _compute_ndcg_at_k(
|
|
|
270
238
|
if labels.sum() == 0:
|
|
271
239
|
continue
|
|
272
240
|
scores = y_pred[idx]
|
|
273
|
-
|
|
241
|
+
|
|
274
242
|
order = np.argsort(scores)[::-1]
|
|
275
243
|
ranked_labels = labels[order]
|
|
276
244
|
dcg = _compute_dcg_at_k(ranked_labels, k)
|
|
277
|
-
|
|
245
|
+
|
|
278
246
|
# ideal DCG
|
|
279
247
|
ideal_labels = np.sort(labels)[::-1]
|
|
280
248
|
idcg = _compute_dcg_at_k(ideal_labels, k)
|
|
281
249
|
if idcg == 0.0:
|
|
282
250
|
continue
|
|
283
251
|
ndcgs.append(dcg / idcg)
|
|
284
|
-
|
|
252
|
+
|
|
285
253
|
return float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
286
254
|
|
|
287
255
|
|
|
288
|
-
def _compute_map_at_k(
|
|
289
|
-
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
|
|
290
|
-
) -> float:
|
|
256
|
+
def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
291
257
|
"""Mean Average Precision@K."""
|
|
292
258
|
y_true = (y_true > 0).astype(int)
|
|
293
259
|
n = len(y_true)
|
|
294
260
|
groups = _group_indices_by_user(user_ids, n)
|
|
295
|
-
|
|
261
|
+
|
|
296
262
|
aps = []
|
|
297
263
|
for idx in groups:
|
|
298
264
|
if idx.size == 0:
|
|
@@ -301,23 +267,23 @@ def _compute_map_at_k(
|
|
|
301
267
|
num_pos = labels.sum()
|
|
302
268
|
if num_pos == 0:
|
|
303
269
|
continue
|
|
304
|
-
|
|
270
|
+
|
|
305
271
|
scores = y_pred[idx]
|
|
306
272
|
order = np.argsort(scores)[::-1]
|
|
307
273
|
k_user = min(k, idx.size)
|
|
308
|
-
|
|
274
|
+
|
|
309
275
|
hits = 0
|
|
310
276
|
sum_precisions = 0.0
|
|
311
277
|
for rank, i in enumerate(order[:k_user], start=1):
|
|
312
278
|
if labels[i] > 0:
|
|
313
279
|
hits += 1
|
|
314
280
|
sum_precisions += hits / float(rank)
|
|
315
|
-
|
|
281
|
+
|
|
316
282
|
if hits == 0:
|
|
317
283
|
aps.append(0.0)
|
|
318
284
|
else:
|
|
319
285
|
aps.append(sum_precisions / float(num_pos))
|
|
320
|
-
|
|
286
|
+
|
|
321
287
|
return float(np.mean(aps)) if aps else 0.0
|
|
322
288
|
|
|
323
289
|
|
|
@@ -326,26 +292,24 @@ def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
326
292
|
y_true = (y_true > 0).astype(int)
|
|
327
293
|
pos_mask = y_true == 1
|
|
328
294
|
neg_mask = y_true == 0
|
|
329
|
-
|
|
295
|
+
|
|
330
296
|
if not np.any(pos_mask) or not np.any(neg_mask):
|
|
331
297
|
return 0.0
|
|
332
|
-
|
|
298
|
+
|
|
333
299
|
pos_mean = float(np.mean(y_pred[pos_mask]))
|
|
334
300
|
neg_mean = float(np.mean(y_pred[neg_mask]))
|
|
335
301
|
return pos_mean - neg_mean
|
|
336
302
|
|
|
337
303
|
|
|
338
304
|
def configure_metrics(
|
|
339
|
-
task: str | list[str],
|
|
340
|
-
metrics:
|
|
341
|
-
|
|
342
|
-
), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
|
|
343
|
-
target_names: list[str], # ['target1', 'target2']
|
|
305
|
+
task: str | list[str], # 'binary' or ['binary', 'regression']
|
|
306
|
+
metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
|
|
307
|
+
target_names: list[str] # ['target1', 'target2']
|
|
344
308
|
) -> tuple[list[str], dict[str, list[str]] | None, str]:
|
|
345
309
|
"""Configure metrics based on task and user input."""
|
|
346
310
|
primary_task = task[0] if isinstance(task, list) else task
|
|
347
311
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
348
|
-
|
|
312
|
+
|
|
349
313
|
metrics_list: list[str] = []
|
|
350
314
|
task_specific_metrics: dict[str, list[str]] | None = None
|
|
351
315
|
|
|
@@ -387,62 +351,48 @@ def configure_metrics(
|
|
|
387
351
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
388
352
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
389
353
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
390
|
-
|
|
354
|
+
|
|
391
355
|
if not metrics_list:
|
|
392
356
|
# Inline get_default_metrics_for_task logic
|
|
393
357
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
394
358
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
395
359
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
396
|
-
|
|
360
|
+
|
|
397
361
|
best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
|
|
398
|
-
|
|
362
|
+
|
|
399
363
|
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
400
364
|
|
|
401
365
|
|
|
402
366
|
def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
403
367
|
"""Determine if metric should be maximized or minimized."""
|
|
404
368
|
first_metric_lower = first_metric.lower()
|
|
405
|
-
|
|
369
|
+
|
|
406
370
|
# Metrics that should be maximized
|
|
407
|
-
if first_metric_lower in {
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
"ks",
|
|
411
|
-
"accuracy",
|
|
412
|
-
"acc",
|
|
413
|
-
"precision",
|
|
414
|
-
"recall",
|
|
415
|
-
"f1",
|
|
416
|
-
"r2",
|
|
417
|
-
"micro_f1",
|
|
418
|
-
"macro_f1",
|
|
419
|
-
}:
|
|
420
|
-
return "max"
|
|
421
|
-
|
|
371
|
+
if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
|
|
372
|
+
return 'max'
|
|
373
|
+
|
|
422
374
|
# Ranking metrics that should be maximized (with @K suffix)
|
|
423
|
-
if (
|
|
424
|
-
first_metric_lower.startswith(
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
return "max"
|
|
433
|
-
|
|
375
|
+
if (first_metric_lower.startswith('recall@') or
|
|
376
|
+
first_metric_lower.startswith('precision@') or
|
|
377
|
+
first_metric_lower.startswith('hitrate@') or
|
|
378
|
+
first_metric_lower.startswith('hr@') or
|
|
379
|
+
first_metric_lower.startswith('mrr@') or
|
|
380
|
+
first_metric_lower.startswith('ndcg@') or
|
|
381
|
+
first_metric_lower.startswith('map@')):
|
|
382
|
+
return 'max'
|
|
383
|
+
|
|
434
384
|
# Cosine separation should be maximized
|
|
435
|
-
if first_metric_lower ==
|
|
436
|
-
return
|
|
437
|
-
|
|
385
|
+
if first_metric_lower == 'cosine':
|
|
386
|
+
return 'max'
|
|
387
|
+
|
|
438
388
|
# Metrics that should be minimized
|
|
439
|
-
if first_metric_lower in {
|
|
440
|
-
return
|
|
441
|
-
|
|
389
|
+
if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
|
|
390
|
+
return 'min'
|
|
391
|
+
|
|
442
392
|
# Default based on task type
|
|
443
|
-
if primary_task ==
|
|
444
|
-
return
|
|
445
|
-
return
|
|
393
|
+
if primary_task == 'regression':
|
|
394
|
+
return 'min'
|
|
395
|
+
return 'max'
|
|
446
396
|
|
|
447
397
|
|
|
448
398
|
def compute_single_metric(
|
|
@@ -450,111 +400,80 @@ def compute_single_metric(
|
|
|
450
400
|
y_true: np.ndarray,
|
|
451
401
|
y_pred: np.ndarray,
|
|
452
402
|
task_type: str,
|
|
453
|
-
user_ids: np.ndarray | None = None
|
|
403
|
+
user_ids: np.ndarray | None = None
|
|
454
404
|
) -> float:
|
|
455
405
|
"""Compute a single metric given true and predicted values."""
|
|
456
406
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
457
|
-
|
|
407
|
+
|
|
458
408
|
try:
|
|
459
409
|
metric_lower = metric.lower()
|
|
460
|
-
|
|
410
|
+
|
|
461
411
|
# recall@K
|
|
462
|
-
if metric_lower.startswith(
|
|
463
|
-
k = int(metric_lower.split(
|
|
412
|
+
if metric_lower.startswith('recall@'):
|
|
413
|
+
k = int(metric_lower.split('@')[1])
|
|
464
414
|
return _compute_recall_at_k(y_true, y_pred, user_ids, k)
|
|
465
415
|
|
|
466
416
|
# precision@K
|
|
467
|
-
if metric_lower.startswith(
|
|
468
|
-
k = int(metric_lower.split(
|
|
417
|
+
if metric_lower.startswith('precision@'):
|
|
418
|
+
k = int(metric_lower.split('@')[1])
|
|
469
419
|
return _compute_precision_at_k(y_true, y_pred, user_ids, k)
|
|
470
420
|
|
|
471
421
|
# hitrate@K / hr@K
|
|
472
|
-
if metric_lower.startswith(
|
|
473
|
-
k_str = metric_lower.split(
|
|
422
|
+
if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
|
|
423
|
+
k_str = metric_lower.split('@')[1]
|
|
474
424
|
k = int(k_str)
|
|
475
425
|
return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
|
|
476
426
|
|
|
477
427
|
# mrr@K
|
|
478
|
-
if metric_lower.startswith(
|
|
479
|
-
k = int(metric_lower.split(
|
|
428
|
+
if metric_lower.startswith('mrr@'):
|
|
429
|
+
k = int(metric_lower.split('@')[1])
|
|
480
430
|
return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
|
|
481
431
|
|
|
482
432
|
# ndcg@K
|
|
483
|
-
if metric_lower.startswith(
|
|
484
|
-
k = int(metric_lower.split(
|
|
433
|
+
if metric_lower.startswith('ndcg@'):
|
|
434
|
+
k = int(metric_lower.split('@')[1])
|
|
485
435
|
return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
|
|
486
436
|
|
|
487
437
|
# map@K
|
|
488
|
-
if metric_lower.startswith(
|
|
489
|
-
k = int(metric_lower.split(
|
|
438
|
+
if metric_lower.startswith('map@'):
|
|
439
|
+
k = int(metric_lower.split('@')[1])
|
|
490
440
|
return _compute_map_at_k(y_true, y_pred, user_ids, k)
|
|
491
441
|
|
|
492
442
|
# cosine for matching task
|
|
493
|
-
if metric_lower ==
|
|
443
|
+
if metric_lower == 'cosine':
|
|
494
444
|
return _compute_cosine_separation(y_true, y_pred)
|
|
495
|
-
|
|
496
|
-
if metric ==
|
|
497
|
-
value = float(
|
|
498
|
-
|
|
499
|
-
y_true,
|
|
500
|
-
y_pred,
|
|
501
|
-
average="macro" if task_type == "multilabel" else None,
|
|
502
|
-
)
|
|
503
|
-
)
|
|
504
|
-
elif metric == "gauc":
|
|
445
|
+
|
|
446
|
+
if metric == 'auc':
|
|
447
|
+
value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
|
|
448
|
+
elif metric == 'gauc':
|
|
505
449
|
value = float(compute_gauc(y_true, y_pred, user_ids))
|
|
506
|
-
elif metric ==
|
|
450
|
+
elif metric == 'ks':
|
|
507
451
|
value = float(compute_ks(y_true, y_pred))
|
|
508
|
-
elif metric ==
|
|
452
|
+
elif metric == 'logloss':
|
|
509
453
|
value = float(log_loss(y_true, y_pred))
|
|
510
|
-
elif metric in (
|
|
454
|
+
elif metric in ('accuracy', 'acc'):
|
|
511
455
|
value = float(accuracy_score(y_true, y_p_binary))
|
|
512
|
-
elif metric ==
|
|
513
|
-
value = float(
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
recall_score(
|
|
524
|
-
y_true,
|
|
525
|
-
y_p_binary,
|
|
526
|
-
average="samples" if task_type == "multilabel" else "binary",
|
|
527
|
-
zero_division=0,
|
|
528
|
-
)
|
|
529
|
-
)
|
|
530
|
-
elif metric == "f1":
|
|
531
|
-
value = float(
|
|
532
|
-
f1_score(
|
|
533
|
-
y_true,
|
|
534
|
-
y_p_binary,
|
|
535
|
-
average="samples" if task_type == "multilabel" else "binary",
|
|
536
|
-
zero_division=0,
|
|
537
|
-
)
|
|
538
|
-
)
|
|
539
|
-
elif metric == "micro_f1":
|
|
540
|
-
value = float(
|
|
541
|
-
f1_score(y_true, y_p_binary, average="micro", zero_division=0)
|
|
542
|
-
)
|
|
543
|
-
elif metric == "macro_f1":
|
|
544
|
-
value = float(
|
|
545
|
-
f1_score(y_true, y_p_binary, average="macro", zero_division=0)
|
|
546
|
-
)
|
|
547
|
-
elif metric == "mse":
|
|
456
|
+
elif metric == 'precision':
|
|
457
|
+
value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
458
|
+
elif metric == 'recall':
|
|
459
|
+
value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
460
|
+
elif metric == 'f1':
|
|
461
|
+
value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
462
|
+
elif metric == 'micro_f1':
|
|
463
|
+
value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
|
|
464
|
+
elif metric == 'macro_f1':
|
|
465
|
+
value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
|
|
466
|
+
elif metric == 'mse':
|
|
548
467
|
value = float(mean_squared_error(y_true, y_pred))
|
|
549
|
-
elif metric ==
|
|
468
|
+
elif metric == 'mae':
|
|
550
469
|
value = float(mean_absolute_error(y_true, y_pred))
|
|
551
|
-
elif metric ==
|
|
470
|
+
elif metric == 'rmse':
|
|
552
471
|
value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
|
553
|
-
elif metric ==
|
|
472
|
+
elif metric == 'r2':
|
|
554
473
|
value = float(r2_score(y_true, y_pred))
|
|
555
|
-
elif metric ==
|
|
474
|
+
elif metric == 'mape':
|
|
556
475
|
value = float(compute_mape(y_true, y_pred))
|
|
557
|
-
elif metric ==
|
|
476
|
+
elif metric == 'msle':
|
|
558
477
|
value = float(compute_msle(y_true, y_pred))
|
|
559
478
|
else:
|
|
560
479
|
logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
|
|
@@ -562,40 +481,35 @@ def compute_single_metric(
|
|
|
562
481
|
except Exception as exception:
|
|
563
482
|
logging.warning(f"Failed to compute metric {metric}: {exception}")
|
|
564
483
|
value = 0.0
|
|
565
|
-
|
|
484
|
+
|
|
566
485
|
return value
|
|
567
486
|
|
|
568
|
-
|
|
569
487
|
def evaluate_metrics(
|
|
570
488
|
y_true: np.ndarray | None,
|
|
571
489
|
y_pred: np.ndarray | None,
|
|
572
|
-
metrics: list[str],
|
|
573
|
-
task: str | list[str],
|
|
574
|
-
target_names: list[str],
|
|
575
|
-
task_specific_metrics:
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
|
|
579
|
-
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
490
|
+
metrics: list[str], # example: ['auc', 'logloss']
|
|
491
|
+
task: str | list[str], # example: 'binary' or ['binary', 'regression']
|
|
492
|
+
target_names: list[str], # example: ['target1', 'target2']
|
|
493
|
+
task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
494
|
+
user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
|
|
495
|
+
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
580
496
|
"""Evaluate specified metrics for given true and predicted values."""
|
|
581
497
|
result = {}
|
|
582
|
-
|
|
498
|
+
|
|
583
499
|
if y_true is None or y_pred is None:
|
|
584
500
|
return result
|
|
585
|
-
|
|
501
|
+
|
|
586
502
|
# Main evaluation logic
|
|
587
503
|
primary_task = task[0] if isinstance(task, list) else task
|
|
588
504
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
589
|
-
|
|
505
|
+
|
|
590
506
|
# Single task evaluation
|
|
591
507
|
if nums_task == 1:
|
|
592
508
|
for metric in metrics:
|
|
593
509
|
metric_lower = metric.lower()
|
|
594
|
-
value = compute_single_metric(
|
|
595
|
-
metric_lower, y_true, y_pred, primary_task, user_ids
|
|
596
|
-
)
|
|
510
|
+
value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
|
|
597
511
|
result[metric_lower] = value
|
|
598
|
-
|
|
512
|
+
|
|
599
513
|
# Multi-task evaluation
|
|
600
514
|
else:
|
|
601
515
|
for metric in metrics:
|
|
@@ -605,9 +519,7 @@ def evaluate_metrics(
|
|
|
605
519
|
should_compute = True
|
|
606
520
|
if task_specific_metrics is not None and task_idx < len(target_names):
|
|
607
521
|
task_name = target_names[task_idx]
|
|
608
|
-
should_compute = metric_lower in task_specific_metrics.get(
|
|
609
|
-
task_name, []
|
|
610
|
-
)
|
|
522
|
+
should_compute = metric_lower in task_specific_metrics.get(task_name, [])
|
|
611
523
|
else:
|
|
612
524
|
# Get task type for specific index
|
|
613
525
|
if isinstance(task, list) and task_idx < len(task):
|
|
@@ -615,51 +527,31 @@ def evaluate_metrics(
|
|
|
615
527
|
elif isinstance(task, str):
|
|
616
528
|
task_type = task
|
|
617
529
|
else:
|
|
618
|
-
task_type =
|
|
619
|
-
|
|
620
|
-
if task_type in [
|
|
621
|
-
should_compute = metric_lower in {
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
"accuracy",
|
|
626
|
-
"acc",
|
|
627
|
-
"precision",
|
|
628
|
-
"recall",
|
|
629
|
-
"f1",
|
|
630
|
-
"micro_f1",
|
|
631
|
-
"macro_f1",
|
|
632
|
-
}
|
|
633
|
-
elif task_type == "regression":
|
|
634
|
-
should_compute = metric_lower in {
|
|
635
|
-
"mse",
|
|
636
|
-
"mae",
|
|
637
|
-
"rmse",
|
|
638
|
-
"r2",
|
|
639
|
-
"mape",
|
|
640
|
-
"msle",
|
|
641
|
-
}
|
|
642
|
-
|
|
530
|
+
task_type = 'binary'
|
|
531
|
+
|
|
532
|
+
if task_type in ['binary', 'multilabel']:
|
|
533
|
+
should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
534
|
+
elif task_type == 'regression':
|
|
535
|
+
should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
536
|
+
|
|
643
537
|
if not should_compute:
|
|
644
538
|
continue
|
|
645
|
-
|
|
539
|
+
|
|
646
540
|
target_name = target_names[task_idx]
|
|
647
|
-
|
|
541
|
+
|
|
648
542
|
# Get task type for specific index
|
|
649
543
|
if isinstance(task, list) and task_idx < len(task):
|
|
650
544
|
task_type = task[task_idx]
|
|
651
545
|
elif isinstance(task, str):
|
|
652
546
|
task_type = task
|
|
653
547
|
else:
|
|
654
|
-
task_type =
|
|
655
|
-
|
|
548
|
+
task_type = 'binary'
|
|
549
|
+
|
|
656
550
|
y_true_task = y_true[:, task_idx]
|
|
657
551
|
y_pred_task = y_pred[:, task_idx]
|
|
658
|
-
|
|
552
|
+
|
|
659
553
|
# Compute metric
|
|
660
|
-
value = compute_single_metric(
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
result[f"{metric_lower}_{target_name}"] = value
|
|
664
|
-
|
|
554
|
+
value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
|
|
555
|
+
result[f'{metric_lower}_{target_name}'] = value
|
|
556
|
+
|
|
665
557
|
return result
|