nextrec 0.4.25__py3-none-any.whl → 0.4.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +54 -51
  5. nextrec/data/batch_utils.py +23 -3
  6. nextrec/data/dataloader.py +3 -8
  7. nextrec/models/multi_task/[pre]aitm.py +173 -0
  8. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  9. nextrec/models/multi_task/[pre]star.py +192 -0
  10. nextrec/models/multi_task/apg.py +330 -0
  11. nextrec/models/multi_task/cross_stitch.py +229 -0
  12. nextrec/models/multi_task/escm.py +290 -0
  13. nextrec/models/multi_task/esmm.py +8 -21
  14. nextrec/models/multi_task/hmoe.py +203 -0
  15. nextrec/models/multi_task/mmoe.py +20 -28
  16. nextrec/models/multi_task/pepnet.py +81 -76
  17. nextrec/models/multi_task/ple.py +30 -44
  18. nextrec/models/multi_task/poso.py +13 -22
  19. nextrec/models/multi_task/share_bottom.py +14 -25
  20. nextrec/models/ranking/afm.py +2 -2
  21. nextrec/models/ranking/autoint.py +2 -4
  22. nextrec/models/ranking/dcn.py +2 -3
  23. nextrec/models/ranking/dcn_v2.py +2 -3
  24. nextrec/models/ranking/deepfm.py +2 -3
  25. nextrec/models/ranking/dien.py +7 -9
  26. nextrec/models/ranking/din.py +8 -10
  27. nextrec/models/ranking/eulernet.py +1 -2
  28. nextrec/models/ranking/ffm.py +1 -2
  29. nextrec/models/ranking/fibinet.py +2 -3
  30. nextrec/models/ranking/fm.py +1 -1
  31. nextrec/models/ranking/lr.py +1 -1
  32. nextrec/models/ranking/masknet.py +1 -2
  33. nextrec/models/ranking/pnn.py +1 -2
  34. nextrec/models/ranking/widedeep.py +2 -3
  35. nextrec/models/ranking/xdeepfm.py +2 -4
  36. nextrec/models/representation/rqvae.py +4 -4
  37. nextrec/models/retrieval/dssm.py +18 -26
  38. nextrec/models/retrieval/dssm_v2.py +15 -22
  39. nextrec/models/retrieval/mind.py +9 -15
  40. nextrec/models/retrieval/sdm.py +36 -33
  41. nextrec/models/retrieval/youtube_dnn.py +16 -24
  42. nextrec/models/sequential/hstu.py +2 -2
  43. nextrec/utils/__init__.py +5 -1
  44. nextrec/utils/model.py +9 -14
  45. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/METADATA +72 -62
  46. nextrec-0.4.28.dist-info/RECORD +90 -0
  47. nextrec/models/multi_task/aitm.py +0 -0
  48. nextrec/models/multi_task/snr_trans.py +0 -0
  49. nextrec-0.4.25.dist-info/RECORD +0 -86
  50. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/WHEEL +0 -0
  51. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/entry_points.txt +0 -0
  52. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,232 @@
