nextrec 0.4.10__py3-none-any.whl → 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +44 -54
- nextrec/basic/features.py +35 -22
- nextrec/basic/layers.py +64 -68
- nextrec/basic/loggers.py +2 -2
- nextrec/basic/metrics.py +9 -5
- nextrec/basic/model.py +162 -106
- nextrec/cli.py +16 -5
- nextrec/data/preprocessor.py +4 -4
- nextrec/loss/loss_utils.py +1 -1
- nextrec/models/generative/__init__.py +1 -1
- nextrec/models/ranking/eulernet.py +44 -75
- nextrec/models/ranking/ffm.py +275 -0
- nextrec/models/ranking/lr.py +1 -3
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/console.py +9 -1
- nextrec/utils/model.py +14 -0
- {nextrec-0.4.10.dist-info → nextrec-0.4.12.dist-info}/METADATA +32 -11
- {nextrec-0.4.10.dist-info → nextrec-0.4.12.dist-info}/RECORD +30 -23
- /nextrec/models/{generative → sequential}/hstu.py +0 -0
- {nextrec-0.4.10.dist-info → nextrec-0.4.12.dist-info}/WHEEL +0 -0
- {nextrec-0.4.10.dist-info → nextrec-0.4.12.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.10.dist-info → nextrec-0.4.12.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.12"
|
nextrec/basic/callback.py
CHANGED
|
@@ -22,10 +22,10 @@ class Callback:
|
|
|
22
22
|
"""
|
|
23
23
|
Base callback.
|
|
24
24
|
|
|
25
|
-
Notes
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
Notes for DDP training:
|
|
26
|
+
In distributed training, the training loop runs on every rank.
|
|
27
|
+
For callbacks with side effects (saving, logging, etc.), set
|
|
28
|
+
``run_on_main_process_only=True`` to avoid multi-rank duplication.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
run_on_main_process_only: bool = False
|
|
@@ -70,7 +70,7 @@ class Callback:
|
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
class CallbackList:
|
|
73
|
-
"""
|
|
73
|
+
"""Generates a list of callbacks"""
|
|
74
74
|
|
|
75
75
|
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
76
76
|
self.callbacks = callbacks or []
|
|
@@ -78,61 +78,41 @@ class CallbackList:
|
|
|
78
78
|
def append(self, callback: Callback):
|
|
79
79
|
self.callbacks.append(callback)
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def call(self, fn_name: str, *args, **kwargs):
|
|
82
82
|
for callback in self.callbacks:
|
|
83
|
-
callback.
|
|
83
|
+
if not callback.should_run():
|
|
84
|
+
continue
|
|
85
|
+
getattr(callback, fn_name)(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def set_model(self, model):
|
|
88
|
+
self.call("set_model", model)
|
|
84
89
|
|
|
85
90
|
def set_params(self, params: dict):
|
|
86
|
-
|
|
87
|
-
callback.set_params(params)
|
|
91
|
+
self.call("set_params", params)
|
|
88
92
|
|
|
89
93
|
def on_train_begin(self, logs: Optional[dict] = None):
|
|
90
|
-
|
|
91
|
-
if not callback.should_run():
|
|
92
|
-
continue
|
|
93
|
-
callback.on_train_begin(logs)
|
|
94
|
+
self.call("on_train_begin", logs)
|
|
94
95
|
|
|
95
96
|
def on_train_end(self, logs: Optional[dict] = None):
|
|
96
|
-
|
|
97
|
-
if not callback.should_run():
|
|
98
|
-
continue
|
|
99
|
-
callback.on_train_end(logs)
|
|
97
|
+
self.call("on_train_end", logs)
|
|
100
98
|
|
|
101
99
|
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
102
|
-
|
|
103
|
-
if not callback.should_run():
|
|
104
|
-
continue
|
|
105
|
-
callback.on_epoch_begin(epoch, logs)
|
|
100
|
+
self.call("on_epoch_begin", epoch, logs)
|
|
106
101
|
|
|
107
102
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
108
|
-
|
|
109
|
-
if not callback.should_run():
|
|
110
|
-
continue
|
|
111
|
-
callback.on_epoch_end(epoch, logs)
|
|
103
|
+
self.call("on_epoch_end", epoch, logs)
|
|
112
104
|
|
|
113
105
|
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
114
|
-
|
|
115
|
-
if not callback.should_run():
|
|
116
|
-
continue
|
|
117
|
-
callback.on_batch_begin(batch, logs)
|
|
106
|
+
self.call("on_batch_begin", batch, logs)
|
|
118
107
|
|
|
119
108
|
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
120
|
-
|
|
121
|
-
if not callback.should_run():
|
|
122
|
-
continue
|
|
123
|
-
callback.on_batch_end(batch, logs)
|
|
109
|
+
self.call("on_batch_end", batch, logs)
|
|
124
110
|
|
|
125
111
|
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
126
|
-
|
|
127
|
-
if not callback.should_run():
|
|
128
|
-
continue
|
|
129
|
-
callback.on_validation_begin(logs)
|
|
112
|
+
self.call("on_validation_begin", logs)
|
|
130
113
|
|
|
131
114
|
def on_validation_end(self, logs: Optional[dict] = None):
|
|
132
|
-
|
|
133
|
-
if not callback.should_run():
|
|
134
|
-
continue
|
|
135
|
-
callback.on_validation_end(logs)
|
|
115
|
+
self.call("on_validation_end", logs)
|
|
136
116
|
|
|
137
117
|
|
|
138
118
|
class EarlyStopper(Callback):
|
|
@@ -146,6 +126,20 @@ class EarlyStopper(Callback):
|
|
|
146
126
|
restore_best_weights: bool = True,
|
|
147
127
|
verbose: int = 1,
|
|
148
128
|
):
|
|
129
|
+
"""
|
|
130
|
+
Callback to stop training early if no improvement.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
monitor: Metric name to monitor.
|
|
134
|
+
patience: Number of epochs with no improvement after which training will be stopped.
|
|
135
|
+
mode: One of {'min', 'max'}. In 'min' mode, training will stop when the
|
|
136
|
+
monitored metric has stopped decreasing; in 'max' mode it will stop
|
|
137
|
+
when the monitored metric has stopped increasing.
|
|
138
|
+
min_delta: Minimum change in the monitored metric to qualify as an improvement.
|
|
139
|
+
restore_best_weights: Whether to restore model weights from the epoch with the best value
|
|
140
|
+
of the monitored metric.
|
|
141
|
+
verbose: Verbosity mode. 1: messages will be printed. 0: silent.
|
|
142
|
+
"""
|
|
149
143
|
super().__init__()
|
|
150
144
|
self.monitor = monitor
|
|
151
145
|
self.patience = patience
|
|
@@ -233,6 +227,7 @@ class CheckpointSaver(Callback):
|
|
|
233
227
|
save_best_only: If True, only save when the model is considered the "best".
|
|
234
228
|
save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
|
|
235
229
|
verbose: Verbosity mode.
|
|
230
|
+
run_on_main_process_only: Whether to run this callback only on the main process in DDP.
|
|
236
231
|
"""
|
|
237
232
|
|
|
238
233
|
def __init__(
|
|
@@ -274,7 +269,6 @@ class CheckpointSaver(Callback):
|
|
|
274
269
|
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
275
270
|
|
|
276
271
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
277
|
-
logging.info("")
|
|
278
272
|
logs = logs or {}
|
|
279
273
|
|
|
280
274
|
should_save = False
|
|
@@ -283,9 +277,6 @@ class CheckpointSaver(Callback):
|
|
|
283
277
|
elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
|
|
284
278
|
should_save = True
|
|
285
279
|
|
|
286
|
-
if not should_save and self.save_best_only:
|
|
287
|
-
should_save = False
|
|
288
|
-
|
|
289
280
|
# Check if this is the best model
|
|
290
281
|
current = logs.get(self.monitor)
|
|
291
282
|
is_best = False
|
|
@@ -297,11 +288,7 @@ class CheckpointSaver(Callback):
|
|
|
297
288
|
|
|
298
289
|
if should_save:
|
|
299
290
|
if not self.save_best_only or is_best:
|
|
300
|
-
checkpoint_path
|
|
301
|
-
self.checkpoint_path.parent
|
|
302
|
-
/ f"{self.checkpoint_path.stem}{self.checkpoint_path.suffix}"
|
|
303
|
-
)
|
|
304
|
-
self.save_checkpoint(checkpoint_path, epoch, logs)
|
|
291
|
+
self.save_checkpoint(self.checkpoint_path, epoch, logs)
|
|
305
292
|
|
|
306
293
|
if is_best:
|
|
307
294
|
# Use save_path directly without adding _best suffix since it may already contain it
|
|
@@ -371,7 +358,9 @@ class LearningRateScheduler(Callback):
|
|
|
371
358
|
# Step the scheduler
|
|
372
359
|
if hasattr(self.scheduler, "step"):
|
|
373
360
|
# Some schedulers need metrics
|
|
374
|
-
if
|
|
361
|
+
if logs is None:
|
|
362
|
+
logs = {}
|
|
363
|
+
if "val_loss" in logs and hasattr(self.scheduler, "mode"):
|
|
375
364
|
self.scheduler.step(logs["val_loss"])
|
|
376
365
|
else:
|
|
377
366
|
self.scheduler.step()
|
|
@@ -399,7 +388,6 @@ class MetricsLogger(Callback):
|
|
|
399
388
|
self.run_on_main_process_only = True
|
|
400
389
|
self.log_freq = log_freq
|
|
401
390
|
self.verbose = verbose
|
|
402
|
-
self.batch_count = 0
|
|
403
391
|
|
|
404
392
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
405
393
|
if self.verbose > 0 and (
|
|
@@ -416,8 +404,10 @@ class MetricsLogger(Callback):
|
|
|
416
404
|
logging.info(f"Epoch {epoch + 1}: {metrics_str}")
|
|
417
405
|
|
|
418
406
|
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
419
|
-
self.
|
|
420
|
-
|
|
407
|
+
if self.verbose > 1 and (
|
|
408
|
+
self.log_freq == "batch"
|
|
409
|
+
or (isinstance(self.log_freq, int) and (batch + 1) % self.log_freq == 0)
|
|
410
|
+
):
|
|
421
411
|
logs = logs or {}
|
|
422
412
|
metrics_str = " - ".join(
|
|
423
413
|
[
|
nextrec/basic/features.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Feature definitions
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -12,22 +12,20 @@ from nextrec.utils.embedding import get_auto_embedding_dim
|
|
|
12
12
|
from nextrec.utils.feature import normalize_to_list
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class BaseFeature
|
|
15
|
+
class BaseFeature:
|
|
16
16
|
def __repr__(self):
|
|
17
17
|
params = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
|
18
18
|
param_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
|
|
19
19
|
return f"{self.__class__.__name__}({param_str})"
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class
|
|
22
|
+
class EmbeddingFeature(BaseFeature):
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
25
|
name: str,
|
|
26
26
|
vocab_size: int,
|
|
27
|
-
max_len: int = 20,
|
|
28
27
|
embedding_name: str = "",
|
|
29
28
|
embedding_dim: int | None = 4,
|
|
30
|
-
combiner: str = "mean",
|
|
31
29
|
padding_idx: int | None = None,
|
|
32
30
|
init_type: str = "normal",
|
|
33
31
|
init_params: dict | None = None,
|
|
@@ -39,13 +37,15 @@ class SequenceFeature(BaseFeature):
|
|
|
39
37
|
):
|
|
40
38
|
self.name = name
|
|
41
39
|
self.vocab_size = vocab_size
|
|
42
|
-
self.max_len = max_len
|
|
43
40
|
self.embedding_name = embedding_name or name
|
|
44
|
-
self.embedding_dim =
|
|
41
|
+
self.embedding_dim = (
|
|
42
|
+
get_auto_embedding_dim(vocab_size)
|
|
43
|
+
if embedding_dim is None
|
|
44
|
+
else embedding_dim
|
|
45
|
+
)
|
|
45
46
|
|
|
46
47
|
self.init_type = init_type
|
|
47
48
|
self.init_params = init_params or {}
|
|
48
|
-
self.combiner = combiner
|
|
49
49
|
self.padding_idx = padding_idx
|
|
50
50
|
self.l1_reg = l1_reg
|
|
51
51
|
self.l2_reg = l2_reg
|
|
@@ -54,13 +54,15 @@ class SequenceFeature(BaseFeature):
|
|
|
54
54
|
self.freeze_pretrained = freeze_pretrained
|
|
55
55
|
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class SequenceFeature(EmbeddingFeature):
|
|
58
58
|
def __init__(
|
|
59
59
|
self,
|
|
60
60
|
name: str,
|
|
61
61
|
vocab_size: int,
|
|
62
|
+
max_len: int = 20,
|
|
62
63
|
embedding_name: str = "",
|
|
63
64
|
embedding_dim: int | None = 4,
|
|
65
|
+
combiner: str = "mean",
|
|
64
66
|
padding_idx: int | None = None,
|
|
65
67
|
init_type: str = "normal",
|
|
66
68
|
init_params: dict | None = None,
|
|
@@ -70,19 +72,26 @@ class SparseFeature(BaseFeature):
|
|
|
70
72
|
pretrained_weight: torch.Tensor | None = None,
|
|
71
73
|
freeze_pretrained: bool = False,
|
|
72
74
|
):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
super().__init__(
|
|
76
|
+
name=name,
|
|
77
|
+
vocab_size=vocab_size,
|
|
78
|
+
embedding_name=embedding_name,
|
|
79
|
+
embedding_dim=embedding_dim,
|
|
80
|
+
padding_idx=padding_idx,
|
|
81
|
+
init_type=init_type,
|
|
82
|
+
init_params=init_params,
|
|
83
|
+
l1_reg=l1_reg,
|
|
84
|
+
l2_reg=l2_reg,
|
|
85
|
+
trainable=trainable,
|
|
86
|
+
pretrained_weight=pretrained_weight,
|
|
87
|
+
freeze_pretrained=freeze_pretrained,
|
|
88
|
+
)
|
|
89
|
+
self.max_len = max_len
|
|
90
|
+
self.combiner = combiner
|
|
77
91
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
self.l1_reg = l1_reg
|
|
82
|
-
self.l2_reg = l2_reg
|
|
83
|
-
self.trainable = trainable
|
|
84
|
-
self.pretrained_weight = pretrained_weight
|
|
85
|
-
self.freeze_pretrained = freeze_pretrained
|
|
92
|
+
|
|
93
|
+
class SparseFeature(EmbeddingFeature):
|
|
94
|
+
pass
|
|
86
95
|
|
|
87
96
|
|
|
88
97
|
class DenseFeature(BaseFeature):
|
|
@@ -95,7 +104,11 @@ class DenseFeature(BaseFeature):
|
|
|
95
104
|
):
|
|
96
105
|
self.name = name
|
|
97
106
|
self.input_dim = max(int(input_dim or 1), 1)
|
|
98
|
-
self.embedding_dim = embedding_dim
|
|
107
|
+
self.embedding_dim = self.input_dim if embedding_dim is None else embedding_dim
|
|
108
|
+
if use_embedding and self.embedding_dim == 0:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"[Features Error] DenseFeature: use_embedding=True is incompatible with embedding_dim=0"
|
|
111
|
+
)
|
|
99
112
|
if embedding_dim is not None and embedding_dim > 1:
|
|
100
113
|
self.use_embedding = True
|
|
101
114
|
else:
|
nextrec/basic/layers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Layer implementations used across NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -28,6 +28,16 @@ class PredictionLayer(nn.Module):
|
|
|
28
28
|
use_bias: bool = True,
|
|
29
29
|
return_logits: bool = False,
|
|
30
30
|
):
|
|
31
|
+
"""
|
|
32
|
+
Prediction layer supporting binary and regression outputs.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
task_type: A string or list of strings specifying the type of each task. supported types are "binary" and "regression".
|
|
36
|
+
task_dims: An integer or list of integers specifying the output dimension for each task.
|
|
37
|
+
If None, defaults to 1 for each task. If a single integer is provided, it is shared across all tasks.
|
|
38
|
+
use_bias: Whether to include a bias term in the prediction layer.
|
|
39
|
+
return_logits: If True, returns raw logits without applying activation functions.
|
|
40
|
+
"""
|
|
31
41
|
super().__init__()
|
|
32
42
|
self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
|
|
33
43
|
if len(self.task_types) == 0:
|
|
@@ -253,8 +263,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
253
263
|
for feat in unique_feats.values():
|
|
254
264
|
if isinstance(feat, DenseFeature):
|
|
255
265
|
in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
|
|
256
|
-
|
|
257
|
-
|
|
266
|
+
if getattr(feat, "use_embedding", False):
|
|
267
|
+
emb_dim = getattr(feat, "embedding_dim", None)
|
|
268
|
+
out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
|
|
269
|
+
else:
|
|
270
|
+
out_dim = in_dim
|
|
258
271
|
dim += out_dim
|
|
259
272
|
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
260
273
|
dim += feat.embedding_dim * feat.max_len
|
|
@@ -518,13 +531,17 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
518
531
|
self.use_residual = use_residual
|
|
519
532
|
self.dropout_rate = dropout
|
|
520
533
|
|
|
521
|
-
self.W_Q = nn.Linear(
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
self.
|
|
534
|
+
self.W_Q = nn.Linear(
|
|
535
|
+
embedding_dim, embedding_dim, bias=False
|
|
536
|
+
) # Query projection
|
|
537
|
+
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False) # Key projection
|
|
538
|
+
self.W_V = nn.Linear(
|
|
539
|
+
embedding_dim, embedding_dim, bias=False
|
|
540
|
+
) # Value projection
|
|
541
|
+
self.W_O = nn.Linear(
|
|
542
|
+
embedding_dim, embedding_dim, bias=False
|
|
543
|
+
) # Output projection
|
|
525
544
|
|
|
526
|
-
if self.use_residual:
|
|
527
|
-
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
528
545
|
if use_layer_norm:
|
|
529
546
|
self.layer_norm = nn.LayerNorm(embedding_dim)
|
|
530
547
|
else:
|
|
@@ -537,81 +554,60 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
537
554
|
def forward(
|
|
538
555
|
self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
539
556
|
) -> torch.Tensor:
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
Returns:
|
|
545
|
-
output: [batch_size, seq_len, embedding_dim]
|
|
546
|
-
"""
|
|
547
|
-
batch_size, seq_len, _ = x.shape
|
|
548
|
-
Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
|
|
557
|
+
# x: [Batch, Length, Dim]
|
|
558
|
+
B, L, D = x.shape
|
|
559
|
+
|
|
560
|
+
Q = self.W_Q(x)
|
|
549
561
|
K = self.W_K(x)
|
|
550
562
|
V = self.W_V(x)
|
|
551
563
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
564
|
+
Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(
|
|
565
|
+
1, 2
|
|
566
|
+
) # [Batch, Heads, Length, head_dim]
|
|
567
|
+
K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
568
|
+
V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
569
|
+
|
|
570
|
+
key_padding_mask = None
|
|
571
|
+
if attention_mask is not None:
|
|
572
|
+
if attention_mask.dim() == 2: # [B,L], 1=valid, 0=pad
|
|
573
|
+
key_padding_mask = ~attention_mask.bool()
|
|
574
|
+
attn_mask = key_padding_mask[:, None, None, :]
|
|
575
|
+
attn_mask = attn_mask.expand(B, 1, L, L)
|
|
576
|
+
elif attention_mask.dim() == 3: # [B,L,L], 1=allowed, 0=masked
|
|
577
|
+
attn_mask = (~attention_mask.bool()).view(B, 1, L, L)
|
|
578
|
+
else:
|
|
579
|
+
raise ValueError("attention_mask must be [B,L] or [B,L,L]")
|
|
580
|
+
else:
|
|
581
|
+
attn_mask = None
|
|
556
582
|
|
|
557
583
|
if self.use_flash_attention:
|
|
558
|
-
|
|
559
|
-
if attention_mask is not None:
|
|
560
|
-
# Convert mask to [batch_size, 1, seq_len, seq_len] format
|
|
561
|
-
if attention_mask.dim() == 2:
|
|
562
|
-
# [B, L] -> [B, 1, 1, L]
|
|
563
|
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
564
|
-
elif attention_mask.dim() == 3:
|
|
565
|
-
# [B, L, L] -> [B, 1, L, L]
|
|
566
|
-
attention_mask = attention_mask.unsqueeze(1)
|
|
567
|
-
attention_output = F.scaled_dot_product_attention(
|
|
584
|
+
attn = F.scaled_dot_product_attention(
|
|
568
585
|
Q,
|
|
569
586
|
K,
|
|
570
587
|
V,
|
|
571
|
-
attn_mask=
|
|
588
|
+
attn_mask=attn_mask,
|
|
572
589
|
dropout_p=self.dropout_rate if self.training else 0.0,
|
|
573
|
-
)
|
|
574
|
-
# Handle potential NaN values
|
|
575
|
-
attention_output = torch.nan_to_num(attention_output, nan=0.0)
|
|
590
|
+
) # [B,H,L,dh]
|
|
576
591
|
else:
|
|
577
|
-
# Fallback to standard attention
|
|
578
592
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
593
|
+
if attn_mask is not None:
|
|
594
|
+
scores = scores.masked_fill(attn_mask, float("-inf"))
|
|
595
|
+
attn_weights = torch.softmax(scores, dim=-1)
|
|
596
|
+
attn_weights = self.dropout(attn_weights)
|
|
597
|
+
attn = torch.matmul(attn_weights, V) # [B,H,L,dh]
|
|
579
598
|
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
if attention_mask.dim() == 2:
|
|
583
|
-
# [B, L] -> [B, 1, 1, L]
|
|
584
|
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
585
|
-
elif attention_mask.dim() == 3:
|
|
586
|
-
# [B, L, L] -> [B, 1, L, L]
|
|
587
|
-
attention_mask = attention_mask.unsqueeze(1)
|
|
588
|
-
scores = scores.masked_fill(~attention_mask, float("-1e9"))
|
|
589
|
-
|
|
590
|
-
attention_weights = F.softmax(scores, dim=-1)
|
|
591
|
-
attention_weights = self.dropout(attention_weights)
|
|
592
|
-
attention_output = torch.matmul(
|
|
593
|
-
attention_weights, V
|
|
594
|
-
) # [batch_size, num_heads, seq_len, head_dim]
|
|
595
|
-
|
|
596
|
-
# Concatenate heads
|
|
597
|
-
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
598
|
-
attention_output = attention_output.view(
|
|
599
|
-
batch_size, seq_len, self.embedding_dim
|
|
600
|
-
)
|
|
599
|
+
attn = attn.transpose(1, 2).contiguous().view(B, L, D)
|
|
600
|
+
out = self.W_O(attn)
|
|
601
601
|
|
|
602
|
-
# Output projection
|
|
603
|
-
output = self.W_O(attention_output)
|
|
604
|
-
|
|
605
|
-
# Residual connection
|
|
606
602
|
if self.use_residual:
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
# Layer normalization
|
|
603
|
+
out = out + x
|
|
610
604
|
if self.layer_norm is not None:
|
|
611
|
-
|
|
605
|
+
out = self.layer_norm(out)
|
|
612
606
|
|
|
613
|
-
|
|
614
|
-
|
|
607
|
+
if key_padding_mask is not None:
|
|
608
|
+
out = out * (~key_padding_mask).unsqueeze(-1)
|
|
609
|
+
|
|
610
|
+
return out
|
|
615
611
|
|
|
616
612
|
|
|
617
613
|
class AttentionPoolingLayer(nn.Module):
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -185,7 +185,7 @@ class TrainingLogger:
|
|
|
185
185
|
) -> dict[str, float]:
|
|
186
186
|
formatted: dict[str, float] = {}
|
|
187
187
|
for key, value in metrics.items():
|
|
188
|
-
if isinstance(value, numbers.
|
|
188
|
+
if isinstance(value, numbers.Real):
|
|
189
189
|
formatted[f"{split}/{key}"] = float(value)
|
|
190
190
|
elif hasattr(value, "item"):
|
|
191
191
|
try:
|
nextrec/basic/metrics.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -49,8 +49,8 @@ TASK_DEFAULT_METRICS = {
|
|
|
49
49
|
|
|
50
50
|
def check_user_id(*metric_sources: Any) -> bool:
|
|
51
51
|
"""Return True when GAUC or ranking@K metrics appear in the provided sources."""
|
|
52
|
-
metric_names
|
|
53
|
-
stack
|
|
52
|
+
metric_names = set()
|
|
53
|
+
stack = list(metric_sources)
|
|
54
54
|
while stack:
|
|
55
55
|
item = stack.pop()
|
|
56
56
|
if not item:
|
|
@@ -367,10 +367,12 @@ def configure_metrics(
|
|
|
367
367
|
target_names: list[str], # ['target1', 'target2']
|
|
368
368
|
) -> tuple[list[str], dict[str, list[str]] | None, str]:
|
|
369
369
|
"""Configure metrics based on task and user input."""
|
|
370
|
+
|
|
370
371
|
primary_task = task[0] if isinstance(task, list) else task
|
|
371
372
|
nums_task = len(task) if isinstance(task, list) else 1
|
|
372
|
-
metrics_list
|
|
373
|
-
task_specific_metrics
|
|
373
|
+
metrics_list = []
|
|
374
|
+
task_specific_metrics = None
|
|
375
|
+
|
|
374
376
|
if isinstance(metrics, dict):
|
|
375
377
|
metrics_list = []
|
|
376
378
|
task_specific_metrics = {}
|
|
@@ -462,6 +464,7 @@ def compute_single_metric(
|
|
|
462
464
|
user_ids: np.ndarray | None = None,
|
|
463
465
|
) -> float:
|
|
464
466
|
"""Compute a single metric given true and predicted values."""
|
|
467
|
+
|
|
465
468
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
466
469
|
metric_lower = metric.lower()
|
|
467
470
|
try:
|
|
@@ -575,6 +578,7 @@ def evaluate_metrics(
|
|
|
575
578
|
user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
|
|
576
579
|
) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
|
|
577
580
|
"""Evaluate specified metrics for given true and predicted values."""
|
|
581
|
+
|
|
578
582
|
result = {}
|
|
579
583
|
if y_true is None or y_pred is None:
|
|
580
584
|
return result
|