nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.27"
|
nextrec/basic/asserts.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Assert function definitions for NextRec models.
|
|
3
|
+
|
|
4
|
+
Date: create on 01/01/2026
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from nextrec.utils.types import TaskTypeName, TrainingModeName
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def assert_task(
|
|
15
|
+
task: list[TaskTypeName] | TaskTypeName | None,
|
|
16
|
+
nums_task: int,
|
|
17
|
+
*,
|
|
18
|
+
model_name: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
if task is None:
|
|
21
|
+
raise ValueError(f"{model_name} requires task to be specified.")
|
|
22
|
+
|
|
23
|
+
# case 1: task is str
|
|
24
|
+
if isinstance(task, str):
|
|
25
|
+
if nums_task != 1:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"{model_name} received task='{task}' but nums_task={nums_task}. "
|
|
28
|
+
"String task is only allowed for single-task models."
|
|
29
|
+
)
|
|
30
|
+
return # single-task, valid
|
|
31
|
+
|
|
32
|
+
# case 2: task is list
|
|
33
|
+
if not isinstance(task, list):
|
|
34
|
+
raise TypeError(
|
|
35
|
+
f"{model_name} requires task to be a string or a list of strings."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# list but length == 1
|
|
39
|
+
if len(task) == 1:
|
|
40
|
+
if nums_task != 1:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"{model_name} received task list of length 1 but nums_task={nums_task}. "
|
|
43
|
+
"Length-1 task list is only allowed for single-task models."
|
|
44
|
+
)
|
|
45
|
+
return # single-task, valid
|
|
46
|
+
|
|
47
|
+
# multi-task: length must match nums_task
|
|
48
|
+
if len(task) != nums_task:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"{model_name} requires task length {nums_task}, got {len(task)}."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def assert_training_mode(
|
|
55
|
+
training_mode: TrainingModeName | list[TrainingModeName],
|
|
56
|
+
nums_task: int,
|
|
57
|
+
*,
|
|
58
|
+
model_name: str,
|
|
59
|
+
) -> None:
|
|
60
|
+
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
61
|
+
if not isinstance(training_mode, list):
|
|
62
|
+
raise TypeError(
|
|
63
|
+
f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
|
|
64
|
+
)
|
|
65
|
+
if len(training_mode) != nums_task:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"[{model_name}-init Error] training_mode list length must match number of tasks."
|
|
68
|
+
)
|
|
69
|
+
if any(mode not in valid_modes for mode in training_mode):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
72
|
+
)
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -190,6 +190,19 @@ class BasicLogger:
|
|
|
190
190
|
def close(self) -> None:
|
|
191
191
|
for backend in self.backends:
|
|
192
192
|
backend.close()
|
|
193
|
+
for backend in self.backends:
|
|
194
|
+
if isinstance(backend, SwanLabLogger):
|
|
195
|
+
swanlab = backend.swanlab
|
|
196
|
+
if not backend.enabled or swanlab is None:
|
|
197
|
+
continue
|
|
198
|
+
finish_fn = getattr(swanlab, "finish", None)
|
|
199
|
+
if finish_fn is None:
|
|
200
|
+
continue
|
|
201
|
+
try:
|
|
202
|
+
finish_fn()
|
|
203
|
+
except TypeError:
|
|
204
|
+
finish_fn()
|
|
205
|
+
break
|
|
193
206
|
|
|
194
207
|
|
|
195
208
|
class TensorBoardLogger(MetricsLoggerBackend):
|
|
@@ -369,10 +382,14 @@ class TrainingLogger(BasicLogger):
|
|
|
369
382
|
wandb_kwargs = dict(wandb_kwargs or {})
|
|
370
383
|
wandb_kwargs.setdefault("config", {})
|
|
371
384
|
wandb_kwargs["config"].update(config)
|
|
385
|
+
if "notes" in wandb_kwargs:
|
|
386
|
+
wandb_kwargs["config"].pop("note", None)
|
|
372
387
|
|
|
373
388
|
swanlab_kwargs = dict(swanlab_kwargs or {})
|
|
374
389
|
swanlab_kwargs.setdefault("config", {})
|
|
375
390
|
swanlab_kwargs["config"].update(config)
|
|
391
|
+
if "description" in swanlab_kwargs:
|
|
392
|
+
swanlab_kwargs["config"].pop("note", None)
|
|
376
393
|
|
|
377
394
|
self.wandb_logger = None
|
|
378
395
|
if use_wandb:
|