nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ Reference:
6
6
  [1] Li C, Liu Z, Wu M, et al. Multi-interest network with dynamic routing for recommendation at Tmall[C]
7
7
  //Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
8
8
  """
9
+
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -15,6 +16,7 @@ from nextrec.basic.model import BaseMatchModel
15
16
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
17
  from nextrec.basic.layers import MLP, EmbeddingLayer
17
18
 
19
+
18
20
  class MultiInterestSA(nn.Module):
19
21
  """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
20
22
 
@@ -22,19 +24,25 @@ class MultiInterestSA(nn.Module):
22
24
  super(MultiInterestSA, self).__init__()
23
25
  self.embedding_dim = embedding_dim
24
26
  self.interest_num = interest_num
25
- if hidden_dim == None:
27
+ if hidden_dim is None:
26
28
  self.hidden_dim = self.embedding_dim * 4
27
- self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
28
- self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
29
- self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
29
+ self.W1 = torch.nn.Parameter(
30
+ torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True
31
+ )
32
+ self.W2 = torch.nn.Parameter(
33
+ torch.rand(self.hidden_dim, self.interest_num), requires_grad=True
34
+ )
35
+ self.W3 = torch.nn.Parameter(
36
+ torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True
37
+ )
30
38
 
31
39
  def forward(self, seq_emb, mask=None):
32
- H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
33
- if mask != None:
34
- A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
40
+ H = torch.einsum("bse, ed -> bsd", seq_emb, self.W1).tanh()
41
+ if mask is not None:
42
+ A = torch.einsum("bsd, dk -> bsk", H, self.W2) + -1.0e9 * (1 - mask.float())
35
43
  A = F.softmax(A, dim=1)
36
44
  else:
37
- A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
45
+ A = F.softmax(torch.einsum("bsd, dk -> bsk", H, self.W2), dim=1)
38
46
  A = A.permute(0, 2, 1)
39
47
  multi_interest_emb = torch.matmul(A, seq_emb)
40
48
  return multi_interest_emb
@@ -43,7 +51,15 @@ class MultiInterestSA(nn.Module):
43
51
  class CapsuleNetwork(nn.Module):
