nextrec 0.4.22__py3-none-any.whl → 0.4.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/metrics.py +1 -2
  3. nextrec/basic/model.py +68 -73
  4. nextrec/basic/summary.py +36 -2
  5. nextrec/data/preprocessor.py +137 -5
  6. nextrec/loss/listwise.py +19 -6
  7. nextrec/loss/pairwise.py +6 -4
  8. nextrec/loss/pointwise.py +8 -6
  9. nextrec/models/multi_task/esmm.py +3 -26
  10. nextrec/models/multi_task/mmoe.py +2 -24
  11. nextrec/models/multi_task/ple.py +13 -35
  12. nextrec/models/multi_task/poso.py +4 -28
  13. nextrec/models/multi_task/share_bottom.py +1 -24
  14. nextrec/models/ranking/afm.py +3 -27
  15. nextrec/models/ranking/autoint.py +5 -38
  16. nextrec/models/ranking/dcn.py +1 -26
  17. nextrec/models/ranking/dcn_v2.py +5 -33
  18. nextrec/models/ranking/deepfm.py +2 -29
  19. nextrec/models/ranking/dien.py +2 -28
  20. nextrec/models/ranking/din.py +2 -27
  21. nextrec/models/ranking/eulernet.py +3 -30
  22. nextrec/models/ranking/ffm.py +0 -26
  23. nextrec/models/ranking/fibinet.py +8 -32
  24. nextrec/models/ranking/fm.py +0 -29
  25. nextrec/models/ranking/lr.py +0 -30
  26. nextrec/models/ranking/masknet.py +4 -30
  27. nextrec/models/ranking/pnn.py +4 -28
  28. nextrec/models/ranking/widedeep.py +0 -32
  29. nextrec/models/ranking/xdeepfm.py +0 -30
  30. nextrec/models/retrieval/dssm.py +0 -24
  31. nextrec/models/retrieval/dssm_v2.py +0 -24
  32. nextrec/models/retrieval/mind.py +0 -20
  33. nextrec/models/retrieval/sdm.py +0 -20
  34. nextrec/models/retrieval/youtube_dnn.py +0 -21
  35. nextrec/models/sequential/hstu.py +0 -18
  36. nextrec/utils/model.py +79 -1
  37. nextrec/utils/types.py +35 -0
  38. {nextrec-0.4.22.dist-info → nextrec-0.4.23.dist-info}/METADATA +7 -5
  39. nextrec-0.4.23.dist-info/RECORD +81 -0
  40. nextrec-0.4.22.dist-info/RECORD +0 -81
  41. {nextrec-0.4.22.dist-info → nextrec-0.4.23.dist-info}/WHEEL +0 -0
  42. {nextrec-0.4.22.dist-info → nextrec-0.4.23.dist-info}/entry_points.txt +0 -0
  43. {nextrec-0.4.22.dist-info → nextrec-0.4.23.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.22"
1
+ __version__ = "0.4.23"
nextrec/basic/metrics.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
39
39
  TASK_DEFAULT_METRICS = {
40
40
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
41
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
42
- "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
43
42
  "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
44
43
  + [f"recall@{k}" for k in (5, 10, 20)]
45
44
  + [f"ndcg@{k}" for k in (5, 10, 20)]
nextrec/basic/model.py CHANGED
@@ -2,13 +2,14 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
10
  import logging
11
11
  import os
12
+ import sys
12
13
  import pickle
13
14
  import socket
14
15
  from pathlib import Path
@@ -16,6 +17,16 @@ from typing import Any, Literal
16
17
 
17
18
  import numpy as np
18
19
  import pandas as pd
20
+
21
+ try:
22
+ import swanlab # type: ignore
23
+ except ModuleNotFoundError:
24
+ swanlab = None
25
+ try:
26
+ import wandb # type: ignore
27
+ except ModuleNotFoundError:
28
+ wandb = None
29
+
19
30
  import torch
20
31
  import torch.distributed as dist
21
32
  import torch.nn as nn
@@ -74,13 +85,19 @@ from nextrec.utils.torch_utils import (
74
85
  to_tensor,
75
86
  )
76
87
  from nextrec.utils.config import safe_value
77
- from nextrec.utils.model import compute_ranking_loss
88
+ from nextrec.utils.model import (
89
+ compute_ranking_loss,
90
+ get_loss_list,
91
+ resolve_loss_weights,
92
+ get_training_modes,
93
+ )
78
94
  from nextrec.utils.types import (
79
95
  LossName,
80
96
  OptimizerName,
81
97
  SchedulerName,
82
98
  TrainingModeName,
83
99
  TaskTypeName,
100
+ MetricsName,
84
101
  )
85
102
 
86
103
 
@@ -90,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
90
107
  raise NotImplementedError
91
108
 
92
109
  @property
93
- def default_task(self) -> str | list[str]:
110
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
94
111
  raise NotImplementedError
95
112
 
96
113
  def __init__(
@@ -139,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
139
156
  world_size: Number of processes (defaults to env WORLD_SIZE).
140
157
  local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
141
158
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
159
+
160
+ Note:
161
+ Optimizer, scheduler, and loss are configured via compile().
142
162
  """
143
163
  super(BaseModel, self).__init__()
144
164
 
@@ -171,24 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
171
191
  dense_features, sparse_features, sequence_features, target, id_columns
172
192
  )
173
193
 
174
- self.task = self.default_task if task is None else task
194
+ self.task = task or self.default_task
175
195
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
- if isinstance(training_mode, list):
177
- training_modes = list(training_mode)
178
- if len(training_modes) != self.nums_task:
179
- raise ValueError(
180
- "[BaseModel-init Error] training_mode list length must match number of tasks."
181
- )
182
- else:
183
- training_modes = [training_mode] * self.nums_task
184
- if any(
185
- mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
186
- ):
187
- raise ValueError(
188
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
189
- )
190
- self.training_modes = training_modes
191
- self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
196
+ self.training_modes = get_training_modes(training_mode, self.nums_task)
197
+ self.training_mode = (
198
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
+ )
192
200
 
193
201
  self.embedding_l1_reg = embedding_l1_reg
194
202
  self.dense_l1_reg = dense_l1_reg
@@ -196,8 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
196
204
  self.dense_l2_reg = dense_l2_reg
197
205
  self.regularization_weights = []
198
206
  self.embedding_params = []
199
- self.loss_weight = None
207
+
200
208
  self.ignore_label = None
209
+ self.compiled = False
201
210
 
202
211
  self.max_gradient_norm = 1.0
203
212
  self.logger_initialized = False
@@ -431,28 +440,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
431
440
  "pairwise": "bpr",
432
441
  "listwise": "listnet",
433
442
  }
434
- effective_loss = loss
435
- if effective_loss is None:
436
- loss_list = [default_losses[mode] for mode in self.training_modes]
437
- elif isinstance(effective_loss, list):
438
- if not effective_loss:
439
- loss_list = [default_losses[mode] for mode in self.training_modes]
440
- else:
441
- if len(effective_loss) != self.nums_task:
442
- raise ValueError(
443
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
444
- )
445
- loss_list = list(effective_loss)
446
- else:
447
- loss_list = [effective_loss] * self.nums_task
448
-
449
- for idx, mode in enumerate(self.training_modes):
450
- if isinstance(loss_list[idx], str) and loss_list[idx] in {
451
- "bce",
452
- "binary_crossentropy",
453
- }:
454
- if mode in {"pairwise", "listwise"}:
455
- loss_list[idx] = default_losses[mode]
443
+ loss_list = get_loss_list(
444
+ loss, self.training_modes, self.nums_task, default_losses
445
+ )
456
446
  self.loss_params = loss_params or {}
457
447
  optimizer_params = optimizer_params or {}
458
448
  self.optimizer_name = (
@@ -516,30 +506,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
516
506
  nums_task=self.nums_task, device=self.device, **grad_norm_params
517
507
  )
518
508
  self.loss_weights = None
519
- elif loss_weights is None:
520
- self.loss_weights = None
521
- elif self.nums_task == 1:
522
- if isinstance(loss_weights, (list, tuple)):
523
- if len(loss_weights) != 1:
524
- raise ValueError(
525
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
526
- )
527
- loss_weights = loss_weights[0]
528
- self.loss_weights = [float(loss_weights)] # type: ignore
529
509
  else:
530
- if isinstance(loss_weights, (int, float)):
531
- weights = [float(loss_weights)] * self.nums_task
532
- elif isinstance(loss_weights, (list, tuple)):
533
- weights = [float(w) for w in loss_weights]
534
- if len(weights) != self.nums_task:
535
- raise ValueError(
536
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
537
- )
538
- else:
539
- raise TypeError(
540
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
541
- )
542
- self.loss_weights = weights
510
+ self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
511
+ self.compiled = True
543
512
 
544
513
  def compute_loss(self, y_pred, y_true):
545
514
  if y_true is None:
@@ -602,9 +571,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
602
571
  for i, (start, end) in enumerate(slices): # type: ignore
603
572
  y_pred_i = y_pred[:, start:end]
604
573
  y_true_i = y_true[:, start:end]
605
- total_count = y_true_i.shape[0]
606
- # valid_count = None
607
-
608
574
  # mask ignored labels
609
575
  if self.ignore_label is not None:
610
576
  valid_mask = y_true_i != self.ignore_label
@@ -613,11 +579,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
613
579
  if not torch.any(valid_mask):
614
580
  task_losses.append(y_pred_i.sum() * 0.0)
615
581
  continue
616
- # valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
617
582
  y_pred_i = y_pred_i[valid_mask]
618
583
  y_true_i = y_true_i[valid_mask]
619
- # else:
620
- # valid_count = y_true_i.new_tensor(float(total_count))
621
584
 
622
585
  mode = self.training_modes[i]
623
586
 
@@ -691,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
691
654
  train_data=None,
692
655
  valid_data=None,
693
656
  metrics: (
694
- list[str] | dict[str, list[str]] | None
657
+ list[MetricsName] | dict[str, list[MetricsName]] | None
695
658
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
696
659
  epochs: int = 1,
697
660
  shuffle: bool = True,
@@ -705,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
705
668
  use_tensorboard: bool = True,
706
669
  use_wandb: bool = False,
707
670
  use_swanlab: bool = False,
671
+ wandb_api: str | None = None,
672
+ swanlab_api: str | None = None,
708
673
  wandb_kwargs: dict | None = None,
709
674
  swanlab_kwargs: dict | None = None,
710
675
  auto_ddp_sampler: bool = True,
@@ -734,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
734
699
  use_tensorboard: Enable tensorboard logging.
735
700
  use_wandb: Enable Weights & Biases logging.
736
701
  use_swanlab: Enable SwanLab logging.
702
+ wandb_api: W&B API key for non-tty login.
703
+ swanlab_api: SwanLab API key for non-tty login.
737
704
  wandb_kwargs: Optional kwargs for wandb.init(...).
738
705
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
739
706
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
@@ -751,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
751
718
  )
752
719
  self.to(self.device)
753
720
 
721
+ if not self.compiled:
722
+ self.compile(
723
+ optimizer="adam",
724
+ optimizer_params={},
725
+ scheduler=None,
726
+ scheduler_params={},
727
+ loss=None,
728
+ loss_params={},
729
+ )
730
+
754
731
  if (
755
732
  self.distributed
756
733
  and dist.is_available()
@@ -825,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
825
802
  }
826
803
  training_config: dict = safe_value(training_config) # type: ignore
827
804
 
805
+ if self.is_main_process:
806
+ is_tty = sys.stdin.isatty() and sys.stdout.isatty()
807
+ if not is_tty:
808
+ if use_wandb and wandb_api:
809
+ if wandb is None:
810
+ logging.warning(
811
+ "[BaseModel-fit] wandb not installed, skip wandb login."
812
+ )
813
+ else:
814
+ wandb.login(key=wandb_api)
815
+ if use_swanlab and swanlab_api:
816
+ if swanlab is None:
817
+ logging.warning(
818
+ "[BaseModel-fit] swanlab not installed, skip swanlab login."
819
+ )
820
+ else:
821
+ swanlab.login(api_key=swanlab_api)
822
+
828
823
  self.training_logger = (
829
824
  TrainingLogger(
830
825
  session=self.session,
nextrec/basic/summary.py CHANGED
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Summary utilities for BaseModel.
3
+
4
+ Date: create on 03/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from __future__ import annotations
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
12
16
 
13
17
  from nextrec.basic.loggers import colorize, format_kv
14
18
  from nextrec.data.data_processing import extract_label_arrays, get_data_length
19
+ from nextrec.utils.types import TaskTypeName
15
20
 
16
21
 
17
22
  class SummarySet:
23
+ model_name: str
24
+ dense_features: list[Any]
25
+ sparse_features: list[Any]
26
+ sequence_features: list[Any]
27
+ task: TaskTypeName | list[TaskTypeName]
28
+ target_columns: list[str]
29
+ nums_task: int
30
+ metrics: Any
31
+ device: Any
32
+ optimizer_name: str
33
+ optimizer_params: dict[str, Any]
34
+ scheduler_name: str | None
35
+ scheduler_params: dict[str, Any]
36
+ loss_config: Any
37
+ loss_weights: Any
38
+ grad_norm: Any
39
+ embedding_l1_reg: float
40
+ embedding_l2_reg: float
41
+ dense_l1_reg: float
42
+ dense_l2_reg: float
43
+ early_stop_patience: int
44
+ max_gradient_norm: float | None
45
+ metrics_sample_limit: int | None
46
+ session_id: str | None
47
+ features_config_path: str
48
+ checkpoint_path: str
49
+ train_data_summary: dict[str, Any] | None
50
+ valid_data_summary: dict[str, Any] | None
51
+
18
52
  def build_data_summary(
19
53
  self, data: Any, data_loader: DataLoader | None, sample_key: str
20
54
  ):
@@ -305,7 +339,7 @@ class SummarySet:
305
339
  lines = details.get("lines", [])
306
340
  logger.info(f"{target_name}:")
307
341
  for label, value in lines:
308
- logger.info(format_kv(label, value))
342
+ logger.info(f" {format_kv(label, value)}")
309
343
 
310
344
  if self.valid_data_summary:
311
345
  if self.train_data_summary:
@@ -320,4 +354,4 @@ class SummarySet:
320
354
  lines = details.get("lines", [])
321
355
  logger.info(f"{target_name}:")
322
356
  for label, value in lines:
323
- logger.info(format_kv(label, value))
357
+ logger.info(f" {format_kv(label, value)}")
@@ -2,7 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -79,6 +79,14 @@ class DataProcessor(FeatureSet):
79
79
  ] = "standard",
80
80
  fill_na: Optional[float] = None,
81
81
  ):
82
+ """Add a numeric feature configuration.
83
+
84
+ Args:
85
+ name (str): Feature name.
86
+ scaler (Optional[Literal["standard", "minmax", "robust", "maxabs", "log", "none"]], optional): Scaler type. Defaults to "standard".
87
+ fill_na (Optional[float], optional): Fill value for missing entries. Defaults to None.
88
+ """
89
+
82
90
  self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
83
91
 
84
92
  def add_sparse_feature(
@@ -88,6 +96,14 @@ class DataProcessor(FeatureSet):
88
96
  hash_size: Optional[int] = None,
89
97
  fill_na: str = "<UNK>",
90
98
  ):
99
+ """Add a sparse feature configuration.
100
+
101
+ Args:
102
+ name (str): Feature name.
103
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
+ fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
106
+ """
91
107
  if encode_method == "hash" and hash_size is None:
92
108
  raise ValueError(
93
109
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -101,7 +117,7 @@ class DataProcessor(FeatureSet):
101
117
  def add_sequence_feature(
102
118
  self,
103
119
  name: str,
104
- encode_method: Literal["hash", "label"] = "label",
120
+ encode_method: Literal["hash", "label"] = "hash",
105
121
  hash_size: Optional[int] = None,
106
122
  max_len: Optional[int] = 50,
107
123
  pad_value: int = 0,
@@ -110,6 +126,17 @@ class DataProcessor(FeatureSet):
110
126
  ] = "pre", # pre: keep last max_len items, post: keep first max_len items
111
127
  separator: str = ",",
112
128
  ):
129
+ """Add a sequence feature configuration.
130
+
131
+ Args:
132
+ name (str): Feature name.
133
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
+ max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
+ pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
+ truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
+ separator (str, optional): Separator for string sequences. Defaults to ",".
139
+ """
113
140
  if encode_method == "hash" and hash_size is None:
114
141
  raise ValueError(
115
142
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -131,6 +158,14 @@ class DataProcessor(FeatureSet):
131
158
  Dict[str, int]
132
159
  ] = None, # example: {'click': 1, 'no_click': 0}
133
160
  ):
161
+ """Add a target configuration.
162
+
163
+ Args:
164
+ name (str): Target name.
165
+ target_type (Literal["binary", "regression"], optional): Target type. Defaults to "binary".
166
+ label_map (Optional[Dict[str, int]], optional): Label mapping for binary targets. Defaults to None.
167
+ """
168
+
134
169
  self.target_features[name] = {
135
170
  "target_type": target_type,
136
171
  "label_map": label_map,
@@ -392,7 +427,15 @@ class DataProcessor(FeatureSet):
392
427
  )
393
428
 
394
429
  def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
395
- """Load all data from a file or directory path into a single DataFrame."""
430
+ """
431
+ Load all data from a file or directory path into a single DataFrame.
432
+
433
+ Args:
434
+ path (str): File or directory path.
435
+
436
+ Returns:
437
+ pd.DataFrame: Loaded DataFrame.
438
+ """
396
439
  file_paths, file_type = resolve_file_paths(path)
397
440
  frames = load_dataframes(file_paths, file_type)
398
441
  return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
@@ -411,7 +454,16 @@ class DataProcessor(FeatureSet):
411
454
  return [str(value)]
412
455
 
413
456
  def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
414
- """Fit processor statistics by streaming files to reduce memory usage."""
457
+ """
458
+ Fit processor statistics by streaming files to reduce memory usage.
459
+
460
+ Args:
461
+ path (str): File or directory path.
462
+ chunk_size (int): Number of rows per chunk.
463
+
464
+ Returns:
465
+ DataProcessor: Fitted DataProcessor instance.
466
+ """
415
467
  logger = logging.getLogger()
416
468
  logger.info(
417
469
  colorize(
@@ -428,7 +480,7 @@ class DataProcessor(FeatureSet):
428
480
  "Use fit(dataframe) with in-memory data or convert the data format."
429
481
  )
430
482
 
431
- numeric_acc: Dict[str, Dict[str, float]] = {}
483
+ numeric_acc = {}
432
484
  for name in self.numeric_features.keys():
433
485
  numeric_acc[name] = {
434
486
  "sum": 0.0,
@@ -609,6 +661,21 @@ class DataProcessor(FeatureSet):
609
661
  output_path: Optional[str],
610
662
  warn_missing: bool = True,
611
663
  ):
664
+ """
665
+ Transform in-memory data and optionally persist the transformed data.
666
+
667
+ Args:
668
+ data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
669
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
670
+ persist (bool): Whether to persist the transformed data to disk.
671
+ save_format (Optional[str]): Format to save the data if persisting.
672
+ output_path (Optional[str]): Output path to save the data if persisting.
673
+ warn_missing (bool): Whether to warn about missing features in the data.
674
+
675
+ Returns:
676
+ Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
677
+ """
678
+
612
679
  logger = logging.getLogger()
613
680
  data_dict = data if isinstance(data, dict) else None
614
681
 
@@ -719,6 +786,12 @@ class DataProcessor(FeatureSet):
719
786
  """Transform data from files under a path and save them to a new location.
720
787
 
721
788
  Uses chunked reading/writing to keep peak memory bounded for large files.
789
+
790
+ Args:
791
+ input_path (str): Input file or directory path.
792
+ output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
793
+ save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
794
+ chunk_size (int): Number of rows per chunk.
722
795
  """
723
796
  logger = logging.getLogger()
724
797
  file_paths, file_type = resolve_file_paths(input_path)
@@ -876,6 +949,17 @@ class DataProcessor(FeatureSet):
876
949
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
877
950
  chunk_size: int = 200000,
878
951
  ):
952
+ """
953
+ Fit the DataProcessor to the provided data.
954
+
955
+ Args:
956
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
957
+ chunk_size (int): Number of rows per chunk when streaming from path.
958
+
959
+ Returns:
960
+ DataProcessor: Fitted DataProcessor instance.
961
+ """
962
+
879
963
  logger = logging.getLogger()
880
964
  if isinstance(data, (str, os.PathLike)):
881
965
  path_str = str(data)
@@ -915,6 +999,19 @@ class DataProcessor(FeatureSet):
915
999
  output_path: Optional[str] = None,
916
1000
  chunk_size: int = 200000,
917
1001
  ):
1002
+ """
1003
+ Transform the provided data using the fitted DataProcessor.
1004
+
1005
+ Args:
1006
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
1007
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1008
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1009
+ output_path (Optional[str]): Output path to save the transformed data.
1010
+ chunk_size (int): Number of rows per chunk when streaming from path.
1011
+ Returns:
1012
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1013
+ """
1014
+
918
1015
  if not self.is_fitted:
919
1016
  raise ValueError(
920
1017
  "[Data Processor Error] DataProcessor must be fitted before transform"
@@ -943,6 +1040,19 @@ class DataProcessor(FeatureSet):
943
1040
  output_path: Optional[str] = None,
944
1041
  chunk_size: int = 200000,
945
1042
  ):
1043
+ """
1044
+ Fit the DataProcessor to the provided data and then transform it.
1045
+
1046
+ Args:
1047
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
1048
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1049
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1050
+ output_path (Optional[str]): Output path to save the transformed data.
1051
+ chunk_size (int): Number of rows per chunk when streaming from path.
1052
+ Returns:
1053
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1054
+ """
1055
+
946
1056
  self.fit(data, chunk_size=chunk_size)
947
1057
  return self.transform(
948
1058
  data,
@@ -952,6 +1062,12 @@ class DataProcessor(FeatureSet):
952
1062
  )
953
1063
 
954
1064
  def save(self, save_path: str | Path):
1065
+ """
1066
+ Save the fitted DataProcessor to a file.
1067
+
1068
+ Args:
1069
+ save_path (str | Path): Path to save the DataProcessor.
1070
+ """
955
1071
  logger = logging.getLogger()
956
1072
  assert isinstance(save_path, (str, Path)), "save_path must be a string or Path"
957
1073
  save_path = Path(save_path)
@@ -983,6 +1099,16 @@ class DataProcessor(FeatureSet):
983
1099
 
984
1100
  @classmethod
985
1101
  def load(cls, load_path: str | Path) -> "DataProcessor":
1102
+ """
1103
+ Load a fitted DataProcessor from a file.
1104
+
1105
+ Args:
1106
+ load_path (str | Path): Path to load the DataProcessor from.
1107
+
1108
+ Returns:
1109
+ DataProcessor: Loaded DataProcessor instance.
1110
+ """
1111
+
986
1112
  logger = logging.getLogger()
987
1113
  load_path = Path(load_path)
988
1114
  with open(load_path, "rb") as f:
@@ -1003,6 +1129,12 @@ class DataProcessor(FeatureSet):
1003
1129
  return processor
1004
1130
 
1005
1131
  def get_vocab_sizes(self) -> Dict[str, int]:
1132
+ """
1133
+ Get vocabulary sizes for all sparse and sequence features.
1134
+
1135
+ Returns:
1136
+ Dict[str, int]: Mapping of feature names to vocabulary sizes.
1137
+ """
1006
1138
  vocab_sizes = {}
1007
1139
  for name, config in self.sparse_features.items():
1008
1140
  vocab_sizes[name] = config.get("vocab_size", 0)
nextrec/loss/listwise.py CHANGED
@@ -2,10 +2,11 @@
2
2
  Listwise loss functions for ranking and contrastive training.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
+ from typing import Literal
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -16,7 +17,7 @@ class SampledSoftmaxLoss(nn.Module):
16
17
  Softmax over one positive and multiple sampled negatives.
17
18
  """
18
19
 
19
- def __init__(self, reduction: str = "mean"):
20
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
20
21
  super().__init__()
21
22
  self.reduction = reduction
22
23
 
@@ -37,7 +38,11 @@ class InfoNCELoss(nn.Module):
37
38
  InfoNCE loss for contrastive learning with one positive and many negatives.
38
39
  """
39
40
 
40
- def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
41
+ def __init__(
42
+ self,
43
+ temperature: float = 0.07,
44
+ reduction: Literal["mean", "sum", "none"] = "mean",
45
+ ):
41
46
  super().__init__()
42
47
  self.temperature = temperature
43
48
  self.reduction = reduction
@@ -61,7 +66,11 @@ class ListNetLoss(nn.Module):
61
66
  Reference: Cao et al. (ICML 2007)
62
67
  """
63
68
 
64
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
69
+ def __init__(
70
+ self,
71
+ temperature: float = 1.0,
72
+ reduction: Literal["mean", "sum", "none"] = "mean",
73
+ ):
65
74
  super().__init__()
66
75
  self.temperature = temperature
67
76
  self.reduction = reduction
@@ -84,7 +93,7 @@ class ListMLELoss(nn.Module):
84
93
  Reference: Xia et al. (ICML 2008)
85
94
  """
86
95
 
87
- def __init__(self, reduction: str = "mean"):
96
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
88
97
  super().__init__()
89
98
  self.reduction = reduction
90
99
 
@@ -117,7 +126,11 @@ class ApproxNDCGLoss(nn.Module):
117
126
  Reference: Qin et al. (2010)
118
127
  """
119
128
 
120
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
129
+ def __init__(
130
+ self,
131
+ temperature: float = 1.0,
132
+ reduction: Literal["mean", "sum", "none"] = "mean",
133
+ ):
121
134
  super().__init__()
122
135
  self.temperature = temperature
123
136
  self.reduction = reduction