nextrec 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.2"
1
+ __version__ = "0.4.4"
nextrec/basic/layers.py CHANGED
@@ -51,11 +51,11 @@ class PredictionLayer(nn.Module):
51
51
 
52
52
  # slice offsets per task
53
53
  start = 0
54
- self._task_slices: list[tuple[int, int]] = []
54
+ self.task_slices: list[tuple[int, int]] = []
55
55
  for dim in self.task_dims:
56
56
  if dim < 1:
57
57
  raise ValueError("Each task dimension must be >= 1.")
58
- self._task_slices.append((start, start + dim))
58
+ self.task_slices.append((start, start + dim))
59
59
  start += dim
60
60
  if use_bias:
61
61
  self.bias = nn.Parameter(torch.zeros(self.total_dim))
@@ -71,7 +71,7 @@ class PredictionLayer(nn.Module):
71
71
  )
72
72
  logits = x if self.bias is None else x + self.bias
73
73
  outputs = []
74
- for task_type, (start, end) in zip(self.task_types, self._task_slices):
74
+ for task_type, (start, end) in zip(self.task_types, self.task_slices):
75
75
  task_logits = logits[..., start:end] # logits for the current task
76
76
  if self.return_logits:
77
77
  outputs.append(task_logits)
@@ -367,20 +367,29 @@ class MLP(nn.Module):
367
367
  dims: list[int] | None = None,
368
368
  dropout: float = 0.0,
369
369
  activation: str = "relu",
370
+ use_norm: bool = True,
371
+ norm_type: str = "layer_norm",
370
372
  ):
371
373
  super().__init__()
372
374
  if dims is None:
373
375
  dims = []
374
376
  layers = []
375
377
  current_dim = input_dim
376
-
377
378
  for i_dim in dims:
378
379
  layers.append(nn.Linear(current_dim, i_dim))
379
- layers.append(nn.BatchNorm1d(i_dim))
380
+ if use_norm:
381
+ if norm_type == "batch_norm":
382
+ # **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
383
+ layers.append(nn.BatchNorm1d(i_dim))
384
+ elif norm_type == "layer_norm":
385
+ layers.append(nn.LayerNorm(i_dim))
386
+ else:
387
+ raise ValueError(f"Unsupported norm_type: {norm_type}")
388
+
380
389
  layers.append(activation_layer(activation))
381
390
  layers.append(nn.Dropout(p=dropout))
382
391
  current_dim = i_dim
383
-
392
+ # output layer
384
393
  if output_layer:
385
394
  layers.append(nn.Linear(current_dim, 1))
386
395
  self.output_dim = 1
@@ -471,6 +480,21 @@ class BiLinearInteractionLayer(nn.Module):
471
480
  return torch.cat(bilinear_list, dim=1)
472
481
 
473
482
 
483
+ class HadamardInteractionLayer(nn.Module):
484
+ """Hadamard interaction layer for Deep-FiBiNET (0 case in 01/11)."""
485
+
486
+ def __init__(self, num_fields: int):
487
+ super().__init__()
488
+ self.num_fields = num_fields
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ # x: [B, F, D]
492
+ feature_emb = torch.split(x, 1, dim=1) # list of F tensors [B,1,D]
493
+
494
+ hadamard_list = [v_i * v_j for (v_i, v_j) in combinations(feature_emb, 2)]
495
+ return torch.cat(hadamard_list, dim=1) # [B, num_pairs, D]
496
+
497
+
474
498
  class MultiHeadSelfAttention(nn.Module):
475
499
  def __init__(
476
500
  self,
@@ -542,7 +566,7 @@ class AttentionPoolingLayer(nn.Module):
542
566
  embedding_dim: int,
543
567
  hidden_units: list = [80, 40],
544
568
  activation: str = "sigmoid",
545
- use_softmax: bool = True,
569
+ use_softmax: bool = False,
546
570
  ):
547
571
  super().__init__()
548
572
  self.embedding_dim = embedding_dim
@@ -553,7 +577,7 @@ class AttentionPoolingLayer(nn.Module):
553
577
  layers = []
554
578
  for hidden_unit in hidden_units:
