nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +4 -5
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +257 -177
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec/utils/optimizer.py +7 -3
- nextrec-0.3.2.dist-info/METADATA +312 -0
- nextrec-0.3.2.dist-info/RECORD +57 -0
- nextrec-0.2.7.dist-info/METADATA +0 -281
- nextrec-0.2.7.dist-info/RECORD +0 -54
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
|
|
5
|
+
Checkpoint: edit on 29/11/2025
|
|
6
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
8
9
|
|
|
@@ -10,10 +11,8 @@ import os
|
|
|
10
11
|
import re
|
|
11
12
|
import sys
|
|
12
13
|
import copy
|
|
13
|
-
import datetime
|
|
14
14
|
import logging
|
|
15
|
-
from
|
|
16
|
-
from nextrec.basic.session import resolve_save_path, create_session
|
|
15
|
+
from nextrec.basic.session import create_session
|
|
17
16
|
|
|
18
17
|
ANSI_CODES = {
|
|
19
18
|
'black': '\033[30m',
|
|
@@ -107,7 +106,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
107
106
|
|
|
108
107
|
console_format = '%(message)s'
|
|
109
108
|
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
|
110
|
-
date_format = '%H:%M:%S'
|
|
109
|
+
date_format = '%Y-%m-%d %H:%M:%S'
|
|
111
110
|
|
|
112
111
|
logger = logging.getLogger()
|
|
113
112
|
logger.setLevel(logging.INFO)
|
nextrec/basic/metrics.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 29/11/2025
|
|
5
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
import logging
|
|
@@ -11,7 +12,6 @@ from sklearn.metrics import (
|
|
|
11
12
|
accuracy_score, precision_score, recall_score, f1_score, r2_score,
|
|
12
13
|
)
|
|
13
14
|
|
|
14
|
-
|
|
15
15
|
CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
16
16
|
REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
17
17
|
TASK_DEFAULT_METRICS = {
|
|
@@ -21,8 +21,6 @@ TASK_DEFAULT_METRICS = {
|
|
|
21
21
|
'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
|
|
22
22
|
}
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
24
|
def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
27
25
|
"""Compute Kolmogorov-Smirnov statistic."""
|
|
28
26
|
sorted_indices = np.argsort(y_pred)[::-1]
|
|
@@ -38,7 +36,6 @@ def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
38
36
|
return float(ks_value)
|
|
39
37
|
return 0.0
|
|
40
38
|
|
|
41
|
-
|
|
42
39
|
def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
43
40
|
"""Compute Mean Absolute Percentage Error."""
|
|
44
41
|
mask = y_true != 0
|
|
@@ -46,83 +43,62 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
|
46
43
|
return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
|
|
47
44
|
return 0.0
|
|
48
45
|
|
|
49
|
-
|
|
50
46
|
def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
51
47
|
"""Compute Mean Squared Log Error."""
|
|
52
48
|
y_pred_pos = np.maximum(y_pred, 0)
|
|
53
49
|
return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
|
|
54
50
|
|
|
55
|
-
|
|
56
|
-
def compute_gauc(
|
|
57
|
-
y_true: np.ndarray,
|
|
58
|
-
y_pred: np.ndarray,
|
|
59
|
-
user_ids: np.ndarray | None = None
|
|
60
|
-
) -> float:
|
|
51
|
+
def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None) -> float:
|
|
61
52
|
if user_ids is None:
|
|
62
53
|
# If no user_ids provided, fall back to regular AUC
|
|
63
54
|
try:
|
|
64
55
|
return float(roc_auc_score(y_true, y_pred))
|
|
65
56
|
except:
|
|
66
57
|
return 0.0
|
|
67
|
-
|
|
68
58
|
# Group by user_id and calculate AUC for each user
|
|
69
59
|
user_aucs = []
|
|
70
60
|
user_weights = []
|
|
71
|
-
|
|
72
61
|
unique_users = np.unique(user_ids)
|
|
73
|
-
|
|
74
62
|
for user_id in unique_users:
|
|
75
63
|
mask = user_ids == user_id
|
|
76
64
|
user_y_true = y_true[mask]
|
|
77
65
|
user_y_pred = y_pred[mask]
|
|
78
|
-
|
|
79
66
|
# Skip users with only one class (cannot compute AUC)
|
|
80
67
|
if len(np.unique(user_y_true)) < 2:
|
|
81
68
|
continue
|
|
82
|
-
|
|
83
69
|
try:
|
|
84
70
|
user_auc = roc_auc_score(user_y_true, user_y_pred)
|
|
85
71
|
user_aucs.append(user_auc)
|
|
86
72
|
user_weights.append(len(user_y_true))
|
|
87
73
|
except:
|
|
88
74
|
continue
|
|
89
|
-
|
|
90
75
|
if len(user_aucs) == 0:
|
|
91
76
|
return 0.0
|
|
92
|
-
|
|
93
77
|
# Weighted average
|
|
94
78
|
user_aucs = np.array(user_aucs)
|
|
95
79
|
user_weights = np.array(user_weights)
|
|
96
80
|
gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
|
|
97
|
-
|
|
98
81
|
return gauc
|
|
99
82
|
|
|
100
|
-
|
|
101
83
|
def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
|
|
102
84
|
"""Group sample indices by user_id. If user_ids is None, treat all as one group."""
|
|
103
85
|
if user_ids is None:
|
|
104
86
|
return [np.arange(n_samples)]
|
|
105
|
-
|
|
106
87
|
user_ids = np.asarray(user_ids)
|
|
107
88
|
if user_ids.shape[0] != n_samples:
|
|
108
|
-
logging.warning(
|
|
109
|
-
"user_ids length (%d) != number of samples (%d), "
|
|
110
|
-
"treating all samples as a single group for ranking metrics.",
|
|
111
|
-
user_ids.shape[0],
|
|
112
|
-
n_samples,
|
|
113
|
-
)
|
|
89
|
+
logging.warning(f"[Metrics Warning: GAUC] user_ids length {user_ids.shape[0]} != number of samples {n_samples}, treating all samples as a single group for ranking metrics.")
|
|
114
90
|
return [np.arange(n_samples)]
|
|
115
|
-
|
|
116
91
|
unique_users = np.unique(user_ids)
|
|
117
92
|
groups = [np.where(user_ids == u)[0] for u in unique_users]
|
|
118
93
|
return groups
|
|
119
94
|
|
|
120
|
-
|
|
121
|
-
|
|
95
|
+
def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
96
|
+
"""Compute Precision@K."""
|
|
97
|
+
if user_ids is None:
|
|
98
|
+
raise ValueError("[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation.")
|
|
122
99
|
y_true = (y_true > 0).astype(int)
|
|
123
100
|
n = len(y_true)
|
|
124
101
|
groups = _group_indices_by_user(user_ids, n)
|
|
125
|
-
|
|
126
102
|
precisions = []
|
|
127
103
|
for idx in groups:
|
|
128
104
|
if idx.size == 0:
|
|
@@ -134,16 +110,15 @@ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np
|
|
|
134
110
|
topk = order[:k_user]
|
|
135
111
|
hits = labels[topk].sum()
|
|
136
112
|
precisions.append(hits / float(k_user))
|
|
137
|
-
|
|
138
113
|
return float(np.mean(precisions)) if precisions else 0.0
|
|
139
114
|
|
|
140
|
-
|
|
141
|
-
def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
115
|
+
def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
142
116
|
"""Compute Recall@K."""
|
|
117
|
+
if user_ids is None:
|
|
118
|
+
raise ValueError("[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation.")
|
|
143
119
|
y_true = (y_true > 0).astype(int)
|
|
144
120
|
n = len(y_true)
|
|
145
121
|
groups = _group_indices_by_user(user_ids, n)
|
|
146
|
-
|
|
147
122
|
recalls = []
|
|
148
123
|
for idx in groups:
|
|
149
124
|
if idx.size == 0:
|
|
@@ -151,46 +126,44 @@ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
|
|
|
151
126
|
labels = y_true[idx]
|
|
152
127
|
num_pos = labels.sum()
|
|
153
128
|
if num_pos == 0:
|
|
154
|
-
continue #
|
|
129
|
+
continue # dont count users with no positive labels
|
|
155
130
|
scores = y_pred[idx]
|
|
156
131
|
order = np.argsort(scores)[::-1]
|
|
157
132
|
k_user = min(k, idx.size)
|
|
158
133
|
topk = order[:k_user]
|
|
159
134
|
hits = labels[topk].sum()
|
|
160
135
|
recalls.append(hits / float(num_pos))
|
|
161
|
-
|
|
162
136
|
return float(np.mean(recalls)) if recalls else 0.0
|
|
163
137
|
|
|
164
|
-
|
|
165
|
-
def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
138
|
+
def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
166
139
|
"""Compute HitRate@K."""
|
|
140
|
+
if user_ids is None:
|
|
141
|
+
raise ValueError("[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation.")
|
|
167
142
|
y_true = (y_true > 0).astype(int)
|
|
168
143
|
n = len(y_true)
|
|
169
144
|
groups = _group_indices_by_user(user_ids, n)
|
|
170
|
-
|
|
171
145
|
hits_per_user = []
|
|
172
146
|
for idx in groups:
|
|
173
147
|
if idx.size == 0:
|
|
174
148
|
continue
|
|
175
149
|
labels = y_true[idx]
|
|
176
150
|
if labels.sum() == 0:
|
|
177
|
-
continue #
|
|
151
|
+
continue # dont count users with no positive labels
|
|
178
152
|
scores = y_pred[idx]
|
|
179
153
|
order = np.argsort(scores)[::-1]
|
|
180
154
|
k_user = min(k, idx.size)
|
|
181
155
|
topk = order[:k_user]
|
|
182
156
|
hits = labels[topk].sum()
|
|
183
157
|
hits_per_user.append(1.0 if hits > 0 else 0.0)
|
|
184
|
-
|
|
185
158
|
return float(np.mean(hits_per_user)) if hits_per_user else 0.0
|
|
186
159
|
|
|
187
|
-
|
|
188
|
-
def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
160
|
+
def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
189
161
|
"""Compute MRR@K."""
|
|
162
|
+
if user_ids is None:
|
|
163
|
+
raise ValueError("[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation.")
|
|
190
164
|
y_true = (y_true > 0).astype(int)
|
|
191
165
|
n = len(y_true)
|
|
192
166
|
groups = _group_indices_by_user(user_ids, n)
|
|
193
|
-
|
|
194
167
|
mrrs = []
|
|
195
168
|
for idx in groups:
|
|
196
169
|
if idx.size == 0:
|
|
@@ -203,17 +176,14 @@ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
|
|
|
203
176
|
k_user = min(k, idx.size)
|
|
204
177
|
topk = order[:k_user]
|
|
205
178
|
ranked_labels = labels[order]
|
|
206
|
-
|
|
207
179
|
rr = 0.0
|
|
208
180
|
for rank, lab in enumerate(ranked_labels[:k_user], start=1):
|
|
209
181
|
if lab > 0:
|
|
210
182
|
rr = 1.0 / rank
|
|
211
183
|
break
|
|
212
184
|
mrrs.append(rr)
|
|
213
|
-
|
|
214
185
|
return float(np.mean(mrrs)) if mrrs else 0.0
|
|
215
186
|
|
|
216
|
-
|
|
217
187
|
def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
218
188
|
k_user = min(k, labels.size)
|
|
219
189
|
if k_user == 0:
|
|
@@ -222,13 +192,13 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
|
222
192
|
discounts = np.log2(np.arange(2, k_user + 2))
|
|
223
193
|
return float(np.sum(gains / discounts))
|
|
224
194
|
|
|
225
|
-
|
|
226
|
-
def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
195
|
+
def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
227
196
|
"""Compute NDCG@K."""
|
|
197
|
+
if user_ids is None:
|
|
198
|
+
raise ValueError("[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation.")
|
|
228
199
|
y_true = (y_true > 0).astype(int)
|
|
229
200
|
n = len(y_true)
|
|
230
201
|
groups = _group_indices_by_user(user_ids, n)
|
|
231
|
-
|
|
232
202
|
ndcgs = []
|
|
233
203
|
for idx in groups:
|
|
234
204
|
if idx.size == 0:
|
|
@@ -237,27 +207,25 @@ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndar
|
|
|
237
207
|
if labels.sum() == 0:
|
|
238
208
|
continue
|
|
239
209
|
scores = y_pred[idx]
|
|
240
|
-
|
|
241
210
|
order = np.argsort(scores)[::-1]
|
|
242
211
|
ranked_labels = labels[order]
|
|
243
212
|
dcg = _compute_dcg_at_k(ranked_labels, k)
|
|
244
|
-
|
|
245
213
|
# ideal DCG
|
|
246
214
|
ideal_labels = np.sort(labels)[::-1]
|
|
247
215
|
idcg = _compute_dcg_at_k(ideal_labels, k)
|
|
248
216
|
if idcg == 0.0:
|
|
249
217
|
continue
|
|
250
218
|
ndcgs.append(dcg / idcg)
|
|
251
|
-
|
|
252
219
|
return float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
253
220
|
|
|
254
221
|
|
|
255
|
-
def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray
|
|
222
|
+
def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
|
|
256
223
|
"""Mean Average Precision@K."""
|
|
224
|
+
if user_ids is None:
|
|
225
|
+
raise ValueError("[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation.")
|
|
257
226
|
y_true = (y_true > 0).astype(int)
|
|
258
227
|
n = len(y_true)
|
|
259
228
|
groups = _group_indices_by_user(user_ids, n)
|
|
260
|
-
|
|
261
229
|
aps = []
|
|
262
230
|
for idx in groups:
|
|
263
231
|
if idx.size == 0:
|
|
@@ -266,23 +234,19 @@ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
|
|
|
266
234
|
num_pos = labels.sum()
|
|
267
235
|
if num_pos == 0:
|
|
268
236
|
continue
|
|
269
|
-
|
|
270
237
|
scores = y_pred[idx]
|
|
271
238
|
order = np.argsort(scores)[::-1]
|
|
272
239
|
k_user = min(k, idx.size)
|
|
273
|
-
|
|
274
240
|
hits = 0
|
|
275
241
|
sum_precisions = 0.0
|
|
276
242
|
for rank, i in enumerate(order[:k_user], start=1):
|
|
277
243
|
if labels[i] > 0:
|
|
278
244
|
hits += 1
|
|
279
245
|
sum_precisions += hits / float(rank)
|
|
280
|
-
|
|
281
246
|
if hits == 0:
|
|
282
247
|
aps.append(0.0)
|
|
283
248
|
else:
|
|
284
249
|
aps.append(sum_precisions / float(num_pos))
|
|
285
|
-
|
|
286
250
|
return float(np.mean(aps)) if aps else 0.0
|
|
287
251
|
|
|
288
252
|
|
|
@@ -308,31 +272,22 @@ def configure_metrics(
|
|
|
308
272
|
"""Configure metrics based on task and user input."""
|
|
309
273
|
primary_task = task[0] if isinstance(task, list) else task
|
|
310
274
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
311
|
-
|
|
312
275
|
metrics_list: list[str] = []
|
|
313
276
|
task_specific_metrics: dict[str, list[str]] | None = None
|
|
314
|
-
|
|
315
277
|
if isinstance(metrics, dict):
|
|
316
278
|
metrics_list = []
|
|
317
279
|
task_specific_metrics = {}
|
|
318
280
|
for task_name, task_metrics in metrics.items():
|
|
319
281
|
if task_name not in target_names:
|
|
320
|
-
logging.warning(
|
|
321
|
-
"Task '%s' not found in targets %s, skipping its metrics",
|
|
322
|
-
task_name,
|
|
323
|
-
target_names,
|
|
324
|
-
)
|
|
282
|
+
logging.warning(f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics")
|
|
325
283
|
continue
|
|
326
|
-
|
|
327
284
|
lowered = [m.lower() for m in task_metrics]
|
|
328
285
|
task_specific_metrics[task_name] = lowered
|
|
329
286
|
for metric in lowered:
|
|
330
287
|
if metric not in metrics_list:
|
|
331
288
|
metrics_list.append(metric)
|
|
332
|
-
|
|
333
289
|
elif metrics:
|
|
334
290
|
metrics_list = [m.lower() for m in metrics]
|
|
335
|
-
|
|
336
291
|
else:
|
|
337
292
|
# No user provided metrics, derive per task type
|
|
338
293
|
if nums_task > 1 and isinstance(task, list):
|
|
@@ -350,26 +305,20 @@ def configure_metrics(
|
|
|
350
305
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
351
306
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
352
307
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
353
|
-
|
|
354
308
|
if not metrics_list:
|
|
355
309
|
# Inline get_default_metrics_for_task logic
|
|
356
310
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
357
311
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
358
312
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
359
|
-
|
|
360
313
|
best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
|
|
361
|
-
|
|
362
314
|
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
363
315
|
|
|
364
|
-
|
|
365
316
|
def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
366
317
|
"""Determine if metric should be maximized or minimized."""
|
|
367
318
|
first_metric_lower = first_metric.lower()
|
|
368
|
-
|
|
369
319
|
# Metrics that should be maximized
|
|
370
320
|
if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
|
|
371
321
|
return 'max'
|
|
372
|
-
|
|
373
322
|
# Ranking metrics that should be maximized (with @K suffix)
|
|
374
323
|
if (first_metric_lower.startswith('recall@') or
|
|
375
324
|
first_metric_lower.startswith('precision@') or
|
|
@@ -379,21 +328,17 @@ def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
|
379
328
|
first_metric_lower.startswith('ndcg@') or
|
|
380
329
|
first_metric_lower.startswith('map@')):
|
|
381
330
|
return 'max'
|
|
382
|
-
|
|
383
331
|
# Cosine separation should be maximized
|
|
384
332
|
if first_metric_lower == 'cosine':
|
|
385
333
|
return 'max'
|
|
386
|
-
|
|
387
334
|
# Metrics that should be minimized
|
|
388
335
|
if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
|
|
389
336
|
return 'min'
|
|
390
|
-
|
|
391
337
|
# Default based on task type
|
|
392
338
|
if primary_task == 'regression':
|
|
393
339
|
return 'min'
|
|
394
340
|
return 'max'
|
|
395
341
|
|
|
396
|
-
|
|
397
342
|
def compute_single_metric(
|
|
398
343
|
metric: str,
|
|
399
344
|
y_true: np.ndarray,
|
|
@@ -403,45 +348,36 @@ def compute_single_metric(
|
|
|
403
348
|
) -> float:
|
|
404
349
|
"""Compute a single metric given true and predicted values."""
|
|
405
350
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
406
|
-
|
|
407
351
|
try:
|
|
408
352
|
metric_lower = metric.lower()
|
|
409
|
-
|
|
410
353
|
# recall@K
|
|
411
354
|
if metric_lower.startswith('recall@'):
|
|
412
355
|
k = int(metric_lower.split('@')[1])
|
|
413
|
-
return _compute_recall_at_k(y_true, y_pred, user_ids, k)
|
|
414
|
-
|
|
356
|
+
return _compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
415
357
|
# precision@K
|
|
416
358
|
if metric_lower.startswith('precision@'):
|
|
417
359
|
k = int(metric_lower.split('@')[1])
|
|
418
|
-
return _compute_precision_at_k(y_true, y_pred, user_ids, k)
|
|
419
|
-
|
|
360
|
+
return _compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
420
361
|
# hitrate@K / hr@K
|
|
421
362
|
if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
|
|
422
363
|
k_str = metric_lower.split('@')[1]
|
|
423
364
|
k = int(k_str)
|
|
424
|
-
return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
|
|
425
|
-
|
|
365
|
+
return _compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
426
366
|
# mrr@K
|
|
427
367
|
if metric_lower.startswith('mrr@'):
|
|
428
368
|
k = int(metric_lower.split('@')[1])
|
|
429
|
-
return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
|
|
430
|
-
|
|
369
|
+
return _compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
431
370
|
# ndcg@K
|
|
432
371
|
if metric_lower.startswith('ndcg@'):
|
|
433
372
|
k = int(metric_lower.split('@')[1])
|
|
434
|
-
return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
|
|
435
|
-
|
|
373
|
+
return _compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
436
374
|
# map@K
|
|
437
375
|
if metric_lower.startswith('map@'):
|
|
438
376
|
k = int(metric_lower.split('@')[1])
|
|
439
|
-
return _compute_map_at_k(y_true, y_pred, user_ids, k)
|
|
440
|
-
|
|
377
|
+
return _compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
441
378
|
# cosine for matching task
|
|
442
379
|
if metric_lower == 'cosine':
|
|
443
380
|
return _compute_cosine_separation(y_true, y_pred)
|
|
444
|
-
|
|
445
381
|
if metric == 'auc':
|
|
446
382
|
value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
|
|
447
383
|
elif metric == 'gauc':
|
|
@@ -475,12 +411,11 @@ def compute_single_metric(
|
|
|
475
411
|
elif metric == 'msle':
|
|
476
412
|
value = float(compute_msle(y_true, y_pred))
|
|
477
413
|
else:
|
|
478
|
-
logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
|
|
414
|
+
logging.warning(f"[Metric Warning] Metric '{metric}' is not supported, returning 0.0")
|
|
479
415
|
value = 0.0
|
|
480
416
|
except Exception as exception:
|
|
481
|
-
logging.warning(f"Failed to compute metric {metric}: {exception}")
|
|
417
|
+
logging.warning(f"[Metric Warning] Failed to compute metric {metric}: {exception}")
|
|
482
418
|
value = 0.0
|
|
483
|
-
|
|
484
419
|
return value
|
|
485
420
|
|
|
486
421
|
def evaluate_metrics(
|
|
@@ -494,21 +429,17 @@ def evaluate_metrics(
|
|
|
494
429
|
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
495
430
|
"""Evaluate specified metrics for given true and predicted values."""
|
|
496
431
|
result = {}
|
|
497
|
-
|
|
498
432
|
if y_true is None or y_pred is None:
|
|
499
433
|
return result
|
|
500
|
-
|
|
501
434
|
# Main evaluation logic
|
|
502
435
|
primary_task = task[0] if isinstance(task, list) else task
|
|
503
436
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
504
|
-
|
|
505
437
|
# Single task evaluation
|
|
506
438
|
if nums_task == 1:
|
|
507
439
|
for metric in metrics:
|
|
508
440
|
metric_lower = metric.lower()
|
|
509
441
|
value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
|
|
510
442
|
result[metric_lower] = value
|
|
511
|
-
|
|
512
443
|
# Multi-task evaluation
|
|
513
444
|
else:
|
|
514
445
|
for metric in metrics:
|
|
@@ -526,31 +457,24 @@ def evaluate_metrics(
|
|
|
526
457
|
elif isinstance(task, str):
|
|
527
458
|
task_type = task
|
|
528
459
|
else:
|
|
529
|
-
task_type = 'binary'
|
|
530
|
-
|
|
460
|
+
task_type = 'binary'
|
|
531
461
|
if task_type in ['binary', 'multilabel']:
|
|
532
462
|
should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
533
463
|
elif task_type == 'regression':
|
|
534
|
-
should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
535
|
-
|
|
464
|
+
should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
536
465
|
if not should_compute:
|
|
537
|
-
continue
|
|
538
|
-
|
|
539
|
-
target_name = target_names[task_idx]
|
|
540
|
-
|
|
466
|
+
continue
|
|
467
|
+
target_name = target_names[task_idx]
|
|
541
468
|
# Get task type for specific index
|
|
542
469
|
if isinstance(task, list) and task_idx < len(task):
|
|
543
470
|
task_type = task[task_idx]
|
|
544
471
|
elif isinstance(task, str):
|
|
545
472
|
task_type = task
|
|
546
473
|
else:
|
|
547
|
-
task_type = 'binary'
|
|
548
|
-
|
|
474
|
+
task_type = 'binary'
|
|
549
475
|
y_true_task = y_true[:, task_idx]
|
|
550
|
-
y_pred_task = y_pred[:, task_idx]
|
|
551
|
-
|
|
476
|
+
y_pred_task = y_pred[:, task_idx]
|
|
552
477
|
# Compute metric
|
|
553
478
|
value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
|
|
554
479
|
result[f'{metric_lower}_{target_name}'] = value
|
|
555
|
-
|
|
556
480
|
return result
|