replay-rec 0.20.3__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +603 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +50 -0
  18. replay/experimental/models/admm_slim.py +257 -0
  19. replay/experimental/models/base_neighbour_rec.py +200 -0
  20. replay/experimental/models/base_rec.py +1386 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +932 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +264 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/hierarchical_recommender.py +331 -0
  32. replay/experimental/models/implicit_wrap.py +131 -0
  33. replay/experimental/models/lightfm_wrap.py +303 -0
  34. replay/experimental/models/mult_vae.py +332 -0
  35. replay/experimental/models/neural_ts.py +986 -0
  36. replay/experimental/models/neuromf.py +406 -0
  37. replay/experimental/models/scala_als.py +293 -0
  38. replay/experimental/models/u_lin_ucb.py +115 -0
  39. replay/experimental/nn/data/__init__.py +1 -0
  40. replay/experimental/nn/data/schema_builder.py +102 -0
  41. replay/experimental/preprocessing/__init__.py +3 -0
  42. replay/experimental/preprocessing/data_preparator.py +839 -0
  43. replay/experimental/preprocessing/padder.py +229 -0
  44. replay/experimental/preprocessing/sequence_generator.py +208 -0
  45. replay/experimental/scenarios/__init__.py +1 -0
  46. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  47. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  48. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  49. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  50. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  51. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  52. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  53. replay/experimental/utils/__init__.py +0 -0
  54. replay/experimental/utils/logger.py +24 -0
  55. replay/experimental/utils/model_handler.py +186 -0
  56. replay/experimental/utils/session_handler.py +44 -0
  57. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  58. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +61 -6
  59. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,406 @@
