nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py CHANGED
@@ -14,37 +14,36 @@ import datetime
14
14
  import logging
15
15
 
16
16
  ANSI_CODES = {
17
- "black": "\033[30m",
18
- "red": "\033[31m",
19
- "green": "\033[32m",
20
- "yellow": "\033[33m",
21
- "blue": "\033[34m",
22
- "magenta": "\033[35m",
23
- "cyan": "\033[36m",
24
- "white": "\033[37m",
25
- "bright_black": "\033[90m",
26
- "bright_red": "\033[91m",
27
- "bright_green": "\033[92m",
28
- "bright_yellow": "\033[93m",
29
- "bright_blue": "\033[94m",
30
- "bright_magenta": "\033[95m",
31
- "bright_cyan": "\033[96m",
32
- "bright_white": "\033[97m",
17
+ 'black': '\033[30m',
18
+ 'red': '\033[31m',
19
+ 'green': '\033[32m',
20
+ 'yellow': '\033[33m',
21
+ 'blue': '\033[34m',
22
+ 'magenta': '\033[35m',
23
+ 'cyan': '\033[36m',
24
+ 'white': '\033[37m',
25
+ 'bright_black': '\033[90m',
26
+ 'bright_red': '\033[91m',
27
+ 'bright_green': '\033[92m',
28
+ 'bright_yellow': '\033[93m',
29
+ 'bright_blue': '\033[94m',
30
+ 'bright_magenta': '\033[95m',
31
+ 'bright_cyan': '\033[96m',
32
+ 'bright_white': '\033[97m',
33
33
  }
34
34
 
35
- ANSI_BOLD = "\033[1m"
36
- ANSI_RESET = "\033[0m"
37
- ANSI_ESCAPE_PATTERN = re.compile(r"\033\[[0-9;]*m")
35
+ ANSI_BOLD = '\033[1m'
36
+ ANSI_RESET = '\033[0m'
37
+ ANSI_ESCAPE_PATTERN = re.compile(r'\033\[[0-9;]*m')
38
38
 
39
39
  DEFAULT_LEVEL_COLORS = {
40
- "DEBUG": "cyan",
41
- "INFO": None,
42
- "WARNING": "yellow",
43
- "ERROR": "red",
44
- "CRITICAL": "bright_red",
40
+ 'DEBUG': 'cyan',
41
+ 'INFO': None,
42
+ 'WARNING': 'yellow',
43
+ 'ERROR': 'red',
44
+ 'CRITICAL': 'bright_red',
45
45
  }
46
46
 
47
-
48
47
  class AnsiFormatter(logging.Formatter):
49
48
  def __init__(
50
49
  self,
@@ -63,17 +62,16 @@ class AnsiFormatter(logging.Formatter):
63
62
  record_copy = copy.copy(record)
64
63
  formatted = super().format(record_copy)
65
64
 
66
- if self.auto_color_level and "\033[" not in formatted:
65
+ if self.auto_color_level and '\033[' not in formatted:
67
66
  color = self.level_colors.get(record.levelname)
68
67
  if color:
69
68
  formatted = colorize(formatted, color=color)
70
69
 
71
70
  if self.strip_ansi:
72
- return ANSI_ESCAPE_PATTERN.sub("", formatted)
71
+ return ANSI_ESCAPE_PATTERN.sub('', formatted)
73
72
 
74
73
  return formatted
75
74
 
76
-
77
75
  def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
78
76
  """Apply ANSI color and bold formatting to the given text."""
79
77
  if not color and not bold:
@@ -91,47 +89,36 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
91
89
 
92
90
  return result
93
91
 
94
-
95
92
  def setup_logger(log_dir: str | None = None):
96
93
  """Set up a logger that logs to both console and a file with ANSI formatting.
97
- Only console output has colors; file output is stripped of ANSI codes.
94
+ Only console output has colors; file output is stripped of ANSI codes.
98
95
  """
99
96
  if log_dir is None:
100
97
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
101
98
  log_dir = os.path.join(project_root, "..", "logs")
102
-
99
+
103
100
  os.makedirs(log_dir, exist_ok=True)
104
- log_file = os.path.join(
105
- log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
106
- )
107
-
108
- console_format = "%(message)s"
109
- file_format = "%(asctime)s - %(levelname)s - %(message)s"
110
- date_format = "%H:%M:%S"
101
+ log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
111
102
 
103
+ console_format = '%(message)s'
104
+ file_format = '%(asctime)s - %(levelname)s - %(message)s'
105
+ date_format = '%H:%M:%S'
106
+
112
107
  logger = logging.getLogger()
113
108
  logger.setLevel(logging.INFO)
114
109
 
115
110
  if logger.hasHandlers():
116
111
  logger.handlers.clear()
117
112
 
118
- file_handler = logging.FileHandler(log_file, encoding="utf-8")
113
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
119
114
  file_handler.setLevel(logging.INFO)
120
- file_handler.setFormatter(
121
- AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True)
122
- )
115
+ file_handler.setFormatter(AnsiFormatter(file_format, datefmt=date_format, strip_ansi=True))
123
116
 
124
117
  console_handler = logging.StreamHandler(sys.stdout)
125
118
  console_handler.setLevel(logging.INFO)
126
- console_handler.setFormatter(
127
- AnsiFormatter(
128
- console_format,
129
- datefmt=date_format,
130
- auto_color_level=True,
131
- )
132
- )
119
+ console_handler.setFormatter(AnsiFormatter(console_format, datefmt=date_format, auto_color_level=True,))
133
120
 
134
121
  logger.addHandler(file_handler)
135
122
  logger.addHandler(console_handler)
136
-
123
+
137
124
  return logger