nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ Reference:
6
6
  [1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
7
7
  //Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
8
8
  """
9
+
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -15,6 +16,7 @@ from nextrec.basic.model import BaseMatchModel
15
16
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
17
  from nextrec.basic.layers import MLP, EmbeddingLayer
17
18
 
19
+
18
20
  class MultiInterestSA(nn.Module):
19
21
  """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
20
22
 
@@ -22,19 +24,25 @@ class MultiInterestSA(nn.Module):
22
24
  super(MultiInterestSA, self).__init__()
23
25
  self.embedding_dim = embedding_dim
24
26
  self.interest_num = interest_num
25
- if hidden_dim == None:
27
+ if hidden_dim is None:
26
28
  self.hidden_dim = self.embedding_dim * 4
27
- self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
28
- self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
29
- self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
29
+ self.W1 = torch.nn.Parameter(
30
+ torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
31
+ )
32
+ self.W2 = torch.nn.Parameter(
33
+ torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
34
+ )
35
+ self.W3 = torch.nn.Parameter(
36
+ torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
37
+ )
30
38
 
31
39
  def forward(self, seq_emb, mask=None):
32
- H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
33
- if mask != None:
34
- A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
40
+ H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
41
+ if mask is not None:
42
+ A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
35
43
  A = F.softmax(A, dim=1)
36
44
  else:
37
- A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
45
+ A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
38
46
  A = A.permute(0, 2, 1)
39
47
  multi_interest_emb = torch.matmul(A, seq_emb)
40
48
  return multi_interest_emb
@@ -43,7 +51,15 @@ class MultiInterestSA(nn.Module):
43
51
  class CapsuleNetwork(nn.Module):
44
52
  """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
45
53
 
46
- def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
54
+ def __init__(
55
+ self,
56
+ embedding_dim,
57
+ seq_len,
58
+ bilinear_type=2,
59
+ interest_num=4,
60
+ routing_times=3,
61
+ relu_layer=False,
62
+ ):
47
63
  super(CapsuleNetwork, self).__init__()
48
64
  self.embedding_dim = embedding_dim # h
49
65
  self.seq_len = seq_len # s
@@ -53,13 +69,24 @@ class CapsuleNetwork(nn.Module):
53
69
 
54
70
  self.relu_layer = relu_layer
55
71
  self.stop_grad = True
56
- self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
72
+ self.relu = nn.Sequential(
73
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
74
+ )
57
75
  if self.bilinear_type == 0: # MIND
58
76
  self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
59
77
  elif self.bilinear_type == 1:
60
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
78
+ self.linear = nn.Linear(
79
+ self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
80
+ )
61
81
  else:
62
- self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
82
+ self.w = nn.Parameter(
83
+ torch.Tensor(
84
+ 1,
85
+ self.seq_len,
86
+ self.interest_num * self.embedding_dim,
87
+ self.embedding_dim,
88
+ )
89
+ )
63
90
  nn.init.xavier_uniform_(self.w)
64
91
 
65
92
  def forward(self, item_eb, mask):
@@ -70,11 +97,15 @@ class CapsuleNetwork(nn.Module):
70
97
  item_eb_hat = self.linear(item_eb)
71
98
  else:
72
99
  u = torch.unsqueeze(item_eb, dim=2)
73
- item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
100
+ item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
74
101
 
75
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
102
+ item_eb_hat = torch.reshape(
103
+ item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
104
+ )
76
105
  item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
77
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
106
+ item_eb_hat = torch.reshape(
107
+ item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
108
+ )
78
109
 
79
110
  if self.stop_grad:
80
111
  item_eb_hat_iter = item_eb_hat.detach()
@@ -82,34 +113,47 @@ class CapsuleNetwork(nn.Module):
82
113
  item_eb_hat_iter = item_eb_hat
83
114
 
84
115
  if self.bilinear_type > 0:
85
- capsule_weight = torch.zeros(item_eb_hat.shape[0],
86
- self.interest_num,
87
- self.seq_len,
88
- device=item_eb.device,
89
- requires_grad=False)
116
+ capsule_weight = torch.zeros(
117
+ item_eb_hat.shape[0],
118
+ self.interest_num,
119
+ self.seq_len,
120
+ device=item_eb.device,
121
+ requires_grad=False,
122
+ )
90
123
  else:
91
- capsule_weight = torch.randn(item_eb_hat.shape[0],
92
- self.interest_num,
93
- self.seq_len,
94
- device=item_eb.device,
95
- requires_grad=False)
124
+ capsule_weight = torch.randn(
125
+ item_eb_hat.shape[0],
126
+ self.interest_num,
127
+ self.seq_len,
128
+ device=item_eb.device,
129
+ requires_grad=False,
130
+ )
96
131
 
97
132
  for i in range(self.routing_times): # 动态路由传播3次
98
133
  atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
99
134
  paddings = torch.zeros_like(atten_mask, dtype=torch.float)
100
135
 
101
136
  capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
102
- capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
137
+ capsule_softmax_weight = torch.where(
138
+ torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
139
+ )
103
140
  capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
104
141
 
105
142
  if i < 2:
106
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
143
+ interest_capsule = torch.matmul(
144
+ capsule_softmax_weight, item_eb_hat_iter
145
+ )
107
146
  cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
108
147
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
109
148
  interest_capsule = scalar_factor * interest_capsule
110
149
 
111
- delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
112
- delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
150
+ delta_weight = torch.matmul(
151
+ item_eb_hat_iter,
152
+ torch.transpose(interest_capsule, 2, 3).contiguous(),
153
+ )
154
+ delta_weight = torch.reshape(
155
+ delta_weight, (-1, self.interest_num, self.seq_len)
156
+ )
113
157
  capsule_weight = capsule_weight + delta_weight
114
158
  else:
115
159
  interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
@@ -117,7 +161,9 @@ class CapsuleNetwork(nn.Module):
117
161
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
118
162
  interest_capsule = scalar_factor * interest_capsule
119
163
 
120
- interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
164
+ interest_capsule = torch.reshape(
165
+ interest_capsule, (-1, self.interest_num, self.embedding_dim)
166
+ )
121
167
 
122
168
  if self.relu_layer:
123
169
  interest_capsule = self.relu(interest_capsule)
@@ -129,45 +175,52 @@ class MIND(BaseMatchModel):
129
175
  @property
130
176
  def model_name(self) -> str:
131
177
  return "MIND"
132
-
178
+
133
179
  @property
134
180
  def support_training_modes(self) -> list[str]:
135
181
  """MIND only supports pointwise training mode"""
136
- return ['pointwise']
137
-
138
- def __init__(self,
139
- user_dense_features: list[DenseFeature] | None = None,
140
- user_sparse_features: list[SparseFeature] | None = None,
141
- user_sequence_features: list[SequenceFeature] | None = None,
142
- item_dense_features: list[DenseFeature] | None = None,
143
- item_sparse_features: list[SparseFeature] | None = None,
144
- item_sequence_features: list[SequenceFeature] | None = None,
145
- embedding_dim: int = 64,
146
- num_interests: int = 4,
147
- capsule_bilinear_type: int = 2,
148
- routing_times: int = 3,
149
- relu_layer: bool = False,
150
- item_dnn_hidden_units: list[int] = [256, 128],
151
- dnn_activation: str = 'relu',
152
- dnn_dropout: float = 0.0,
153
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
154
- num_negative_samples: int = 100,
155
- temperature: float = 1.0,
156
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
157
- device: str = 'cpu',
158
- embedding_l1_reg: float = 0.0,
159
- dense_l1_reg: float = 0.0,
160
- embedding_l2_reg: float = 0.0,
161
- dense_l2_reg: float = 0.0,
162
- early_stop_patience: int = 20,
163
- optimizer: str | torch.optim.Optimizer = "adam",
164
- optimizer_params: dict | None = None,
165
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
166
- scheduler_params: dict | None = None,
167
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
168
- loss_params: dict | list[dict] | None = None,
169
- **kwargs):
170
-
182
+ return ["pointwise"]
183
+
184
+ def __init__(
185
+ self,
186
+ user_dense_features: list[DenseFeature] | None = None,
187
+ user_sparse_features: list[SparseFeature] | None = None,
188
+ user_sequence_features: list[SequenceFeature] | None = None,
189
+ item_dense_features: list[DenseFeature] | None = None,
190
+ item_sparse_features: list[SparseFeature] | None = None,
191
+ item_sequence_features: list[SequenceFeature] | None = None,
192
+ embedding_dim: int = 64,
193
+ num_interests: int = 4,
194
+ capsule_bilinear_type: int = 2,
195
+ routing_times: int = 3,
196
+ relu_layer: bool = False,
197
+ item_dnn_hidden_units: list[int] = [256, 128],
198
+ dnn_activation: str = "relu",
199
+ dnn_dropout: float = 0.0,
200
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
201
+ num_negative_samples: int = 100,
202
+ temperature: float = 1.0,
203
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
204
+ device: str = "cpu",
205
+ embedding_l1_reg: float = 0.0,
206
+ dense_l1_reg: float = 0.0,
207
+ embedding_l2_reg: float = 0.0,
208
+ dense_l2_reg: float = 0.0,
209
+ early_stop_patience: int = 20,
210
+ optimizer: str | torch.optim.Optimizer = "adam",
211
+ optimizer_params: dict | None = None,
212
+ scheduler: (
213
+ str
214
+ | torch.optim.lr_scheduler._LRScheduler
215
+ | type[torch.optim.lr_scheduler._LRScheduler]
216
+ | None
217
+ ) = None,
218
+ scheduler_params: dict | None = None,
219
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
220
+ loss_params: dict | list[dict] | None = None,
221
+ **kwargs,
222
+ ):
223
+
171
224
  super(MIND, self).__init__(
172
225
  user_dense_features=user_dense_features,
173
226
  user_sparse_features=user_sparse_features,
@@ -184,10 +237,9 @@ class MIND(BaseMatchModel):
184
237
  dense_l1_reg=dense_l1_reg,
185
238
  embedding_l2_reg=embedding_l2_reg,
186
239
  dense_l2_reg=dense_l2_reg,
187
- early_stop_patience=early_stop_patience,
188
- **kwargs
240
+ **kwargs,
189
241
  )
190
-
242
+
191
243
  self.embedding_dim = embedding_dim
192
244
  self.num_interests = num_interests
193
245
  self.item_dnn_hidden_units = item_dnn_hidden_units
@@ -199,16 +251,20 @@ class MIND(BaseMatchModel):
199
251
  user_features.extend(user_sparse_features)
200
252
  if user_sequence_features:
201
253
  user_features.extend(user_sequence_features)
202
-
254
+
203
255
  if len(user_features) > 0:
204
256
  self.user_embedding = EmbeddingLayer(user_features)
205
-
257
+
206
258
  if not user_sequence_features or len(user_sequence_features) == 0:
207
259
  raise ValueError("MIND requires at least one user sequence feature")
208
-
209
- seq_max_len = user_sequence_features[0].max_len if user_sequence_features[0].max_len else 50
260
+
261
+ seq_max_len = (
262
+ user_sequence_features[0].max_len
263
+ if user_sequence_features[0].max_len
264
+ else 50
265
+ )
210
266
  seq_embedding_dim = user_sequence_features[0].embedding_dim
211
-
267
+
212
268
  # Capsule Network for multi-interest extraction
213
269
  self.capsule_network = CapsuleNetwork(
214
270
  embedding_dim=seq_embedding_dim,
@@ -216,15 +272,17 @@ class MIND(BaseMatchModel):
216
272
  bilinear_type=capsule_bilinear_type,
217
273
  interest_num=num_interests,
218
274
  routing_times=routing_times,
219
- relu_layer=relu_layer
275
+ relu_layer=relu_layer,
220
276
  )
221
-
277
+
222
278
  if seq_embedding_dim != embedding_dim:
223
- self.interest_projection = nn.Linear(seq_embedding_dim, embedding_dim, bias=False)
279
+ self.interest_projection = nn.Linear(
280
+ seq_embedding_dim, embedding_dim, bias=False
281
+ )
224
282
  nn.init.xavier_uniform_(self.interest_projection.weight)
225
283
  else:
226
284
  self.interest_projection = None
227
-
285
+
228
286
  # Item tower
229
287
  item_features = []
230
288
  if item_dense_features:
@@ -233,10 +291,10 @@ class MIND(BaseMatchModel):
233
291
  item_features.extend(item_sparse_features)
234
292
  if item_sequence_features:
235
293
  item_features.extend(item_sequence_features)
236
-
294
+
237
295
  if len(item_features) > 0:
238
296
  self.item_embedding = EmbeddingLayer(item_features)
239
-
297
+
240
298
  item_input_dim = 0
241
299
  for feat in item_dense_features or []:
242
300
  item_input_dim += 1
@@ -244,7 +302,7 @@ class MIND(BaseMatchModel):
244
302
  item_input_dim += feat.embedding_dim
245
303
  for feat in item_sequence_features or []:
246
304
  item_input_dim += feat.embedding_dim
247
-
305
+
248
306
  # Item DNN
249
307
  if len(item_dnn_hidden_units) > 0:
250
308
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
@@ -253,20 +311,19 @@ class MIND(BaseMatchModel):
253
311
  dims=item_dnn_units,
254
312
  output_layer=False,
255
313
  dropout=dnn_dropout,
256
- activation=dnn_activation
314
+ activation=dnn_activation,
257
315
  )
258
316
  else:
259
317
  self.item_dnn = None
260
-
318
+
261
319
  self.register_regularization_weights(
262
- embedding_attr='user_embedding',
263
- include_modules=['capsule_network']
320
+ embedding_attr="user_embedding", include_modules=["capsule_network"]
264
321
  )
265
322
  self.register_regularization_weights(
266
- embedding_attr='item_embedding',
267
- include_modules=['item_dnn'] if self.item_dnn else []
323
+ embedding_attr="item_embedding",
324
+ include_modules=["item_dnn"] if self.item_dnn else [],
268
325
  )
269
-
326
+
270
327
  self.compile(
271
328
  optimizer=optimizer,
272
329
  optimizer_params=optimizer_params,
@@ -277,11 +334,11 @@ class MIND(BaseMatchModel):
277
334
  )
278
335
 
279
336
  self.to(device)
280
-
337
+
281
338
  def user_tower(self, user_input: dict) -> torch.Tensor:
282
339
  """
283
340
  User tower with multi-interest extraction
284
-
341
+
285
342
  Returns:
286
343
  user_interests: [batch_size, num_interests, embedding_dim]
287
344
  """
@@ -292,43 +349,53 @@ class MIND(BaseMatchModel):
292
349
  seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
293
350
 
294
351
  mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
295
-
296
- multi_interests = self.capsule_network(seq_emb, mask) # [batch_size, num_interests, seq_embedding_dim]
297
-
352
+
353
+ multi_interests = self.capsule_network(
354
+ seq_emb, mask
355
+ ) # [batch_size, num_interests, seq_embedding_dim]
356
+
298
357
  if self.interest_projection is not None:
299
- multi_interests = self.interest_projection(multi_interests) # [batch_size, num_interests, embedding_dim]
300
-
358
+ multi_interests = self.interest_projection(
359
+ multi_interests
360
+ ) # [batch_size, num_interests, embedding_dim]
361
+
301
362
  # L2 normalization
302
363
  multi_interests = F.normalize(multi_interests, p=2, dim=-1)
303
-
364
+
304
365
  return multi_interests
305
-
366
+
306
367
  def item_tower(self, item_input: dict) -> torch.Tensor:
307
368
  """Item tower"""
308
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
369
+ all_item_features = (
370
+ self.item_dense_features
371
+ + self.item_sparse_features
372
+ + self.item_sequence_features
373
+ )
309
374
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
310
-
375
+
311
376
  if self.item_dnn is not None:
312
377
  item_emb = self.item_dnn(item_emb)
313
-
378
+
314
379
  # L2 normalization
315
380
  item_emb = F.normalize(item_emb, p=2, dim=1)
316
-
381
+
317
382
  return item_emb
318
-
319
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
383
+
384
+ def compute_similarity(
385
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
386
+ ) -> torch.Tensor:
320
387
  item_emb_expanded = item_emb.unsqueeze(1)
321
-
322
- if self.similarity_metric == 'dot':
388
+
389
+ if self.similarity_metric == "dot":
323
390
  similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
324
- elif self.similarity_metric == 'cosine':
391
+ elif self.similarity_metric == "cosine":
325
392
  similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
326
- elif self.similarity_metric == 'euclidean':
393
+ elif self.similarity_metric == "euclidean":
327
394
  similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
328
395
  else:
329
396
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
330
397
 
331
398
  max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
332
399
  max_similarity = max_similarity / self.temperature
333
-
400
+
334
401
  return max_similarity