nextrec 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/METADATA +7 -41
  45. nextrec-0.1.3.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py CHANGED
@@ -14,36 +14,37 @@ import datetime
14
14
  import logging
15
15
 
16
16
  ANSI_CODES = {
17
- 'black': '\033[30m',
18
- 'red': '\033[31m',
19
- 'green': '\033[32m',
20
- 'yellow': '\033[33m',
21
- 'blue': '\033[34m',
22
- 'magenta': '\033[35m',
23
- 'cyan': '\033[36m',
24
- 'white': '\033[37m',
25
- 'bright_black': '\033[90m',
26
- 'bright_red': '\033[91m',
27
- 'bright_green': '\033[92m',
28
- 'bright_yellow': '\033[93m',
29
- 'bright_blue': '\033[94m',
30
- 'bright_magenta': '\033[95m',
31
- 'bright_cyan': '\033[96m',
32
- 'bright_white': '\033[97m',
17
+ "black": "\033[30m",
18
+ "red": "\033[31m",
19
+ "green": "\033[32m",
20
+ "yellow": "\033[33m",
21
+ "blue": "\033[34m",
22
+ "magenta": "\033[35m",
23
+ "cyan": "\033[36m",
24
+ "white": "\033[37m",
25
+ "bright_black": "\033[90m",
26
+ "bright_red": "\033[91m",
27
+ "bright_green": "\033[92m",
28
+ "bright_yellow": "\033[93m",
29
+ "bright_blue": "\033[94m",
30
+ "bright_magenta": "\033[95m",
31
+ "bright_cyan": "\033[96m",
32
+ "bright_white": "\033[97m",
33
33
  }
34
34
 
35
- ANSI_BOLD = '\033[1m'
36
- ANSI_RESET = '\033[0m'
37
- ANSI_ESCAPE_PATTERN = re.compile(r'\033\[[0-9;]*m')
35
+ ANSI_BOLD = "\033[1m"
36
+ ANSI_RESET = "\033[0m"
37
+ ANSI_ESCAPE_PATTERN = re.compile(r"\033\[[0-9;]*m")
38
38
 
39
39
  DEFAULT_LEVEL_COLORS = {
40
- 'DEBUG': 'cyan',
41
- 'INFO': None,
42
- 'WARNING': 'yellow',
43
- 'ERROR': 'red',
44
- 'CRITICAL': 'bright_red',
40
+ "DEBUG": "cyan",
41
+ "INFO": None,
42
+ "WARNING": "yellow",
43
+ "ERROR": "red",
44
+ "CRITICAL": "bright_red",
45
45
  }
46
46
 
47
+
47
48
  class AnsiFormatter(logging.Formatter):
48
49
  def __init__(
49
50
  self,
@@ -62,16 +63,17 @@ class AnsiFormatter(logging.Formatter):
62
63
  record_copy = copy.copy(record)
63
64
  formatted = super().format(record_copy)
64
65
 
65
- if self.auto_color_level and '\033[' not in formatted:
66
+ if self.auto_color_level and "\033[" not in formatted:
66
67
  color = self.level_colors.get(record.levelname)
67
68
  if color:
68
69
  formatted = colorize(formatted, color=color)
69
70
 
70
71
  if self.strip_ansi:
71
- return ANSI_ESCAPE_PATTERN.sub('', formatted)
72
+ return ANSI_ESCAPE_PATTERN.sub("", formatted)
72
73
 
73
74
  return formatted
74
75
 
76
+
75
77
  def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
76
78
  """Apply ANSI color and bold formatting to the given text."""
77
79
  if not color and not bold:
@@ -89,36 +91,47 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
89
91
 
90
92
  return result
91
93
 
94
+
92
95
  def setup_logger(log_dir: str | None = None):
93
96
  """Set up a logger that logs to both console and a file with ANSI formatting.
94
- Only console output has colors; file output is stripped of ANSI codes.
97
+ Only console output has colors; file output is stripped of ANSI codes.
95
98
  """
96
99
  if log_dir is None:
97
100
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
98
101
  log_dir = os.path.join(project_root, "..", "logs")
99
-
102
+
100
103
  os.makedirs(log_dir, exist_ok=True)
101
- log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
104
+ log_file = os.path.join(
105
+ log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
106
+ )
107
+
108
+ console_format = "%(message)s"
109
+ file_format = "%(asctime)s - %(levelname)s - %(message)s"
110
+ date_format = "%H:%M:%S"
102
111
 
103
- console_format = '%(message)s'
104
- file_format = '%(asctime)s - %(levelname)s - %(message)s'
105
- date_format = '%H:%M:%S'
106
-
107
112
  logger = logging.getLogger()
108
113
  logger.setLevel(logging.INFO)
109
114
 
110
115
  if logger.hasHandlers():
111
116
  logger.handlers.clear()
112
117
 
113
- file_handler = logging.FileHandler(log_file, encoding='utf-8')
118
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
114
119
  file_handler.setLevel(logging.INFO)
115
- file_handler.setFormatter(AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True))
120
+ file_handler.setFormatter(
121
+ AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True)
122
+ )
116
123
 
117
124
  console_handler = logging.StreamHandler(sys.stdout)
118
125
  console_handler.setLevel(logging.INFO)
119
- console_handler.setFormatter(AnsiFormatter(console_format, datefmt=date_format, auto_color_level=True,))
126
+ console_handler.setFormatter(
127
+ AnsiFormatter(
128
+ console_format,
129
+ datefmt=date_format,
130
+ auto_color_level=True,
131
+ )
132
+ )
120
133
 
121
134
  logger.addHandler(file_handler)
122
135
  logger.addHandler(console_handler)
123
-
136
+
124
137
  return logger