555
579
  layers.append(nn.Linear(input_dim, hidden_unit))
556
- layers.append(activation_layer(activation))
580
+ layers.append(activation_layer(activation, emb_size=hidden_unit))
557
581
  input_dim = hidden_unit
558
582
  layers.append(nn.Linear(input_dim, 1))
559
583
  self.attention_net = nn.Sequential(*layers)
nextrec/basic/loggers.py CHANGED
@@ -103,7 +103,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
103
103
  session = create_session(str(session_id) if session_id is not None else None)
104
104
  log_dir = session.logs_dir
105
105
  log_dir.mkdir(parents=True, exist_ok=True)
106
- log_file = log_dir / f"{session.log_basename}.log"
106
+ log_file = log_dir / "runs.log"
107
107
 
108
108
  console_format = "%(message)s"
109
109
  file_format = "%(asctime)s - %(levelname)s - %(message)s"
nextrec/basic/metrics.py CHANGED
@@ -260,7 +260,7 @@ def compute_mrr_at_k(
260
260
  order = np.argsort(scores)[::-1]
261
261
  k_user = min(k, idx.size)
262
262
  topk = order[:k_user]
263
- ranked_labels = labels[order]
263
+ ranked_labels = labels[topk]
264
264
  rr = 0.0
265
265
  for rank, lab in enumerate(ranked_labels[:k_user], start=1):
266
266
  if lab > 0:
@@ -612,6 +612,7 @@ def evaluate_metrics(
612
612
  if task_type in ["binary", "multilabel"]:
613
613
  should_compute = metric_lower in {
614
614
  "auc",
615
+ "gauc",
615
616
  "ks",
616
617
  "logloss",
617
618
  "accuracy",
nextrec/basic/model.py CHANGED
@@ -455,7 +455,7 @@ class BaseModel(FeatureSet, nn.Module):
455
455
  if hasattr(
456
456
  self, "prediction_layer"
457
457
  ): # we need to use registered task_slices for multi-task and multi-class
458
- slices = self.prediction_layer._task_slices # type: ignore
458
+ slices = self.prediction_layer.task_slices # type: ignore
459
459
  else:
460
460
  slices = [(i, i + 1) for i in range(self.nums_task)]
461
461
  task_losses = []
@@ -1369,7 +1369,7 @@ class BaseModel(FeatureSet, nn.Module):
1369
1369
  pred_columns: list[str] = []
1370
1370
  if self.target_columns:
1371
1371
  for name in self.target_columns[:num_outputs]:
1372
- pred_columns.append(f"{name}_pred")
1372
+ pred_columns.append(f"{name}")
1373
1373
  while len(pred_columns) < num_outputs:
1374
1374
  pred_columns.append(f"pred_{len(pred_columns)}")
1375
1375
  if include_ids and predict_id_columns:
@@ -1496,7 +1496,7 @@ class BaseModel(FeatureSet, nn.Module):
1496
1496
  pred_columns = []
1497
1497
  if self.target_columns:
1498
1498
  for name in self.target_columns[:num_outputs]:
1499
- pred_columns.append(f"{name}_pred")
1499
+ pred_columns.append(f"{name}")
1500
1500
  while len(pred_columns) < num_outputs:
1501
1501
  pred_columns.append(f"pred_{len(pred_columns)}")
1502
1502
 
nextrec/cli.py CHANGED
@@ -8,10 +8,10 @@ following script to execute the desired operations.
8
8
 
9
9
  Examples:
10
10
  # Train a model
11
- nextrec --mode=train --train_config=tutorials/iflytek/scripts/masknet/train_config.yaml
11
+ nextrec --mode=train --train_config=nextrec_cli_preset/train_config.yaml
12
12
 
13
13
  # Run prediction
14
- nextrec --mode=predict --predict_config=tutorials/iflytek/scripts/masknet/predict_config.yaml
14
+ nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
15
15
 
16
16
  Date: create on 06/12/2025
17
17
  Author: Yang Zhou, zyaztec@gmail.com
@@ -33,7 +33,6 @@ from nextrec.data.preprocessor import DataProcessor
33
33
  from nextrec.utils.config import (
34
34
  build_feature_objects,
35
35
  build_model_instance,
36
- extract_feature_groups,
37
36
  register_processor_features,
38
37
  resolve_path,
39
38
  select_features,
@@ -115,16 +114,13 @@ def train_model(train_config_path: str) -> None:
115
114
  df = read_table(data_path, data_cfg.get("format"))
116
115
  df_columns = list(df.columns)
117
116
 
118
- # for some models have independent feature groups, we need to extract them here
119
- feature_groups, grouped_columns = extract_feature_groups(feature_cfg, df_columns)
120
- if feature_groups:
121
- model_cfg.setdefault("params", {})
122
- model_cfg["params"].setdefault("feature_groups", feature_groups)
123
-
124
117
  dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
125
- used_columns = (
126
- dense_names + sparse_names + sequence_names + grouped_columns + target
127
- )
118
+
119
+ # Extract id_column from data config for GAUC metrics
120
+ id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
121
+ id_columns = [id_column] if id_column else []
122
+
123
+ used_columns = dense_names + sparse_names + sequence_names + target + id_columns
128
124
 
129
125
  # keep order but drop duplicates
130
126
  seen = set()
@@ -183,17 +179,16 @@ def train_model(train_config_path: str) -> None:
183
179
  streaming_valid_files = file_paths[-val_count:]
184
180
  streaming_train_files = file_paths[:-val_count]
185
181
  logger.info(
186
- "使用 valid_ratio=%.3f 切分文件: 训练 %d 个文件, 验证 %d 个文件",
187
- ratio,
188
- len(streaming_train_files),
189
- len(streaming_valid_files),
182
+ f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
190
183
  )
191
184
  train_data: Dict[str, Any]
192
185
  valid_data: Dict[str, Any] | None
193
186
 
194
187
  if val_data_path and not streaming:
195
188
  # Use specified validation dataset path
196
- logger.info("使用指定的验证集路径: %s", val_data_path)
189
+ logger.info(
190
+ f"Validation using specified validation dataset path: {val_data_path}"
191
+ )
197
192
  val_data_resolved = resolve_path(val_data_path, config_dir)
198
193
  val_df = read_table(val_data_resolved, data_cfg.get("format"))
199
194
  val_df = val_df[unique_used_columns]
@@ -206,17 +201,21 @@ def train_model(train_config_path: str) -> None:
206
201
  valid_data = valid_data_result
207
202
  train_size = len(list(train_data.values())[0])
208
203
  valid_size = len(list(valid_data.values())[0])
209
- logger.info("训练集样本数: %s, 验证集样本数: %s", train_size, valid_size)
204
+ logger.info(
205
+ f"Sample count - Training set: {train_size}, Validation set: {valid_size}"
206
+ )
210
207
  elif streaming:
211
208
  train_data = None # type: ignore[assignment]
212
209
  valid_data = None
213
210
  if not val_data_path and not streaming_valid_files:
214
211
  logger.info(
215
- "流式训练模式,未指定验证集路径且未配置 valid_ratio,跳过验证集创建"
212
+ "Streaming training mode: No validation dataset path specified and valid_ratio not configured, skipping validation dataset creation"
216
213
  )
217
214
  else:
218
215
  # Split data using valid_ratio
219
- logger.info("使用 valid_ratio 切分数据: %s", data_cfg.get("valid_ratio", 0.2))
216
+ logger.info(
217
+ f"Splitting data using valid_ratio: {data_cfg.get('valid_ratio', 0.2)}"
218
+ )
220
219
  if not isinstance(processed, dict):
221
220
  raise TypeError("Processed data must be a dictionary for splitting")
222
221
  train_data, valid_data = split_dict_random(
@@ -230,6 +229,7 @@ def train_model(train_config_path: str) -> None:
230
229
  sparse_features=sparse_features,
231
230
  sequence_features=sequence_features,
232
231
  target=target,
232
+ id_columns=id_columns,
233
233
  processor=processor if streaming else None,
234
234
  )
235
235
  if streaming:
@@ -240,6 +240,7 @@ def train_model(train_config_path: str) -> None:
240
240
  shuffle=dataloader_cfg.get("train_shuffle", True),
241
241
  load_full=False,
242
242
  chunk_size=dataloader_chunk_size,
243
+ num_workers=dataloader_cfg.get("num_workers", 0),
243
244
  )
244
245
  valid_loader = None
245
246
  if val_data_path:
@@ -250,6 +251,7 @@ def train_model(train_config_path: str) -> None:
250
251
  shuffle=dataloader_cfg.get("valid_shuffle", False),
251
252
  load_full=False,
252
253
  chunk_size=dataloader_chunk_size,
254
+ num_workers=dataloader_cfg.get("num_workers", 0),
253
255
  )
254
256
  elif streaming_valid_files:
255
257
  valid_loader = dataloader.create_dataloader(
@@ -258,17 +260,20 @@ def train_model(train_config_path: str) -> None:
258
260
  shuffle=dataloader_cfg.get("valid_shuffle", False),
259
261
  load_full=False,
260
262
  chunk_size=dataloader_chunk_size,
263
+ num_workers=dataloader_cfg.get("num_workers", 0),
261
264
  )
262
265
  else:
263
266
  train_loader = dataloader.create_dataloader(
264
267
  data=train_data,
265
268
  batch_size=dataloader_cfg.get("train_batch_size", 512),
266
269
  shuffle=dataloader_cfg.get("train_shuffle", True),
270
+ num_workers=dataloader_cfg.get("num_workers", 0),
267
271
  )
268
272
  valid_loader = dataloader.create_dataloader(
269
273
  data=valid_data,
270
274
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
271
275
  shuffle=dataloader_cfg.get("valid_shuffle", False),
276
+ num_workers=dataloader_cfg.get("num_workers", 0),
272
277
  )
273
278
 
274
279
  model_cfg.setdefault("session_id", session_id)
@@ -300,6 +305,9 @@ def train_model(train_config_path: str) -> None:
300
305
  "batch_size", dataloader_cfg.get("train_batch_size", 512)
301
306
  ),
302
307
  shuffle=train_cfg.get("shuffle", True),
308
+ num_workers=dataloader_cfg.get("num_workers", 0),
309
+ user_id_column=id_column,
310
+ tensorboard=False,
303
311
  )
