nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +4 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +209 -316
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/loss_utils.py +5 -30
  16. nextrec/models/multi_task/esmm.py +4 -6
  17. nextrec/models/multi_task/mmoe.py +4 -6
  18. nextrec/models/multi_task/ple.py +6 -8
  19. nextrec/models/multi_task/poso.py +5 -7
  20. nextrec/models/multi_task/share_bottom.py +6 -8
  21. nextrec/models/ranking/afm.py +4 -6
  22. nextrec/models/ranking/autoint.py +4 -6
  23. nextrec/models/ranking/dcn.py +8 -7
  24. nextrec/models/ranking/dcn_v2.py +4 -6
  25. nextrec/models/ranking/deepfm.py +5 -7
  26. nextrec/models/ranking/dien.py +8 -7
  27. nextrec/models/ranking/din.py +8 -7
  28. nextrec/models/ranking/eulernet.py +5 -7
  29. nextrec/models/ranking/ffm.py +5 -7
  30. nextrec/models/ranking/fibinet.py +4 -6
  31. nextrec/models/ranking/fm.py +4 -6
  32. nextrec/models/ranking/lr.py +4 -6
  33. nextrec/models/ranking/masknet.py +8 -9
  34. nextrec/models/ranking/pnn.py +4 -6
  35. nextrec/models/ranking/widedeep.py +5 -7
  36. nextrec/models/ranking/xdeepfm.py +8 -7
  37. nextrec/models/retrieval/dssm.py +4 -10
  38. nextrec/models/retrieval/dssm_v2.py +0 -6
  39. nextrec/models/retrieval/mind.py +4 -10
  40. nextrec/models/retrieval/sdm.py +4 -10
  41. nextrec/models/retrieval/youtube_dnn.py +4 -10
  42. nextrec/models/sequential/hstu.py +1 -3
  43. nextrec/utils/__init__.py +12 -14
  44. nextrec/utils/config.py +15 -5
  45. nextrec/utils/console.py +2 -2
  46. nextrec/utils/feature.py +2 -2
  47. nextrec/utils/torch_utils.py +57 -112
  48. nextrec/utils/types.py +59 -0
  49. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
  50. nextrec-0.4.21.dist-info/RECORD +81 -0
  51. nextrec-0.4.20.dist-info/RECORD +0 -79
  52. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
  53. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py CHANGED
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
 
13
13
  __all__ = [
14
14
  "Session",
15
- "resolve_save_path",
15
+ "get_save_path",
16
16
  "create_session",
17
17
  ]
18
18
 
@@ -89,7 +89,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
89
89
  return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
90
90
 
91
91
 
