nextrec 0.3.3__tar.gz → 0.3.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. {nextrec-0.3.3 → nextrec-0.3.5}/PKG-INFO +3 -3
  2. {nextrec-0.3.3 → nextrec-0.3.5}/README.md +2 -2
  3. {nextrec-0.3.3 → nextrec-0.3.5}/README_zh.md +2 -2
  4. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/conf.py +1 -1
  5. nextrec-0.3.5/nextrec/__version__.py +1 -0
  6. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/features.py +1 -1
  7. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/loggers.py +71 -8
  8. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/model.py +45 -11
  9. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/session.py +3 -10
  10. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/__init__.py +47 -9
  11. nextrec-0.3.5/nextrec/data/batch_utils.py +80 -0
  12. nextrec-0.3.5/nextrec/data/data_processing.py +152 -0
  13. nextrec-0.3.5/nextrec/data/data_utils.py +35 -0
  14. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/dataloader.py +6 -4
  15. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/preprocessor.py +39 -85
  16. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/poso.py +1 -1
  17. nextrec-0.3.5/nextrec/utils/__init__.py +68 -0
  18. nextrec-0.3.5/nextrec/utils/device.py +37 -0
  19. nextrec-0.3.5/nextrec/utils/feature.py +13 -0
  20. nextrec-0.3.5/nextrec/utils/file.py +70 -0
  21. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/initializer.py +0 -8
  22. nextrec-0.3.5/nextrec/utils/model.py +22 -0
  23. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/optimizer.py +0 -19
  24. nextrec-0.3.5/nextrec/utils/tensor.py +61 -0
  25. {nextrec-0.3.3 → nextrec-0.3.5}/pyproject.toml +1 -1
  26. {nextrec-0.3.3 → nextrec-0.3.5}/requirements.txt +2 -1
  27. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_match_dssm.py +1 -1
  28. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_multitask.py +5 -60
  29. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_ranking_din.py +0 -42
  30. nextrec-0.3.3/nextrec/__version__.py +0 -1
  31. nextrec-0.3.3/nextrec/data/data_utils.py +0 -268
  32. nextrec-0.3.3/nextrec/utils/__init__.py +0 -18
  33. nextrec-0.3.3/nextrec/utils/common.py +0 -60
  34. {nextrec-0.3.3 → nextrec-0.3.5}/.github/workflows/publish.yml +0 -0
  35. {nextrec-0.3.3 → nextrec-0.3.5}/.github/workflows/tests.yml +0 -0
  36. {nextrec-0.3.3 → nextrec-0.3.5}/.gitignore +0 -0
  37. {nextrec-0.3.3 → nextrec-0.3.5}/.readthedocs.yaml +0 -0
  38. {nextrec-0.3.3 → nextrec-0.3.5}/CODE_OF_CONDUCT.md +0 -0
  39. {nextrec-0.3.3 → nextrec-0.3.5}/CONTRIBUTING.md +0 -0
  40. {nextrec-0.3.3 → nextrec-0.3.5}/LICENSE +0 -0
  41. {nextrec-0.3.3 → nextrec-0.3.5}/MANIFEST.in +0 -0
  42. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Feature Configuration.png +0 -0
  43. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Model Parameters.png +0 -0
  44. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Training Configuration.png +0 -0
  45. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Training logs.png +0 -0
  46. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/logo.png +0 -0
  47. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/mmoe_tutorial.png +0 -0
  48. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/nextrec_diagram_en.png +0 -0
  49. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/nextrec_diagram_zh.png +0 -0
  50. {nextrec-0.3.3 → nextrec-0.3.5}/asserts/test data.png +0 -0
  51. {nextrec-0.3.3 → nextrec-0.3.5}/dataset/ctcvr_task.csv +0 -0
  52. {nextrec-0.3.3 → nextrec-0.3.5}/dataset/match_task.csv +0 -0
  53. {nextrec-0.3.3 → nextrec-0.3.5}/dataset/movielens_100k.csv +0 -0
  54. {nextrec-0.3.3 → nextrec-0.3.5}/dataset/multitask_task.csv +0 -0
  55. {nextrec-0.3.3 → nextrec-0.3.5}/dataset/ranking_task.csv +0 -0
  56. {nextrec-0.3.3 → nextrec-0.3.5}/docs/en/Getting started guide.md +0 -0
  57. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/Makefile +0 -0
  58. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/index.md +0 -0
  59. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/make.bat +0 -0
  60. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/modules.rst +0 -0
  61. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.basic.rst +0 -0
  62. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.data.rst +0 -0
  63. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.loss.rst +0 -0
  64. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.rst +0 -0
  65. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.utils.rst +0 -0
  66. {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/requirements.txt +0 -0
  67. {nextrec-0.3.3 → nextrec-0.3.5}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  68. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/__init__.py +0 -0
  69. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/__init__.py +0 -0
  70. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/activation.py +0 -0
  71. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/callback.py +0 -0
  72. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/layers.py +0 -0
  73. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/metrics.py +0 -0
  74. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/__init__.py +0 -0
  75. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/listwise.py +0 -0
  76. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/loss_utils.py +0 -0
  77. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/pairwise.py +0 -0
  78. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/pointwise.py +0 -0
  79. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/__init__.py +0 -0
  80. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/hstu.py +0 -0
  81. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/tiger.py +0 -0
  82. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/__init__.py +0 -0
  83. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/dssm.py +0 -0
  84. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/dssm_v2.py +0 -0
  85. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/mind.py +0 -0
  86. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/sdm.py +0 -0
  87. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/youtube_dnn.py +0 -0
  88. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/esmm.py +0 -0
  89. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/mmoe.py +0 -0
  90. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/ple.py +0 -0
  91. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/share_bottom.py +0 -0
  92. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/__init__.py +0 -0
  93. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/afm.py +0 -0
  94. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/autoint.py +0 -0
  95. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dcn.py +0 -0
  96. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dcn_v2.py +0 -0
  97. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/deepfm.py +0 -0
  98. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dien.py +0 -0
  99. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/din.py +0 -0
  100. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/fibinet.py +0 -0
  101. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/fm.py +0 -0
  102. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/masknet.py +0 -0
  103. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/pnn.py +0 -0
  104. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/widedeep.py +0 -0
  105. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/xdeepfm.py +0 -0
  106. {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/embedding.py +0 -0
  107. {nextrec-0.3.3 → nextrec-0.3.5}/pytest.ini +0 -0
  108. {nextrec-0.3.3 → nextrec-0.3.5}/test/__init__.py +0 -0
  109. {nextrec-0.3.3 → nextrec-0.3.5}/test/conftest.py +0 -0
  110. {nextrec-0.3.3 → nextrec-0.3.5}/test/run_tests.py +0 -0
  111. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_layers.py +0 -0
  112. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_losses.py +0 -0
  113. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_match_models.py +0 -0
  114. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_multitask_models.py +0 -0
  115. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_preprocessor.py +0 -0
  116. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_ranking_models.py +0 -0
  117. {nextrec-0.3.3 → nextrec-0.3.5}/test/test_utils.py +0 -0
  118. {nextrec-0.3.3 → nextrec-0.3.5}/test_requirements.txt +0 -0
  119. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/movielen_match_dssm.py +0 -0
  120. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/movielen_ranking_deepfm.py +0 -0
  121. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  122. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  123. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/zh/Hands on dataprocessor.ipynb +0 -0
  124. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/zh/Hands on nextrec.ipynb +0 -0
  125. {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/run_all_tutorials.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.5-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
110
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
112
 
113
- > Current version [0.3.3]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.5]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
114
 
115
115
  ## 5-Minute Quick Start
116
116
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.5-orange.svg)
11
11
 
12
12
  English | [中文文档](README_zh.md)
13
13
 
@@ -54,7 +54,7 @@ To dive deeper, Jupyter notebooks are available:
54
54
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
55
55
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > Current version [0.3.3]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
57
+ > Current version [0.3.5]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
58
58
 
59
59
  ## 5-Minute Quick Start
60
60
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.5-orange.svg)
11
11
 