304
312
 
305
313
 
@@ -325,19 +333,15 @@ def predict_model(predict_config_path: str) -> None:
325
333
  model_cfg_path = resolve_path(
326
334
  cfg.get("model_config", "model_config.yaml"), config_dir
327
335
  )
328
- feature_cfg_path = resolve_path(
329
- cfg.get("feature_config", "feature_config.yaml"), config_dir
330
- )
336
+ # feature_cfg_path = resolve_path(
337
+ # cfg.get("feature_config", "feature_config.yaml"), config_dir
338
+ # )
331
339
 
332
340
  model_cfg = read_yaml(model_cfg_path)
333
- feature_cfg = read_yaml(feature_cfg_path)
341
+ # feature_cfg = read_yaml(feature_cfg_path)
334
342
  model_cfg.setdefault("session_id", session_id)
335
- feature_groups_raw = feature_cfg.get("feature_groups") or {}
336
343
  model_cfg.setdefault("params", {})
337
344
 
338
- # attach feature_groups in predict phase to avoid missing bindings
339
- model_cfg["params"]["feature_groups"] = feature_groups_raw
340
-
341
345
  processor = DataProcessor.load(processor_path)
342
346
 
343
347
  # Load checkpoint and ensure required parameters are passed
@@ -383,13 +387,6 @@ def predict_model(predict_config_path: str) -> None:
383
387
  if target_override:
