nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +3 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +259 -326
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/__init__.py +0 -4
  16. nextrec/loss/grad_norm.py +3 -3
  17. nextrec/models/multi_task/esmm.py +4 -6
  18. nextrec/models/multi_task/mmoe.py +4 -6
  19. nextrec/models/multi_task/ple.py +6 -8
  20. nextrec/models/multi_task/poso.py +5 -7
  21. nextrec/models/multi_task/share_bottom.py +6 -8
  22. nextrec/models/ranking/afm.py +4 -6
  23. nextrec/models/ranking/autoint.py +4 -6
  24. nextrec/models/ranking/dcn.py +8 -7
  25. nextrec/models/ranking/dcn_v2.py +4 -6
  26. nextrec/models/ranking/deepfm.py +5 -7
  27. nextrec/models/ranking/dien.py +8 -7
  28. nextrec/models/ranking/din.py +8 -7
  29. nextrec/models/ranking/eulernet.py +5 -7
  30. nextrec/models/ranking/ffm.py +5 -7
  31. nextrec/models/ranking/fibinet.py +4 -6
  32. nextrec/models/ranking/fm.py +4 -6
  33. nextrec/models/ranking/lr.py +4 -6
  34. nextrec/models/ranking/masknet.py +8 -9
  35. nextrec/models/ranking/pnn.py +4 -6
  36. nextrec/models/ranking/widedeep.py +5 -7
  37. nextrec/models/ranking/xdeepfm.py +8 -7
  38. nextrec/models/retrieval/dssm.py +4 -10
  39. nextrec/models/retrieval/dssm_v2.py +0 -6
  40. nextrec/models/retrieval/mind.py +4 -10
  41. nextrec/models/retrieval/sdm.py +4 -10
  42. nextrec/models/retrieval/youtube_dnn.py +4 -10
  43. nextrec/models/sequential/hstu.py +1 -3
  44. nextrec/utils/__init__.py +17 -15
  45. nextrec/utils/config.py +15 -5
  46. nextrec/utils/console.py +2 -2
  47. nextrec/utils/feature.py +2 -2
  48. nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
  49. nextrec/utils/torch_utils.py +57 -112
  50. nextrec/utils/types.py +63 -0
  51. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
  52. nextrec-0.4.22.dist-info/RECORD +81 -0
  53. nextrec-0.4.20.dist-info/RECORD +0 -79
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
  55. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
  56. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -13,7 +13,7 @@ import numbers
13
13
  import os
14
14
  import re
15
15
  import sys
16
- from typing import Any, Mapping
16
+ from typing import Any
17
17
 
18
18
  from nextrec.basic.session import Session, create_session
19
19
 
@@ -101,6 +101,7 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
101
101
  def setup_logger(session_id: str | os.PathLike | None = None):