12
12
  [English Version](README.md) | 中文文档
13
13
 
@@ -54,7 +54,7 @@ NextRec采用模块化、低耦合的工程设计,使得推荐系统从数据
54
54
  - [如何上手NextRec框架](/tutorials/notebooks/zh/Hands%20on%20nextrec.ipynb)
55
55
  - [如何使用数据处理器进行数据预处理](/tutorials/notebooks/zh/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > 当前版本[0.3.3],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
57
+ > 当前版本[0.3.5],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
58
58
 
59
59
  ## 5分钟快速上手
60
60
 
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.3.3"
14
+ release = "0.3.5"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.3.5"
@@ -7,7 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
  import torch
9
9
  from nextrec.utils.embedding import get_auto_embedding_dim
10
- from nextrec.utils.common import normalize_to_list
10
+ from nextrec.utils.feature import normalize_to_list
11
11
 
12
12
  class BaseFeature(object):
13
13
  def __repr__(self):
@@ -2,17 +2,19 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 03/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
-
10
9
  import os
11
10
  import re
12
11
  import sys
12
+ import json
13
13
  import copy
14
14
  import logging
15
- from nextrec.basic.session import create_session
15
+ import numbers
16
+ from typing import Mapping, Any
17
+ from nextrec.basic.session import create_session, Session
16
18
 
17
19
  ANSI_CODES = {
18
20
  'black': '\033[30m',
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
77
79
  """Apply ANSI color and bold formatting to the given text."""
78
80
  if not color and not bold:
79
81
  return text
80
-
81
82
  result = ""
82
-
83
83
  if bold:
84
84
  result += ANSI_BOLD
85
-
86
85
  if color and color in ANSI_CODES:
87
86
  result += ANSI_CODES[color]
88
-
89
87
  result += text + ANSI_RESET
90
-
91
88
  return result
92
89
 
93
90
  def setup_logger(session_id: str | os.PathLike | None = None):
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
126
123
  logger.addHandler(console_handler)
127
124
 
128
125
  return logger
126
+
127
+ class TrainingLogger:
128
+ def __init__(
129
+ self,
130
+ session: Session,
131
+ enable_tensorboard: bool,
132
+ log_name: str = "training_metrics.jsonl",
133
+ ) -> None:
134
+ self.session = session
135
+ self.enable_tensorboard = enable_tensorboard
136
+ self.log_path = session.metrics_dir / log_name
137
+ self.log_path.parent.mkdir(parents=True, exist_ok=True)
138
+
139
+ self.tb_writer = None
140
+ self.tb_dir = None
141
+
142
+ if self.enable_tensorboard:
143
+ self._init_tensorboard()
144
+
145
+ def _init_tensorboard(self) -> None:
146
+ try:
147
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
148
+ except ImportError:
149
+ logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
150
+ self.enable_tensorboard = False
151
+ return
152
+ tb_dir = self.session.logs_dir / "tensorboard"
153
+ tb_dir.mkdir(parents=True, exist_ok=True)
154
+ self.tb_dir = tb_dir
155
+ self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
156
+
157
+ @property
158
+ def tensorboard_logdir(self):
159
+ return self.tb_dir
160
+
161
+ def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
162
+ formatted: dict[str, float] = {}
163
+ for key, value in metrics.items():
164
+ if isinstance(value, numbers.Number):
165
+ formatted[f"{split}/{key}"] = float(value)
166
+ elif hasattr(value, "item"):
167
+ try:
168
+ formatted[f"{split}/{key}"] = float(value.item())
169
+ except Exception:
170
+ continue
171
+ return formatted
172
+
173
+ def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
174
+ payload = self.format_metrics(metrics, split)
175
+ payload["step"] = int(step)
176
+ with self.log_path.open("a", encoding="utf-8") as f:
177
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
178
+
179
+ if not self.tb_writer:
180
+ return
181
+ step = int(payload.get("step", 0))
182
+ for key, value in payload.items():
183
+ if key == "step":
184
+ continue
185
+ self.tb_writer.add_scalar(key, value, global_step=step)
186
+
187
+ def close(self) -> None:
188
+ if self.tb_writer:
189
+ self.tb_writer.flush()
190
+ self.tb_writer.close()
191
+ self.tb_writer = None
@@ -10,6 +10,8 @@ import os
10
10
  import tqdm
11
11
  import pickle
12
12
  import logging
13
+ import getpass
14
+ import socket
13
15
  import numpy as np
14
16
  import pandas as pd
15
17
  import torch
@@ -24,15 +26,17 @@ from nextrec.basic.callback import EarlyStopper
24
26
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
25
27
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
28
 
27
- from nextrec.basic.loggers import setup_logger, colorize
29
+ from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
28
30
  from nextrec.basic.session import resolve_save_path, create_session
29
31
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
30
32
 
31
33
  from nextrec.data.dataloader import build_tensors_from_data
32
- from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
34
+ from nextrec.data.data_processing import get_column_data, get_user_ids
35
+ from nextrec.data.batch_utils import collate_fn, batch_to_dict
33
36
 
34
37
  from nextrec.loss import get_loss_fn, get_loss_kwargs
35
- from nextrec.utils import get_optimizer, get_scheduler, to_tensor
38
+ from nextrec.utils import get_optimizer, get_scheduler
39
+ from nextrec.utils.tensor import to_tensor
36
40
 
37
41
  from nextrec import __version__
38
42
 
@@ -88,6 +92,7 @@ class BaseModel(FeatureSet, nn.Module):
88
92
  self.early_stop_patience = early_stop_patience
89
93
  self.max_gradient_norm = 1.0
90
94
  self.logger_initialized = False
95
+ self.training_logger: TrainingLogger | None = None
91
96
 
92
97
  def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
93
98
  exclude_modules = exclude_modules or []
@@ -275,11 +280,13 @@ class BaseModel(FeatureSet, nn.Module):
275
280
  metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
276
281
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
277
282
  user_id_column: str | None = None,
278
- validation_split: float | None = None):
283
+ validation_split: float | None = None,
284
+ tensorboard: bool = True,):
279
285
  self.to(self.device)
