nextrec 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {nextrec-0.3.2 → nextrec-0.3.4}/PKG-INFO +3 -3
  2. {nextrec-0.3.2 → nextrec-0.3.4}/README.md +2 -2
  3. {nextrec-0.3.2 → nextrec-0.3.4}/README_zh.md +2 -2
  4. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/conf.py +1 -1
  5. nextrec-0.3.4/nextrec/__version__.py +1 -0
  6. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/features.py +10 -23
  7. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/layers.py +18 -61
  8. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/loggers.py +71 -8
  9. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/metrics.py +55 -33
  10. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/model.py +287 -397
  11. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/data/__init__.py +2 -2
  12. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/data/data_utils.py +80 -4
  13. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/data/dataloader.py +38 -59
  14. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/data/preprocessor.py +38 -73
  15. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/generative/hstu.py +1 -1
  16. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/dssm.py +2 -2
  17. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/dssm_v2.py +2 -2
  18. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/mind.py +2 -2
  19. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/sdm.py +2 -2
  20. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/youtube_dnn.py +2 -2
  21. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/multi_task/esmm.py +1 -1
  22. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/multi_task/mmoe.py +1 -1
  23. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/multi_task/ple.py +1 -1
  24. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/multi_task/poso.py +1 -1
  25. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/multi_task/share_bottom.py +1 -1
  26. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/afm.py +1 -1
  27. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/autoint.py +1 -1
  28. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/dcn.py +1 -1
  29. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/deepfm.py +1 -1
  30. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/dien.py +1 -1
  31. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/din.py +1 -1
  32. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/fibinet.py +1 -1
  33. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/fm.py +1 -1
  34. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/masknet.py +2 -2
  35. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/pnn.py +1 -1
  36. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/widedeep.py +1 -1
  37. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/xdeepfm.py +1 -1
  38. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/utils/__init__.py +2 -1
  39. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/utils/common.py +21 -2
  40. {nextrec-0.3.2 → nextrec-0.3.4}/pyproject.toml +1 -1
  41. {nextrec-0.3.2 → nextrec-0.3.4}/requirements.txt +2 -1
  42. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_losses.py +1 -1
  43. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_multitask_models.py +1 -1
  44. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/example_match_dssm.py +1 -1
  45. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/example_multitask.py +6 -61
  46. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/example_ranking_din.py +1 -47
  47. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/movielen_match_dssm.py +3 -2
  48. nextrec-0.3.4/tutorials/run_all_tutorials.py +59 -0
  49. nextrec-0.3.2/nextrec/__version__.py +0 -1
  50. {nextrec-0.3.2 → nextrec-0.3.4}/.github/workflows/publish.yml +0 -0
  51. {nextrec-0.3.2 → nextrec-0.3.4}/.github/workflows/tests.yml +0 -0
  52. {nextrec-0.3.2 → nextrec-0.3.4}/.gitignore +0 -0
  53. {nextrec-0.3.2 → nextrec-0.3.4}/.readthedocs.yaml +0 -0
  54. {nextrec-0.3.2 → nextrec-0.3.4}/CODE_OF_CONDUCT.md +0 -0
  55. {nextrec-0.3.2 → nextrec-0.3.4}/CONTRIBUTING.md +0 -0
  56. {nextrec-0.3.2 → nextrec-0.3.4}/LICENSE +0 -0
  57. {nextrec-0.3.2 → nextrec-0.3.4}/MANIFEST.in +0 -0
  58. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/Feature Configuration.png +0 -0
  59. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/Model Parameters.png +0 -0
  60. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/Training Configuration.png +0 -0
  61. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/Training logs.png +0 -0
  62. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/logo.png +0 -0
  63. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/mmoe_tutorial.png +0 -0
  64. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/nextrec_diagram_en.png +0 -0
  65. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/nextrec_diagram_zh.png +0 -0
  66. {nextrec-0.3.2 → nextrec-0.3.4}/asserts/test data.png +0 -0
  67. {nextrec-0.3.2 → nextrec-0.3.4}/dataset/ctcvr_task.csv +0 -0
  68. {nextrec-0.3.2 → nextrec-0.3.4}/dataset/match_task.csv +0 -0
  69. {nextrec-0.3.2 → nextrec-0.3.4}/dataset/movielens_100k.csv +0 -0
  70. {nextrec-0.3.2 → nextrec-0.3.4}/dataset/multitask_task.csv +0 -0
  71. {nextrec-0.3.2 → nextrec-0.3.4}/dataset/ranking_task.csv +0 -0
  72. {nextrec-0.3.2 → nextrec-0.3.4}/docs/en/Getting started guide.md +0 -0
  73. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/Makefile +0 -0
  74. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/index.md +0 -0
  75. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/make.bat +0 -0
  76. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/modules.rst +0 -0
  77. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/nextrec.basic.rst +0 -0
  78. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/nextrec.data.rst +0 -0
  79. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/nextrec.loss.rst +0 -0
  80. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/nextrec.rst +0 -0
  81. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/nextrec.utils.rst +0 -0
  82. {nextrec-0.3.2 → nextrec-0.3.4}/docs/rtd/requirements.txt +0 -0
  83. {nextrec-0.3.2 → nextrec-0.3.4}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  84. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/__init__.py +0 -0
  85. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/__init__.py +0 -0
  86. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/activation.py +0 -0
  87. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/callback.py +0 -0
  88. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/basic/session.py +0 -0
  89. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/loss/__init__.py +0 -0
  90. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/loss/listwise.py +0 -0
  91. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/loss/loss_utils.py +0 -0
  92. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/loss/pairwise.py +0 -0
  93. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/loss/pointwise.py +0 -0
  94. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/generative/__init__.py +0 -0
  95. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/generative/tiger.py +0 -0
  96. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/match/__init__.py +0 -0
  97. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/__init__.py +0 -0
  98. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/models/ranking/dcn_v2.py +0 -0
  99. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/utils/embedding.py +0 -0
  100. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/utils/initializer.py +0 -0
  101. {nextrec-0.3.2 → nextrec-0.3.4}/nextrec/utils/optimizer.py +0 -0
  102. {nextrec-0.3.2 → nextrec-0.3.4}/pytest.ini +0 -0
  103. {nextrec-0.3.2 → nextrec-0.3.4}/test/__init__.py +0 -0
  104. {nextrec-0.3.2 → nextrec-0.3.4}/test/conftest.py +0 -0
  105. {nextrec-0.3.2 → nextrec-0.3.4}/test/run_tests.py +0 -0
  106. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_layers.py +0 -0
  107. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_match_models.py +0 -0
  108. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_preprocessor.py +0 -0
  109. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_ranking_models.py +0 -0
  110. {nextrec-0.3.2 → nextrec-0.3.4}/test/test_utils.py +0 -0
  111. {nextrec-0.3.2 → nextrec-0.3.4}/test_requirements.txt +0 -0
  112. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/movielen_ranking_deepfm.py +0 -0
  113. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  114. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  115. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/notebooks/zh/Hands on dataprocessor.ipynb +0 -0
  116. {nextrec-0.3.2 → nextrec-0.3.4}/tutorials/notebooks/zh/Hands on nextrec.ipynb +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.4-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
