nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,120 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]
7
+ //Proceedings of the ADKDD'17. 2017: 1-7.
8
+ (https://arxiv.org/abs/1708.05123)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from nextrec.basic.model import BaseModel
15
+ from nextrec.basic.layers import EmbeddingLayer, MLP, CrossNetwork, PredictionLayer
16
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
+
18
+
19
+ class DCN(BaseModel):
20
+ @property
21
+ def model_name(self):
22
+ return "DCN"
23
+
24
+ @property
25
+ def task_type(self):
26
+ return "binary"
27
+
28
+ def __init__(self,
29
+ dense_features: list[DenseFeature],
30
+ sparse_features: list[SparseFeature],
31
+ sequence_features: list[SequenceFeature],
32
+ cross_num: int = 3,
33
+ mlp_params: dict | None = None,
34
+ target: list[str] = [],
35
+ optimizer: str = "adam",
36
+ optimizer_params: dict = {},
37
+ loss: str | nn.Module | None = "bce",
38
+ device: str = 'cpu',
39
+ model_id: str = "baseline",
40
+ embedding_l1_reg=1e-6,
41
+ dense_l1_reg=1e-5,
42
+ embedding_l2_reg=1e-5,
43
+ dense_l2_reg=1e-4):
44
+
45
+ super(DCN, self).__init__(
46
+ dense_features=dense_features,
47
+ sparse_features=sparse_features,
48
+ sequence_features=sequence_features,
49
+ target=target,
50
+ task=self.task_type,
51
+ device=device,
52
+ embedding_l1_reg=embedding_l1_reg,
53
+ dense_l1_reg=dense_l1_reg,
54
+ embedding_l2_reg=embedding_l2_reg,
55
+ dense_l2_reg=dense_l2_reg,
56
+ early_stop_patience=20,
57
+ model_id=model_id
58
+ )
59
+
60
+ self.loss = loss
61
+ if self.loss is None:
62
+ self.loss = "bce"
63
+
64
+ # All features
65
+ self.all_features = dense_features + sparse_features + sequence_features
66
+
67
+ # Embedding layer
68
+ self.embedding = EmbeddingLayer(features=self.all_features)
69
+
70
+ # Calculate input dimension
71
+ emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
72
+ dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
73
+ input_dim = emb_dim_total + dense_input_dim
74
+
75
+ # Cross Network
76
+ self.cross_network = CrossNetwork(input_dim=input_dim, num_layers=cross_num)
77
+
78
+ # Deep Network (optional)
79
+ if mlp_params is not None:
80
+ self.use_dnn = True
81
+ self.mlp = MLP(input_dim=input_dim, **mlp_params)
82
+ # Final layer combines cross and deep
83
+ self.final_layer = nn.Linear(input_dim + 1, 1) # +1 for MLP output
84
+ else:
85
+ self.use_dnn = False
86
+ # Final layer only uses cross network output
87
+ self.final_layer = nn.Linear(input_dim, 1)
88
+
89
+ self.prediction_layer = PredictionLayer(task_type=self.task_type)
90
+
91
+ # Register regularization weights
92
+ self._register_regularization_weights(
93
+ embedding_attr='embedding',
94
+ include_modules=['cross_network', 'mlp', 'final_layer']
95
+ )
96
+
97
+ self.compile(
98
+ optimizer=optimizer,
99
+ optimizer_params=optimizer_params,
100
+ loss=loss
101
+ )
102
+
103
+ def forward(self, x):
104
+ # Get all embeddings and flatten
105
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
106
+
107
+ # Cross Network
108
+ cross_output = self.cross_network(input_flat) # [B, input_dim]
109
+
110
+ if self.use_dnn:
111
+ # Deep Network
112
+ deep_output = self.mlp(input_flat) # [B, 1]
113
+ # Concatenate cross and deep
114
+ combined = torch.cat([cross_output, deep_output], dim=-1) # [B, input_dim + 1]
115
+ else:
116
+ combined = cross_output
117
+
118
+ # Final prediction
119
+ y = self.final_layer(combined)
120
+ return self.prediction_layer(y)
@@ -0,0 +1,95 @@
1
+ """
2
+ Date: create on 27/10/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Guo H, Tang R, Ye Y, et al. Deepfm: a factorization-machine based neural network for ctr prediction[J]. arXiv preprint arXiv:1703.04247, 2017.(https://arxiv.org/abs/1703.04247)
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from nextrec.basic.model import BaseModel
13
+ from nextrec.basic.layers import FM, LR, EmbeddingLayer, MLP, PredictionLayer
14
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
+
16
+ class DeepFM(BaseModel):
17
+ @property
18
+ def model_name(self):
19
+ return "DeepFM"
20
+
21
+ @property
22
+ def task_type(self):
23
+ return "binary"
24
+
25
+ def __init__(self,
26
+ dense_features: list[DenseFeature]|list = [],
27
+ sparse_features: list[SparseFeature]|list = [],
28
+ sequence_features: list[SequenceFeature]|list = [],
29
+ mlp_params: dict = {},
30
+ target: list[str]|str = [],
31
+ optimizer: str = "adam",
32
+ optimizer_params: dict = {},
33
+ loss: str | nn.Module | None = "bce",
34
+ device: str = 'cpu',
35
+ model_id: str = "baseline",
36
+ embedding_l1_reg=1e-6,
37
+ dense_l1_reg=1e-5,
38
+ embedding_l2_reg=1e-5,
39
+ dense_l2_reg=1e-4):
40
+
41
+ super(DeepFM, self).__init__(
42
+ dense_features=dense_features,
43
+ sparse_features=sparse_features,
44
+ sequence_features=sequence_features,
45
+ target=target,
46
+ task=self.task_type,
47
+ device=device,
48
+ embedding_l1_reg=embedding_l1_reg,
49
+ dense_l1_reg=dense_l1_reg,
50
+ embedding_l2_reg=embedding_l2_reg,
51
+ dense_l2_reg=dense_l2_reg,
52
+ early_stop_patience=20,
53
+ model_id=model_id
54
+ )
55
+
56
+ self.loss = loss
57
+ if self.loss is None:
58
+ self.loss = "bce"
59
+
60
+ self.fm_features = sparse_features + sequence_features
61
+ self.deep_features = dense_features + sparse_features + sequence_features
62
+
63
+ self.embedding = EmbeddingLayer(features=self.deep_features)
64
+
65
+ fm_emb_dim_total = sum([f.embedding_dim for f in self.fm_features])
66
+ deep_emb_dim_total = sum([f.embedding_dim for f in self.deep_features if not isinstance(f, DenseFeature)])
67
+ dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
68
+
69
+ self.linear = LR(fm_emb_dim_total)
70
+ self.fm = FM(reduce_sum=True)
71
+ self.mlp = MLP(input_dim=deep_emb_dim_total + dense_input_dim, **mlp_params)
72
+ self.prediction_layer = PredictionLayer(task_type=self.task_type)
73
+
74
+ # Register regularization weights
75
+ self._register_regularization_weights(
76
+ embedding_attr='embedding',
77
+ include_modules=['linear', 'mlp']
78
+ )
79
+
80
+ self.compile(
81
+ optimizer=optimizer,
82
+ optimizer_params=optimizer_params,
83
+ loss=loss
84
+ )
85
+
86
+ def forward(self, x):
87
+ input_deep = self.embedding(x=x, features=self.deep_features, squeeze_dim=True)
88
+ input_fm = self.embedding(x=x, features=self.fm_features, squeeze_dim=False)
89
+
90
+ y_linear = self.linear(input_fm.flatten(start_dim=1))
91
+ y_fm = self.fm(input_fm)
92
+ y_deep = self.mlp(input_deep) # [B, 1]
93
+
94
+ y = y_linear + y_fm + y_deep
95
+ return self.prediction_layer(y)
@@ -0,0 +1,214 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Zhou G, Mou N, Fan Y, et al. Deep interest evolution network for click-through rate prediction[C]
7
+ //Proceedings of the AAAI conference on artificial intelligence. 2019, 33(01): 5941-5948.
8
+ (https://arxiv.org/abs/1809.03672)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from nextrec.basic.model import BaseModel
15
+ from nextrec.basic.layers import EmbeddingLayer, MLP, AttentionPoolingLayer, DynamicGRU, AUGRU, PredictionLayer
16
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
+
18
+
19
+ class DIEN(BaseModel):
20
+ @property
21
+ def model_name(self):
22
+ return "DIEN"
23
+
24
+ @property
25
+ def task_type(self):
26
+ return "binary"
27
+
28
+ def __init__(self,
29
+ dense_features: list[DenseFeature],
30
+ sparse_features: list[SparseFeature],
31
+ sequence_features: list[SequenceFeature],
32
+ mlp_params: dict,
33
+ gru_hidden_size: int = 64,
34
+ attention_hidden_units: list[int] = [80, 40],
35
+ attention_activation: str = 'sigmoid',
36
+ use_negsampling: bool = False,
37
+ target: list[str] = [],
38
+ optimizer: str = "adam",
39
+ optimizer_params: dict = {},
40
+ loss: str | nn.Module | None = "bce",
41
+ device: str = 'cpu',
42
+ model_id: str = "baseline",
43
+ embedding_l1_reg=1e-6,
44
+ dense_l1_reg=1e-5,
45
+ embedding_l2_reg=1e-5,
46
+ dense_l2_reg=1e-4):
47
+
48
+ super(DIEN, self).__init__(
49
+ dense_features=dense_features,
50
+ sparse_features=sparse_features,
51
+ sequence_features=sequence_features,
52
+ target=target,
53
+ task=self.task_type,
54
+ device=device,
55
+ embedding_l1_reg=embedding_l1_reg,
56
+ dense_l1_reg=dense_l1_reg,
57
+ embedding_l2_reg=embedding_l2_reg,
58
+ dense_l2_reg=dense_l2_reg,
59
+ early_stop_patience=20,
60
+ model_id=model_id
61
+ )
62
+
63
+ self.loss = loss
64
+ if self.loss is None:
65
+ self.loss = "bce"
66
+
67
+ self.use_negsampling = use_negsampling
68
+
69
+ # Features classification
70
+ if len(sequence_features) == 0:
71
+ raise ValueError("DIEN requires at least one sequence feature for user behavior history")
72
+
73
+ self.behavior_feature = sequence_features[0] # User behavior sequence
74
+ self.candidate_feature = sparse_features[-1] if sparse_features else None # Candidate item
75
+
76
+ self.other_sparse_features = sparse_features[:-1] if self.candidate_feature else sparse_features
77
+ self.dense_features_list = dense_features
78
+
79
+ # All features for embedding
80
+ self.all_features = dense_features + sparse_features + sequence_features
81
+
82
+ # Embedding layer
83
+ self.embedding = EmbeddingLayer(features=self.all_features)
84
+
85
+ behavior_emb_dim = self.behavior_feature.embedding_dim
86
+ self.candidate_proj = None
87
+ if self.candidate_feature is not None and self.candidate_feature.embedding_dim != gru_hidden_size:
88
+ self.candidate_proj = nn.Linear(self.candidate_feature.embedding_dim, gru_hidden_size)
89
+
90
+ # Interest Extractor Layer (GRU)
91
+ self.interest_extractor = DynamicGRU(
92
+ input_size=behavior_emb_dim,
93
+ hidden_size=gru_hidden_size
94
+ )
95
+
96
+ # Attention layer for computing attention scores
97
+ self.attention_layer = AttentionPoolingLayer(
98
+ embedding_dim=gru_hidden_size,
99
+ hidden_units=attention_hidden_units,
100
+ activation=attention_activation,
101
+ use_softmax=False # We'll use scores directly for AUGRU
102
+ )
103
+
104
+ # Interest Evolution Layer (AUGRU)
105
+ self.interest_evolution = AUGRU(
106
+ input_size=gru_hidden_size,
107
+ hidden_size=gru_hidden_size
108
+ )
109
+
110
+ # Calculate MLP input dimension
111
+ mlp_input_dim = 0
112
+ if self.candidate_feature:
113
+ mlp_input_dim += self.candidate_feature.embedding_dim
114
+ mlp_input_dim += gru_hidden_size # final interest state
115
+ mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
116
+ mlp_input_dim += sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
117
+
118
+ # MLP for final prediction
119
+ self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
120
+ self.prediction_layer = PredictionLayer(task_type=self.task_type)
121
+
122
+ # Register regularization weights
123
+ self._register_regularization_weights(
124
+ embedding_attr='embedding',
125
+ include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj']
126
+ )
127
+
128
+ self.compile(
129
+ optimizer=optimizer,
130
+ optimizer_params=optimizer_params,
131
+ loss=loss
132
+ )
133
+
134
+ def forward(self, x):
135
+ # Get candidate item embedding
136
+ if self.candidate_feature:
137
+ candidate_emb = self.embedding.embed_dict[self.candidate_feature.embedding_name](
138
+ x[self.candidate_feature.name].long()
139
+ ) # [B, emb_dim]
140
+ else:
141
+ raise ValueError("DIEN requires a candidate item feature")
142
+
143
+ # Get behavior sequence embedding
144
+ behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
145
+ behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
146
+ behavior_seq
147
+ ) # [B, seq_len, emb_dim]
148
+
149
+ # Create mask for padding
150
+ if self.behavior_feature.padding_idx is not None:
151
+ mask = (behavior_seq != self.behavior_feature.padding_idx).unsqueeze(-1).float()
152
+ else:
153
+ mask = (behavior_seq != 0).unsqueeze(-1).float()
154
+
155
+ # Step 1: Interest Extractor (GRU)
156
+ interest_states, _ = self.interest_extractor(behavior_emb) # [B, seq_len, hidden_size]
157
+
158
+ # Step 2: Compute attention scores for each time step
159
+ batch_size, seq_len, hidden_size = interest_states.shape
160
+
161
+ # Project candidate to hidden_size if necessary (defined in __init__)
162
+ if self.candidate_proj is not None:
163
+ candidate_for_attention = self.candidate_proj(candidate_emb)
164
+ else:
165
+ candidate_for_attention = candidate_emb
166
+
167
+ # Compute attention scores for AUGRU
168
+ attention_scores = []
169
+ for t in range(seq_len):
170
+ score = self.attention_layer.attention_net(
171
+ torch.cat([
172
+ candidate_for_attention,
173
+ interest_states[:, t, :],
174
+ candidate_for_attention - interest_states[:, t, :],
175
+ candidate_for_attention * interest_states[:, t, :]
176
+ ], dim=-1)
177
+ ) # [B, 1]
178
+ attention_scores.append(score)
179
+
180
+ attention_scores = torch.cat(attention_scores, dim=1).unsqueeze(-1) # [B, seq_len, 1]
181
+ attention_scores = torch.sigmoid(attention_scores) # Normalize to [0, 1]
182
+
183
+ # Apply mask to attention scores
184
+ attention_scores = attention_scores * mask
185
+
186
+ # Step 3: Interest Evolution (AUGRU)
187
+ final_states, final_interest = self.interest_evolution(
188
+ interest_states,
189
+ attention_scores
190
+ ) # final_interest: [B, hidden_size]
191
+
192
+ # Get other features
193
+ other_embeddings = []
194
+ other_embeddings.append(candidate_emb)
195
+ other_embeddings.append(final_interest)
196
+
197
+ # Other sparse features
198
+ for feat in self.other_sparse_features:
199
+ feat_emb = self.embedding.embed_dict[feat.embedding_name](x[feat.name].long())
200
+ other_embeddings.append(feat_emb)
201
+
202
+ # Dense features
203
+ for feat in self.dense_features_list:
204
+ val = x[feat.name].float()
205
+ if val.dim() == 1:
206
+ val = val.unsqueeze(1)
207
+ other_embeddings.append(val)
208
+
209
+ # Concatenate all features
210
+ concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
211
+
212
+ # MLP prediction
213
+ y = self.mlp(concat_input) # [B, 1]
214
+ return self.prediction_layer(y)
@@ -0,0 +1,181 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
7
+ //Proceedings of the 24th ACM SIGKDD international conference on knowledge discovery & data mining. 2018: 1059-1068.
8
+ (https://arxiv.org/abs/1706.06978)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from nextrec.basic.model import BaseModel
15
+ from nextrec.basic.layers import EmbeddingLayer, MLP, AttentionPoolingLayer, PredictionLayer
16
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
17
+
18
+
19
+ class DIN(BaseModel):
20
+ @property
21
+ def model_name(self):
22
+ return "DIN"
23
+
24
+ @property
25
+ def task_type(self):
26
+ return "binary"
27
+
28
+ def __init__(self,
29
+ dense_features: list[DenseFeature],
30
+ sparse_features: list[SparseFeature],
31
+ sequence_features: list[SequenceFeature],
32
+ mlp_params: dict,
33
+ attention_hidden_units: list[int] = [80, 40],
34
+ attention_activation: str = 'sigmoid',
35
+ attention_use_softmax: bool = True,
36
+ target: list[str] = [],
37
+ optimizer: str = "adam",
38
+ optimizer_params: dict = {},
39
+ loss: str | nn.Module | None = "bce",
40
+ device: str = 'cpu',
41
+ model_id: str = "baseline",
42
+ embedding_l1_reg=1e-6,
43
+ dense_l1_reg=1e-5,
44
+ embedding_l2_reg=1e-5,
45
+ dense_l2_reg=1e-4):
46
+
47
+ super(DIN, self).__init__(
48
+ dense_features=dense_features,
49
+ sparse_features=sparse_features,
50
+ sequence_features=sequence_features,
51
+ target=target,
52
+ task=self.task_type,
53
+ device=device,
54
+ embedding_l1_reg=embedding_l1_reg,
55
+ dense_l1_reg=dense_l1_reg,
56
+ embedding_l2_reg=embedding_l2_reg,
57
+ dense_l2_reg=dense_l2_reg,
58
+ early_stop_patience=20,
59
+ model_id=model_id
60
+ )
61
+
62
+ self.loss = loss
63
+ if self.loss is None:
64
+ self.loss = "bce"
65
+
66
+ # Features classification
67
+ # DIN requires: candidate item + user behavior sequence + other features
68
+ if len(sequence_features) == 0:
69
+ raise ValueError("DIN requires at least one sequence feature for user behavior history")
70
+
71
+ self.behavior_feature = sequence_features[0] # User behavior sequence
72
+ self.candidate_feature = sparse_features[-1] if sparse_features else None # Candidate item
73
+
74
+ # Other features (excluding behavior sequence in final concatenation)
75
+ self.other_sparse_features = sparse_features[:-1] if self.candidate_feature else sparse_features
76
+ self.dense_features_list = dense_features
77
+
78
+ # All features for embedding
79
+ self.all_features = dense_features + sparse_features + sequence_features
80
+
81
+ # Embedding layer
82
+ self.embedding = EmbeddingLayer(features=self.all_features)
83
+
84
+ # Attention layer for behavior sequence
85
+ behavior_emb_dim = self.behavior_feature.embedding_dim
86
+ self.candidate_attention_proj = None
87
+ if self.candidate_feature is not None and self.candidate_feature.embedding_dim != behavior_emb_dim:
88
+ self.candidate_attention_proj = nn.Linear(self.candidate_feature.embedding_dim, behavior_emb_dim)
89
+ self.attention = AttentionPoolingLayer(
90
+ embedding_dim=behavior_emb_dim,
91
+ hidden_units=attention_hidden_units,
92
+ activation=attention_activation,
93
+ use_softmax=attention_use_softmax
94
+ )
95
+
96
+ # Calculate MLP input dimension
97
+ # candidate + attention_pooled_behavior + other_sparse + dense
98
+ mlp_input_dim = 0
99
+ if self.candidate_feature:
100
+ mlp_input_dim += self.candidate_feature.embedding_dim
101
+ mlp_input_dim += behavior_emb_dim # attention pooled
102
+ mlp_input_dim += sum([f.embedding_dim for f in self.other_sparse_features])
103
+ mlp_input_dim += sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
104
+
105
+ # MLP for final prediction
106
+ self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
107
+ self.prediction_layer = PredictionLayer(task_type=self.task_type)
108
+
109
+ # Register regularization weights
110
+ self._register_regularization_weights(
111
+ embedding_attr='embedding',
112
+ include_modules=['attention', 'mlp', 'candidate_attention_proj']
113
+ )
114
+
115
+ self.compile(
116
+ optimizer=optimizer,
117
+ optimizer_params=optimizer_params,
118
+ loss=loss
119
+ )
120
+
121
+ def forward(self, x):
122
+ # Get candidate item embedding
123
+ if self.candidate_feature:
124
+ candidate_emb = self.embedding.embed_dict[self.candidate_feature.embedding_name](
125
+ x[self.candidate_feature.name].long()
126
+ ) # [B, emb_dim]
127
+ else:
128
+ candidate_emb = None
129
+
130
+ # Get behavior sequence embedding
131
+ behavior_seq = x[self.behavior_feature.name].long() # [B, seq_len]
132
+ behavior_emb = self.embedding.embed_dict[self.behavior_feature.embedding_name](
133
+ behavior_seq
134
+ ) # [B, seq_len, emb_dim]
135
+
136
+ # Create mask for padding
137
+ if self.behavior_feature.padding_idx is not None:
138
+ mask = (behavior_seq != self.behavior_feature.padding_idx).unsqueeze(-1).float()
139
+ else:
140
+ mask = (behavior_seq != 0).unsqueeze(-1).float()
141
+
142
+ # Apply attention pooling
143
+ if candidate_emb is not None:
144
+ candidate_query = candidate_emb
145
+ if self.candidate_attention_proj is not None:
146
+ candidate_query = self.candidate_attention_proj(candidate_query)
147
+ pooled_behavior = self.attention(
148
+ query=candidate_query,
149
+ keys=behavior_emb,
150
+ mask=mask
151
+ ) # [B, emb_dim]
152
+ else:
153
+ # If no candidate, use mean pooling
154
+ pooled_behavior = torch.sum(behavior_emb * mask, dim=1) / (mask.sum(dim=1) + 1e-9)
155
+
156
+ # Get other features
157
+ other_embeddings = []
158
+
159
+ if candidate_emb is not None:
160
+ other_embeddings.append(candidate_emb)
161
+
162
+ other_embeddings.append(pooled_behavior)
163
+
164
+ # Other sparse features
165
+ for feat in self.other_sparse_features:
166
+ feat_emb = self.embedding.embed_dict[feat.embedding_name](x[feat.name].long())
167
+ other_embeddings.append(feat_emb)
168
+
169
+ # Dense features
170
+ for feat in self.dense_features_list:
171
+ val = x[feat.name].float()
172
+ if val.dim() == 1:
173
+ val = val.unsqueeze(1)
174
+ other_embeddings.append(val)
175
+
176
+ # Concatenate all features
177
+ concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
178
+
179
+ # MLP prediction
180
+ y = self.mlp(concat_input) # [B, 1]
181
+ return self.prediction_layer(y)