280
286
  if not self.logger_initialized:
281
287
  setup_logger(session_id=self.session_id)
282
288
  self.logger_initialized = True
289
+ self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
283
290
 
284
291
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
285
292
  self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
@@ -303,6 +310,20 @@ class BaseModel(FeatureSet, nn.Module):
303
310
  is_streaming = True
304
311
 
305
312
  self.summary()
313
+ logging.info("")
314
+ if self.training_logger and self.training_logger.enable_tensorboard:
315
+ tb_dir = self.training_logger.tensorboard_logdir
316
+ if tb_dir:
317
+ user = getpass.getuser()
318
+ host = socket.gethostname()
319
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
320
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
321
+ logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
322
+ logging.info(colorize("To view logs, run:", color="cyan"))
323
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
324
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
325
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
326
+
306
327
  logging.info("")
307
328
  logging.info(colorize("=" * 80, bold=True))
308
329
  if is_streaming:
@@ -312,7 +333,7 @@ class BaseModel(FeatureSet, nn.Module):
312
333
  logging.info(colorize("=" * 80, bold=True))
313
334
  logging.info("")
314
335
  logging.info(colorize(f"Model device: {self.device}", bold=True))
315
-
336
+
316
337
  for epoch in range(epochs):
317
338
  self.epoch_index = epoch
