nextrec 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -1
  3. nextrec/basic/heads.py +2 -3
  4. nextrec/basic/metrics.py +1 -2
  5. nextrec/basic/model.py +115 -80
  6. nextrec/basic/summary.py +36 -2
  7. nextrec/data/preprocessor.py +137 -5
  8. nextrec/loss/__init__.py +0 -4
  9. nextrec/loss/grad_norm.py +3 -3
  10. nextrec/loss/listwise.py +19 -6
  11. nextrec/loss/pairwise.py +6 -4
  12. nextrec/loss/pointwise.py +8 -6
  13. nextrec/models/multi_task/esmm.py +3 -26
  14. nextrec/models/multi_task/mmoe.py +2 -24
  15. nextrec/models/multi_task/ple.py +13 -35
  16. nextrec/models/multi_task/poso.py +4 -28
  17. nextrec/models/multi_task/share_bottom.py +1 -24
  18. nextrec/models/ranking/afm.py +3 -27
  19. nextrec/models/ranking/autoint.py +5 -38
  20. nextrec/models/ranking/dcn.py +1 -26
  21. nextrec/models/ranking/dcn_v2.py +5 -33
  22. nextrec/models/ranking/deepfm.py +2 -29
  23. nextrec/models/ranking/dien.py +2 -28
  24. nextrec/models/ranking/din.py +2 -27
  25. nextrec/models/ranking/eulernet.py +3 -30
  26. nextrec/models/ranking/ffm.py +0 -26
  27. nextrec/models/ranking/fibinet.py +8 -32
  28. nextrec/models/ranking/fm.py +0 -29
  29. nextrec/models/ranking/lr.py +0 -30
  30. nextrec/models/ranking/masknet.py +4 -30
  31. nextrec/models/ranking/pnn.py +4 -28
  32. nextrec/models/ranking/widedeep.py +0 -32
  33. nextrec/models/ranking/xdeepfm.py +0 -30
  34. nextrec/models/retrieval/dssm.py +0 -24
  35. nextrec/models/retrieval/dssm_v2.py +0 -24
  36. nextrec/models/retrieval/mind.py +0 -20
  37. nextrec/models/retrieval/sdm.py +0 -20
  38. nextrec/models/retrieval/youtube_dnn.py +0 -21
  39. nextrec/models/sequential/hstu.py +0 -18
  40. nextrec/utils/__init__.py +5 -1
  41. nextrec/{loss/loss_utils.py → utils/loss.py} +17 -7
  42. nextrec/utils/model.py +79 -1
  43. nextrec/utils/types.py +62 -23
  44. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/METADATA +8 -6
  45. nextrec-0.4.23.dist-info/RECORD +81 -0
  46. nextrec-0.4.21.dist-info/RECORD +0 -81
  47. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/WHEEL +0 -0
  48. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/entry_points.txt +0 -0
  49. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -79,6 +79,14 @@ class DataProcessor(FeatureSet):
79
79
  ] = "standard",
80
80
  fill_na: Optional[float] = None,
81
81
  ):
82
+ """Add a numeric feature configuration.
83
+
84
+ Args:
85
+ name (str): Feature name.
86
+ scaler (Optional[Literal["standard", "minmax", "robust", "maxabs", "log", "none"]], optional): Scaler type. Defaults to "standard".
87
+ fill_na (Optional[float], optional): Fill value for missing entries. Defaults to None.
88
+ """
89
+
82
90
  self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
83
91
 
84
92
  def add_sparse_feature(
@@ -88,6 +96,14 @@ class DataProcessor(FeatureSet):
88
96
  hash_size: Optional[int] = None,
89
97
  fill_na: str = "<UNK>",
90
98
  ):
99
+ """Add a sparse feature configuration.
100
+
101
+ Args:
102
+ name (str): Feature name.
103
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
+ fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
106
+ """
91
107
  if encode_method == "hash" and hash_size is None:
