nextrec 0.4.25__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +54 -51
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/model.py +9 -14
- {nextrec-0.4.25.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.25.dist-info/RECORD +0 -86
- {nextrec-0.4.25.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.25.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.25.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.27"
|
nextrec/basic/asserts.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Assert function definitions for NextRec models.
|
|
3
|
+
|
|
4
|
+
Date: create on 01/01/2026
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from nextrec.utils.types import TaskTypeName, TrainingModeName
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def assert_task(
|
|
15
|
+
task: list[TaskTypeName] | TaskTypeName | None,
|
|
16
|
+
nums_task: int,
|
|
17
|
+
*,
|
|
18
|
+
model_name: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
if task is None:
|
|
21
|
+
raise ValueError(f"{model_name} requires task to be specified.")
|
|
22
|
+
|
|
23
|
+
# case 1: task is str
|
|
24
|
+
if isinstance(task, str):
|
|
25
|
+
if nums_task != 1:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"{model_name} received task='{task}' but nums_task={nums_task}. "
|
|
28
|
+
"String task is only allowed for single-task models."
|
|
29
|
+
)
|
|
30
|
+
return # single-task, valid
|
|
31
|
+
|
|
32
|
+
# case 2: task is list
|
|
33
|
+
if not isinstance(task, list):
|
|
34
|
+
raise TypeError(
|
|
35
|
+
f"{model_name} requires task to be a string or a list of strings."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# list but length == 1
|
|
39
|
+
if len(task) == 1:
|
|
40
|
+
if nums_task != 1:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"{model_name} received task list of length 1 but nums_task={nums_task}. "
|
|
43
|
+
"Length-1 task list is only allowed for single-task models."
|
|
44
|
+
)
|
|
45
|
+
return # single-task, valid
|
|
46
|
+
|
|
47
|
+
# multi-task: length must match nums_task
|
|
48
|
+
if len(task) != nums_task:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"{model_name} requires task length {nums_task}, got {len(task)}."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def assert_training_mode(
|
|
55
|
+
training_mode: TrainingModeName | list[TrainingModeName],
|
|
56
|
+
nums_task: int,
|
|
57
|
+
*,
|
|
58
|
+
model_name: str,
|
|
59
|
+
) -> None:
|
|
60
|
+
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
61
|
+
if not isinstance(training_mode, list):
|
|
62
|
+
raise TypeError(
|
|
63
|
+
f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
|
|
64
|
+
)
|
|
65
|
+
if len(training_mode) != nums_task:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"[{model_name}-init Error] training_mode list length must match number of tasks."
|
|
68
|
+
)
|
|
69
|
+
if any(mode not in valid_modes for mode in training_mode):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
72
|
+
)
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -190,6 +190,19 @@ class BasicLogger:
|
|
|
190
190
|
def close(self) -> None:
|
|
191
191
|
for backend in self.backends:
|
|
192
192
|
backend.close()
|
|
193
|
+
for backend in self.backends:
|
|
194
|
+
if isinstance(backend, SwanLabLogger):
|
|
195
|
+
swanlab = backend.swanlab
|
|
196
|
+
if not backend.enabled or swanlab is None:
|
|
197
|
+
continue
|
|
198
|
+
finish_fn = getattr(swanlab, "finish", None)
|
|
199
|
+
if finish_fn is None:
|
|
200
|
+
continue
|
|
201
|
+
try:
|
|
202
|
+
finish_fn()
|
|
203
|
+
except TypeError:
|
|
204
|
+
finish_fn()
|
|
205
|
+
break
|
|
193
206
|
|
|
194
207
|
|
|
195
208
|
class TensorBoardLogger(MetricsLoggerBackend):
|
|
@@ -369,10 +382,14 @@ class TrainingLogger(BasicLogger):
|
|
|
369
382
|
wandb_kwargs = dict(wandb_kwargs or {})
|
|
370
383
|
wandb_kwargs.setdefault("config", {})
|
|
371
384
|
wandb_kwargs["config"].update(config)
|
|
385
|
+
if "notes" in wandb_kwargs:
|
|
386
|
+
wandb_kwargs["config"].pop("note", None)
|
|
372
387
|
|
|
373
388
|
swanlab_kwargs = dict(swanlab_kwargs or {})
|
|
374
389
|
swanlab_kwargs.setdefault("config", {})
|
|
375
390
|
swanlab_kwargs["config"].update(config)
|
|
391
|
+
if "description" in swanlab_kwargs:
|
|
392
|
+
swanlab_kwargs["config"].pop("note", None)
|
|
376
393
|
|
|
377
394
|
self.wandb_logger = None
|
|
378
395
|
if use_wandb:
|
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -36,6 +36,7 @@ from torch.utils.data import DataLoader
|
|
|
36
36
|
from torch.utils.data.distributed import DistributedSampler
|
|
37
37
|
|
|
38
38
|
from nextrec import __version__
|
|
39
|
+
from nextrec.basic.asserts import assert_task
|
|
39
40
|
from nextrec.basic.callback import (
|
|
40
41
|
CallbackList,
|
|
41
42
|
CheckpointSaver,
|
|
@@ -101,6 +102,7 @@ from nextrec.utils.types import (
|
|
|
101
102
|
|
|
102
103
|
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
103
104
|
|
|
105
|
+
|
|
104
106
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
105
107
|
@property
|
|
106
108
|
def model_name(self) -> str:
|
|
@@ -110,30 +112,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
110
112
|
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
111
113
|
raise NotImplementedError
|
|
112
114
|
|
|
113
|
-
@property
|
|
114
|
-
def training_mode(self) -> TrainingModeName | list[TrainingModeName]:
|
|
115
|
-
if self.nums_task > 1:
|
|
116
|
-
return self.training_modes
|
|
117
|
-
return self.training_modes[0] if self.training_modes else "pointwise"
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
@training_mode.setter
|
|
121
|
-
def training_mode(self, training_mode: TrainingModeName | list[TrainingModeName]):
|
|
122
|
-
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
123
|
-
if isinstance(training_mode, list):
|
|
124
|
-
training_modes = list(training_mode)
|
|
125
|
-
if len(training_modes) != self.nums_task:
|
|
126
|
-
raise ValueError(
|
|
127
|
-
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
128
|
-
)
|
|
129
|
-
else:
|
|
130
|
-
training_modes = [training_mode] * self.nums_task
|
|
131
|
-
if any(mode not in valid_modes for mode in training_modes):
|
|
132
|
-
raise ValueError(
|
|
133
|
-
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
134
|
-
)
|
|
135
|
-
self.training_modes = list(training_modes)
|
|
136
|
-
|
|
137
115
|
def __init__(
|
|
138
116
|
self,
|
|
139
117
|
dense_features: list[DenseFeature] | None = None,
|
|
@@ -142,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
142
120
|
target: list[str] | str | None = None,
|
|
143
121
|
id_columns: list[str] | str | None = None,
|
|
144
122
|
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
145
|
-
training_mode: TrainingModeName | list[TrainingModeName] =
|
|
123
|
+
training_mode: TrainingModeName | list[TrainingModeName] | None = None,
|
|
146
124
|
embedding_l1_reg: float = 0.0,
|
|
147
125
|
dense_l1_reg: float = 0.0,
|
|
148
126
|
embedding_l2_reg: float = 0.0,
|
|
@@ -162,10 +140,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
162
140
|
dense_features: DenseFeature definitions.
|
|
163
141
|
sparse_features: SparseFeature definitions.
|
|
164
142
|
sequence_features: SequenceFeature definitions.
|
|
165
|
-
target: Target column name. e.g., '
|
|
143
|
+
target: Target column name. e.g., 'label_ctr' or ['label_ctr', 'label_cvr'].
|
|
166
144
|
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
167
145
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
168
|
-
training_mode: Training mode for
|
|
146
|
+
training_mode: Training mode for different tasks. e.g., 'pointwise', ['pointwise', 'pairwise'].
|
|
169
147
|
|
|
170
148
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
171
149
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
@@ -218,7 +196,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
218
196
|
self.task = task or self.default_task
|
|
219
197
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
220
198
|
|
|
221
|
-
|
|
199
|
+
training_mode = training_mode or "pointwise"
|
|
200
|
+
if isinstance(training_mode, list):
|
|
201
|
+
self.training_modes = list(training_mode)
|
|
202
|
+
else:
|
|
203
|
+
self.training_modes = [training_mode] * self.nums_task
|
|
222
204
|
|
|
223
205
|
self.embedding_l1_reg = embedding_l1_reg
|
|
224
206
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -328,13 +310,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
328
310
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
329
311
|
"""
|
|
330
312
|
Prepare unified input features and labels from the given input data.
|
|
331
|
-
|
|
313
|
+
|
|
332
314
|
|
|
333
315
|
Args:
|
|
334
316
|
input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
|
|
335
317
|
require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
|
|
336
|
-
|
|
337
|
-
Note:
|
|
318
|
+
|
|
319
|
+
Note:
|
|
338
320
|
target tensor shape will always be (batch_size, num_targets)
|
|
339
321
|
"""
|
|
340
322
|
feature_source = input_data.get("features", {})
|
|
@@ -491,9 +473,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
491
473
|
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
492
474
|
"""
|
|
493
475
|
self.ignore_label = ignore_label
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
)
|
|
476
|
+
|
|
477
|
+
# get loss list
|
|
478
|
+
loss_list = get_loss_list(loss, self.training_modes, self.nums_task)
|
|
497
479
|
|
|
498
480
|
self.loss_params = {} if loss_params is None else loss_params
|
|
499
481
|
self.optimizer_params = optimizer_params or {}
|
|
@@ -546,7 +528,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
546
528
|
raise ValueError(
|
|
547
529
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
548
530
|
)
|
|
549
|
-
grad_norm_params =
|
|
531
|
+
grad_norm_params = (
|
|
532
|
+
dict(loss_weights) if isinstance(loss_weights, dict) else {}
|
|
533
|
+
)
|
|
550
534
|
grad_norm_params.pop("method", None)
|
|
551
535
|
self.grad_norm = GradNormLossWeighting(
|
|
552
536
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
@@ -594,7 +578,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
594
578
|
y_true = y_true.view(-1, 1)
|
|
595
579
|
|
|
596
580
|
loss_fn = self.loss_fn[0]
|
|
597
|
-
|
|
581
|
+
|
|
598
582
|
# mask ignored labels
|
|
599
583
|
# we don't suggest using ignore_label for single task training
|
|
600
584
|
if self.ignore_label is not None:
|
|
@@ -685,6 +669,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
685
669
|
batch_size: int = 32,
|
|
686
670
|
shuffle: bool = True,
|
|
687
671
|
num_workers: int = 0,
|
|
672
|
+
prefetch_factor: int | None = None,
|
|
688
673
|
sampler=None,
|
|
689
674
|
return_dataset: bool = False,
|
|
690
675
|
):
|
|
@@ -696,6 +681,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
696
681
|
batch_size: Batch size.
|
|
697
682
|
shuffle: Whether to shuffle the data (ignored when a sampler is provided).
|
|
698
683
|
num_workers: Number of DataLoader workers.
|
|
684
|
+
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
699
685
|
sampler: Optional sampler for DataLoader.
|
|
700
686
|
return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
|
|
701
687
|
Returns:
|
|
@@ -715,6 +701,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
715
701
|
"[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
|
|
716
702
|
)
|
|
717
703
|
dataset = TensorDictDataset(tensors)
|
|
704
|
+
loader_kwargs = {}
|
|
705
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
706
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
718
707
|
loader = DataLoader(
|
|
719
708
|
dataset,
|
|
720
709
|
batch_size=batch_size,
|
|
@@ -724,6 +713,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
724
713
|
num_workers=num_workers,
|
|
725
714
|
pin_memory=self.device.type == "cuda",
|
|
726
715
|
persistent_workers=num_workers > 0,
|
|
716
|
+
**loader_kwargs,
|
|
727
717
|
)
|
|
728
718
|
return (loader, dataset) if return_dataset else loader
|
|
729
719
|
|
|
@@ -798,6 +788,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
798
788
|
)
|
|
799
789
|
self.to(self.device)
|
|
800
790
|
|
|
791
|
+
assert_task(self.task, len(self.target_columns), model_name=self.model_name)
|
|
792
|
+
|
|
801
793
|
if not self.compiled:
|
|
802
794
|
self.compile(
|
|
803
795
|
optimizer="adam",
|
|
@@ -902,6 +894,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
902
894
|
else:
|
|
903
895
|
swanlab.login(api_key=swanlab_api)
|
|
904
896
|
|
|
897
|
+
if use_wandb and self.note:
|
|
898
|
+
wandb_kwargs = dict(wandb_kwargs or {})
|
|
899
|
+
wandb_kwargs.setdefault("notes", self.note)
|
|
900
|
+
|
|
901
|
+
if use_swanlab and self.note:
|
|
902
|
+
swanlab_kwargs = dict(swanlab_kwargs or {})
|
|
903
|
+
swanlab_kwargs.setdefault("description", self.note)
|
|
904
|
+
|
|
905
905
|
self.training_logger = (
|
|
906
906
|
TrainingLogger(
|
|
907
907
|
session=self.session,
|
|
@@ -1649,7 +1649,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1649
1649
|
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1650
1650
|
num_workers: DataLoader worker count.
|
|
1651
1651
|
|
|
1652
|
-
Note:
|
|
1652
|
+
Note:
|
|
1653
1653
|
predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1654
1654
|
"""
|
|
1655
1655
|
self.eval()
|
|
@@ -1837,7 +1837,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1837
1837
|
):
|
|
1838
1838
|
"""
|
|
1839
1839
|
Make predictions on the given data using streaming mode for large datasets.
|
|
1840
|
-
|
|
1840
|
+
|
|
1841
1841
|
Args:
|
|
1842
1842
|
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
1843
1843
|
batch_size: Batch size for prediction.
|
|
@@ -2279,9 +2279,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2279
2279
|
self.num_negative_samples = num_negative_samples
|
|
2280
2280
|
self.temperature = temperature
|
|
2281
2281
|
self.similarity_metric = similarity_metric
|
|
2282
|
-
|
|
2282
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2283
|
+
if primary_mode not in self.support_training_modes:
|
|
2283
2284
|
raise ValueError(
|
|
2284
|
-
f"{self.model_name.upper()} does not support training_mode='{
|
|
2285
|
+
f"{self.model_name.upper()} does not support training_mode='{primary_mode}'. Supported modes: {self.support_training_modes}"
|
|
2285
2286
|
)
|
|
2286
2287
|
self.user_features_all = (
|
|
2287
2288
|
self.user_dense_features
|
|
@@ -2298,7 +2299,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2298
2299
|
self.head = RetrievalHead(
|
|
2299
2300
|
similarity_metric=self.similarity_metric,
|
|
2300
2301
|
temperature=self.temperature,
|
|
2301
|
-
training_mode=
|
|
2302
|
+
training_mode=primary_mode,
|
|
2302
2303
|
apply_sigmoid=True,
|
|
2303
2304
|
)
|
|
2304
2305
|
|
|
@@ -2338,26 +2339,27 @@ class BaseMatchModel(BaseModel):
|
|
|
2338
2339
|
}
|
|
2339
2340
|
|
|
2340
2341
|
effective_loss = loss
|
|
2342
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2341
2343
|
if effective_loss is None:
|
|
2342
|
-
effective_loss = default_loss_by_mode[
|
|
2344
|
+
effective_loss = default_loss_by_mode[primary_mode]
|
|
2343
2345
|
elif isinstance(effective_loss, str):
|
|
2344
|
-
if
|
|
2346
|
+
if primary_mode in {"pairwise", "listwise"} and effective_loss in {
|
|
2345
2347
|
"bce",
|
|
2346
2348
|
"binary_crossentropy",
|
|
2347
2349
|
}:
|
|
2348
|
-
effective_loss = default_loss_by_mode[
|
|
2350
|
+
effective_loss = default_loss_by_mode[primary_mode]
|
|
2349
2351
|
elif isinstance(effective_loss, list):
|
|
2350
2352
|
if not effective_loss:
|
|
2351
|
-
effective_loss = [default_loss_by_mode[
|
|
2353
|
+
effective_loss = [default_loss_by_mode[primary_mode]]
|
|
2352
2354
|
else:
|
|
2353
2355
|
first = effective_loss[0]
|
|
2354
2356
|
if (
|
|
2355
|
-
|
|
2357
|
+
primary_mode in {"pairwise", "listwise"}
|
|
2356
2358
|
and isinstance(first, str)
|
|
2357
2359
|
and first in {"bce", "binary_crossentropy"}
|
|
2358
2360
|
):
|
|
2359
2361
|
effective_loss = [
|
|
2360
|
-
default_loss_by_mode[
|
|
2362
|
+
default_loss_by_mode[primary_mode],
|
|
2361
2363
|
*effective_loss[1:],
|
|
2362
2364
|
]
|
|
2363
2365
|
return super().compile(
|
|
@@ -2435,11 +2437,12 @@ class BaseMatchModel(BaseModel):
|
|
|
2435
2437
|
return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
|
|
2436
2438
|
|
|
2437
2439
|
def compute_loss(self, y_pred, y_true):
|
|
2438
|
-
if self.
|
|
2440
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2441
|
+
if primary_mode == "pointwise":
|
|
2439
2442
|
return super().compute_loss(y_pred, y_true)
|
|
2440
2443
|
|
|
2441
2444
|
# pairwise / listwise using inbatch neg
|
|
2442
|
-
elif
|
|
2445
|
+
elif primary_mode in ["pairwise", "listwise"]:
|
|
2443
2446
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
2444
2447
|
raise ValueError(
|
|
2445
2448
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
@@ -2482,7 +2485,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2482
2485
|
loss *= float(self.loss_weights[0])
|
|
2483
2486
|
return loss
|
|
2484
2487
|
else:
|
|
2485
|
-
raise ValueError(f"Unknown training mode: {
|
|
2488
|
+
raise ValueError(f"Unknown training mode: {primary_mode}")
|
|
2486
2489
|
|
|
2487
2490
|
def prepare_feature_data(
|
|
2488
2491
|
self,
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
|
|
7
|
+
URL: https://arxiv.org/abs/2105.08489
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
20
|
+
from nextrec.basic.heads import TaskHead
|
|
21
|
+
from nextrec.basic.model import BaseModel
|
|
22
|
+
from nextrec.utils.model import get_mlp_output_dim
|
|
23
|
+
from nextrec.utils.types import TaskTypeName
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AITMTransfer(nn.Module):
|
|
27
|
+
"""Attentive information transfer from previous task to current task."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, input_dim: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.input_dim = input_dim
|
|
32
|
+
self.prev_proj = nn.Linear(input_dim, input_dim)
|
|
33
|
+
self.value = nn.Linear(input_dim, input_dim)
|
|
34
|
+
self.key = nn.Linear(input_dim, input_dim)
|
|
35
|
+
self.query = nn.Linear(input_dim, input_dim)
|
|
36
|
+
|
|
37
|
+
def forward(self, prev_feat: torch.Tensor, curr_feat: torch.Tensor) -> torch.Tensor:
|
|
38
|
+
prev = self.prev_proj(prev_feat).unsqueeze(1)
|
|
39
|
+
curr = curr_feat.unsqueeze(1)
|
|
40
|
+
stacked = torch.cat([prev, curr], dim=1)
|
|
41
|
+
value = self.value(stacked)
|
|
42
|
+
key = self.key(stacked)
|
|
43
|
+
query = self.query(stacked)
|
|
44
|
+
attn_scores = torch.sum(key * query, dim=2, keepdim=True) / math.sqrt(
|
|
45
|
+
self.input_dim
|
|
46
|
+
)
|
|
47
|
+
attn = torch.softmax(attn_scores, dim=1)
|
|
48
|
+
return torch.sum(attn * value, dim=1)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AITM(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
Attentive Information Transfer Multi-Task model.
|
|
54
|
+
|
|
55
|
+
AITM learns task-specific representations and transfers information from
|
|
56
|
+
task i-1 to task i via attention, enabling sequential task dependency modeling.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def model_name(self):
|
|
61
|
+
return "AITM"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def default_task(self):
|
|
65
|
+
nums_task = getattr(self, "nums_task", None)
|
|
66
|
+
if nums_task is not None and nums_task > 0:
|
|
67
|
+
return ["binary"] * nums_task
|
|
68
|
+
return ["binary"]
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dense_features: list[DenseFeature] | None = None,
|
|
73
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
74
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
75
|
+
bottom_mlp_params: dict | list[dict] | None = None,
|
|
76
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
77
|
+
calibrator_alpha: float = 0.1,
|
|
78
|
+
target: list[str] | str | None = None,
|
|
79
|
+
task: list[TaskTypeName] | None = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
dense_features = dense_features or []
|
|
83
|
+
sparse_features = sparse_features or []
|
|
84
|
+
sequence_features = sequence_features or []
|
|
85
|
+
bottom_mlp_params = bottom_mlp_params or {}
|
|
86
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
87
|
+
self.calibrator_alpha = calibrator_alpha
|
|
88
|
+
|
|
89
|
+
if target is None:
|
|
90
|
+
raise ValueError("AITM requires target names for all tasks.")
|
|
91
|
+
if isinstance(target, str):
|
|
92
|
+
target = [target]
|
|
93
|
+
|
|
94
|
+
self.nums_task = len(target)
|
|
95
|
+
if self.nums_task < 2:
|
|
96
|
+
raise ValueError("AITM requires at least 2 tasks.")
|
|
97
|
+
|
|
98
|
+
super(AITM, self).__init__(
|
|
99
|
+
dense_features=dense_features,
|
|
100
|
+
sparse_features=sparse_features,
|
|
101
|
+
sequence_features=sequence_features,
|
|
102
|
+
target=target,
|
|
103
|
+
task=task,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"Number of tower mlp params "
|
|
110
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
bottom_mlp_params_list: list[dict]
|
|
114
|
+
if isinstance(bottom_mlp_params, list):
|
|
115
|
+
if len(bottom_mlp_params) != self.nums_task:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
"Number of bottom mlp params "
|
|
118
|
+
f"({len(bottom_mlp_params)}) must match number of tasks ({self.nums_task})."
|
|
119
|
+
)
|
|
120
|
+
bottom_mlp_params_list = [params.copy() for params in bottom_mlp_params]
|
|
121
|
+
else:
|
|
122
|
+
bottom_mlp_params_list = [
|
|
123
|
+
bottom_mlp_params.copy() for _ in range(self.nums_task)
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
127
|
+
input_dim = self.embedding.input_dim
|
|
128
|
+
|
|
129
|
+
self.bottoms = nn.ModuleList(
|
|
130
|
+
[
|
|
131
|
+
MLP(input_dim=input_dim, output_dim=None, **params)
|
|
132
|
+
for params in bottom_mlp_params_list
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
bottom_dims = [
|
|
136
|
+
get_mlp_output_dim(params, input_dim) for params in bottom_mlp_params_list
|
|
137
|
+
]
|
|
138
|
+
if len(set(bottom_dims)) != 1:
|
|
139
|
+
raise ValueError(f"All bottom output dims must match, got {bottom_dims}.")
|
|
140
|
+
bottom_output_dim = bottom_dims[0]
|
|
141
|
+
|
|
142
|
+
self.transfers = nn.ModuleList(
|
|
143
|
+
[AITMTransfer(bottom_output_dim) for _ in range(self.nums_task - 1)]
|
|
144
|
+
)
|
|
145
|
+
self.grad_norm_shared_modules = ["embedding", "transfers"]
|
|
146
|
+
|
|
147
|
+
self.towers = nn.ModuleList(
|
|
148
|
+
[
|
|
149
|
+
MLP(input_dim=bottom_output_dim, output_dim=1, **params)
|
|
150
|
+
for params in tower_mlp_params_list
|
|
151
|
+
]
|
|
152
|
+
)
|
|
153
|
+
self.prediction_layer = TaskHead(
|
|
154
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.register_regularization_weights(
|
|
158
|
+
embedding_attr="embedding",
|
|
159
|
+
include_modules=["bottoms", "transfers", "towers"],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
163
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
164
|
+
task_feats = [bottom(input_flat) for bottom in self.bottoms]
|
|
165
|
+
|
|
166
|
+
for idx in range(1, self.nums_task):
|
|
167
|
+
task_feats[idx] = self.transfers[idx - 1](
|
|
168
|
+
task_feats[idx - 1], task_feats[idx]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
task_outputs = [tower(task_feats[idx]) for idx, tower in enumerate(self.towers)]
|
|
172
|
+
logits = torch.cat(task_outputs, dim=1)
|
|
173
|
+
return self.prediction_layer(logits)
|