318
339
  if is_streaming:
@@ -326,7 +347,8 @@ class BaseModel(FeatureSet, nn.Module):
326
347
  else:
327
348
  train_loss = train_result
328
349
  train_metrics = None
329
-
350
+
351
+ train_log_payload: dict[str, float] = {}
330
352
  # handle logging for single-task and multi-task
331
353
  if self.nums_task == 1:
332
354
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
@@ -334,6 +356,9 @@ class BaseModel(FeatureSet, nn.Module):
334
356
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
335
357
  log_str += f", {metrics_str}"
336
358
  logging.info(colorize(log_str))
359
+ train_log_payload["loss"] = float(train_loss)
360
+ if train_metrics:
361
+ train_log_payload.update(train_metrics)
337
362
  else:
338
363
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
339
364
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
@@ -356,12 +381,17 @@ class BaseModel(FeatureSet, nn.Module):
356
381
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
357
382
  log_str += ", " + ", ".join(task_metric_strs)
358
383
  logging.info(colorize(log_str))
384
+ train_log_payload["loss"] = float(total_loss_val)
385
+ if train_metrics:
386
+ train_log_payload.update(train_metrics)
387
+ if self.training_logger:
388
+ self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
359
389
  if valid_loader is not None:
360
390
  # pass user_ids only if needed for GAUC metric
361
391
  val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
362
392
  if self.nums_task == 1:
363
393
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
364
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
394
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
365
395
  else:
366
396
  # multi task metrics
367
397
  task_metrics = {}
@@ -378,7 +408,9 @@ class BaseModel(FeatureSet, nn.Module):
378
408
  if target_name in task_metrics:
379
409
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
380
410
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
381
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
411
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
412
+ if val_metrics and self.training_logger:
413
+ self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
382
414
  # Handle empty validation metrics
383
415
  if not val_metrics:
