nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +4 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +209 -316
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/loss_utils.py +5 -30
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +12 -14
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +59 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
- nextrec-0.4.21.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.21"
|
nextrec/basic/activation.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Activation function definitions
|
|
2
|
+
Activation function definitions for NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 28/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
11
11
|
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
from nextrec.utils.types import ActivationName
|
|
12
15
|
|
|
13
16
|
class Dice(nn.Module):
|
|
14
17
|
"""
|
|
@@ -41,9 +44,11 @@ class Dice(nn.Module):
|
|
|
41
44
|
return output
|
|
42
45
|
|
|
43
46
|
|
|
44
|
-
def activation_layer(
|
|
47
|
+
def activation_layer(
|
|
48
|
+
activation: ActivationName = "none",
|
|
49
|
+
emb_size: int | None = None,
|
|
50
|
+
):
|
|
45
51
|
"""Create an activation layer based on the given activation name."""
|
|
46
|
-
activation = activation.lower()
|
|
47
52
|
if activation == "dice":
|
|
48
53
|
if emb_size is None:
|
|
49
54
|
raise ValueError(
|
nextrec/basic/callback.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -61,16 +61,16 @@ class Callback:
|
|
|
61
61
|
self.params = params
|
|
62
62
|
|
|
63
63
|
def should_run(self) -> bool:
|
|
64
|
-
if not
|
|
64
|
+
if not self.run_on_main_process_only:
|
|
65
65
|
return True
|
|
66
|
-
model =
|
|
67
|
-
|
|
68
|
-
return True
|
|
69
|
-
return bool(getattr(model, "is_main_process", True))
|
|
66
|
+
model = self.model
|
|
67
|
+
return bool(model.is_main_process)
|
|
70
68
|
|
|
71
69
|
|
|
72
70
|
class CallbackList:
|
|
73
|
-
"""
|
|
71
|
+
"""
|
|
72
|
+
Generates a list of callbacks
|
|
73
|
+
"""
|
|
74
74
|
|
|
75
75
|
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
76
76
|
self.callbacks = callbacks or []
|
|
@@ -85,7 +85,8 @@ class CallbackList:
|
|
|
85
85
|
getattr(callback, fn_name)(*args, **kwargs)
|
|
86
86
|
|
|
87
87
|
def set_model(self, model):
|
|
88
|
-
self.
|
|
88
|
+
for callback in self.callbacks:
|
|
89
|
+
callback.set_model(model)
|
|
89
90
|
|
|
90
91
|
def set_params(self, params: dict):
|
|
91
92
|
self.call("set_params", params)
|
|
@@ -194,9 +195,8 @@ class EarlyStopper(Callback):
|
|
|
194
195
|
self.wait += 1
|
|
195
196
|
if self.wait >= self.patience:
|
|
196
197
|
self.stopped_epoch = epoch
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if self.verbose > 0:
|
|
198
|
+
self.model.stop_training = True
|
|
199
|
+
if self.verbose == 1:
|
|
200
200
|
logging.info(
|
|
201
201
|
f"Early stopping triggered at epoch {epoch + 1}. "
|
|
202
202
|
f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
|
|
@@ -218,14 +218,15 @@ class EarlyStopper(Callback):
|
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
class CheckpointSaver(Callback):
|
|
221
|
-
"""
|
|
221
|
+
"""
|
|
222
|
+
Callback to save model checkpoints during training.
|
|
222
223
|
|
|
223
224
|
Args:
|
|
224
225
|
save_path: Path to save checkpoints.
|
|
225
226
|
monitor: Metric name to monitor for saving best model.
|
|
226
227
|
mode: One of {'min', 'max'}.
|
|
227
228
|
save_best_only: If True, only save when the model is considered the "best".
|
|
228
|
-
save_freq: Frequency of checkpoint saving (
|
|
229
|
+
save_freq: Frequency of checkpoint saving (integer for every N epochs).
|
|
229
230
|
verbose: Verbosity mode.
|
|
230
231
|
run_on_main_process_only: Whether to run this callback only on the main process in DDP.
|
|
231
232
|
"""
|
|
@@ -237,7 +238,7 @@ class CheckpointSaver(Callback):
|
|
|
237
238
|
monitor: str = "val_auc",
|
|
238
239
|
mode: str = "max",
|
|
239
240
|
save_best_only: bool = False,
|
|
240
|
-
save_freq:
|
|
241
|
+
save_freq: int = 1,
|
|
241
242
|
verbose: int = 1,
|
|
242
243
|
run_on_main_process_only: bool = True,
|
|
243
244
|
):
|
|
@@ -272,7 +273,7 @@ class CheckpointSaver(Callback):
|
|
|
272
273
|
logs = logs or {}
|
|
273
274
|
|
|
274
275
|
should_save = False
|
|
275
|
-
if self.save_freq ==
|
|
276
|
+
if self.save_freq == 1:
|
|
276
277
|
should_save = True
|
|
277
278
|
elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
|
|
278
279
|
should_save = True
|
|
@@ -306,12 +307,10 @@ class CheckpointSaver(Callback):
|
|
|
306
307
|
|
|
307
308
|
def save_checkpoint(self, path: Path, epoch: int, logs: dict):
|
|
308
309
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
else self.model
|
|
314
|
-
)
|
|
310
|
+
if hasattr(self.model, "ddp_model") and self.model.ddp_model is not None:
|
|
311
|
+
model_to_save = self.model.ddp_model.module
|
|
312
|
+
else:
|
|
313
|
+
model_to_save = self.model
|
|
315
314
|
|
|
316
315
|
# Save only state_dict to match BaseModel.save_model() format
|
|
317
316
|
torch.save(model_to_save.state_dict(), path)
|
|
@@ -328,12 +327,13 @@ class CheckpointSaver(Callback):
|
|
|
328
327
|
with open(config_path, "wb") as f:
|
|
329
328
|
pickle.dump(features_config, f)
|
|
330
329
|
|
|
331
|
-
if self.verbose
|
|
330
|
+
if self.verbose == 1:
|
|
332
331
|
logging.info(f"Saved checkpoint to {path}")
|
|
333
332
|
|
|
334
333
|
|
|
335
334
|
class LearningRateScheduler(Callback):
|
|
336
|
-
"""
|
|
335
|
+
"""
|
|
336
|
+
Callback for learning rate scheduling.
|
|
337
337
|
|
|
338
338
|
Args:
|
|
339
339
|
scheduler: Learning rate scheduler instance or name.
|
|
@@ -346,73 +346,25 @@ class LearningRateScheduler(Callback):
|
|
|
346
346
|
self.verbose = verbose
|
|
347
347
|
|
|
348
348
|
def on_train_begin(self, logs: Optional[dict] = None):
|
|
349
|
-
if self.scheduler is None
|
|
349
|
+
if self.scheduler is None:
|
|
350
350
|
self.scheduler = self.model.scheduler_fn
|
|
351
351
|
|
|
352
352
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
353
353
|
if self.scheduler is not None:
|
|
354
|
-
|
|
355
|
-
if
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
# Step the scheduler
|
|
359
|
-
if hasattr(self.scheduler, "step"):
|
|
360
|
-
# Some schedulers need metrics
|
|
361
|
-
if logs is None:
|
|
362
|
-
logs = {}
|
|
363
|
-
if "val_loss" in logs and hasattr(self.scheduler, "mode"):
|
|
364
|
-
self.scheduler.step(logs["val_loss"])
|
|
365
|
-
else:
|
|
366
|
-
self.scheduler.step()
|
|
354
|
+
old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
355
|
+
if logs is None:
|
|
356
|
+
logs = {}
|
|
367
357
|
|
|
368
|
-
#
|
|
369
|
-
if
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
logging.info(
|
|
374
|
-
f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
|
|
375
|
-
)
|
|
358
|
+
# step for ReduceLROnPlateau
|
|
359
|
+
if "val_loss" in logs and hasattr(self.scheduler, "mode"):
|
|
360
|
+
self.scheduler.step(logs["val_loss"])
|
|
361
|
+
else:
|
|
362
|
+
self.scheduler.step()
|
|
376
363
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
|
|
387
|
-
super().__init__()
|
|
388
|
-
self.run_on_main_process_only = True
|
|
389
|
-
self.log_freq = log_freq
|
|
390
|
-
self.verbose = verbose
|
|
391
|
-
|
|
392
|
-
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
393
|
-
if self.verbose > 0 and (
|
|
394
|
-
self.log_freq == "epoch"
|
|
395
|
-
or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
|
|
396
|
-
):
|
|
397
|
-
logs = logs or {}
|
|
398
|
-
metrics_str = " - ".join(
|
|
399
|
-
[
|
|
400
|
-
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
401
|
-
for k, v in logs.items()
|
|
402
|
-
]
|
|
403
|
-
)
|
|
404
|
-
logging.info(f"Epoch {epoch + 1}: {metrics_str}")
|
|
405
|
-
|
|
406
|
-
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
407
|
-
if self.verbose > 1 and (
|
|
408
|
-
self.log_freq == "batch"
|
|
409
|
-
or (isinstance(self.log_freq, int) and (batch + 1) % self.log_freq == 0)
|
|
410
|
-
):
|
|
411
|
-
logs = logs or {}
|
|
412
|
-
metrics_str = " - ".join(
|
|
413
|
-
[
|
|
414
|
-
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
415
|
-
for k, v in logs.items()
|
|
416
|
-
]
|
|
417
|
-
)
|
|
418
|
-
logging.info(f"Batch {batch}: {metrics_str}")
|
|
364
|
+
# Log new lr
|
|
365
|
+
if self.verbose == 1:
|
|
366
|
+
new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
367
|
+
if new_lr != old_lr:
|
|
368
|
+
logging.info(
|
|
369
|
+
f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
|
|
370
|
+
)
|
nextrec/basic/features.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Feature definitions
|
|
2
|
+
Feature definitions for NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
11
13
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
12
|
-
from nextrec.utils.feature import
|
|
14
|
+
from nextrec.utils.feature import to_list
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class BaseFeature:
|
|
@@ -25,12 +27,20 @@ class EmbeddingFeature(BaseFeature):
|
|
|
25
27
|
name: str,
|
|
26
28
|
vocab_size: int,
|
|
27
29
|
embedding_name: str = "",
|
|
28
|
-
embedding_dim: int | None =
|
|
29
|
-
padding_idx: int
|
|
30
|
-
init_type:
|
|
30
|
+
embedding_dim: int | None = None,
|
|
31
|
+
padding_idx: int = 0,
|
|
32
|
+
init_type: Literal[
|
|
33
|
+
"normal",
|
|
34
|
+
"uniform",
|
|
35
|
+
"xavier_uniform",
|
|
36
|
+
"xavier_normal",
|
|
37
|
+
"kaiming_uniform",
|
|
38
|
+
"kaiming_normal",
|
|
39
|
+
"orthogonal",
|
|
40
|
+
] = "normal",
|
|
31
41
|
init_params: dict | None = None,
|
|
32
42
|
l1_reg: float = 0.0,
|
|
33
|
-
l2_reg: float =
|
|
43
|
+
l2_reg: float = 0.0,
|
|
34
44
|
trainable: bool = True,
|
|
35
45
|
pretrained_weight: torch.Tensor | None = None,
|
|
36
46
|
freeze_pretrained: bool = False,
|
|
@@ -55,23 +65,57 @@ class EmbeddingFeature(BaseFeature):
|
|
|
55
65
|
|
|
56
66
|
|
|
57
67
|
class SequenceFeature(EmbeddingFeature):
|
|
68
|
+
|
|
58
69
|
def __init__(
|
|
59
70
|
self,
|
|
60
71
|
name: str,
|
|
61
72
|
vocab_size: int,
|
|
62
|
-
max_len: int =
|
|
73
|
+
max_len: int = 50,
|
|
63
74
|
embedding_name: str = "",
|
|
64
|
-
embedding_dim: int | None =
|
|
65
|
-
combiner:
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
embedding_dim: int | None = None,
|
|
76
|
+
combiner: Literal[
|
|
77
|
+
"mean",
|
|
78
|
+
"sum",
|
|
79
|
+
"concat",
|
|
80
|
+
"dot_attention",
|
|
81
|
+
"self_attention",
|
|
82
|
+
] = "mean",
|
|
83
|
+
padding_idx: int = 0,
|
|
84
|
+
init_type: Literal[
|
|
85
|
+
"normal",
|
|
86
|
+
"uniform",
|
|
87
|
+
"xavier_uniform",
|
|
88
|
+
"xavier_normal",
|
|
89
|
+
"kaiming_uniform",
|
|
90
|
+
"kaiming_normal",
|
|
91
|
+
"orthogonal",
|
|
92
|
+
] = "normal",
|
|
68
93
|
init_params: dict | None = None,
|
|
69
94
|
l1_reg: float = 0.0,
|
|
70
|
-
l2_reg: float =
|
|
95
|
+
l2_reg: float = 0.0,
|
|
71
96
|
trainable: bool = True,
|
|
72
97
|
pretrained_weight: torch.Tensor | None = None,
|
|
73
98
|
freeze_pretrained: bool = False,
|
|
74
99
|
):
|
|
100
|
+
"""
|
|
101
|
+
Sequence feature for variable-length categorical id sequences.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Feature name used as input key.
|
|
105
|
+
vocab_size: Number of unique ids in the sequence vocabulary.
|
|
106
|
+
max_len: Maximum sequence length for padding/truncation.
|
|
107
|
+
embedding_name: Shared embedding table name. Defaults to ``name``.
|
|
108
|
+
embedding_dim: Embedding dimension. Set to ``None`` for auto sizing.
|
|
109
|
+
combiner: Pooling method for sequence embeddings, e.g. ``"mean"`` or ``"sum"``.
|
|
110
|
+
padding_idx: Index used for padding tokens.
|
|
111
|
+
init_type: Embedding initializer type.
|
|
112
|
+
init_params: Initializer parameters.
|
|
113
|
+
l1_reg: L1 regularization weight on embedding.
|
|
114
|
+
l2_reg: L2 regularization weight on embedding.
|
|
115
|
+
trainable: Whether the embedding is trainable. [TODO] This is for representation learning.
|
|
116
|
+
pretrained_weight: Optional pretrained embedding weights. [TODO] This is for representation learning.
|
|
117
|
+
freeze_pretrained: If True, keep pretrained weights frozen. [TODO] This is for representation learning.
|
|
118
|
+
"""
|
|
75
119
|
super().__init__(
|
|
76
120
|
name=name,
|
|
77
121
|
vocab_size=vocab_size,
|
|
@@ -91,28 +135,105 @@ class SequenceFeature(EmbeddingFeature):
|
|
|
91
135
|
|
|
92
136
|
|
|
93
137
|
class SparseFeature(EmbeddingFeature):
|
|
94
|
-
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
name: str,
|
|
142
|
+
vocab_size: int,
|
|
143
|
+
embedding_name: str = "",
|
|
144
|
+
embedding_dim: int | None = None,
|
|
145
|
+
padding_idx: int = 0,
|
|
146
|
+
init_type: Literal[
|
|
147
|
+
"normal",
|
|
148
|
+
"uniform",
|
|
149
|
+
"xavier_uniform",
|
|
150
|
+
"xavier_normal",
|
|
151
|
+
"kaiming_uniform",
|
|
152
|
+
"kaiming_normal",
|
|
153
|
+
"orthogonal",
|
|
154
|
+
] = "normal",
|
|
155
|
+
init_params: dict | None = None,
|
|
156
|
+
l1_reg: float = 0.0,
|
|
157
|
+
l2_reg: float = 0.0,
|
|
158
|
+
trainable: bool = True,
|
|
159
|
+
pretrained_weight: torch.Tensor | None = None,
|
|
160
|
+
freeze_pretrained: bool = False,
|
|
161
|
+
):
|
|
162
|
+
"""
|
|
163
|
+
Sparse feature for categorical ids.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
name: Feature name used as input key.
|
|
167
|
+
vocab_size: Number of unique categorical ids.
|
|
168
|
+
embedding_name: Shared embedding table name. Defaults to ``name``.
|
|
169
|
+
embedding_dim: Embedding dimension. Set to ``None`` for auto sizing.
|
|
170
|
+
padding_idx: Index used for padding tokens.
|
|
171
|
+
init_type: Embedding initializer type.
|
|
172
|
+
init_params: Initializer parameters.
|
|
173
|
+
l1_reg: L1 regularization weight on embedding.
|
|
174
|
+
l2_reg: L2 regularization weight on embedding.
|
|
175
|
+
trainable: Whether the embedding is trainable.
|
|
176
|
+
pretrained_weight: Optional pretrained embedding weights.
|
|
177
|
+
freeze_pretrained: If True, keep pretrained weights frozen.
|
|
178
|
+
"""
|
|
179
|
+
super().__init__(
|
|
180
|
+
name=name,
|
|
181
|
+
vocab_size=vocab_size,
|
|
182
|
+
embedding_name=embedding_name,
|
|
183
|
+
embedding_dim=embedding_dim,
|
|
184
|
+
padding_idx=padding_idx,
|
|
185
|
+
init_type=init_type,
|
|
186
|
+
init_params=init_params,
|
|
187
|
+
l1_reg=l1_reg,
|
|
188
|
+
l2_reg=l2_reg,
|
|
189
|
+
trainable=trainable,
|
|
190
|
+
pretrained_weight=pretrained_weight,
|
|
191
|
+
freeze_pretrained=freeze_pretrained,
|
|
192
|
+
)
|
|
95
193
|
|
|
96
194
|
|
|
97
195
|
class DenseFeature(BaseFeature):
|
|
196
|
+
|
|
98
197
|
def __init__(
|
|
99
198
|
self,
|
|
100
199
|
name: str,
|
|
101
|
-
embedding_dim: int | None = 1,
|
|
102
200
|
input_dim: int = 1,
|
|
103
|
-
|
|
201
|
+
proj_dim: int | None = 0,
|
|
202
|
+
use_projection: bool = False,
|
|
203
|
+
trainable: bool = True,
|
|
204
|
+
pretrained_weight: torch.Tensor | None = None,
|
|
205
|
+
freeze_pretrained: bool = False,
|
|
104
206
|
):
|
|
207
|
+
"""
|
|
208
|
+
Dense feature for continuous values.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
name: Feature name used as input key.
|
|
212
|
+
input_dim: Input dimension for continuous values.
|
|
213
|
+
proj_dim: Projection dimension. If None or 0, no projection is applied.
|
|
214
|
+
use_projection: Whether to project inputs to higher dimension.
|
|
215
|
+
trainable: Whether the projection is trainable.
|
|
216
|
+
pretrained_weight: Optional pretrained projection weights.
|
|
217
|
+
freeze_pretrained: If True, keep pretrained weights frozen.
|
|
218
|
+
"""
|
|
105
219
|
self.name = name
|
|
106
|
-
self.input_dim = max(int(input_dim
|
|
107
|
-
self.
|
|
108
|
-
if
|
|
220
|
+
self.input_dim = max(int(input_dim), 1)
|
|
221
|
+
self.proj_dim = self.input_dim if proj_dim is None else proj_dim
|
|
222
|
+
if use_projection and self.proj_dim == 0:
|
|
109
223
|
raise ValueError(
|
|
110
|
-
"[Features Error] DenseFeature:
|
|
224
|
+
"[Features Error] DenseFeature: use_projection=True is incompatible with proj_dim=0"
|
|
111
225
|
)
|
|
112
|
-
if
|
|
113
|
-
self.
|
|
226
|
+
if proj_dim is not None and proj_dim > 1:
|
|
227
|
+
self.use_projection = True
|
|
114
228
|
else:
|
|
115
|
-
self.
|
|
229
|
+
self.use_projection = use_projection
|
|
230
|
+
self.embedding_dim = (
|
|
231
|
+
self.input_dim if not self.use_projection else self.proj_dim
|
|
232
|
+
) # for compatibility
|
|
233
|
+
|
|
234
|
+
self.trainable = trainable
|
|
235
|
+
self.pretrained_weight = pretrained_weight
|
|
236
|
+
self.freeze_pretrained = freeze_pretrained
|
|
116
237
|
|
|
117
238
|
|
|
118
239
|
class FeatureSet:
|
|
@@ -123,7 +244,7 @@ class FeatureSet:
|
|
|
123
244
|
sequence_features: list[SequenceFeature] | None = None,
|
|
124
245
|
target: str | list[str] | None = None,
|
|
125
246
|
id_columns: str | list[str] | None = None,
|
|
126
|
-
)
|
|
247
|
+
):
|
|
127
248
|
self.dense_features = list(dense_features) if dense_features else []
|
|
128
249
|
self.sparse_features = list(sparse_features) if sparse_features else []
|
|
129
250
|
self.sequence_features = list(sequence_features) if sequence_features else []
|
|
@@ -132,13 +253,13 @@ class FeatureSet:
|
|
|
132
253
|
self.dense_features + self.sparse_features + self.sequence_features
|
|
133
254
|
)
|
|
134
255
|
self.feature_names = [feat.name for feat in self.all_features]
|
|
135
|
-
self.target_columns =
|
|
136
|
-
self.id_columns =
|
|
256
|
+
self.target_columns = to_list(target)
|
|
257
|
+
self.id_columns = to_list(id_columns)
|
|
137
258
|
|
|
138
259
|
def set_target_id(
|
|
139
260
|
self,
|
|
140
261
|
target: str | list[str] | None = None,
|
|
141
262
|
id_columns: str | list[str] | None = None,
|
|
142
263
|
) -> None:
|
|
143
|
-
self.target_columns =
|
|
144
|
-
self.id_columns =
|
|
264
|
+
self.target_columns = to_list(target)
|
|
265
|
+
self.id_columns = to_list(id_columns)
|
nextrec/basic/heads.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Task head implementations for NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 23/12/2025
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
@@ -26,7 +27,9 @@ class TaskHead(nn.Module):
|
|
|
26
27
|
|
|
27
28
|
def __init__(
|
|
28
29
|
self,
|
|
29
|
-
task_type:
|
|
30
|
+
task_type: (
|
|
31
|
+
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
32
|
+
) = "binary",
|
|
30
33
|
task_dims: int | list[int] | None = None,
|
|
31
34
|
use_bias: bool = True,
|
|
32
35
|
return_logits: bool = False,
|