1
+ """
2
+ Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Ma J, Zhao Z, Chen J, Li A, Hong L, Chi EH. SNR: Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning in E-Commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 33rd AAAI Conference on Artificial Intelligence (AAAI 2019), 2019, pp. 216-223.
7
+ URL: https://ojs.aaai.org/index.php/AAAI/article/view/3788
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ SNR-Trans stacks multiple expert layers and applies sparse routing with
11
+ learnable per-output transform matrices. Intermediate gates route expert
12
+ outputs to the next expert stage, while the final gate routes to each task
13
+ tower for prediction.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
+ from nextrec.basic.layers import EmbeddingLayer, MLP
23
+ from nextrec.basic.heads import TaskHead
24
+ from nextrec.basic.model import BaseModel
25
+ from nextrec.utils.types import TaskTypeName
26
+
27
+
28
+ class SNRTransGate(nn.Module):
29
+ """
30
+ Sparse routing gate with per-output transform matrices.
31
+ """
32
+
33
+ def __init__(self, num_inputs: int, num_outputs: int, units: int) -> None:
34
+ super().__init__()
35
+ if num_inputs < 1 or num_outputs < 1:
36
+ raise ValueError("num_inputs and num_outputs must be >= 1")
37
+
38
+ self.num_inputs = num_inputs
39
+ self.num_outputs = num_outputs
40
+ self.units = units
41
+
42
+ self.alpha = nn.Parameter(torch.rand(1), requires_grad=True)
43
+ self.beta = 0.9
44
+ self.gamma = -0.1
45
+ self.epsilon = 1.1
46
+ self.eps = 1e-8
47
+
48
+ u_init = torch.empty(num_outputs, num_inputs)
49
+ u_init = nn.init.uniform_(u_init, self.eps, 1 - self.eps)
50
+ self.u = nn.Parameter(u_init, requires_grad=True)
51
+
52
+ trans = torch.empty(num_outputs, num_inputs, units, units)
53
+ nn.init.xavier_normal_(trans)
54
+ self.trans_matrix = nn.Parameter(trans, requires_grad=True)
55
+
56
+ def forward(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]:
57
+ if len(inputs) != self.num_inputs:
58
+ raise ValueError(
59
+ f"SNRTransGate expects {self.num_inputs} inputs, got {len(inputs)}"
60
+ )
61
+
62
+ s = torch.sigmoid(
63
+ torch.log(self.u)
64
+ - torch.log(1 - self.u)
65
+ + torch.log(self.alpha) / self.beta
66
+ )
67
+ s_ = s * (self.epsilon - self.gamma) + self.gamma
68
+ z = torch.clamp(s_, min=0.0, max=1.0)
69
+
70
+ x_stack = torch.stack(inputs, dim=1) # [B, num_inputs, units]
71
+ transformed = torch.einsum(
72
+ "bnu,onuv->bonv", x_stack, self.trans_matrix
73
+ ) # [B, num_outputs, num_inputs, units]
74
+ weighted = transformed * z.unsqueeze(0).unsqueeze(-1)
75
+ outputs = weighted.sum(dim=2) # [B, num_outputs, units]
76
+ return [outputs[:, i, :] for i in range(self.num_outputs)]
77
+
78
+
79
+ class SNRTrans(BaseModel):
80
+ """
81
+ SNR-Trans with sparse expert routing.
82
+ """
83
+
84
+ @property
85
+ def model_name(self) -> str:
86
+ return "SNRTrans"
87
+
88
+ @property
89
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
90
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
91
+ if nums_task is not None and nums_task > 0:
92
+ return ["binary"] * nums_task
93
+ return ["binary"]
94
+
95
+ def __init__(
96
+ self,
97
+ dense_features: list[DenseFeature] | None = None,
98
+ sparse_features: list[SparseFeature] | None = None,
99
+ sequence_features: list[SequenceFeature] | None = None,
100
+ expert_mlp_params: dict | None = None,
101
+ num_experts: int = 4,
102
+ tower_mlp_params_list: list[dict] | None = None,
103
+ target: list[str] | str | None = None,
104
+ task: TaskTypeName | list[TaskTypeName] | None = None,
105
+ **kwargs,
106
+ ) -> None:
107
+ dense_features = dense_features or []
108
+ sparse_features = sparse_features or []
109
+ sequence_features = sequence_features or []
110
+ expert_mlp_params = expert_mlp_params or {}
111
+ tower_mlp_params_list = tower_mlp_params_list or []
112
+
113
+ expert_mlp_params.setdefault("hidden_dims", [256, 128])
114
+ expert_mlp_params.setdefault("activation", "relu")
115
+ expert_mlp_params.setdefault("dropout", 0.0)
116
+ expert_mlp_params.setdefault("norm_type", "none")
117
+ expert_hidden_dims = expert_mlp_params["hidden_dims"]
118
+
119
+ if target is None:
120
+ target = []
121
+ elif isinstance(target, str):
122
+ target = [target]
123
+
124
+ self.nums_task = len(target) if target else 1
125
+
126
+ super().__init__(
127
+ dense_features=dense_features,
128
+ sparse_features=sparse_features,
129
+ sequence_features=sequence_features,
130
+ target=target,
131
+ task=task,
132
+ **kwargs,
133
+ )
134
+
135
+ self.nums_task = len(target) if target else 1
136
+ self.num_experts = num_experts
137
+
138
+ if self.nums_task <= 1:
139
+ raise ValueError("SNRTrans requires at least 2 tasks.")
140
+ if self.num_experts <= 1:
141
+ raise ValueError("num_experts must be greater than 1.")
142
+ if not expert_hidden_dims:
143
+ raise ValueError("expert_mlp_params['hidden_dims'] must not be empty.")
144
+
145
+ if tower_mlp_params_list:
146
+ if len(tower_mlp_params_list) != self.nums_task:
147
+ raise ValueError(
148
+ "Number of tower mlp params "
149
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
150
+ )
151
+ tower_params = [params.copy() for params in tower_mlp_params_list]
152
+ else:
153
+ tower_params = [{} for _ in range(self.nums_task)]
154
+
155
+ self.embedding = EmbeddingLayer(features=self.all_features)
156
+ input_dim = self.embedding.input_dim
157
+
158
+ self.expert_layers = nn.ModuleList()
159
+ self.gates = nn.ModuleList()
160
+ prev_dim = input_dim
161
+ for idx, hidden_dim in enumerate(expert_hidden_dims):
162
+ layer_experts = nn.ModuleList(
163
+ [
164
+ MLP(
165
+ input_dim=prev_dim,
166
+ hidden_dims=[hidden_dim],
167
+ output_dim=None,
168
+ dropout=expert_mlp_params["dropout"],
169
+ activation=expert_mlp_params["activation"],
170
+ norm_type=expert_mlp_params["norm_type"],
171
+ )
172
+ for _ in range(self.num_experts)
173
+ ]
174
+ )
175
+ self.expert_layers.append(layer_experts)
176
+ output_dim = (
177
+ self.nums_task
178
+ if idx == len(expert_hidden_dims) - 1
179
+ else self.num_experts
180
+ )
181
+ self.gates.append(
182
+ SNRTransGate(
183
+ num_inputs=self.num_experts,
184
+ num_outputs=output_dim,
185
+ units=hidden_dim,
186
+ )
187
+ )
188
+ prev_dim = hidden_dim
189
+
190
+ self.towers = nn.ModuleList(
191
+ [
192
+ MLP(input_dim=expert_hidden_dims[-1], output_dim=1, **params)
193
+ for params in tower_params
194
+ ]
195
+ )
196
+ self.prediction_layer = TaskHead(
197
+ task_type=self.task, task_dims=[1] * self.nums_task
198
+ )
199
+ self.grad_norm_shared_modules = ["embedding", "expert_layers", "gates"]
200
+ self.register_regularization_weights(
201
+ embedding_attr="embedding",
202
+ include_modules=["expert_layers", "gates", "towers"],
203
+ )
204
+
205
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
206
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
207
+
208
+ gate_outputs: list[torch.Tensor] | None = None
209
+ for layer_idx, (layer_experts, gate) in enumerate(
210
+ zip(self.expert_layers, self.gates)
211
+ ):
212
+ expert_outputs = []
213
+ if layer_idx == 0:
214
+ expert_inputs = [input_flat] * self.num_experts
215
+ else:
216
+ if gate_outputs is None:
217
+ raise RuntimeError("SNRTrans gate outputs are not initialized.")
218
+ expert_inputs = gate_outputs
219
+ for expert, expert_input in zip(layer_experts, expert_inputs):
220
+ expert_outputs.append(expert(expert_input))
221
+ gate_outputs = gate(expert_outputs)
222
+
223
+ if gate_outputs is None or len(gate_outputs) != self.nums_task:
224
+ raise RuntimeError("SNRTrans gate outputs do not match task count.")
225
+
226
+ task_outputs = []
227
+ for task_idx in range(self.nums_task):
228
+ tower_output = self.towers[task_idx](gate_outputs[task_idx])
229
+ task_outputs.append(tower_output)
230
+
231
+ y = torch.cat(task_outputs, dim=1)
232
+ return self.prediction_layer(y)
@@ -0,0 +1,192 @@
1
+ """
2
+ Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Sheng XR, Zhao L, Zhou G, Ding X, Dai B, Luo Q, Yang S, Lv J, Zhang C, Deng H, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427, 2021.
7
+ URL: https://arxiv.org/abs/2101.11427
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ STAR uses shared-specific linear layers to adapt representations per task while
11
+ optionally reusing shared parameters. It can also apply domain-specific batch
12
+ normalization on the first hidden layer when a domain mask is provided.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from nextrec.basic.activation import activation_layer
21
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
+ from nextrec.basic.heads import TaskHead
23
+ from nextrec.basic.layers import DomainBatchNorm, EmbeddingLayer
24
+ from nextrec.basic.model import BaseModel
25
+ from nextrec.utils.types import TaskTypeName
26
+
27
+
28
+ class SharedSpecificLinear(nn.Module):
29
+ """
30
+ Shared-specific linear layer: task-specific projection plus optional shared one.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ input_dim: int,
36
+ output_dim: int,
37
+ nums_task: int,
38
+ use_shared: bool = True,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.use_shared = use_shared
42
+ self.shared = nn.Linear(input_dim, output_dim) if use_shared else None
43
+ self.specific = nn.ModuleList(
44
+ [nn.Linear(input_dim, output_dim) for _ in range(nums_task)]
45
+ )
46
+
47
+ def forward(self, x: torch.Tensor, task_idx: int) -> torch.Tensor:
48
+ output = self.specific[task_idx](x)
49
+ if self.use_shared and self.shared is not None:
50
+ output = output + self.shared(x)
51
+ return output
52
+
53
+
54
+ class STAR(BaseModel):
55
+ """
56
+ STAR: shared-specific multi-task tower with optional domain-specific batch norm.
57
+ """
58
+
59
+ @property
60
+ def model_name(self) -> str:
61
+ return "STAR"
62
+
63
+ @property
64
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
65
+ nums_task = self.nums_task if hasattr(self, "nums_task") else None
66
+ if nums_task is not None and nums_task > 0:
67
+ return ["binary"] * nums_task
68
+ return ["binary"]
69
+
70
+ def __init__(
71
+ self,
72
+ dense_features: list[DenseFeature] | None = None,
73
+ sparse_features: list[SparseFeature] | None = None,
74
+ sequence_features: list[SequenceFeature] | None = None,
75
+ target: list[str] | str | None = None,
76
+ task: TaskTypeName | list[TaskTypeName] | None = None,
77
+ mlp_params: dict | None = None,
78
+ use_shared: bool = True,
79
+ **kwargs,
80
+ ) -> None:
81
+ dense_features = dense_features or []
82
+ sparse_features = sparse_features or []
83
+ sequence_features = sequence_features or []
84
+ mlp_params = mlp_params or {}
85
+ mlp_params.setdefault("hidden_dims", [256, 128])
86
+ mlp_params.setdefault("activation", "relu")
87
+ mlp_params.setdefault("dropout", 0.0)
88
+ mlp_params.setdefault("norm_type", "none")
89
+
90
+ if target is None:
91
+ target = []
92
+ elif isinstance(target, str):
93
+ target = [target]
94
+
95
+ self.nums_task = len(target) if target else 1
96
+
97
+ super().__init__(
98
+ dense_features=dense_features,
99
+ sparse_features=sparse_features,
100
+ sequence_features=sequence_features,
101
+ target=target,
102
+ task=task,
103
+ **kwargs,
104
+ )
105
+
106
+ if not mlp_params["hidden_dims"]:
107
+ raise ValueError("mlp_params['hidden_dims'] must not be empty.")
108
+
109
+ norm_type = mlp_params["norm_type"]
110
+ self.dnn_use_bn = norm_type == "batch_norm"
111
+ self.dnn_dropout = mlp_params["dropout"]
112
+
113
+ self.embedding = EmbeddingLayer(features=self.all_features)
114
+ input_dim = self.embedding.input_dim
115
+
116
+ layer_units = [input_dim] + list(mlp_params["hidden_dims"])
117
+ self.star_layers = nn.ModuleList(
118
+ [
119
+ SharedSpecificLinear(
120
+ input_dim=layer_units[idx],
121
+ output_dim=layer_units[idx + 1],
122
+ nums_task=self.nums_task,
123
+ use_shared=use_shared,
124
+ )
125
+ for idx in range(len(mlp_params["hidden_dims"]))
126
+ ]
127
+ )
128
+ self.activation_layers = nn.ModuleList(
129
+ [
130
+ activation_layer(mlp_params["activation"])
131
+ for _ in range(len(mlp_params["hidden_dims"]))
132
+ ]
133
+ )
134
+ if mlp_params["dropout"] > 0:
135
+ self.dropout_layers = nn.ModuleList(
136
+ [
137
+ nn.Dropout(mlp_params["dropout"])
138
+ for _ in range(len(mlp_params["hidden_dims"]))
139
+ ]
140
+ )
141
+ else:
142
+ self.dropout_layers = nn.ModuleList(
143
+ [nn.Identity() for _ in range(len(mlp_params["hidden_dims"]))]
144
+ )
145
+
146
+ self.domain_bn = (
147
+ DomainBatchNorm(
148
+ num_features=mlp_params["hidden_dims"][0], num_domains=self.nums_task
149
+ )
150
+ if self.dnn_use_bn
151
+ else None
152
+ )
153
+
154
+ self.final_layer = SharedSpecificLinear(
155
+ input_dim=mlp_params["hidden_dims"][-1],
156
+ output_dim=1,
157
+ nums_task=self.nums_task,
158
+ use_shared=use_shared,
159
+ )
160
+ self.prediction_layer = TaskHead(
161
+ task_type=self.task, task_dims=[1] * self.nums_task
162
+ )
163
+
164
+ self.grad_norm_shared_modules = ["embedding", "star_layers", "final_layer"]
165
+ self.register_regularization_weights(
166
+ embedding_attr="embedding",
167
+ include_modules=["star_layers", "final_layer"],
168
+ )
169
+
170
+ def forward(
171
+ self, x: dict[str, torch.Tensor], domain_mask: torch.Tensor | None = None
172
+ ) -> torch.Tensor:
173
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
174
+
175
+ task_outputs = []
176
+ for task_idx in range(self.nums_task):
177
+ output = input_flat
178
+ for layer_idx, layer in enumerate(self.star_layers):
179
+ output = layer(output, task_idx)
180
+ output = self.activation_layers[layer_idx](output)
181
+ output = self.dropout_layers[layer_idx](output)
182
+ if (
183
+ layer_idx == 0
184
+ and self.dnn_use_bn
185
+ and self.domain_bn is not None
186
+ and domain_mask is not None
187
+ ):
188
+ output = self.domain_bn(output, domain_mask)
189
+ task_outputs.append(self.final_layer(output, task_idx))
190
+
191
+ logits = torch.cat(task_outputs, dim=1)
192
+ return self.prediction_layer(logits)