nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ Reference:
6
6
  [1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
7
7
  //Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
8
8
  """
9
+
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -20,39 +21,41 @@ class MIND(BaseMatchModel):
20
21
  @property
21
22
  def model_name(self) -> str:
22
23
  return "MIND"
23
-
24
+
24
25
  @property
25
26
  def support_training_modes(self) -> list[str]:
26
27
  """MIND only supports pointwise training mode"""
27
- return ['pointwise']
28
-
29
- def __init__(self,
30
- user_dense_features: list[DenseFeature] | None = None,
31
- user_sparse_features: list[SparseFeature] | None = None,
32
- user_sequence_features: list[SequenceFeature] | None = None,
33
- item_dense_features: list[DenseFeature] | None = None,
34
- item_sparse_features: list[SparseFeature] | None = None,
35
- item_sequence_features: list[SequenceFeature] | None = None,
36
- embedding_dim: int = 64,
37
- num_interests: int = 4,
38
- capsule_bilinear_type: int = 2,
39
- routing_times: int = 3,
40
- relu_layer: bool = False,
41
- item_dnn_hidden_units: list[int] = [256, 128],
42
- dnn_activation: str = 'relu',
43
- dnn_dropout: float = 0.0,
44
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
45
- num_negative_samples: int = 100,
46
- temperature: float = 1.0,
47
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
48
- device: str = 'cpu',
49
- embedding_l1_reg: float = 0.0,
50
- dense_l1_reg: float = 0.0,
51
- embedding_l2_reg: float = 0.0,
52
- dense_l2_reg: float = 0.0,
53
- early_stop_patience: int = 20,
54
- model_id: str = 'mind'):
55
-
28
+ return ["pointwise"]
29
+
30
+ def __init__(
31
+ self,
32
+ user_dense_features: list[DenseFeature] | None = None,
33
+ user_sparse_features: list[SparseFeature] | None = None,
34
+ user_sequence_features: list[SequenceFeature] | None = None,
35
+ item_dense_features: list[DenseFeature] | None = None,
36
+ item_sparse_features: list[SparseFeature] | None = None,
37
+ item_sequence_features: list[SequenceFeature] | None = None,
38
+ embedding_dim: int = 64,
39
+ num_interests: int = 4,
40
+ capsule_bilinear_type: int = 2,
41
+ routing_times: int = 3,
42
+ relu_layer: bool = False,
43
+ item_dnn_hidden_units: list[int] = [256, 128],
44
+ dnn_activation: str = "relu",
45
+ dnn_dropout: float = 0.0,
46
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "listwise",
47
+ num_negative_samples: int = 100,
48
+ temperature: float = 1.0,
49
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
50
+ device: str = "cpu",
51
+ embedding_l1_reg: float = 0.0,
52
+ dense_l1_reg: float = 0.0,
53
+ embedding_l2_reg: float = 0.0,
54
+ dense_l2_reg: float = 0.0,
55
+ early_stop_patience: int = 20,
56
+ model_id: str = "mind",
57
+ ):
58
+
56
59
  super(MIND, self).__init__(
57
60
  user_dense_features=user_dense_features,
58
61
  user_sparse_features=user_sparse_features,
@@ -70,9 +73,9 @@ class MIND(BaseMatchModel):
70
73
  embedding_l2_reg=embedding_l2_reg,
71
74
  dense_l2_reg=dense_l2_reg,
72
75
  early_stop_patience=early_stop_patience,
73
- model_id=model_id
76
+ model_id=model_id,
74
77
  )
75
-
78
+
76
79
  self.embedding_dim = embedding_dim
77
80
  self.num_interests = num_interests
78
81
  self.item_dnn_hidden_units = item_dnn_hidden_units
@@ -84,16 +87,20 @@ class MIND(BaseMatchModel):
84
87
  user_features.extend(user_sparse_features)
85
88
  if user_sequence_features:
86
89
  user_features.extend(user_sequence_features)
87
-
90
+
88
91
  if len(user_features) > 0:
89
92
  self.user_embedding = EmbeddingLayer(user_features)
90
-
93
+
91
94
  if not user_sequence_features or len(user_sequence_features) == 0:
92
95
  raise ValueError("MIND requires at least one user sequence feature")
93
-
94
- seq_max_len = user_sequence_features[0].max_len if user_sequence_features[0].max_len else 50
96
+
97
+ seq_max_len = (
98
+ user_sequence_features[0].max_len
99
+ if user_sequence_features[0].max_len
100
+ else 50
101
+ )
95
102
  seq_embedding_dim = user_sequence_features[0].embedding_dim
96
-
103
+
97
104
  # Capsule Network for multi-interest extraction
98
105
  self.capsule_network = CapsuleNetwork(
99
106
  embedding_dim=seq_embedding_dim,
@@ -101,15 +108,17 @@ class MIND(BaseMatchModel):
101
108
  bilinear_type=capsule_bilinear_type,
102
109
  interest_num=num_interests,
103
110
  routing_times=routing_times,
104
- relu_layer=relu_layer
111
+ relu_layer=relu_layer,
105
112
  )
106
-
113
+
107
114
  if seq_embedding_dim != embedding_dim:
108
- self.interest_projection = nn.Linear(seq_embedding_dim, embedding_dim, bias=False)
115
+ self.interest_projection = nn.Linear(
116
+ seq_embedding_dim, embedding_dim, bias=False
117
+ )
109
118
  nn.init.xavier_uniform_(self.interest_projection.weight)
110
119
  else:
111
120
  self.interest_projection = None
112
-
121
+
113
122
  # Item tower
114
123
  item_features = []
115
124
  if item_dense_features:
@@ -118,10 +127,10 @@ class MIND(BaseMatchModel):
118
127
  item_features.extend(item_sparse_features)
119
128
  if item_sequence_features:
120
129
  item_features.extend(item_sequence_features)
121
-
130
+
122
131
  if len(item_features) > 0:
123
132
  self.item_embedding = EmbeddingLayer(item_features)
124
-
133
+
125
134
  item_input_dim = 0
126
135
  for feat in item_dense_features or []:
127
136
  item_input_dim += 1
@@ -129,7 +138,7 @@ class MIND(BaseMatchModel):
129
138
  item_input_dim += feat.embedding_dim
130
139
  for feat in item_sequence_features or []:
131
140
  item_input_dim += feat.embedding_dim
132
-
141
+
133
142
  # Item DNN
134
143
  if len(item_dnn_hidden_units) > 0:
135
144
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
@@ -138,26 +147,25 @@ class MIND(BaseMatchModel):
138
147
  dims=item_dnn_units,
139
148
  output_layer=False,
140
149
  dropout=dnn_dropout,
141
- activation=dnn_activation
150
+ activation=dnn_activation,
142
151
  )
143
152
  else:
144
153
  self.item_dnn = None
145
-
154
+
146
155
  self._register_regularization_weights(
147
- embedding_attr='user_embedding',
148
- include_modules=['capsule_network']
156
+ embedding_attr="user_embedding", include_modules=["capsule_network"]
149
157
  )
150
158
  self._register_regularization_weights(
151
- embedding_attr='item_embedding',
152
- include_modules=['item_dnn'] if self.item_dnn else []
159
+ embedding_attr="item_embedding",
160
+ include_modules=["item_dnn"] if self.item_dnn else [],
153
161
  )
154
-
162
+
155
163
  self.to(device)
156
-
164
+
157
165
  def user_tower(self, user_input: dict) -> torch.Tensor:
158
166
  """
159
167
  User tower with multi-interest extraction
160
-
168
+
161
169
  Returns:
162
170
  user_interests: [batch_size, num_interests, embedding_dim]
163
171
  """
@@ -168,43 +176,53 @@ class MIND(BaseMatchModel):
168
176
  seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
169
177
 
170
178
  mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
171
-
172
- multi_interests = self.capsule_network(seq_emb, mask) # [batch_size, num_interests, seq_embedding_dim]
173
-
179
+
180
+ multi_interests = self.capsule_network(
181
+ seq_emb, mask
182
+ ) # [batch_size, num_interests, seq_embedding_dim]
183
+
174
184
  if self.interest_projection is not None:
175
- multi_interests = self.interest_projection(multi_interests) # [batch_size, num_interests, embedding_dim]
176
-
185
+ multi_interests = self.interest_projection(
186
+ multi_interests
187
+ ) # [batch_size, num_interests, embedding_dim]
188
+
177
189
  # L2 normalization
178
190
  multi_interests = F.normalize(multi_interests, p=2, dim=-1)
179
-
191
+
180
192
  return multi_interests
181
-
193
+
182
194
  def item_tower(self, item_input: dict) -> torch.Tensor:
183
195
  """Item tower"""
184
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
196
+ all_item_features = (
197
+ self.item_dense_features
198
+ + self.item_sparse_features
199
+ + self.item_sequence_features
200
+ )
185
201
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
186
-
202
+
187
203
  if self.item_dnn is not None:
188
204
  item_emb = self.item_dnn(item_emb)
189
-
205
+
190
206
  # L2 normalization
191
207
  item_emb = F.normalize(item_emb, p=2, dim=1)
192
-
208
+
193
209
  return item_emb
194
-
195
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
210
+
211
+ def compute_similarity(
212
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
213
+ ) -> torch.Tensor:
196
214
  item_emb_expanded = item_emb.unsqueeze(1)
197
-
198
- if self.similarity_metric == 'dot':
215
+
216
+ if self.similarity_metric == "dot":
199
217
  similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
200
- elif self.similarity_metric == 'cosine':
218
+ elif self.similarity_metric == "cosine":
201
219
  similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
202
- elif self.similarity_metric == 'euclidean':
220
+ elif self.similarity_metric == "euclidean":
203
221
  similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
204
222
  else:
205
223
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
206
224
 
207
225
  max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
208
226
  max_similarity = max_similarity / self.temperature
209
-
227
+
210
228
  return max_similarity
@@ -6,6 +6,7 @@ Reference:
6
6
  [1] Ying H, Zhuang F, Zhang F, et al. Sequential recommender system based on hierarchical attention networks[C]
7
7
  //IJCAI. 2018: 3926-3932.
8
8
  """
9
+
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -20,40 +21,42 @@ class SDM(BaseMatchModel):
20
21
  @property
21
22
  def model_name(self) -> str:
22
23
  return "SDM"
23
-
24
+
24
25
  @property
25
26
  def support_training_modes(self) -> list[str]:
26
- return ['pointwise']
27
-
28
- def __init__(self,
29
- user_dense_features: list[DenseFeature] | None = None,
30
- user_sparse_features: list[SparseFeature] | None = None,
31
- user_sequence_features: list[SequenceFeature] | None = None,
32
- item_dense_features: list[DenseFeature] | None = None,
33
- item_sparse_features: list[SparseFeature] | None = None,
34
- item_sequence_features: list[SequenceFeature] | None = None,
35
- embedding_dim: int = 64,
36
- rnn_type: Literal['GRU', 'LSTM'] = 'GRU',
37
- rnn_hidden_size: int = 64,
38
- rnn_num_layers: int = 1,
39
- rnn_dropout: float = 0.0,
40
- use_short_term: bool = True,
41
- use_long_term: bool = True,
42
- item_dnn_hidden_units: list[int] = [256, 128],
43
- dnn_activation: str = 'relu',
44
- dnn_dropout: float = 0.0,
45
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
46
- num_negative_samples: int = 4,
47
- temperature: float = 1.0,
48
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
49
- device: str = 'cpu',
50
- embedding_l1_reg: float = 0.0,
51
- dense_l1_reg: float = 0.0,
52
- embedding_l2_reg: float = 0.0,
53
- dense_l2_reg: float = 0.0,
54
- early_stop_patience: int = 20,
55
- model_id: str = 'sdm'):
56
-
27
+ return ["pointwise"]
28
+
29
+ def __init__(
30
+ self,
31
+ user_dense_features: list[DenseFeature] | None = None,
32
+ user_sparse_features: list[SparseFeature] | None = None,
33
+ user_sequence_features: list[SequenceFeature] | None = None,
34
+ item_dense_features: list[DenseFeature] | None = None,
35
+ item_sparse_features: list[SparseFeature] | None = None,
36
+ item_sequence_features: list[SequenceFeature] | None = None,
37
+ embedding_dim: int = 64,
38
+ rnn_type: Literal["GRU", "LSTM"] = "GRU",
39
+ rnn_hidden_size: int = 64,
40
+ rnn_num_layers: int = 1,
41
+ rnn_dropout: float = 0.0,
42
+ use_short_term: bool = True,
43
+ use_long_term: bool = True,
44
+ item_dnn_hidden_units: list[int] = [256, 128],
45
+ dnn_activation: str = "relu",
46
+ dnn_dropout: float = 0.0,
47
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
48
+ num_negative_samples: int = 4,
49
+ temperature: float = 1.0,
50
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
51
+ device: str = "cpu",
52
+ embedding_l1_reg: float = 0.0,
53
+ dense_l1_reg: float = 0.0,
54
+ embedding_l2_reg: float = 0.0,
55
+ dense_l2_reg: float = 0.0,
56
+ early_stop_patience: int = 20,
57
+ model_id: str = "sdm",
58
+ ):
59
+
57
60
  super(SDM, self).__init__(
58
61
  user_dense_features=user_dense_features,
59
62
  user_sparse_features=user_sparse_features,
@@ -71,16 +74,16 @@ class SDM(BaseMatchModel):
71
74
  embedding_l2_reg=embedding_l2_reg,
72
75
  dense_l2_reg=dense_l2_reg,
73
76
  early_stop_patience=early_stop_patience,
74
- model_id=model_id
77
+ model_id=model_id,
75
78
  )
76
-
79
+
77
80
  self.embedding_dim = embedding_dim
78
81
  self.rnn_type = rnn_type
79
82
  self.rnn_hidden_size = rnn_hidden_size
80
83
  self.use_short_term = use_short_term
81
84
  self.use_long_term = use_long_term
82
85
  self.item_dnn_hidden_units = item_dnn_hidden_units
83
-
86
+
84
87
  # User tower
85
88
  user_features = []
86
89
  if user_dense_features:
@@ -89,54 +92,54 @@ class SDM(BaseMatchModel):
89
92
  user_features.extend(user_sparse_features)
90
93
  if user_sequence_features:
91
94
  user_features.extend(user_sequence_features)
92
-
95
+
93
96
  if len(user_features) > 0:
94
97
  self.user_embedding = EmbeddingLayer(user_features)
95
-
98
+
96
99
  if not user_sequence_features or len(user_sequence_features) == 0:
97
100
  raise ValueError("SDM requires at least one user sequence feature")
98
-
101
+
99
102
  seq_emb_dim = user_sequence_features[0].embedding_dim
100
-
101
- if rnn_type == 'GRU':
103
+
104
+ if rnn_type == "GRU":
102
105
  self.rnn = nn.GRU(
103
106
  input_size=seq_emb_dim,
104
107
  hidden_size=rnn_hidden_size,
105
108
  num_layers=rnn_num_layers,
106
109
  batch_first=True,
107
- dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
110
+ dropout=rnn_dropout if rnn_num_layers > 1 else 0.0,
108
111
  )
109
- elif rnn_type == 'LSTM':
112
+ elif rnn_type == "LSTM":
110
113
  self.rnn = nn.LSTM(
111
114
  input_size=seq_emb_dim,
112
115
  hidden_size=rnn_hidden_size,
113
116
  num_layers=rnn_num_layers,
114
117
  batch_first=True,
115
- dropout=rnn_dropout if rnn_num_layers > 1 else 0.0
118
+ dropout=rnn_dropout if rnn_num_layers > 1 else 0.0,
116
119
  )
117
120
  else:
118
121
  raise ValueError(f"Unknown RNN type: {rnn_type}")
119
-
122
+
120
123
  user_final_dim = 0
121
124
  if use_long_term:
122
- user_final_dim += rnn_hidden_size
125
+ user_final_dim += rnn_hidden_size
123
126
  if use_short_term:
124
- user_final_dim += seq_emb_dim
125
-
127
+ user_final_dim += seq_emb_dim
128
+
126
129
  for feat in user_dense_features or []:
127
130
  user_final_dim += 1
128
131
  for feat in user_sparse_features or []:
129
132
  user_final_dim += feat.embedding_dim
130
-
133
+
131
134
  # User DNN to final embedding
132
135
  self.user_dnn = MLP(
133
136
  input_dim=user_final_dim,
134
137
  dims=[rnn_hidden_size * 2, embedding_dim],
135
138
  output_layer=False,
136
139
  dropout=dnn_dropout,
137
- activation=dnn_activation
140
+ activation=dnn_activation,
138
141
  )
139
-
142
+
140
143
  # Item tower
141
144
  item_features = []
142
145
  if item_dense_features:
@@ -145,10 +148,10 @@ class SDM(BaseMatchModel):
145
148
  item_features.extend(item_sparse_features)
146
149
  if item_sequence_features:
147
150
  item_features.extend(item_sequence_features)
148
-
151
+
149
152
  if len(item_features) > 0:
150
153
  self.item_embedding = EmbeddingLayer(item_features)
151
-
154
+
152
155
  item_input_dim = 0
153
156
  for feat in item_dense_features or []:
154
157
  item_input_dim += 1
@@ -156,7 +159,7 @@ class SDM(BaseMatchModel):
156
159
  item_input_dim += feat.embedding_dim
157
160
  for feat in item_sequence_features or []:
158
161
  item_input_dim += feat.embedding_dim
159
-
162
+
160
163
  # Item DNN
161
164
  if len(item_dnn_hidden_units) > 0:
162
165
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
@@ -165,53 +168,58 @@ class SDM(BaseMatchModel):
165
168
  dims=item_dnn_units,
166
169
  output_layer=False,
167
170
  dropout=dnn_dropout,
168
- activation=dnn_activation
171
+ activation=dnn_activation,
169
172
  )
