nextrec 0.3.3__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.3.3 → nextrec-0.3.5}/PKG-INFO +3 -3
- {nextrec-0.3.3 → nextrec-0.3.5}/README.md +2 -2
- {nextrec-0.3.3 → nextrec-0.3.5}/README_zh.md +2 -2
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/conf.py +1 -1
- nextrec-0.3.5/nextrec/__version__.py +1 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/features.py +1 -1
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/loggers.py +71 -8
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/model.py +45 -11
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/session.py +3 -10
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/__init__.py +47 -9
- nextrec-0.3.5/nextrec/data/batch_utils.py +80 -0
- nextrec-0.3.5/nextrec/data/data_processing.py +152 -0
- nextrec-0.3.5/nextrec/data/data_utils.py +35 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/dataloader.py +6 -4
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/data/preprocessor.py +39 -85
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/poso.py +1 -1
- nextrec-0.3.5/nextrec/utils/__init__.py +68 -0
- nextrec-0.3.5/nextrec/utils/device.py +37 -0
- nextrec-0.3.5/nextrec/utils/feature.py +13 -0
- nextrec-0.3.5/nextrec/utils/file.py +70 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/initializer.py +0 -8
- nextrec-0.3.5/nextrec/utils/model.py +22 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/optimizer.py +0 -19
- nextrec-0.3.5/nextrec/utils/tensor.py +61 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/pyproject.toml +1 -1
- {nextrec-0.3.3 → nextrec-0.3.5}/requirements.txt +2 -1
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_match_dssm.py +1 -1
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_multitask.py +5 -60
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/example_ranking_din.py +0 -42
- nextrec-0.3.3/nextrec/__version__.py +0 -1
- nextrec-0.3.3/nextrec/data/data_utils.py +0 -268
- nextrec-0.3.3/nextrec/utils/__init__.py +0 -18
- nextrec-0.3.3/nextrec/utils/common.py +0 -60
- {nextrec-0.3.3 → nextrec-0.3.5}/.github/workflows/publish.yml +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/.github/workflows/tests.yml +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/.gitignore +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/.readthedocs.yaml +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/CONTRIBUTING.md +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/LICENSE +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/MANIFEST.in +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Feature Configuration.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Model Parameters.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Training Configuration.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/Training logs.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/logo.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/mmoe_tutorial.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/nextrec_diagram_en.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/nextrec_diagram_zh.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/asserts/test data.png +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/dataset/match_task.csv +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/dataset/multitask_task.csv +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/dataset/ranking_task.csv +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/en/Getting started guide.md +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/Makefile +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/index.md +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/make.bat +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/modules.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/nextrec.utils.rst +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/activation.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/callback.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/layers.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/basic/metrics.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/listwise.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/loss_utils.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/pairwise.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/loss/pointwise.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/hstu.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/dssm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/dssm_v2.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/mind.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/sdm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/match/youtube_dnn.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/esmm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/mmoe.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/ple.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/multi_task/share_bottom.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/afm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/autoint.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dcn.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dcn_v2.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/deepfm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/dien.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/din.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/fibinet.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/fm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/masknet.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/pnn.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/widedeep.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/models/ranking/xdeepfm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/pytest.ini +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/__init__.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/conftest.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/run_tests.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_layers.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_losses.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_match_models.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_multitask_models.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_preprocessor.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_ranking_models.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test/test_utils.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/test_requirements.txt +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/movielen_match_dssm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/movielen_ranking_deepfm.py +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/zh/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/notebooks/zh/Hands on nextrec.ipynb +0 -0
- {nextrec-0.3.3 → nextrec-0.3.5}/tutorials/run_all_tutorials.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.5
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
|
|
|
63
63
|

|
|
64
64
|

|
|
65
65
|

|
|
66
|
-