102
102
  """Set up a logger that logs to both console and a file with ANSI formatting.
103
103
  Only console output has colors; file output is stripped of ANSI codes.
104
+
104
105
  Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
105
106
  log file is used per experiment so multiple components (e.g. data
106
107
  processor and model training) append to the same file instead of creating
@@ -144,45 +145,27 @@ def setup_logger(session_id: str | os.PathLike | None = None):
144
145
  return logger
145
146
 
146
147
 
147
- class TrainingLogger:
148
+ class MetricsLoggerBackend:
149
+ def log_payload(self, payload: dict[str, float]) -> None:
150
+ raise NotImplementedError
151
+
152
+ def close(self) -> None:
153
+ return None
154
+
155
+
156
+ class BasicLogger:
148
157
  def __init__(
149
158
  self,
150
159
  session: Session,
151
- use_tensorboard: bool,
152
160
  log_name: str = "training_metrics.jsonl",
161
+ backends: list[MetricsLoggerBackend] | None = None,
153
162
  ) -> None:
154
163
  self.session = session
155
- self.use_tensorboard = use_tensorboard
156
164
  self.log_path = session.metrics_dir / log_name
157
165
  self.log_path.parent.mkdir(parents=True, exist_ok=True)
166
+ self.backends = backends or []
158
167
 
159
- self.tb_writer = None
160
- self.tb_dir = None
161
-
162
- if self.use_tensorboard:
163
- self._init_tensorboard()
164
-
165
- def _init_tensorboard(self) -> None:
166
- try:
167
- from torch.utils.tensorboard import SummaryWriter # type: ignore
168
- except ImportError:
169
- logging.warning(
170
- "[TrainingLogger] tensorboard not installed, disable tensorboard logging."
171
- )
172
- self.use_tensorboard = False
173
- return
174
- tb_dir = self.session.logs_dir / "tensorboard"
175
- tb_dir.mkdir(parents=True, exist_ok=True)
176
- self.tb_dir = tb_dir
177
- self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
178
-
179
- @property
180
- def tensorboard_logdir(self):
181
- return self.tb_dir
182
-
183
- def format_metrics(
184
- self, metrics: Mapping[str, Any], split: str
185
- ) -> dict[str, float]:
168
+ def format_metrics(self, metrics: dict[str, Any], split: str) -> dict[str, float]:
186
169
  formatted: dict[str, float] = {}
187
170
  for key, value in metrics.items():
188
171
  if isinstance(value, numbers.Real):
@@ -195,23 +178,237 @@ class TrainingLogger:
195
178
  return formatted
196
179
 
197
180
  def log_metrics(
198
- self, metrics: Mapping[str, Any], step: int, split: str = "train"
181
+ self, metrics: dict[str, Any], step: int, split: str = "train"
199
182
  ) -> None:
200
183
  payload = self.format_metrics(metrics, split)
201
184
  payload["step"] = int(step)
202
185
  with self.log_path.open("a", encoding="utf-8") as f:
203
186
  f.write(json.dumps(payload, ensure_ascii=False) + "\n")
187
+ for backend in self.backends:
188
+ backend.log_payload(payload)
189
+
190
+ def close(self) -> None:
191
+ for backend in self.backends:
192
+ backend.close()
193
+
194
+
195
+ class TensorBoardLogger(MetricsLoggerBackend):
196
+ def __init__(
197
+ self,
198
+ session: Session,
199
+ enabled: bool = True,
200
+ log_dir_name: str = "tensorboard",
201
+ ) -> None:
202
+ self.enabled = enabled
203
+ self.writer = None
204
+ self.log_dir = None
205
+ if self.enabled:
206
+ self._init_writer(session, log_dir_name)
204
207
 
205
- if not self.tb_writer:
208
+ def _init_writer(self, session: Session, log_dir_name: str) -> None:
209
+ try:
210
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
211
+ except ImportError:
212
+ logging.warning(
213
+ "[TrainingLogger] tensorboard not installed, disable tensorboard logging."
214
+ )
215
+ self.enabled = False
216
+ return
217
+ log_dir = session.logs_dir / log_dir_name
218
+ log_dir.mkdir(parents=True, exist_ok=True)
219
+ self.log_dir = log_dir
220
+ self.writer = SummaryWriter(log_dir=str(log_dir))
221
+
222
+ def log_payload(self, payload: dict[str, float]) -> None:
223
+ if not self.writer:
206
224
  return
207
225
  step = int(payload.get("step", 0))
208
226
  for key, value in payload.items():
209
227
  if key == "step":
210
228
  continue
211
- self.tb_writer.add_scalar(key, value, global_step=step)
229
+ self.writer.add_scalar(key, value, global_step=step)
230
+
231
+ def close(self) -> None:
232
+ if self.writer:
233
+ self.writer.flush()
234
+ self.writer.close()
235
+ self.writer = None
236
+
237
+
238
+ class WandbLogger(MetricsLoggerBackend):
239
+ def __init__(
240
+ self,
241
+ session: Session,
242
+ enabled: bool = True,
243
+ project: str | None = None,
244
+ run_name: str | None = None,
245
+ init_run: bool = True,
246
+ **init_kwargs: Any,
247
+ ) -> None:
248
+ self.enabled = enabled
249
+ self.wandb = None
250
+ if not self.enabled:
251
+ return
252
+ try:
253
+ import wandb # type: ignore
254
+ except ImportError:
255
+ logging.warning("[WandbLogger] wandb not installed, disable wandb logging.")
256
+ self.enabled = False
257
+ return
258
+ self.wandb = wandb
259
+ if init_run and getattr(wandb, "run", None) is None:
260
+ kwargs = dict(init_kwargs)
261
+ if project is not None:
262
+ kwargs.setdefault("project", project)
263
+ if run_name is None:
264
+ run_name = session.experiment_id
265
+ if run_name is not None:
266
+ kwargs.setdefault("name", run_name)
267
+ try:
268
+ wandb.init(**kwargs)
269
+ except TypeError:
270
+ wandb.init()
271
+
272
+ def log_payload(self, payload: dict[str, float]) -> None:
273
+ if not self.enabled or self.wandb is None:
274
+ return
275
+ step = int(payload.get("step", 0))
276
+ log_payload = {k: v for k, v in payload.items() if k != "step"}
277
+ if not log_payload:
278
+ return
279
+ try:
280
+ self.wandb.log(log_payload, step=step)
281
+ except TypeError:
282
+ self.wandb.log(log_payload)
283
+
284
+
285
+ class SwanLabLogger(MetricsLoggerBackend):
286
+ def __init__(
287
+ self,
288
+ session: Session,
289
+ enabled: bool = True,
290
+ project: str | None = None,
291
+ run_name: str | None = None,
292
+ init_run: bool = True,
293
+ **init_kwargs: Any,
294
+ ) -> None:
295
+ self.enabled = enabled
296
+ self.swanlab = None
297
+ self._warned_missing_log = False
298
+ if not self.enabled:
299
+ return
300
+ try:
301
+ import swanlab # type: ignore
302
+ except ImportError:
303
+ logging.warning(
304
+ "[SwanLabLogger] swanlab not installed, disable swanlab logging."
305
+ )
306
+ self.enabled = False
307
+ return
308
+ self.swanlab = swanlab
309
+ if init_run and hasattr(swanlab, "init"):
310
+ kwargs = dict(init_kwargs)
311
+ kwargs.setdefault("logdir", str(session.logs_dir) + "/swanlog")
312
+ if project is not None:
313
+ kwargs.setdefault("project", project)
314
+ if run_name is None:
315
+ run_name = session.experiment_id
316
+ if run_name is not None:
317
+ kwargs.setdefault("name", run_name)
318
+ try:
319
+ swanlab.init(**kwargs)
320
+ except TypeError:
321
+ swanlab.init()
322
+
323
+ def log_payload(self, payload: dict[str, float]) -> None:
324
+ if not self.enabled or self.swanlab is None:
325
+ return
326
+ log_fn = getattr(self.swanlab, "log", None)
327
+ if log_fn is None:
328
+ if not self._warned_missing_log:
329
+ logging.warning(
330
+ "[SwanLabLogger] swanlab.log not found, disable swanlab logging."
331
+ )
332
+ self._warned_missing_log = True
333
+ return
334
+ step = int(payload.get("step", 0))
335
+ log_payload = {k: v for k, v in payload.items() if k != "step"}
336
+ if not log_payload:
337
+ return
338
+ try:
339
+ log_fn(log_payload, step=step)
340
+ except TypeError:
341
+ log_fn(log_payload)
342
+
343
+
344
+ class TrainingLogger(BasicLogger):
345
+ def __init__(
346
+ self,
347
+ session: Session,
348
+ use_tensorboard: bool,
349
+ log_name: str = "training_metrics.jsonl",
350
+ use_wandb: bool = False,
351
+ use_swanlab: bool = False,
352
+ config: dict[str, Any] = {},
353
+ wandb_kwargs: dict[str, Any] | None = None,
354
+ swanlab_kwargs: dict[str, Any] | None = None,
355
+ ):
356
+ self.session = session
357
+ self.use_tensorboard = use_tensorboard
358
+ self.tensorboard_logger = TensorBoardLogger(
359
+ session=session, enabled=use_tensorboard
360
+ )
361
+ self.use_tensorboard = self.tensorboard_logger.enabled
362
+ self.tb_writer = self.tensorboard_logger.writer
363
+ self.tb_dir = self.tensorboard_logger.log_dir
364
+
365
+ backends = []
366
+ if self.tensorboard_logger.enabled:
367
+ backends.append(self.tensorboard_logger)
368
+
369
+ wandb_kwargs = dict(wandb_kwargs or {})
370
+ wandb_kwargs.setdefault("config", {})
371
+ wandb_kwargs["config"].update(config)
372
+
373
+ swanlab_kwargs = dict(swanlab_kwargs or {})
374
+ swanlab_kwargs.setdefault("config", {})
375
+ swanlab_kwargs["config"].update(config)
376
+
377
+ self.wandb_logger = None
378
+ if use_wandb:
379
+ self.wandb_logger = WandbLogger(
380
+ session=session, enabled=use_wandb, **wandb_kwargs
381
+ )
382
+ if self.wandb_logger.enabled:
383
+ backends.append(self.wandb_logger)
384
+
385
+ self.swanlab_logger = None
386
+ if use_swanlab:
387
+ self.swanlab_logger = SwanLabLogger(
388
+ session=session, enabled=use_swanlab, **swanlab_kwargs
389
+ )
390
+ if self.swanlab_logger.enabled:
391
+ backends.append(self.swanlab_logger)
392
+
393
+ super().__init__(session=session, log_name=log_name, backends=backends)
394
+
395
+ def init_tensorboard(self) -> None:
396
+ if self.tensorboard_logger and self.tensorboard_logger.enabled:
397
+ return
398
+ self.tensorboard_logger = TensorBoardLogger(session=self.session, enabled=True)
399
+ self.use_tensorboard = self.tensorboard_logger.enabled
400
+ self.tb_writer = self.tensorboard_logger.writer
401
+ self.tb_dir = self.tensorboard_logger.log_dir
402
+ if (
403
+ self.tensorboard_logger.enabled
404
+ and self.tensorboard_logger not in self.backends
405
+ ):
406
+ self.backends.append(self.tensorboard_logger)
407
+
408
+ @property
409
+ def tensorboard_logdir(self):
410
+ return self.tb_dir
212
411
 
213
412
  def close(self) -> None:
214
- if self.tb_writer:
215
- self.tb_writer.flush()
216
- self.tb_writer.close()
217
- self.tb_writer = None
413
+ super().close()
414
+ self.tb_writer = None