1
+ """
2
+ Generalized Matrix Factorization (GMF),
3
+ Multi-Layer Perceptron (MLP),
4
+ Neural Matrix Factorization (MLP + GMF).
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as sf
12
+ from sklearn.model_selection import train_test_split
13
+ from torch import LongTensor, Tensor, nn
14
+ from torch.optim import Adam
15
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+ from replay.experimental.models.base_torch_rec import TorchRecommender
19
+ from replay.utils import PandasDataFrame, SparkDataFrame
20
+
21
+ EMBED_DIM = 128
22
+
23
+
24
+ def xavier_init_(layer: nn.Module):
25
+ """
26
+ Xavier initialization
27
+
28
+ :param layer: net layer
29
+ """
30
+ if isinstance(layer, (nn.Embedding, nn.Linear)):
31
+ nn.init.xavier_normal_(layer.weight.data)
32
+
33
+ if isinstance(layer, nn.Linear):
34
+ layer.bias.data.normal_(0.0, 0.001)
35
+
36
+
37
+ class GMF(nn.Module):
38
+ """Generalized Matrix Factorization"""
39
+
40
+ def __init__(self, user_count: int, item_count: int, embedding_dim: int):
41
+ """
42
+ :param user_count: number of users
43
+ :param item_count: number of items
44
+ :param embedding_dim: embedding size
45
+ """
46
+ super().__init__()
47
+ self.user_embedding = nn.Embedding(num_embeddings=user_count, embedding_dim=embedding_dim)
48
+ self.item_embedding = nn.Embedding(num_embeddings=item_count, embedding_dim=embedding_dim)
49
+ self.item_biases = nn.Embedding(num_embeddings=item_count, embedding_dim=1)
50
+ self.user_biases = nn.Embedding(num_embeddings=user_count, embedding_dim=1)
51
+
52
+ xavier_init_(self.user_embedding)
53
+ xavier_init_(self.item_embedding)
54
+ self.user_biases.weight.data.zero_()
55
+ self.item_biases.weight.data.zero_()
56
+
57
+ def forward(self, user: Tensor, item: Tensor) -> Tensor:
58
+ """
59
+ :param user: user id batch
60
+ :param item: item id batch
61
+ :return: model output
62
+ """
63
+ user_emb = self.user_embedding(user) + self.user_biases(user)
64
+ item_emb = self.item_embedding(item) + self.item_biases(item)
65
+ element_product = torch.mul(user_emb, item_emb)
66
+
67
+ return element_product
68
+
69
+
70
+ class MLP(nn.Module):
71
+ """Multi-Layer Perceptron"""
72
+
73
+ def __init__(
74
+ self,
75
+ user_count: int,
76
+ item_count: int,
77
+ embedding_dim: int,
78
+ hidden_dims: Optional[list[int]] = None,
79
+ ):
80
+ """
81
+ :param user_count: number of users
82
+ :param item_count: number of items
83
+ :param embedding_dim: embedding size
84
+ :param hidden_dims: list of hidden dimension sizes
85
+ """
86
+ super().__init__()
87
+ self.user_embedding = nn.Embedding(num_embeddings=user_count, embedding_dim=embedding_dim)
88
+ self.item_embedding = nn.Embedding(num_embeddings=item_count, embedding_dim=embedding_dim)
89
+ self.item_biases = nn.Embedding(num_embeddings=item_count, embedding_dim=1)
90
+ self.user_biases = nn.Embedding(num_embeddings=user_count, embedding_dim=1)
91
+
92
+ if hidden_dims:
93
+ full_hidden_dims = [2 * embedding_dim, *hidden_dims]
94
+ self.hidden_layers = nn.ModuleList(
95
+ [nn.Linear(d_in, d_out) for d_in, d_out in zip(full_hidden_dims[:-1], full_hidden_dims[1:])]
96
+ )
97
+
98
+ else:
99
+ self.hidden_layers = nn.ModuleList()
100
+
101
+ self.activation = nn.ReLU()
102
+
103
+ xavier_init_(self.user_embedding)
104
+ xavier_init_(self.item_embedding)
105
+ self.user_biases.weight.data.zero_()
106
+ self.item_biases.weight.data.zero_()
107
+ for layer in self.hidden_layers:
108
+ xavier_init_(layer)
109
+
110
+ def forward(self, user: Tensor, item: Tensor) -> Tensor:
111
+ """
112
+ :param user: user id batch
113
+ :param item: item id batch
114
+ :return: output
115
+ """
116
+ user_emb = self.user_embedding(user) + self.user_biases(user)
117
+ item_emb = self.item_embedding(item) + self.item_biases(item)
118
+ hidden = torch.cat([user_emb, item_emb], dim=-1)
119
+ for layer in self.hidden_layers:
120
+ hidden = layer(hidden)
121
+ hidden = self.activation(hidden)
122
+ return hidden
123
+
124
+
125
+ class NMF(nn.Module):
126
+ """NMF = MLP + GMF"""
127
+
128
+ def __init__(
129
+ self,
130
+ user_count: int,
131
+ item_count: int,
132
+ embedding_gmf_dim: Optional[int] = None,
133
+ embedding_mlp_dim: Optional[int] = None,
134
+ hidden_mlp_dims: Optional[list[int]] = None,
135
+ ):
136
+ """
137
+ :param user_count: number of users
138
+ :param item_count: number of items
139
+ :param embedding_gmf_dim: embedding size for gmf
140
+ :param embedding_mlp_dim: embedding size for mlp
141
+ :param hidden_mlp_dims: list of hidden dimension sizes for mlp
142
+ """
143
+ self.gmf: Optional[GMF] = None
144
+ self.mlp: Optional[MLP] = None
145
+
146
+ super().__init__()
147
+ merged_dim = 0
148
+ if embedding_gmf_dim:
149
+ self.gmf = GMF(user_count, item_count, embedding_gmf_dim)
150
+ merged_dim += embedding_gmf_dim
151
+
152
+ if embedding_mlp_dim:
153
+ self.mlp = MLP(user_count, item_count, embedding_mlp_dim, hidden_mlp_dims)
154
+ merged_dim += hidden_mlp_dims[-1] if hidden_mlp_dims else 2 * embedding_mlp_dim
155
+
156
+ self.last_layer = nn.Linear(merged_dim, 1)
157
+ xavier_init_(self.last_layer)
158
+
159
+ def forward(self, user: Tensor, item: Tensor) -> Tensor:
160
+ """
161
+ :param user: user id batch
162
+ :param item: item id batch
163
+ :return: output
164
+ """
165
+ batch_size = len(user)
166
+ gmf_vector = self.gmf(user, item) if self.gmf else torch.zeros(batch_size, 0).to(user.device)
167
+ mlp_vector = self.mlp(user, item) if self.mlp else torch.zeros(batch_size, 0).to(user.device)
168
+
169
+ merged_vector = torch.cat([gmf_vector, mlp_vector], dim=1)
170
+ merged_vector = self.last_layer(merged_vector).squeeze(dim=1)
171
+ merged_vector = torch.sigmoid(merged_vector)
172
+
173
+ return merged_vector
174
+
175
+
176
+ class NeuroMF(TorchRecommender):
177
+ """
178
+ Neural Matrix Factorization model (NeuMF, NCF).
179
+
180
+ In this implementation MLP and GMF modules are optional.
181
+ """
182
+
183
+ num_workers: int = 0
184
+ batch_size_users: int = 100000
185
+ patience: int = 3
186
+ n_saved: int = 2
187
+ valid_split_size: float = 0.1
188
+ seed: int = 42
189
+ _search_space = {
190
+ "embedding_gmf_dim": {"type": "int", "args": [EMBED_DIM, EMBED_DIM]},
191
+ "embedding_mlp_dim": {"type": "int", "args": [EMBED_DIM, EMBED_DIM]},
192
+ "learning_rate": {"type": "loguniform", "args": [0.0001, 0.5]},
193
+ "l2_reg": {"type": "loguniform", "args": [1e-9, 5]},
194
+ "count_negative_sample": {"type": "int", "args": [1, 20]},
195
+ "factor": {"type": "uniform", "args": [0.2, 0.2]},
196
+ "patience": {"type": "int", "args": [3, 3]},
197
+ }
198
+
199
+ def __init__(
200
+ self,
201
+ learning_rate: float = 0.05,
202
+ epochs: int = 20,
203
+ embedding_gmf_dim: Optional[int] = None,
204
+ embedding_mlp_dim: Optional[int] = None,
205
+ hidden_mlp_dims: Optional[list[int]] = None,
206
+ l2_reg: float = 0,
207
+ count_negative_sample: int = 1,
208
+ factor: float = 0.2,
209
+ patience: int = 3,
210
+ ):
211
+ """
212
+ MLP or GMF model can be ignored if
213
+ its embedding size (embedding_mlp_dim or embedding_gmf_dim) is set to ``None``.
214
+ Default variant is MLP + GMF with embedding size 128.
215
+
216
+ :param learning_rate: learning rate
217
+ :param epochs: number of epochs to train model
218
+ :param embedding_gmf_dim: embedding size for gmf
219
+ :param embedding_mlp_dim: embedding size for mlp
220
+ :param hidden_mlp_dims: list of hidden dimension sized for mlp
221
+ :param l2_reg: l2 regularization term
222
+ :param count_negative_sample: number of negative samples to use
223
+ :param factor: ReduceLROnPlateau reducing factor. new_lr = lr * factor
224
+ :param patience: number of non-improved epochs before reducing lr
225
+ """
226
+ super().__init__()
227
+ if not embedding_gmf_dim and not embedding_mlp_dim:
228
+ embedding_gmf_dim, embedding_mlp_dim = EMBED_DIM, EMBED_DIM
229
+
230
+ if (embedding_gmf_dim is None or embedding_gmf_dim < 0) and (
231
+ embedding_mlp_dim is None or embedding_mlp_dim < 0
232
+ ):
233
+ msg = "embedding_gmf_dim and embedding_mlp_dim must be positive"
234
+ raise ValueError(msg)
235
+
236
+ self.learning_rate = learning_rate
237
+ self.epochs = epochs
238
+ self.embedding_gmf_dim = embedding_gmf_dim
239
+ self.embedding_mlp_dim = embedding_mlp_dim
240
+ self.hidden_mlp_dims = hidden_mlp_dims
241
+ self.l2_reg = l2_reg
242
+ self.count_negative_sample = count_negative_sample
243
+ self.factor = factor
244
+ self.patience = patience
245
+
246
+ @property
247
+ def _init_args(self):
248
+ return {
249
+ "learning_rate": self.learning_rate,
250
+ "epochs": self.epochs,
251
+ "embedding_gmf_dim": self.embedding_gmf_dim,
252
+ "embedding_mlp_dim": self.embedding_mlp_dim,
253
+ "hidden_mlp_dims": self.hidden_mlp_dims,
254
+ "l2_reg": self.l2_reg,
255
+ "count_negative_sample": self.count_negative_sample,
256
+ "factor": self.factor,
257
+ "patience": self.patience,
258
+ }
259
+
260
+ def _data_loader(self, data: PandasDataFrame, shuffle: bool = True) -> DataLoader:
261
+ user_batch = LongTensor(data["user_idx"].values)
262
+ item_batch = LongTensor(data["item_idx"].values)
263
+
264
+ dataset = TensorDataset(user_batch, item_batch)
265
+
266
+ loader = DataLoader(
267
+ dataset,
268
+ batch_size=self.batch_size_users,
269
+ shuffle=shuffle,
270
+ num_workers=self.num_workers,
271
+ )
272
+ return loader
273
+
274
+ def _get_neg_batch(self, batch: Tensor) -> Tensor:
275
+ return torch.from_numpy(np.random.choice(self._fit_items_np, batch.shape[0] * self.count_negative_sample))
276
+
277
+ def _fit(
278
+ self,
279
+ log: SparkDataFrame,
280
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
281
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
282
+ ) -> None:
283
+ self.logger.debug("Create DataLoaders")
284
+ tensor_data = log.select("user_idx", "item_idx").toPandas()
285
+ train_tensor_data, valid_tensor_data = train_test_split(
286
+ tensor_data,
287
+ test_size=self.valid_split_size,
288
+ random_state=self.seed,
289
+ )
290
+ train_data_loader = self._data_loader(train_tensor_data)
291
+ valid_data_loader = self._data_loader(valid_tensor_data)
292
+ self._fit_items_np = self.fit_items.toPandas().to_numpy().ravel()
293
+
294
+ self.logger.debug("Training NeuroMF")
295
+ self.model = NMF(
296
+ user_count=self._user_dim,
297
+ item_count=self._item_dim,
298
+ embedding_gmf_dim=self.embedding_gmf_dim,
299
+ embedding_mlp_dim=self.embedding_mlp_dim,
300
+ hidden_mlp_dims=self.hidden_mlp_dims,
301
+ ).to(self.device)
302
+ optimizer = Adam(
303
+ self.model.parameters(),
304
+ lr=self.learning_rate,
305
+ weight_decay=self.l2_reg / self.batch_size_users,
306
+ )
307
+ lr_scheduler = ReduceLROnPlateau(optimizer, factor=self.factor, patience=self.patience)
308
+
309
+ self.train(
310
+ train_data_loader,
311
+ valid_data_loader,
312
+ optimizer,
313
+ lr_scheduler,
314
+ self.epochs,
315
+ "neuromf",
316
+ )
317
+
318
+ del self._fit_items_np
319
+
320
+ @staticmethod
321
+ def _loss(y_pred, y_true):
322
+ return sf.binary_cross_entropy(y_pred, y_true).mean()
323
+
324
+ def _batch_pass(self, batch, model):
325
+ user_batch, pos_item_batch = batch
326
+ neg_item_batch = self._get_neg_batch(user_batch)
327
+ pos_relevance = model(user_batch.to(self.device), pos_item_batch.to(self.device))
328
+ neg_relevance = model(
329
+ user_batch.repeat([self.count_negative_sample]).to(self.device),
330
+ neg_item_batch.to(self.device),
331
+ )
332
+ y_pred = torch.cat((pos_relevance, neg_relevance), 0)
333
+ y_true_pos = torch.ones_like(pos_item_batch).to(self.device)
334
+ y_true_neg = torch.zeros_like(neg_item_batch).to(self.device)
335
+ y_true = torch.cat((y_true_pos, y_true_neg), 0).float()
336
+
337
+ return {"y_pred": y_pred, "y_true": y_true}
338
+
339
+ @staticmethod
340
+ def _predict_pairs_inner(
341
+ model: nn.Module,
342
+ user_idx: int,
343
+ items_np: np.ndarray,
344
+ cnt: Optional[int] = None,
345
+ ) -> SparkDataFrame:
346
+ model.eval()
347
+ with torch.no_grad():
348
+ user_batch = LongTensor([user_idx] * len(items_np))
349
+ item_batch = LongTensor(items_np)
350
+ user_recs = torch.reshape(
351
+ model(user_batch, item_batch).detach(),
352
+ [
353
+ -1,
354
+ ],
355
+ )
356
+ if cnt is not None:
357
+ best_item_idx = (torch.argsort(user_recs, descending=True)[:cnt]).numpy()
358
+ user_recs = user_recs[best_item_idx]
359
+ items_np = items_np[best_item_idx]
360
+
361
+ return PandasDataFrame(
362
+ {
363
+ "user_idx": user_recs.shape[0] * [user_idx],
364
+ "item_idx": items_np,
365
+ "relevance": user_recs,
366
+ }
367
+ )
368
+
369
+ @staticmethod
370
+ def _predict_by_user(
371
+ pandas_df: PandasDataFrame,
372
+ model: nn.Module,
373
+ items_np: np.ndarray,
374
+ k: int,
375
+ item_count: int, # noqa: ARG004
376
+ ) -> PandasDataFrame:
377
+ return NeuroMF._predict_pairs_inner(
378
+ model=model,
379
+ user_idx=pandas_df["user_idx"][0],
380
+ items_np=items_np,
381
+ cnt=min(len(pandas_df) + k, len(items_np)),
382
+ )
383
+
384
+ @staticmethod
385
+ def _predict_by_user_pairs(
386
+ pandas_df: PandasDataFrame,
387
+ model: nn.Module,
388
+ item_count: int, # noqa: ARG004
389
+ ) -> PandasDataFrame:
390
+ return NeuroMF._predict_pairs_inner(
391
+ model=model,
392
+ user_idx=pandas_df["user_idx"][0],
393
+ items_np=np.array(pandas_df["item_idx_to_pred"][0]),
394
+ cnt=None,
395
+ )
396
+
397
+ def _load_model(self, path: str):
398
+ self.model = NMF(
399
+ user_count=self._user_dim,
400
+ item_count=self._item_dim,
401
+ embedding_gmf_dim=self.embedding_gmf_dim,
402
+ embedding_mlp_dim=self.embedding_mlp_dim,
403
+ hidden_mlp_dims=self.hidden_mlp_dims,
404
+ ).to(self.device)
405
+ self.model.load_state_dict(torch.load(path))
406
+ self.model.eval()
@@ -0,0 +1,293 @@
1
+ from typing import Any, Optional
2
+
3
+ from replay.experimental.models.base_rec import ItemVectorModel, Recommender
4
+ from replay.experimental.models.extensions.spark_custom_models.als_extension import ALS, ALSModel
5
+ from replay.models.extensions.ann.ann_mixin import ANNMixin
6
+ from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
7
+ from replay.utils import OPTUNA_AVAILABLE, PYSPARK_AVAILABLE, SparkDataFrame
8
+ from replay.utils.spark_utils import list_to_vector_udf
9
+
10
+ if PYSPARK_AVAILABLE:
11
+ import pyspark.sql.functions as sf
12
+ from pyspark.sql.types import DoubleType
13
+
14
+
15
+ class ALSWrap(Recommender, ItemVectorModel):
16
+ """Wrapper for `Spark ALS
17
+ <https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS>`_.
18
+ """
19
+
20
+ _seed: Optional[int] = None
21
+ if OPTUNA_AVAILABLE:
22
+ _search_space = {
23
+ "rank": {"type": "loguniform_int", "args": [8, 256]},
24
+ }
25
+
26
+ def __init__(
27
+ self,
28
+ rank: int = 10,
29
+ implicit_prefs: bool = True,
30
+ seed: Optional[int] = None,
31
+ num_item_blocks: Optional[int] = None,
32
+ num_user_blocks: Optional[int] = None,
33
+ ):
34
+ """
35
+ :param rank: hidden dimension for the approximate matrix
36
+ :param implicit_prefs: flag to use implicit feedback
37
+ :param seed: random seed
38
+ :param num_item_blocks: number of blocks the items will be partitioned into in order
39
+ to parallelize computation.
40
+ if None then will be init with number of partitions of log.
41
+ :param num_user_blocks: number of blocks the users will be partitioned into in order
42
+ to parallelize computation.
43
+ if None then will be init with number of partitions of log.
44
+ """
45
+ self.rank = rank
46
+ self.implicit_prefs = implicit_prefs
47
+ self._seed = seed
48
+ self._num_item_blocks = num_item_blocks
49
+ self._num_user_blocks = num_user_blocks
50
+
51
+ @property
52
+ def _init_args(self):
53
+ return {
54
+ "rank": self.rank,
55
+ "implicit_prefs": self.implicit_prefs,
56
+ "seed": self._seed,
57
+ }
58
+
59
+ def _save_model(self, path: str):
60
+ self.model.write().overwrite().save(path)
61
+
62
+ def _load_model(self, path: str):
63
+ self.model = ALSModel.load(path)
64
+ self.model.itemFactors.cache()
65
+ self.model.userFactors.cache()
66
+
67
+ def _fit(
68
+ self,
69
+ log: SparkDataFrame,
70
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
71
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
72
+ ) -> None:
73
+ if self._num_item_blocks is None:
74
+ self._num_item_blocks = log.rdd.getNumPartitions()
75
+ if self._num_user_blocks is None:
76
+ self._num_user_blocks = log.rdd.getNumPartitions()
77
+
78
+ self.model = ALS(
79
+ rank=self.rank,
80
+ numItemBlocks=self._num_item_blocks,
81
+ numUserBlocks=self._num_user_blocks,
82
+ userCol="user_idx",
83
+ itemCol="item_idx",
84
+ ratingCol="relevance",
85
+ implicitPrefs=self.implicit_prefs,
86
+ seed=self._seed,
87
+ coldStartStrategy="drop",
88
+ ).fit(log)
89
+ self.model.itemFactors.cache()
90
+ self.model.userFactors.cache()
91
+ self.model.itemFactors.count()
92
+ self.model.userFactors.count()
93
+
94
+ def _clear_cache(self):
95
+ if hasattr(self, "model"):
96
+ self.model.itemFactors.unpersist()
97
+ self.model.userFactors.unpersist()
98
+
99
+ def _predict(
100
+ self,
101
+ log: Optional[SparkDataFrame],
102
+ k: int,
103
+ users: SparkDataFrame,
104
+ items: SparkDataFrame,
105
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
106
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
107
+ filter_seen_items: bool = True,
108
+ ) -> SparkDataFrame:
109
+ if (items.count() == self.fit_items.count()) and (
110
+ items.join(self.fit_items, on="item_idx", how="inner").count() == self.fit_items.count()
111
+ ):
112
+ max_seen = 0
113
+ if filter_seen_items and log is not None:
114
+ max_seen_in_log = (
115
+ log.join(users, on="user_idx")
116
+ .groupBy("user_idx")
117
+ .agg(sf.count("user_idx").alias("num_seen"))
118
+ .select(sf.max("num_seen"))
119
+ .first()[0]
120
+ )
121
+ max_seen = max_seen_in_log if max_seen_in_log is not None else 0
122
+
123
+ recs_als = self.model.recommendForUserSubset(users, k + max_seen)
124
+ return (
125
+ recs_als.withColumn("recommendations", sf.explode("recommendations"))
126
+ .withColumn("item_idx", sf.col("recommendations.item_idx"))
127
+ .withColumn(
128
+ "relevance",
129
+ sf.col("recommendations.rating").cast(DoubleType()),
130
+ )
131
+ .select("user_idx", "item_idx", "relevance")
132
+ )
133
+
134
+ return self._predict_pairs(
135
+ pairs=users.crossJoin(items).withColumn("relevance", sf.lit(1)),
136
+ log=log,
137
+ )
138
+
139
+ def _predict_pairs(
140
+ self,
141
+ pairs: SparkDataFrame,
142
+ log: Optional[SparkDataFrame] = None, # noqa: ARG002
143
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
144
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
145
+ ) -> SparkDataFrame:
146
+ return (
147
+ self.model.transform(pairs)
148
+ .withColumn("relevance", sf.col("prediction").cast(DoubleType()))
149
+ .drop("prediction")
150
+ )
151
+
152
+ def _get_features(
153
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
154
+ ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
155
+ entity = "user" if "user_idx" in ids.columns else "item"
156
+ als_factors = getattr(self.model, f"{entity}Factors")
157
+ als_factors = als_factors.withColumnRenamed("id", f"{entity}_idx").withColumnRenamed(
158
+ "features", f"{entity}_factors"
159
+ )
160
+ return (
161
+ als_factors.join(ids, how="right", on=f"{entity}_idx"),
162
+ self.model.rank,
163
+ )
164
+
165
+ def _get_item_vectors(self):
166
+ return self.model.itemFactors.select(
167
+ sf.col("id").alias("item_idx"),
168
+ list_to_vector_udf(sf.col("features")).alias("item_vector"),
169
+ )
170
+
171
+
172
+ class ScalaALSWrap(ALSWrap, ANNMixin):
173
+ """Wrapper for `Spark ALS
174
+ <https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS>`_.
175
+ """
176
+
177
+ def _get_ann_infer_params(self) -> dict[str, Any]:
178
+ self.index_builder.index_params.dim = self.rank
179
+ return {
180
+ "features_col": "user_factors",
181
+ }
182
+
183
+ def _get_vectors_to_infer_ann_inner(
184
+ self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
185
+ ) -> SparkDataFrame:
186
+ user_vectors, _ = self.get_features(queries)
187
+ return user_vectors
188
+
189
+ def _configure_index_builder(self, interactions: SparkDataFrame):
190
+ item_vectors, _ = self.get_features(interactions.select("item_idx").distinct())
191
+
192
+ self.index_builder.index_params.dim = self.rank
193
+ self.index_builder.index_params.max_elements = interactions.select("item_idx").distinct().count()
194
+
195
+ return item_vectors, {
196
+ "features_col": "item_factors",
197
+ "ids_col": "item_idx",
198
+ }
199
+
200
+ def __init__(
201
+ self,
202
+ rank: int = 10,
203
+ implicit_prefs: bool = True,
204
+ seed: Optional[int] = None,
205
+ num_item_blocks: Optional[int] = None,
206
+ num_user_blocks: Optional[int] = None,
207
+ index_builder: Optional[IndexBuilder] = None,
208
+ ):
209
+ ALSWrap.__init__(self, rank, implicit_prefs, seed, num_item_blocks, num_user_blocks)
210
+ self.init_index_builder(index_builder)
211
+ self.num_elements = None
212
+
213
+ @property
214
+ def _init_args(self):
215
+ return {
216
+ "rank": self.rank,
217
+ "implicit_prefs": self.implicit_prefs,
218
+ "seed": self._seed,
219
+ "index_builder": self.index_builder.init_meta_as_dict() if self.index_builder else None,
220
+ }
221
+
222
+ def _fit(
223
+ self,
224
+ log: SparkDataFrame,
225
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
226
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
227
+ ) -> None:
228
+ if self._num_item_blocks is None:
229
+ self._num_item_blocks = log.rdd.getNumPartitions()
230
+ if self._num_user_blocks is None:
231
+ self._num_user_blocks = log.rdd.getNumPartitions()
232
+
233
+ self.model = ALS(
234
+ rank=self.rank,
235
+ numItemBlocks=self._num_item_blocks,
236
+ numUserBlocks=self._num_user_blocks,
237
+ userCol="user_idx",
238
+ itemCol="item_idx",
239
+ ratingCol="relevance",
240
+ implicitPrefs=self.implicit_prefs,
241
+ seed=self._seed,
242
+ coldStartStrategy="drop",
243
+ ).fit(log)
244
+ self.model.itemFactors.cache()
245
+ self.model.userFactors.cache()
246
+ self.model.itemFactors.count()
247
+ self.model.userFactors.count()
248
+
249
+ def _save_model(self, path: str):
250
+ self.model.write().overwrite().save(path)
251
+
252
+ if self._use_ann:
253
+ self._save_index(path)
254
+
255
+ def _load_model(self, path: str):
256
+ self.model = ALSModel.load(path)
257
+ self.model.itemFactors.cache()
258
+ self.model.userFactors.cache()
259
+
260
+ if self._use_ann:
261
+ self._load_index(path)
262
+
263
+ def _predict(
264
+ self,
265
+ log: Optional[SparkDataFrame],
266
+ k: int,
267
+ users: SparkDataFrame,
268
+ items: SparkDataFrame,
269
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
270
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
271
+ filter_seen_items: bool = True,
272
+ ) -> SparkDataFrame:
273
+ max_seen = 0
274
+ if filter_seen_items and log is not None:
275
+ max_seen_in_log = (
276
+ log.join(users, on="user_idx")
277
+ .groupBy("user_idx")
278
+ .agg(sf.count("user_idx").alias("num_seen"))
279
+ .select(sf.max("num_seen"))
280
+ .first()[0]
281
+ )
282
+ max_seen = max_seen_in_log if max_seen_in_log is not None else 0
283
+
284
+ recs_als = self.model.recommendItemsForUserItemSubset(users, items, k + max_seen)
285
+ return (
286
+ recs_als.withColumn("recommendations", sf.explode("recommendations"))
287
+ .withColumn("item_idx", sf.col("recommendations.item_idx"))
288
+ .withColumn(
289
+ "relevance",
290
+ sf.col("recommendations.rating").cast(DoubleType()),
291
+ )
292
+ .select("user_idx", "item_idx", "relevance")
293
+ )