92
108
  raise ValueError(
93
109
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -101,7 +117,7 @@ class DataProcessor(FeatureSet):
101
117
  def add_sequence_feature(
102
118
  self,
103
119
  name: str,
104
- encode_method: Literal["hash", "label"] = "label",
120
+ encode_method: Literal["hash", "label"] = "hash",
105
121
  hash_size: Optional[int] = None,
106
122
  max_len: Optional[int] = 50,
107
123
  pad_value: int = 0,
@@ -110,6 +126,17 @@ class DataProcessor(FeatureSet):
110
126
  ] = "pre", # pre: keep last max_len items, post: keep first max_len items
111
127
  separator: str = ",",
112
128
  ):
129
+ """Add a sequence feature configuration.
130
+
131
+ Args:
132
+ name (str): Feature name.
133
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
+ max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
+ pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
+ truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
+ separator (str, optional): Separator for string sequences. Defaults to ",".
139
+ """
113
140
  if encode_method == "hash" and hash_size is None:
114
141
  raise ValueError(
115
142
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -131,6 +158,14 @@ class DataProcessor(FeatureSet):
131
158
  Dict[str, int]
132
159
  ] = None, # example: {'click': 1, 'no_click': 0}
133
160
  ):
161
+ """Add a target configuration.
162
+
163
+ Args:
164
+ name (str): Target name.
165
+ target_type (Literal["binary", "regression"], optional): Target type. Defaults to "binary".
166
+ label_map (Optional[Dict[str, int]], optional): Label mapping for binary targets. Defaults to None.
167
+ """
168
+
134
169
  self.target_features[name] = {
135
170
  "target_type": target_type,
136
171
  "label_map": label_map,
@@ -392,7 +427,15 @@ class DataProcessor(FeatureSet):
392
427
  )
393
428
 
394
429
  def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
395
- """Load all data from a file or directory path into a single DataFrame."""
430
+ """
431
+ Load all data from a file or directory path into a single DataFrame.
432
+
433
+ Args:
434
+ path (str): File or directory path.
435
+
436
+ Returns:
437
+ pd.DataFrame: Loaded DataFrame.
438
+ """
396
439
  file_paths, file_type = resolve_file_paths(path)
397
440
  frames = load_dataframes(file_paths, file_type)
398
441
  return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
@@ -411,7 +454,16 @@ class DataProcessor(FeatureSet):
411
454
  return [str(value)]
412
455
 
413
456
  def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
414
- """Fit processor statistics by streaming files to reduce memory usage."""
457
+ """
458
+ Fit processor statistics by streaming files to reduce memory usage.
459
+
460
+ Args:
461
+ path (str): File or directory path.
462
+ chunk_size (int): Number of rows per chunk.
463
+
464
+ Returns:
465
+ DataProcessor: Fitted DataProcessor instance.
466
+ """
415
467
  logger = logging.getLogger()
416
468
  logger.info(
417
469
  colorize(
@@ -428,7 +480,7 @@ class DataProcessor(FeatureSet):
428
480
  "Use fit(dataframe) with in-memory data or convert the data format."
429
481
  )
430
482
 
431
- numeric_acc: Dict[str, Dict[str, float]] = {}
483
+ numeric_acc = {}
432
484
  for name in self.numeric_features.keys():
433
485
  numeric_acc[name] = {
434
486
  "sum": 0.0,
@@ -609,6 +661,21 @@ class DataProcessor(FeatureSet):
609
661
  output_path: Optional[str],
610
662
  warn_missing: bool = True,
611
663
  ):
664
+ """
665
+ Transform in-memory data and optionally persist the transformed data.
666
+
667
+ Args:
668
+ data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
669
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
670
+ persist (bool): Whether to persist the transformed data to disk.
671
+ save_format (Optional[str]): Format to save the data if persisting.
672
+ output_path (Optional[str]): Output path to save the data if persisting.
673
+ warn_missing (bool): Whether to warn about missing features in the data.
674
+
675
+ Returns:
676
+ Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
677
+ """
678
+
612
679
  logger = logging.getLogger()
613
680
  data_dict = data if isinstance(data, dict) else None
614
681
 
@@ -719,6 +786,12 @@ class DataProcessor(FeatureSet):
719
786
  """Transform data from files under a path and save them to a new location.
720
787
 
721
788
  Uses chunked reading/writing to keep peak memory bounded for large files.
789
+
790
+ Args:
791
+ input_path (str): Input file or directory path.
792
+ output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
793
+ save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
794
+ chunk_size (int): Number of rows per chunk.
722
795
  """
723
796
  logger = logging.getLogger()
724
797
  file_paths, file_type = resolve_file_paths(input_path)
@@ -876,6 +949,17 @@ class DataProcessor(FeatureSet):
876
949
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
877
950
  chunk_size: int = 200000,
878
951
  ):