|
|
67
67
|
|
|
68
68
|
English | [中文文档](README_zh.md)
|
|
69
69
|
|
|
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
|
|
|
110
110
|
- [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
|
|
111
111
|
- [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
|
|
112
112
|
|
|
113
|
-
> Current version [0.3.
|
|
113
|
+
> Current version [0.3.5]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
|
|
114
114
|
|
|
115
115
|
## 5-Minute Quick Start
|
|
116
116
|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
English | [中文文档](README_zh.md)
|
|
13
13
|
|
|
@@ -54,7 +54,7 @@ To dive deeper, Jupyter notebooks are available:
|
|
|
54
54
|
- [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
|
|
55
55
|
- [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
|
|
56
56
|
|
|
57
|
-
> Current version [0.3.
|
|
57
|
+
> Current version [0.3.5]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
|
|
58
58
|
|
|
59
59
|
## 5-Minute Quick Start
|
|
60
60
|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
[English Version](README.md) | 中文文档
|
|
13
13
|
|
|
@@ -54,7 +54,7 @@ NextRec采用模块化、低耦合的工程设计,使得推荐系统从数据
|
|
|
54
54
|
- [如何上手NextRec框架](/tutorials/notebooks/zh/Hands%20on%20nextrec.ipynb)
|
|
55
55
|
- [如何使用数据处理器进行数据预处理](/tutorials/notebooks/zh/Hands%20on%20dataprocessor.ipynb)
|
|
56
56
|
|
|
57
|
-
> 当前版本[0.3.
|
|
57
|
+
> 当前版本[0.3.5],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
|
|
58
58
|
|
|
59
59
|
## 5分钟快速上手
|
|
60
60
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.3.5"
|
|
@@ -7,7 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
7
7
|
"""
|
|
8
8
|
import torch
|
|
9
9
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
10
|
-
from nextrec.utils.
|
|
10
|
+
from nextrec.utils.feature import normalize_to_list
|
|
11
11
|
|
|
12
12
|
class BaseFeature(object):
|
|
13
13
|
def __repr__(self):
|
|
@@ -2,17 +2,19 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 03/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
import os
|
|
11
10
|
import re
|
|
12
11
|
import sys
|
|
12
|
+
import json
|
|
13
13
|
import copy
|
|
14
14
|
import logging
|
|
15
|
-
|
|
15
|
+
import numbers
|
|
16
|
+
from typing import Mapping, Any
|
|
17
|
+
from nextrec.basic.session import create_session, Session
|
|
16
18
|
|
|
17
19
|
ANSI_CODES = {
|
|
18
20
|
'black': '\033[30m',
|
|
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
77
79
|
"""Apply ANSI color and bold formatting to the given text."""
|
|
78
80
|
if not color and not bold:
|
|
79
81
|
return text
|
|
80
|
-
|
|
81
82
|
result = ""
|
|
82
|
-
|
|
83
83
|
if bold:
|
|
84
84
|
result += ANSI_BOLD
|
|
85
|
-
|
|
86
85
|
if color and color in ANSI_CODES:
|
|
87
86
|
result += ANSI_CODES[color]
|
|
88
|
-
|
|
89
87
|
result += text + ANSI_RESET
|
|
90
|
-
|
|
91
88
|
return result
|
|
92
89
|
|
|
93
90
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
126
123
|
logger.addHandler(console_handler)
|
|
127
124
|
|
|
128
125
|
return logger
|
|
126
|
+
|
|
127
|
+
class TrainingLogger:
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
session: Session,
|
|
131
|
+
enable_tensorboard: bool,
|
|
132
|
+
log_name: str = "training_metrics.jsonl",
|
|
133
|
+
) -> None:
|
|
134
|
+
self.session = session
|
|
135
|
+
self.enable_tensorboard = enable_tensorboard
|
|
136
|
+
self.log_path = session.metrics_dir / log_name
|
|
137
|
+
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
self.tb_writer = None
|
|
140
|
+
self.tb_dir = None
|
|
141
|
+
|
|
142
|
+
if self.enable_tensorboard:
|
|
143
|
+
self._init_tensorboard()
|
|
144
|
+
|
|
145
|
+
def _init_tensorboard(self) -> None:
|
|
146
|
+
try:
|
|
147
|
+
from torch.utils.tensorboard import SummaryWriter # type: ignore
|
|
148
|
+
except ImportError:
|
|
149
|
+
logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
|
|
150
|
+
self.enable_tensorboard = False
|
|
151
|
+
return
|
|
152
|
+
tb_dir = self.session.logs_dir / "tensorboard"
|
|
153
|
+
tb_dir.mkdir(parents=True, exist_ok=True)
|
|
154
|
+
self.tb_dir = tb_dir
|
|
155
|
+
self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def tensorboard_logdir(self):
|
|
159
|
+
return self.tb_dir
|
|
160
|
+
|
|
161
|
+
def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
|
|
162
|
+
formatted: dict[str, float] = {}
|
|
163
|
+
for key, value in metrics.items():
|
|
164
|
+
if isinstance(value, numbers.Number):
|
|
165
|
+
formatted[f"{split}/{key}"] = float(value)
|
|
166
|
+
elif hasattr(value, "item"):
|
|
167
|
+
try:
|
|
168
|
+
formatted[f"{split}/{key}"] = float(value.item())
|
|
169
|
+
except Exception:
|
|
170
|
+
continue
|
|
171
|
+
return formatted
|
|
172
|
+
|
|
173
|
+
def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
|
|
174
|
+
payload = self.format_metrics(metrics, split)
|
|
175
|
+
payload["step"] = int(step)
|
|
176
|
+
with self.log_path.open("a", encoding="utf-8") as f:
|
|
177
|
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
178
|
+
|
|
179
|
+
if not self.tb_writer:
|
|
180
|
+
return
|
|
181
|
+
step = int(payload.get("step", 0))
|
|
182
|
+
for key, value in payload.items():
|
|
183
|
+
if key == "step":
|
|
184
|
+
continue
|
|
185
|
+
self.tb_writer.add_scalar(key, value, global_step=step)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
if self.tb_writer:
|
|
189
|
+
self.tb_writer.flush()
|
|
190
|
+
self.tb_writer.close()
|
|
191
|
+
self.tb_writer = None
|
|
@@ -10,6 +10,8 @@ import os
|
|
|
10
10
|
import tqdm
|
|
11
11
|
import pickle
|
|
12
12
|
import logging
|
|
13
|
+
import getpass
|
|
14
|
+
import socket
|
|
13
15
|
import numpy as np
|
|
14
16
|
import pandas as pd
|
|
15
17
|
import torch
|
|
@@ -24,15 +26,17 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
24
26
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
27
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
28
|
|
|
27
|
-
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
|
+
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
28
30
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
31
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
32
|
|
|
31
33
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
32
|
-
from nextrec.data.
|
|
34
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
33
36
|
|
|
34
37
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
35
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
38
|
+
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
|
+
from nextrec.utils.tensor import to_tensor
|
|
36
40
|
|
|
37
41
|
from nextrec import __version__
|
|
38
42
|
|
|
@@ -88,6 +92,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
88
92
|
self.early_stop_patience = early_stop_patience
|
|
89
93
|
self.max_gradient_norm = 1.0
|
|
90
94
|
self.logger_initialized = False
|
|
95
|
+
self.training_logger: TrainingLogger | None = None
|
|
91
96
|
|
|
92
97
|
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
93
98
|
exclude_modules = exclude_modules or []
|
|
@@ -275,11 +280,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
275
280
|
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
276
281
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
277
282
|
user_id_column: str | None = None,
|
|
278
|
-
validation_split: float | None = None
|
|
283
|
+
validation_split: float | None = None,
|
|
284
|
+
tensorboard: bool = True,):
|
|
279
285
|
self.to(self.device)
|
|
280
286
|
if not self.logger_initialized:
|
|
281
287
|
setup_logger(session_id=self.session_id)
|
|
282
288
|
self.logger_initialized = True
|
|
289
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
283
290
|
|
|
284
291
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
285
292
|
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
@@ -303,6 +310,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
303
310
|
is_streaming = True
|
|
304
311
|
|
|
305
312
|
self.summary()
|
|
313
|
+
logging.info("")
|
|
314
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
315
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
316
|
+
if tb_dir:
|
|
317
|
+
user = getpass.getuser()
|
|
318
|
+
host = socket.gethostname()
|
|
319
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
320
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
321
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
322
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
323
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
324
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
325
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
326
|
+
|
|
306
327
|
logging.info("")
|
|
307
328
|
logging.info(colorize("=" * 80, bold=True))
|
|
308
329
|
if is_streaming:
|
|
@@ -312,7 +333,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
312
333
|
logging.info(colorize("=" * 80, bold=True))
|
|
313
334
|
logging.info("")
|
|
314
335
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
315
|
-
|
|
336
|
+
|
|
316
337
|
for epoch in range(epochs):
|
|
317
338
|
self.epoch_index = epoch
|
|
318
339
|
if is_streaming:
|
|
@@ -326,7 +347,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
326
347
|
else:
|
|
327
348
|
train_loss = train_result
|
|
328
349
|
train_metrics = None
|
|
329
|
-
|
|
350
|
+
|
|
351
|
+
train_log_payload: dict[str, float] = {}
|
|
330
352
|
# handle logging for single-task and multi-task
|
|
331
353
|
if self.nums_task == 1:
|
|
332
354
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
@@ -334,6 +356,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
334
356
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
335
357
|
log_str += f", {metrics_str}"
|
|
336
358
|
logging.info(colorize(log_str))
|
|
359
|
+
train_log_payload["loss"] = float(train_loss)
|
|
360
|
+
if train_metrics:
|
|
361
|
+
train_log_payload.update(train_metrics)
|
|
337
362
|
else:
|
|
338
363
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
339
364
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
@@ -356,12 +381,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
356
381
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
357
382
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
358
383
|
logging.info(colorize(log_str))
|
|
384
|
+
train_log_payload["loss"] = float(total_loss_val)
|
|
385
|
+
if train_metrics:
|
|
386
|
+
train_log_payload.update(train_metrics)
|
|
387
|
+
if self.training_logger:
|
|
388
|
+
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
359
389
|
if valid_loader is not None:
|
|
360
390
|
# pass user_ids only if needed for GAUC metric
|
|
361
391
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
362
392
|
if self.nums_task == 1:
|
|
363
393
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
364
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
394
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
365
395
|
else:
|
|
366
396
|
# multi task metrics
|
|
367
397
|
task_metrics = {}
|
|
@@ -378,7 +408,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
378
408
|
if target_name in task_metrics:
|
|
379
409
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
380
410
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
381
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
411
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
412
|
+
if val_metrics and self.training_logger:
|
|
413
|
+
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
382
414
|
# Handle empty validation metrics
|
|
383
415
|
if not val_metrics:
|
|
384
416
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
@@ -401,6 +433,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
401
433
|
self.best_metric = primary_metric
|
|
402
434
|
improved = True
|
|
403
435
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
436
|
+
logging.info(" ")
|
|
404
437
|
if improved:
|
|
405
438
|
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
406
439
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
@@ -431,6 +464,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
431
464
|
if valid_loader is not None:
|
|
432
465
|
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
433
466
|
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
467
|
+
if self.training_logger:
|
|
468
|
+
self.training_logger.close()
|
|
434
469
|
return self
|
|
435
470
|
|
|
436
471
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
@@ -527,6 +562,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
527
562
|
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
528
563
|
if batch_user_id is not None:
|
|
529
564
|
collected_user_ids.append(batch_user_id)
|
|
565
|
+
logging.info(" ")
|
|
530
566
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
531
567
|
if len(y_true_list) > 0:
|
|
532
568
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -956,9 +992,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
956
992
|
logger.info(f" Session ID: {self.session_id}")
|
|
957
993
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
958
994
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
959
|
-
|
|
960
|
-
logger.info("")
|
|
961
|
-
logger.info("")
|
|
995
|
+
|
|
962
996
|
|
|
963
997
|
|
|
964
998
|
class BaseMatchModel(BaseModel):
|
|
@@ -1,14 +1,5 @@
|
|
|
1
1
|
"""Session and experiment utilities.
|
|
2
2
|
|
|
3
|
-
This module centralizes session/experiment management so the rest of the
|
|
4
|
-
framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
|
|
5
|
-
|
|
6
|
-
Within that folder we keep model parameters, checkpoints, training metrics,
|
|
7
|
-
evaluation metrics, and consolidated log output. When users do not provide an
|
|
8
|
-
``experiment_id`` a timestamp-based identifier is generated once per process to
|
|
9
|
-
avoid scattering files across multiple directories. Test runs are redirected to
|
|
10
|
-
temporary folders so local trees are not polluted.
|
|
11
|
-
|
|
12
3
|
Date: create on 23/11/2025
|
|
13
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
14
5
|
"""
|
|
@@ -16,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
16
7
|
import os
|
|
17
8
|
import tempfile
|
|
18
9
|
from dataclasses import dataclass
|
|
19
|
-
from datetime import datetime
|
|
10
|
+
from datetime import datetime, timezone
|
|
20
11
|
from pathlib import Path
|
|
21
12
|
|
|
22
13
|
__all__ = [
|
|
@@ -74,6 +65,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
74
65
|
if experiment_id is not None and str(experiment_id).strip():
|
|
75
66
|
exp_id = str(experiment_id).strip()
|
|
76
67
|
else:
|
|
68
|
+
# Use local time for session naming
|
|
77
69
|
exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
|
|
78
70
|
|
|
79
71
|
if (
|
|
@@ -111,6 +103,7 @@ def resolve_save_path(
|
|
|
111
103
|
timestamp.
|
|
112
104
|
- Parent directories are created.
|
|
113
105
|
"""
|
|
106
|
+
# Use local time for file timestamps
|
|
114
107
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
|
|
115
108
|
|
|
116
109
|
normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
|
|
@@ -1,48 +1,86 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Data utilities package for NextRec
|
|
3
3
|
|
|
4
|
-
This package provides data processing and manipulation utilities
|
|
4
|
+
This package provides data processing and manipulation utilities organized by category:
|
|
5
|
+
- batch_utils: Batch collation and processing
|
|
6
|
+
- data_processing: Data manipulation and user ID extraction
|
|
7
|
+
- data_utils: Legacy module (re-exports from specialized modules)
|
|
8
|
+
- dataloader: Dataset and DataLoader implementations
|
|
9
|
+
- preprocessor: Data preprocessing pipeline
|
|
5
10
|
|
|
6
11
|
Date: create on 13/11/2025
|
|
12
|
+
Last update: 03/12/2025 (refactored)
|
|
7
13
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
8
14
|
"""
|
|
9
15
|
|
|
10
|
-
|
|
11
|
-
|
|
16
|
+
# Batch utilities
|
|
17
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
18
|
+
|
|
19
|
+
# Data processing utilities
|
|
20
|
+
from nextrec.data.data_processing import (
|
|
12
21
|
get_column_data,
|
|
13
|
-
default_output_dir,
|
|
14
22
|
split_dict_random,
|
|
15
23
|
build_eval_candidates,
|
|
24
|
+
get_user_ids,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# File utilities (from utils package)
|
|
28
|
+
from nextrec.utils.file import (
|
|
16
29
|
resolve_file_paths,
|
|
17
30
|
iter_file_chunks,
|
|
18
31
|
read_table,
|
|
19
32
|
load_dataframes,
|
|
33
|
+
default_output_dir,
|
|
20
34
|
)
|
|
21
|
-
|
|
22
|
-
|
|
35
|
+
|
|
36
|
+
# DataLoader components
|
|
23
37
|
from nextrec.data.dataloader import (
|
|
24
38
|
TensorDictDataset,
|
|
25
39
|
FileDataset,
|
|
26
40
|
RecDataLoader,
|
|
27
41
|
build_tensors_from_data,
|
|
28
42
|
)
|
|
43
|
+
|
|
44
|
+
# Preprocessor
|
|
29
45
|
from nextrec.data.preprocessor import DataProcessor
|
|
30
46
|
|
|
47
|
+
# Feature definitions
|
|
48
|
+
from nextrec.basic.features import FeatureSet
|
|
49
|
+
|
|
50
|
+
# Legacy module (for backward compatibility)
|
|
51
|
+
from nextrec.data import data_utils
|
|
52
|
+
|
|
31
53
|
__all__ = [
|
|
54
|
+
# Batch utilities
|
|
32
55
|
'collate_fn',
|
|
56
|
+
'batch_to_dict',
|
|
57
|
+
'stack_section',
|
|
58
|
+
|
|
59
|
+
# Data processing
|
|
33
60
|
'get_column_data',
|
|
34
|
-
'default_output_dir',
|
|
35
61
|
'split_dict_random',
|
|
36
62
|
'build_eval_candidates',
|
|
63
|
+
'get_user_ids',
|
|
64
|
+
|
|
65
|
+
# File utilities
|
|
37
66
|
'resolve_file_paths',
|
|
38
67
|
'iter_file_chunks',
|
|
39
68
|
'read_table',
|
|
40
69
|
'load_dataframes',
|
|
41
|
-
'
|
|
42
|
-
|
|
70
|
+
'default_output_dir',
|
|
71
|
+
|
|
72
|
+
# DataLoader
|
|
43
73
|
'TensorDictDataset',
|
|
44
74
|
'FileDataset',
|
|
45
75
|
'RecDataLoader',
|
|
46
76
|
'build_tensors_from_data',
|
|
77
|
+
|
|
78
|
+
# Preprocessor
|
|
47
79
|
'DataProcessor',
|
|
80
|
+
|
|
81
|
+
# Features
|
|
82
|
+
'FeatureSet',
|
|
83
|
+
|
|
84
|
+
# Legacy module
|
|
85
|
+
'data_utils',
|
|
48
86
|
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch collation utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Any, Mapping
|
|
11
|
+
|
|
12
|
+
def stack_section(batch: list[dict], section: str):
|
|
13
|
+
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
14
|
+
if not entries:
|
|
15
|
+
return None
|
|
16
|
+
merged: dict = {}
|
|
17
|
+
for name in entries[0]: # type: ignore
|
|
18
|
+
tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
|
|
19
|
+
merged[name] = torch.stack(tensors, dim=0)
|
|
20
|
+
return merged
|
|
21
|
+
|
|
22
|
+
def collate_fn(batch):
|
|
23
|
+
"""
|
|
24
|
+
Collate a list of sample dicts into the unified batch format:
|
|
25
|
+
{
|
|
26
|
+
"features": {name: Tensor(B, ...)},
|
|
27
|
+
"labels": {target: Tensor(B, ...)} or None,
|
|
28
|
+
"ids": {id_name: Tensor(B, ...)} or None,
|
|
29
|
+
}
|
|
30
|
+
Args: batch: List of samples from DataLoader
|
|
31
|
+
|
|
32
|
+
Returns: dict: Batched data in unified format
|
|
33
|
+
"""
|
|
34
|
+
if not batch:
|
|
35
|
+
return {"features": {}, "labels": None, "ids": None}
|
|
36
|
+
|
|
37
|
+
first = batch[0]
|
|
38
|
+
if isinstance(first, dict) and "features" in first:
|
|
39
|
+
# Streaming dataset yields already-batched chunks; avoid adding an extra dim.
|
|
40
|
+
if first.get("_already_batched") and len(batch) == 1:
|
|
41
|
+
return {
|
|
42
|
+
"features": first.get("features", {}),
|
|
43
|
+
"labels": first.get("labels"),
|
|
44
|
+
"ids": first.get("ids"),
|
|
45
|
+
}
|
|
46
|
+
return {
|
|
47
|
+
"features": stack_section(batch, "features") or {},
|
|
48
|
+
"labels": stack_section(batch, "labels"),
|
|
49
|
+
"ids": stack_section(batch, "ids"),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Fallback: stack tuples/lists of tensors
|
|
53
|
+
num_tensors = len(first)
|
|
54
|
+
result = []
|
|
55
|
+
for i in range(num_tensors):
|
|
56
|
+
tensor_list = [item[i] for item in batch]
|
|
57
|
+
first_item = tensor_list[0]
|
|
58
|
+
if isinstance(first_item, torch.Tensor):
|
|
59
|
+
stacked = torch.cat(tensor_list, dim=0)
|
|
60
|
+
elif isinstance(first_item, np.ndarray):
|
|
61
|
+
stacked = np.concatenate(tensor_list, axis=0)
|
|
62
|
+
elif isinstance(first_item, list):
|
|
63
|
+
combined = []
|
|
64
|
+
for entry in tensor_list:
|
|
65
|
+
combined.extend(entry)
|
|
66
|
+
stacked = combined
|
|
67
|
+
else:
|
|
68
|
+
stacked = tensor_list
|
|
69
|
+
result.append(stacked)
|
|
70
|
+
return tuple(result)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
|
|
74
|
+
if not (isinstance(batch_data, Mapping) and "features" in batch_data):
|
|
75
|
+
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
76
|
+
return {
|
|
77
|
+
"features": batch_data.get("features", {}),
|
|
78
|
+
"labels": batch_data.get("labels"),
|
|
79
|
+
"ids": batch_data.get("ids") if include_ids else None,
|
|
80
|
+
}
|