92
- def resolve_save_path(
92
+ def get_save_path(
93
93
  path: str | os.PathLike | Path | None,
94
94
  default_dir: str | Path,
95
95
  default_name: str,
@@ -0,0 +1,323 @@
1
+ """
2
+ Summary utilities for BaseModel.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from typing import Any, Literal
9
+
10
+ import numpy as np
11
+ from torch.utils.data import DataLoader
12
+
13
+ from nextrec.basic.loggers import colorize, format_kv
14
+ from nextrec.data.data_processing import extract_label_arrays, get_data_length
15
+
16
+
17
+ class SummarySet:
18
+ def build_data_summary(
19
+ self, data: Any, data_loader: DataLoader | None, sample_key: str
20
+ ):
21
+ dataset = data_loader.dataset if data_loader else None
22
+
23
+ train_size = get_data_length(dataset)
24
+ if train_size is None:
25
+ train_size = get_data_length(data)
26
+
27
+ labels = extract_label_arrays(dataset, self.target_columns)
28
+ if labels is None:
29
+ labels = extract_label_arrays(data, self.target_columns)
30
+
31
+ summary = {}
32
+ if train_size is not None:
33
+ summary[sample_key] = int(train_size)
34
+
35
+ if labels:
36
+ task_types = list(self.task) if isinstance(self.task, list) else [self.task]
37
+ if len(task_types) != len(self.target_columns):
38
+ task_types = [task_types[0]] * len(self.target_columns)
39
+
40
+ label_distributions = {}
41
+ for target_name, task_type in zip(self.target_columns, task_types):
42
+ values = labels.get(target_name)
43
+ if values is None:
44
+ continue
45
+ values = np.asarray(values).reshape(-1)
46
+ if values.size == 0:
47
+ continue
48
+ if task_type == "regression":
49
+ values = values.astype(float)
50
+ stats = {
51
+ "mean": np.nanmean(values),
52
+ "std": np.nanstd(values),
53
+ "min": np.nanmin(values),
54
+ "p25": np.nanpercentile(values, 25),
55
+ "p50": np.nanpercentile(values, 50),
56
+ "p75": np.nanpercentile(values, 75),
57
+ "max": np.nanmax(values),
58
+ }
59
+ stat_text = ", ".join(
60
+ f"{key}={value:.6g}" for key, value in stats.items()
61
+ )
62
+ label_distributions[target_name] = {
63
+ "task": task_type,
64
+ "lines": [("stats", stat_text)],
65
+ }
66
+ else:
67
+ uniques, counts = np.unique(values, return_counts=True)
68
+ total = counts.sum()
69
+ if total == 0:
70
+ continue
71
+ label_parts = []
72
+ for label_value, count in zip(uniques, counts):
73
+ if isinstance(label_value, (int, np.integer)):
74
+ label_str = f"{int(label_value)}"
75
+ elif isinstance(
76
+ label_value, (float, np.floating)
77
+ ) and np.isclose(label_value, int(label_value)):
78
+ label_str = f"{int(label_value)}"
79
+ else:
80
+ label_str = f"{label_value}"
81
+ ratio = count / total
82
+ label_parts.append((label_str, f"{count} ({ratio:.2%})"))
83
+ label_distributions[target_name] = {
84
+ "task": task_type,
85
+ "lines": label_parts,
86
+ }
87
+
88
+ if label_distributions:
89
+ summary["label_distributions"] = label_distributions
90
+
91
+ return summary or None
92
+
93
+ def build_train_data_summary(
94
+ self, train_data: Any, train_loader: DataLoader | None
95
+ ):
96
+ return self.build_data_summary(
97
+ data=train_data,
98
+ data_loader=train_loader,
99
+ sample_key="train_samples",
100
+ )
101
+
102
+ def build_valid_data_summary(
103
+ self, valid_data: Any, valid_loader: DataLoader | None
104
+ ):
105
+ return self.build_data_summary(
106
+ data=valid_data,
107
+ data_loader=valid_loader,
108
+ sample_key="valid_samples",
109
+ )
110
+
111
+ def summary(
112
+ self,
113
+ sections: list[Literal["feature", "model", "train", "data"]] | None = None,
114
+ ):
115
+ logger = logging.getLogger()
116
+ allowed_sections = {
117
+ "feature": "Feature Configuration",
118
+ "model": "Model Parameters",
119
+ "train": "Training Configuration",
120
+ "data": "Data Summary",
121
+ }
122
+ if sections is None:
123
+ selected_sections = set(allowed_sections.values())
124
+ else:
125
+ selected_sections = set()
126
+ invalid_sections = []
127
+ for section in sections:
128
+ key = str(section).strip().lower()
129
+ if key in allowed_sections:
130
+ selected_sections.add(allowed_sections[key])
131
+ else:
132
+ invalid_sections.append(section)
133
+ if invalid_sections:
134
+ raise ValueError(
135
+ "[BaseModel-summary Error] Unknown summary section(s): "
136
+ f"{invalid_sections}. Allowed: {list(allowed_sections.keys())}"
137
+ )
138
+
139
+ logger.info("")
140
+ logger.info(
141
+ colorize(
142
+ f"Model Summary: {self.model_name.upper()}",
143
+ color="bright_blue",
144
+ bold=True,
145
+ )
146
+ )
147
+ logger.info("")
148
+
149
+ if "Feature Configuration" in selected_sections:
150
+ logger.info("")
151
+ logger.info(colorize("Feature Configuration", color="cyan", bold=True))
152
+ logger.info(colorize("-" * 80, color="cyan"))
153
+
154
+ if self.dense_features:
155
+ logger.info(f"Dense Features ({len(self.dense_features)}):")
156
+ for i, feat in enumerate(self.dense_features, 1):
157
+ embed_dim = (
158
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
159
+ )
160
+ logger.info(f" {i}. {feat.name:20s}")
161
+
162
+ if self.sparse_features:
163
+ logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
164
+
165
+ max_name_len = max(len(feat.name) for feat in self.sparse_features)
166
+ max_embed_name_len = max(
167
+ len(feat.embedding_name) for feat in self.sparse_features
168
+ )
169
+ name_width = max(max_name_len, 10) + 2
170
+ embed_name_width = max(max_embed_name_len, 15) + 2
171
+
172
+ logger.info(
173
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
174
+ )
175
+ logger.info(
176
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
177
+ )
178
+ for i, feat in enumerate(self.sparse_features, 1):
179
+ vocab_size = (
180
+ feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
181
+ )
182
+ embed_dim = (
183
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
184
+ )
185
+ logger.info(
186
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
187
+ )
188
+
189
+ if self.sequence_features:
190
+ logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
191
+
192
+ max_name_len = max(len(feat.name) for feat in self.sequence_features)
193
+ max_embed_name_len = max(
194
+ len(feat.embedding_name) for feat in self.sequence_features
195
+ )
196
+ name_width = max(max_name_len, 10) + 2
197
+ embed_name_width = max(max_embed_name_len, 15) + 2
198
+
199
+ logger.info(
200
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
201
+ )
202
+ logger.info(
203
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
204
+ )
205
+ for i, feat in enumerate(self.sequence_features, 1):
206
+ vocab_size = (
207
+ feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
208
+ )
209
+ embed_dim = (
210
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
211
+ )
212
+ max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
213
+ logger.info(
214
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
215
+ )
216
+
217
+ if "Model Parameters" in selected_sections:
218
+ logger.info("")
219
+ logger.info(colorize("Model Parameters", color="cyan", bold=True))
220
+ logger.info(colorize("-" * 80, color="cyan"))
221
+
222
+ # Model Architecture
223
+ logger.info("Model Architecture:")
224
+ logger.info(str(self))
225
+ logger.info("")
226
+
227
+ total_params = sum(p.numel() for p in self.parameters())
228
+ trainable_params = sum(
229
+ p.numel() for p in self.parameters() if p.requires_grad
230
+ )
231
+ non_trainable_params = total_params - trainable_params
232
+
233
+ logger.info(f"Total Parameters: {total_params:,}")
234
+ logger.info(f"Trainable Parameters: {trainable_params:,}")
235
+ logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
236
+
237
+ logger.info("Layer-wise Parameters:")
238
+ for name, module in self.named_children():
239
+ layer_params = sum(p.numel() for p in module.parameters())
240
+ if layer_params > 0:
241
+ logger.info(f" {name:30s}: {layer_params:,}")
242
+
243
+ if "Training Configuration" in selected_sections:
244
+ logger.info("")
245
+ logger.info(colorize("Training Configuration", color="cyan", bold=True))
246
+ logger.info(colorize("-" * 80, color="cyan"))
247
+
248
+ logger.info(f"Task Type: {self.task}")
249
+ logger.info(f"Number of Tasks: {self.nums_task}")
250
+ logger.info(f"Metrics: {self.metrics}")
251
+ logger.info(f"Target Columns: {self.target_columns}")
252
+ logger.info(f"Device: {self.device}")
253
+
254
+ if hasattr(self, "optimizer_name"):
255
+ logger.info(f"Optimizer: {self.optimizer_name}")
256
+ if self.optimizer_params:
257
+ for key, value in self.optimizer_params.items():
258
+ logger.info(f" {key:25s}: {value}")
259
+
260
+ if hasattr(self, "scheduler_name") and self.scheduler_name:
261
+ logger.info(f"Scheduler: {self.scheduler_name}")
262
+ if self.scheduler_params:
263
+ for key, value in self.scheduler_params.items():
264
+ logger.info(f" {key:25s}: {value}")
265
+
266
+ if hasattr(self, "loss_config"):
267
+ logger.info(f"Loss Function: {self.loss_config}")
268
+ if hasattr(self, "loss_weights"):
269
+ logger.info(f"Loss Weights: {self.loss_weights}")
270
+ if hasattr(self, "grad_norm"):
271
+ logger.info(f"GradNorm Enabled: {self.grad_norm is not None}")
272
+ if self.grad_norm is not None:
273
+ grad_lr = self.grad_norm.optimizer.param_groups[0].get("lr")
274
+ logger.info(f" GradNorm alpha: {self.grad_norm.alpha}")
275
+ logger.info(f" GradNorm lr: {grad_lr}")
276
+
277
+ logger.info("Regularization:")
278
+ logger.info(f" Embedding L1: {self.embedding_l1_reg}")
279
+ logger.info(f" Embedding L2: {self.embedding_l2_reg}")
280
+ logger.info(f" Dense L1: {self.dense_l1_reg}")
281
+ logger.info(f" Dense L2: {self.dense_l2_reg}")
282
+
283
+ logger.info("Other Settings:")
284
+ logger.info(f" Early Stop Patience: {self.early_stop_patience}")
285
+ logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
286
+ logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
287
+ logger.info(f" Session ID: {self.session_id}")
288
+ logger.info(f" Features Config Path: {self.features_config_path}")
289
+ logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
290
+
291
+ if "Data Summary" in selected_sections and (
292
+ self.train_data_summary or self.valid_data_summary
293
+ ):
294
+ logger.info("")
295
+ logger.info(colorize("Data Summary", color="cyan", bold=True))
296
+ logger.info(colorize("-" * 80, color="cyan"))
297
+ if self.train_data_summary:
298
+ train_samples = self.train_data_summary.get("train_samples")
299
+ if train_samples is not None:
300
+ logger.info(format_kv("Train Samples", f"{train_samples:,}"))
301
+
302
+ label_distributions = self.train_data_summary.get("label_distributions")
303
+ if isinstance(label_distributions, dict):
304
+ for target_name, details in label_distributions.items():
305
+ lines = details.get("lines", [])
306
+ logger.info(f"{target_name}:")
307
+ for label, value in lines:
308
+ logger.info(format_kv(label, value))
309
+
310
+ if self.valid_data_summary:
311
+ if self.train_data_summary:
312
+ logger.info("")
313
+ valid_samples = self.valid_data_summary.get("valid_samples")
314
+ if valid_samples is not None:
315
+ logger.info(format_kv("Valid Samples", f"{valid_samples:,}"))
316
+
317
+ label_distributions = self.valid_data_summary.get("label_distributions")
318
+ if isinstance(label_distributions, dict):
319
+ for target_name, details in label_distributions.items():
320
+ lines = details.get("lines", [])
321
+ logger.info(f"{target_name}:")
322
+ for label, value in lines:
323
+ logger.info(format_kv(label, value))
nextrec/cli.py CHANGED
@@ -48,7 +48,7 @@ from nextrec.utils.data import (
48
48
  read_yaml,
49
49
  resolve_file_paths,
50
50
  )
51
- from nextrec.utils.feature import normalize_to_list
51
+ from nextrec.utils.feature import to_list
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
@@ -111,7 +111,7 @@ def train_model(train_config_path: str) -> None:
111
111
 
112
112
  # train data
113
113
  data_path = resolve_path(data_cfg["path"], config_dir)
114
- target = normalize_to_list(data_cfg["target"])
114
+ target = to_list(data_cfg["target"])
115
115
  file_paths: List[str] = []
116
116
  file_type: str | None = None
117
117
  streaming_train_files: List[str] | None = None
@@ -507,7 +507,7 @@ def predict_model(predict_config_path: str) -> None:
507
507
  or model_cfg.get("params", {}).get("target")
508
508
  )
509
509
  if target_override:
510
- target_cols = normalize_to_list(target_override)
510
+ target_cols = to_list(target_override)
511
511
 
512
512
  model = build_model_instance(
513
513
  model_cfg=model_cfg,
@@ -2,7 +2,7 @@
2
2
  Data processing utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 25/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -28,6 +28,50 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
28
28
  raise KeyError(f"Unsupported data type for extracting column {name}")
29
29
 
30
30
 
31
+ def to_numpy(values: Any) -> np.ndarray:
32
+ if isinstance(values, torch.Tensor):
33
+ return values.detach().cpu().numpy()
34
+ return np.asarray(values)
35
+
36
+
37
+ def get_data_length(data: Any) -> int | None:
38
+ if data is None:
39
+ return None
40
+ if isinstance(data, pd.DataFrame):
41
+ return len(data)
42
+ if isinstance(data, dict):
43
+ if not data:
44
+ return None
45
+ sample_key = next(iter(data))
46
+ return len(data[sample_key])
47
+ try:
48
+ return len(data)
49
+ except TypeError:
50
+ return None
51
+
52
+
53
+ def extract_label_arrays(
54
+ data: Any, target_columns: list[str]
55
+ ) -> dict[str, np.ndarray] | None:
56
+ if not target_columns or data is None:
57
+ return None
58
+
59
+ if isinstance(data, (dict, pd.DataFrame)):
60
+ label_source = data
61
+ elif hasattr(data, "labels"):
62
+ label_source = data.labels
63
+ else:
64
+ return None
65
+
66
+ labels: dict[str, np.ndarray] = {}
67
+ for name in target_columns:
68
+ column = get_column_data(label_source, name)
69
+ if column is None:
70
+ continue
71
+ labels[name] = to_numpy(column)
72
+ return labels or None
73
+
74
+
31
75
  def split_dict_random(data_dict, test_size=0.2, random_state=None):
32
76
 
33
77
  lengths = [len(v) for v in data_dict.values()]
@@ -424,10 +424,10 @@ def normalize_sequence_column(column, feature: SequenceFeature) -> np.ndarray:
424
424
  sequences.append(np.asarray(seq, dtype=np.int64))
425
425
  else:
426
426
  sequences.append(np.asarray([seq], dtype=np.int64))
427
- max_len = getattr(feature, "max_len", 0)
427
+ max_len = feature.max_len if feature.max_len is not None else 0
428
428
  if max_len <= 0:
429
429
  max_len = max((len(seq) for seq in sequences), default=1)
430
- pad_value = getattr(feature, "padding_idx", 0)
430
+ pad_value = feature.padding_idx if feature.padding_idx is not None else 0
431
431
  padded = [
432
432
  (
433
433
  seq[:max_len]
@@ -30,7 +30,7 @@ from sklearn.preprocessing import (
30
30
  from nextrec.__version__ import __version__
31
31
  from nextrec.basic.features import FeatureSet
32
32
  from nextrec.basic.loggers import colorize
33
- from nextrec.basic.session import resolve_save_path
33
+ from nextrec.basic.session import get_save_path
34
34
  from nextrec.data.data_processing import hash_md5_mod
35
35
  from nextrec.utils.console import progress
36
36
  from nextrec.utils.data import (
@@ -957,7 +957,7 @@ class DataProcessor(FeatureSet):
957
957
  save_path = Path(save_path)
958
958
  if not self.is_fitted:
959
959
  logger.warning("Saving unfitted DataProcessor")
960
- target_path = resolve_save_path(
960
+ target_path = get_save_path(
961
961
  path=save_path,
962
962
  default_dir=Path(os.getcwd()),
963
963
  default_name="fitted_processor",
@@ -6,8 +6,6 @@ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- from typing import Literal
10
-
11
9
  import torch.nn as nn
12
10
 
13
11
  from nextrec.loss.listwise import (
@@ -19,6 +17,7 @@ from nextrec.loss.listwise import (
19
17
  )
20
18
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
21
19
  from nextrec.loss.pointwise import ClassBalancedFocalLoss, FocalLoss, WeightedBCELoss
20
+ from nextrec.utils.types import LossName
22
21
 
23
22
  VALID_TASK_TYPES = [
24
23
  "binary",
@@ -26,33 +25,6 @@ VALID_TASK_TYPES = [
26
25
  "regression",
27
26
  ]
28
27
 
29
- # Define all supported loss types
30
- LossType = Literal[
31
- # Pointwise losses
32
- "bce",
33
- "binary_crossentropy",
34
- "weighted_bce",
35
- "focal",
36
- "focal_loss",
37
- "cb_focal",
38
- "class_balanced_focal",
39
- "crossentropy",
40
- "ce",
41
- "mse",
42
- "mae",
43
- # Pairwise ranking losses
44
- "bpr",
45
- "hinge",
46
- "triplet",
47
- # Listwise ranking losses
48
- "sampled_softmax",
49
- "softmax",
50
- "infonce",
51
- "listnet",
52
- "listmle",
53
- "approx_ndcg",
54
- ]
55
-
56
28
 
57
29
  def build_cb_focal(kw):
58
30
  if "class_counts" not in kw:
@@ -60,7 +32,10 @@ def build_cb_focal(kw):
60
32
  return ClassBalancedFocalLoss(**kw)
61
33
 
62
34
 
63
- def get_loss_fn(loss=None, **kw) -> nn.Module:
35
+ def get_loss_fn(
36
+ loss: LossName | None | nn.Module = None,
37
+ **kw,
38
+ ) -> nn.Module:
64
39
  """
65
40
  Get loss function by name or return the provided loss module.
66
41
 
@@ -83,11 +83,10 @@ class ESMM(BaseModel):
83
83
  optimizer_params: dict | None = None,
84
84
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
85
85
  loss_params: dict | list[dict] | None = None,
86
- device: str = "cpu",
87
- embedding_l1_reg=1e-6,
88
- dense_l1_reg=1e-5,
89
- embedding_l2_reg=1e-5,
90
- dense_l2_reg=1e-4,
86
+ embedding_l1_reg=0.0,
87
+ dense_l1_reg=0.0,
88
+ embedding_l2_reg=0.0,
89
+ dense_l2_reg=0.0,
91
90
  **kwargs,
92
91
  ):
93
92
 
@@ -121,7 +120,6 @@ class ESMM(BaseModel):
121
120
  sequence_features=sequence_features,
122
121
  target=target,
123
122
  task=resolved_task, # Both CTR and CTCVR are binary classification
124
- device=device,
125
123
  embedding_l1_reg=embedding_l1_reg,
126
124
  dense_l1_reg=dense_l1_reg,
127
125
  embedding_l2_reg=embedding_l2_reg,
@@ -86,11 +86,10 @@ class MMOE(BaseModel):
86
86
  optimizer_params: dict | None = None,
87
87
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
88
88
  loss_params: dict | list[dict] | None = None,
89
- device: str = "cpu",
90
- embedding_l1_reg=1e-6,
91
- dense_l1_reg=1e-5,
92
- embedding_l2_reg=1e-5,
93
- dense_l2_reg=1e-4,
89
+ embedding_l1_reg=0.0,
90
+ dense_l1_reg=0.0,
91
+ embedding_l2_reg=0.0,
92
+ dense_l2_reg=0.0,
94
93
  **kwargs,
95
94
  ):