384
416
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
@@ -401,6 +433,7 @@ class BaseModel(FeatureSet, nn.Module):
401
433
  self.best_metric = primary_metric
402
434
  improved = True
403
435
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
436
+ logging.info(" ")
404
437
  if improved:
405
438
  logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
406
439
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
@@ -431,6 +464,8 @@ class BaseModel(FeatureSet, nn.Module):
431
464
  if valid_loader is not None:
432
465
  logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
433
466
  self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
467
+ if self.training_logger:
468
+ self.training_logger.close()
434
469
  return self
435
470
 
436
471
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
@@ -527,6 +562,7 @@ class BaseModel(FeatureSet, nn.Module):
527
562
  batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
528
563
  if batch_user_id is not None:
529
564
  collected_user_ids.append(batch_user_id)
565
+ logging.info(" ")
530
566
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
531
567
  if len(y_true_list) > 0:
532
568
  y_true_all = np.concatenate(y_true_list, axis=0)
@@ -956,9 +992,7 @@ class BaseModel(FeatureSet, nn.Module):
956
992
  logger.info(f" Session ID: {self.session_id}")
957
993
  logger.info(f" Features Config Path: {self.features_config_path}")
958
994
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
959
-
960
- logger.info("")
961
- logger.info("")
995
+
962
996
 
963
997
 
964
998
  class BaseMatchModel(BaseModel):
@@ -1,14 +1,5 @@
1
1
  """Session and experiment utilities.
2
2
 
3
- This module centralizes session/experiment management so the rest of the
4
- framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
5
-
6
- Within that folder we keep model parameters, checkpoints, training metrics,
7
- evaluation metrics, and consolidated log output. When users do not provide an
8
- ``experiment_id`` a timestamp-based identifier is generated once per process to
9
- avoid scattering files across multiple directories. Test runs are redirected to
10
- temporary folders so local trees are not polluted.
11
-
12
3
  Date: create on 23/11/2025
13
4
  Author: Yang Zhou,zyaztec@gmail.com
14
5
  """
@@ -16,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
16
7
  import os
17
8
  import tempfile
18
9
  from dataclasses import dataclass
19
- from datetime import datetime
10
+ from datetime import datetime, timezone
20
11
  from pathlib import Path
21
12
 