170
173
  else:
171
174
  self.item_dnn = None
172
-
175
+
173
176
  self._register_regularization_weights(
174
- embedding_attr='user_embedding',
175
- include_modules=['rnn', 'user_dnn']
177
+ embedding_attr="user_embedding", include_modules=["rnn", "user_dnn"]
176
178
  )
177
179
  self._register_regularization_weights(
178
- embedding_attr='item_embedding',
179
- include_modules=['item_dnn'] if self.item_dnn else []
180
+ embedding_attr="item_embedding",
181
+ include_modules=["item_dnn"] if self.item_dnn else [],
180
182
  )
181
-
183
+
182
184
  self.to(device)
183
-
185
+
184
186
  def user_tower(self, user_input: dict) -> torch.Tensor:
185
187
  seq_feature = self.user_sequence_features[0]
186
188
  seq_input = user_input[seq_feature.name]
187
-
189
+
188
190
  embed = self.user_embedding.embed_dict[seq_feature.embedding_name]
189
191
  seq_emb = embed(seq_input.long()) # [batch_size, seq_len, seq_emb_dim]
190
-
191
- if self.rnn_type == 'GRU':
192
- rnn_output, hidden = self.rnn(seq_emb) # hidden: [num_layers, batch, hidden_size]
193
- elif self.rnn_type == 'LSTM':
192
+
193
+ if self.rnn_type == "GRU":
194
+ rnn_output, hidden = self.rnn(
195
+ seq_emb
196
+ ) # hidden: [num_layers, batch, hidden_size]
197
+ elif self.rnn_type == "LSTM":
194
198
  rnn_output, (hidden, cell) = self.rnn(seq_emb)