384
388
  target_cols = normalize_to_list(target_override)
385
389
 
386
- # Recompute feature_groups with available feature names to drive bindings
387
- feature_group_names = [f.name for f in all_features if hasattr(f, "name")]
388
- parsed_feature_groups, _ = extract_feature_groups(feature_cfg, feature_group_names)
389
- if parsed_feature_groups:
390
- model_cfg.setdefault("params", {})
391
- model_cfg["params"]["feature_groups"] = parsed_feature_groups
392
-
393
390
  model = build_model_instance(
394
391
  model_cfg=model_cfg,
395
392
  model_cfg_path=model_cfg_path,
@@ -440,6 +437,7 @@ def predict_model(predict_config_path: str) -> None:
440
437
  return_dataframe=False,
441
438
  save_path=output_path,
442
439
  save_format=predict_cfg.get("save_format", "csv"),
440
+ num_workers=predict_cfg.get("num_workers", 0),
443
441
  )
444
442
  duration = time.time() - start
445
443
  logger.info(f"Prediction completed, results saved to: {output_path}")
@@ -448,7 +446,7 @@ def predict_model(predict_config_path: str) -> None:
448
446
  preview_rows = predict_cfg.get("preview_rows", 0)
449
447
  if preview_rows > 0:
450
448
  try:
451
- preview = pd.read_csv(output_path, nrows=preview_rows)
449
+ preview = pd.read_csv(output_path, nrows=preview_rows, low_memory=False)
452
450
  logger.info(f"Output preview:\n{preview}")
453
451
  except Exception as exc: # pragma: no cover
454
452
  logger.warning(f"Failed to read output preview: {exc}")
@@ -472,25 +470,21 @@ Examples:
472
470
  "--mode",
473
471
  choices=["train", "predict"],
474
472
  required=True,
475
- help="运行模式:train predict",
476
- )
477
- parser.add_argument("--train_config", help="训练配置文件路径")
478
- parser.add_argument("--predict_config", help="预测配置文件路径")
479
- parser.add_argument(
480
- "--config",
481
- help="通用配置文件路径(已废弃,建议使用 --train_config 或 --predict_config)",
473
+ help="Running mode: train or predict",
482
474
  )
475
+ parser.add_argument("--train_config", help="Training configuration file path")
476
+ parser.add_argument("--predict_config", help="Prediction configuration file path")
483
477
  args = parser.parse_args()
484
478
 
485
479
  if args.mode == "train":
486
- config_path = args.train_config or args.config
480
+ config_path = args.train_config
487
481
  if not config_path:
488
- parser.error("train 模式需要提供 --train_config")
482
+ parser.error("[NextRec CLI Error] train mode requires --train_config")
489
483
  train_model(config_path)
490
484
  else:
491
- config_path = args.predict_config or args.config
485
+ config_path = args.predict_config
492
486
  if not config_path:
493
- parser.error("predict 模式需要提供 --predict_config")
487
+ parser.error("[NextRec CLI Error] predict mode requires --predict_config")
494
488
  predict_model(config_path)
495
489
 
496
490
 
@@ -322,7 +322,7 @@ class RecDataLoader(FeatureSet):
322
322
  except OSError:
323
323
  pass
324
324
  try:
325
- df = read_table(file_path, file_type=file_type)
325
+ df = read_table(file_path, data_format=file_type)
326
326
  dfs.append(df)
327
327
  except MemoryError as exc:
328
328
  raise MemoryError(
@@ -76,10 +76,10 @@ class ESMM(BaseModel):
76
76
  sequence_features: list[SequenceFeature],
77
77
  ctr_params: dict,
78
78
  cvr_params: dict,
79
- target: list[str] = ["ctr", "ctcvr"], # Note: ctcvr = ctr * cvr
79
+ target: list[str] | None = None, # Note: ctcvr = ctr * cvr
80
80
  task: list[str] | None = None,
81
81
  optimizer: str = "adam",
82
- optimizer_params: dict = {},
82
+ optimizer_params: dict | None = None,
83
83
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
84
84
  loss_params: dict | list[dict] | None = None,
85
85
  device: str = "cpu",
@@ -90,19 +90,36 @@ class ESMM(BaseModel):
90
90
  **kwargs,
91
91
  ):
92
92
 
93
- # ESMM requires exactly 2 targets: ctr and ctcvr
93
+ target = target or ["ctr", "ctcvr"]
94
+ optimizer_params = optimizer_params or {}
95
+ if loss is None:
96
+ loss = "bce"
97
+
94
98
  if len(target) != 2:
95
99
  raise ValueError(
96
100
  f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
97
101
  )
98
102
 