96
95
 
@@ -127,7 +126,6 @@ class MMOE(BaseModel):
127
126
  sequence_features=sequence_features,
128
127
  target=target,
129
128
  task=resolved_task,
130
- device=device,
131
129
  embedding_l1_reg=embedding_l1_reg,
132
130
  dense_l1_reg=dense_l1_reg,
133
131
  embedding_l2_reg=embedding_l2_reg,
@@ -195,7 +195,7 @@ class PLE(BaseModel):
195
195
 
196
196
  @property
197
197
  def default_task(self):
198
- nums_task = getattr(self, "nums_task", None)
198
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
199
199
  if nums_task is not None and nums_task > 0:
200
200
  return ["binary"] * nums_task
201
201
  return ["binary"]
@@ -217,11 +217,10 @@ class PLE(BaseModel):
217
217
  optimizer_params: dict | None = None,
218
218
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
219
219
  loss_params: dict | list[dict] | None = None,
220
- device: str = "cpu",
221
- embedding_l1_reg=1e-6,
222
- dense_l1_reg=1e-5,
223
- embedding_l2_reg=1e-5,
224
- dense_l2_reg=1e-4,
220
+ embedding_l1_reg=0.0,
221
+ dense_l1_reg=0.0,
222
+ embedding_l2_reg=0.0,
223
+ dense_l2_reg=0.0,
225
224
  **kwargs,