22
13
  __all__ = [
@@ -74,6 +65,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
74
65
  if experiment_id is not None and str(experiment_id).strip():
75
66
  exp_id = str(experiment_id).strip()
76
67
  else:
68
+ # Use local time for session naming
77
69
  exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
78
70
 
79
71
  if (
@@ -111,6 +103,7 @@ def resolve_save_path(
111
103
  timestamp.
112
104
  - Parent directories are created.
113
105
  """
106
+ # Use local time for file timestamps
114
107
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
115
108
 
116
109
  normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
@@ -1,48 +1,86 @@
1
1
  """
2
2
  Data utilities package for NextRec
3
3
 
4
- This package provides data processing and manipulation utilities.
4
+ This package provides data processing and manipulation utilities organized by category:
5
+ - batch_utils: Batch collation and processing
6
+ - data_processing: Data manipulation and user ID extraction
7
+ - data_utils: Legacy module (re-exports from specialized modules)
8
+ - dataloader: Dataset and DataLoader implementations
9
+ - preprocessor: Data preprocessing pipeline
5
10
 
6
11
  Date: create on 13/11/2025
12
+ Last update: 03/12/2025 (refactored)
7
13
  Author: Yang Zhou, zyaztec@gmail.com
8
14
  """
9
15
 
10
- from nextrec.data.data_utils import (
11
- collate_fn,
16
+ # Batch utilities
17
+ from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
18
+
19
+ # Data processing utilities
20
+ from nextrec.data.data_processing import (
12
21
  get_column_data,
13
- default_output_dir,
14
22
  split_dict_random,
15
23
  build_eval_candidates,
24
+ get_user_ids,
25
+ )
26
+
27
+ # File utilities (from utils package)
28
+ from nextrec.utils.file import (
16
29
  resolve_file_paths,
17
30
  iter_file_chunks,
18
31
  read_table,
19
32
  load_dataframes,
33
+ default_output_dir,
20
34
  )
21
- from nextrec.basic.features import FeatureSet
22
- from nextrec.data import data_utils
35
+
36
+ # DataLoader components
23
37
  from nextrec.data.dataloader import (
24
38
  TensorDictDataset,
25
39
  FileDataset,
26
40
  RecDataLoader,
27
41
  build_tensors_from_data,
28
42
  )
43
+
44
+ # Preprocessor
29
45
  from nextrec.data.preprocessor import DataProcessor
30
46
 
47
+ # Feature definitions
48
+ from nextrec.basic.features import FeatureSet
49
+
50
+ # Legacy module (for backward compatibility)
51
+ from nextrec.data import data_utils
52
+
31
53
  __all__ = [
54
+ # Batch utilities
32
55
  'collate_fn',
56
+ 'batch_to_dict',
57
+ 'stack_section',
58
+
59
+ # Data processing
33
60
  'get_column_data',
34
- 'default_output_dir',
35
61
  'split_dict_random',
36
62
  'build_eval_candidates',
63
+ 'get_user_ids',
64
+
65
+ # File utilities
37
66
  'resolve_file_paths',
38
67
  'iter_file_chunks',
39
68
  'read_table',
40
69
  'load_dataframes',
41
- 'FeatureSet',
42
- 'data_utils',
70
+ 'default_output_dir',
71
+
72
+ # DataLoader
43
73
  'TensorDictDataset',
44
74
  'FileDataset',
45
75
  'RecDataLoader',
46
76
  'build_tensors_from_data',
77
+
78
+ # Preprocessor
47
79
  'DataProcessor',
80
+
81
+ # Features
82
+ 'FeatureSet',
83
+
84
+ # Legacy module
85
+ 'data_utils',
48
86
  ]
@@ -0,0 +1,80 @@
1
+ """
2
+ Batch collation utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import Any, Mapping
11
+
12
+ def stack_section(batch: list[dict], section: str):
13
+ entries = [item.get(section) for item in batch if item.get(section) is not None]
14
+ if not entries:
15
+ return None
16
+ merged: dict = {}
17
+ for name in entries[0]: # type: ignore
18
+ tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
19
+ merged[name] = torch.stack(tensors, dim=0)
20
+ return merged
21
+
22
+ def collate_fn(batch):
23
+ """
24
+ Collate a list of sample dicts into the unified batch format:
25
+ {
26
+ "features": {name: Tensor(B, ...)},
27
+ "labels": {target: Tensor(B, ...)} or None,
28
+ "ids": {id_name: Tensor(B, ...)} or None,
29
+ }
30
+ Args: batch: List of samples from DataLoader
31
+
32
+ Returns: dict: Batched data in unified format
33
+ """
34
+ if not batch:
35
+ return {"features": {}, "labels": None, "ids": None}
36
+
37
+ first = batch[0]
38
+ if isinstance(first, dict) and "features" in first:
39
+ # Streaming dataset yields already-batched chunks; avoid adding an extra dim.
40
+ if first.get("_already_batched") and len(batch) == 1:
41
+ return {
42
+ "features": first.get("features", {}),
43
+ "labels": first.get("labels"),
44
+ "ids": first.get("ids"),
45
+ }
46
+ return {
47
+ "features": stack_section(batch, "features") or {},
48
+ "labels": stack_section(batch, "labels"),
49
+ "ids": stack_section(batch, "ids"),
50
+ }
51
+
52
+ # Fallback: stack tuples/lists of tensors
53
+ num_tensors = len(first)
54
+ result = []
55
+ for i in range(num_tensors):
56
+ tensor_list = [item[i] for item in batch]
57
+ first_item = tensor_list[0]
58
+ if isinstance(first_item, torch.Tensor):
59
+ stacked = torch.cat(tensor_list, dim=0)
60
+ elif isinstance(first_item, np.ndarray):
61
+ stacked = np.concatenate(tensor_list, axis=0)
62
+ elif isinstance(first_item, list):
63
+ combined = []
64
+ for entry in tensor_list:
65
+ combined.extend(entry)
66
+ stacked = combined
67
+ else:
68
+ stacked = tensor_list
69
+ result.append(stacked)
70
+ return tuple(result)
71
+
72
+
73
+ def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
74
+ if not (isinstance(batch_data, Mapping) and "features" in batch_data):
75
+ raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
76
+ return {
77
+ "features": batch_data.get("features", {}),
78
+ "labels": batch_data.get("labels"),
79
+ "ids": batch_data.get("ids") if include_ids else None,
80
+ }