103
+ self.num_tasks = len(target)
104
+ resolved_task = task
105
+ if resolved_task is None:
106
+ resolved_task = self.default_task
107
+ elif isinstance(resolved_task, str):
108
+ resolved_task = [resolved_task] * self.num_tasks
109
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
110
+ resolved_task = resolved_task * self.num_tasks
111
+ elif len(resolved_task) != self.num_tasks:
112
+ raise ValueError(
113
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
114
+ )
115
+ # resolved_task is now guaranteed to be a list[str]
116
+
99
117
  super(ESMM, self).__init__(
100
118
  dense_features=dense_features,
101
119
  sparse_features=sparse_features,
102
120
  sequence_features=sequence_features,
103
121
  target=target,
104
- task=task
105
- or self.default_task, # Both CTR and CTCVR are binary classification
122
+ task=resolved_task, # Both CTR and CTCVR are binary classification
106
123
  device=device,
107
124
  embedding_l1_reg=embedding_l1_reg,
108
125
  dense_l1_reg=dense_l1_reg,
@@ -112,19 +129,9 @@ class ESMM(BaseModel):
112
129
  )
113
130
 
114
131
  self.loss = loss
115
- if self.loss is None:
116
- self.loss = "bce"
117
132
 
118
- # All features
119
- self.all_features = dense_features + sparse_features + sequence_features
120
- # Shared embedding layer
121
133
  self.embedding = EmbeddingLayer(features=self.all_features)
122
- input_dim = (
123
- self.embedding.input_dim
124
- ) # Calculate input dimension, better way than below
125
- # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
126
- # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
127
- # input_dim = emb_dim_total + dense_input_dim
134
+ input_dim = self.embedding.input_dim
128
135
 
129
136
  # CTR tower
130
137
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
@@ -73,16 +73,16 @@ class MMOE(BaseModel):
73
73
 
74
74
  def __init__(
75
75
  self,
76
- dense_features: list[DenseFeature] = [],
77
- sparse_features: list[SparseFeature] = [],
78
- sequence_features: list[SequenceFeature] = [],
79
- expert_params: dict = {},
76
+ dense_features: list[DenseFeature] | None = None,
77
+ sparse_features: list[SparseFeature] | None = None,
78
+ sequence_features: list[SequenceFeature] | None = None,
79
+ expert_params: dict | None = None,
80
80
  num_experts: int = 3,
81
- tower_params_list: list[dict] = [],
82
- target: list[str] = [],
83
- task: str | list[str] | None = None,
81
+ tower_params_list: list[dict] | None = None,
82
+ target: list[str] | str | None = None,
83
+ task: str | list[str] = "binary",
84
84
  optimizer: str = "adam",
85
- optimizer_params: dict = {},
85
+ optimizer_params: dict | None = None,
86
86
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
87
87
  loss_params: dict | list[dict] | None = None,
88
88
  device: str = "cpu",
@@ -93,14 +93,39 @@ class MMOE(BaseModel):
93
93
  **kwargs,
94
94
  ):
95
95
 
96
- self.num_tasks = len(target)
96
+ dense_features = dense_features or []
97
+ sparse_features = sparse_features or []
98
+ sequence_features = sequence_features or []
99
+ expert_params = expert_params or {}
100
+ tower_params_list = tower_params_list or []
101
+ optimizer_params = optimizer_params or {}
102
+ if loss is None:
103
+ loss = "bce"
104
+ if target is None:
105
+ target = []
106
+ elif isinstance(target, str):
107
+ target = [target]
108
+
109
+ self.num_tasks = len(target) if target else 1
110
+
111
+ resolved_task = task
112
+ if resolved_task is None:
113
+ resolved_task = self.default_task
114
+ elif isinstance(resolved_task, str):
115
+ resolved_task = [resolved_task] * self.num_tasks
116
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
117
+ resolved_task = resolved_task * self.num_tasks
118
+ elif len(resolved_task) != self.num_tasks:
119
+ raise ValueError(
120
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
121
+ )
97
122
 
