nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +10 -18
  3. nextrec/basic/asserts.py +1 -22
  4. nextrec/basic/callback.py +2 -2
  5. nextrec/basic/features.py +6 -37
  6. nextrec/basic/heads.py +13 -1
  7. nextrec/basic/layers.py +33 -123
  8. nextrec/basic/loggers.py +3 -2
  9. nextrec/basic/metrics.py +85 -4
  10. nextrec/basic/model.py +518 -7
  11. nextrec/basic/summary.py +88 -42
  12. nextrec/cli.py +117 -30
  13. nextrec/data/data_processing.py +8 -13
  14. nextrec/data/preprocessor.py +449 -844
  15. nextrec/loss/grad_norm.py +78 -76
  16. nextrec/models/multi_task/ple.py +1 -0
  17. nextrec/models/multi_task/share_bottom.py +1 -0
  18. nextrec/models/ranking/afm.py +4 -9
  19. nextrec/models/ranking/dien.py +7 -8
  20. nextrec/models/ranking/ffm.py +2 -2
  21. nextrec/models/retrieval/sdm.py +1 -2
  22. nextrec/models/sequential/hstu.py +0 -2
  23. nextrec/models/tree_base/base.py +1 -1
  24. nextrec/utils/__init__.py +2 -1
  25. nextrec/utils/config.py +1 -1
  26. nextrec/utils/console.py +1 -1
  27. nextrec/utils/onnx_utils.py +252 -0
  28. nextrec/utils/torch_utils.py +63 -56
  29. nextrec/utils/types.py +43 -0
  30. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
  31. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
  32. nextrec/models/multi_task/[pre]star.py +0 -192
  33. nextrec/models/representation/autorec.py +0 -0
  34. nextrec/models/representation/bpr.py +0 -0
  35. nextrec/models/representation/cl4srec.py +0 -0
  36. nextrec/models/representation/lightgcn.py +0 -0
  37. nextrec/models/representation/mf.py +0 -0
  38. nextrec/models/representation/s3rec.py +0 -0
  39. nextrec/models/sequential/sasrec.py +0 -0
  40. nextrec/utils/feature.py +0 -29
  41. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
  42. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
  43. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.33"
1
+ __version__ = "0.5.0"
@@ -1,8 +1,8 @@
1
1
  """
2
- Activation function definitions for NextRec models.
2
+ Activation function definitions.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 20/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -22,26 +22,18 @@ class Dice(nn.Module):
22
22
  where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
23
23
  """
24
24
 
25
- def __init__(self, emb_size: int, epsilon: float = 1e-9):
25
+ def __init__(self, emb_size: int, epsilon: float = 1e-3):
26
26
  super(Dice, self).__init__()
27
- self.epsilon = epsilon
28
27
  self.alpha = nn.Parameter(torch.zeros(emb_size))
29
- self.bn = nn.BatchNorm1d(emb_size)
28
+ self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
30
29
 
31
30
  def forward(self, x):
32
- # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
33
- original_shape = x.shape
34
-
35
- if x.dim() == 3:
36
- # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
37
- batch_size, seq_len, emb_size = x.shape
38
- x = x.view(-1, emb_size)
39
- x_norm = self.bn(x)
40
- p = torch.sigmoid(x_norm)
41
- output = p * x + (1 - p) * self.alpha * x
42
- if len(original_shape) == 3:
43
- output = output.view(original_shape)
44
- return output
31
+ # keep original shape for reshaping back after batch norm
32
+ orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
33
+ x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
34
+ x_norm = self.bn(x2)
35
+ p = torch.sigmoid(x_norm).reshape(orig_shape)
36
+ return x * (self.alpha + (1 - self.alpha) * p)
45
37
 
46
38
 