110
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
112
 
113
- > Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.4]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
114
 
115
115
  ## 5-Minute Quick Start
116
116
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.4-orange.svg)
11
11
 
12
12
  English | [中文文档](README_zh.md)
13
13
 
@@ -54,7 +54,7 @@ To dive deeper, Jupyter notebooks are available:
54
54
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
55
55
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
57
+ > Current version [0.3.4]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
58
58
 
59
59
  ## 5-Minute Quick Start
60
60
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.4-orange.svg)
11
11
 
12
12
  [English Version](README.md) | 中文文档
13
13
 
@@ -54,7 +54,7 @@ NextRec采用模块化、低耦合的工程设计,使得推荐系统从数据
54
54
  - [如何上手NextRec框架](/tutorials/notebooks/zh/Hands%20on%20nextrec.ipynb)
55
55
  - [如何使用数据处理器进行数据预处理](/tutorials/notebooks/zh/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > 当前版本[0.3.2],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
57
+ > 当前版本[0.3.4],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
58
58
 
59
59
  ## 5分钟快速上手
60
60
 
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.3.2"
14
+ release = "0.3.4"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.3.4"
@@ -2,19 +2,16 @@
2
2
  Feature definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
  import torch
9
9
  from nextrec.utils.embedding import get_auto_embedding_dim
10
+ from nextrec.utils.common import normalize_to_list
10
11
 
11
12
  class BaseFeature(object):
12
13
  def __repr__(self):
13
- params = {
14
- k: v
15
- for k, v in self.__dict__.items()
16
- if not k.startswith("_")
17
- }
14
+ params = {k: v for k, v in self.__dict__.items() if not k.startswith("_") }
18
15
  param_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
19
16
  return f"{self.__class__.__name__}({param_str})"
20
17
 
@@ -93,11 +90,8 @@ class DenseFeature(BaseFeature):
93
90
  else:
94
91
  self.use_embedding = use_embedding # user decides for dim <= 1
95
92
 
96
- class FeatureSpecMixin:
97
- """
98
- Mixin that normalizes dense/sparse/sequence feature lists and target/id columns.
99
- """
100
- def _set_feature_config(
93
+ class FeatureSet:
94
+ def set_all_features(
101
95
  self,
102
96
  dense_features: list[DenseFeature] | None = None,
103
97
  sparse_features: list[SparseFeature] | None = None,
@@ -111,21 +105,14 @@ class FeatureSpecMixin:
111
105
 
112
106
  self.all_features = self.dense_features + self.sparse_features + self.sequence_features
113
107
  self.feature_names = [feat.name for feat in self.all_features]
114
- self.target_columns = self._normalize_to_list(target)
115
- self.id_columns = self._normalize_to_list(id_columns)
108
+ self.target_columns = normalize_to_list(target)
109
+ self.id_columns = normalize_to_list(id_columns)
116
110
 
117
- def _set_target_id_config(
111
+ def set_target_id(
118
112
  self,
119
113
  target: str | list[str] | None = None,
120
114
  id_columns: str | list[str] | None = None,
121
115
  ) -> None:
122
- self.target_columns = self._normalize_to_list(target)
123
- self.id_columns = self._normalize_to_list(id_columns)
116
+ self.target_columns = normalize_to_list(target)
117
+ self.id_columns = normalize_to_list(id_columns)
124
118
 
125
- @staticmethod
126
- def _normalize_to_list(value: str | list[str] | None) -> list[str]:
127
- if value is None:
128
- return []
129
- if isinstance(value, str):
130
- return [value]
131
- return list(value)
@@ -18,23 +18,6 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
18
18
  from nextrec.utils.initializer import get_initializer
19
19
  from nextrec.basic.activation import activation_layer
20
20
 
21
- __all__ = [
22
- "PredictionLayer",
23
- "EmbeddingLayer",
24
- "InputMask",
25
- "LR",
26
- "ConcatPooling",
27
- "AveragePooling",
28
- "SumPooling",
29
- "MLP",
30
- "FM",
31
- "CrossLayer",
32
- "SENETLayer",
33
- "BiLinearInteractionLayer",
34
- "MultiHeadSelfAttention",
35
- "AttentionPoolingLayer",
36
- ]
37
-
38
21
  class PredictionLayer(nn.Module):
39
22
  def __init__(
40
23
  self,
@@ -44,12 +27,10 @@ class PredictionLayer(nn.Module):
44
27
  return_logits: bool = False,
45
28
  ):
46
29
  super().__init__()
47
- if isinstance(task_type, str):
48
- self.task_types = [task_type]
49
- else:
50
- self.task_types = list(task_type)
30
+ self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
51
31
  if len(self.task_types) == 0:
52
32
  raise ValueError("At least one task_type must be specified.")
33
+
53
34
  if task_dims is None:
54
35
  dims = [1] * len(self.task_types)
55
36
  elif isinstance(task_dims, int):
@@ -64,7 +45,7 @@ class PredictionLayer(nn.Module):
64
45
  self.total_dim = sum(self.task_dims)
65
46
  self.return_logits = return_logits
66
47
 
67
- # Keep slice offsets per task
48
+ # slice offsets per task
68
49
  start = 0
69
50
  self._task_slices: list[tuple[int, int]] = []
70
51
  for dim in self.task_dims:
@@ -85,27 +66,25 @@ class PredictionLayer(nn.Module):
85
66
  logits = x if self.bias is None else x + self.bias
86
67
  outputs = []
87
68
  for task_type, (start, end) in zip(self.task_types, self._task_slices):
88
- task_logits = logits[..., start:end] # Extract logits for the current task
69
+ task_logits = logits[..., start:end] # logits for the current task
89
70
  if self.return_logits:
90
71
  outputs.append(task_logits)
91
72
  continue
92
- activation = self._get_activation(task_type)
73
+ task = task_type.lower()
74
+ if task == 'binary':
75
+ activation = torch.sigmoid
76
+ elif task == 'regression':
77
+ activation = lambda x: x
78
+ elif task == 'multiclass':
79
+ activation = lambda x: torch.softmax(x, dim=-1)
80
+ else:
81
+ raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
93
82
  outputs.append(activation(task_logits))
94
83
  result = torch.cat(outputs, dim=-1)
95
84
  if result.shape[-1] == 1:
96
85
  result = result.squeeze(-1)
97
86
  return result
98
87
 
99
- def _get_activation(self, task_type: str):
100
- task = task_type.lower()
101
- if task == 'binary':
102
- return torch.sigmoid
103
- if task == 'regression':
104
- return lambda x: x
105
- if task == 'multiclass':
106
- return lambda x: torch.softmax(x, dim=-1)
107
- raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
108
-
109
88
  class EmbeddingLayer(nn.Module):
110
89
  def __init__(self, features: list):
111
90
  super().__init__()
@@ -145,7 +124,7 @@ class EmbeddingLayer(nn.Module):
145
124
  self.dense_input_dims[feature.name] = in_dim
146
125
  else:
147
126
  raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
148
- self.output_dim = self._compute_output_dim()
127
+ self.output_dim = self.compute_output_dim()
149
128
 
150
129
  def forward(
151
130
  self,
@@ -181,7 +160,7 @@ class EmbeddingLayer(nn.Module):
181
160
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
182
161
 
183
162
  elif isinstance(feature, DenseFeature):
184
- dense_embeds.append(self._project_dense(feature, x))
163
+ dense_embeds.append(self.project_dense(feature, x))
185
164
 
186
165
  if squeeze_dim:
187
166
  flattened_sparse = [emb.flatten(start_dim=1) for emb in sparse_embeds]
@@ -212,7 +191,7 @@ class EmbeddingLayer(nn.Module):
212
191
  raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
213
192
  return torch.cat(output_embeddings, dim=1)
214
193
 
215
- def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
194
+ def project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
216
195
  if feature.name not in x:
217
196
  raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
218
197
  value = x[feature.name].float()
@@ -228,11 +207,7 @@ class EmbeddingLayer(nn.Module):
228
207
  dense_layer = self.dense_transforms[feature.name]
229
208
  return dense_layer(value)
230
209
 
231
- def _compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
232
- """
233
- Compute flattened embedding dimension for provided features or all tracked features.
234
- Deduplicates by feature name to avoid double-counting shared embeddings.
235
- """
210
+ def compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
236
211
  candidates = list(features) if features is not None else self.features
237
212
  unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
238
213
  dim = 0
@@ -249,14 +224,13 @@ class EmbeddingLayer(nn.Module):
249
224
  return dim
250
225
 
251
226
  def get_input_dim(self, features: list[object] | None = None) -> int:
252
- return self._compute_output_dim(features) # type: ignore[assignment]
227
+ return self.compute_output_dim(features) # type: ignore[assignment]
253
228
 
254
229
  @property
255
230
  def input_dim(self) -> int:
256
231
  return self.output_dim
257
232
 
258
233
  class InputMask(nn.Module):
259
- """Utility module to build sequence masks for pooling layers."""
260
234
  def __init__(self):
261
235
  super().__init__()
262
236
 
@@ -271,7 +245,6 @@ class InputMask(nn.Module):
271
245
  return mask.unsqueeze(1).float()
272
246
 
273
247
  class LR(nn.Module):
274
- """Wide component from Wide&Deep (Cheng et al., 2016)."""
275
248
  def __init__(
276
249
  self,
277
250
  input_dim: int,
@@ -287,7 +260,6 @@ class LR(nn.Module):
287
260
  return self.fc(x)
288
261
 
289
262
  class ConcatPooling(nn.Module):
290
- """Concatenates sequence embeddings along the temporal dimension."""
291
263
  def __init__(self):
292
264
  super().__init__()
293
265
 
@@ -295,7 +267,6 @@ class ConcatPooling(nn.Module):
295
267
  return x.flatten(start_dim=1, end_dim=2)
296
268
 
297
269
  class AveragePooling(nn.Module):
298
- """Mean pooling with optional padding mask."""
299
270
  def __init__(self):
300
271
  super().__init__()
301
272
 
@@ -308,7 +279,6 @@ class AveragePooling(nn.Module):
308
279
  return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
309
280
 
310
281
  class SumPooling(nn.Module):
311
- """Sum pooling with optional padding mask."""
312
282
  def __init__(self):
313
283
  super().__init__()
314
284
 
@@ -319,7 +289,6 @@ class SumPooling(nn.Module):
319
289
  return torch.bmm(mask, x).squeeze(1)
320
290
 
321
291
  class MLP(nn.Module):
322
- """Stacked fully connected layers used in the deep component."""
323
292
  def __init__(
324
293
  self,
325
294
  input_dim: int,
@@ -345,7 +314,6 @@ class MLP(nn.Module):
345
314
  return self.mlp(x)
346
315
 
347
316
  class FM(nn.Module):
348
- """Factorization Machine (Rendle, 2010) second-order interaction term."""
349
317
  def __init__(self, reduce_sum: bool = True):
350
318
  super().__init__()
351
319
  self.reduce_sum = reduce_sum
@@ -359,7 +327,6 @@ class FM(nn.Module):
359
327
  return 0.5 * ix
360
328
 
361
329
  class CrossLayer(nn.Module):
362
- """Single cross layer used in DCN (Wang et al., 2017)."""
363
330
  def __init__(self, input_dim: int):
364
331
  super(CrossLayer, self).__init__()
365
332
  self.w = torch.nn.Linear(input_dim, 1, bias=False)
@@ -370,7 +337,6 @@ class CrossLayer(nn.Module):
370
337
  return x
371
338
 
372
339
  class SENETLayer(nn.Module):
373
- """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
374
340
  def __init__(
375
341
  self,
376
342
  num_fields: int,
@@ -388,7 +354,6 @@ class SENETLayer(nn.Module):
388
354
  return v
389
355
 
390
356
  class BiLinearInteractionLayer(nn.Module):
391
- """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
392
357
  def __init__(
393
358
  self,
394
359
  input_dim: int,
@@ -416,7 +381,6 @@ class BiLinearInteractionLayer(nn.Module):
416
381
  return torch.cat(bilinear_list, dim=1)
417
382
 
418
383
  class MultiHeadSelfAttention(nn.Module):
419
- """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
420
384
  def __init__(
421
385
  self,
422
386
  embedding_dim: int,
@@ -438,13 +402,6 @@ class MultiHeadSelfAttention(nn.Module):
438
402
  self.dropout = nn.Dropout(dropout)
439
403
 
440
404
  def forward(self, x: torch.Tensor) -> torch.Tensor:
441
- """
442
- Args:
443
- x (torch.Tensor): Tensor of shape (batch_size, num_fields, embedding_dim)
444
-
445
- Returns:
446
- torch.Tensor: Output tensor of shape (batch_size, num_fields, embedding_dim)
447
- """
448
405
  batch_size, num_fields, _ = x.shape
449
406
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
450
407
  K = self.W_K(x)
@@ -2,17 +2,19 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 03/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
-
10
9
  import os
11
10
  import re
12
11
  import sys
12
+ import json
13
13
  import copy
14
14
  import logging
15
- from nextrec.basic.session import create_session
15
+ import numbers
16
+ from typing import Mapping, Any
17
+ from nextrec.basic.session import create_session, Session
16
18
 
17
19
  ANSI_CODES = {
18
20
  'black': '\033[30m',
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
77
79
  """Apply ANSI color and bold formatting to the given text."""
78
80
  if not color and not bold:
79
81
  return text
80
-
81
82
  result = ""
82
-
83
83
  if bold:
84
84
  result += ANSI_BOLD
85
-
86
85
  if color and color in ANSI_CODES:
87
86
  result += ANSI_CODES[color]
88
-
89
87
  result += text + ANSI_RESET
90
-
91
88
  return result
92
89
 
93
90
  def setup_logger(session_id: str | os.PathLike | None = None):
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
126
123
  logger.addHandler(console_handler)
127
124
 
128
125
  return logger
126
+
127
+ class TrainingLogger:
128
+ def __init__(
129
+ self,
130
+ session: Session,
131
+ enable_tensorboard: bool,
132
+ log_name: str = "training_metrics.jsonl",
133
+ ) -> None:
134
+ self.session = session
135
+ self.enable_tensorboard = enable_tensorboard
136
+ self.log_path = session.metrics_dir / log_name
137
+ self.log_path.parent.mkdir(parents=True, exist_ok=True)
138
+
139
+ self.tb_writer = None
140
+ self.tb_dir = None
141
+
142
+ if self.enable_tensorboard:
143
+ self._init_tensorboard()
144
+
145
+ def _init_tensorboard(self) -> None:
146
+ try:
147
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
148
+ except ImportError:
149
+ logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
150
+ self.enable_tensorboard = False
151
+ return
152
+ tb_dir = self.session.logs_dir / "tensorboard"
153
+ tb_dir.mkdir(parents=True, exist_ok=True)
154
+ self.tb_dir = tb_dir
155
+ self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
156
+
157
+ @property
158
+ def tensorboard_logdir(self):
159
+ return self.tb_dir
160
+
161
+ def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
162
+ formatted: dict[str, float] = {}
163
+ for key, value in metrics.items():
164
+ if isinstance(value, numbers.Number):
165
+ formatted[f"{split}/{key}"] = float(value)
166
+ elif hasattr(value, "item"):
167
+ try:
168
+ formatted[f"{split}/{key}"] = float(value.item())
169
+ except Exception:
170
+ continue
171
+ return formatted
172
+
173
+ def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
174
+ payload = self.format_metrics(metrics, split)
175
+ payload["step"] = int(step)
176
+ with self.log_path.open("a", encoding="utf-8") as f:
177
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
178
+
179
+ if not self.tb_writer:
180
+ return
181
+ step = int(payload.get("step", 0))
182
+ for key, value in payload.items():
183
+ if key == "step":
184
+ continue
185
+ self.tb_writer.add_scalar(key, value, global_step=step)
186
+
187
+ def close(self) -> None:
188
+ if self.tb_writer:
189
+ self.tb_writer.flush()
190
+ self.tb_writer.close()
191
+ self.tb_writer = None