nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py CHANGED
@@ -13,40 +13,42 @@ import json
13
13
  import copy
14
14
  import logging
15
15
  import numbers
16
+
16
17
  from typing import Mapping, Any
17
18
  from nextrec.basic.session import create_session, Session
18
19
 
19
20
  ANSI_CODES = {
20
- 'black': '\033[30m',
21
- 'red': '\033[31m',
22
- 'green': '\033[32m',
23
- 'yellow': '\033[33m',
24
- 'blue': '\033[34m',
25
- 'magenta': '\033[35m',
26
- 'cyan': '\033[36m',
27
- 'white': '\033[37m',
28
- 'bright_black': '\033[90m',
29
- 'bright_red': '\033[91m',
30
- 'bright_green': '\033[92m',
31
- 'bright_yellow': '\033[93m',
32
- 'bright_blue': '\033[94m',
33
- 'bright_magenta': '\033[95m',
34
- 'bright_cyan': '\033[96m',
35
- 'bright_white': '\033[97m',
21
+ "black": "\033[30m",
22
+ "red": "\033[31m",
23
+ "green": "\033[32m",
24
+ "yellow": "\033[33m",
25
+ "blue": "\033[34m",
26
+ "magenta": "\033[35m",
27
+ "cyan": "\033[36m",
28
+ "white": "\033[37m",
29
+ "bright_black": "\033[90m",
30
+ "bright_red": "\033[91m",
31
+ "bright_green": "\033[92m",
32
+ "bright_yellow": "\033[93m",
33
+ "bright_blue": "\033[94m",
34
+ "bright_magenta": "\033[95m",
35
+ "bright_cyan": "\033[96m",
36
+ "bright_white": "\033[97m",
36
37
  }
37
38
 
38
- ANSI_BOLD = '\033[1m'
39
- ANSI_RESET = '\033[0m'
40
- ANSI_ESCAPE_PATTERN = re.compile(r'\033\[[0-9;]*m')
39
+ ANSI_BOLD = "\033[1m"
40
+ ANSI_RESET = "\033[0m"
41
+ ANSI_ESCAPE_PATTERN = re.compile(r"\033\[[0-9;]*m")
41
42
 
42
43
  DEFAULT_LEVEL_COLORS = {
43
- 'DEBUG': 'cyan',
44
- 'INFO': None,
45
- 'WARNING': 'yellow',
46
- 'ERROR': 'red',
47
- 'CRITICAL': 'bright_red',
44
+ "DEBUG": "cyan",
45
+ "INFO": None,
46
+ "WARNING": "yellow",
47
+ "ERROR": "red",
48
+ "CRITICAL": "bright_red",
48
49
  }
49
50
 
51
+
50
52
  class AnsiFormatter(logging.Formatter):
