replay-rec 0.18.0__py3-none-any.whl → 0.18.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +602 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +10 -0
  18. replay/experimental/models/admm_slim.py +205 -0
  19. replay/experimental/models/base_neighbour_rec.py +204 -0
  20. replay/experimental/models/base_rec.py +1271 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +923 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +265 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/implicit_wrap.py +131 -0
  32. replay/experimental/models/lightfm_wrap.py +302 -0
  33. replay/experimental/models/mult_vae.py +332 -0
  34. replay/experimental/models/neuromf.py +406 -0
  35. replay/experimental/models/scala_als.py +296 -0
  36. replay/experimental/nn/data/__init__.py +1 -0
  37. replay/experimental/nn/data/schema_builder.py +55 -0
  38. replay/experimental/preprocessing/__init__.py +3 -0
  39. replay/experimental/preprocessing/data_preparator.py +839 -0
  40. replay/experimental/preprocessing/padder.py +229 -0
  41. replay/experimental/preprocessing/sequence_generator.py +208 -0
  42. replay/experimental/scenarios/__init__.py +1 -0
  43. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  44. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  45. replay/experimental/scenarios/obp_wrapper/replay_offline.py +248 -0
  46. replay/experimental/scenarios/obp_wrapper/utils.py +87 -0
  47. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  48. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  49. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  50. replay/experimental/utils/__init__.py +0 -0
  51. replay/experimental/utils/logger.py +24 -0
  52. replay/experimental/utils/model_handler.py +186 -0
  53. replay/experimental/utils/session_handler.py +44 -0
  54. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/METADATA +11 -3
  55. replay_rec-0.18.0rc0.dist-info/NOTICE +41 -0
  56. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/RECORD +58 -5
  57. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/WHEEL +1 -1
  58. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,1271 @@
