nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
nextrec/basic/loggers.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NextRec Basic Loggers
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import sys
|
|
12
|
+
import copy
|
|
13
|
+
import datetime
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
ANSI_CODES = {
|
|
17
|
+
'black': '\033[30m',
|
|
18
|
+
'red': '\033[31m',
|
|
19
|
+
'green': '\033[32m',
|
|
20
|
+
'yellow': '\033[33m',
|
|
21
|
+
'blue': '\033[34m',
|
|
22
|
+
'magenta': '\033[35m',
|
|
23
|
+
'cyan': '\033[36m',
|
|
24
|
+
'white': '\033[37m',
|
|
25
|
+
'bright_black': '\033[90m',
|
|
26
|
+
'bright_red': '\033[91m',
|
|
27
|
+
'bright_green': '\033[92m',
|
|
28
|
+
'bright_yellow': '\033[93m',
|
|
29
|
+
'bright_blue': '\033[94m',
|
|
30
|
+
'bright_magenta': '\033[95m',
|
|
31
|
+
'bright_cyan': '\033[96m',
|
|
32
|
+
'bright_white': '\033[97m',
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
ANSI_BOLD = '\033[1m'
|
|
36
|
+
ANSI_RESET = '\033[0m'
|
|
37
|
+
ANSI_ESCAPE_PATTERN = re.compile(r'\033\[[0-9;]*m')
|
|
38
|
+
|
|
39
|
+
DEFAULT_LEVEL_COLORS = {
|
|
40
|
+
'DEBUG': 'cyan',
|
|
41
|
+
'INFO': None,
|
|
42
|
+
'WARNING': 'yellow',
|
|
43
|
+
'ERROR': 'red',
|
|
44
|
+
'CRITICAL': 'bright_red',
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
class AnsiFormatter(logging.Formatter):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
*args,
|
|
51
|
+
strip_ansi: bool = False,
|
|
52
|
+
auto_color_level: bool = False,
|
|
53
|
+
level_colors: dict[str, str] | None = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> None:
|
|
56
|
+
super().__init__(*args, **kwargs)
|
|
57
|
+
self.strip_ansi = strip_ansi
|
|
58
|
+
self.auto_color_level = auto_color_level
|
|
59
|
+
self.level_colors = level_colors or DEFAULT_LEVEL_COLORS
|
|
60
|
+
|
|
61
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
62
|
+
record_copy = copy.copy(record)
|
|
63
|
+
formatted = super().format(record_copy)
|
|
64
|
+
|
|
65
|
+
if self.auto_color_level and '\033[' not in formatted:
|
|
66
|
+
color = self.level_colors.get(record.levelname)
|
|
67
|
+
if color:
|
|
68
|
+
formatted = colorize(formatted, color=color)
|
|
69
|
+
|
|
70
|
+
if self.strip_ansi:
|
|
71
|
+
return ANSI_ESCAPE_PATTERN.sub('', formatted)
|
|
72
|
+
|
|
73
|
+
return formatted
|
|
74
|
+
|
|
75
|
+
def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
76
|
+
"""Apply ANSI color and bold formatting to the given text."""
|
|
77
|
+
if not color and not bold:
|
|
78
|
+
return text
|
|
79
|
+
|
|
80
|
+
result = ""
|
|
81
|
+
|
|
82
|
+
if bold:
|
|
83
|
+
result += ANSI_BOLD
|
|
84
|
+
|
|
85
|
+
if color and color in ANSI_CODES:
|
|
86
|
+
result += ANSI_CODES[color]
|
|
87
|
+
|
|
88
|
+
result += text + ANSI_RESET
|
|
89
|
+
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
def setup_logger(log_dir: str | None = None):
|
|
93
|
+
"""Set up a logger that logs to both console and a file with ANSI formatting.
|
|
94
|
+
Only console output has colors; file output is stripped of ANSI codes.
|
|
95
|
+
"""
|
|
96
|
+
if log_dir is None:
|
|
97
|
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
98
|
+
log_dir = os.path.join(project_root, "..", "logs")
|
|
99
|
+
|
|
100
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
101
|
+
log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
102
|
+
|
|
103
|
+
console_format = '%(message)s'
|
|
104
|
+
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
|
105
|
+
date_format = '%H:%M:%S'
|
|
106
|
+
|
|
107
|
+
logger = logging.getLogger()
|
|
108
|
+
logger.setLevel(logging.INFO)
|
|
109
|
+
|
|
110
|
+
if logger.hasHandlers():
|
|
111
|
+
logger.handlers.clear()
|
|
112
|
+
|
|
113
|
+
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
|
114
|
+
file_handler.setLevel(logging.INFO)
|
|
115
|
+
file_handler.setFormatter(AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True))
|
|
116
|
+
|
|
117
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
118
|
+
console_handler.setLevel(logging.INFO)
|
|
119
|
+
console_handler.setFormatter(AnsiFormatter(console_format, datefmt=date_format, auto_color_level=True,))
|
|
120
|
+
|
|
121
|
+
logger.addHandler(file_handler)
|
|
122
|
+
logger.addHandler(console_handler)
|
|
123
|
+
|
|
124
|
+
return logger
|
nextrec/basic/metrics.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics computation and configuration for model evaluation.
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
import logging
|
|
9
|
+
import numpy as np
|
|
10
|
+
from sklearn.metrics import (
|
|
11
|
+
roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
|
|
12
|
+
accuracy_score, precision_score, recall_score, f1_score, r2_score,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
17
|
+
REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
18
|
+
TASK_DEFAULT_METRICS = {
|
|
19
|
+
'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
|
|
20
|
+
'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
|
|
21
|
+
'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
|
|
22
|
+
'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
28
|
+
"""Compute Kolmogorov-Smirnov statistic."""
|
|
29
|
+
sorted_indices = np.argsort(y_pred)[::-1]
|
|
30
|
+
y_true_sorted = y_true[sorted_indices]
|
|
31
|
+
|
|
32
|
+
n_pos = np.sum(y_true_sorted == 1)
|
|
33
|
+
n_neg = np.sum(y_true_sorted == 0)
|
|
34
|
+
|
|
35
|
+
if n_pos > 0 and n_neg > 0:
|
|
36
|
+
cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
|
|
37
|
+
cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
|
|
38
|
+
ks_value = np.max(np.abs(cum_pos_rate - cum_neg_rate))
|
|
39
|
+
return float(ks_value)
|
|
40
|
+
return 0.0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
44
|
+
"""Compute Mean Absolute Percentage Error."""
|
|
45
|
+
mask = y_true != 0
|
|
46
|
+
if np.any(mask):
|
|
47
|
+
return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
|
|
48
|
+
return 0.0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
52
|
+
"""Compute Mean Squared Log Error."""
|
|
53
|
+
y_pred_pos = np.maximum(y_pred, 0)
|
|
54
|
+
return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def compute_gauc(
|
|
58
|
+
y_true: np.ndarray,
|
|
59
|
+
y_pred: np.ndarray,
|
|
60
|
+
user_ids: np.ndarray | None = None
|
|
61
|
+
) -> float:
|
|
62
|
+
if user_ids is None:
|
|
63
|
+
# If no user_ids provided, fall back to regular AUC
|
|
64
|
+
try:
|
|
65
|
+
return float(roc_auc_score(y_true, y_pred))
|
|
66
|
+
except:
|
|
67
|
+
return 0.0
|
|
68
|
+
|
|
69
|
+
# Group by user_id and calculate AUC for each user
|
|
70
|
+
user_aucs = []
|
|
71
|
+
user_weights = []
|
|
72
|
+
|
|
73
|
+
unique_users = np.unique(user_ids)
|
|
74
|
+
|
|
75
|
+
for user_id in unique_users:
|
|
76
|
+
mask = user_ids == user_id
|
|
77
|
+
user_y_true = y_true[mask]
|
|
78
|
+
user_y_pred = y_pred[mask]
|
|
79
|
+
|
|
80
|
+
# Skip users with only one class (cannot compute AUC)
|
|
81
|
+
if len(np.unique(user_y_true)) < 2:
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
user_auc = roc_auc_score(user_y_true, user_y_pred)
|
|
86
|
+
user_aucs.append(user_auc)
|
|
87
|
+
user_weights.append(len(user_y_true))
|
|
88
|
+
except:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if len(user_aucs) == 0:
|
|
92
|
+
return 0.0
|
|
93
|
+
|
|
94
|
+
# Weighted average
|
|
95
|
+
user_aucs = np.array(user_aucs)
|
|
96
|
+
user_weights = np.array(user_weights)
|
|
97
|
+
gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
|
|
98
|
+
|
|
99
|
+
return gauc
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
|
|
103
|
+
"""Group sample indices by user_id. If user_ids is None, treat all as one group."""
|
|
104
|
+
if user_ids is None:
|
|
105
|
+
return [np.arange(n_samples)]
|
|
106
|
+
|
|
107
|
+
user_ids = np.asarray(user_ids)
|
|
108
|
+
if user_ids.shape[0] != n_samples:
|
|
109
|
+
logging.warning(
|
|
110
|
+
"user_ids length (%d) != number of samples (%d), "
|
|
111
|
+
"treating all samples as a single group for ranking metrics.",
|
|
112
|
+
user_ids.shape[0],
|
|
113
|
+
n_samples,
|
|
114
|
+
)
|
|
115
|
+
return [np.arange(n_samples)]
|
|
116
|
+
|
|
117
|
+
unique_users = np.unique(user_ids)
|
|
118
|
+
groups = [np.where(user_ids == u)[0] for u in unique_users]
|
|
119
|
+
return groups
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
123
|
+
y_true = (y_true > 0).astype(int)
|
|
124
|
+
n = len(y_true)
|
|
125
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
126
|
+
|
|
127
|
+
precisions = []
|
|
128
|
+
for idx in groups:
|
|
129
|
+
if idx.size == 0:
|
|
130
|
+
continue
|
|
131
|
+
k_user = min(k, idx.size)
|
|
132
|
+
scores = y_pred[idx]
|
|
133
|
+
labels = y_true[idx]
|
|
134
|
+
order = np.argsort(scores)[::-1]
|
|
135
|
+
topk = order[:k_user]
|
|
136
|
+
hits = labels[topk].sum()
|
|
137
|
+
precisions.append(hits / float(k_user))
|
|
138
|
+
|
|
139
|
+
return float(np.mean(precisions)) if precisions else 0.0
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
143
|
+
"""Compute Recall@K."""
|
|
144
|
+
y_true = (y_true > 0).astype(int)
|
|
145
|
+
n = len(y_true)
|
|
146
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
147
|
+
|
|
148
|
+
recalls = []
|
|
149
|
+
for idx in groups:
|
|
150
|
+
if idx.size == 0:
|
|
151
|
+
continue
|
|
152
|
+
labels = y_true[idx]
|
|
153
|
+
num_pos = labels.sum()
|
|
154
|
+
if num_pos == 0:
|
|
155
|
+
continue # 跳过没有正样本的用户
|
|
156
|
+
scores = y_pred[idx]
|
|
157
|
+
order = np.argsort(scores)[::-1]
|
|
158
|
+
k_user = min(k, idx.size)
|
|
159
|
+
topk = order[:k_user]
|
|
160
|
+
hits = labels[topk].sum()
|
|
161
|
+
recalls.append(hits / float(num_pos))
|
|
162
|
+
|
|
163
|
+
return float(np.mean(recalls)) if recalls else 0.0
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
167
|
+
"""Compute HitRate@K."""
|
|
168
|
+
y_true = (y_true > 0).astype(int)
|
|
169
|
+
n = len(y_true)
|
|
170
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
171
|
+
|
|
172
|
+
hits_per_user = []
|
|
173
|
+
for idx in groups:
|
|
174
|
+
if idx.size == 0:
|
|
175
|
+
continue
|
|
176
|
+
labels = y_true[idx]
|
|
177
|
+
if labels.sum() == 0:
|
|
178
|
+
continue # 无正样本用户不计入
|
|
179
|
+
scores = y_pred[idx]
|
|
180
|
+
order = np.argsort(scores)[::-1]
|
|
181
|
+
k_user = min(k, idx.size)
|
|
182
|
+
topk = order[:k_user]
|
|
183
|
+
hits = labels[topk].sum()
|
|
184
|
+
hits_per_user.append(1.0 if hits > 0 else 0.0)
|
|
185
|
+
|
|
186
|
+
return float(np.mean(hits_per_user)) if hits_per_user else 0.0
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
190
|
+
"""Compute MRR@K."""
|
|
191
|
+
y_true = (y_true > 0).astype(int)
|
|
192
|
+
n = len(y_true)
|
|
193
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
194
|
+
|
|
195
|
+
mrrs = []
|
|
196
|
+
for idx in groups:
|
|
197
|
+
if idx.size == 0:
|
|
198
|
+
continue
|
|
199
|
+
labels = y_true[idx]
|
|
200
|
+
if labels.sum() == 0:
|
|
201
|
+
continue
|
|
202
|
+
scores = y_pred[idx]
|
|
203
|
+
order = np.argsort(scores)[::-1]
|
|
204
|
+
k_user = min(k, idx.size)
|
|
205
|
+
topk = order[:k_user]
|
|
206
|
+
ranked_labels = labels[order]
|
|
207
|
+
|
|
208
|
+
rr = 0.0
|
|
209
|
+
for rank, lab in enumerate(ranked_labels[:k_user], start=1):
|
|
210
|
+
if lab > 0:
|
|
211
|
+
rr = 1.0 / rank
|
|
212
|
+
break
|
|
213
|
+
mrrs.append(rr)
|
|
214
|
+
|
|
215
|
+
return float(np.mean(mrrs)) if mrrs else 0.0
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
|
|
219
|
+
k_user = min(k, labels.size)
|
|
220
|
+
if k_user == 0:
|
|
221
|
+
return 0.0
|
|
222
|
+
gains = (2 ** labels[:k_user] - 1).astype(float)
|
|
223
|
+
discounts = np.log2(np.arange(2, k_user + 2))
|
|
224
|
+
return float(np.sum(gains / discounts))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
228
|
+
"""Compute NDCG@K."""
|
|
229
|
+
y_true = (y_true > 0).astype(int)
|
|
230
|
+
n = len(y_true)
|
|
231
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
232
|
+
|
|
233
|
+
ndcgs = []
|
|
234
|
+
for idx in groups:
|
|
235
|
+
if idx.size == 0:
|
|
236
|
+
continue
|
|
237
|
+
labels = y_true[idx]
|
|
238
|
+
if labels.sum() == 0:
|
|
239
|
+
continue
|
|
240
|
+
scores = y_pred[idx]
|
|
241
|
+
|
|
242
|
+
order = np.argsort(scores)[::-1]
|
|
243
|
+
ranked_labels = labels[order]
|
|
244
|
+
dcg = _compute_dcg_at_k(ranked_labels, k)
|
|
245
|
+
|
|
246
|
+
# ideal DCG
|
|
247
|
+
ideal_labels = np.sort(labels)[::-1]
|
|
248
|
+
idcg = _compute_dcg_at_k(ideal_labels, k)
|
|
249
|
+
if idcg == 0.0:
|
|
250
|
+
continue
|
|
251
|
+
ndcgs.append(dcg / idcg)
|
|
252
|
+
|
|
253
|
+
return float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
|
|
257
|
+
"""Mean Average Precision@K."""
|
|
258
|
+
y_true = (y_true > 0).astype(int)
|
|
259
|
+
n = len(y_true)
|
|
260
|
+
groups = _group_indices_by_user(user_ids, n)
|
|
261
|
+
|
|
262
|
+
aps = []
|
|
263
|
+
for idx in groups:
|
|
264
|
+
if idx.size == 0:
|
|
265
|
+
continue
|
|
266
|
+
labels = y_true[idx]
|
|
267
|
+
num_pos = labels.sum()
|
|
268
|
+
if num_pos == 0:
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
scores = y_pred[idx]
|
|
272
|
+
order = np.argsort(scores)[::-1]
|
|
273
|
+
k_user = min(k, idx.size)
|
|
274
|
+
|
|
275
|
+
hits = 0
|
|
276
|
+
sum_precisions = 0.0
|
|
277
|
+
for rank, i in enumerate(order[:k_user], start=1):
|
|
278
|
+
if labels[i] > 0:
|
|
279
|
+
hits += 1
|
|
280
|
+
sum_precisions += hits / float(rank)
|
|
281
|
+
|
|
282
|
+
if hits == 0:
|
|
283
|
+
aps.append(0.0)
|
|
284
|
+
else:
|
|
285
|
+
aps.append(sum_precisions / float(num_pos))
|
|
286
|
+
|
|
287
|
+
return float(np.mean(aps)) if aps else 0.0
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
291
|
+
"""Compute Cosine Separation."""
|
|
292
|
+
y_true = (y_true > 0).astype(int)
|
|
293
|
+
pos_mask = y_true == 1
|
|
294
|
+
neg_mask = y_true == 0
|
|
295
|
+
|
|
296
|
+
if not np.any(pos_mask) or not np.any(neg_mask):
|
|
297
|
+
return 0.0
|
|
298
|
+
|
|
299
|
+
pos_mean = float(np.mean(y_pred[pos_mask]))
|
|
300
|
+
neg_mean = float(np.mean(y_pred[neg_mask]))
|
|
301
|
+
return pos_mean - neg_mean
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def configure_metrics(
|
|
305
|
+
task: str | list[str], # 'binary' or ['binary', 'regression']
|
|
306
|
+
metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
|
|
307
|
+
target_names: list[str] # ['target1', 'target2']
|
|
308
|
+
) -> tuple[list[str], dict[str, list[str]] | None, str]:
|
|
309
|
+
"""Configure metrics based on task and user input."""
|
|
310
|
+
primary_task = task[0] if isinstance(task, list) else task
|
|
311
|
+
nums_task = len(task) if isinstance(task, list) else 1
|
|
312
|
+
|
|
313
|
+
metrics_list: list[str] = []
|
|
314
|
+
task_specific_metrics: dict[str, list[str]] | None = None
|
|
315
|
+
|
|
316
|
+
if isinstance(metrics, dict):
|
|
317
|
+
metrics_list = []
|
|
318
|
+
task_specific_metrics = {}
|
|
319
|
+
for task_name, task_metrics in metrics.items():
|
|
320
|
+
if task_name not in target_names:
|
|
321
|
+
logging.warning(
|
|
322
|
+
"Task '%s' not found in targets %s, skipping its metrics",
|
|
323
|
+
task_name,
|
|
324
|
+
target_names,
|
|
325
|
+
)
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
lowered = [m.lower() for m in task_metrics]
|
|
329
|
+
task_specific_metrics[task_name] = lowered
|
|
330
|
+
for metric in lowered:
|
|
331
|
+
if metric not in metrics_list:
|
|
332
|
+
metrics_list.append(metric)
|
|
333
|
+
|
|
334
|
+
elif metrics:
|
|
335
|
+
metrics_list = [m.lower() for m in metrics]
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
# No user provided metrics, derive per task type
|
|
339
|
+
if nums_task > 1 and isinstance(task, list):
|
|
340
|
+
deduped: list[str] = []
|
|
341
|
+
for task_type in task:
|
|
342
|
+
# Inline get_default_metrics_for_task logic
|
|
343
|
+
if task_type not in TASK_DEFAULT_METRICS:
|
|
344
|
+
raise ValueError(f"Unsupported task type: {task_type}")
|
|
345
|
+
for metric in TASK_DEFAULT_METRICS[task_type]:
|
|
346
|
+
if metric not in deduped:
|
|
347
|
+
deduped.append(metric)
|
|
348
|
+
metrics_list = deduped
|
|
349
|
+
else:
|
|
350
|
+
# Inline get_default_metrics_for_task logic
|
|
351
|
+
if primary_task not in TASK_DEFAULT_METRICS:
|
|
352
|
+
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
353
|
+
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
354
|
+
|
|
355
|
+
if not metrics_list:
|
|
356
|
+
# Inline get_default_metrics_for_task logic
|
|
357
|
+
if primary_task not in TASK_DEFAULT_METRICS:
|
|
358
|
+
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
359
|
+
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
360
|
+
|
|
361
|
+
best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
|
|
362
|
+
|
|
363
|
+
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
|
|
367
|
+
"""Determine if metric should be maximized or minimized."""
|
|
368
|
+
first_metric_lower = first_metric.lower()
|
|
369
|
+
|
|
370
|
+
# Metrics that should be maximized
|
|
371
|
+
if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
|
|
372
|
+
return 'max'
|
|
373
|
+
|
|
374
|
+
# Ranking metrics that should be maximized (with @K suffix)
|
|
375
|
+
if (first_metric_lower.startswith('recall@') or
|
|
376
|
+
first_metric_lower.startswith('precision@') or
|
|
377
|
+
first_metric_lower.startswith('hitrate@') or
|
|
378
|
+
first_metric_lower.startswith('hr@') or
|
|
379
|
+
first_metric_lower.startswith('mrr@') or
|
|
380
|
+
first_metric_lower.startswith('ndcg@') or
|
|
381
|
+
first_metric_lower.startswith('map@')):
|
|
382
|
+
return 'max'
|
|
383
|
+
|
|
384
|
+
# Cosine separation should be maximized
|
|
385
|
+
if first_metric_lower == 'cosine':
|
|
386
|
+
return 'max'
|
|
387
|
+
|
|
388
|
+
# Metrics that should be minimized
|
|
389
|
+
if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
|
|
390
|
+
return 'min'
|
|
391
|
+
|
|
392
|
+
# Default based on task type
|
|
393
|
+
if primary_task == 'regression':
|
|
394
|
+
return 'min'
|
|
395
|
+
return 'max'
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def compute_single_metric(
|
|
399
|
+
metric: str,
|
|
400
|
+
y_true: np.ndarray,
|
|
401
|
+
y_pred: np.ndarray,
|
|
402
|
+
task_type: str,
|
|
403
|
+
user_ids: np.ndarray | None = None
|
|
404
|
+
) -> float:
|
|
405
|
+
"""Compute a single metric given true and predicted values."""
|
|
406
|
+
y_p_binary = (y_pred > 0.5).astype(int)
|
|
407
|
+
|
|
408
|
+
try:
|
|
409
|
+
metric_lower = metric.lower()
|
|
410
|
+
|
|
411
|
+
# recall@K
|
|
412
|
+
if metric_lower.startswith('recall@'):
|
|
413
|
+
k = int(metric_lower.split('@')[1])
|
|
414
|
+
return _compute_recall_at_k(y_true, y_pred, user_ids, k)
|
|
415
|
+
|
|
416
|
+
# precision@K
|
|
417
|
+
if metric_lower.startswith('precision@'):
|
|
418
|
+
k = int(metric_lower.split('@')[1])
|
|
419
|
+
return _compute_precision_at_k(y_true, y_pred, user_ids, k)
|
|
420
|
+
|
|
421
|
+
# hitrate@K / hr@K
|
|
422
|
+
if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
|
|
423
|
+
k_str = metric_lower.split('@')[1]
|
|
424
|
+
k = int(k_str)
|
|
425
|
+
return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
|
|
426
|
+
|
|
427
|
+
# mrr@K
|
|
428
|
+
if metric_lower.startswith('mrr@'):
|
|
429
|
+
k = int(metric_lower.split('@')[1])
|
|
430
|
+
return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
|
|
431
|
+
|
|
432
|
+
# ndcg@K
|
|
433
|
+
if metric_lower.startswith('ndcg@'):
|
|
434
|
+
k = int(metric_lower.split('@')[1])
|
|
435
|
+
return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
|
|
436
|
+
|
|
437
|
+
# map@K
|
|
438
|
+
if metric_lower.startswith('map@'):
|
|
439
|
+
k = int(metric_lower.split('@')[1])
|
|
440
|
+
return _compute_map_at_k(y_true, y_pred, user_ids, k)
|
|
441
|
+
|
|
442
|
+
# cosine for matching task
|
|
443
|
+
if metric_lower == 'cosine':
|
|
444
|
+
return _compute_cosine_separation(y_true, y_pred)
|
|
445
|
+
|
|
446
|
+
if metric == 'auc':
|
|
447
|
+
value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
|
|
448
|
+
elif metric == 'gauc':
|
|
449
|
+
value = float(compute_gauc(y_true, y_pred, user_ids))
|
|
450
|
+
elif metric == 'ks':
|
|
451
|
+
value = float(compute_ks(y_true, y_pred))
|
|
452
|
+
elif metric == 'logloss':
|
|
453
|
+
value = float(log_loss(y_true, y_pred))
|
|
454
|
+
elif metric in ('accuracy', 'acc'):
|
|
455
|
+
value = float(accuracy_score(y_true, y_p_binary))
|
|
456
|
+
elif metric == 'precision':
|
|
457
|
+
value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
458
|
+
elif metric == 'recall':
|
|
459
|
+
value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
460
|
+
elif metric == 'f1':
|
|
461
|
+
value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
|
|
462
|
+
elif metric == 'micro_f1':
|
|
463
|
+
value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
|
|
464
|
+
elif metric == 'macro_f1':
|
|
465
|
+
value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
|
|
466
|
+
elif metric == 'mse':
|
|
467
|
+
value = float(mean_squared_error(y_true, y_pred))
|
|
468
|
+
elif metric == 'mae':
|
|
469
|
+
value = float(mean_absolute_error(y_true, y_pred))
|
|
470
|
+
elif metric == 'rmse':
|
|
471
|
+
value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
|
472
|
+
elif metric == 'r2':
|
|
473
|
+
value = float(r2_score(y_true, y_pred))
|
|
474
|
+
elif metric == 'mape':
|
|
475
|
+
value = float(compute_mape(y_true, y_pred))
|
|
476
|
+
elif metric == 'msle':
|
|
477
|
+
value = float(compute_msle(y_true, y_pred))
|
|
478
|
+
else:
|
|
479
|
+
logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
|
|
480
|
+
value = 0.0
|
|
481
|
+
except Exception as exception:
|
|
482
|
+
logging.warning(f"Failed to compute metric {metric}: {exception}")
|
|
483
|
+
value = 0.0
|
|
484
|
+
|
|
485
|
+
return value
|
|
486
|
+
|
|
487
|
+
def evaluate_metrics(
|
|
488
|
+
y_true: np.ndarray | None,
|
|
489
|
+
y_pred: np.ndarray | None,
|
|
490
|
+
metrics: list[str], # example: ['auc', 'logloss']
|
|
491
|
+
task: str | list[str], # example: 'binary' or ['binary', 'regression']
|
|
492
|
+
target_names: list[str], # example: ['target1', 'target2']
|
|
493
|
+
task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
494
|
+
user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
|
|
495
|
+
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
496
|
+
"""Evaluate specified metrics for given true and predicted values."""
|
|
497
|
+
result = {}
|
|
498
|
+
|
|
499
|
+
if y_true is None or y_pred is None:
|
|
500
|
+
return result
|
|
501
|
+
|
|
502
|
+
# Main evaluation logic
|
|
503
|
+
primary_task = task[0] if isinstance(task, list) else task
|
|
504
|
+
nums_task = len(task) if isinstance(task, list) else 1
|
|
505
|
+
|
|
506
|
+
# Single task evaluation
|
|
507
|
+
if nums_task == 1:
|
|
508
|
+
for metric in metrics:
|
|
509
|
+
metric_lower = metric.lower()
|
|
510
|
+
value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
|
|
511
|
+
result[metric_lower] = value
|
|
512
|
+
|
|
513
|
+
# Multi-task evaluation
|
|
514
|
+
else:
|
|
515
|
+
for metric in metrics:
|
|
516
|
+
metric_lower = metric.lower()
|
|
517
|
+
for task_idx in range(nums_task):
|
|
518
|
+
# Check if metric should be computed for given task
|
|
519
|
+
should_compute = True
|
|
520
|
+
if task_specific_metrics is not None and task_idx < len(target_names):
|
|
521
|
+
task_name = target_names[task_idx]
|
|
522
|
+
should_compute = metric_lower in task_specific_metrics.get(task_name, [])
|
|
523
|
+
else:
|
|
524
|
+
# Get task type for specific index
|
|
525
|
+
if isinstance(task, list) and task_idx < len(task):
|
|
526
|
+
task_type = task[task_idx]
|
|
527
|
+
elif isinstance(task, str):
|
|
528
|
+
task_type = task
|
|
529
|
+
else:
|
|
530
|
+
task_type = 'binary'
|
|
531
|
+
|
|
532
|
+
if task_type in ['binary', 'multilabel']:
|
|
533
|
+
should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
|
|
534
|
+
elif task_type == 'regression':
|
|
535
|
+
should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
|
|
536
|
+
|
|
537
|
+
if not should_compute:
|
|
538
|
+
continue
|
|
539
|
+
|
|
540
|
+
target_name = target_names[task_idx]
|
|
541
|
+
|
|
542
|
+
# Get task type for specific index
|
|
543
|
+
if isinstance(task, list) and task_idx < len(task):
|
|
544
|
+
task_type = task[task_idx]
|
|
545
|
+
elif isinstance(task, str):
|
|
546
|
+
task_type = task
|
|
547
|
+
else:
|
|
548
|
+
task_type = 'binary'
|
|
549
|
+
|
|
550
|
+
y_true_task = y_true[:, task_idx]
|
|
551
|
+
y_pred_task = y_pred[:, task_idx]
|
|
552
|
+
|
|
553
|
+
# Compute metric
|
|
554
|
+
value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
|
|
555
|
+
result[f'{metric_lower}_{target_name}'] = value
|
|
556
|
+
|
|
557
|
+
return result
|