98
123
  super(MMOE, self).__init__(
99
124
  dense_features=dense_features,
100
125
  sparse_features=sparse_features,
101
126
  sequence_features=sequence_features,
102
127
  target=target,
103
- task=task or self.default_task,
128
+ task=resolved_task,
104
129
  device=device,
105
130
  embedding_l1_reg=embedding_l1_reg,
106
131
  dense_l1_reg=dense_l1_reg,
@@ -110,8 +135,6 @@ class MMOE(BaseModel):
110
135
  )
111
136
 
112
137
  self.loss = loss
113
- if self.loss is None:
114
- self.loss = "bce"
115
138
 
116
139
  # Number of tasks and experts
117
140
  self.num_tasks = len(target)
@@ -122,12 +145,8 @@ class MMOE(BaseModel):
122
145
  f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
123
146
  )
124
147
 
125
- self.all_features = dense_features + sparse_features + sequence_features
126
148
  self.embedding = EmbeddingLayer(features=self.all_features)
127
149
  input_dim = self.embedding.input_dim
128
- # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
129
- # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
130
- # input_dim = emb_dim_total + dense_input_dim
131
150
 
132
151
  # Expert networks (shared by all tasks)
133
152
  self.experts = nn.ModuleList()
@@ -162,7 +181,7 @@ class MMOE(BaseModel):
162
181
  self.compile(
163
182
  optimizer=optimizer,
164
183
  optimizer_params=optimizer_params,
165
- loss=loss,
184
+ loss=self.loss,
166
185
  loss_params=loss_params,
167
186
  )
168
187
 
@@ -51,6 +51,7 @@ import torch.nn as nn
51
51
  from nextrec.basic.model import BaseModel
52
52
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
53
53
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
54
+ from nextrec.utils.model import get_mlp_output_dim
54
55
 
55
56
 
56
57
  class CGCLayer(nn.Module):
@@ -72,13 +73,13 @@ class CGCLayer(nn.Module):
72
73
  if num_tasks < 1:
73
74
  raise ValueError("num_tasks must be >= 1")
74
75
 
75
- specific_params_list = self._normalize_specific_params(
76
+ specific_params_list = self.normalize_specific_params(
76
77
  specific_expert_params, num_tasks
77
78
  )
78
79
 
79
- self.output_dim = self._get_output_dim(shared_expert_params, input_dim)
80
+ self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
80
81
  specific_dims = [
81
- self._get_output_dim(params, input_dim) for params in specific_params_list
82
+ get_mlp_output_dim(params, input_dim) for params in specific_params_list
82
83
  ]
83
84
  dims_set = set(specific_dims + [self.output_dim])
84
85
  if len(dims_set) != 1:
@@ -165,14 +166,7 @@ class CGCLayer(nn.Module):
165
166
  return new_task_fea, new_shared
166
167
 
167
168
  @staticmethod
168
- def _get_output_dim(params: dict, fallback: int) -> int:
169
- dims = params.get("dims")
170
- if dims:
171
- return dims[-1]
172
- return fallback
173
-
174
- @staticmethod
175
- def _normalize_specific_params(
169
+ def normalize_specific_params(
176
170
  params: dict | list[dict], num_tasks: int
177
171
  ) -> list[dict]:
178
172
  if isinstance(params, list):
@@ -232,12 +226,24 @@ class PLE(BaseModel):
232
226
 
233
227
  self.num_tasks = len(target)
234
228
 
229
+ resolved_task = task
230
+ if resolved_task is None:
231
+ resolved_task = self.default_task
232
+ elif isinstance(resolved_task, str):
233
+ resolved_task = [resolved_task] * self.num_tasks
234
+ elif len(resolved_task) == 1 and self.num_tasks > 1:
235
+ resolved_task = resolved_task * self.num_tasks
236
+ elif len(resolved_task) != self.num_tasks:
237
+ raise ValueError(
238
+ f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
239
+ )
240
+
235
241
  super(PLE, self).__init__(
236
242
  dense_features=dense_features,
237
243
  sparse_features=sparse_features,
238
244
  sequence_features=sequence_features,
239
245
  target=target,
240
- task=task or self.default_task,
246
+ task=resolved_task,
241
247
  device=device,
242
248
  embedding_l1_reg=embedding_l1_reg,
243
249
  dense_l1_reg=dense_l1_reg,