nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py
CHANGED
|
@@ -13,40 +13,42 @@ import json
|
|
|
13
13
|
import copy
|
|
14
14
|
import logging
|
|
15
15
|
import numbers
|
|
16
|
+
|
|
16
17
|
from typing import Mapping, Any
|
|
17
18
|
from nextrec.basic.session import create_session, Session
|
|
18
19
|
|
|
19
20
|
ANSI_CODES = {
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
21
|
+
"black": "\033[30m",
|
|
22
|
+
"red": "\033[31m",
|
|
23
|
+
"green": "\033[32m",
|
|
24
|
+
"yellow": "\033[33m",
|
|
25
|
+
"blue": "\033[34m",
|
|
26
|
+
"magenta": "\033[35m",
|
|
27
|
+
"cyan": "\033[36m",
|
|
28
|
+
"white": "\033[37m",
|
|
29
|
+
"bright_black": "\033[90m",
|
|
30
|
+
"bright_red": "\033[91m",
|
|
31
|
+
"bright_green": "\033[92m",
|
|
32
|
+
"bright_yellow": "\033[93m",
|
|
33
|
+
"bright_blue": "\033[94m",
|
|
34
|
+
"bright_magenta": "\033[95m",
|
|
35
|
+
"bright_cyan": "\033[96m",
|
|
36
|
+
"bright_white": "\033[97m",
|
|
36
37
|
}
|
|
37
38
|
|
|
38
|
-
ANSI_BOLD =
|
|
39
|
-
ANSI_RESET =
|
|
40
|
-
ANSI_ESCAPE_PATTERN = re.compile(r
|
|
39
|
+
ANSI_BOLD = "\033[1m"
|
|
40
|
+
ANSI_RESET = "\033[0m"
|
|
41
|
+
ANSI_ESCAPE_PATTERN = re.compile(r"\033\[[0-9;]*m")
|
|
41
42
|
|
|
42
43
|
DEFAULT_LEVEL_COLORS = {
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
44
|
+
"DEBUG": "cyan",
|
|
45
|
+
"INFO": None,
|
|
46
|
+
"WARNING": "yellow",
|
|
47
|
+
"ERROR": "red",
|
|
48
|
+
"CRITICAL": "bright_red",
|
|
48
49
|
}
|
|
49
50
|
|
|
51
|
+
|
|
50
52
|
class AnsiFormatter(logging.Formatter):
|
|
51
53
|
def __init__(
|
|
52
54
|
self,
|
|
@@ -65,16 +67,17 @@ class AnsiFormatter(logging.Formatter):
|
|
|
65
67
|
record_copy = copy.copy(record)
|
|
66
68
|
formatted = super().format(record_copy)
|
|
67
69
|
|
|
68
|
-
if self.auto_color_level and
|
|
70
|
+
if self.auto_color_level and "\033[" not in formatted:
|
|
69
71
|
color = self.level_colors.get(record.levelname)
|
|
70
72
|
if color:
|
|
71
73
|
formatted = colorize(formatted, color=color)
|
|
72
74
|
|
|
73
75
|
if self.strip_ansi:
|
|
74
|
-
return ANSI_ESCAPE_PATTERN.sub(
|
|
76
|
+
return ANSI_ESCAPE_PATTERN.sub("", formatted)
|
|
75
77
|
|
|
76
78
|
return formatted
|
|
77
79
|
|
|
80
|
+
|
|
78
81
|
def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
79
82
|
"""Apply ANSI color and bold formatting to the given text."""
|
|
80
83
|
if not color and not bold:
|
|
@@ -87,43 +90,53 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
87
90
|
result += text + ANSI_RESET
|
|
88
91
|
return result
|
|
89
92
|
|
|
93
|
+
|
|
90
94
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
91
95
|
"""Set up a logger that logs to both console and a file with ANSI formatting.
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
96
|
+
Only console output has colors; file output is stripped of ANSI codes.
|
|
97
|
+
Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
|
|
98
|
+
log file is used per experiment so multiple components (e.g. data
|
|
99
|
+
processor and model training) append to the same file instead of creating
|
|
100
|
+
separate timestamped files.
|
|
97
101
|
"""
|
|
98
102
|
|
|
99
103
|
session = create_session(str(session_id) if session_id is not None else None)
|
|
100
104
|
log_dir = session.logs_dir
|
|
101
105
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
102
|
-
log_file = log_dir /
|
|
106
|
+
log_file = log_dir / "runs.log"
|
|
107
|
+
|
|
108
|
+
console_format = "%(message)s"
|
|
109
|
+
file_format = "%(asctime)s - %(levelname)s - %(message)s"
|
|
110
|
+
date_format = "%Y-%m-%d %H:%M:%S"
|
|
103
111
|
|
|
104
|
-
console_format = '%(message)s'
|
|
105
|
-
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
|
106
|
-
date_format = '%Y-%m-%d %H:%M:%S'
|
|
107
|
-
|
|
108
112
|
logger = logging.getLogger()
|
|
109
113
|
logger.setLevel(logging.INFO)
|
|
110
114
|
|
|
111
115
|
if logger.hasHandlers():
|
|
112
116
|
logger.handlers.clear()
|
|
113
117
|
|
|
114
|
-
file_handler = logging.FileHandler(log_file, encoding=
|
|
118
|
+
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
|
115
119
|
file_handler.setLevel(logging.INFO)
|
|
116
|
-
file_handler.setFormatter(
|
|
120
|
+
file_handler.setFormatter(
|
|
121
|
+
AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True)
|
|
122
|
+
)
|
|
117
123
|
|
|
118
124
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
119
125
|
console_handler.setLevel(logging.INFO)
|
|
120
|
-
console_handler.setFormatter(
|
|
126
|
+
console_handler.setFormatter(
|
|
127
|
+
AnsiFormatter(
|
|
128
|
+
console_format,
|
|
129
|
+
datefmt=date_format,
|
|
130
|
+
auto_color_level=True,
|
|
131
|
+
)
|
|
132
|
+
)
|
|
121
133
|
|
|
122
134
|
logger.addHandler(file_handler)
|
|
123
135
|
logger.addHandler(console_handler)
|
|
124
|
-
|
|
136
|
+
|
|
125
137
|
return logger
|
|
126
138
|
|
|
139
|
+
|
|
127
140
|
class TrainingLogger:
|
|
128
141
|
def __init__(
|
|
129
142
|
self,
|
|
@@ -146,7 +159,9 @@ class TrainingLogger:
|
|
|
146
159
|
try:
|
|
147
160
|
from torch.utils.tensorboard import SummaryWriter # type: ignore
|
|
148
161
|
except ImportError:
|
|
149
|
-
logging.warning(
|
|
162
|
+
logging.warning(
|
|
163
|
+
"[TrainingLogger] tensorboard not installed, disable tensorboard logging."
|
|
164
|
+
)
|
|
150
165
|
self.enable_tensorboard = False
|
|
151
166
|
return
|
|
152
167
|
tb_dir = self.session.logs_dir / "tensorboard"
|
|
@@ -158,7 +173,9 @@ class TrainingLogger:
|
|
|
158
173
|
def tensorboard_logdir(self):
|
|
159
174
|
return self.tb_dir
|
|
160
175
|
|
|
161
|
-
def format_metrics(
|
|
176
|
+
def format_metrics(
|
|
177
|
+
self, metrics: Mapping[str, Any], split: str
|
|
178
|
+
) -> dict[str, float]:
|
|
162
179
|
formatted: dict[str, float] = {}
|
|
163
180
|
for key, value in metrics.items():
|
|
164
181
|
if isinstance(value, numbers.Number):
|
|
@@ -170,7 +187,9 @@ class TrainingLogger:
|
|
|
170
187
|
continue
|
|
171
188
|
return formatted
|
|
172
189
|
|
|
173
|
-
def log_metrics(
|
|
190
|
+
def log_metrics(
|
|
191
|
+
self, metrics: Mapping[str, Any], step: int, split: str = "train"
|
|
192
|
+
) -> None:
|
|
174
193
|
payload = self.format_metrics(metrics, split)
|
|
175
194
|
payload["step"] = int(step)
|
|
176
195
|
with self.log_path.open("a", encoding="utf-8") as f:
|
|
@@ -188,4 +207,4 @@ class TrainingLogger:
|
|
|
188
207
|
if self.tb_writer:
|
|
189
208
|
self.tb_writer.flush()
|
|
190
209
|
self.tb_writer.close()
|
|
191
|
-
self.tb_writer = None
|
|
210
|
+
self.tb_writer = None
|