nextrec 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -8
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/metrics.py +2 -1
- nextrec/basic/model.py +3 -3
- nextrec/cli.py +41 -47
- nextrec/data/dataloader.py +1 -1
- nextrec/models/multi_task/esmm.py +23 -16
- nextrec/models/multi_task/mmoe.py +36 -17
- nextrec/models/multi_task/ple.py +18 -12
- nextrec/models/multi_task/poso.py +68 -37
- nextrec/models/multi_task/share_bottom.py +16 -2
- nextrec/models/ranking/afm.py +14 -14
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +61 -19
- nextrec/models/ranking/dcn_v2.py +224 -45
- nextrec/models/ranking/deepfm.py +14 -9
- nextrec/models/ranking/dien.py +215 -82
- nextrec/models/ranking/din.py +95 -57
- nextrec/models/ranking/fibinet.py +92 -30
- nextrec/models/ranking/fm.py +44 -8
- nextrec/models/ranking/masknet.py +7 -7
- nextrec/models/ranking/pnn.py +105 -38
- nextrec/models/ranking/widedeep.py +8 -4
- nextrec/models/ranking/xdeepfm.py +57 -10
- nextrec/utils/config.py +15 -3
- nextrec/utils/file.py +2 -1
- nextrec/utils/initializer.py +12 -16
- nextrec/utils/model.py +22 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.4.dist-info}/METADATA +57 -22
- {nextrec-0.4.2.dist-info → nextrec-0.4.4.dist-info}/RECORD +34 -34
- {nextrec-0.4.2.dist-info → nextrec-0.4.4.dist-info}/WHEEL +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.4.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.2.dist-info → nextrec-0.4.4.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.4"
|
nextrec/basic/layers.py
CHANGED
|
@@ -51,11 +51,11 @@ class PredictionLayer(nn.Module):
|
|
|
51
51
|
|
|
52
52
|
# slice offsets per task
|
|
53
53
|
start = 0
|
|
54
|
-
self.
|
|
54
|
+
self.task_slices: list[tuple[int, int]] = []
|
|
55
55
|
for dim in self.task_dims:
|
|
56
56
|
if dim < 1:
|
|
57
57
|
raise ValueError("Each task dimension must be >= 1.")
|
|
58
|
-
self.
|
|
58
|
+
self.task_slices.append((start, start + dim))
|
|
59
59
|
start += dim
|
|
60
60
|
if use_bias:
|
|
61
61
|
self.bias = nn.Parameter(torch.zeros(self.total_dim))
|
|
@@ -71,7 +71,7 @@ class PredictionLayer(nn.Module):
|
|
|
71
71
|
)
|
|
72
72
|
logits = x if self.bias is None else x + self.bias
|
|
73
73
|
outputs = []
|
|
74
|
-
for task_type, (start, end) in zip(self.task_types, self.
|
|
74
|
+
for task_type, (start, end) in zip(self.task_types, self.task_slices):
|
|
75
75
|
task_logits = logits[..., start:end] # logits for the current task
|
|
76
76
|
if self.return_logits:
|
|
77
77
|
outputs.append(task_logits)
|
|
@@ -367,20 +367,29 @@ class MLP(nn.Module):
|
|
|
367
367
|
dims: list[int] | None = None,
|
|
368
368
|
dropout: float = 0.0,
|
|
369
369
|
activation: str = "relu",
|
|
370
|
+
use_norm: bool = True,
|
|
371
|
+
norm_type: str = "layer_norm",
|
|
370
372
|
):
|
|
371
373
|
super().__init__()
|
|
372
374
|
if dims is None:
|
|
373
375
|
dims = []
|
|
374
376
|
layers = []
|
|
375
377
|
current_dim = input_dim
|
|
376
|
-
|
|
377
378
|
for i_dim in dims:
|
|
378
379
|
layers.append(nn.Linear(current_dim, i_dim))
|
|
379
|
-
|
|
380
|
+
if use_norm:
|
|
381
|
+
if norm_type == "batch_norm":
|
|
382
|
+
# **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
|
|
383
|
+
layers.append(nn.BatchNorm1d(i_dim))
|
|
384
|
+
elif norm_type == "layer_norm":
|
|
385
|
+
layers.append(nn.LayerNorm(i_dim))
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(f"Unsupported norm_type: {norm_type}")
|
|
388
|
+
|
|
380
389
|
layers.append(activation_layer(activation))
|
|
381
390
|
layers.append(nn.Dropout(p=dropout))
|
|
382
391
|
current_dim = i_dim
|
|
383
|
-
|
|
392
|
+
# output layer
|
|
384
393
|
if output_layer:
|
|
385
394
|
layers.append(nn.Linear(current_dim, 1))
|
|
386
395
|
self.output_dim = 1
|
|
@@ -471,6 +480,21 @@ class BiLinearInteractionLayer(nn.Module):
|
|
|
471
480
|
return torch.cat(bilinear_list, dim=1)
|
|
472
481
|
|
|
473
482
|
|
|
483
|
+
class HadamardInteractionLayer(nn.Module):
|
|
484
|
+
"""Hadamard interaction layer for Deep-FiBiNET (0 case in 01/11)."""
|
|
485
|
+
|
|
486
|
+
def __init__(self, num_fields: int):
|
|
487
|
+
super().__init__()
|
|
488
|
+
self.num_fields = num_fields
|
|
489
|
+
|
|
490
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
491
|
+
# x: [B, F, D]
|
|
492
|
+
feature_emb = torch.split(x, 1, dim=1) # list of F tensors [B,1,D]
|
|
493
|
+
|
|
494
|
+
hadamard_list = [v_i * v_j for (v_i, v_j) in combinations(feature_emb, 2)]
|
|
495
|
+
return torch.cat(hadamard_list, dim=1) # [B, num_pairs, D]
|
|
496
|
+
|
|
497
|
+
|
|
474
498
|
class MultiHeadSelfAttention(nn.Module):
|
|
475
499
|
def __init__(
|
|
476
500
|
self,
|
|
@@ -542,7 +566,7 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
542
566
|
embedding_dim: int,
|
|
543
567
|
hidden_units: list = [80, 40],
|
|
544
568
|
activation: str = "sigmoid",
|
|
545
|
-
use_softmax: bool =
|
|
569
|
+
use_softmax: bool = False,
|
|
546
570
|
):
|
|
547
571
|
super().__init__()
|
|
548
572
|
self.embedding_dim = embedding_dim
|
|
@@ -553,7 +577,7 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
553
577
|
layers = []
|
|
554
578
|
for hidden_unit in hidden_units:
|
|
555
579
|
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
556
|
-
layers.append(activation_layer(activation))
|
|
580
|
+
layers.append(activation_layer(activation, emb_size=hidden_unit))
|
|
557
581
|
input_dim = hidden_unit
|
|
558
582
|
layers.append(nn.Linear(input_dim, 1))
|
|
559
583
|
self.attention_net = nn.Sequential(*layers)
|
nextrec/basic/loggers.py
CHANGED
|
@@ -103,7 +103,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
103
103
|
session = create_session(str(session_id) if session_id is not None else None)
|
|
104
104
|
log_dir = session.logs_dir
|
|
105
105
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
106
|
-
log_file = log_dir /
|
|
106
|
+
log_file = log_dir / "runs.log"
|
|
107
107
|
|
|
108
108
|
console_format = "%(message)s"
|
|
109
109
|
file_format = "%(asctime)s - %(levelname)s - %(message)s"
|
nextrec/basic/metrics.py
CHANGED
|
@@ -260,7 +260,7 @@ def compute_mrr_at_k(
|
|
|
260
260
|
order = np.argsort(scores)[::-1]
|
|
261
261
|
k_user = min(k, idx.size)
|
|
262
262
|
topk = order[:k_user]
|
|
263
|
-
ranked_labels = labels[
|
|
263
|
+
ranked_labels = labels[topk]
|
|
264
264
|
rr = 0.0
|
|
265
265
|
for rank, lab in enumerate(ranked_labels[:k_user], start=1):
|
|
266
266
|
if lab > 0:
|
|
@@ -612,6 +612,7 @@ def evaluate_metrics(
|
|
|
612
612
|
if task_type in ["binary", "multilabel"]:
|
|
613
613
|
should_compute = metric_lower in {
|
|
614
614
|
"auc",
|
|
615
|
+
"gauc",
|
|
615
616
|
"ks",
|
|
616
617
|
"logloss",
|
|
617
618
|
"accuracy",
|
nextrec/basic/model.py
CHANGED
|
@@ -455,7 +455,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
455
455
|
if hasattr(
|
|
456
456
|
self, "prediction_layer"
|
|
457
457
|
): # we need to use registered task_slices for multi-task and multi-class
|
|
458
|
-
slices = self.prediction_layer.
|
|
458
|
+
slices = self.prediction_layer.task_slices # type: ignore
|
|
459
459
|
else:
|
|
460
460
|
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
461
461
|
task_losses = []
|
|
@@ -1369,7 +1369,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1369
1369
|
pred_columns: list[str] = []
|
|
1370
1370
|
if self.target_columns:
|
|
1371
1371
|
for name in self.target_columns[:num_outputs]:
|
|
1372
|
-
pred_columns.append(f"{name}
|
|
1372
|
+
pred_columns.append(f"{name}")
|
|
1373
1373
|
while len(pred_columns) < num_outputs:
|
|
1374
1374
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1375
1375
|
if include_ids and predict_id_columns:
|
|
@@ -1496,7 +1496,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1496
1496
|
pred_columns = []
|
|
1497
1497
|
if self.target_columns:
|
|
1498
1498
|
for name in self.target_columns[:num_outputs]:
|
|
1499
|
-
pred_columns.append(f"{name}
|
|
1499
|
+
pred_columns.append(f"{name}")
|
|
1500
1500
|
while len(pred_columns) < num_outputs:
|
|
1501
1501
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1502
1502
|
|
nextrec/cli.py
CHANGED
|
@@ -8,10 +8,10 @@ following script to execute the desired operations.
|
|
|
8
8
|
|
|
9
9
|
Examples:
|
|
10
10
|
# Train a model
|
|
11
|
-
nextrec --mode=train --train_config=
|
|
11
|
+
nextrec --mode=train --train_config=nextrec_cli_preset/train_config.yaml
|
|
12
12
|
|
|
13
13
|
# Run prediction
|
|
14
|
-
nextrec --mode=predict --predict_config=
|
|
14
|
+
nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
|
|
15
15
|
|
|
16
16
|
Date: create on 06/12/2025
|
|
17
17
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
@@ -33,7 +33,6 @@ from nextrec.data.preprocessor import DataProcessor
|
|
|
33
33
|
from nextrec.utils.config import (
|
|
34
34
|
build_feature_objects,
|
|
35
35
|
build_model_instance,
|
|
36
|
-
extract_feature_groups,
|
|
37
36
|
register_processor_features,
|
|
38
37
|
resolve_path,
|
|
39
38
|
select_features,
|
|
@@ -115,16 +114,13 @@ def train_model(train_config_path: str) -> None:
|
|
|
115
114
|
df = read_table(data_path, data_cfg.get("format"))
|
|
116
115
|
df_columns = list(df.columns)
|
|
117
116
|
|
|
118
|
-
# for some models have independent feature groups, we need to extract them here
|
|
119
|
-
feature_groups, grouped_columns = extract_feature_groups(feature_cfg, df_columns)
|
|
120
|
-
if feature_groups:
|
|
121
|
-
model_cfg.setdefault("params", {})
|
|
122
|
-
model_cfg["params"].setdefault("feature_groups", feature_groups)
|
|
123
|
-
|
|
124
117
|
dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
)
|
|
118
|
+
|
|
119
|
+
# Extract id_column from data config for GAUC metrics
|
|
120
|
+
id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
|
|
121
|
+
id_columns = [id_column] if id_column else []
|
|
122
|
+
|
|
123
|
+
used_columns = dense_names + sparse_names + sequence_names + target + id_columns
|
|
128
124
|
|
|
129
125
|
# keep order but drop duplicates
|
|
130
126
|
seen = set()
|
|
@@ -183,17 +179,16 @@ def train_model(train_config_path: str) -> None:
|
|
|
183
179
|
streaming_valid_files = file_paths[-val_count:]
|
|
184
180
|
streaming_train_files = file_paths[:-val_count]
|
|
185
181
|
logger.info(
|
|
186
|
-
"
|
|
187
|
-
ratio,
|
|
188
|
-
len(streaming_train_files),
|
|
189
|
-
len(streaming_valid_files),
|
|
182
|
+
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
190
183
|
)
|
|
191
184
|
train_data: Dict[str, Any]
|
|
192
185
|
valid_data: Dict[str, Any] | None
|
|
193
186
|
|
|
194
187
|
if val_data_path and not streaming:
|
|
195
188
|
# Use specified validation dataset path
|
|
196
|
-
logger.info(
|
|
189
|
+
logger.info(
|
|
190
|
+
f"Validation using specified validation dataset path: {val_data_path}"
|
|
191
|
+
)
|
|
197
192
|
val_data_resolved = resolve_path(val_data_path, config_dir)
|
|
198
193
|
val_df = read_table(val_data_resolved, data_cfg.get("format"))
|
|
199
194
|
val_df = val_df[unique_used_columns]
|
|
@@ -206,17 +201,21 @@ def train_model(train_config_path: str) -> None:
|
|
|
206
201
|
valid_data = valid_data_result
|
|
207
202
|
train_size = len(list(train_data.values())[0])
|
|
208
203
|
valid_size = len(list(valid_data.values())[0])
|
|
209
|
-
logger.info(
|
|
204
|
+
logger.info(
|
|
205
|
+
f"Sample count - Training set: {train_size}, Validation set: {valid_size}"
|
|
206
|
+
)
|
|
210
207
|
elif streaming:
|
|
211
208
|
train_data = None # type: ignore[assignment]
|
|
212
209
|
valid_data = None
|
|
213
210
|
if not val_data_path and not streaming_valid_files:
|
|
214
211
|
logger.info(
|
|
215
|
-
"
|
|
212
|
+
"Streaming training mode: No validation dataset path specified and valid_ratio not configured, skipping validation dataset creation"
|
|
216
213
|
)
|
|
217
214
|
else:
|
|
218
215
|
# Split data using valid_ratio
|
|
219
|
-
logger.info(
|
|
216
|
+
logger.info(
|
|
217
|
+
f"Splitting data using valid_ratio: {data_cfg.get('valid_ratio', 0.2)}"
|
|
218
|
+
)
|
|
220
219
|
if not isinstance(processed, dict):
|
|
221
220
|
raise TypeError("Processed data must be a dictionary for splitting")
|
|
222
221
|
train_data, valid_data = split_dict_random(
|
|
@@ -230,6 +229,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
230
229
|
sparse_features=sparse_features,
|
|
231
230
|
sequence_features=sequence_features,
|
|
232
231
|
target=target,
|
|
232
|
+
id_columns=id_columns,
|
|
233
233
|
processor=processor if streaming else None,
|
|
234
234
|
)
|
|
235
235
|
if streaming:
|
|
@@ -240,6 +240,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
240
240
|
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
241
241
|
load_full=False,
|
|
242
242
|
chunk_size=dataloader_chunk_size,
|
|
243
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
243
244
|
)
|
|
244
245
|
valid_loader = None
|
|
245
246
|
if val_data_path:
|
|
@@ -250,6 +251,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
250
251
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
251
252
|
load_full=False,
|
|
252
253
|
chunk_size=dataloader_chunk_size,
|
|
254
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
253
255
|
)
|
|
254
256
|
elif streaming_valid_files:
|
|
255
257
|
valid_loader = dataloader.create_dataloader(
|
|
@@ -258,17 +260,20 @@ def train_model(train_config_path: str) -> None:
|
|
|
258
260
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
259
261
|
load_full=False,
|
|
260
262
|
chunk_size=dataloader_chunk_size,
|
|
263
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
261
264
|
)
|
|
262
265
|
else:
|
|
263
266
|
train_loader = dataloader.create_dataloader(
|
|
264
267
|
data=train_data,
|
|
265
268
|
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
266
269
|
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
270
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
267
271
|
)
|
|
268
272
|
valid_loader = dataloader.create_dataloader(
|
|
269
273
|
data=valid_data,
|
|
270
274
|
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
271
275
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
276
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
272
277
|
)
|
|
273
278
|
|
|
274
279
|
model_cfg.setdefault("session_id", session_id)
|
|
@@ -300,6 +305,9 @@ def train_model(train_config_path: str) -> None:
|
|
|
300
305
|
"batch_size", dataloader_cfg.get("train_batch_size", 512)
|
|
301
306
|
),
|
|
302
307
|
shuffle=train_cfg.get("shuffle", True),
|
|
308
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
309
|
+
user_id_column=id_column,
|
|
310
|
+
tensorboard=False,
|
|
303
311
|
)
|
|
304
312
|
|
|
305
313
|
|
|
@@ -325,19 +333,15 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
325
333
|
model_cfg_path = resolve_path(
|
|
326
334
|
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
327
335
|
)
|
|
328
|
-
feature_cfg_path = resolve_path(
|
|
329
|
-
|
|
330
|
-
)
|
|
336
|
+
# feature_cfg_path = resolve_path(
|
|
337
|
+
# cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
338
|
+
# )
|
|
331
339
|
|
|
332
340
|
model_cfg = read_yaml(model_cfg_path)
|
|
333
|
-
feature_cfg = read_yaml(feature_cfg_path)
|
|
341
|
+
# feature_cfg = read_yaml(feature_cfg_path)
|
|
334
342
|
model_cfg.setdefault("session_id", session_id)
|
|
335
|
-
feature_groups_raw = feature_cfg.get("feature_groups") or {}
|
|
336
343
|
model_cfg.setdefault("params", {})
|
|
337
344
|
|
|
338
|
-
# attach feature_groups in predict phase to avoid missing bindings
|
|
339
|
-
model_cfg["params"]["feature_groups"] = feature_groups_raw
|
|
340
|
-
|
|
341
345
|
processor = DataProcessor.load(processor_path)
|
|
342
346
|
|
|
343
347
|
# Load checkpoint and ensure required parameters are passed
|
|
@@ -383,13 +387,6 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
383
387
|
if target_override:
|
|
384
388
|
target_cols = normalize_to_list(target_override)
|
|
385
389
|
|
|
386
|
-
# Recompute feature_groups with available feature names to drive bindings
|
|
387
|
-
feature_group_names = [f.name for f in all_features if hasattr(f, "name")]
|
|
388
|
-
parsed_feature_groups, _ = extract_feature_groups(feature_cfg, feature_group_names)
|
|
389
|
-
if parsed_feature_groups:
|
|
390
|
-
model_cfg.setdefault("params", {})
|
|
391
|
-
model_cfg["params"]["feature_groups"] = parsed_feature_groups
|
|
392
|
-
|
|
393
390
|
model = build_model_instance(
|
|
394
391
|
model_cfg=model_cfg,
|
|
395
392
|
model_cfg_path=model_cfg_path,
|
|
@@ -440,6 +437,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
440
437
|
return_dataframe=False,
|
|
441
438
|
save_path=output_path,
|
|
442
439
|
save_format=predict_cfg.get("save_format", "csv"),
|
|
440
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
443
441
|
)
|
|
444
442
|
duration = time.time() - start
|
|
445
443
|
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
@@ -448,7 +446,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
448
446
|
preview_rows = predict_cfg.get("preview_rows", 0)
|
|
449
447
|
if preview_rows > 0:
|
|
450
448
|
try:
|
|
451
|
-
preview = pd.read_csv(output_path, nrows=preview_rows)
|
|
449
|
+
preview = pd.read_csv(output_path, nrows=preview_rows, low_memory=False)
|
|
452
450
|
logger.info(f"Output preview:\n{preview}")
|
|
453
451
|
except Exception as exc: # pragma: no cover
|
|
454
452
|
logger.warning(f"Failed to read output preview: {exc}")
|
|
@@ -472,25 +470,21 @@ Examples:
|
|
|
472
470
|
"--mode",
|
|
473
471
|
choices=["train", "predict"],
|
|
474
472
|
required=True,
|
|
475
|
-
help="
|
|
476
|
-
)
|
|
477
|
-
parser.add_argument("--train_config", help="训练配置文件路径")
|
|
478
|
-
parser.add_argument("--predict_config", help="预测配置文件路径")
|
|
479
|
-
parser.add_argument(
|
|
480
|
-
"--config",
|
|
481
|
-
help="通用配置文件路径(已废弃,建议使用 --train_config 或 --predict_config)",
|
|
473
|
+
help="Running mode: train or predict",
|
|
482
474
|
)
|
|
475
|
+
parser.add_argument("--train_config", help="Training configuration file path")
|
|
476
|
+
parser.add_argument("--predict_config", help="Prediction configuration file path")
|
|
483
477
|
args = parser.parse_args()
|
|
484
478
|
|
|
485
479
|
if args.mode == "train":
|
|
486
|
-
config_path = args.train_config
|
|
480
|
+
config_path = args.train_config
|
|
487
481
|
if not config_path:
|
|
488
|
-
parser.error("train
|
|
482
|
+
parser.error("[NextRec CLI Error] train mode requires --train_config")
|
|
489
483
|
train_model(config_path)
|
|
490
484
|
else:
|
|
491
|
-
config_path = args.predict_config
|
|
485
|
+
config_path = args.predict_config
|
|
492
486
|
if not config_path:
|
|
493
|
-
parser.error("predict
|
|
487
|
+
parser.error("[NextRec CLI Error] predict mode requires --predict_config")
|
|
494
488
|
predict_model(config_path)
|
|
495
489
|
|
|
496
490
|
|
nextrec/data/dataloader.py
CHANGED
|
@@ -76,10 +76,10 @@ class ESMM(BaseModel):
|
|
|
76
76
|
sequence_features: list[SequenceFeature],
|
|
77
77
|
ctr_params: dict,
|
|
78
78
|
cvr_params: dict,
|
|
79
|
-
target: list[str] =
|
|
79
|
+
target: list[str] | None = None, # Note: ctcvr = ctr * cvr
|
|
80
80
|
task: list[str] | None = None,
|
|
81
81
|
optimizer: str = "adam",
|
|
82
|
-
optimizer_params: dict =
|
|
82
|
+
optimizer_params: dict | None = None,
|
|
83
83
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
84
84
|
loss_params: dict | list[dict] | None = None,
|
|
85
85
|
device: str = "cpu",
|
|
@@ -90,19 +90,36 @@ class ESMM(BaseModel):
|
|
|
90
90
|
**kwargs,
|
|
91
91
|
):
|
|
92
92
|
|
|
93
|
-
|
|
93
|
+
target = target or ["ctr", "ctcvr"]
|
|
94
|
+
optimizer_params = optimizer_params or {}
|
|
95
|
+
if loss is None:
|
|
96
|
+
loss = "bce"
|
|
97
|
+
|
|
94
98
|
if len(target) != 2:
|
|
95
99
|
raise ValueError(
|
|
96
100
|
f"ESMM requires exactly 2 targets (ctr and ctcvr), got {len(target)}"
|
|
97
101
|
)
|
|
98
102
|
|
|
103
|
+
self.num_tasks = len(target)
|
|
104
|
+
resolved_task = task
|
|
105
|
+
if resolved_task is None:
|
|
106
|
+
resolved_task = self.default_task
|
|
107
|
+
elif isinstance(resolved_task, str):
|
|
108
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
109
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
110
|
+
resolved_task = resolved_task * self.num_tasks
|
|
111
|
+
elif len(resolved_task) != self.num_tasks:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
114
|
+
)
|
|
115
|
+
# resolved_task is now guaranteed to be a list[str]
|
|
116
|
+
|
|
99
117
|
super(ESMM, self).__init__(
|
|
100
118
|
dense_features=dense_features,
|
|
101
119
|
sparse_features=sparse_features,
|
|
102
120
|
sequence_features=sequence_features,
|
|
103
121
|
target=target,
|
|
104
|
-
task=
|
|
105
|
-
or self.default_task, # Both CTR and CTCVR are binary classification
|
|
122
|
+
task=resolved_task, # Both CTR and CTCVR are binary classification
|
|
106
123
|
device=device,
|
|
107
124
|
embedding_l1_reg=embedding_l1_reg,
|
|
108
125
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -112,19 +129,9 @@ class ESMM(BaseModel):
|
|
|
112
129
|
)
|
|
113
130
|
|
|
114
131
|
self.loss = loss
|
|
115
|
-
if self.loss is None:
|
|
116
|
-
self.loss = "bce"
|
|
117
132
|
|
|
118
|
-
# All features
|
|
119
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
120
|
-
# Shared embedding layer
|
|
121
133
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
122
|
-
input_dim =
|
|
123
|
-
self.embedding.input_dim
|
|
124
|
-
) # Calculate input dimension, better way than below
|
|
125
|
-
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
126
|
-
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
127
|
-
# input_dim = emb_dim_total + dense_input_dim
|
|
134
|
+
input_dim = self.embedding.input_dim
|
|
128
135
|
|
|
129
136
|
# CTR tower
|
|
130
137
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
@@ -73,16 +73,16 @@ class MMOE(BaseModel):
|
|
|
73
73
|
|
|
74
74
|
def __init__(
|
|
75
75
|
self,
|
|
76
|
-
dense_features: list[DenseFeature] =
|
|
77
|
-
sparse_features: list[SparseFeature] =
|
|
78
|
-
sequence_features: list[SequenceFeature] =
|
|
79
|
-
expert_params: dict =
|
|
76
|
+
dense_features: list[DenseFeature] | None = None,
|
|
77
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
78
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
79
|
+
expert_params: dict | None = None,
|
|
80
80
|
num_experts: int = 3,
|
|
81
|
-
tower_params_list: list[dict] =
|
|
82
|
-
target: list[str] =
|
|
83
|
-
task: str | list[str]
|
|
81
|
+
tower_params_list: list[dict] | None = None,
|
|
82
|
+
target: list[str] | str | None = None,
|
|
83
|
+
task: str | list[str] = "binary",
|
|
84
84
|
optimizer: str = "adam",
|
|
85
|
-
optimizer_params: dict =
|
|
85
|
+
optimizer_params: dict | None = None,
|
|
86
86
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
87
87
|
loss_params: dict | list[dict] | None = None,
|
|
88
88
|
device: str = "cpu",
|
|
@@ -93,14 +93,39 @@ class MMOE(BaseModel):
|
|
|
93
93
|
**kwargs,
|
|
94
94
|
):
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
dense_features = dense_features or []
|
|
97
|
+
sparse_features = sparse_features or []
|
|
98
|
+
sequence_features = sequence_features or []
|
|
99
|
+
expert_params = expert_params or {}
|
|
100
|
+
tower_params_list = tower_params_list or []
|
|
101
|
+
optimizer_params = optimizer_params or {}
|
|
102
|
+
if loss is None:
|
|
103
|
+
loss = "bce"
|
|
104
|
+
if target is None:
|
|
105
|
+
target = []
|
|
106
|
+
elif isinstance(target, str):
|
|
107
|
+
target = [target]
|
|
108
|
+
|
|
109
|
+
self.num_tasks = len(target) if target else 1
|
|
110
|
+
|
|
111
|
+
resolved_task = task
|
|
112
|
+
if resolved_task is None:
|
|
113
|
+
resolved_task = self.default_task
|
|
114
|
+
elif isinstance(resolved_task, str):
|
|
115
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
116
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
117
|
+
resolved_task = resolved_task * self.num_tasks
|
|
118
|
+
elif len(resolved_task) != self.num_tasks:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
121
|
+
)
|
|
97
122
|
|
|
98
123
|
super(MMOE, self).__init__(
|
|
99
124
|
dense_features=dense_features,
|
|
100
125
|
sparse_features=sparse_features,
|
|
101
126
|
sequence_features=sequence_features,
|
|
102
127
|
target=target,
|
|
103
|
-
task=
|
|
128
|
+
task=resolved_task,
|
|
104
129
|
device=device,
|
|
105
130
|
embedding_l1_reg=embedding_l1_reg,
|
|
106
131
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -110,8 +135,6 @@ class MMOE(BaseModel):
|
|
|
110
135
|
)
|
|
111
136
|
|
|
112
137
|
self.loss = loss
|
|
113
|
-
if self.loss is None:
|
|
114
|
-
self.loss = "bce"
|
|
115
138
|
|
|
116
139
|
# Number of tasks and experts
|
|
117
140
|
self.num_tasks = len(target)
|
|
@@ -122,12 +145,8 @@ class MMOE(BaseModel):
|
|
|
122
145
|
f"Number of tower params ({len(tower_params_list)}) must match number of tasks ({self.num_tasks})"
|
|
123
146
|
)
|
|
124
147
|
|
|
125
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
126
148
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
127
149
|
input_dim = self.embedding.input_dim
|
|
128
|
-
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
129
|
-
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
130
|
-
# input_dim = emb_dim_total + dense_input_dim
|
|
131
150
|
|
|
132
151
|
# Expert networks (shared by all tasks)
|
|
133
152
|
self.experts = nn.ModuleList()
|
|
@@ -162,7 +181,7 @@ class MMOE(BaseModel):
|
|
|
162
181
|
self.compile(
|
|
163
182
|
optimizer=optimizer,
|
|
164
183
|
optimizer_params=optimizer_params,
|
|
165
|
-
loss=loss,
|
|
184
|
+
loss=self.loss,
|
|
166
185
|
loss_params=loss_params,
|
|
167
186
|
)
|
|
168
187
|
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -51,6 +51,7 @@ import torch.nn as nn
|
|
|
51
51
|
from nextrec.basic.model import BaseModel
|
|
52
52
|
from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
|
|
53
53
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
54
|
+
from nextrec.utils.model import get_mlp_output_dim
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
class CGCLayer(nn.Module):
|
|
@@ -72,13 +73,13 @@ class CGCLayer(nn.Module):
|
|
|
72
73
|
if num_tasks < 1:
|
|
73
74
|
raise ValueError("num_tasks must be >= 1")
|
|
74
75
|
|
|
75
|
-
specific_params_list = self.
|
|
76
|
+
specific_params_list = self.normalize_specific_params(
|
|
76
77
|
specific_expert_params, num_tasks
|
|
77
78
|
)
|
|
78
79
|
|
|
79
|
-
self.output_dim =
|
|
80
|
+
self.output_dim = get_mlp_output_dim(shared_expert_params, input_dim)
|
|
80
81
|
specific_dims = [
|
|
81
|
-
|
|
82
|
+
get_mlp_output_dim(params, input_dim) for params in specific_params_list
|
|
82
83
|
]
|
|
83
84
|
dims_set = set(specific_dims + [self.output_dim])
|
|
84
85
|
if len(dims_set) != 1:
|
|
@@ -165,14 +166,7 @@ class CGCLayer(nn.Module):
|
|
|
165
166
|
return new_task_fea, new_shared
|
|
166
167
|
|
|
167
168
|
@staticmethod
|
|
168
|
-
def
|
|
169
|
-
dims = params.get("dims")
|
|
170
|
-
if dims:
|
|
171
|
-
return dims[-1]
|
|
172
|
-
return fallback
|
|
173
|
-
|
|
174
|
-
@staticmethod
|
|
175
|
-
def _normalize_specific_params(
|
|
169
|
+
def normalize_specific_params(
|
|
176
170
|
params: dict | list[dict], num_tasks: int
|
|
177
171
|
) -> list[dict]:
|
|
178
172
|
if isinstance(params, list):
|
|
@@ -232,12 +226,24 @@ class PLE(BaseModel):
|
|
|
232
226
|
|
|
233
227
|
self.num_tasks = len(target)
|
|
234
228
|
|
|
229
|
+
resolved_task = task
|
|
230
|
+
if resolved_task is None:
|
|
231
|
+
resolved_task = self.default_task
|
|
232
|
+
elif isinstance(resolved_task, str):
|
|
233
|
+
resolved_task = [resolved_task] * self.num_tasks
|
|
234
|
+
elif len(resolved_task) == 1 and self.num_tasks > 1:
|
|
235
|
+
resolved_task = resolved_task * self.num_tasks
|
|
236
|
+
elif len(resolved_task) != self.num_tasks:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Length of task ({len(resolved_task)}) must match number of targets ({self.num_tasks})."
|
|
239
|
+
)
|
|
240
|
+
|
|
235
241
|
super(PLE, self).__init__(
|
|
236
242
|
dense_features=dense_features,
|
|
237
243
|
sparse_features=sparse_features,
|
|
238
244
|
sequence_features=sequence_features,
|
|
239
245
|
target=target,
|
|
240
|
-
task=
|
|
246
|
+
task=resolved_task,
|
|
241
247
|
device=device,
|
|
242
248
|
embedding_l1_reg=embedding_l1_reg,
|
|
243
249
|
dense_l1_reg=dense_l1_reg,
|