replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,462 +0,0 @@
1
- """
2
- Generalized Matrix Factorization (GMF),
3
- Multi-Layer Perceptron (MLP),
4
- Neural Matrix Factorization (MLP + GMF).
5
- """
6
- from typing import List, Optional
7
-
8
- import numpy as np
9
- import torch
10
- import torch.nn.functional as F
11
- from sklearn.model_selection import train_test_split
12
- from torch import LongTensor, Tensor, nn
13
- from torch.optim import Adam
14
- from torch.optim.lr_scheduler import ReduceLROnPlateau
15
- from torch.utils.data import DataLoader, TensorDataset
16
-
17
- from replay.experimental.models.base_torch_rec import TorchRecommender
18
- from replay.utils import PandasDataFrame, SparkDataFrame
19
-
20
- EMBED_DIM = 128
21
-
22
-
23
- def xavier_init_(layer: nn.Module):
24
- """
25
- Xavier initialization
26
-
27
- :param layer: net layer
28
- """
29
- if isinstance(layer, (nn.Embedding, nn.Linear)):
30
- nn.init.xavier_normal_(layer.weight.data)
31
-
32
- if isinstance(layer, nn.Linear):
33
- layer.bias.data.normal_(0.0, 0.001)
34
-
35
-
36
- # pylint: disable=too-few-public-methods
37
- class GMF(nn.Module):
38
- """Generalized Matrix Factorization"""
39
-
40
- def __init__(self, user_count: int, item_count: int, embedding_dim: int):
41
- """
42
- :param user_count: number of users
43
- :param item_count: number of items
44
- :param embedding_dim: embedding size
45
- """
46
- super().__init__()
47
- self.user_embedding = nn.Embedding(
48
- num_embeddings=user_count, embedding_dim=embedding_dim
49
- )
50
- self.item_embedding = nn.Embedding(
51
- num_embeddings=item_count, embedding_dim=embedding_dim
52
- )
53
- self.item_biases = nn.Embedding(
54
- num_embeddings=item_count, embedding_dim=1
55
- )
56
- self.user_biases = nn.Embedding(
57
- num_embeddings=user_count, embedding_dim=1
58
- )
59
-
60
- xavier_init_(self.user_embedding)
61
- xavier_init_(self.item_embedding)
62
- self.user_biases.weight.data.zero_()
63
- self.item_biases.weight.data.zero_()
64
-
65
- # pylint: disable=arguments-differ
66
- def forward(self, user: Tensor, item: Tensor) -> Tensor: # type: ignore
67
- """
68
- :param user: user id batch
69
- :param item: item id batch
70
- :return: model output
71
- """
72
- user_emb = self.user_embedding(user) + self.user_biases(user)
73
- item_emb = self.item_embedding(item) + self.item_biases(item)
74
- element_product = torch.mul(user_emb, item_emb)
75
-
76
- return element_product
77
-
78
-
79
- # pylint: disable=too-few-public-methods
80
- class MLP(nn.Module):
81
- """Multi-Layer Perceptron"""
82
-
83
- def __init__(
84
- self,
85
- user_count: int,
86
- item_count: int,
87
- embedding_dim: int,
88
- hidden_dims: Optional[List[int]] = None,
89
- ):
90
- """
91
- :param user_count: number of users
92
- :param item_count: number of items
93
- :param embedding_dim: embedding size
94
- :param hidden_dims: list of hidden dimension sizes
95
- """
96
- super().__init__()
97
- self.user_embedding = nn.Embedding(
98
- num_embeddings=user_count, embedding_dim=embedding_dim
99
- )
100
- self.item_embedding = nn.Embedding(
101
- num_embeddings=item_count, embedding_dim=embedding_dim
102
- )
103
- self.item_biases = nn.Embedding(
104
- num_embeddings=item_count, embedding_dim=1
105
- )
106
- self.user_biases = nn.Embedding(
107
- num_embeddings=user_count, embedding_dim=1
108
- )
109
-
110
- if hidden_dims:
111
- full_hidden_dims = [2 * embedding_dim] + hidden_dims
112
- self.hidden_layers = nn.ModuleList(
113
- [
114
- nn.Linear(d_in, d_out)
115
- for d_in, d_out in zip(
116
- full_hidden_dims[:-1], full_hidden_dims[1:]
117
- )
118
- ]
119
- )
120
-
121
- else:
122
- self.hidden_layers = nn.ModuleList()
123
-
124
- self.activation = nn.ReLU()
125
-
126
- xavier_init_(self.user_embedding)
127
- xavier_init_(self.item_embedding)
128
- self.user_biases.weight.data.zero_()
129
- self.item_biases.weight.data.zero_()
130
- for layer in self.hidden_layers:
131
- xavier_init_(layer)
132
-
133
- # pylint: disable=arguments-differ
134
- def forward(self, user: Tensor, item: Tensor) -> Tensor: # type: ignore
135
- """
136
- :param user: user id batch
137
- :param item: item id batch
138
- :return: output
139
- """
140
- user_emb = self.user_embedding(user) + self.user_biases(user)
141
- item_emb = self.item_embedding(item) + self.item_biases(item)
142
- hidden = torch.cat([user_emb, item_emb], dim=-1)
143
- for layer in self.hidden_layers:
144
- hidden = layer(hidden)
145
- hidden = self.activation(hidden)
146
- return hidden
147
-
148
-
149
- # pylint: disable=too-few-public-methods
150
- class NMF(nn.Module):
151
- """NMF = MLP + GMF"""
152
-
153
- # pylint: disable=too-many-arguments
154
- def __init__(
155
- self,
156
- user_count: int,
157
- item_count: int,
158
- embedding_gmf_dim: Optional[int] = None,
159
- embedding_mlp_dim: Optional[int] = None,
160
- hidden_mlp_dims: Optional[List[int]] = None,
161
- ):
162
- """
163
- :param user_count: number of users
164
- :param item_count: number of items
165
- :param embedding_gmf_dim: embedding size for gmf
166
- :param embedding_mlp_dim: embedding size for mlp
167
- :param hidden_mlp_dims: list of hidden dimension sizes for mlp
168
- """
169
- self.gmf: Optional[GMF] = None
170
- self.mlp: Optional[MLP] = None
171
-
172
- super().__init__()
173
- merged_dim = 0
174
- if embedding_gmf_dim:
175
- self.gmf = GMF(user_count, item_count, embedding_gmf_dim)
176
- merged_dim += embedding_gmf_dim
177
-
178
- if embedding_mlp_dim:
179
- self.mlp = MLP(
180
- user_count, item_count, embedding_mlp_dim, hidden_mlp_dims
181
- )
182
- merged_dim += (
183
- hidden_mlp_dims[-1]
184
- if hidden_mlp_dims
185
- else 2 * embedding_mlp_dim
186
- )
187
-
188
- self.last_layer = nn.Linear(merged_dim, 1)
189
- xavier_init_(self.last_layer)
190
-
191
- # pylint: disable=arguments-differ
192
- def forward(self, user: Tensor, item: Tensor) -> Tensor: # type: ignore
193
- """
194
- :param user: user id batch
195
- :param item: item id batch
196
- :return: output
197
- """
198
- batch_size = len(user)
199
- if self.gmf:
200
- gmf_vector = self.gmf(user, item)
201
- else:
202
- gmf_vector = torch.zeros(batch_size, 0).to(user.device)
203
-
204
- if self.mlp:
205
- mlp_vector = self.mlp(user, item)
206
- else:
207
- mlp_vector = torch.zeros(batch_size, 0).to(user.device)
208
-
209
- merged_vector = torch.cat([gmf_vector, mlp_vector], dim=1)
210
- merged_vector = self.last_layer(merged_vector).squeeze(dim=1)
211
- merged_vector = torch.sigmoid(merged_vector)
212
-
213
- return merged_vector
214
-
215
-
216
- # pylint: disable=too-many-instance-attributes
217
- class NeuroMF(TorchRecommender):
218
- """
219
- Neural Matrix Factorization model (NeuMF, NCF).
220
-
221
- In this implementation MLP and GMF modules are optional.
222
- """
223
-
224
- num_workers: int = 0
225
- batch_size_users: int = 100000
226
- patience: int = 3
227
- n_saved: int = 2
228
- valid_split_size: float = 0.1
229
- seed: int = 42
230
- _search_space = {
231
- "embedding_gmf_dim": {"type": "int", "args": [EMBED_DIM, EMBED_DIM]},
232
- "embedding_mlp_dim": {"type": "int", "args": [EMBED_DIM, EMBED_DIM]},
233
- "learning_rate": {"type": "loguniform", "args": [0.0001, 0.5]},
234
- "l2_reg": {"type": "loguniform", "args": [1e-9, 5]},
235
- "count_negative_sample": {"type": "int", "args": [1, 20]},
236
- "factor": {"type": "uniform", "args": [0.2, 0.2]},
237
- "patience": {"type": "int", "args": [3, 3]},
238
- }
239
-
240
- # pylint: disable=too-many-arguments
241
- def __init__(
242
- self,
243
- learning_rate: float = 0.05,
244
- epochs: int = 20,
245
- embedding_gmf_dim: Optional[int] = None,
246
- embedding_mlp_dim: Optional[int] = None,
247
- hidden_mlp_dims: Optional[List[int]] = None,
248
- l2_reg: float = 0,
249
- count_negative_sample: int = 1,
250
- factor: float = 0.2,
251
- patience: int = 3,
252
- ):
253
- """
254
- MLP or GMF model can be ignored if
255
- its embedding size (embedding_mlp_dim or embedding_gmf_dim) is set to ``None``.
256
- Default variant is MLP + GMF with embedding size 128.
257
-
258
- :param learning_rate: learning rate
259
- :param epochs: number of epochs to train model
260
- :param embedding_gmf_dim: embedding size for gmf
261
- :param embedding_mlp_dim: embedding size for mlp
262
- :param hidden_mlp_dims: list of hidden dimension sized for mlp
263
- :param l2_reg: l2 regularization term
264
- :param count_negative_sample: number of negative samples to use
265
- :param factor: ReduceLROnPlateau reducing factor. new_lr = lr * factor
266
- :param patience: number of non-improved epochs before reducing lr
267
- """
268
- super().__init__()
269
- if not embedding_gmf_dim and not embedding_mlp_dim:
270
- embedding_gmf_dim, embedding_mlp_dim = EMBED_DIM, EMBED_DIM
271
-
272
- if (embedding_gmf_dim is None or embedding_gmf_dim < 0) and (
273
- embedding_mlp_dim is None or embedding_mlp_dim < 0
274
- ):
275
- raise ValueError(
276
- "embedding_gmf_dim and embedding_mlp_dim must be positive"
277
- )
278
-
279
- self.learning_rate = learning_rate
280
- self.epochs = epochs
281
- self.embedding_gmf_dim = embedding_gmf_dim
282
- self.embedding_mlp_dim = embedding_mlp_dim
283
- self.hidden_mlp_dims = hidden_mlp_dims
284
- self.l2_reg = l2_reg
285
- self.count_negative_sample = count_negative_sample
286
- self.factor = factor
287
- self.patience = patience
288
-
289
- @property
290
- def _init_args(self):
291
- return {
292
- "learning_rate": self.learning_rate,
293
- "epochs": self.epochs,
294
- "embedding_gmf_dim": self.embedding_gmf_dim,
295
- "embedding_mlp_dim": self.embedding_mlp_dim,
296
- "hidden_mlp_dims": self.hidden_mlp_dims,
297
- "l2_reg": self.l2_reg,
298
- "count_negative_sample": self.count_negative_sample,
299
- "factor": self.factor,
300
- "patience": self.patience,
301
- }
302
-
303
- def _data_loader(
304
- self, data: PandasDataFrame, shuffle: bool = True
305
- ) -> DataLoader:
306
-
307
- user_batch = LongTensor(data["user_idx"].values) # type: ignore
308
- item_batch = LongTensor(data["item_idx"].values) # type: ignore
309
-
310
- dataset = TensorDataset(user_batch, item_batch)
311
-
312
- loader = DataLoader(
313
- dataset,
314
- batch_size=self.batch_size_users,
315
- shuffle=shuffle,
316
- num_workers=self.num_workers,
317
- )
318
- return loader
319
-
320
- def _get_neg_batch(self, batch: Tensor) -> Tensor:
321
- return torch.from_numpy(
322
- np.random.choice(
323
- self._fit_items_np, batch.shape[0] * self.count_negative_sample
324
- )
325
- )
326
-
327
- def _fit(
328
- self,
329
- log: SparkDataFrame,
330
- user_features: Optional[SparkDataFrame] = None,
331
- item_features: Optional[SparkDataFrame] = None,
332
- ) -> None:
333
- self.logger.debug("Create DataLoaders")
334
- tensor_data = log.select("user_idx", "item_idx").toPandas()
335
- train_tensor_data, valid_tensor_data = train_test_split(
336
- tensor_data,
337
- test_size=self.valid_split_size,
338
- random_state=self.seed,
339
- )
340
- train_data_loader = self._data_loader(train_tensor_data)
341
- valid_data_loader = self._data_loader(valid_tensor_data)
342
- # pylint: disable=attribute-defined-outside-init
343
- self._fit_items_np = self.fit_items.toPandas().to_numpy().ravel()
344
-
345
- self.logger.debug("Training NeuroMF")
346
- self.model = NMF(
347
- user_count=self._user_dim,
348
- item_count=self._item_dim,
349
- embedding_gmf_dim=self.embedding_gmf_dim,
350
- embedding_mlp_dim=self.embedding_mlp_dim,
351
- hidden_mlp_dims=self.hidden_mlp_dims,
352
- ).to(self.device)
353
- optimizer = Adam(
354
- self.model.parameters(),
355
- lr=self.learning_rate,
356
- weight_decay=self.l2_reg / self.batch_size_users,
357
- )
358
- lr_scheduler = ReduceLROnPlateau(
359
- optimizer, factor=self.factor, patience=self.patience
360
- )
361
-
362
- self.train(
363
- train_data_loader,
364
- valid_data_loader,
365
- optimizer,
366
- lr_scheduler,
367
- self.epochs,
368
- "neuromf",
369
- )
370
-
371
- del self._fit_items_np
372
-
373
- # pylint: disable=arguments-differ
374
- @staticmethod
375
- def _loss(y_pred, y_true):
376
- return F.binary_cross_entropy(y_pred, y_true).mean()
377
-
378
- def _batch_pass(self, batch, model):
379
- user_batch, pos_item_batch = batch
380
- neg_item_batch = self._get_neg_batch(user_batch)
381
- pos_relevance = model(
382
- user_batch.to(self.device), pos_item_batch.to(self.device)
383
- )
384
- neg_relevance = model(
385
- user_batch.repeat([self.count_negative_sample]).to(self.device),
386
- neg_item_batch.to(self.device),
387
- )
388
- y_pred = torch.cat((pos_relevance, neg_relevance), 0)
389
- y_true_pos = torch.ones_like(pos_item_batch).to(self.device)
390
- y_true_neg = torch.zeros_like(neg_item_batch).to(self.device)
391
- y_true = torch.cat((y_true_pos, y_true_neg), 0).float()
392
-
393
- return {"y_pred": y_pred, "y_true": y_true}
394
-
395
- @staticmethod
396
- def _predict_pairs_inner(
397
- model: nn.Module,
398
- user_idx: int,
399
- items_np: np.ndarray,
400
- cnt: Optional[int] = None,
401
- ) -> SparkDataFrame:
402
- model.eval()
403
- with torch.no_grad():
404
- user_batch = LongTensor([user_idx] * len(items_np))
405
- item_batch = LongTensor(items_np)
406
- user_recs = torch.reshape(
407
- model(user_batch, item_batch).detach(),
408
- [
409
- -1,
410
- ],
411
- )
412
- if cnt is not None:
413
- best_item_idx = (
414
- torch.argsort(user_recs, descending=True)[:cnt]
415
- ).numpy()
416
- user_recs = user_recs[best_item_idx]
417
- items_np = items_np[best_item_idx]
418
-
419
- return PandasDataFrame(
420
- {
421
- "user_idx": user_recs.shape[0] * [user_idx],
422
- "item_idx": items_np,
423
- "relevance": user_recs,
424
- }
425
- )
426
-
427
- @staticmethod
428
- def _predict_by_user(
429
- pandas_df: PandasDataFrame,
430
- model: nn.Module,
431
- items_np: np.ndarray,
432
- k: int,
433
- item_count: int,
434
- ) -> PandasDataFrame:
435
- return NeuroMF._predict_pairs_inner(
436
- model=model,
437
- user_idx=pandas_df["user_idx"][0],
438
- items_np=items_np,
439
- cnt=min(len(pandas_df) + k, len(items_np)),
440
- )
441
-
442
- @staticmethod
443
- def _predict_by_user_pairs(
444
- pandas_df: PandasDataFrame, model: nn.Module, item_count: int
445
- ) -> PandasDataFrame:
446
- return NeuroMF._predict_pairs_inner(
447
- model=model,
448
- user_idx=pandas_df["user_idx"][0],
449
- items_np=np.array(pandas_df["item_idx_to_pred"][0]),
450
- cnt=None,
451
- )
452
-
453
- def _load_model(self, path: str):
454
- self.model = NMF(
455
- user_count=self._user_dim,
456
- item_count=self._item_dim,
457
- embedding_gmf_dim=self.embedding_gmf_dim,
458
- embedding_mlp_dim=self.embedding_mlp_dim,
459
- hidden_mlp_dims=self.hidden_mlp_dims,
460
- ).to(self.device)
461
- self.model.load_state_dict(torch.load(path))
462
- self.model.eval()