44
52
  """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
45
53
 
46
- def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
54
+ def __init__(
55
+ self,
56
+ embedding_dim,
57
+ seq_len,
58
+ bilinear_type=2,
59
+ interest_num=4,
60
+ routing_times=3,
61
+ relu_layer=False,
62
+ ):
47
63
  super(CapsuleNetwork, self).__init__()
48
64
  self.embedding_dim = embedding_dim # h
49
65
  self.seq_len = seq_len # s
@@ -53,13 +69,24 @@ class CapsuleNetwork(nn.Module):
53
69
 
54
70
  self.relu_layer = relu_layer
55
71
  self.stop_grad = True
56
- self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
72
+ self.relu = nn.Sequential(
73
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU()
74
+ )
57
75
  if self.bilinear_type == 0: # MIND
58
76
  self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
59
77
  elif self.bilinear_type == 1:
60
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
78
+ self.linear = nn.Linear(
79
+ self.embedding_dim, self.embedding_dim * self.interest_num, bias=False
80
+ )
61
81
  else:
62
- self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
82
+ self.w = nn.Parameter(
83
+ torch.Tensor(
84
+ 1,
85
+ self.seq_len,
86
+ self.interest_num * self.embedding_dim,
87
+ self.embedding_dim,
88
+ )
89
+ )
63
90
  nn.init.xavier_uniform_(self.w)
64
91
 
65
92
  def forward(self, item_eb, mask):
@@ -70,11 +97,15 @@ class CapsuleNetwork(nn.Module):
70
97
  item_eb_hat = self.linear(item_eb)
71
98
  else:
72
99
  u = torch.unsqueeze(item_eb, dim=2)
73
- item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
100
+ item_eb_hat = torch.sum(self.w[:, : self.seq_len, :, :] * u, dim=3)
74
101
 
75
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
102
+ item_eb_hat = torch.reshape(
103
+ item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim)
104
+ )
76
105
  item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
77
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
106
+ item_eb_hat = torch.reshape(
107
+ item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim)
108
+ )
78
109
 
79
110
  if self.stop_grad:
80
111
  item_eb_hat_iter = item_eb_hat.detach()
@@ -82,34 +113,47 @@ class CapsuleNetwork(nn.Module):
82
113
  item_eb_hat_iter = item_eb_hat
83
114
 
84
115
  if self.bilinear_type > 0:
85
- capsule_weight = torch.zeros(item_eb_hat.shape[0],
86
- self.interest_num,
87
- self.seq_len,
88
- device=item_eb.device,
89
- requires_grad=False)
116
+ capsule_weight = torch.zeros(
117
+ item_eb_hat.shape[0],
118
+ self.interest_num,
119
+ self.seq_len,
120
+ device=item_eb.device,
121
+ requires_grad=False,
122
+ )
90
123
  else:
91
- capsule_weight = torch.randn(item_eb_hat.shape[0],
92
- self.interest_num,
93
- self.seq_len,
94
- device=item_eb.device,
95
- requires_grad=False)
124
+ capsule_weight = torch.randn(
125
+ item_eb_hat.shape[0],
126
+ self.interest_num,
127
+ self.seq_len,
128
+ device=item_eb.device,
129
+ requires_grad=False,
130
+ )
96
131
 
97
132
  for i in range(self.routing_times): # 动态路由传播3次
98
133
  atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
99
134
  paddings = torch.zeros_like(atten_mask, dtype=torch.float)
100
135
 
101
136
  capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
102
- capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
137
+ capsule_softmax_weight = torch.where(
138
+ torch.eq(atten_mask, 0), paddings, capsule_softmax_weight
139
+ )
103
140
  capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
104
141
 
105
142
  if i < 2:
106
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
143
+ interest_capsule = torch.matmul(
144
+ capsule_softmax_weight, item_eb_hat_iter
145
+ )
107
146
  cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
108
147
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
109
148
  interest_capsule = scalar_factor * interest_capsule
110
149
 
111
- delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
112
- delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
150
+ delta_weight = torch.matmul(
151
+ item_eb_hat_iter,
152
+ torch.transpose(interest_capsule, 2, 3).contiguous(),
153
+ )
154
+ delta_weight = torch.reshape(
155
+ delta_weight, (-1, self.interest_num, self.seq_len)
156
+ )
113
157
  capsule_weight = capsule_weight + delta_weight
114
158
  else:
115
159
  interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
@@ -117,7 +161,9 @@ class CapsuleNetwork(nn.Module):
117
161
  scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
118
162
  interest_capsule = scalar_factor * interest_capsule
119
163
 
120
- interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
164
+ interest_capsule = torch.reshape(
165
+ interest_capsule, (-1, self.interest_num, self.embedding_dim)
166
+ )
121
167
 
122
168
  if self.relu_layer:
123
169
  interest_capsule = self.relu(interest_capsule)
@@ -129,45 +175,52 @@ class MIND(BaseMatchModel):
129
175
  @property
130
176
  def model_name(self) -> str:
131
177
  return "MIND"
132
-
178
+
133
179
  @property
134
180
  def support_training_modes(self) -> list[str]:
135
181
  """MIND only supports pointwise training mode"""
136
- return ['pointwise']
137
-
138
- def __init__(self,
139
- user_dense_features: list[DenseFeature] | None = None,
140
- user_sparse_features: list[SparseFeature] | None = None,
141
- user_sequence_features: list[SequenceFeature] | None = None,
142
- item_dense_features: list[DenseFeature] | None = None,
143
- item_sparse_features: list[SparseFeature] | None = None,
144
- item_sequence_features: list[SequenceFeature] | None = None,
145
- embedding_dim: int = 64,
146
- num_interests: int = 4,
147
- capsule_bilinear_type: int = 2,
148
- routing_times: int = 3,
149
- relu_layer: bool = False,
150
- item_dnn_hidden_units: list[int] = [256, 128],
151
- dnn_activation: str = 'relu',
152
- dnn_dropout: float = 0.0,
153
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
154
- num_negative_samples: int = 100,
155
- temperature: float = 1.0,
156
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
157
- device: str = 'cpu',
158
- embedding_l1_reg: float = 0.0,
159
- dense_l1_reg: float = 0.0,
160
- embedding_l2_reg: float = 0.0,
161
- dense_l2_reg: float = 0.0,
162
- early_stop_patience: int = 20,
163
- optimizer: str | torch.optim.Optimizer = "adam",
164
- optimizer_params: dict | None = None,
165
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
166
- scheduler_params: dict | None = None,
167
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
168
- loss_params: dict | list[dict] | None = None,
169
- **kwargs):
170
-
182
+ return ["pointwise"]
183
+
184
+ def __init__(
185
+ self,
186
+ user_dense_features: list[DenseFeature] | None = None,
187
+ user_sparse_features: list[SparseFeature] | None = None,
188
+ user_sequence_features: list[SequenceFeature] | None = None,
189
+ item_dense_features: list[DenseFeature] | None = None,
190
+ item_sparse_features: list[SparseFeature] | None = None,
191
+ item_sequence_features: list[SequenceFeature] | None = None,
192
+ embedding_dim: int = 64,
193
+ num_interests: int = 4,
194
+ capsule_bilinear_type: int = 2,
195
+ routing_times: int = 3,
196
+ relu_layer: bool = False,
197
+ item_dnn_hidden_units: list[int] = [256, 128],
198
+ dnn_activation: str = "relu",
199
+ dnn_dropout: float = 0.0,
200
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
201
+ num_negative_samples: int = 100,
202
+ temperature: float = 1.0,
203
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
204
+ device: str = "cpu",
205
+ embedding_l1_reg: float = 0.0,
206
+ dense_l1_reg: float = 0.0,
207
+ embedding_l2_reg: float = 0.0,
208
+ dense_l2_reg: float = 0.0,
209
+ early_stop_patience: int = 20,
210
+ optimizer: str | torch.optim.Optimizer = "adam",
211
+ optimizer_params: dict | None = None,
212
+ scheduler: (
213
+ str
214
+ | torch.optim.lr_scheduler._LRScheduler
215
+ | type[torch.optim.lr_scheduler._LRScheduler]
216
+ | None
217
+ ) = None,
218
+ scheduler_params: dict | None = None,
219
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
220
+ loss_params: dict | list[dict] | None = None,
221
+ **kwargs,
222
+ ):
223
+
171
224
  super(MIND, self).__init__(
172
225
  user_dense_features=user_dense_features,
173
226
  user_sparse_features=user_sparse_features,
@@ -184,9 +237,9 @@ class MIND(BaseMatchModel):
184
237
  dense_l1_reg=dense_l1_reg,
185
238
  embedding_l2_reg=embedding_l2_reg,
186
239
  dense_l2_reg=dense_l2_reg,
187
- **kwargs
240
+ **kwargs,
188
241
  )
189
-
242
+
190
243
  self.embedding_dim = embedding_dim
191
244
  self.num_interests = num_interests
192
245
  self.item_dnn_hidden_units = item_dnn_hidden_units
@@ -198,16 +251,20 @@ class MIND(BaseMatchModel):
198
251
  user_features.extend(user_sparse_features)
199
252
  if user_sequence_features:
200
253
  user_features.extend(user_sequence_features)
201
-
254
+
202
255
  if len(user_features) > 0:
203
256
  self.user_embedding = EmbeddingLayer(user_features)
204
-
257
+
205
258
  if not user_sequence_features or len(user_sequence_features) == 0:
206
259
  raise ValueError("MIND requires at least one user sequence feature")
207
-
208
- seq_max_len = user_sequence_features[0].max_len if user_sequence_features[0].max_len else 50
260
+
261
+ seq_max_len = (
262
+ user_sequence_features[0].max_len
263
+ if user_sequence_features[0].max_len
264
+ else 50
265
+ )
209
266
  seq_embedding_dim = user_sequence_features[0].embedding_dim
210
-
267
+
211
268
  # Capsule Network for multi-interest extraction
212
269
  self.capsule_network = CapsuleNetwork(
213
270
  embedding_dim=seq_embedding_dim,
@@ -215,15 +272,17 @@ class MIND(BaseMatchModel):
215
272
  bilinear_type=capsule_bilinear_type,
216
273
  interest_num=num_interests,
217
274
  routing_times=routing_times,
218
- relu_layer=relu_layer
275
+ relu_layer=relu_layer,
219
276
  )
220
-
277
+
221
278
  if seq_embedding_dim != embedding_dim:
222
- self.interest_projection = nn.Linear(seq_embedding_dim, embedding_dim, bias=False)
279
+ self.interest_projection = nn.Linear(
280
+ seq_embedding_dim, embedding_dim, bias=False
281
+ )
223
282
  nn.init.xavier_uniform_(self.interest_projection.weight)
224
283
  else:
225
284
  self.interest_projection = None
226
-
285
+
227
286
  # Item tower
228
287
  item_features = []
229
288
  if item_dense_features:
@@ -232,10 +291,10 @@ class MIND(BaseMatchModel):
232
291
  item_features.extend(item_sparse_features)
233
292
  if item_sequence_features:
234
293
  item_features.extend(item_sequence_features)
235
-
294
+
236
295
  if len(item_features) > 0:
237
296
  self.item_embedding = EmbeddingLayer(item_features)
238
-
297
+
239
298
  item_input_dim = 0
240
299
  for feat in item_dense_features or []:
241
300
  item_input_dim += 1
@@ -243,7 +302,7 @@ class MIND(BaseMatchModel):
243
302
  item_input_dim += feat.embedding_dim
244
303
  for feat in item_sequence_features or []:
245
304
  item_input_dim += feat.embedding_dim
246
-
305
+
247
306
  # Item DNN
248
307
  if len(item_dnn_hidden_units) > 0:
249
308
  item_dnn_units = item_dnn_hidden_units + [embedding_dim]
@@ -252,20 +311,19 @@ class MIND(BaseMatchModel):
252
311
  dims=item_dnn_units,
253
312
  output_layer=False,
254
313
  dropout=dnn_dropout,
255
- activation=dnn_activation
314
+ activation=dnn_activation,
256
315
  )
257
316
  else:
258
317
  self.item_dnn = None
259
-
318
+
260
319
  self.register_regularization_weights(
261
- embedding_attr='user_embedding',
262
- include_modules=['capsule_network']
320
+ embedding_attr="user_embedding", include_modules=["capsule_network"]
263
321
  )
264
322
  self.register_regularization_weights(
265
- embedding_attr='item_embedding',
266
- include_modules=['item_dnn'] if self.item_dnn else []
323
+ embedding_attr="item_embedding",
324
+ include_modules=["item_dnn"] if self.item_dnn else [],
267
325
  )
268
-
326
+
269
327
  self.compile(
270
328
  optimizer=optimizer,
271
329
  optimizer_params=optimizer_params,
@@ -276,11 +334,11 @@ class MIND(BaseMatchModel):
276
334
  )
277
335
 
278
336
  self.to(device)
279
-
337
+
280
338
  def user_tower(self, user_input: dict) -> torch.Tensor:
281
339
  """