952
+ """
953
+ Fit the DataProcessor to the provided data.
954
+
955
+ Args:
956
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
957
+ chunk_size (int): Number of rows per chunk when streaming from path.
958
+
959
+ Returns:
960
+ DataProcessor: Fitted DataProcessor instance.
961
+ """
962
+
879
963
  logger = logging.getLogger()
880
964
  if isinstance(data, (str, os.PathLike)):
881
965
  path_str = str(data)
@@ -915,6 +999,19 @@ class DataProcessor(FeatureSet):
915
999
  output_path: Optional[str] = None,
916
1000
  chunk_size: int = 200000,
917
1001
  ):
1002
+ """
1003
+ Transform the provided data using the fitted DataProcessor.
1004
+
1005
+ Args:
1006
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
1007
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1008
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1009
+ output_path (Optional[str]): Output path to save the transformed data.
1010
+ chunk_size (int): Number of rows per chunk when streaming from path.
1011
+ Returns:
1012
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1013
+ """
1014
+
918
1015
  if not self.is_fitted:
919
1016
  raise ValueError(
920
1017
  "[Data Processor Error] DataProcessor must be fitted before transform"
@@ -943,6 +1040,19 @@ class DataProcessor(FeatureSet):
943
1040
  output_path: Optional[str] = None,
944
1041
  chunk_size: int = 200000,
945
1042
  ):
1043
+ """
1044
+ Fit the DataProcessor to the provided data and then transform it.
1045
+
1046
+ Args:
1047
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
1048
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1049
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1050
+ output_path (Optional[str]): Output path to save the transformed data.
1051
+ chunk_size (int): Number of rows per chunk when streaming from path.
1052
+ Returns:
1053
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1054
+ """
1055
+
946
1056
  self.fit(data, chunk_size=chunk_size)
947
1057
  return self.transform(
948
1058
  data,
@@ -952,6 +1062,12 @@ class DataProcessor(FeatureSet):
952
1062
  )
953
1063
 
954
1064
  def save(self, save_path: str | Path):
1065
+ """
1066
+ Save the fitted DataProcessor to a file.
1067
+
1068
+ Args:
1069
+ save_path (str | Path): Path to save the DataProcessor.
1070
+ """
955
1071
  logger = logging.getLogger()
956
1072
  assert isinstance(save_path, (str, Path)), "save_path must be a string or Path"
957
1073
  save_path = Path(save_path)
@@ -983,6 +1099,16 @@ class DataProcessor(FeatureSet):
983
1099
 
984
1100
  @classmethod
985
1101
  def load(cls, load_path: str | Path) -> "DataProcessor":
1102
+ """
1103
+ Load a fitted DataProcessor from a file.
1104
+
1105
+ Args:
1106
+ load_path (str | Path): Path to load the DataProcessor from.
1107
+
1108
+ Returns:
1109
+ DataProcessor: Loaded DataProcessor instance.
1110
+ """
1111
+
986
1112
  logger = logging.getLogger()
987
1113
  load_path = Path(load_path)
988
1114
  with open(load_path, "rb") as f:
@@ -1003,6 +1129,12 @@ class DataProcessor(FeatureSet):
1003
1129
  return processor
1004
1130
 
1005
1131
  def get_vocab_sizes(self) -> Dict[str, int]:
1132
+ """
1133
+ Get vocabulary sizes for all sparse and sequence features.
1134
+
1135
+ Returns:
1136
+ Dict[str, int]: Mapping of feature names to vocabulary sizes.
1137
+ """
1006
1138
  vocab_sizes = {}