47
39
  def activation_layer(
nextrec/basic/asserts.py CHANGED
@@ -8,7 +8,7 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from nextrec.utils.types import TaskTypeName, TrainingModeName
11
+ from nextrec.utils.types import TaskTypeName
12
12
 
13
13
 
14
14
  def assert_task(
@@ -49,24 +49,3 @@ def assert_task(
49
49
  raise ValueError(
50
50
  f"{model_name} requires task length {nums_task}, got {len(task)}."
51
51
  )
52
-
53
-
54
- def assert_training_mode(
55
- training_mode: TrainingModeName | list[TrainingModeName],
56
- nums_task: int,
57
- *,
58
- model_name: str,
59
- ) -> None:
60
- valid_modes = {"pointwise", "pairwise", "listwise"}
61
- if not isinstance(training_mode, list):
62
- raise TypeError(
63
- f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
64
- )
65
- if len(training_mode) != nums_task:
66
- raise ValueError(
67
- f"[{model_name}-init Error] training_mode list length must match number of tasks."
68
- )
69
- if any(mode not in valid_modes for mode in training_mode):
70
- raise ValueError(
71
- f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
72
- )
nextrec/basic/callback.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 21/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -69,7 +69,7 @@ class Callback:
69
69
 
70
70
  class CallbackList:
71
71
  """
72
- Generates a list of callbacks
72
+ Generates a list of callbacks, used to manage and invoke multiple callbacks during training.
73
73
  """
74
74
 
75
75
  def __init__(self, callbacks: Optional[list[Callback]] = None):
nextrec/basic/features.py CHANGED
@@ -8,10 +8,9 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  import torch
10
10
 
11
- from typing import Literal
12
-
13
11
  from nextrec.utils.embedding import get_auto_embedding_dim
14
- from nextrec.utils.feature import to_list
12
+ from nextrec.utils.torch_utils import to_list
13
+ from nextrec.utils.types import EmbeddingInitType, SequenceCombinerType
15
14
 
16
15
 
17
16
  class BaseFeature:
@@ -29,15 +28,7 @@ class EmbeddingFeature(BaseFeature):
29
28
  embedding_name: str = "",
30
29
  embedding_dim: int | None = None,
31
30
  padding_idx: int = 0,
32
- init_type: Literal[
33
- "normal",
34
- "uniform",
35
- "xavier_uniform",
36
- "xavier_normal",
37
- "kaiming_uniform",
38
- "kaiming_normal",
39
- "orthogonal",
40
- ] = "normal",
31
+ init_type: EmbeddingInitType = "normal",
41
32
  init_params: dict | None = None,
42
33
  l1_reg: float = 0.0,
43
34
  l2_reg: float = 0.0,
@@ -73,23 +64,9 @@ class SequenceFeature(EmbeddingFeature):
73
64
  max_len: int = 50,
74
65
  embedding_name: str = "",
75
66
  embedding_dim: int | None = None,
76
- combiner: Literal[
77
- "mean",
78
- "sum",
79
- "concat",
80
- "dot_attention",
81
- "self_attention",
82
- ] = "mean",
67
+ combiner: SequenceCombinerType = "mean",
83
68
  padding_idx: int = 0,
84
- init_type: Literal[
85
- "normal",
86
- "uniform",
87
- "xavier_uniform",
88
- "xavier_normal",
89
- "kaiming_uniform",
90
- "kaiming_normal",
91
- "orthogonal",
92
- ] = "normal",
69
+ init_type: EmbeddingInitType = "normal",
93
70
  init_params: dict | None = None,
94
71
  l1_reg: float = 0.0,
95
72
  l2_reg: float = 0.0,
@@ -143,15 +120,7 @@ class SparseFeature(EmbeddingFeature):
143
120
  embedding_name: str = "",
144
121
  embedding_dim: int | None = None,
145
122
  padding_idx: int = 0,
146
- init_type: Literal[
147
- "normal",
148
- "uniform",
149
- "xavier_uniform",
150
- "xavier_normal",
151
- "kaiming_uniform",
152
- "kaiming_normal",
153
- "orthogonal",
154
- ] = "normal",
123
+ init_type: EmbeddingInitType = "normal",
155
124
  init_params: dict | None = None,
156
125
  l1_reg: float = 0.0,
157
126
  l2_reg: float = 0.0,
nextrec/basic/heads.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Task head implementations for NextRec models.
3
3
 
4
4
  Date: create on 23/12/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -24,6 +24,12 @@ class TaskHead(nn.Module):
24
24
 
25
25
  This wraps PredictionLayer so models can depend on a "Head" abstraction
26
26
  without changing their existing forward signatures.
27
+
28
+ Args:
29
+ task_type: The type of task(s) this head is responsible for.
30
+ task_dims: The dimensionality of each task's output.
31
+ use_bias: Whether to include a bias term in the prediction layer.
32
+ return_logits: Whether to return raw logits or apply activation.
27
33
  """
28
34
 
29
35
  def __init__(
@@ -56,6 +62,12 @@ class RetrievalHead(nn.Module):
56
62
 
57
63
  It computes similarity for pointwise training/inference, and returns
58
64
  raw embeddings for in-batch negative sampling in pairwise/listwise modes.
65
+
66
+ Args:
67
+ similarity_metric: The metric used to compute similarity between embeddings.
68
+ temperature: Scaling factor for similarity scores.
69
+ training_mode: The training mode, which can be pointwise, pairwise, or listwise.
70
+ apply_sigmoid: Whether to apply sigmoid activation to the similarity scores in pointwise mode.
59
71
  """
60
72
 
61
73
  def __init__(
nextrec/basic/layers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 25/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -20,15 +20,13 @@ import torch.nn.functional as F
20
20
  from nextrec.basic.activation import activation_layer
21
21
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.utils.torch_utils import get_initializer
23
- from nextrec.utils.types import ActivationName
23
+ from nextrec.utils.types import ActivationName, TaskTypeName
24
24
 
25
25
 
26
26
  class PredictionLayer(nn.Module):
27
27
  def __init__(
28
28
  self,
29
- task_type: (
30
- Literal["binary", "regression"] | list[Literal["binary", "regression"]]
31
- ) = "binary",
29
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
32
30
  task_dims: int | list[int] | None = None,
33
31
  use_bias: bool = True,
34
32
  return_logits: bool = False,
@@ -81,10 +79,12 @@ class PredictionLayer(nn.Module):
81
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
82
80
  if x.dim() == 1:
83
81
  x = x.unsqueeze(0) # (1 * total_dim)
84
- if x.shape[-1] != self.total_dim:
85
- raise ValueError(
86
- f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
87
- )
82
+ if not torch.onnx.is_in_onnx_export():
83
+ if x.shape[-1] != self.total_dim:
84
+ raise ValueError(
85
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
86
+ )
87
+
88
88
  logits = x if self.bias is None else x + self.bias
89
89
  outputs = []
90
90
  for task_type, (start, end) in zip(self.task_types, self.task_slices):
@@ -92,10 +92,9 @@ class PredictionLayer(nn.Module):
92
92
  if self.return_logits:
93
93
  outputs.append(task_logits)
94
94
  continue
95
- task = task_type.lower()
96
- if task == "binary":
95
+ if task_type == "binary":
97
96
  outputs.append(torch.sigmoid(task_logits))
98
- elif task == "regression":
97
+ elif task_type == "regression":
99
98
  outputs.append(task_logits)
100
99
  else:
101
100
  raise ValueError(
@@ -219,7 +218,7 @@ class EmbeddingLayer(nn.Module):
219
218
 
220
219
  elif isinstance(feature, SequenceFeature):
221
220
  seq_input = x[feature.name].long()
222
- if feature.max_len is not None and seq_input.size(1) > feature.max_len:
221
+ if feature.max_len is not None:
223
222
  seq_input = seq_input[:, -feature.max_len :]
224
223
 
225
224
  embed = self.embed_dict[feature.embedding_name]
@@ -282,10 +281,11 @@ class EmbeddingLayer(nn.Module):
282
281
  value = value.view(value.size(0), -1) # [B, input_dim]
283
282
  input_dim = feature.input_dim
284
283
  assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
285
- if value.shape[1] != assert_input_dim:
286
- raise ValueError(
287
- f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
- )
284
+ if not torch.onnx.is_in_onnx_export():
285
+ if value.shape[1] != assert_input_dim:
286
+ raise ValueError(
287
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
+ )
289
289
  if not feature.use_projection:
290
290
  return value
291
291
  dense_layer = self.dense_transforms[feature.name]
@@ -331,29 +331,10 @@ class InputMask(nn.Module):
331
331
  feature: SequenceFeature,
332
332
  seq_tensor: torch.Tensor | None = None,
333
333
  ):
334
- if seq_tensor is not None:
335
- values = seq_tensor
336
- else:
337
- values = x[feature.name]
338
- values = values.long()
334
+ values = seq_tensor if seq_tensor is not None else x[feature.name]
335
+ values = values.long().view(values.size(0), -1)
339
336
  padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
340
- mask = values != padding_idx
341
-
342
- if mask.dim() == 1:
343
- # [B] -> [B, 1, 1]
344
- mask = mask.unsqueeze(1).unsqueeze(2)
345
- elif mask.dim() == 2:
346
- # [B, L] -> [B, 1, L]
347
- mask = mask.unsqueeze(1)
348
- elif mask.dim() == 3:
349
- # [B, 1, L]
350
- # [B, L, 1] -> [B, L] -> [B, 1, L]
351
- if mask.size(1) != 1 and mask.size(2) == 1:
352
- mask = mask.squeeze(-1).unsqueeze(1)
353
- else:
354
- raise ValueError(
355
- f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
356
- )
337
+ mask = (values != padding_idx).unsqueeze(1)
357
338
  return mask.float()
358
339
 
359
340
 
@@ -897,30 +878,7 @@ class AttentionPoolingLayer(nn.Module):
897
878
  self,
898
879
  embedding_dim: int,
899
880
  hidden_units: list = [80, 40],
900
- activation: Literal[
901
- "dice",
902
- "relu",
903
- "relu6",
904
- "elu",
905
- "selu",
906
- "leaky_relu",
907
- "prelu",
908
- "gelu",
909
- "sigmoid",
910
- "tanh",
911
- "softplus",
912
- "softsign",
913
- "hardswish",
914
- "mish",
915
- "silu",
916
- "swish",
917
- "hardsigmoid",
918
- "tanhshrink",
919
- "softshrink",
920
- "none",
921
- "linear",
922
- "identity",
923
- ] = "sigmoid",
881
+ activation: ActivationName = "sigmoid",
924
882
  use_softmax: bool = False,
925
883
  ):
926
884
  super().__init__()
@@ -954,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
954
912
  output: [batch_size, embedding_dim] - attention pooled representation
955
913
  """
956
914
  batch_size, sequence_length, embedding_dim = keys.shape
957
- assert query.shape == (
958
- batch_size,
959
- embedding_dim,
960
- ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
961
- if mask is None and keys_length is not None:
962
- # keys_length: (batch_size,)
963
- device = keys.device
964
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(
965
- 0
966
- ) # (1, sequence_length)
967
- mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
968
- if mask is not None:
969
- if mask.dim() == 2:
970
- # (B, L)
971
- mask = mask.unsqueeze(-1)
972
- elif (
973
- mask.dim() == 3
974
- and mask.shape[1] == 1
975
- and mask.shape[2] == sequence_length
976
- ):
977
- # (B, 1, L) -> (B, L, 1)
978
- mask = mask.transpose(1, 2)
979
- elif (
980
- mask.dim() == 3
981
- and mask.shape[1] == sequence_length
982
- and mask.shape[2] == 1
983
- ):
984
- pass
915
+ if mask is None:
916
+ if keys_length is None:
917
+ mask = torch.ones(
918
+ (batch_size, sequence_length), device=keys.device, dtype=keys.dtype
919
+ )
985
920
  else:
921
+ device = keys.device
922
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
923
+ mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
924
+ else:
925
+ mask = mask.to(keys.dtype).reshape(batch_size, -1)
926
+ if mask.shape[1] != sequence_length:
986
927
  raise ValueError(
987
928
  f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
988
929
  )
989
- mask = mask.to(keys.dtype)
930
+ mask = mask.unsqueeze(-1)
990
931
  # Expand query to (B, L, D)
991
932
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
992
933
  # [query, key, query-key, query*key] -> (B, L, 4D)
@@ -1026,34 +967,3 @@ class RMSNorm(torch.nn.Module):
1026
967
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1027
968
  x_normalized = x * torch.rsqrt(variance + self.eps)
1028
969
  return self.weight * x_normalized
1029
-
1030
-
1031
- class DomainBatchNorm(nn.Module):
1032
- """Domain-specific BatchNorm (applied per-domain with a shared interface)."""
1033
-
1034
- def __init__(self, num_features: int, num_domains: int):
1035
- super().__init__()
1036
- if num_domains < 1:
1037
- raise ValueError("num_domains must be >= 1")
1038
- self.bns = nn.ModuleList(
1039
- [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1040
- )
1041
-
1042
- def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1043
- if x.dim() != 2:
1044
- raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1045
- output = x.clone()
1046
- if domain_mask.dim() == 1:
1047
- domain_ids = domain_mask.long()
1048
- for idx, bn in enumerate(self.bns):
1049
- mask = domain_ids == idx
1050
- if mask.any():
1051
- output[mask] = bn(x[mask])
1052
- return output
1053
- if domain_mask.dim() != 2:
1054
- raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1055
- for idx, bn in enumerate(self.bns):
1056
- mask = domain_mask[:, idx] > 0
1057
- if mask.any():
1058
- output[mask] = bn(x[mask])
1059
- return output
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 01/01/2026
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -99,7 +99,8 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
99
99
 
100
100
 
101
101
  def setup_logger(session_id: str | os.PathLike | None = None):
102
- """Set up a logger that logs to both console and a file with ANSI formatting.
102
+ """
103
+ Set up a logger that logs to both console and a file with ANSI formatting.
103
104
  Only console output has colors; file output is stripped of ANSI codes.
104
105
 
105
106
  Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
nextrec/basic/metrics.py CHANGED
@@ -23,7 +23,6 @@ from sklearn.metrics import (
23
23
  )
24
24
  from nextrec.utils.types import TaskTypeName, MetricsName
25
25
 
26
-
27
26
  TASK_DEFAULT_METRICS = {
28
27
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
29
28
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
@@ -334,6 +333,60 @@ def compute_map_at_k(
334
333
  return float(np.mean(aps)) if aps else 0.0
335
334
 
336
335
 
336
+ def compute_topk_counts(
337
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
338
+ ) -> tuple[int, int, int]:
339
+ """Compute Top-K% sample size, hits, and positives for binary labels."""
340
+ y_true = (y_true > 0).astype(int)
341
+ n = y_true.size
342
+ if n == 0:
343
+ return 0, 0, 0
344
+ if k_percent <= 0:
345
+ return 0, 0, int(y_true.sum())
346
+ if k_percent >= 100:
347
+ k_count = n
348
+ else:
349
+ k_count = int(np.ceil(n * (k_percent / 100.0)))
350
+ k_count = max(k_count, 1)
351
+ order = np.argsort(y_pred)[::-1]
352
+ topk = order[:k_count]
353
+ hits = int(y_true[topk].sum())
354
+ total_pos = int(y_true.sum())
355
+ return k_count, hits, total_pos
356
+
357
+
358
+ def compute_topk_precision(
359
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
360
+ ) -> float:
361
+ """Compute Top-K% Precision."""
362
+ k_count, hits, _ = compute_topk_counts(y_true, y_pred, k_percent)
363
+ if k_count == 0:
364
+ return 0.0
365
+ return float(hits / k_count)
366
+
367
+
368
+ def compute_topk_recall(
369
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
370
+ ) -> float:
371
+ """Compute Top-K% Recall."""
372
+ _, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
373
+ if total_pos == 0:
374
+ return 0.0
375
+ return float(hits / total_pos)
376
+
377
+
378
+ def compute_lift_at_k(y_true: np.ndarray, y_pred: np.ndarray, k_percent: int) -> float:
379
+ """Compute Lift@K from Top-K% precision and overall rate."""
380
+ k_count, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
381
+ if k_count == 0:
382
+ return 0.0
383
+ base_rate = total_pos / float(y_true.size)
384
+ if base_rate == 0.0:
385
+ return 0.0
386
+ precision = hits / float(k_count)
387
+ return float(precision / base_rate)
388
+
389
+
337
390
  def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
338
391
  """Compute Cosine Separation."""
339
392
  y_true = (y_true > 0).astype(int)
@@ -399,11 +452,11 @@ def configure_metrics(
399
452
  if primary_task not in TASK_DEFAULT_METRICS:
400
453
  raise ValueError(f"Unsupported task type: {primary_task}")
401
454
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
402
- best_metrics_mode = getbest_metric_mode(metrics_list[0], primary_task)
455
+ best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
403
456
  return metrics_list, task_specific_metrics, best_metrics_mode
404
457
 
405
458
 
406
- def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
459
+ def get_best_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
407
460
  """Determine if metric should be maximized or minimized."""
408
461
  # Metrics that should be maximized
409
462
  if first_metric in {
@@ -429,6 +482,9 @@ def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -
429
482
  or first_metric.startswith("mrr@")
430
483
  or first_metric.startswith("ndcg@")
431
484
  or first_metric.startswith("map@")
485
+ or first_metric.startswith("topk_recall@")
486
+ or first_metric.startswith("topk_precision@")
487
+ or first_metric.startswith("lift@")
432
488
  ):
433
489
  return "max"
434
490
  # Cosine separation should be maximized
@@ -457,6 +513,15 @@ def compute_single_metric(
457
513
 
458
514
  y_p_binary = (y_pred > 0.5).astype(int)
459
515
  try:
516
+ if metric.startswith("topk_recall@"):
517
+ k_percent = int(metric.split("@")[1])
518
+ return compute_topk_recall(y_true, y_pred, k_percent)
519
+ if metric.startswith("topk_precision@"):
520
+ k_percent = int(metric.split("@")[1])
521
+ return compute_topk_precision(y_true, y_pred, k_percent)
522
+ if metric.startswith("lift@"):
523
+ k_percent = int(metric.split("@")[1])
524
+ return compute_lift_at_k(y_true, y_pred, k_percent)
460
525
  if metric.startswith("recall@"):
461
526
  k = int(metric.split("@")[1])
462
527
  return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
@@ -650,7 +715,23 @@ def evaluate_metrics(
650
715
  allowed_metrics = metric_allowlist.get(task_type)
651
716
  for metric in metrics:
652
717
  if allowed_metrics is not None and metric not in allowed_metrics:
653
- continue
718
+ if metric.startswith(
719
+ (
720
+ "recall@",
721
+ "precision@",
722
+ "hitrate@",
723
+ "hr@",
724
+ "mrr@",
725
+ "ndcg@",
726
+ "map@",
727
+ "topk_recall@",
728
+ "topk_precision@",
729
+ "lift@",
730
+ )
731
+ ):
732
+ pass
733
+ else:
734
+ continue
654
735
  y_true_task = y_true[:, task_idx]
655
736
  y_pred_task = y_pred[:, task_idx]
656
737
  task_user_ids = user_ids