nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,124 @@
1
+ """
2
+ NextRec Basic Loggers
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import sys
12
+ import copy
13
+ import datetime
14
+ import logging
15
+
16
+ ANSI_CODES = {
17
+ 'black': '\033[30m',
18
+ 'red': '\033[31m',
19
+ 'green': '\033[32m',
20
+ 'yellow': '\033[33m',
21
+ 'blue': '\033[34m',
22
+ 'magenta': '\033[35m',
23
+ 'cyan': '\033[36m',
24
+ 'white': '\033[37m',
25
+ 'bright_black': '\033[90m',
26
+ 'bright_red': '\033[91m',
27
+ 'bright_green': '\033[92m',
28
+ 'bright_yellow': '\033[93m',
29
+ 'bright_blue': '\033[94m',
30
+ 'bright_magenta': '\033[95m',
31
+ 'bright_cyan': '\033[96m',
32
+ 'bright_white': '\033[97m',
33
+ }
34
+
35
+ ANSI_BOLD = '\033[1m'
36
+ ANSI_RESET = '\033[0m'
37
+ ANSI_ESCAPE_PATTERN = re.compile(r'\033\[[0-9;]*m')
38
+
39
+ DEFAULT_LEVEL_COLORS = {
40
+ 'DEBUG': 'cyan',
41
+ 'INFO': None,
42
+ 'WARNING': 'yellow',
43
+ 'ERROR': 'red',
44
+ 'CRITICAL': 'bright_red',
45
+ }
46
+
47
+ class AnsiFormatter(logging.Formatter):
48
+ def __init__(
49
+ self,
50
+ *args,
51
+ strip_ansi: bool = False,
52
+ auto_color_level: bool = False,
53
+ level_colors: dict[str, str] | None = None,
54
+ **kwargs,
55
+ ) -> None:
56
+ super().__init__(*args, **kwargs)
57
+ self.strip_ansi = strip_ansi
58
+ self.auto_color_level = auto_color_level
59
+ self.level_colors = level_colors or DEFAULT_LEVEL_COLORS
60
+
61
+ def format(self, record: logging.LogRecord) -> str:
62
+ record_copy = copy.copy(record)
63
+ formatted = super().format(record_copy)
64
+
65
+ if self.auto_color_level and '\033[' not in formatted:
66
+ color = self.level_colors.get(record.levelname)
67
+ if color:
68
+ formatted = colorize(formatted, color=color)
69
+
70
+ if self.strip_ansi:
71
+ return ANSI_ESCAPE_PATTERN.sub('', formatted)
72
+
73
+ return formatted
74
+
75
+ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
76
+ """Apply ANSI color and bold formatting to the given text."""
77
+ if not color and not bold:
78
+ return text
79
+
80
+ result = ""
81
+
82
+ if bold:
83
+ result += ANSI_BOLD
84
+
85
+ if color and color in ANSI_CODES:
86
+ result += ANSI_CODES[color]
87
+
88
+ result += text + ANSI_RESET
89
+
90
+ return result
91
+
92
+ def setup_logger(log_dir: str | None = None):
93
+ """Set up a logger that logs to both console and a file with ANSI formatting.
94
+ Only console output has colors; file output is stripped of ANSI codes.
95
+ """
96
+ if log_dir is None:
97
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
98
+ log_dir = os.path.join(project_root, "..", "logs")
99
+
100
+ os.makedirs(log_dir, exist_ok=True)
101
+ log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
102
+
103
+ console_format = '%(message)s'
104
+ file_format = '%(asctime)s - %(levelname)s - %(message)s'
105
+ date_format = '%H:%M:%S'
106
+
107
+ logger = logging.getLogger()
108
+ logger.setLevel(logging.INFO)
109
+
110
+ if logger.hasHandlers():
111
+ logger.handlers.clear()
112
+
113
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
114
+ file_handler.setLevel(logging.INFO)
115
+ file_handler.setFormatter(AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True))
116
+
117
+ console_handler = logging.StreamHandler(sys.stdout)
118
+ console_handler.setLevel(logging.INFO)
119
+ console_handler.setFormatter(AnsiFormatter(console_format, datefmt=date_format, auto_color_level=True,))
120
+
121
+ logger.addHandler(file_handler)
122
+ logger.addHandler(console_handler)
123
+
124
+ return logger
@@ -0,0 +1,557 @@
1
+ """
2
+ Metrics computation and configuration for model evaluation.
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+ import logging
9
+ import numpy as np
10
+ from sklearn.metrics import (
11
+ roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
12
+ accuracy_score, precision_score, recall_score, f1_score, r2_score,
13
+ )
14
+
15
+
16
+ CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
17
+ REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
18
+ TASK_DEFAULT_METRICS = {
19
+ 'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
20
+ 'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
21
+ 'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
22
+ 'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
23
+ }
24
+
25
+
26
+
27
+ def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
28
+ """Compute Kolmogorov-Smirnov statistic."""
29
+ sorted_indices = np.argsort(y_pred)[::-1]
30
+ y_true_sorted = y_true[sorted_indices]
31
+
32
+ n_pos = np.sum(y_true_sorted == 1)
33
+ n_neg = np.sum(y_true_sorted == 0)
34
+
35
+ if n_pos > 0 and n_neg > 0:
36
+ cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
37
+ cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
38
+ ks_value = np.max(np.abs(cum_pos_rate - cum_neg_rate))
39
+ return float(ks_value)
40
+ return 0.0
41
+
42
+
43
+ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
44
+ """Compute Mean Absolute Percentage Error."""
45
+ mask = y_true != 0
46
+ if np.any(mask):
47
+ return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
48
+ return 0.0
49
+
50
+
51
+ def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
52
+ """Compute Mean Squared Log Error."""
53
+ y_pred_pos = np.maximum(y_pred, 0)
54
+ return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
55
+
56
+
57
+ def compute_gauc(
58
+ y_true: np.ndarray,
59
+ y_pred: np.ndarray,
60
+ user_ids: np.ndarray | None = None
61
+ ) -> float:
62
+ if user_ids is None:
63
+ # If no user_ids provided, fall back to regular AUC
64
+ try:
65
+ return float(roc_auc_score(y_true, y_pred))
66
+ except:
67
+ return 0.0
68
+
69
+ # Group by user_id and calculate AUC for each user
70
+ user_aucs = []
71
+ user_weights = []
72
+
73
+ unique_users = np.unique(user_ids)
74
+
75
+ for user_id in unique_users:
76
+ mask = user_ids == user_id
77
+ user_y_true = y_true[mask]
78
+ user_y_pred = y_pred[mask]
79
+
80
+ # Skip users with only one class (cannot compute AUC)
81
+ if len(np.unique(user_y_true)) < 2:
82
+ continue
83
+
84
+ try:
85
+ user_auc = roc_auc_score(user_y_true, user_y_pred)
86
+ user_aucs.append(user_auc)
87
+ user_weights.append(len(user_y_true))
88
+ except:
89
+ continue
90
+
91
+ if len(user_aucs) == 0:
92
+ return 0.0
93
+
94
+ # Weighted average
95
+ user_aucs = np.array(user_aucs)
96
+ user_weights = np.array(user_weights)
97
+ gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
98
+
99
+ return gauc
100
+
101
+
102
+ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
103
+ """Group sample indices by user_id. If user_ids is None, treat all as one group."""
104
+ if user_ids is None:
105
+ return [np.arange(n_samples)]
106
+
107
+ user_ids = np.asarray(user_ids)
108
+ if user_ids.shape[0] != n_samples:
109
+ logging.warning(
110
+ "user_ids length (%d) != number of samples (%d), "
111
+ "treating all samples as a single group for ranking metrics.",
112
+ user_ids.shape[0],
113
+ n_samples,
114
+ )
115
+ return [np.arange(n_samples)]
116
+
117
+ unique_users = np.unique(user_ids)
118
+ groups = [np.where(user_ids == u)[0] for u in unique_users]
119
+ return groups
120
+
121
+
122
+ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
123
+ y_true = (y_true > 0).astype(int)
124
+ n = len(y_true)
125
+ groups = _group_indices_by_user(user_ids, n)
126
+
127
+ precisions = []
128
+ for idx in groups:
129
+ if idx.size == 0:
130
+ continue
131
+ k_user = min(k, idx.size)
132
+ scores = y_pred[idx]
133
+ labels = y_true[idx]
134
+ order = np.argsort(scores)[::-1]
135
+ topk = order[:k_user]
136
+ hits = labels[topk].sum()
137
+ precisions.append(hits / float(k_user))
138
+
139
+ return float(np.mean(precisions)) if precisions else 0.0
140
+
141
+
142
+ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
143
+ """Compute Recall@K."""
144
+ y_true = (y_true > 0).astype(int)
145
+ n = len(y_true)
146
+ groups = _group_indices_by_user(user_ids, n)
147
+
148
+ recalls = []
149
+ for idx in groups:
150
+ if idx.size == 0:
151
+ continue
152
+ labels = y_true[idx]
153
+ num_pos = labels.sum()
154
+ if num_pos == 0:
155
+ continue # 跳过没有正样本的用户
156
+ scores = y_pred[idx]
157
+ order = np.argsort(scores)[::-1]
158
+ k_user = min(k, idx.size)
159
+ topk = order[:k_user]
160
+ hits = labels[topk].sum()
161
+ recalls.append(hits / float(num_pos))
162
+
163
+ return float(np.mean(recalls)) if recalls else 0.0
164
+
165
+
166
+ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
167
+ """Compute HitRate@K."""
168
+ y_true = (y_true > 0).astype(int)
169
+ n = len(y_true)
170
+ groups = _group_indices_by_user(user_ids, n)
171
+
172
+ hits_per_user = []
173
+ for idx in groups:
174
+ if idx.size == 0:
175
+ continue
176
+ labels = y_true[idx]
177
+ if labels.sum() == 0:
178
+ continue # 无正样本用户不计入
179
+ scores = y_pred[idx]
180
+ order = np.argsort(scores)[::-1]
181
+ k_user = min(k, idx.size)
182
+ topk = order[:k_user]
183
+ hits = labels[topk].sum()
184
+ hits_per_user.append(1.0 if hits > 0 else 0.0)
185
+
186
+ return float(np.mean(hits_per_user)) if hits_per_user else 0.0
187
+
188
+
189
+ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
190
+ """Compute MRR@K."""
191
+ y_true = (y_true > 0).astype(int)
192
+ n = len(y_true)
193
+ groups = _group_indices_by_user(user_ids, n)
194
+
195
+ mrrs = []
196
+ for idx in groups:
197
+ if idx.size == 0:
198
+ continue
199
+ labels = y_true[idx]
200
+ if labels.sum() == 0:
201
+ continue
202
+ scores = y_pred[idx]
203
+ order = np.argsort(scores)[::-1]
204
+ k_user = min(k, idx.size)
205
+ topk = order[:k_user]
206
+ ranked_labels = labels[order]
207
+
208
+ rr = 0.0
209
+ for rank, lab in enumerate(ranked_labels[:k_user], start=1):
210
+ if lab > 0:
211
+ rr = 1.0 / rank
212
+ break
213
+ mrrs.append(rr)
214
+
215
+ return float(np.mean(mrrs)) if mrrs else 0.0
216
+
217
+
218
+ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
219
+ k_user = min(k, labels.size)
220
+ if k_user == 0:
221
+ return 0.0
222
+ gains = (2 ** labels[:k_user] - 1).astype(float)
223
+ discounts = np.log2(np.arange(2, k_user + 2))
224
+ return float(np.sum(gains / discounts))
225
+
226
+
227
+ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
228
+ """Compute NDCG@K."""
229
+ y_true = (y_true > 0).astype(int)
230
+ n = len(y_true)
231
+ groups = _group_indices_by_user(user_ids, n)
232
+
233
+ ndcgs = []
234
+ for idx in groups:
235
+ if idx.size == 0:
236
+ continue
237
+ labels = y_true[idx]
238
+ if labels.sum() == 0:
239
+ continue
240
+ scores = y_pred[idx]
241
+
242
+ order = np.argsort(scores)[::-1]
243
+ ranked_labels = labels[order]
244
+ dcg = _compute_dcg_at_k(ranked_labels, k)
245
+
246
+ # ideal DCG
247
+ ideal_labels = np.sort(labels)[::-1]
248
+ idcg = _compute_dcg_at_k(ideal_labels, k)
249
+ if idcg == 0.0:
250
+ continue
251
+ ndcgs.append(dcg / idcg)
252
+
253
+ return float(np.mean(ndcgs)) if ndcgs else 0.0
254
+
255
+
256
+ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
257
+ """Mean Average Precision@K."""
258
+ y_true = (y_true > 0).astype(int)
259
+ n = len(y_true)
260
+ groups = _group_indices_by_user(user_ids, n)
261
+
262
+ aps = []
263
+ for idx in groups:
264
+ if idx.size == 0:
265
+ continue
266
+ labels = y_true[idx]
267
+ num_pos = labels.sum()
268
+ if num_pos == 0:
269
+ continue
270
+
271
+ scores = y_pred[idx]
272
+ order = np.argsort(scores)[::-1]
273
+ k_user = min(k, idx.size)
274
+
275
+ hits = 0
276
+ sum_precisions = 0.0
277
+ for rank, i in enumerate(order[:k_user], start=1):
278
+ if labels[i] > 0:
279
+ hits += 1
280
+ sum_precisions += hits / float(rank)
281
+
282
+ if hits == 0:
283
+ aps.append(0.0)
284
+ else:
285
+ aps.append(sum_precisions / float(num_pos))
286
+
287
+ return float(np.mean(aps)) if aps else 0.0
288
+
289
+
290
+ def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
291
+ """Compute Cosine Separation."""
292
+ y_true = (y_true > 0).astype(int)
293
+ pos_mask = y_true == 1
294
+ neg_mask = y_true == 0
295
+
296
+ if not np.any(pos_mask) or not np.any(neg_mask):
297
+ return 0.0
298
+
299
+ pos_mean = float(np.mean(y_pred[pos_mask]))
300
+ neg_mean = float(np.mean(y_pred[neg_mask]))
301
+ return pos_mean - neg_mean
302
+
303
+
304
+ def configure_metrics(
305
+ task: str | list[str], # 'binary' or ['binary', 'regression']
306
+ metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
307
+ target_names: list[str] # ['target1', 'target2']
308
+ ) -> tuple[list[str], dict[str, list[str]] | None, str]:
309
+ """Configure metrics based on task and user input."""
310
+ primary_task = task[0] if isinstance(task, list) else task
311
+ nums_task = len(task) if isinstance(task, list) else 1
312
+
313
+ metrics_list: list[str] = []
314
+ task_specific_metrics: dict[str, list[str]] | None = None
315
+
316
+ if isinstance(metrics, dict):
317
+ metrics_list = []
318
+ task_specific_metrics = {}
319
+ for task_name, task_metrics in metrics.items():
320
+ if task_name not in target_names:
321
+ logging.warning(
322
+ "Task '%s' not found in targets %s, skipping its metrics",
323
+ task_name,
324
+ target_names,
325
+ )
326
+ continue
327
+
328
+ lowered = [m.lower() for m in task_metrics]
329
+ task_specific_metrics[task_name] = lowered
330
+ for metric in lowered:
331
+ if metric not in metrics_list:
332
+ metrics_list.append(metric)
333
+
334
+ elif metrics:
335
+ metrics_list = [m.lower() for m in metrics]
336
+
337
+ else:
338
+ # No user provided metrics, derive per task type
339
+ if nums_task > 1 and isinstance(task, list):
340
+ deduped: list[str] = []
341
+ for task_type in task:
342
+ # Inline get_default_metrics_for_task logic
343
+ if task_type not in TASK_DEFAULT_METRICS:
344
+ raise ValueError(f"Unsupported task type: {task_type}")
345
+ for metric in TASK_DEFAULT_METRICS[task_type]:
346
+ if metric not in deduped:
347
+ deduped.append(metric)
348
+ metrics_list = deduped
349
+ else:
350
+ # Inline get_default_metrics_for_task logic
351
+ if primary_task not in TASK_DEFAULT_METRICS:
352
+ raise ValueError(f"Unsupported task type: {primary_task}")
353
+ metrics_list = TASK_DEFAULT_METRICS[primary_task]
354
+
355
+ if not metrics_list:
356
+ # Inline get_default_metrics_for_task logic
357
+ if primary_task not in TASK_DEFAULT_METRICS:
358
+ raise ValueError(f"Unsupported task type: {primary_task}")
359
+ metrics_list = TASK_DEFAULT_METRICS[primary_task]
360
+
361
+ best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
362
+
363
+ return metrics_list, task_specific_metrics, best_metrics_mode
364
+
365
+
366
+ def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
367
+ """Determine if metric should be maximized or minimized."""
368
+ first_metric_lower = first_metric.lower()
369
+
370
+ # Metrics that should be maximized
371
+ if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
372
+ return 'max'
373
+
374
+ # Ranking metrics that should be maximized (with @K suffix)
375
+ if (first_metric_lower.startswith('recall@') or
376
+ first_metric_lower.startswith('precision@') or
377
+ first_metric_lower.startswith('hitrate@') or
378
+ first_metric_lower.startswith('hr@') or
379
+ first_metric_lower.startswith('mrr@') or
380
+ first_metric_lower.startswith('ndcg@') or
381
+ first_metric_lower.startswith('map@')):
382
+ return 'max'
383
+
384
+ # Cosine separation should be maximized
385
+ if first_metric_lower == 'cosine':
386
+ return 'max'
387
+
388
+ # Metrics that should be minimized
389
+ if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
390
+ return 'min'
391
+
392
+ # Default based on task type
393
+ if primary_task == 'regression':
394
+ return 'min'
395
+ return 'max'
396
+
397
+
398
+ def compute_single_metric(
399
+ metric: str,
400
+ y_true: np.ndarray,
401
+ y_pred: np.ndarray,
402
+ task_type: str,
403
+ user_ids: np.ndarray | None = None
404
+ ) -> float:
405
+ """Compute a single metric given true and predicted values."""
406
+ y_p_binary = (y_pred > 0.5).astype(int)
407
+
408
+ try:
409
+ metric_lower = metric.lower()
410
+
411
+ # recall@K
412
+ if metric_lower.startswith('recall@'):
413
+ k = int(metric_lower.split('@')[1])
414
+ return _compute_recall_at_k(y_true, y_pred, user_ids, k)
415
+
416
+ # precision@K
417
+ if metric_lower.startswith('precision@'):
418
+ k = int(metric_lower.split('@')[1])
419
+ return _compute_precision_at_k(y_true, y_pred, user_ids, k)
420
+
421
+ # hitrate@K / hr@K
422
+ if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
423
+ k_str = metric_lower.split('@')[1]
424
+ k = int(k_str)
425
+ return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
426
+
427
+ # mrr@K
428
+ if metric_lower.startswith('mrr@'):
429
+ k = int(metric_lower.split('@')[1])
430
+ return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
431
+
432
+ # ndcg@K
433
+ if metric_lower.startswith('ndcg@'):
434
+ k = int(metric_lower.split('@')[1])
435
+ return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
436
+
437
+ # map@K
438
+ if metric_lower.startswith('map@'):
439
+ k = int(metric_lower.split('@')[1])
440
+ return _compute_map_at_k(y_true, y_pred, user_ids, k)
441
+
442
+ # cosine for matching task
443
+ if metric_lower == 'cosine':
444
+ return _compute_cosine_separation(y_true, y_pred)
445
+
446
+ if metric == 'auc':
447
+ value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
448
+ elif metric == 'gauc':
449
+ value = float(compute_gauc(y_true, y_pred, user_ids))
450
+ elif metric == 'ks':
451
+ value = float(compute_ks(y_true, y_pred))
452
+ elif metric == 'logloss':
453
+ value = float(log_loss(y_true, y_pred))
454
+ elif metric in ('accuracy', 'acc'):
455
+ value = float(accuracy_score(y_true, y_p_binary))
456
+ elif metric == 'precision':
457
+ value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
458
+ elif metric == 'recall':
459
+ value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
460
+ elif metric == 'f1':
461
+ value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
462
+ elif metric == 'micro_f1':
463
+ value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
464
+ elif metric == 'macro_f1':
465
+ value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
466
+ elif metric == 'mse':
467
+ value = float(mean_squared_error(y_true, y_pred))
468
+ elif metric == 'mae':
469
+ value = float(mean_absolute_error(y_true, y_pred))
470
+ elif metric == 'rmse':
471
+ value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
472
+ elif metric == 'r2':
473
+ value = float(r2_score(y_true, y_pred))
474
+ elif metric == 'mape':
475
+ value = float(compute_mape(y_true, y_pred))
476
+ elif metric == 'msle':
477
+ value = float(compute_msle(y_true, y_pred))
478
+ else:
479
+ logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
480
+ value = 0.0
481
+ except Exception as exception:
482
+ logging.warning(f"Failed to compute metric {metric}: {exception}")
483
+ value = 0.0
484
+
485
+ return value
486
+
487
+ def evaluate_metrics(
488
+ y_true: np.ndarray | None,
489
+ y_pred: np.ndarray | None,
490
+ metrics: list[str], # example: ['auc', 'logloss']
491
+ task: str | list[str], # example: 'binary' or ['binary', 'regression']
492
+ target_names: list[str], # example: ['target1', 'target2']
493
+ task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
494
+ user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
495
+ ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
496
+ """Evaluate specified metrics for given true and predicted values."""
497
+ result = {}
498
+
499
+ if y_true is None or y_pred is None:
500
+ return result
501
+
502
+ # Main evaluation logic
503
+ primary_task = task[0] if isinstance(task, list) else task
504
+ nums_task = len(task) if isinstance(task, list) else 1
505
+
506
+ # Single task evaluation
507
+ if nums_task == 1:
508
+ for metric in metrics:
509
+ metric_lower = metric.lower()
510
+ value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
511
+ result[metric_lower] = value
512
+
513
+ # Multi-task evaluation
514
+ else:
515
+ for metric in metrics:
516
+ metric_lower = metric.lower()
517
+ for task_idx in range(nums_task):
518
+ # Check if metric should be computed for given task
519
+ should_compute = True
520
+ if task_specific_metrics is not None and task_idx < len(target_names):
521
+ task_name = target_names[task_idx]
522
+ should_compute = metric_lower in task_specific_metrics.get(task_name, [])
523
+ else:
524
+ # Get task type for specific index
525
+ if isinstance(task, list) and task_idx < len(task):
526
+ task_type = task[task_idx]
527
+ elif isinstance(task, str):
528
+ task_type = task
529
+ else:
530
+ task_type = 'binary'
531
+
532
+ if task_type in ['binary', 'multilabel']:
533
+ should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
534
+ elif task_type == 'regression':
535
+ should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
536
+
537
+ if not should_compute:
538
+ continue
539
+
540
+ target_name = target_names[task_idx]
541
+
542
+ # Get task type for specific index
543
+ if isinstance(task, list) and task_idx < len(task):
544
+ task_type = task[task_idx]
545
+ elif isinstance(task, str):
546
+ task_type = task
547
+ else:
548
+ task_type = 'binary'
549
+
550
+ y_true_task = y_true[:, task_idx]
551
+ y_pred_task = y_pred[:, task_idx]
552
+
553
+ # Compute metric
554
+ value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
555
+ result[f'{metric_lower}_{target_name}'] = value
556
+
557
+ return result