195
-
199
+
196
200
  features_list = []
197
-
201
+
198
202
  if self.use_long_term:
199
203
  if self.rnn.num_layers > 1:
200
204
  long_term = hidden[-1, :, :] # [batch_size, hidden_size]
201
205
  else:
202
206
  long_term = hidden.squeeze(0) # [batch_size, hidden_size]
203
207
  features_list.append(long_term)
204
-
208
+
205
209
  if self.use_short_term:
206
- mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
210
+ mask = (
211
+ seq_input != seq_feature.padding_idx
212
+ ).float() # [batch_size, seq_len]
207
213
  seq_lengths = mask.sum(dim=1).long() - 1 # [batch_size]
208
214
  seq_lengths = torch.clamp(seq_lengths, min=0)
209
-
215
+
210
216
  batch_size = seq_emb.size(0)
211
217
  batch_indices = torch.arange(batch_size, device=seq_emb.device)
212
- short_term = seq_emb[batch_indices, seq_lengths, :] # [batch_size, seq_emb_dim]
218
+ short_term = seq_emb[
219
+ batch_indices, seq_lengths, :
220
+ ] # [batch_size, seq_emb_dim]
213
221
  features_list.append(short_term)
214
-
222
+
215
223
  if self.user_dense_features:
216
224
  dense_features = []