1
+ """
2
+ Base abstract classes:
3
+ - BaseRecommender - the simplest base class
4
+ - Recommender - base class for models that fit on interaction log
5
+ - HybridRecommender - base class for models that accept user or item features
6
+ - UserRecommender - base class that accepts only user features, but not item features
7
+ - NeighbourRec - base class that requires log at prediction time
8
+ - ItemVectorModel - class for models which provides items' vectors.
9
+ Implements similar items search.
10
+ - NonPersonalizedRecommender - base class for non-personalized recommenders
11
+ with popularity statistics
12
+ """
13
+
14
+ from abc import ABC, abstractmethod
15
+ from os.path import join
16
+ from typing import Any, Iterable, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ from numpy.random import default_rng
20
+
21
+ from replay.data import get_schema
22
+ from replay.experimental.utils.session_handler import State
23
+ from replay.models.base_rec import IsSavable, RecommenderCommons
24
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
25
+ from replay.utils.spark_utils import (
26
+ convert2spark,
27
+ cosine_similarity,
28
+ filter_cold,
29
+ get_top_k,
30
+ get_top_k_recs,
31
+ get_unique_entities,
32
+ load_pickled_from_parquet,
33
+ return_recs,
34
+ save_picklable_to_parquet,
35
+ vector_dot,
36
+ vector_euclidean_distance_similarity,
37
+ )
38
+
39
+ if PYSPARK_AVAILABLE:
40
+ from pyspark.sql import (
41
+ Window,
42
+ functions as sf,
43
+ )
44
+
45
+
46
+ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
47
+ """Base recommender"""
48
+
49
+ model: Any
50
+ can_predict_cold_users: bool = False
51
+ can_predict_cold_items: bool = False
52
+ fit_users: SparkDataFrame
53
+ fit_items: SparkDataFrame
54
+ _num_users: int
55
+ _num_items: int
56
+ _user_dim_size: int
57
+ _item_dim_size: int
58
+
59
+ def _fit_wrap(
60
+ self,
61
+ log: SparkDataFrame,
62
+ user_features: Optional[SparkDataFrame] = None,
63
+ item_features: Optional[SparkDataFrame] = None,
64
+ ) -> None:
65
+ """
66
+ Wrapper for fit to allow for fewer arguments in a model.
67
+
68
+ :param log: historical log of interactions
69
+ ``[user_idx, item_idx, timestamp, relevance]``
70
+ :param user_features: user features
71
+ ``[user_idx, timestamp]`` + feature columns
72
+ :param item_features: item features
73
+ ``[item_idx, timestamp]`` + feature columns
74
+ :return:
75
+ """
76
+ self.logger.debug("Starting fit %s", type(self).__name__)
77
+ if user_features is None:
78
+ users = log.select("user_idx").distinct()
79
+ else:
80
+ users = log.select("user_idx").union(user_features.select("user_idx")).distinct()
81
+ if item_features is None:
82
+ items = log.select("item_idx").distinct()
83
+ else:
84
+ items = log.select("item_idx").union(item_features.select("item_idx")).distinct()
85
+ self.fit_users = sf.broadcast(users)
86
+ self.fit_items = sf.broadcast(items)
87
+ self._num_users = self.fit_users.count()
88
+ self._num_items = self.fit_items.count()
89
+ self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
90
+ self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
91
+ self._fit(log, user_features, item_features)
92
+
93
+ @abstractmethod
94
+ def _fit(
95
+ self,
96
+ log: SparkDataFrame,
97
+ user_features: Optional[SparkDataFrame] = None,
98
+ item_features: Optional[SparkDataFrame] = None,
99
+ ) -> None:
100
+ """
101
+ Inner method where model actually fits.
102
+
103
+ :param log: historical log of interactions
104
+ ``[user_idx, item_idx, timestamp, relevance]``
105
+ :param user_features: user features
106
+ ``[user_idx, timestamp]`` + feature columns
107
+ :param item_features: item features
108
+ ``[item_idx, timestamp]`` + feature columns
109
+ :return:
110
+ """
111
+
112
+ def _filter_seen(self, recs: SparkDataFrame, log: SparkDataFrame, k: int, users: SparkDataFrame):
113
+ """
114
+ Filter seen items (presented in log) out of the users' recommendations.
115
+ For each user return from `k` to `k + number of seen by user` recommendations.
116
+ """
117
+ users_log = log.join(users, on="user_idx")
118
+ self._cache_model_temp_view(users_log, "filter_seen_users_log")
119
+ num_seen = users_log.groupBy("user_idx").agg(sf.count("item_idx").alias("seen_count"))
120
+ self._cache_model_temp_view(num_seen, "filter_seen_num_seen")
121
+
122
+ # count maximal number of items seen by users
123
+ max_seen = 0
124
+ if num_seen.count() > 0:
125
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
126
+
127
+ # crop recommendations to first k + max_seen items for each user
128
+ recs = recs.withColumn(
129
+ "temp_rank",
130
+ sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.col("relevance").desc())),
131
+ ).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
132
+
133
+ # leave k + number of items seen by user recommendations in recs
134
+ recs = (
135
+ recs.join(num_seen, on="user_idx", how="left")
136
+ .fillna(0)
137
+ .filter(sf.col("temp_rank") <= sf.col("seen_count") + sf.lit(k))
138
+ .drop("temp_rank", "seen_count")
139
+ )
140
+
141
+ # filter recommendations presented in interactions log
142
+ recs = recs.join(
143
+ users_log.withColumnRenamed("item_idx", "item")
144
+ .withColumnRenamed("user_idx", "user")
145
+ .select("user", "item"),
146
+ on=(sf.col("user_idx") == sf.col("user")) & (sf.col("item_idx") == sf.col("item")),
147
+ how="anti",
148
+ ).drop("user", "item")
149
+
150
+ return recs
151
+
152
+ def _filter_log_users_items_dataframes(
153
+ self,
154
+ log: Optional[SparkDataFrame],
155
+ k: int,
156
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
157
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
158
+ user_features: Optional[SparkDataFrame] = None,
159
+ ):
160
+ """
161
+ Returns triplet of filtered `log`, `users`, and `items`.
162
+ Filters out cold entities (users/items) from the `users`/`items` and `log` dataframes
163
+ if the model does not predict cold.
164
+ Filters out duplicates from `users` and `items` dataframes,
165
+ and excludes all columns except `user_idx` and `item_idx`.
166
+
167
+ :param log: historical log of interactions
168
+ ``[user_idx, item_idx, timestamp, relevance]``
169
+ :param k: number of recommendations for each user
170
+ :param users: users to create recommendations for
171
+ dataframe containing ``[user_idx]`` or ``array-like``;
172
+ if ``None``, recommend to all users from ``log``
173
+ :param items: candidate items for recommendations
174
+ dataframe containing ``[item_idx]`` or ``array-like``;
175
+ if ``None``, take all items from ``log``.
176
+ If it contains new items, ``relevance`` for them will be ``0``.
177
+ :param user_features: user features
178
+ ``[user_idx , timestamp]`` + feature columns
179
+ :return: triplet of filtered `log`, `users`, and `items` dataframes.
180
+ """
181
+ self.logger.debug("Starting predict %s", type(self).__name__)
182
+ user_data = users or log or user_features or self.fit_users
183
+ users = get_unique_entities(user_data, "user_idx")
184
+ users, log = self._filter_cold_for_predict(users, log, "user")
185
+
186
+ item_data = items or self.fit_items
187
+ items = get_unique_entities(item_data, "item_idx")
188
+ items, log = self._filter_cold_for_predict(items, log, "item")
189
+ num_items = items.count()
190
+ if num_items < k:
191
+ message = f"k = {k} > number of items = {num_items}"
192
+ self.logger.debug(message)
193
+ return log, users, items
194
+
195
+ def _predict_wrap(
196
+ self,
197
+ log: Optional[SparkDataFrame],
198
+ k: int,
199
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
200
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
201
+ user_features: Optional[SparkDataFrame] = None,
202
+ item_features: Optional[SparkDataFrame] = None,
203
+ filter_seen_items: bool = True,
204
+ recs_file_path: Optional[str] = None,
205
+ ) -> Optional[SparkDataFrame]:
206
+ """
207
+ Predict wrapper to allow for fewer parameters in models
208
+
209
+ :param log: historical log of interactions
210
+ ``[user_idx, item_idx, timestamp, relevance]``
211
+ :param k: number of recommendations for each user
212
+ :param users: users to create recommendations for
213
+ dataframe containing ``[user_idx]`` or ``array-like``;
214
+ if ``None``, recommend to all users from ``log``
215
+ :param items: candidate items for recommendations
216
+ dataframe containing ``[item_idx]`` or ``array-like``;
217
+ if ``None``, take all items from ``log``.
218
+ If it contains new items, ``relevance`` for them will be ``0``.
219
+ :param user_features: user features
220
+ ``[user_idx , timestamp]`` + feature columns
221
+ :param item_features: item features
222
+ ``[item_idx , timestamp]`` + feature columns
223
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
224
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
225
+ If None, cached and materialized recommendations dataframe will be returned
226
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
227
+ or None if `file_path` is provided
228
+ """
229
+ log, users, items = self._filter_log_users_items_dataframes(log, k, users, items)
230
+
231
+ recs = self._predict(
232
+ log,
233
+ k,
234
+ users,
235
+ items,
236
+ user_features,
237
+ item_features,
238
+ filter_seen_items,
239
+ )
240
+ if filter_seen_items and log:
241
+ recs = self._filter_seen(recs=recs, log=log, users=users, k=k)
242
+
243
+ recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
244
+
245
+ output = return_recs(recs, recs_file_path)
246
+ self._clear_model_temp_view("filter_seen_users_log")
247
+ self._clear_model_temp_view("filter_seen_num_seen")
248
+ return output
249
+
250
+ def _filter_cold_for_predict(
251
+ self,
252
+ main_df: SparkDataFrame,
253
+ log_df: Optional[SparkDataFrame],
254
+ entity: str,
255
+ suffix: str = "idx",
256
+ ):
257
+ """
258
+ Filter out cold entities (users/items) from the `main_df` and `log_df`
259
+ if the model does not predict cold.
260
+ Warn if cold entities are present in the `main_df`.
261
+ """
262
+ if getattr(self, f"can_predict_cold_{entity}s"):
263
+ return main_df, log_df
264
+
265
+ fit_entities = getattr(self, f"fit_{entity}s")
266
+
267
+ num_new, main_df = filter_cold(main_df, fit_entities, col_name=f"{entity}_{suffix}")
268
+ if num_new > 0:
269
+ self.logger.info(
270
+ "%s model can't predict cold %ss, they will be ignored",
271
+ self,
272
+ entity,
273
+ )
274
+ _, log_df = filter_cold(log_df, fit_entities, col_name=f"{entity}_{suffix}")
275
+ return main_df, log_df
276
+
277
+ @abstractmethod
278
+ def _predict(
279
+ self,
280
+ log: SparkDataFrame,
281
+ k: int,
282
+ users: SparkDataFrame,
283
+ items: SparkDataFrame,
284
+ user_features: Optional[SparkDataFrame] = None,
285
+ item_features: Optional[SparkDataFrame] = None,
286
+ filter_seen_items: bool = True,
287
+ ) -> SparkDataFrame:
288
+ """
289
+ Inner method where model actually predicts.
290
+
291
+ :param log: historical log of interactions
292
+ ``[user_idx, item_idx, timestamp, relevance]``
293
+ :param k: number of recommendations for each user
294
+ :param users: users to create recommendations for
295
+ dataframe containing ``[user_idx]`` or ``array-like``;
296
+ if ``None``, recommend to all users from ``log``
297
+ :param items: candidate items for recommendations
298
+ dataframe containing ``[item_idx]`` or ``array-like``;
299
+ if ``None``, take all items from ``log``.
300
+ If it contains new items, ``relevance`` for them will be ``0``.
301
+ :param user_features: user features
302
+ ``[user_idx , timestamp]`` + feature columns
303
+ :param item_features: item features
304
+ ``[item_idx , timestamp]`` + feature columns
305
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
306
+ :return: recommendation dataframe
307
+ ``[user_idx, item_idx, relevance]``
308
+ """
309
+
310
+ def _get_fit_counts(self, entity: str) -> int:
311
+ if not hasattr(self, f"_num_{entity}s"):
312
+ setattr(
313
+ self,
314
+ f"_num_{entity}s",
315
+ getattr(self, f"fit_{entity}s").count(),
316
+ )
317
+ return getattr(self, f"_num_{entity}s")
318
+
319
+ @property
320
+ def users_count(self) -> int:
321
+ """
322
+ :returns: number of users the model was trained on
323
+ """
324
+ return self._get_fit_counts("user")
325
+
326
+ @property
327
+ def items_count(self) -> int:
328
+ """
329
+ :returns: number of items the model was trained on
330
+ """
331
+ return self._get_fit_counts("item")
332
+
333
+ def _get_fit_dims(self, entity: str) -> int:
334
+ if not hasattr(self, f"_{entity}_dim_size"):
335
+ setattr(
336
+ self,
337
+ f"_{entity}_dim_size",
338
+ getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
339
+ )
340
+ return getattr(self, f"_{entity}_dim_size")
341
+
342
+ @property
343
+ def _user_dim(self) -> int:
344
+ """
345
+ :returns: dimension of users matrix (maximal user idx + 1)
346
+ """
347
+ return self._get_fit_dims("user")
348
+
349
+ @property
350
+ def _item_dim(self) -> int:
351
+ """
352
+ :returns: dimension of items matrix (maximal item idx + 1)
353
+ """
354
+ return self._get_fit_dims("item")
355
+
356
+ def _fit_predict(
357
+ self,
358
+ log: SparkDataFrame,
359
+ k: int,
360
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
361
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
362
+ user_features: Optional[SparkDataFrame] = None,
363
+ item_features: Optional[SparkDataFrame] = None,
364
+ filter_seen_items: bool = True,
365
+ recs_file_path: Optional[str] = None,
366
+ ) -> Optional[SparkDataFrame]:
367
+ self._fit_wrap(log, user_features, item_features)
368
+ return self._predict_wrap(
369
+ log,
370
+ k,
371
+ users,
372
+ items,
373
+ user_features,
374
+ item_features,
375
+ filter_seen_items,
376
+ recs_file_path=recs_file_path,
377
+ )
378
+
379
+ def _predict_pairs_wrap(
380
+ self,
381
+ pairs: SparkDataFrame,
382
+ log: Optional[SparkDataFrame] = None,
383
+ user_features: Optional[SparkDataFrame] = None,
384
+ item_features: Optional[SparkDataFrame] = None,
385
+ recs_file_path: Optional[str] = None,
386
+ k: Optional[int] = None,
387
+ ) -> Optional[SparkDataFrame]:
388
+ """
389
+ This method
390
+ 1) converts data to spark
391
+ 2) removes cold users and items if model does not predict them
392
+ 3) calls inner _predict_pairs method of a model
393
+
394
+ :param pairs: user-item pairs to get relevance for,
395
+ dataframe containing``[user_idx, item_idx]``.
396
+ :param log: train data
397
+ ``[user_idx, item_idx, timestamp, relevance]``.
398
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
399
+ If None, cached and materialized recommendations dataframe will be returned
400
+ :return: cached dataframe with columns ``[user_idx, item_idx, relevance]``
401
+ or None if `file_path` is provided
402
+ """
403
+ log, user_features, item_features, pairs = [
404
+ convert2spark(df) for df in [log, user_features, item_features, pairs]
405
+ ]
406
+ if sorted(pairs.columns) != ["item_idx", "user_idx"]:
407
+ msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
408
+ raise ValueError(msg)
409
+ pairs, log = self._filter_cold_for_predict(pairs, log, "user")
410
+ pairs, log = self._filter_cold_for_predict(pairs, log, "item")
411
+
412
+ pred = self._predict_pairs(
413
+ pairs=pairs,
414
+ log=log,
415
+ user_features=user_features,
416
+ item_features=item_features,
417
+ )
418
+
419
+ if k:
420
+ pred = get_top_k(
421
+ dataframe=pred,
422
+ partition_by_col=sf.col("user_idx"),
423
+ order_by_col=[
424
+ sf.col("relevance").desc(),
425
+ ],
426
+ k=k,
427
+ )
428
+
429
+ if recs_file_path is not None:
430
+ pred.write.parquet(path=recs_file_path, mode="overwrite")
431
+ return None
432
+
433
+ pred.cache().count()
434
+ return pred
435
+
436
+ def _predict_pairs(
437
+ self,
438
+ pairs: SparkDataFrame,
439
+ log: Optional[SparkDataFrame] = None,
440
+ user_features: Optional[SparkDataFrame] = None,
441
+ item_features: Optional[SparkDataFrame] = None,
442
+ ) -> SparkDataFrame:
443
+ """
444
+ Fallback method to use in case ``_predict_pairs`` is not implemented.
445
+ Simply joins ``predict`` with given ``pairs``.
446
+ :param pairs: user-item pairs to get relevance for,
447
+ dataframe containing``[user_idx, item_idx]``.
448
+ :param log: train data
449
+ ``[user_idx, item_idx, timestamp, relevance]``.
450
+ """
451
+ message = (
452
+ "native predict_pairs is not implemented for this model. "
453
+ "Falling back to usual predict method and filtering the results."
454
+ )
455
+ self.logger.warning(message)
456
+
457
+ users = pairs.select("user_idx").distinct()
458
+ items = pairs.select("item_idx").distinct()
459
+ k = items.count()
460
+ pred = self._predict(
461
+ log=log,
462
+ k=k,
463
+ users=users,
464
+ items=items,
465
+ user_features=user_features,
466
+ item_features=item_features,
467
+ filter_seen_items=False,
468
+ )
469
+
470
+ pred = pred.join(
471
+ pairs.select("user_idx", "item_idx"),
472
+ on=["user_idx", "item_idx"],
473
+ how="inner",
474
+ )
475
+ return pred
476
+
477
+ def _get_features_wrap(
478
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
479
+ ) -> Optional[Tuple[SparkDataFrame, int]]:
480
+ if "user_idx" not in ids.columns and "item_idx" not in ids.columns:
481
+ msg = "user_idx or item_idx missing"
482
+ raise ValueError(msg)
483
+ vectors, rank = self._get_features(ids, features)
484
+ return vectors, rank
485
+
486
+ def _get_features(
487
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
488
+ ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
489
+ """
490
+ Get embeddings from model
491
+
492
+ :param ids: id ids to get embeddings for Spark DataFrame containing user_idx or item_idx
493
+ :param features: user or item features
494
+ :return: DataFrame with biases and embeddings, and vector size
495
+ """
496
+
497
+ self.logger.info(
498
+ "get_features method is not defined for the model %s. Features will not be returned.",
499
+ str(self),
500
+ )
501
+ return None, None
502
+
503
+ def _get_nearest_items_wrap(
504
+ self,
505
+ items: Union[SparkDataFrame, Iterable],
506
+ k: int,
507
+ metric: Optional[str] = "cosine_similarity",
508
+ candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
509
+ ) -> Optional[SparkDataFrame]:
510
+ """
511
+ Convert indexes and leave top-k nearest items for each item in `items`.
512
+ """
513
+ items = get_unique_entities(items, "item_idx")
514
+ if candidates is not None:
515
+ candidates = get_unique_entities(candidates, "item_idx")
516
+
517
+ nearest_items_to_filter = self._get_nearest_items(
518
+ items=items,
519
+ metric=metric,
520
+ candidates=candidates,
521
+ )
522
+
523
+ rel_col_name = metric if metric is not None else "similarity"
524
+ nearest_items = get_top_k(
525
+ dataframe=nearest_items_to_filter,
526
+ partition_by_col=sf.col("item_idx_one"),
527
+ order_by_col=[
528
+ sf.col(rel_col_name).desc(),
529
+ sf.col("item_idx_two").desc(),
530
+ ],
531
+ k=k,
532
+ )
533
+
534
+ nearest_items = nearest_items.withColumnRenamed("item_idx_two", "neighbour_item_idx")
535
+ nearest_items = nearest_items.withColumnRenamed("item_idx_one", "item_idx")
536
+ return nearest_items
537
+
538
+ def _get_nearest_items(
539
+ self,
540
+ items: SparkDataFrame, # noqa: ARG002
541
+ metric: Optional[str] = None, # noqa: ARG002
542
+ candidates: Optional[SparkDataFrame] = None, # noqa: ARG002
543
+ ) -> Optional[SparkDataFrame]:
544
+ msg = f"item-to-item prediction is not implemented for {self}"
545
+ raise NotImplementedError(msg)
546
+
547
+ def _save_model(self, path: str):
548
+ pass
549
+
550
+ def _load_model(self, path: str):
551
+ pass
552
+
553
+
554
+ class ItemVectorModel(BaseRecommender):
555
+ """Parent for models generating items' vector representations"""
556
+
557
+ can_predict_item_to_item: bool = True
558
+ item_to_item_metrics: List[str] = [
559
+ "euclidean_distance_sim",
560
+ "cosine_similarity",
561
+ "dot_product",
562
+ ]
563
+
564
+ @abstractmethod
565
+ def _get_item_vectors(self) -> SparkDataFrame:
566
+ """
567
+ Return dataframe with items' vectors as a
568
+ spark dataframe with columns ``[item_idx, item_vector]``
569
+ """
570
+
571
+ def get_nearest_items(
572
+ self,
573
+ items: Union[SparkDataFrame, Iterable],
574
+ k: int,
575
+ metric: Optional[str] = "cosine_similarity",
576
+ candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
577
+ ) -> Optional[SparkDataFrame]:
578
+ """
579
+ Get k most similar items be the `metric` for each of the `items`.
580
+
581
+ :param items: spark dataframe or list of item ids to find neighbors
582
+ :param k: number of neighbors
583
+ :param metric: 'euclidean_distance_sim', 'cosine_similarity', 'dot_product'
584
+ :param candidates: spark dataframe or list of items
585
+ to consider as similar, e.g. popular/new items. If None,
586
+ all items presented during model training are used.
587
+ :return: dataframe with the most similar items,
588
+ where bigger value means greater similarity.
589
+ spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
590
+ """
591
+ if metric not in self.item_to_item_metrics:
592
+ msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
593
+ raise ValueError(msg)
594
+
595
+ return self._get_nearest_items_wrap(
596
+ items=items,
597
+ k=k,
598
+ metric=metric,
599
+ candidates=candidates,
600
+ )
601
+
602
+ def _get_nearest_items(
603
+ self,
604
+ items: SparkDataFrame,
605
+ metric: str = "cosine_similarity",
606
+ candidates: Optional[SparkDataFrame] = None,
607
+ ) -> SparkDataFrame:
608
+ """
609
+ Return distance metric value for all available close items filtered by `candidates`.
610
+
611
+ :param items: ids to find neighbours, spark dataframe with column ``item_idx``
612
+ :param metric: 'euclidean_distance_sim' calculated as 1/(1 + euclidean_distance),
613
+ 'cosine_similarity', 'dot_product'
614
+ :param candidates: items among which we are looking for similar,
615
+ e.g. popular/new items. If None, all items presented during model training are used.
616
+ :return: dataframe with neighbours,
617
+ spark-dataframe with columns ``[item_idx_one, item_idx_two, similarity]``
618
+ """
619
+ dist_function = cosine_similarity
620
+ if metric == "euclidean_distance_sim":
621
+ dist_function = vector_euclidean_distance_similarity
622
+ elif metric == "dot_product":
623
+ dist_function = vector_dot
624
+
625
+ items_vectors = self._get_item_vectors()
626
+ left_part = (
627
+ items_vectors.withColumnRenamed("item_idx", "item_idx_one")
628
+ .withColumnRenamed("item_vector", "item_vector_one")
629
+ .join(
630
+ items.select(sf.col("item_idx").alias("item_idx_one")),
631
+ on="item_idx_one",
632
+ )
633
+ )
634
+
635
+ right_part = items_vectors.withColumnRenamed("item_idx", "item_idx_two").withColumnRenamed(
636
+ "item_vector", "item_vector_two"
637
+ )
638
+
639
+ if candidates is not None:
640
+ right_part = right_part.join(
641
+ candidates.withColumnRenamed("item_idx", "item_idx_two"),
642
+ on="item_idx_two",
643
+ )
644
+
645
+ joined_factors = left_part.join(right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two"))
646
+
647
+ joined_factors = joined_factors.withColumn(
648
+ metric,
649
+ dist_function(sf.col("item_vector_one"), sf.col("item_vector_two")),
650
+ )
651
+
652
+ similarity_matrix = joined_factors.select("item_idx_one", "item_idx_two", metric)
653
+
654
+ return similarity_matrix
655
+
656
+
657
+ class HybridRecommender(BaseRecommender, ABC):
658
+ """Base class for models that can use extra features"""
659
+
660
+ def fit(
661
+ self,
662
+ log: SparkDataFrame,
663
+ user_features: Optional[SparkDataFrame] = None,
664
+ item_features: Optional[SparkDataFrame] = None,
665
+ ) -> None:
666
+ """
667
+ Fit a recommendation model
668
+
669
+ :param log: historical log of interactions
670
+ ``[user_idx, item_idx, timestamp, relevance]``
671
+ :param user_features: user features
672
+ ``[user_idx, timestamp]`` + feature columns
673
+ :param item_features: item features
674
+ ``[item_idx, timestamp]`` + feature columns
675
+ :return:
676
+ """
677
+ self._fit_wrap(
678
+ log=log,
679
+ user_features=user_features,
680
+ item_features=item_features,
681
+ )
682
+
683
+ def predict(
684
+ self,
685
+ log: SparkDataFrame,
686
+ k: int,
687
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
688
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
689
+ user_features: Optional[SparkDataFrame] = None,
690
+ item_features: Optional[SparkDataFrame] = None,
691
+ filter_seen_items: bool = True,
692
+ recs_file_path: Optional[str] = None,
693
+ ) -> Optional[SparkDataFrame]:
694
+ """
695
+ Get recommendations
696
+
697
+ :param log: historical log of interactions
698
+ ``[user_idx, item_idx, timestamp, relevance]``
699
+ :param k: number of recommendations for each user
700
+ :param users: users to create recommendations for
701
+ dataframe containing ``[user_idx]`` or ``array-like``;
702
+ if ``None``, recommend to all users from ``log``
703
+ :param items: candidate items for recommendations
704
+ dataframe containing ``[item_idx]`` or ``array-like``;
705
+ if ``None``, take all items from ``log``.
706
+ If it contains new items, ``relevance`` for them will be ``0``.
707
+ :param user_features: user features
708
+ ``[user_idx , timestamp]`` + feature columns
709
+ :param item_features: item features
710
+ ``[item_idx , timestamp]`` + feature columns
711
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
712
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
713
+ If None, cached and materialized recommendations dataframe will be returned
714
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
715
+ or None if `file_path` is provided
716
+
717
+ """
718
+ return self._predict_wrap(
719
+ log=log,
720
+ k=k,
721
+ users=users,
722
+ items=items,
723
+ user_features=user_features,
724
+ item_features=item_features,
725
+ filter_seen_items=filter_seen_items,
726
+ recs_file_path=recs_file_path,
727
+ )
728
+
729
+ def fit_predict(
730
+ self,
731
+ log: SparkDataFrame,
732
+ k: int,
733
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
734
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
735
+ user_features: Optional[SparkDataFrame] = None,
736
+ item_features: Optional[SparkDataFrame] = None,
737
+ filter_seen_items: bool = True,
738
+ recs_file_path: Optional[str] = None,
739
+ ) -> Optional[SparkDataFrame]:
740
+ """
741
+ Fit model and get recommendations
742
+
743
+ :param log: historical log of interactions
744
+ ``[user_idx, item_idx, timestamp, relevance]``
745
+ :param k: number of recommendations for each user
746
+ :param users: users to create recommendations for
747
+ dataframe containing ``[user_idx]`` or ``array-like``;
748
+ if ``None``, recommend to all users from ``log``
749
+ :param items: candidate items for recommendations
750
+ dataframe containing ``[item_idx]`` or ``array-like``;
751
+ if ``None``, take all items from ``log``.
752
+ If it contains new items, ``relevance`` for them will be ``0``.
753
+ :param user_features: user features
754
+ ``[user_idx , timestamp]`` + feature columns
755
+ :param item_features: item features
756
+ ``[item_idx , timestamp]`` + feature columns
757
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
758
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
759
+ If None, cached and materialized recommendations dataframe will be returned
760
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
761
+ or None if `file_path` is provided
762
+ """
763
+ return self._fit_predict(
764
+ log=log,
765
+ k=k,
766
+ users=users,
767
+ items=items,
768
+ user_features=user_features,
769
+ item_features=item_features,
770
+ filter_seen_items=filter_seen_items,
771
+ recs_file_path=recs_file_path,
772
+ )
773
+
774
+ def predict_pairs(
775
+ self,
776
+ pairs: SparkDataFrame,
777
+ log: Optional[SparkDataFrame] = None,
778
+ user_features: Optional[SparkDataFrame] = None,
779
+ item_features: Optional[SparkDataFrame] = None,
780
+ recs_file_path: Optional[str] = None,
781
+ k: Optional[int] = None,
782
+ ) -> Optional[SparkDataFrame]:
783
+ """
784
+ Get recommendations for specific user-item ``pairs``.
785
+ If a model can't produce recommendation
786
+ for specific pair it is removed from the resulting dataframe.
787
+
788
+ :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
789
+ :param log: historical log of interactions
790
+ ``[user_idx, item_idx, timestamp, relevance]``
791
+ :param user_features: user features
792
+ ``[user_idx , timestamp]`` + feature columns
793
+ :param item_features: item features
794
+ ``[item_idx , timestamp]`` + feature columns
795
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
796
+ If None, cached and materialized recommendations dataframe will be returned
797
+ :param k: top-k items for each user from pairs.
798
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
799
+ or None if `file_path` is provided
800
+ """
801
+ return self._predict_pairs_wrap(
802
+ pairs=pairs,
803
+ log=log,
804
+ user_features=user_features,
805
+ item_features=item_features,
806
+ recs_file_path=recs_file_path,
807
+ k=k,
808
+ )
809
+
810
+ def get_features(
811
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
812
+ ) -> Optional[Tuple[SparkDataFrame, int]]:
813
+ """
814
+ Returns user or item feature vectors as a Column with type ArrayType
815
+ :param ids: Spark DataFrame with unique ids
816
+ :param features: Spark DataFrame with features for provided ids
817
+ :return: feature vectors
818
+ If a model does not have a vector for some ids they are not present in the final result.
819
+ """
820
+ return self._get_features_wrap(ids, features)
821
+
822
+
823
+ class Recommender(BaseRecommender, ABC):
824
+ """Usual recommender class for models without features."""
825
+
826
+ def fit(self, log: SparkDataFrame) -> None:
827
+ """
828
+ Fit a recommendation model
829
+
830
+ :param log: historical log of interactions
831
+ ``[user_idx, item_idx, timestamp, relevance]``
832
+ :return:
833
+ """
834
+ self._fit_wrap(
835
+ log=log,
836
+ user_features=None,
837
+ item_features=None,
838
+ )
839
+
840
+ def predict(
841
+ self,
842
+ log: SparkDataFrame,
843
+ k: int,
844
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
845
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
846
+ filter_seen_items: bool = True,
847
+ recs_file_path: Optional[str] = None,
848
+ ) -> Optional[SparkDataFrame]:
849
+ """
850
+ Get recommendations
851
+
852
+ :param log: historical log of interactions
853
+ ``[user_idx, item_idx, timestamp, relevance]``
854
+ :param k: number of recommendations for each user
855
+ :param users: users to create recommendations for
856
+ dataframe containing ``[user_idx]`` or ``array-like``;
857
+ if ``None``, recommend to all users from ``log``
858
+ :param items: candidate items for recommendations
859
+ dataframe containing ``[item_idx]`` or ``array-like``;
860
+ if ``None``, take all items from ``log``.
861
+ If it contains new items, ``relevance`` for them will be ``0``.
862
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
863
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
864
+ If None, cached and materialized recommendations dataframe will be returned
865
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
866
+ or None if `file_path` is provided
867
+ """
868
+ return self._predict_wrap(
869
+ log=log,
870
+ k=k,
871
+ users=users,
872
+ items=items,
873
+ user_features=None,
874
+ item_features=None,
875
+ filter_seen_items=filter_seen_items,
876
+ recs_file_path=recs_file_path,
877
+ )
878
+
879
+ def predict_pairs(
880
+ self,
881
+ pairs: SparkDataFrame,
882
+ log: Optional[SparkDataFrame] = None,
883
+ recs_file_path: Optional[str] = None,
884
+ k: Optional[int] = None,
885
+ ) -> Optional[SparkDataFrame]:
886
+ """
887
+ Get recommendations for specific user-item ``pairs``.
888
+ If a model can't produce recommendation
889
+ for specific pair it is removed from the resulting dataframe.
890
+
891
+ :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
892
+ :param log: historical log of interactions
893
+ ``[user_idx, item_idx, timestamp, relevance]``
894
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
895
+ If None, cached and materialized recommendations dataframe will be returned
896
+ :param k: top-k items for each user from pairs.
897
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
898
+ or None if `file_path` is provided
899
+ """
900
+ return self._predict_pairs_wrap(
901
+ pairs=pairs,
902
+ log=log,
903
+ recs_file_path=recs_file_path,
904
+ k=k,
905
+ )
906
+
907
+ def fit_predict(
908
+ self,
909
+ log: SparkDataFrame,
910
+ k: int,
911
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
912
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
913
+ filter_seen_items: bool = True,
914
+ recs_file_path: Optional[str] = None,
915
+ ) -> Optional[SparkDataFrame]:
916
+ """
917
+ Fit model and get recommendations
918
+
919
+ :param log: historical log of interactions
920
+ ``[user_idx, item_idx, timestamp, relevance]``
921
+ :param k: number of recommendations for each user
922
+ :param users: users to create recommendations for
923
+ dataframe containing ``[user_idx]`` or ``array-like``;
924
+ if ``None``, recommend to all users from ``log``
925
+ :param items: candidate items for recommendations
926
+ dataframe containing ``[item_idx]`` or ``array-like``;
927
+ if ``None``, take all items from ``log``.
928
+ If it contains new items, ``relevance`` for them will be ``0``.
929
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
930
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
931
+ If None, cached and materialized recommendations dataframe will be returned
932
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
933
+ or None if `file_path` is provided
934
+ """
935
+ return self._fit_predict(
936
+ log=log,
937
+ k=k,
938
+ users=users,
939
+ items=items,
940
+ user_features=None,
941
+ item_features=None,
942
+ filter_seen_items=filter_seen_items,
943
+ recs_file_path=recs_file_path,
944
+ )
945
+
946
+ def get_features(self, ids: SparkDataFrame) -> Optional[Tuple[SparkDataFrame, int]]:
947
+ """
948
+ Returns user or item feature vectors as a Column with type ArrayType
949
+
950
+ :param ids: Spark DataFrame with unique ids
951
+ :return: feature vectors.
952
+ If a model does not have a vector for some ids they are not present in the final result.
953
+ """
954
+ return self._get_features_wrap(ids, None)
955
+
956
+
957
+ class UserRecommender(BaseRecommender, ABC):
958
+ """Base class for models that use user features
959
+ but not item features. ``log`` is not required for this class."""
960
+
961
+ def fit(
962
+ self,
963
+ log: SparkDataFrame,
964
+ user_features: SparkDataFrame,
965
+ ) -> None:
966
+ """
967
+ Finds user clusters and calculates item similarity in that clusters.
968
+
969
+ :param log: historical log of interactions
970
+ ``[user_idx, item_idx, timestamp, relevance]``
971
+ :param user_features: user features
972
+ ``[user_idx, timestamp]`` + feature columns
973
+ :return:
974
+ """
975
+ self._fit_wrap(log=log, user_features=user_features)
976
+
977
+ def predict(
978
+ self,
979
+ user_features: SparkDataFrame,
980
+ k: int,
981
+ log: Optional[SparkDataFrame] = None,
982
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
983
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
984
+ filter_seen_items: bool = True,
985
+ recs_file_path: Optional[str] = None,
986
+ ) -> Optional[SparkDataFrame]:
987
+ """
988
+ Get recommendations
989
+
990
+ :param log: historical log of interactions
991
+ ``[user_idx, item_idx, timestamp, relevance]``
992
+ :param k: number of recommendations for each user
993
+ :param users: users to create recommendations for
994
+ dataframe containing ``[user_idx]`` or ``array-like``;
995
+ if ``None``, recommend to all users from ``log``
996
+ :param items: candidate items for recommendations
997
+ dataframe containing ``[item_idx]`` or ``array-like``;
998
+ if ``None``, take all items from ``log``.
999
+ If it contains new items, ``relevance`` for them will be ``0``.
1000
+ :param user_features: user features
1001
+ ``[user_idx , timestamp]`` + feature columns
1002
+ :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
1003
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
1004
+ If None, cached and materialized recommendations dataframe will be returned
1005
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
1006
+ or None if `file_path` is provided
1007
+ """
1008
+ return self._predict_wrap(
1009
+ log=log,
1010
+ user_features=user_features,
1011
+ k=k,
1012
+ filter_seen_items=filter_seen_items,
1013
+ users=users,
1014
+ items=items,
1015
+ recs_file_path=recs_file_path,
1016
+ )
1017
+
1018
+ def predict_pairs(
1019
+ self,
1020
+ pairs: SparkDataFrame,
1021
+ user_features: SparkDataFrame,
1022
+ log: Optional[SparkDataFrame] = None,
1023
+ recs_file_path: Optional[str] = None,
1024
+ k: Optional[int] = None,
1025
+ ) -> Optional[SparkDataFrame]:
1026
+ """
1027
+ Get recommendations for specific user-item ``pairs``.
1028
+ If a model can't produce recommendation
1029
+ for specific pair it is removed from the resulting dataframe.
1030
+
1031
+ :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
1032
+ :param user_features: user features
1033
+ ``[user_idx , timestamp]`` + feature columns
1034
+ :param log: historical log of interactions
1035
+ ``[user_idx, item_idx, timestamp, relevance]``
1036
+ :param recs_file_path: save recommendations at the given absolute path as parquet file.
1037
+ If None, cached and materialized recommendations dataframe will be returned
1038
+ :param k: top-k items for each user from pairs.
1039
+ :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
1040
+ or None if `file_path` is provided
1041
+ """
1042
+ return self._predict_pairs_wrap(
1043
+ pairs=pairs,
1044
+ log=log,
1045
+ user_features=user_features,
1046
+ recs_file_path=recs_file_path,
1047
+ k=k,
1048
+ )
1049
+
1050
+
1051
+ class NonPersonalizedRecommender(Recommender, ABC):
1052
+ """Base class for non-personalized recommenders with popularity statistics."""
1053
+
1054
+ can_predict_cold_users = True
1055
+ can_predict_cold_items = True
1056
+ item_popularity: SparkDataFrame
1057
+ add_cold_items: bool
1058
+ cold_weight: float
1059
+ sample: bool
1060
+ fill: float
1061
+ seed: Optional[int] = None
1062
+
1063
+ def __init__(self, add_cold_items: bool, cold_weight: float):
1064
+ self.add_cold_items = add_cold_items
1065
+ if 0 < cold_weight <= 1:
1066
+ self.cold_weight = cold_weight
1067
+ else:
1068
+ msg = "`cold_weight` value should be in interval (0, 1]"
1069
+ raise ValueError(msg)
1070
+
1071
+ @property
1072
+ def _dataframes(self):
1073
+ return {"item_popularity": self.item_popularity}
1074
+
1075
+ def _save_model(self, path: str):
1076
+ save_picklable_to_parquet(self.fill, join(path, "params.dump"))
1077
+
1078
+ def _load_model(self, path: str):
1079
+ self.fill = load_pickled_from_parquet(join(path, "params.dump"))
1080
+
1081
+ def _clear_cache(self):
1082
+ if hasattr(self, "item_popularity"):
1083
+ self.item_popularity.unpersist()
1084
+
1085
+ @staticmethod
1086
+ def _calc_fill(item_popularity: SparkDataFrame, weight: float) -> float:
1087
+ """
1088
+ Calculating a fill value a the minimal relevance
1089
+ calculated during model training multiplied by weight.
1090
+ """
1091
+ return item_popularity.select(sf.min("relevance")).first()[0] * weight
1092
+
1093
+ @staticmethod
1094
+ def _check_relevance(log: SparkDataFrame):
1095
+ vals = log.select("relevance").where((sf.col("relevance") != 1) & (sf.col("relevance") != 0))
1096
+ if vals.count() > 0:
1097
+ msg = "Relevance values in log must be 0 or 1"
1098
+ raise ValueError(msg)
1099
+
1100
+ def _get_selected_item_popularity(self, items: SparkDataFrame) -> SparkDataFrame:
1101
+ """
1102
+ Choose only required item from `item_popularity` dataframe
1103
+ for further recommendations generation.
1104
+ """
1105
+ return self.item_popularity.join(
1106
+ items,
1107
+ on="item_idx",
1108
+ how="right" if self.add_cold_items else "inner",
1109
+ ).fillna(value=self.fill, subset=["relevance"])
1110
+
1111
+ @staticmethod
1112
+ def _calc_max_hist_len(log: SparkDataFrame, users: SparkDataFrame) -> int:
1113
+ max_hist_len = (
1114
+ (log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
1115
+ .select(sf.max("items_count"))
1116
+ .first()[0]
1117
+ )
1118
+ # all users have empty history
1119
+ if max_hist_len is None:
1120
+ max_hist_len = 0
1121
+
1122
+ return max_hist_len
1123
+
1124
+ def _predict_without_sampling(
1125
+ self,
1126
+ log: SparkDataFrame,
1127
+ k: int,
1128
+ users: SparkDataFrame,
1129
+ items: SparkDataFrame,
1130
+ filter_seen_items: bool = True,
1131
+ ) -> SparkDataFrame:
1132
+ """
1133
+ Regular prediction for popularity-based models,
1134
+ top-k most relevant items from `items` are chosen for each user
1135
+ """
1136
+ selected_item_popularity = self._get_selected_item_popularity(items)
1137
+ selected_item_popularity = selected_item_popularity.withColumn(
1138
+ "rank",
1139
+ sf.row_number().over(Window.orderBy(sf.col("relevance").desc(), sf.col("item_idx").desc())),
1140
+ )
1141
+
1142
+ if filter_seen_items and log is not None:
1143
+ user_to_num_items = (
1144
+ log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("num_items"))
1145
+ )
1146
+ users = users.join(user_to_num_items, on="user_idx", how="left")
1147
+ users = users.fillna(0, "num_items")
1148
+ # 'selected_item_popularity' truncation by k + max_seen
1149
+ max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1150
+ selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1151
+ return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1152
+
1153
+ return users.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
1154
+
1155
+ def _predict_with_sampling(
1156
+ self,
1157
+ log: SparkDataFrame,
1158
+ k: int,
1159
+ users: SparkDataFrame,
1160
+ items: SparkDataFrame,
1161
+ filter_seen_items: bool = True,
1162
+ ) -> SparkDataFrame:
1163
+ """
1164
+ Randomized prediction for popularity-based models,
1165
+ top-k items from `items` are sampled for each user based with
1166
+ probability proportional to items' popularity
1167
+ """
1168
+ selected_item_popularity = self._get_selected_item_popularity(items)
1169
+ selected_item_popularity = selected_item_popularity.withColumn(
1170
+ "relevance",
1171
+ sf.when(sf.col("relevance") == sf.lit(0.0), 0.1**6).otherwise(sf.col("relevance")),
1172
+ )
1173
+
1174
+ items_pd = selected_item_popularity.withColumn(
1175
+ "probability",
1176
+ sf.col("relevance") / selected_item_popularity.select(sf.sum("relevance")).first()[0],
1177
+ ).toPandas()
1178
+
1179
+ rec_schema = get_schema(
1180
+ query_column="user_idx",
1181
+ item_column="item_idx",
1182
+ rating_column="relevance",
1183
+ has_timestamp=False,
1184
+ )
1185
+ if items_pd.shape[0] == 0:
1186
+ return State().session.createDataFrame([], rec_schema)
1187
+
1188
+ seed = self.seed
1189
+ class_name = self.__class__.__name__
1190
+
1191
+ def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
1192
+ user_idx = pandas_df["user_idx"][0]
1193
+ cnt = pandas_df["cnt"][0]
1194
+
1195
+ local_rng = default_rng(seed + user_idx) if seed is not None else default_rng()
1196
+
1197
+ items_positions = local_rng.choice(
1198
+ np.arange(items_pd.shape[0]),
1199
+ size=cnt,
1200
+ p=items_pd["probability"].values,
1201
+ replace=False,
1202
+ )
1203
+
1204
+ # workaround to unify RandomRec and UCB
1205
+ if class_name == "RandomRec":
1206
+ relevance = 1 / np.arange(1, cnt + 1)
1207
+ else:
1208
+ relevance = items_pd["probability"].values[items_positions]
1209
+
1210
+ return PandasDataFrame(
1211
+ {
1212
+ "user_idx": cnt * [user_idx],
1213
+ "item_idx": items_pd["item_idx"].values[items_positions],
1214
+ "relevance": relevance,
1215
+ }
1216
+ )
1217
+
1218
+ if log is not None and filter_seen_items:
1219
+ recs = (
1220
+ log.select("user_idx", "item_idx")
1221
+ .distinct()
1222
+ .join(users, how="right", on="user_idx")
1223
+ .groupby("user_idx")
1224
+ .agg(sf.countDistinct("item_idx").alias("cnt"))
1225
+ .selectExpr(
1226
+ "user_idx",
1227
+ f"LEAST(cnt + {k}, {items_pd.shape[0]}) AS cnt",
1228
+ )
1229
+ )
1230
+ else:
1231
+ recs = users.withColumn("cnt", sf.lit(min(k, items_pd.shape[0])))
1232
+
1233
+ return recs.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
1234
+
1235
+ def _predict(
1236
+ self,
1237
+ log: SparkDataFrame,
1238
+ k: int,
1239
+ users: SparkDataFrame,
1240
+ items: SparkDataFrame,
1241
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1242
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1243
+ filter_seen_items: bool = True,
1244
+ ) -> SparkDataFrame:
1245
+ if self.sample:
1246
+ return self._predict_with_sampling(
1247
+ log=log,
1248
+ k=k,
1249
+ users=users,
1250
+ items=items,
1251
+ filter_seen_items=filter_seen_items,
1252
+ )
1253
+ else:
1254
+ return self._predict_without_sampling(log, k, users, items, filter_seen_items)
1255
+
1256
+ def _predict_pairs(
1257
+ self,
1258
+ pairs: SparkDataFrame,
1259
+ log: Optional[SparkDataFrame] = None, # noqa: ARG002
1260
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1261
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1262
+ ) -> SparkDataFrame:
1263
+ return (
1264
+ pairs.join(
1265
+ self.item_popularity,
1266
+ on="item_idx",
1267
+ how="left" if self.add_cold_items else "inner",
1268
+ )
1269
+ .fillna(value=self.fill, subset=["relevance"])
1270
+ .select("user_idx", "item_idx", "relevance")
1271
+ )