51
53
  def __init__(
52
54
  self,
@@ -65,16 +67,17 @@ class AnsiFormatter(logging.Formatter):
65
67
  record_copy = copy.copy(record)
66
68
  formatted = super().format(record_copy)
67
69
 
68
- if self.auto_color_level and '\033[' not in formatted:
70
+ if self.auto_color_level and "\033[" not in formatted:
69
71
  color = self.level_colors.get(record.levelname)
70
72
  if color:
71
73
  formatted = colorize(formatted, color=color)
72
74
 
73
75
  if self.strip_ansi:
74
- return ANSI_ESCAPE_PATTERN.sub('', formatted)
76
+ return ANSI_ESCAPE_PATTERN.sub("", formatted)
75
77
 
76
78
  return formatted
77
79
 
80
+
78
81
  def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
79
82
  """Apply ANSI color and bold formatting to the given text."""
80
83
  if not color and not bold:
@@ -87,43 +90,53 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
87
90
  result += text + ANSI_RESET
88
91
  return result
89
92
 
93
+
90
94
  def setup_logger(session_id: str | os.PathLike | None = None):
91
95
  """Set up a logger that logs to both console and a file with ANSI formatting.
92
- Only console output has colors; file output is stripped of ANSI codes.
93
- Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
94
- log file is used per experiment so multiple components (e.g. data
95
- processor and model training) append to the same file instead of creating
96
- separate timestamped files.
96
+ Only console output has colors; file output is stripped of ANSI codes.
97
+ Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
98
+ log file is used per experiment so multiple components (e.g. data
99
+ processor and model training) append to the same file instead of creating
100
+ separate timestamped files.
97
101
  """
98
102
 
99
103
  session = create_session(str(session_id) if session_id is not None else None)
100
104
  log_dir = session.logs_dir
101
105
  log_dir.mkdir(parents=True, exist_ok=True)
102
- log_file = log_dir / f"{session.log_basename}.log"
106
+ log_file = log_dir / "runs.log"
107
+
108
+ console_format = "%(message)s"
109
+ file_format = "%(asctime)s - %(levelname)s - %(message)s"
110
+ date_format = "%Y-%m-%d %H:%M:%S"
103
111
 
104
- console_format = '%(message)s'
105
- file_format = '%(asctime)s - %(levelname)s - %(message)s'
106
- date_format = '%Y-%m-%d %H:%M:%S'
107
-
108
112
  logger = logging.getLogger()
109
113
  logger.setLevel(logging.INFO)
110
114
 
111
115
  if logger.hasHandlers():
112
116
  logger.handlers.clear()
113
117
 
114
- file_handler = logging.FileHandler(log_file, encoding='utf-8')
118
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
115
119
  file_handler.setLevel(logging.INFO)
116
- file_handler.setFormatter(AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True))
120
+ file_handler.setFormatter(
121
+ AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True)
122
+ )
117
123
 
118
124
  console_handler = logging.StreamHandler(sys.stdout)
119
125
  console_handler.setLevel(logging.INFO)
120
- console_handler.setFormatter(AnsiFormatter(console_format, datefmt=date_format, auto_color_level=True,))
126
+ console_handler.setFormatter(
127
+ AnsiFormatter(
128
+ console_format,
129
+ datefmt=date_format,
130
+ auto_color_level=True,
131
+ )
132
+ )
121
133
 
122
134
  logger.addHandler(file_handler)
123
135
  logger.addHandler(console_handler)
124
-
136
+
125
137
  return logger
126
138
 
139
+
127
140
  class TrainingLogger:
128
141
  def __init__(
129
142
  self,
@@ -146,7 +159,9 @@ class TrainingLogger:
146
159
  try:
147
160
  from torch.utils.tensorboard import SummaryWriter # type: ignore
148
161
  except ImportError:
149
- logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
162
+ logging.warning(
163
+ "[TrainingLogger] tensorboard not installed, disable tensorboard logging."
164
+ )
150
165
  self.enable_tensorboard = False
151
166
  return
152
167
  tb_dir = self.session.logs_dir / "tensorboard"
@@ -158,7 +173,9 @@ class TrainingLogger:
158
173
  def tensorboard_logdir(self):
159
174
  return self.tb_dir
160
175
 
161
- def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
176
+ def format_metrics(
177
+ self, metrics: Mapping[str, Any], split: str
178
+ ) -> dict[str, float]:
162
179
  formatted: dict[str, float] = {}
163
180
  for key, value in metrics.items():
164
181
  if isinstance(value, numbers.Number):
@@ -170,7 +187,9 @@ class TrainingLogger:
170
187
  continue
171
188
  return formatted
172
189
 
173
- def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
190
+ def log_metrics(
191
+ self, metrics: Mapping[str, Any], step: int, split: str = "train"
192
+ ) -> None:
174
193
  payload = self.format_metrics(metrics, split)
175
194
  payload["step"] = int(step)
176
195
  with self.log_path.open("a", encoding="utf-8") as f:
@@ -188,4 +207,4 @@ class TrainingLogger:
188
207
  if self.tb_writer:
189
208
  self.tb_writer.flush()
190
209
  self.tb_writer.close()
191
- self.tb_writer = None
210
+ self.tb_writer = None