217
225
  for feat in self.user_dense_features:
@@ -222,7 +230,7 @@ class SDM(BaseMatchModel):
222
230
  dense_features.append(val)
223
231
  if dense_features:
224
232
  features_list.append(torch.cat(dense_features, dim=1))
225
-
233
+
226
234
  if self.user_sparse_features:
227
235
  sparse_features = []
228
236
  for feat in self.user_sparse_features:
@@ -232,22 +240,26 @@ class SDM(BaseMatchModel):
232
240
  sparse_features.append(sparse_emb)
233
241
  if sparse_features:
234
242
  features_list.append(torch.cat(sparse_features, dim=1))
235
-
243
+
236
244
  user_features = torch.cat(features_list, dim=1)
237
245
  user_emb = self.user_dnn(user_features)
238
246
  user_emb = F.normalize(user_emb, p=2, dim=1)
239
-
247
+
240
248
  return user_emb
241
-
249
+
242
250
  def item_tower(self, item_input: dict) -> torch.Tensor:
243
251
  """Item tower"""
244
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
252
+ all_item_features = (
253
+ self.item_dense_features
254
+ + self.item_sparse_features
255
+ + self.item_sequence_features
256
+ )
245
257
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
246
-
258
+
247
259
  if self.item_dnn is not None:
248
260
  item_emb = self.item_dnn(item_emb)
249
-
261
+
250
262
  # L2 normalization
251
263
  item_emb = F.normalize(item_emb, p=2, dim=1)
252
-
264
+
253
265
  return item_emb