1007
1139
  for name, config in self.sparse_features.items():
1008
1140
  vocab_sizes[name] = config.get("vocab_size", 0)
nextrec/loss/__init__.py CHANGED
@@ -6,7 +6,6 @@ from nextrec.loss.listwise import (
6
6
  SampledSoftmaxLoss,
7
7
  )
8
8
  from nextrec.loss.grad_norm import GradNormLossWeighting
9
- from nextrec.loss.loss_utils import VALID_TASK_TYPES, get_loss_fn, get_loss_kwargs
10
9
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
11
10
  from nextrec.loss.pointwise import (
12
11
  ClassBalancedFocalLoss,
@@ -34,7 +33,4 @@ __all__ = [
34
33
  # Multi-task weighting
35
34
  "GradNormLossWeighting",
36
35
  # Utilities
37
- "get_loss_fn",
38
- "get_loss_kwargs",
39
- "VALID_TASK_TYPES",
40
36
  ]
nextrec/loss/grad_norm.py CHANGED
@@ -20,9 +20,9 @@ import torch.nn.functional as F
20
20
 
21
21
 
22
22
  def get_grad_norm_shared_params(
23
- model: torch.nn.Module,
24
- shared_modules: Iterable[str] | None = None,
25
- ) -> list[torch.nn.Parameter]:
23
+ model,
24
+ shared_modules=None,
25
+ ):
26
26
  if not shared_modules:
27
27
  return [p for p in model.parameters() if p.requires_grad]
28
28
  shared_params = []
nextrec/loss/listwise.py CHANGED
@@ -2,10 +2,11 @@
2
2
  Listwise loss functions for ranking and contrastive training.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
+ from typing import Literal
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -16,7 +17,7 @@ class SampledSoftmaxLoss(nn.Module):
16
17
  Softmax over one positive and multiple sampled negatives.
17
18
  """
18
19
 
19
- def __init__(self, reduction: str = "mean"):
20
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
20
21
  super().__init__()
21
22
  self.reduction = reduction
22
23
 
@@ -37,7 +38,11 @@ class InfoNCELoss(nn.Module):
37
38
  InfoNCE loss for contrastive learning with one positive and many negatives.
38
39
  """
39
40
 
40
- def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
41
+ def __init__(
42
+ self,
43
+ temperature: float = 0.07,
44
+ reduction: Literal["mean", "sum", "none"] = "mean",
45
+ ):
41
46
  super().__init__()
42
47
  self.temperature = temperature
43
48
  self.reduction = reduction
@@ -61,7 +66,11 @@ class ListNetLoss(nn.Module):
61
66
  Reference: Cao et al. (ICML 2007)
62
67
  """
63
68
 
64
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
69
+ def __init__(
70
+ self,
71
+ temperature: float = 1.0,
72
+ reduction: Literal["mean", "sum", "none"] = "mean",
73
+ ):
65
74
  super().__init__()
66
75
  self.temperature = temperature
67
76
  self.reduction = reduction
@@ -84,7 +93,7 @@ class ListMLELoss(nn.Module):
84
93
  Reference: Xia et al. (ICML 2008)
85
94
  """
86
95
 
87
- def __init__(self, reduction: str = "mean"):
96
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
88
97
  super().__init__()
89
98
  self.reduction = reduction
90
99
 
@@ -117,7 +126,11 @@ class ApproxNDCGLoss(nn.Module):
117
126
  Reference: Qin et al. (2010)
118
127
  """
119
128
 
120
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
129
+ def __init__(
130
+ self,
131
+ temperature: float = 1.0,
132
+ reduction: Literal["mean", "sum", "none"] = "mean",
133
+ ):
121
134
  super().__init__()
122
135
  self.temperature = temperature
123
136
  self.reduction = reduction
nextrec/loss/pairwise.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Pairwise loss functions for learning-to-rank and matching tasks.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -18,7 +18,7 @@ class BPRLoss(nn.Module):
18
18
  Bayesian Personalized Ranking loss with support for multiple negatives.
19
19
  """
20
20
 
