nextrec 0.4.32__py3-none-any.whl → 0.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +14 -16
- nextrec/basic/asserts.py +1 -22
- nextrec/basic/callback.py +2 -2
- nextrec/basic/features.py +6 -37
- nextrec/basic/heads.py +13 -1
- nextrec/basic/layers.py +9 -33
- nextrec/basic/loggers.py +3 -2
- nextrec/basic/metrics.py +85 -4
- nextrec/basic/model.py +19 -12
- nextrec/basic/summary.py +89 -42
- nextrec/cli.py +54 -41
- nextrec/data/preprocessor.py +74 -25
- nextrec/loss/grad_norm.py +78 -76
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/tree_base/base.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/config.py +1 -1
- nextrec/utils/console.py +1 -1
- nextrec/utils/torch_utils.py +63 -56
- nextrec/utils/types.py +43 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/METADATA +4 -4
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/RECORD +27 -35
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/feature.py +0 -29
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/WHEEL +0 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.32.dist-info → nextrec-0.4.34.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.34"
|
nextrec/basic/activation.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Activation function definitions
|
|
2
|
+
Activation function definitions.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -22,26 +22,24 @@ class Dice(nn.Module):
|
|
|
22
22
|
where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
def __init__(self, emb_size: int, epsilon: float = 1e-
|
|
25
|
+
def __init__(self, emb_size: int, epsilon: float = 1e-3):
|
|
26
26
|
super(Dice, self).__init__()
|
|
27
|
-
self.epsilon = epsilon
|
|
28
27
|
self.alpha = nn.Parameter(torch.zeros(emb_size))
|
|
29
|
-
self.bn = nn.BatchNorm1d(emb_size)
|
|
28
|
+
self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)
|
|
30
29
|
|
|
31
30
|
def forward(self, x):
|
|
32
31
|
# x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
|
|
33
|
-
|
|
32
|
+
if x.dim() == 2: # (B, E)
|
|
33
|
+
x_norm = self.bn(x)
|
|
34
|
+
p = torch.sigmoid(x_norm)
|
|
35
|
+
return x * (self.alpha + (1 - self.alpha) * p)
|
|
34
36
|
|
|
35
|
-
if x.dim() == 3:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
output = p * x + (1 - p) * self.alpha * x
|
|
42
|
-
if len(original_shape) == 3:
|
|
43
|
-
output = output.view(original_shape)
|
|
44
|
-
return output
|
|
37
|
+
if x.dim() == 3: # (B, T, E)
|
|
38
|
+
b, t, e = x.shape
|
|
39
|
+
x2 = x.reshape(-1, e) # (B*T, E)
|
|
40
|
+
x_norm = self.bn(x2)
|
|
41
|
+
p = torch.sigmoid(x_norm).reshape(b, t, e)
|
|
42
|
+
return x * (self.alpha + (1 - self.alpha) * p)
|
|
45
43
|
|
|
46
44
|
|
|
47
45
|
def activation_layer(
|
nextrec/basic/asserts.py
CHANGED
|
@@ -8,7 +8,7 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
-
from nextrec.utils.types import TaskTypeName
|
|
11
|
+
from nextrec.utils.types import TaskTypeName
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def assert_task(
|
|
@@ -49,24 +49,3 @@ def assert_task(
|
|
|
49
49
|
raise ValueError(
|
|
50
50
|
f"{model_name} requires task length {nums_task}, got {len(task)}."
|
|
51
51
|
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def assert_training_mode(
|
|
55
|
-
training_mode: TrainingModeName | list[TrainingModeName],
|
|
56
|
-
nums_task: int,
|
|
57
|
-
*,
|
|
58
|
-
model_name: str,
|
|
59
|
-
) -> None:
|
|
60
|
-
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
61
|
-
if not isinstance(training_mode, list):
|
|
62
|
-
raise TypeError(
|
|
63
|
-
f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
|
|
64
|
-
)
|
|
65
|
-
if len(training_mode) != nums_task:
|
|
66
|
-
raise ValueError(
|
|
67
|
-
f"[{model_name}-init Error] training_mode list length must match number of tasks."
|
|
68
|
-
)
|
|
69
|
-
if any(mode not in valid_modes for mode in training_mode):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
72
|
-
)
|
nextrec/basic/callback.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 21/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -69,7 +69,7 @@ class Callback:
|
|
|
69
69
|
|
|
70
70
|
class CallbackList:
|
|
71
71
|
"""
|
|
72
|
-
Generates a list of callbacks
|
|
72
|
+
Generates a list of callbacks, used to manage and invoke multiple callbacks during training.
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
75
|
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
nextrec/basic/features.py
CHANGED
|
@@ -8,10 +8,9 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from typing import Literal
|
|
12
|
-
|
|
13
11
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
14
|
-
from nextrec.utils.
|
|
12
|
+
from nextrec.utils.torch_utils import to_list
|
|
13
|
+
from nextrec.utils.types import EmbeddingInitType, SequenceCombinerType
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class BaseFeature:
|
|
@@ -29,15 +28,7 @@ class EmbeddingFeature(BaseFeature):
|
|
|
29
28
|
embedding_name: str = "",
|
|
30
29
|
embedding_dim: int | None = None,
|
|
31
30
|
padding_idx: int = 0,
|
|
32
|
-
init_type:
|
|
33
|
-
"normal",
|
|
34
|
-
"uniform",
|
|
35
|
-
"xavier_uniform",
|
|
36
|
-
"xavier_normal",
|
|
37
|
-
"kaiming_uniform",
|
|
38
|
-
"kaiming_normal",
|
|
39
|
-
"orthogonal",
|
|
40
|
-
] = "normal",
|
|
31
|
+
init_type: EmbeddingInitType = "normal",
|
|
41
32
|
init_params: dict | None = None,
|
|
42
33
|
l1_reg: float = 0.0,
|
|
43
34
|
l2_reg: float = 0.0,
|
|
@@ -73,23 +64,9 @@ class SequenceFeature(EmbeddingFeature):
|
|
|
73
64
|
max_len: int = 50,
|
|
74
65
|
embedding_name: str = "",
|
|
75
66
|
embedding_dim: int | None = None,
|
|
76
|
-
combiner:
|
|
77
|
-
"mean",
|
|
78
|
-
"sum",
|
|
79
|
-
"concat",
|
|
80
|
-
"dot_attention",
|
|
81
|
-
"self_attention",
|
|
82
|
-
] = "mean",
|
|
67
|
+
combiner: SequenceCombinerType = "mean",
|
|
83
68
|
padding_idx: int = 0,
|
|
84
|
-
init_type:
|
|
85
|
-
"normal",
|
|
86
|
-
"uniform",
|
|
87
|
-
"xavier_uniform",
|
|
88
|
-
"xavier_normal",
|
|
89
|
-
"kaiming_uniform",
|
|
90
|
-
"kaiming_normal",
|
|
91
|
-
"orthogonal",
|
|
92
|
-
] = "normal",
|
|
69
|
+
init_type: EmbeddingInitType = "normal",
|
|
93
70
|
init_params: dict | None = None,
|
|
94
71
|
l1_reg: float = 0.0,
|
|
95
72
|
l2_reg: float = 0.0,
|
|
@@ -143,15 +120,7 @@ class SparseFeature(EmbeddingFeature):
|
|
|
143
120
|
embedding_name: str = "",
|
|
144
121
|
embedding_dim: int | None = None,
|
|
145
122
|
padding_idx: int = 0,
|
|
146
|
-
init_type:
|
|
147
|
-
"normal",
|
|
148
|
-
"uniform",
|
|
149
|
-
"xavier_uniform",
|
|
150
|
-
"xavier_normal",
|
|
151
|
-
"kaiming_uniform",
|
|
152
|
-
"kaiming_normal",
|
|
153
|
-
"orthogonal",
|
|
154
|
-
] = "normal",
|
|
123
|
+
init_type: EmbeddingInitType = "normal",
|
|
155
124
|
init_params: dict | None = None,
|
|
156
125
|
l1_reg: float = 0.0,
|
|
157
126
|
l2_reg: float = 0.0,
|
nextrec/basic/heads.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Task head implementations for NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 23/12/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -24,6 +24,12 @@ class TaskHead(nn.Module):
|
|
|
24
24
|
|
|
25
25
|
This wraps PredictionLayer so models can depend on a "Head" abstraction
|
|
26
26
|
without changing their existing forward signatures.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
task_type: The type of task(s) this head is responsible for.
|
|
30
|
+
task_dims: The dimensionality of each task's output.
|
|
31
|
+
use_bias: Whether to include a bias term in the prediction layer.
|
|
32
|
+
return_logits: Whether to return raw logits or apply activation.
|
|
27
33
|
"""
|
|
28
34
|
|
|
29
35
|
def __init__(
|
|
@@ -56,6 +62,12 @@ class RetrievalHead(nn.Module):
|
|
|
56
62
|
|
|
57
63
|
It computes similarity for pointwise training/inference, and returns
|
|
58
64
|
raw embeddings for in-batch negative sampling in pairwise/listwise modes.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
similarity_metric: The metric used to compute similarity between embeddings.
|
|
68
|
+
temperature: Scaling factor for similarity scores.
|
|
69
|
+
training_mode: The training mode, which can be pointwise, pairwise, or listwise.
|
|
70
|
+
apply_sigmoid: Whether to apply sigmoid activation to the similarity scores in pointwise mode.
|
|
59
71
|
"""
|
|
60
72
|
|
|
61
73
|
def __init__(
|
nextrec/basic/layers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Layer implementations used across NextRec.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -20,15 +20,13 @@ import torch.nn.functional as F
|
|
|
20
20
|
from nextrec.basic.activation import activation_layer
|
|
21
21
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
22
22
|
from nextrec.utils.torch_utils import get_initializer
|
|
23
|
-
from nextrec.utils.types import ActivationName
|
|
23
|
+
from nextrec.utils.types import ActivationName, TaskTypeName
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class PredictionLayer(nn.Module):
|
|
27
27
|
def __init__(
|
|
28
28
|
self,
|
|
29
|
-
task_type:
|
|
30
|
-
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
31
|
-
) = "binary",
|
|
29
|
+
task_type: TaskTypeName | list[TaskTypeName] = "binary",
|
|
32
30
|
task_dims: int | list[int] | None = None,
|
|
33
31
|
use_bias: bool = True,
|
|
34
32
|
return_logits: bool = False,
|
|
@@ -92,10 +90,9 @@ class PredictionLayer(nn.Module):
|
|
|
92
90
|
if self.return_logits:
|
|
93
91
|
outputs.append(task_logits)
|
|
94
92
|
continue
|
|
95
|
-
|
|
96
|
-
if task == "binary":
|
|
93
|
+
if task_type == "binary":
|
|
97
94
|
outputs.append(torch.sigmoid(task_logits))
|
|
98
|
-
elif
|
|
95
|
+
elif task_type == "regression":
|
|
99
96
|
outputs.append(task_logits)
|
|
100
97
|
else:
|
|
101
98
|
raise ValueError(
|
|
@@ -897,30 +894,7 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
897
894
|
self,
|
|
898
895
|
embedding_dim: int,
|
|
899
896
|
hidden_units: list = [80, 40],
|
|
900
|
-
activation:
|
|
901
|
-
"dice",
|
|
902
|
-
"relu",
|
|
903
|
-
"relu6",
|
|
904
|
-
"elu",
|
|
905
|
-
"selu",
|
|
906
|
-
"leaky_relu",
|
|
907
|
-
"prelu",
|
|
908
|
-
"gelu",
|
|
909
|
-
"sigmoid",
|
|
910
|
-
"tanh",
|
|
911
|
-
"softplus",
|
|
912
|
-
"softsign",
|
|
913
|
-
"hardswish",
|
|
914
|
-
"mish",
|
|
915
|
-
"silu",
|
|
916
|
-
"swish",
|
|
917
|
-
"hardsigmoid",
|
|
918
|
-
"tanhshrink",
|
|
919
|
-
"softshrink",
|
|
920
|
-
"none",
|
|
921
|
-
"linear",
|
|
922
|
-
"identity",
|
|
923
|
-
] = "sigmoid",
|
|
897
|
+
activation: ActivationName = "sigmoid",
|
|
924
898
|
use_softmax: bool = False,
|
|
925
899
|
):
|
|
926
900
|
super().__init__()
|
|
@@ -1029,7 +1003,9 @@ class RMSNorm(torch.nn.Module):
|
|
|
1029
1003
|
|
|
1030
1004
|
|
|
1031
1005
|
class DomainBatchNorm(nn.Module):
|
|
1032
|
-
"""
|
|
1006
|
+
"""
|
|
1007
|
+
Domain-specific BatchNorm (applied per-domain with a shared interface).
|
|
1008
|
+
"""
|
|
1033
1009
|
|
|
1034
1010
|
def __init__(self, num_features: int, num_domains: int):
|
|
1035
1011
|
super().__init__()
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -99,7 +99,8 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
102
|
-
"""
|
|
102
|
+
"""
|
|
103
|
+
Set up a logger that logs to both console and a file with ANSI formatting.
|
|
103
104
|
Only console output has colors; file output is stripped of ANSI codes.
|
|
104
105
|
|
|
105
106
|
Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
|
nextrec/basic/metrics.py
CHANGED
|
@@ -23,7 +23,6 @@ from sklearn.metrics import (
|
|
|
23
23
|
)
|
|
24
24
|
from nextrec.utils.types import TaskTypeName, MetricsName
|
|
25
25
|
|
|
26
|
-
|
|
27
26
|
TASK_DEFAULT_METRICS = {
|
|
28
27
|
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
29
28
|
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
@@ -334,6 +333,60 @@ def compute_map_at_k(
|
|
|
334
333
|
return float(np.mean(aps)) if aps else 0.0
|
|
335
334
|
|
|
336
335
|
|
|
336
|
+
def compute_topk_counts(
|
|
337
|
+
y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
|
|
338
|
+
) -> tuple[int, int, int]:
|
|
339
|
+
"""Compute Top-K% sample size, hits, and positives for binary labels."""
|
|
340
|
+
y_true = (y_true > 0).astype(int)
|
|
341
|
+
n = y_true.size
|
|
342
|
+
if n == 0:
|
|
343
|
+
return 0, 0, 0
|
|
344
|
+
if k_percent <= 0:
|
|
345
|
+
return 0, 0, int(y_true.sum())
|
|
346
|
+
if k_percent >= 100:
|
|
347
|
+
k_count = n
|
|
348
|
+
else:
|
|
349
|
+
k_count = int(np.ceil(n * (k_percent / 100.0)))
|
|
350
|
+
k_count = max(k_count, 1)
|
|
351
|
+
order = np.argsort(y_pred)[::-1]
|
|
352
|
+
topk = order[:k_count]
|
|
353
|
+
hits = int(y_true[topk].sum())
|
|
354
|
+
total_pos = int(y_true.sum())
|
|
355
|
+
return k_count, hits, total_pos
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def compute_topk_precision(
|
|
359
|
+
y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
|
|
360
|
+
) -> float:
|
|
361
|
+
"""Compute Top-K% Precision."""
|
|
362
|
+
k_count, hits, _ = compute_topk_counts(y_true, y_pred, k_percent)
|
|
363
|
+
if k_count == 0:
|
|
364
|
+
return 0.0
|
|
365
|
+
return float(hits / k_count)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def compute_topk_recall(
|
|
369
|
+
y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
|
|
370
|
+
) -> float:
|
|
371
|
+
"""Compute Top-K% Recall."""
|
|
372
|
+
_, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
|
|
373
|
+
if total_pos == 0:
|
|
374
|
+
return 0.0
|
|
375
|
+
return float(hits / total_pos)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def compute_lift_at_k(y_true: np.ndarray, y_pred: np.ndarray, k_percent: int) -> float:
|
|
379
|
+
"""Compute Lift@K from Top-K% precision and overall rate."""
|
|
380
|
+
k_count, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
|
|
381
|
+
if k_count == 0:
|
|
382
|
+
return 0.0
|
|
383
|
+
base_rate = total_pos / float(y_true.size)
|
|
384
|
+
if base_rate == 0.0:
|
|
385
|
+
return 0.0
|
|
386
|
+
precision = hits / float(k_count)
|
|
387
|
+
return float(precision / base_rate)
|
|
388
|
+
|
|
389
|
+
|
|
337
390
|
def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
338
391
|
"""Compute Cosine Separation."""
|
|
339
392
|
y_true = (y_true > 0).astype(int)
|
|
@@ -399,11 +452,11 @@ def configure_metrics(
|
|
|
399
452
|
if primary_task not in TASK_DEFAULT_METRICS:
|
|
400
453
|
raise ValueError(f"Unsupported task type: {primary_task}")
|
|
401
454
|
metrics_list = TASK_DEFAULT_METRICS[primary_task]
|
|
402
|
-
best_metrics_mode =
|
|
455
|
+
best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
|
|
403
456
|
return metrics_list, task_specific_metrics, best_metrics_mode
|
|
404
457
|
|
|
405
458
|
|
|
406
|
-
def
|
|
459
|
+
def get_best_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
|
|
407
460
|
"""Determine if metric should be maximized or minimized."""
|
|
408
461
|
# Metrics that should be maximized
|
|
409
462
|
if first_metric in {
|
|
@@ -429,6 +482,9 @@ def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -
|
|
|
429
482
|
or first_metric.startswith("mrr@")
|
|
430
483
|
or first_metric.startswith("ndcg@")
|
|
431
484
|
or first_metric.startswith("map@")
|
|
485
|
+
or first_metric.startswith("topk_recall@")
|
|
486
|
+
or first_metric.startswith("topk_precision@")
|
|
487
|
+
or first_metric.startswith("lift@")
|
|
432
488
|
):
|
|
433
489
|
return "max"
|
|
434
490
|
# Cosine separation should be maximized
|
|
@@ -457,6 +513,15 @@ def compute_single_metric(
|
|
|
457
513
|
|
|
458
514
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
459
515
|
try:
|
|
516
|
+
if metric.startswith("topk_recall@"):
|
|
517
|
+
k_percent = int(metric.split("@")[1])
|
|
518
|
+
return compute_topk_recall(y_true, y_pred, k_percent)
|
|
519
|
+
if metric.startswith("topk_precision@"):
|
|
520
|
+
k_percent = int(metric.split("@")[1])
|
|
521
|
+
return compute_topk_precision(y_true, y_pred, k_percent)
|
|
522
|
+
if metric.startswith("lift@"):
|
|
523
|
+
k_percent = int(metric.split("@")[1])
|
|
524
|
+
return compute_lift_at_k(y_true, y_pred, k_percent)
|
|
460
525
|
if metric.startswith("recall@"):
|
|
461
526
|
k = int(metric.split("@")[1])
|
|
462
527
|
return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
|
@@ -650,7 +715,23 @@ def evaluate_metrics(
|
|
|
650
715
|
allowed_metrics = metric_allowlist.get(task_type)
|
|
651
716
|
for metric in metrics:
|
|
652
717
|
if allowed_metrics is not None and metric not in allowed_metrics:
|
|
653
|
-
|
|
718
|
+
if metric.startswith(
|
|
719
|
+
(
|
|
720
|
+
"recall@",
|
|
721
|
+
"precision@",
|
|
722
|
+
"hitrate@",
|
|
723
|
+
"hr@",
|
|
724
|
+
"mrr@",
|
|
725
|
+
"ndcg@",
|
|
726
|
+
"map@",
|
|
727
|
+
"topk_recall@",
|
|
728
|
+
"topk_precision@",
|
|
729
|
+
"lift@",
|
|
730
|
+
)
|
|
731
|
+
):
|
|
732
|
+
pass
|
|
733
|
+
else:
|
|
734
|
+
continue
|
|
654
735
|
y_true_task = y_true[:, task_idx]
|
|
655
736
|
y_pred_task = y_pred[:, task_idx]
|
|
656
737
|
task_user_ids = user_ids
|
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -155,9 +155,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
155
155
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
156
156
|
|
|
157
157
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
158
|
-
rank: Global rank (defaults to env RANK).
|
|
159
|
-
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
160
|
-
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
158
|
+
rank: Global rank (defaults to env RANK). e.g., 0 for the main process.
|
|
159
|
+
world_size: Number of processes (defaults to env WORLD_SIZE). e.g., 4 for a 4-process training.
|
|
160
|
+
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK). e.g., 0 for the first GPU.
|
|
161
161
|
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
162
162
|
|
|
163
163
|
Note:
|
|
@@ -933,6 +933,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
933
933
|
|
|
934
934
|
existing_callbacks = self.callbacks.callbacks
|
|
935
935
|
|
|
936
|
+
has_validation = valid_data is not None or valid_split is not None
|
|
937
|
+
checkpoint_monitor = monitor_metric
|
|
938
|
+
checkpoint_mode = self.best_metrics_mode
|
|
939
|
+
if not has_validation:
|
|
940
|
+
checkpoint_monitor = "loss"
|
|
941
|
+
checkpoint_mode = "min"
|
|
942
|
+
|
|
936
943
|
if self.early_stop_patience > 0 and not any(
|
|
937
944
|
isinstance(cb, EarlyStopper) for cb in existing_callbacks
|
|
938
945
|
):
|
|
@@ -946,6 +953,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
946
953
|
)
|
|
947
954
|
)
|
|
948
955
|
|
|
956
|
+
has_validation = valid_data is not None or valid_split is not None
|
|
957
|
+
|
|
949
958
|
if self.is_main_process and not any(
|
|
950
959
|
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
951
960
|
):
|
|
@@ -953,9 +962,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
953
962
|
CheckpointSaver(
|
|
954
963
|
best_path=self.best_path,
|
|
955
964
|
checkpoint_path=self.checkpoint_path,
|
|
956
|
-
monitor=
|
|
957
|
-
mode=
|
|
958
|
-
save_best_only=
|
|
965
|
+
monitor=checkpoint_monitor,
|
|
966
|
+
mode=checkpoint_mode,
|
|
967
|
+
save_best_only=has_validation,
|
|
959
968
|
verbose=1,
|
|
960
969
|
)
|
|
961
970
|
)
|
|
@@ -1246,11 +1255,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1246
1255
|
epoch_logs[f"val_{k}"] = v
|
|
1247
1256
|
else:
|
|
1248
1257
|
epoch_logs = {**train_log_payload}
|
|
1249
|
-
if self.is_main_process:
|
|
1250
|
-
self.save_model(
|
|
1251
|
-
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
1252
|
-
)
|
|
1253
|
-
self.best_checkpoint_path = self.checkpoint_path
|
|
1254
1258
|
|
|
1255
1259
|
# Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
|
|
1256
1260
|
self.callbacks.on_epoch_end(epoch, epoch_logs)
|
|
@@ -1347,6 +1351,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1347
1351
|
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
1348
1352
|
self.optimizer_fn.step()
|
|
1349
1353
|
if self.grad_norm is not None:
|
|
1354
|
+
# Synchronize GradNorm buffers across DDP ranks before stepping
|
|
1355
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1356
|
+
self.grad_norm.sync()
|
|
1350
1357
|
self.grad_norm.step()
|
|
1351
1358
|
accumulated_loss += loss.item()
|
|
1352
1359
|
|