282
340
  User tower with multi-interest extraction
283
-
341
+
284
342
  Returns:
285
343
  user_interests: [batch_size, num_interests, embedding_dim]
286
344
  """
@@ -291,43 +349,53 @@ class MIND(BaseMatchModel):
291
349
  seq_emb = embed(seq_input.long()) # [batch_size, seq_len, embedding_dim]
292
350
 
293
351
  mask = (seq_input != seq_feature.padding_idx).float() # [batch_size, seq_len]
294
-
295
- multi_interests = self.capsule_network(seq_emb, mask) # [batch_size, num_interests, seq_embedding_dim]
296
-
352
+
353
+ multi_interests = self.capsule_network(
354
+ seq_emb, mask
355
+ ) # [batch_size, num_interests, seq_embedding_dim]
356
+
297
357
  if self.interest_projection is not None:
298
- multi_interests = self.interest_projection(multi_interests) # [batch_size, num_interests, embedding_dim]
299
-
358
+ multi_interests = self.interest_projection(
359
+ multi_interests
360
+ ) # [batch_size, num_interests, embedding_dim]
361
+
300
362
  # L2 normalization
301
363
  multi_interests = F.normalize(multi_interests, p=2, dim=-1)
302
-
364
+
303
365
  return multi_interests
304
-
366
+
305
367
  def item_tower(self, item_input: dict) -> torch.Tensor:
306
368
  """Item tower"""
307
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
369
+ all_item_features = (
370
+ self.item_dense_features
371
+ + self.item_sparse_features
372
+ + self.item_sequence_features
373
+ )
308
374
  item_emb = self.item_embedding(item_input, all_item_features, squeeze_dim=True)
309
-
375
+
310
376
  if self.item_dnn is not None:
311
377
  item_emb = self.item_dnn(item_emb)
312
-
378
+
313
379
  # L2 normalization
314
380
  item_emb = F.normalize(item_emb, p=2, dim=1)
315
-
381
+
316
382
  return item_emb
317
-
318
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
383
+
384
+ def compute_similarity(
385
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
386
+ ) -> torch.Tensor:
319
387
  item_emb_expanded = item_emb.unsqueeze(1)
320
-
321
- if self.similarity_metric == 'dot':
388
+
389
+ if self.similarity_metric == "dot":
322
390
  similarities = torch.sum(user_emb * item_emb_expanded, dim=-1)
323
- elif self.similarity_metric == 'cosine':
391
+ elif self.similarity_metric == "cosine":
324
392
  similarities = F.cosine_similarity(user_emb, item_emb_expanded, dim=-1)
325
- elif self.similarity_metric == 'euclidean':
393
+ elif self.similarity_metric == "euclidean":
326
394
  similarities = -torch.sum((user_emb - item_emb_expanded) ** 2, dim=-1)
327
395
  else:
328
396
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
329
397
 
330
398
  max_similarity, _ = torch.max(similarities, dim=1) # [batch_size]
331
399
  max_similarity = max_similarity / self.temperature
332
-
400
+
333
401
  return max_similarity