21
- def __init__(self, reduction: str = "mean"):
21
+ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean"):
22
22
  super().__init__()
23
23
  self.reduction = reduction
24
24
 
@@ -42,7 +42,9 @@ class HingeLoss(nn.Module):
42
42
  Hinge loss for pairwise ranking.
43
43
  """
44
44
 
45
- def __init__(self, margin: float = 1.0, reduction: str = "mean"):
45
+ def __init__(
46
+ self, margin: float = 1.0, reduction: Literal["mean", "sum", "none"] = "mean"
47
+ ):
46
48
  super().__init__()
47
49
  self.margin = margin
48
50
  self.reduction = reduction
@@ -69,7 +71,7 @@ class TripletLoss(nn.Module):
69
71
  def __init__(
70
72
  self,
71
73
  margin: float = 1.0,
72
- reduction: str = "mean",
74
+ reduction: Literal["mean", "sum", "none"] = "mean",
73
75
  distance: Literal["euclidean", "cosine"] = "euclidean",
74
76
  ):
75
77
  super().__init__()
nextrec/loss/pointwise.py CHANGED
@@ -2,11 +2,11 @@
2
2
  Pointwise loss functions, including imbalance-aware variants.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- from typing import Optional, Sequence
9
+ from typing import Optional, Sequence, Literal
10
10
 
11
11
  import torch
12
12
  import torch.nn as nn
@@ -18,7 +18,9 @@ class CosineContrastiveLoss(nn.Module):
18
18
  Contrastive loss using cosine similarity for positive/negative pairs.
19
19
  """
20
20
 
21
- def __init__(self, margin: float = 0.5, reduction: str = "mean"):
21
+ def __init__(
22
+ self, margin: float = 0.5, reduction: Literal["mean", "sum", "none"] = "mean"
23
+ ):
22
24
  super().__init__()
23
25
  self.margin = margin
24
26
  self.reduction = reduction
@@ -50,7 +52,7 @@ class WeightedBCELoss(nn.Module):
50
52
  def __init__(
51
53
  self,
52
54
  pos_weight: float | torch.Tensor | None = None,
53
- reduction: str = "mean",
55
+ reduction: Literal["mean", "sum", "none"] = "mean",
54
56
  logits: bool = False,
55
57
  auto_balance: bool = False,
56
58
  ):
@@ -110,7 +112,7 @@ class FocalLoss(nn.Module):
110
112
  self,
111
113
  gamma: float = 2.0,
112
114
  alpha: Optional[float | Sequence[float] | torch.Tensor] = None,
113
- reduction: str = "mean",
115
+ reduction: Literal["mean", "sum", "none"] = "mean",
114
116
  logits: bool = False,
115
117
  ):
116
118
  super().__init__()
@@ -187,7 +189,7 @@ class ClassBalancedFocalLoss(nn.Module):
187
189
  class_counts: Sequence[int] | torch.Tensor,
188
190
  beta: float = 0.9999,
189
191
  gamma: float = 2.0,
190
- reduction: str = "mean",
192
+ reduction: Literal["mean", "sum", "none"] = "mean",
191
193
  ):
192
194
  super().__init__()
193
195
  self.gamma = gamma