226
225
  ):
227
226
 
@@ -245,7 +244,6 @@ class PLE(BaseModel):
245
244
  sequence_features=sequence_features,
246
245
  target=target,
247
246
  task=resolved_task,
248
- device=device,
249
247
  embedding_l1_reg=embedding_l1_reg,
250
248
  dense_l1_reg=dense_l1_reg,
251
249
  embedding_l2_reg=embedding_l2_reg,
@@ -273,7 +271,7 @@ class PLE(BaseModel):
273
271
  # Calculate input dimension
274
272
  input_dim = self.embedding.input_dim
275
273
  # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
276
- # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
274
+ # dense_input_dim = sum([(f.embedding_dim or 1) for f in dense_features])
277
275
  # input_dim = emb_dim_total + dense_input_dim
278
276
 
279
277
  # Get expert output dimension
@@ -290,7 +290,7 @@ class POSO(BaseModel):
290
290
 
291
291
  @property
292
292
  def default_task(self) -> list[str]:
293
- nums_task = getattr(self, "nums_task", None)
293
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
294
294
  if nums_task is not None and nums_task > 0:
295
295
  return ["binary"] * nums_task
296
296
  return ["binary"]
@@ -327,11 +327,10 @@ class POSO(BaseModel):
327
327
  optimizer_params: dict | None = None,
328
328
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
329
329
  loss_params: dict | list[dict] | None = None,
330
- device: str = "cpu",
331
- embedding_l1_reg: float = 1e-6,
332
- dense_l1_reg: float = 1e-5,
333
- embedding_l2_reg: float = 1e-5,
334
- dense_l2_reg: float = 1e-4,
330
+ embedding_l1_reg=0.0,
331
+ dense_l1_reg=0.0,
332
+ embedding_l2_reg=0.0,
333
+ dense_l2_reg=0.0,
335
334
  **kwargs,
336
335
  ):
337
336
  self.nums_task = len(target)
@@ -360,7 +359,6 @@ class POSO(BaseModel):
360
359
  sequence_features=sequence_features,
361
360
  target=target,
362
361
  task=resolved_task,
363
- device=device,
364
362
  embedding_l1_reg=embedding_l1_reg,
365
363
  dense_l1_reg=dense_l1_reg,
366
364
  embedding_l2_reg=embedding_l2_reg,