nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,162 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ DSSM v2 - DSSM with pairwise training using BPR loss
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Literal
11
+
12
+ from nextrec.basic.model import BaseMatchModel
13
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
14
+ from nextrec.basic.layers import MLP, EmbeddingLayer
15
+
16
+
17
+ class DSSM_v2(BaseMatchModel):
18
+ """
19
+ DSSM with Pairwise Training
20
+ """
21
+ @property
22
+ def model_name(self) -> str:
23
+ return "DSSM_v2"
24
+
25
+ def __init__(self,
26
+ user_dense_features: list[DenseFeature] | None = None,
27
+ user_sparse_features: list[SparseFeature] | None = None,
28
+ user_sequence_features: list[SequenceFeature] | None = None,
29
+ item_dense_features: list[DenseFeature] | None = None,
30
+ item_sparse_features: list[SparseFeature] | None = None,
31
+ item_sequence_features: list[SequenceFeature] | None = None,
32
+ user_dnn_hidden_units: list[int] = [256, 128, 64],
33
+ item_dnn_hidden_units: list[int] = [256, 128, 64],
34
+ embedding_dim: int = 64,
35
+ dnn_activation: str = 'relu',
36
+ dnn_dropout: float = 0.0,
37
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pairwise',
38
+ num_negative_samples: int = 4,
39
+ temperature: float = 1.0,
40
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
41
+ device: str = 'cpu',
42
+ embedding_l1_reg: float = 0.0,
43
+ dense_l1_reg: float = 0.0,
44
+ embedding_l2_reg: float = 0.0,
45
+ dense_l2_reg: float = 0.0,
46
+ early_stop_patience: int = 20,
47
+ model_id: str = 'dssm_v2'):
48
+
49
+ super(DSSM_v2, self).__init__(
50
+ user_dense_features=user_dense_features,
51
+ user_sparse_features=user_sparse_features,
52
+ user_sequence_features=user_sequence_features,
53
+ item_dense_features=item_dense_features,
54
+ item_sparse_features=item_sparse_features,
55
+ item_sequence_features=item_sequence_features,
56
+ training_mode=training_mode,
57
+ num_negative_samples=num_negative_samples,
58
+ temperature=temperature,
59
+ similarity_metric=similarity_metric,
60
+ device=device,
61
+ embedding_l1_reg=embedding_l1_reg,
62
+ dense_l1_reg=dense_l1_reg,
63
+ embedding_l2_reg=embedding_l2_reg,
64
+ dense_l2_reg=dense_l2_reg,
65
+ early_stop_patience=early_stop_patience,
66
+ model_id=model_id
67
+ )
68
+
69
+ self.embedding_dim = embedding_dim
70
+ self.user_dnn_hidden_units = user_dnn_hidden_units
71
+ self.item_dnn_hidden_units = item_dnn_hidden_units
72
+
73
+ # User tower
74
+ user_features = []
75
+ if user_dense_features:
76
+ user_features.extend(user_dense_features)
77
+ if user_sparse_features:
78
+ user_features.extend(user_sparse_features)
79
+ if user_sequence_features:
80
+ user_features.extend(user_sequence_features)
81
+
82
+ if len(user_features) > 0:
83
+ self.user_embedding = EmbeddingLayer(user_features)
84
+
85
+ user_input_dim = 0
86
+ for feat in user_dense_features or []:
87
+ user_input_dim += 1
88
+ for feat in user_sparse_features or []:
89
+ user_input_dim += feat.embedding_dim
90
+ for feat in user_sequence_features or []:
91
+ user_input_dim += feat.embedding_dim
92
+
93
+ user_dnn_units = user_dnn_hidden_units + [embedding_dim]
94
+ self.user_dnn = MLP(
95
+ input_dim=user_input_dim,
96
+ dims=user_dnn_units,
97
+ output_layer=False,
98
+ dropout=dnn_dropout,
99
+ activation=dnn_activation
100
+ )
101
+
102
+ # Item tower
103
+ item_features = []
104
+ if item_dense_features:
105
+ item_features.extend(item_dense_features)
106
+ if item_sparse_features:
107
+ item_features.extend(item_sparse_features)
108
+ if item_sequence_features:
109
+ item_features.extend(item_sequence_features)
110
+
111
+ if len(item_features) > 0:
112
+ self.item_embedding = EmbeddingLayer(item_features)
113
+
114
+ item_input_dim = 0
115
+ for feat in item_dense_features or []:
116
+ item_input_dim += 1
117
+ for feat in item_sparse_features or []:
118
+ item_input_dim += feat.embedding_dim
119
+ for feat in item_sequence_features or []:
120
+ item_input_dim += feat.embedding_dim
121
+
122
+ item_dnn_units = item_dnn_hidden_units + [embedding_dim]
123
+ self.item_dnn = MLP(
124
+ input_dim=item_input_dim,
125
+ dims=item_dnn_units,
126
+ output_layer=False,
127
+ dropout=dnn_dropout,
128
+ activation=dnn_activation
129
+ )
130
+
131
+ self._register_regularization_weights(
132
+ embedding_attr='user_embedding',
133
+ include_modules=['user_dnn']
134
+ )
135
+ self._register_regularization_weights(
136
+ embedding_attr='item_embedding',
137
+ include_modules=['item_dnn']
138
+ )
139
+
140
+ self.to(device)
141
+
142
+ def user_tower(self, user_input: dict) -> torch.Tensor:
143
+ """User tower"""
144
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
145
+ user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
146
+ user_emb = self.user_dnn(user_emb)
147
+
148
+ # Normalization for better pairwise training
149
+ user_emb = torch.nn.functional.normalize(user_emb, p=2, dim=1)
150
+
151
+ return user_emb
152
+
153
+ def item_tower(self, item_input: dict) -> torch.Tensor:
154
+ """Item tower"""
155
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
156
+ item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
157
+ item_emb = self.item_dnn(item_emb)
158
+
159
+ # Normalization for better pairwise training
160
+ item_emb = torch.nn.functional.normalize(item_emb, p=2, dim=1)
161
+
162
+ return item_emb
@@ -0,0 +1,210 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
7
+ //Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Literal
13
+
14
+ from nextrec.basic.model import BaseMatchModel
15
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
+ from nextrec.basic.layers import MLP, EmbeddingLayer, CapsuleNetwork
17
+
18
+
19
+ class MIND(BaseMatchModel):
20
+ @property
21
+ def model_name(self) -> str:
22
+ return "MIND"
23
+
24
+ @property
25
+ def support_training_modes(self) -> list[str]:
26
+ """MIND only supports pointwise training mode"""
27
+ return ['pointwise']
28
+
29
+ def __init__(self,
30
+ user_dense_features: list[DenseFeature] | None = None,
31
+ user_sparse_features: list[SparseFeature] | None = None,
32
+ user_sequence_features: list[SequenceFeature] | None = None,
33
+ item_dense_features: list[DenseFeature] | None = None,
34
+ item_sparse_features: list[SparseFeature] | None = None,
35
+ item_sequence_features: list[SequenceFeature] | None = None,
36
+ embedding_dim: int = 64,
37
+ num_interests: int = 4,
38
+ capsule_bilinear_type: int = 2,
39
+ routing_times: int = 3,
40
+ relu_layer: bool = False,
41
+ item_dnn_hidden_units: list[int] = [256, 128],
42
+ dnn_activation: str = 'relu',
43
+ dnn_dropout: float = 0.0,
44
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
45
+ num_negative_samples: int = 100,
46
+ temperature: float = 1.0,
47
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
48
+ device: str = 'cpu',
49
+ embedding_l1_reg: float = 0.0,
50
+ dense_l1_reg: float = 0.0,
51
+ embedding_l2_reg: float = 0.0,
52
+ dense_l2_reg: float = 0.0,
53
+ early_stop_patience: int = 20,
54
+ model_id: str = 'mind'):
55
+
56
+ super(MIND, self).__init__(
57
+ user_dense_features=user_dense_features,
58
+ user_sparse_features=user_sparse_features,
59
+ user_sequence_features=user_sequence_features,
60
+ item_dense_features=item_dense_features,
61
+ item_sparse_features=item_sparse_features,
62
+ item_sequence_features=item_sequence_features,
63
+ training_mode=training_mode,
64
+ num_negative_samples=num_negative_samples,
65
+ temperature=temperature,
66
+ similarity_metric=similarity_metric,
67
+ device=device,
68
+ embedding_l1_reg=embedding_l1_reg,
69
+ dense_l1_reg=dense_l1_reg,
70
+ embedding_l2_reg=embedding_l2_reg,
71
+ dense_l2_reg=dense_l2_reg,
72
+ early_stop_patience=early_stop_patience,
73
+ model_id=model_id
74
+ )
75
+
76
+ self.embedding_dim = embedding_dim
77
+ self.num_interests = num_interests
78
+ self.item_dnn_hidden_units = item_dnn_hidden_units
79
+
80
+ user_features = []
81
+ if user_dense_features:
82
+ user_features.extend(user_dense_features)
83
+ if user_sparse_features:
84
+ user_features.extend(user_sparse_features)
85
+ if user_sequence_features:
86
+ user_features.extend(user_sequence_features)
87
+
88
+ if len(user_features) > 0:
89
+ self.user_embedding = EmbeddingLayer(user_features)
90
+
91
+ if not user_sequence_features or len(user_sequence_features) == 0:
92
+ raise ValueError("MIND requires at least one user sequence feature")
93
+
94
+ seq_max_len = user_sequence_features[0].max_len if user_sequence_features[0].max_len else 50
95
+ seq_embedding_dim = user_sequence_features[0].embedding_dim
96
+
97
+ # Capsule Network for multi-interest extraction
98
+ self.capsule_network = CapsuleNetwork(
99
+ embedding_dim=seq_embedding_dim,
100
+ seq_len=seq_max_len,
101
+ bilinear_type=capsule_bilinear_type,
102
+ interest_num=num_interests,
103
+ routing_times=routing_times,
104
+ relu_layer=relu_layer
105
+ )
106
+
107
+ if seq_embedding_dim != embedding_dim:
108
+ self.interest_projection = nn.Linear(seq_embedding_dim, embedding_dim, bias=False)
109
+ nn.init.xavier_uniform_(self.interest_projection.weight)
110
+ else:
111
+ self.interest_projection = None
112
+
113
+ # Item tower
114
+ item_features = []
115
+ if item_dense_features:
116
+ item_features.extend(item_dense_features)
117
+ if item_sparse_features:
118
+ item_features.extend(item_sparse_features)
119
+ if item_sequence_features:
120
+ item_features.extend(item_sequence_features)
121
+
122
+ if len(item_features) > 0:
123
+ self.item_embedding = EmbeddingLayer(item_features)
124
+
125
+ item_input_dim = 0
126
+ for feat in item_dense_features or []:
127
+ item_input_dim += 1
128
+ for feat in item_sparse_features or []:
129
+ item_input_dim += feat.embedding_dim
130
+ for feat in item_sequence_features or []:
131
+ item_input_dim += feat.embedding_dim
132
+
133
+ # Item DNN
134
+ if len(item_dnn_hidden_units) > 0:
135
+ item_dnn_units = item_dnn_hidden_units + [embedding_dim]
136
+ self.item_dnn = MLP(
137
+ input_dim=item_input_dim,
138
+ dims=item_dnn_units,
139
+ output_layer=False,
140
+ dropout=dnn_dropout,
141
+ activation=dnn_activation
142
+ )
143
+ else:
144
+ self.item_dnn = None
145
+
146
+ self._register_regularization_weights(
147
+ embedding_attr='user_embedding',
148
+ include_modules=['capsule_network']
149
+ )
150
+ self._register_regularization_weights(
151
+ embedding_attr='item_embedding',
152
+ include_modules=['item_dnn'] if self.item_dnn else []
153
+ )
154
+
155
+ self.to(device)
156
+
157
+ def user_tower(self, user_input: dict) -> torch.Tensor:
158
+ """
159
+ User tower with multi-interest extraction
160
+
161
+ Returns:
162
+ user_interests: [batch_size, num_interests, embedding_dim]
163
+ """
164
+ seq_feature = self.user_sequence_features[0]
165
+ seq_input = user_input[seq_feature.name]
166
+
167
+ embed = self.user_embedding.embed_dict[seq_feature.embedding_name]
168
+ seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
169
+
170
+ mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
171
+
172
+ multi_interests = self.capsule_network(seq_emb, mask) # [batch_size, num_interests, seq_embedding_dim]
173
+
174
+ if self.interest_projection is not None:
175
+ multi_interests = self.interest_projection(multi_interests) # [batch_size, num_interests, embedding_dim]
176
+
177
+ # L2 normalization
178
+ multi_interests = F.normalize(multi_interests, p=2, dim=-1)
179
+
180
+ return multi_interests
181
+
182
+ def item_tower(self, item_input: dict) -> torch.Tensor:
183
+ """Item tower"""
184
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
185
+ item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
186
+
187
+ if self.item_dnn is not None:
188
+ item_emb = self.item_dnn(item_emb)
189
+
190
+ # L2 normalization
191
+ item_emb = F.normalize(item_emb, p=2, dim=1)
192
+
193
+ return item_emb
194
+
195
+ def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
196
+ item_emb_expanded = item_emb.unsqueeze(1)
197
+
198
+ if self.similarity_metric == 'dot':
199
+ similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
200
+ elif self.similarity_metric == 'cosine':
201
+ similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
202
+ elif self.similarity_metric == 'euclidean':
203
+ similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
204
+ else:
205
+ raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
206
+
207
+ max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
208
+ max_similarity = max_similarity / self.temperature
209
+
210
+ return max_similarity
@@ -0,0 +1,253 @@
1
+ """
2
+ Date: create on 09/11/2025
3
+ Author:
4
+ Yang Zhou,zyaztec@gmail.com
5
+ Reference:
6
+ [1] Ying H, Zhuang F, Zhang F, et al. Sequential recommender system based on hierarchical attention networks[C]
7
+ //IJCAI. 2018: 3926-3932.
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Literal
13
+
14
+ from nextrec.basic.model import BaseMatchModel
15
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
+ from nextrec.basic.layers import MLP, EmbeddingLayer
17
+
18
+
19
+ class SDM(BaseMatchModel):
20
+ @property
21
+ def model_name(self) -> str:
22
+ return "SDM"
23
+
24
+ @property
25
+ def support_training_modes(self) -> list[str]:
26
+ return ['pointwise']
27
+
28
+ def __init__(self,
29
+ user_dense_features: list[DenseFeature] | None = None,
30
+ user_sparse_features: list[SparseFeature] | None = None,
31
+ user_sequence_features: list[SequenceFeature] | None = None,
32
+ item_dense_features: list[DenseFeature] | None = None,
33
+ item_sparse_features: list[SparseFeature] | None = None,
34
+ item_sequence_features: list[SequenceFeature] | None = None,
35
+ embedding_dim: int = 64,
36
+ rnn_type: Literal['GRU', 'LSTM'] = 'GRU',
37
+ rnn_hidden_size: int = 64,
38
+ rnn_num_layers: int = 1,
39
+ rnn_dropout: float = 0.0,
40
+ use_short_term: bool = True,
41
+ use_long_term: bool = True,
42
+ item_dnn_hidden_units: list[int] = [256, 128],
43
+ dnn_activation: str = 'relu',
44
+ dnn_dropout: float = 0.0,
45
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
46
+ num_negative_samples: int = 4,
47
+ temperature: float = 1.0,
48
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
49
+ device: str = 'cpu',
50
+ embedding_l1_reg: float = 0.0,
51
+ dense_l1_reg: float = 0.0,
52
+ embedding_l2_reg: float = 0.0,
53
+ dense_l2_reg: float = 0.0,
54
+ early_stop_patience: int = 20,
55
+ model_id: str = 'sdm'):
56
+
57
+ super(SDM, self).__init__(
58
+ user_dense_features=user_dense_features,
59
+ user_sparse_features=user_sparse_features,
60
+ user_sequence_features=user_sequence_features,
61
+ item_dense_features=item_dense_features,
62
+ item_sparse_features=item_sparse_features,
63
+ item_sequence_features=item_sequence_features,
64
+ training_mode=training_mode,
65
+ num_negative_samples=num_negative_samples,
66
+ temperature=temperature,
67
+ similarity_metric=similarity_metric,
68
+ device=device,
69
+ embedding_l1_reg=embedding_l1_reg,
70
+ dense_l1_reg=dense_l1_reg,
71
+ embedding_l2_reg=embedding_l2_reg,
72
+ dense_l2_reg=dense_l2_reg,
73
+ early_stop_patience=early_stop_patience,
74
+ model_id=model_id
75
+ )
76
+
77
+ self.embedding_dim = embedding_dim
78
+ self.rnn_type = rnn_type
79
+ self.rnn_hidden_size = rnn_hidden_size
80
+ self.use_short_term = use_short_term
81
+ self.use_long_term = use_long_term
82
+ self.item_dnn_hidden_units = item_dnn_hidden_units
83
+
84
+ # User tower
85
+ user_features = []
86
+ if user_dense_features:
87
+ user_features.extend(user_dense_features)
88
+ if user_sparse_features:
89
+ user_features.extend(user_sparse_features)
90
+ if user_sequence_features:
91
+ user_features.extend(user_sequence_features)
92
+
93
+ if len(user_features) > 0:
94
+ self.user_embedding = EmbeddingLayer(user_features)
95
+
96
+ if not user_sequence_features or len(user_sequence_features) == 0:
97
+ raise ValueError("SDM requires at least one user sequence feature")
98
+
99
+ seq_emb_dim = user_sequence_features[0].embedding_dim
100
+
101
+ if rnn_type == 'GRU':
102
+ self.rnn = nn.GRU(
103
+ input_size=seq_emb_dim,
104
+ hidden_size=rnn_hidden_size,
105
+ num_layers=rnn_num_layers,
106
+ batch_first=True,
107
+ dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
108
+ )
109
+ elif rnn_type == 'LSTM':
110
+ self.rnn = nn.LSTM(
111
+ input_size=seq_emb_dim,
112
+ hidden_size=rnn_hidden_size,
113
+ num_layers=rnn_num_layers,
114
+ batch_first=True,
115
+ dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
116
+ )
117
+ else:
118
+ raise ValueError(f"Unknown RNN type: {rnn_type}")
119
+
120
+ user_final_dim = 0
121
+ if use_long_term:
122
+ user_final_dim += rnn_hidden_size
123
+ if use_short_term:
124
+ user_final_dim += seq_emb_dim
125
+
126
+ for feat in user_dense_features or []:
127
+ user_final_dim += 1
128
+ for feat in user_sparse_features or []:
129
+ user_final_dim += feat.embedding_dim
130
+
131
+ # User DNN to final embedding
132
+ self.user_dnn = MLP(
133
+ input_dim=user_final_dim,
134
+ dims=[rnn_hidden_size * 2, embedding_dim],
135
+ output_layer=False,
136
+ dropout=dnn_dropout,
137
+ activation=dnn_activation
138
+ )
139
+
140
+ # Item tower
141
+ item_features = []
142
+ if item_dense_features:
143
+ item_features.extend(item_dense_features)
144
+ if item_sparse_features:
145
+ item_features.extend(item_sparse_features)
146
+ if item_sequence_features:
147
+ item_features.extend(item_sequence_features)
148
+
149
+ if len(item_features) > 0:
150
+ self.item_embedding = EmbeddingLayer(item_features)
151
+
152
+ item_input_dim = 0
153
+ for feat in item_dense_features or []:
154
+ item_input_dim += 1
155
+ for feat in item_sparse_features or []:
156
+ item_input_dim += feat.embedding_dim
157
+ for feat in item_sequence_features or []:
158
+ item_input_dim += feat.embedding_dim
159
+
160
+ # Item DNN
161
+ if len(item_dnn_hidden_units) > 0:
162
+ item_dnn_units = item_dnn_hidden_units + [embedding_dim]
163
+ self.item_dnn = MLP(
164
+ input_dim=item_input_dim,
165
+ dims=item_dnn_units,
166
+ output_layer=False,
167
+ dropout=dnn_dropout,
168
+ activation=dnn_activation
169
+ )
170
+ else:
171
+ self.item_dnn = None
172
+
173
+ self._register_regularization_weights(
174
+ embedding_attr='user_embedding',
175
+ include_modules=['rnn', 'user_dnn']
176
+ )
177
+ self._register_regularization_weights(
178
+ embedding_attr='item_embedding',
179
+ include_modules=['item_dnn'] if self.item_dnn else []
180
+ )
181
+
182
+ self.to(device)
183
+
184
+ def user_tower(self, user_input: dict) -> torch.Tensor:
185
+ seq_feature = self.user_sequence_features[0]
186
+ seq_input = user_input[seq_feature.name]
187
+
188
+ embed = self.user_embedding.embed_dict[seq_feature.embedding_name]
189
+ seq_emb = embed(seq_input.long()) # [batch_size, seq_len, seq_emb_dim]
190
+
191
+ if self.rnn_type == 'GRU':
192
+ rnn_output, hidden = self.rnn(seq_emb) # hidden: [num_layers, batch, hidden_size]
193
+ elif self.rnn_type == 'LSTM':
194
+ rnn_output, (hidden, cell) = self.rnn(seq_emb)
195
+
196
+ features_list = []
197
+
198
+ if self.use_long_term:
199
+ if self.rnn.num_layers > 1:
200
+ long_term = hidden[-1, :, :] # [batch_size, hidden_size]
201
+ else:
202
+ long_term = hidden.squeeze(0) # [batch_size, hidden_size]
203
+ features_list.append(long_term)
204
+
205
+ if self.use_short_term:
206
+ mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
207
+ seq_lengths = mask.sum(dim=1).long() - 1 # [batch_size]
208
+ seq_lengths = torch.clamp(seq_lengths, min=0)
209
+
210
+ batch_size = seq_emb.size(0)
211
+ batch_indices = torch.arange(batch_size, device=seq_emb.device)
212
+ short_term = seq_emb[batch_indices, seq_lengths, :] # [batch_size, seq_emb_dim]
213
+ features_list.append(short_term)
214
+
215
+ if self.user_dense_features:
216
+ dense_features = []
217
+ for feat in self.user_dense_features:
218
+ if feat.name in user_input:
219
+ val = user_input[feat.name].float()
220
+ if val.dim() == 1:
221
+ val = val.unsqueeze(1)
222
+ dense_features.append(val)
223
+ if dense_features:
224
+ features_list.append(torch.cat(dense_features, dim=1))
225
+
226
+ if self.user_sparse_features:
227
+ sparse_features = []
228
+ for feat in self.user_sparse_features:
229
+ if feat.name in user_input:
230
+ embed = self.user_embedding.embed_dict[feat.embedding_name]
231
+ sparse_emb = embed(user_input[feat.name].long())
232
+ sparse_features.append(sparse_emb)
233
+ if sparse_features:
234
+ features_list.append(torch.cat(sparse_features, dim=1))
235
+
236
+ user_features = torch.cat(features_list, dim=1)
237
+ user_emb = self.user_dnn(user_features)
238
+ user_emb = F.normalize(user_emb, p=2, dim=1)
239
+
240
+ return user_emb
241
+
242
+ def item_tower(self, item_input: dict) -> torch.Tensor:
243
+ """Item tower"""
244
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
245
+ item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
246
+
247
+ if self.item_dnn is not None:
248
+ item_emb = self.item_dnn(item_emb)
249
+
250
+ # L2 normalization
251
+ item_emb = F.normalize(item_emb, p=2, dim=1)
252
+
253
+ return item_emb