@@ -42,12 +42,12 @@ CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上
42
42
  """
43
43
 
44
44
  import torch
45
- import torch.nn as nn
46
45
 
47
46
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
47
  from nextrec.basic.layers import MLP, EmbeddingLayer
49
48
  from nextrec.basic.heads import TaskHead
50
49
  from nextrec.basic.model import BaseModel
50
+ from nextrec.utils.types import TaskTypeName
51
51
 
52
52
 
53
53
  class ESMM(BaseModel):
@@ -77,23 +77,12 @@ class ESMM(BaseModel):
77
77
  sequence_features: list[SequenceFeature],
78
78
  ctr_params: dict,
79
79
  cvr_params: dict,
80
+ task: TaskTypeName | list[TaskTypeName] | None = None,
80
81
  target: list[str] | None = None, # Note: ctcvr = ctr * cvr
81
- task: list[str] | None = None,
82
- optimizer: str = "adam",
83
- optimizer_params: dict | None = None,
84
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
85
- loss_params: dict | list[dict] | None = None,
86
- embedding_l1_reg=0.0,
87
- dense_l1_reg=0.0,
88
- embedding_l2_reg=0.0,
89
- dense_l2_reg=0.0,
90
82
  **kwargs,
91
83
  ):
92
84
 
93
85
  target = target or ["ctr", "ctcvr"]
94
- optimizer_params = optimizer_params or {}
95
- if loss is None:
96
- loss = "bce"
97
86
 
98
87
  if len(target) != 2:
99
88
  raise ValueError(
@@ -120,15 +109,9 @@ class ESMM(BaseModel):
120
109
  sequence_features=sequence_features,
121
110
  target=target,
122
111
  task=resolved_task, # Both CTR and CTCVR are binary classification
123
- embedding_l1_reg=embedding_l1_reg,
124
- dense_l1_reg=dense_l1_reg,
125
- embedding_l2_reg=embedding_l2_reg,
126
- dense_l2_reg=dense_l2_reg,
127
112
  **kwargs,
128
113
  )
129
114
 
130
- self.loss = loss
131
-
132
115
  self.embedding = EmbeddingLayer(features=self.all_features)
133
116
  input_dim = self.embedding.input_dim
134
117
 
@@ -138,17 +121,11 @@ class ESMM(BaseModel):
138
121
  # CVR tower
139
122
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
140
123
  self.grad_norm_shared_modules = ["embedding"]
141
- self.prediction_layer = TaskHead(task_type=self.default_task, task_dims=[1, 1])
124
+ self.prediction_layer = TaskHead(task_type=self.task, task_dims=[1, 1])
142
125
  # Register regularization weights
143
126
  self.register_regularization_weights(
144
127
  embedding_attr="embedding", include_modules=["ctr_tower", "cvr_tower"]
145
128
  )
146
- self.compile(
147
- optimizer=optimizer,
148
- optimizer_params=optimizer_params,
149
- loss=loss,
150
- loss_params=loss_params,
151
- )
152
129
 
153
130
  def forward(self, x):
154
131
  # Get all embeddings and flatten
@@ -82,14 +82,6 @@ class MMOE(BaseModel):
82
82
  tower_params_list: list[dict] | None = None,
83
83
  target: list[str] | str | None = None,
84
84
  task: str | list[str] = "binary",
85
- optimizer: str = "adam",
86
- optimizer_params: dict | None = None,
87
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
88
- loss_params: dict | list[dict] | None = None,
89
- embedding_l1_reg=0.0,
90
- dense_l1_reg=0.0,
91
- embedding_l2_reg=0.0,
92
- dense_l2_reg=0.0,
93
85
  **kwargs,
94
86
  ):
95
87
 
@@ -98,9 +90,7 @@ class MMOE(BaseModel):
98
90
  sequence_features = sequence_features or []
99
91
  expert_params = expert_params or {}
100
92
  tower_params_list = tower_params_list or []
101
- optimizer_params = optimizer_params or {}
102
- if loss is None:
103
- loss = "bce"
93
+
104
94
  if target is None:
105
95
  target = []
106
96
  elif isinstance(target, str):
@@ -126,15 +116,9 @@ class MMOE(BaseModel):
126
116
  sequence_features=sequence_features,
127
117
  target=target,
128
118
  task=resolved_task,
129
- embedding_l1_reg=embedding_l1_reg,
130
- dense_l1_reg=dense_l1_reg,
131
- embedding_l2_reg=embedding_l2_reg,
132
- dense_l2_reg=dense_l2_reg,
133
119
  **kwargs,
134
120
  )
135
121
 
136
- self.loss = loss
137
-
138
122
  # Number of tasks and experts
139
123
  self.nums_task = len(target)
140
124
  self.num_experts = num_experts
@@ -172,18 +156,12 @@ class MMOE(BaseModel):
172
156
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
173
157
  self.towers.append(tower)
174
158
  self.prediction_layer = TaskHead(
175
- task_type=self.default_task, task_dims=[1] * self.nums_task
159
+ task_type=self.task, task_dims=[1] * self.nums_task
176
160
  )
177
161
  # Register regularization weights
178
162
  self.register_regularization_weights(
179
163
  embedding_attr="embedding", include_modules=["experts", "gates", "towers"]
180
164
  )
181
- self.compile(
182
- optimizer=optimizer,
183
- optimizer_params=optimizer_params,
184
- loss=self.loss,
185
- loss_params=loss_params,
186
- )
187
165
 
188
166
  def forward(self, x):
189
167
  # Get all embeddings and flatten
@@ -202,29 +202,21 @@ class PLE(BaseModel):
202
202
 
203
203
  def __init__(
204
204
  self,
205
- dense_features: list[DenseFeature],
206
- sparse_features: list[SparseFeature],
207
- sequence_features: list[SequenceFeature],
208
- shared_expert_params: dict,
209
- specific_expert_params: dict | list[dict],
210
- num_shared_experts: int,
211
- num_specific_experts: int,
212
- num_levels: int,
213
- tower_params_list: list[dict],
214
- target: list[str],
205
+ dense_features: list[DenseFeature] | None = None,
206
+ sparse_features: list[SparseFeature] | None = None,
207
+ sequence_features: list[SequenceFeature] | None = None,
208
+ shared_expert_params: dict | None = None,
209
+ specific_expert_params: dict | list[dict] | None = None,
210
+ num_shared_experts: int = 2,
211
+ num_specific_experts: int = 2,
212
+ num_levels: int = 2,
213
+ tower_params_list: list[dict] | None = None,
214
+ target: list[str] | None = None,
215
215
  task: str | list[str] | None = None,
216
- optimizer: str = "adam",
217
- optimizer_params: dict | None = None,
218
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
219
- loss_params: dict | list[dict] | None = None,
220
- embedding_l1_reg=0.0,
221
- dense_l1_reg=0.0,
222
- embedding_l2_reg=0.0,
223
- dense_l2_reg=0.0,
224
216
  **kwargs,
225
217
  ):
226
218
 
227
- self.nums_task = len(target)
219
+ self.nums_task = len(target) if target is not None else 1
228
220
 
229
221
  resolved_task = task
230
222
  if resolved_task is None:
@@ -244,23 +236,15 @@ class PLE(BaseModel):
244
236
  sequence_features=sequence_features,
245
237
  target=target,
246
238
  task=resolved_task,
247
- embedding_l1_reg=embedding_l1_reg,
248
- dense_l1_reg=dense_l1_reg,
249
- embedding_l2_reg=embedding_l2_reg,
250
- dense_l2_reg=dense_l2_reg,
251
239
  **kwargs,
252
240
  )
253
241
 
254
- self.loss = loss
255
- if self.loss is None:
256
- self.loss = "bce"
257
242
  # Number of tasks, experts, and levels
258
243
  self.nums_task = len(target)
259
244
  self.num_shared_experts = num_shared_experts
260
245
  self.num_specific_experts = num_specific_experts
261
246
  self.num_levels = num_levels
262
- if optimizer_params is None:
263
- optimizer_params = {}
247
+
264
248
  if len(tower_params_list) != self.nums_task:
265
249
  raise ValueError(
266
250
  f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.nums_task})"
@@ -302,18 +286,12 @@ class PLE(BaseModel):
302
286
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
303
287
  self.towers.append(tower)
304
288
  self.prediction_layer = TaskHead(
305
- task_type=self.default_task, task_dims=[1] * self.nums_task
289
+ task_type=self.task, task_dims=[1] * self.nums_task
306
290
  )
307
291
  # Register regularization weights
308
292
  self.register_regularization_weights(
309
293
  embedding_attr="embedding", include_modules=["cgc_layers", "towers"]
310
294
  )
311
- self.compile(
312
- optimizer=optimizer,
313
- optimizer_params=optimizer_params,
314
- loss=self.loss,
315
- loss_params=loss_params,
316
- )
317
295
 
318
296
  def forward(self, x):
319
297
  # Get all embeddings and flatten