nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +191 -71
  5. nextrec/basic/summary.py +58 -0
  6. nextrec/cli.py +13 -0
  7. nextrec/data/data_processing.py +3 -9
  8. nextrec/data/dataloader.py +25 -2
  9. nextrec/data/preprocessor.py +283 -36
  10. nextrec/models/multi_task/[pre]aitm.py +173 -0
  11. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  12. nextrec/models/multi_task/[pre]star.py +192 -0
  13. nextrec/models/multi_task/apg.py +330 -0
  14. nextrec/models/multi_task/cross_stitch.py +229 -0
  15. nextrec/models/multi_task/escm.py +290 -0
  16. nextrec/models/multi_task/esmm.py +8 -21
  17. nextrec/models/multi_task/hmoe.py +203 -0
  18. nextrec/models/multi_task/mmoe.py +20 -28
  19. nextrec/models/multi_task/pepnet.py +68 -66
  20. nextrec/models/multi_task/ple.py +30 -44
  21. nextrec/models/multi_task/poso.py +13 -22
  22. nextrec/models/multi_task/share_bottom.py +14 -25
  23. nextrec/models/ranking/afm.py +2 -2
  24. nextrec/models/ranking/autoint.py +2 -4
  25. nextrec/models/ranking/dcn.py +2 -3
  26. nextrec/models/ranking/dcn_v2.py +2 -3
  27. nextrec/models/ranking/deepfm.py +2 -3
  28. nextrec/models/ranking/dien.py +7 -9
  29. nextrec/models/ranking/din.py +8 -10
  30. nextrec/models/ranking/eulernet.py +1 -2
  31. nextrec/models/ranking/ffm.py +1 -2
  32. nextrec/models/ranking/fibinet.py +2 -3
  33. nextrec/models/ranking/fm.py +1 -1
  34. nextrec/models/ranking/lr.py +1 -1
  35. nextrec/models/ranking/masknet.py +1 -2
  36. nextrec/models/ranking/pnn.py +1 -2
  37. nextrec/models/ranking/widedeep.py +2 -3
  38. nextrec/models/ranking/xdeepfm.py +2 -4
  39. nextrec/models/representation/rqvae.py +4 -4
  40. nextrec/models/retrieval/dssm.py +18 -26
  41. nextrec/models/retrieval/dssm_v2.py +15 -22
  42. nextrec/models/retrieval/mind.py +9 -15
  43. nextrec/models/retrieval/sdm.py +36 -33
  44. nextrec/models/retrieval/youtube_dnn.py +16 -24
  45. nextrec/models/sequential/hstu.py +2 -2
  46. nextrec/utils/__init__.py +5 -1
  47. nextrec/utils/config.py +2 -0
  48. nextrec/utils/model.py +16 -77
  49. nextrec/utils/torch_utils.py +11 -0
  50. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
  51. nextrec-0.4.27.dist-info/RECORD +90 -0
  52. nextrec/models/multi_task/aitm.py +0 -0
  53. nextrec/models/multi_task/snr_trans.py +0 -0
  54. nextrec-0.4.24.dist-info/RECORD +0 -86
  55. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
  56. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
  57. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.24"
1
+ __version__ = "0.4.27"
@@ -0,0 +1,72 @@
1
+ """
2
+ Assert function definitions for NextRec models.
3
+
4
+ Date: create on 01/01/2026
5
+ Checkpoint: edit on 01/01/2026
6
+ Author: Yang Zhou, zyaztec@gmail.com
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from nextrec.utils.types import TaskTypeName, TrainingModeName
12
+
13
+
14
+ def assert_task(
15
+ task: list[TaskTypeName] | TaskTypeName | None,
16
+ nums_task: int,
17
+ *,
18
+ model_name: str,
19
+ ) -> None:
20
+ if task is None:
21
+ raise ValueError(f"{model_name} requires task to be specified.")
22
+
23
+ # case 1: task is str
24
+ if isinstance(task, str):
25
+ if nums_task != 1:
26
+ raise ValueError(
27
+ f"{model_name} received task='{task}' but nums_task={nums_task}. "
28
+ "String task is only allowed for single-task models."
29
+ )
30
+ return # single-task, valid
31
+
32
+ # case 2: task is list
33
+ if not isinstance(task, list):
34
+ raise TypeError(
35
+ f"{model_name} requires task to be a string or a list of strings."
36
+ )
37
+
38
+ # list but length == 1
39
+ if len(task) == 1:
40
+ if nums_task != 1:
41
+ raise ValueError(
42
+ f"{model_name} received task list of length 1 but nums_task={nums_task}. "
43
+ "Length-1 task list is only allowed for single-task models."
44
+ )
45
+ return # single-task, valid
46
+
47
+ # multi-task: length must match nums_task
48
+ if len(task) != nums_task:
49
+ raise ValueError(
50
+ f"{model_name} requires task length {nums_task}, got {len(task)}."
51
+ )
52
+
53
+
54
+ def assert_training_mode(
55
+ training_mode: TrainingModeName | list[TrainingModeName],
56
+ nums_task: int,
57
+ *,
58
+ model_name: str,
59
+ ) -> None:
60
+ valid_modes = {"pointwise", "pairwise", "listwise"}
61
+ if not isinstance(training_mode, list):
62
+ raise TypeError(
63
+ f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
64
+ )
65
+ if len(training_mode) != nums_task:
66
+ raise ValueError(
67
+ f"[{model_name}-init Error] training_mode list length must match number of tasks."
68
+ )
69
+ if any(mode not in valid_modes for mode in training_mode):
70
+ raise ValueError(
71
+ f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
72
+ )
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 01/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -190,6 +190,19 @@ class BasicLogger:
190
190
  def close(self) -> None:
191
191
  for backend in self.backends:
192
192
  backend.close()
193
+ for backend in self.backends:
194
+ if isinstance(backend, SwanLabLogger):
195
+ swanlab = backend.swanlab
196
+ if not backend.enabled or swanlab is None:
197
+ continue
198
+ finish_fn = getattr(swanlab, "finish", None)
199
+ if finish_fn is None:
200
+ continue
201
+ try:
202
+ finish_fn()
203
+ except TypeError:
204
+ finish_fn()
205
+ break
193
206
 
194
207
 
195
208
  class TensorBoardLogger(MetricsLoggerBackend):
@@ -369,10 +382,14 @@ class TrainingLogger(BasicLogger):
369
382
  wandb_kwargs = dict(wandb_kwargs or {})
370
383
  wandb_kwargs.setdefault("config", {})
371
384
  wandb_kwargs["config"].update(config)
385
+ if "notes" in wandb_kwargs:
386
+ wandb_kwargs["config"].pop("note", None)
372
387
 
373
388
  swanlab_kwargs = dict(swanlab_kwargs or {})
374
389
  swanlab_kwargs.setdefault("config", {})
375
390
  swanlab_kwargs["config"].update(config)
391
+ if "description" in swanlab_kwargs:
392
+ swanlab_kwargs["config"].pop("note", None)
376
393
 
377
394
  self.wandb_logger = None
378
395
  if use_wandb: