replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +9 -9
  3. replay/data/nn/__init__.py +6 -6
  4. replay/data/nn/sequence_tokenizer.py +44 -38
  5. replay/data/nn/sequential_dataset.py +13 -8
  6. replay/data/nn/torch_sequential_dataset.py +14 -13
  7. replay/data/nn/utils.py +1 -1
  8. replay/metrics/base_metric.py +1 -1
  9. replay/metrics/coverage.py +7 -11
  10. replay/metrics/experiment.py +3 -3
  11. replay/metrics/offline_metrics.py +2 -2
  12. replay/models/__init__.py +19 -0
  13. replay/models/association_rules.py +1 -4
  14. replay/models/base_neighbour_rec.py +6 -9
  15. replay/models/base_rec.py +44 -293
  16. replay/models/cat_pop_rec.py +2 -1
  17. replay/models/common.py +69 -0
  18. replay/models/extensions/ann/ann_mixin.py +30 -25
  19. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  20. replay/models/extensions/ann/utils.py +4 -3
  21. replay/models/knn.py +18 -17
  22. replay/models/nn/sequential/bert4rec/dataset.py +1 -1
  23. replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
  24. replay/models/nn/sequential/compiled/__init__.py +10 -0
  25. replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
  26. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  27. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  28. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  29. replay/models/nn/sequential/sasrec/model.py +1 -1
  30. replay/models/optimization/__init__.py +14 -0
  31. replay/models/optimization/optuna_mixin.py +279 -0
  32. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  33. replay/models/slim.py +2 -4
  34. replay/models/word2vec.py +7 -12
  35. replay/preprocessing/discretizer.py +1 -2
  36. replay/preprocessing/history_based_fp.py +1 -1
  37. replay/preprocessing/label_encoder.py +1 -1
  38. replay/splitters/cold_user_random_splitter.py +13 -7
  39. replay/splitters/last_n_splitter.py +17 -10
  40. replay/utils/__init__.py +6 -2
  41. replay/utils/common.py +4 -2
  42. replay/utils/model_handler.py +11 -31
  43. replay/utils/session_handler.py +2 -2
  44. replay/utils/spark_utils.py +2 -2
  45. replay/utils/types.py +28 -18
  46. replay/utils/warnings.py +26 -0
  47. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -40
  48. replay_rec-0.20.0.dist-info/RECORD +139 -0
  49. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
  50. replay/experimental/__init__.py +0 -0
  51. replay/experimental/metrics/__init__.py +0 -62
  52. replay/experimental/metrics/base_metric.py +0 -602
  53. replay/experimental/metrics/coverage.py +0 -97
  54. replay/experimental/metrics/experiment.py +0 -175
  55. replay/experimental/metrics/hitrate.py +0 -26
  56. replay/experimental/metrics/map.py +0 -30
  57. replay/experimental/metrics/mrr.py +0 -18
  58. replay/experimental/metrics/ncis_precision.py +0 -31
  59. replay/experimental/metrics/ndcg.py +0 -49
  60. replay/experimental/metrics/precision.py +0 -22
  61. replay/experimental/metrics/recall.py +0 -25
  62. replay/experimental/metrics/rocauc.py +0 -49
  63. replay/experimental/metrics/surprisal.py +0 -90
  64. replay/experimental/metrics/unexpectedness.py +0 -76
  65. replay/experimental/models/__init__.py +0 -13
  66. replay/experimental/models/admm_slim.py +0 -205
  67. replay/experimental/models/base_neighbour_rec.py +0 -204
  68. replay/experimental/models/base_rec.py +0 -1340
  69. replay/experimental/models/base_torch_rec.py +0 -234
  70. replay/experimental/models/cql.py +0 -454
  71. replay/experimental/models/ddpg.py +0 -923
  72. replay/experimental/models/dt4rec/__init__.py +0 -0
  73. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  74. replay/experimental/models/dt4rec/gpt1.py +0 -401
  75. replay/experimental/models/dt4rec/trainer.py +0 -127
  76. replay/experimental/models/dt4rec/utils.py +0 -265
  77. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  78. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  79. replay/experimental/models/hierarchical_recommender.py +0 -331
  80. replay/experimental/models/implicit_wrap.py +0 -131
  81. replay/experimental/models/lightfm_wrap.py +0 -302
  82. replay/experimental/models/mult_vae.py +0 -332
  83. replay/experimental/models/neural_ts.py +0 -986
  84. replay/experimental/models/neuromf.py +0 -406
  85. replay/experimental/models/scala_als.py +0 -296
  86. replay/experimental/models/u_lin_ucb.py +0 -115
  87. replay/experimental/nn/data/__init__.py +0 -1
  88. replay/experimental/nn/data/schema_builder.py +0 -102
  89. replay/experimental/preprocessing/__init__.py +0 -3
  90. replay/experimental/preprocessing/data_preparator.py +0 -839
  91. replay/experimental/preprocessing/padder.py +0 -229
  92. replay/experimental/preprocessing/sequence_generator.py +0 -208
  93. replay/experimental/scenarios/__init__.py +0 -1
  94. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  95. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  96. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  97. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  98. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  99. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  100. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  101. replay/experimental/utils/__init__.py +0 -0
  102. replay/experimental/utils/logger.py +0 -24
  103. replay/experimental/utils/model_handler.py +0 -186
  104. replay/experimental/utils/session_handler.py +0 -44
  105. replay/optimization/__init__.py +0 -5
  106. replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
  107. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
  108. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/NOTICE +0 -0
@@ -1,1340 +0,0 @@
1
- """
2
- Base abstract classes:
3
- - BaseRecommender - the simplest base class
4
- - Recommender - base class for models that fit on interaction log
5
- - HybridRecommender - base class for models that accept user or item features
6
- - UserRecommender - base class that accepts only user features, but not item features
7
- - NeighbourRec - base class that requires log at prediction time
8
- - ItemVectorModel - class for models which provides items' vectors.
9
- Implements similar items search.
10
- - NonPersonalizedRecommender - base class for non-personalized recommenders
11
- with popularity statistics
12
- """
13
-
14
- from abc import ABC, abstractmethod
15
- from os.path import join
16
- from typing import Any, Iterable, List, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- from numpy.random import default_rng
20
-
21
- from replay.data import get_schema
22
- from replay.experimental.utils.session_handler import State
23
- from replay.models.base_rec import IsSavable, RecommenderCommons
24
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
25
- from replay.utils.spark_utils import (
26
- convert2spark,
27
- cosine_similarity,
28
- filter_cold,
29
- get_top_k,
30
- get_top_k_recs,
31
- get_unique_entities,
32
- load_pickled_from_parquet,
33
- return_recs,
34
- save_picklable_to_parquet,
35
- vector_dot,
36
- vector_euclidean_distance_similarity,
37
- )
38
-
39
- if PYSPARK_AVAILABLE:
40
- from pyspark.sql import (
41
- Window,
42
- functions as sf,
43
- )
44
-
45
-
46
- class BaseRecommender(RecommenderCommons, IsSavable, ABC):
47
- """Base recommender"""
48
-
49
- model: Any
50
- can_predict_cold_users: bool = False
51
- can_predict_cold_items: bool = False
52
- fit_users: SparkDataFrame
53
- fit_items: SparkDataFrame
54
- _num_users: int
55
- _num_items: int
56
- _user_dim_size: int
57
- _item_dim_size: int
58
-
59
- def _fit_wrap(
60
- self,
61
- log: SparkDataFrame,
62
- user_features: Optional[SparkDataFrame] = None,
63
- item_features: Optional[SparkDataFrame] = None,
64
- ) -> None:
65
- """
66
- Wrapper for fit to allow for fewer arguments in a model.
67
-
68
- :param log: historical log of interactions
69
- ``[user_idx, item_idx, timestamp, relevance]``
70
- :param user_features: user features
71
- ``[user_idx, timestamp]`` + feature columns
72
- :param item_features: item features
73
- ``[item_idx, timestamp]`` + feature columns
74
- :return:
75
- """
76
- self.logger.debug("Starting fit %s", type(self).__name__)
77
- if user_features is None:
78
- users = log.select("user_idx").distinct()
79
- else:
80
- users = log.select("user_idx").union(user_features.select("user_idx")).distinct()
81
- if item_features is None:
82
- items = log.select("item_idx").distinct()
83
- else:
84
- items = log.select("item_idx").union(item_features.select("item_idx")).distinct()
85
- self.fit_users = sf.broadcast(users)
86
- self.fit_items = sf.broadcast(items)
87
- self._num_users = self.fit_users.count()
88
- self._num_items = self.fit_items.count()
89
- self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
90
- self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
91
- self._fit(log, user_features, item_features)
92
-
93
- @abstractmethod
94
- def _fit(
95
- self,
96
- log: SparkDataFrame,
97
- user_features: Optional[SparkDataFrame] = None,
98
- item_features: Optional[SparkDataFrame] = None,
99
- ) -> None:
100
- """
101
- Inner method where model actually fits.
102
-
103
- :param log: historical log of interactions
104
- ``[user_idx, item_idx, timestamp, relevance]``
105
- :param user_features: user features
106
- ``[user_idx, timestamp]`` + feature columns
107
- :param item_features: item features
108
- ``[item_idx, timestamp]`` + feature columns
109
- :return:
110
- """
111
-
112
- def _filter_seen(self, recs: SparkDataFrame, log: SparkDataFrame, k: int, users: SparkDataFrame):
113
- """
114
- Filter seen items (presented in log) out of the users' recommendations.
115
- For each user return from `k` to `k + number of seen by user` recommendations.
116
- """
117
- users_log = log.join(users, on="user_idx")
118
- self._cache_model_temp_view(users_log, "filter_seen_users_log")
119
- num_seen = users_log.groupBy("user_idx").agg(sf.count("item_idx").alias("seen_count"))
120
- self._cache_model_temp_view(num_seen, "filter_seen_num_seen")
121
-
122
- # count maximal number of items seen by users
123
- max_seen = 0
124
- if num_seen.count() > 0:
125
- max_seen = num_seen.select(sf.max("seen_count")).first()[0]
126
-
127
- # crop recommendations to first k + max_seen items for each user
128
- recs = recs.withColumn(
129
- "temp_rank",
130
- sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.col("relevance").desc())),
131
- ).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
132
-
133
- # leave k + number of items seen by user recommendations in recs
134
- recs = (
135
- recs.join(num_seen, on="user_idx", how="left")
136
- .fillna(0)
137
- .filter(sf.col("temp_rank") <= sf.col("seen_count") + sf.lit(k))
138
- .drop("temp_rank", "seen_count")
139
- )
140
-
141
- # filter recommendations presented in interactions log
142
- recs = recs.join(
143
- users_log.withColumnRenamed("item_idx", "item")
144
- .withColumnRenamed("user_idx", "user")
145
- .select("user", "item"),
146
- on=(sf.col("user_idx") == sf.col("user")) & (sf.col("item_idx") == sf.col("item")),
147
- how="anti",
148
- ).drop("user", "item")
149
-
150
- return recs
151
-
152
- def _filter_log_users_items_dataframes(
153
- self,
154
- log: Optional[SparkDataFrame],
155
- k: int,
156
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
157
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
158
- user_features: Optional[SparkDataFrame] = None,
159
- ):
160
- """
161
- Returns triplet of filtered `log`, `users`, and `items`.
162
- Filters out cold entities (users/items) from the `users`/`items` and `log` dataframes
163
- if the model does not predict cold.
164
- Filters out duplicates from `users` and `items` dataframes,
165
- and excludes all columns except `user_idx` and `item_idx`.
166
-
167
- :param log: historical log of interactions
168
- ``[user_idx, item_idx, timestamp, relevance]``
169
- :param k: number of recommendations for each user
170
- :param users: users to create recommendations for
171
- dataframe containing ``[user_idx]`` or ``array-like``;
172
- if ``None``, recommend to all users from ``log``
173
- :param items: candidate items for recommendations
174
- dataframe containing ``[item_idx]`` or ``array-like``;
175
- if ``None``, take all items from ``log``.
176
- If it contains new items, ``relevance`` for them will be ``0``.
177
- :param user_features: user features
178
- ``[user_idx , timestamp]`` + feature columns
179
- :return: triplet of filtered `log`, `users`, and `items` dataframes.
180
- """
181
- self.logger.debug("Starting predict %s", type(self).__name__)
182
- user_data = users or log or user_features or self.fit_users
183
- users = get_unique_entities(user_data, "user_idx")
184
- users, log = self._filter_cold_for_predict(users, log, "user")
185
-
186
- item_data = items or self.fit_items
187
- items = get_unique_entities(item_data, "item_idx")
188
- items, log = self._filter_cold_for_predict(items, log, "item")
189
- num_items = items.count()
190
- if num_items < k:
191
- message = f"k = {k} > number of items = {num_items}"
192
- self.logger.debug(message)
193
- return log, users, items
194
-
195
- def _predict_wrap(
196
- self,
197
- log: Optional[SparkDataFrame],
198
- k: int,
199
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
200
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
201
- user_features: Optional[SparkDataFrame] = None,
202
- item_features: Optional[SparkDataFrame] = None,
203
- filter_seen_items: bool = True,
204
- recs_file_path: Optional[str] = None,
205
- ) -> Optional[SparkDataFrame]:
206
- """
207
- Predict wrapper to allow for fewer parameters in models
208
-
209
- :param log: historical log of interactions
210
- ``[user_idx, item_idx, timestamp, relevance]``
211
- :param k: number of recommendations for each user
212
- :param users: users to create recommendations for
213
- dataframe containing ``[user_idx]`` or ``array-like``;
214
- if ``None``, recommend to all users from ``log``
215
- :param items: candidate items for recommendations
216
- dataframe containing ``[item_idx]`` or ``array-like``;
217
- if ``None``, take all items from ``log``.
218
- If it contains new items, ``relevance`` for them will be ``0``.
219
- :param user_features: user features
220
- ``[user_idx , timestamp]`` + feature columns
221
- :param item_features: item features
222
- ``[item_idx , timestamp]`` + feature columns
223
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
224
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
225
- If None, cached and materialized recommendations dataframe will be returned
226
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
227
- or None if `file_path` is provided
228
- """
229
- log, users, items = self._filter_log_users_items_dataframes(log, k, users, items)
230
-
231
- recs = self._predict(
232
- log,
233
- k,
234
- users,
235
- items,
236
- user_features,
237
- item_features,
238
- filter_seen_items,
239
- )
240
- if filter_seen_items and log:
241
- recs = self._filter_seen(recs=recs, log=log, users=users, k=k)
242
-
243
- recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
244
-
245
- output = return_recs(recs, recs_file_path)
246
- self._clear_model_temp_view("filter_seen_users_log")
247
- self._clear_model_temp_view("filter_seen_num_seen")
248
- return output
249
-
250
- def _filter_cold_for_predict(
251
- self,
252
- main_df: SparkDataFrame,
253
- log_df: Optional[SparkDataFrame],
254
- entity: str,
255
- suffix: str = "idx",
256
- ):
257
- """
258
- Filter out cold entities (users/items) from the `main_df` and `log_df`
259
- if the model does not predict cold.
260
- Warn if cold entities are present in the `main_df`.
261
- """
262
- if getattr(self, f"can_predict_cold_{entity}s"):
263
- return main_df, log_df
264
-
265
- fit_entities = getattr(self, f"fit_{entity}s")
266
-
267
- num_new, main_df = filter_cold(main_df, fit_entities, col_name=f"{entity}_{suffix}")
268
- if num_new > 0:
269
- self.logger.info(
270
- "%s model can't predict cold %ss, they will be ignored",
271
- self,
272
- entity,
273
- )
274
- _, log_df = filter_cold(log_df, fit_entities, col_name=f"{entity}_{suffix}")
275
- return main_df, log_df
276
-
277
- @abstractmethod
278
- def _predict(
279
- self,
280
- log: SparkDataFrame,
281
- k: int,
282
- users: SparkDataFrame,
283
- items: SparkDataFrame,
284
- user_features: Optional[SparkDataFrame] = None,
285
- item_features: Optional[SparkDataFrame] = None,
286
- filter_seen_items: bool = True,
287
- ) -> SparkDataFrame:
288
- """
289
- Inner method where model actually predicts.
290
-
291
- :param log: historical log of interactions
292
- ``[user_idx, item_idx, timestamp, relevance]``
293
- :param k: number of recommendations for each user
294
- :param users: users to create recommendations for
295
- dataframe containing ``[user_idx]`` or ``array-like``;
296
- if ``None``, recommend to all users from ``log``
297
- :param items: candidate items for recommendations
298
- dataframe containing ``[item_idx]`` or ``array-like``;
299
- if ``None``, take all items from ``log``.
300
- If it contains new items, ``relevance`` for them will be ``0``.
301
- :param user_features: user features
302
- ``[user_idx , timestamp]`` + feature columns
303
- :param item_features: item features
304
- ``[item_idx , timestamp]`` + feature columns
305
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
306
- :return: recommendation dataframe
307
- ``[user_idx, item_idx, relevance]``
308
- """
309
-
310
- def _predict_proba(
311
- self,
312
- log: SparkDataFrame,
313
- k: int,
314
- users: SparkDataFrame,
315
- items: SparkDataFrame,
316
- user_features: Optional[SparkDataFrame] = None,
317
- item_features: Optional[SparkDataFrame] = None,
318
- filter_seen_items: bool = True,
319
- ) -> np.ndarray:
320
- """
321
- Inner method where model actually predicts probability estimates.
322
-
323
- Mainly used in ```OBPOfflinePolicyLearner```.
324
-
325
- :param log: historical log of interactions
326
- ``[user_idx, item_idx, timestamp, rating]``
327
- :param k: number of recommendations for each user
328
- :param users: users to create recommendations for
329
- dataframe containing ``[user_idx]`` or ``array-like``;
330
- if ``None``, recommend to all users from ``log``
331
- :param items: candidate items for recommendations
332
- dataframe containing ``[item_idx]`` or ``array-like``;
333
- if ``None``, take all items from ``log``.
334
- If it contains new items, ``rating`` for them will be ``0``.
335
- :param user_features: user features
336
- ``[user_idx , timestamp]`` + feature columns
337
- :param item_features: item features
338
- ``[item_idx , timestamp]`` + feature columns
339
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
340
- :return: distribution over items for each user with shape
341
- ``(n_users, n_items, k)``
342
- where we have probability for each user to choose item at fixed position(top-k).
343
- """
344
-
345
- n_users = users.select("user_idx").count()
346
- n_items = items.select("item_idx").count()
347
-
348
- recs = self._predict(
349
- log,
350
- k,
351
- users,
352
- items,
353
- user_features,
354
- item_features,
355
- filter_seen_items,
356
- )
357
-
358
- recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
359
-
360
- cols = [f"k{i}" for i in range(k)]
361
-
362
- recs_items = (
363
- recs.groupBy("user_idx")
364
- .agg(sf.collect_list("item_idx").alias("item_idx"))
365
- .select([sf.col("item_idx")[i].alias(cols[i]) for i in range(k)])
366
- )
367
-
368
- action_dist = np.zeros(shape=(n_users, n_items, k))
369
-
370
- for i in range(k):
371
- action_dist[
372
- np.arange(n_users),
373
- recs_items.select(cols[i]).toPandas()[cols[i]].to_numpy(),
374
- np.ones(n_users, dtype=int) * i,
375
- ] += 1
376
-
377
- return action_dist
378
-
379
- def _get_fit_counts(self, entity: str) -> int:
380
- if not hasattr(self, f"_num_{entity}s"):
381
- setattr(
382
- self,
383
- f"_num_{entity}s",
384
- getattr(self, f"fit_{entity}s").count(),
385
- )
386
- return getattr(self, f"_num_{entity}s")
387
-
388
- @property
389
- def users_count(self) -> int:
390
- """
391
- :returns: number of users the model was trained on
392
- """
393
- return self._get_fit_counts("user")
394
-
395
- @property
396
- def items_count(self) -> int:
397
- """
398
- :returns: number of items the model was trained on
399
- """
400
- return self._get_fit_counts("item")
401
-
402
- def _get_fit_dims(self, entity: str) -> int:
403
- if not hasattr(self, f"_{entity}_dim_size"):
404
- setattr(
405
- self,
406
- f"_{entity}_dim_size",
407
- getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
408
- )
409
- return getattr(self, f"_{entity}_dim_size")
410
-
411
- @property
412
- def _user_dim(self) -> int:
413
- """
414
- :returns: dimension of users matrix (maximal user idx + 1)
415
- """
416
- return self._get_fit_dims("user")
417
-
418
- @property
419
- def _item_dim(self) -> int:
420
- """
421
- :returns: dimension of items matrix (maximal item idx + 1)
422
- """
423
- return self._get_fit_dims("item")
424
-
425
- def _fit_predict(
426
- self,
427
- log: SparkDataFrame,
428
- k: int,
429
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
430
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
431
- user_features: Optional[SparkDataFrame] = None,
432
- item_features: Optional[SparkDataFrame] = None,
433
- filter_seen_items: bool = True,
434
- recs_file_path: Optional[str] = None,
435
- ) -> Optional[SparkDataFrame]:
436
- self._fit_wrap(log, user_features, item_features)
437
- return self._predict_wrap(
438
- log,
439
- k,
440
- users,
441
- items,
442
- user_features,
443
- item_features,
444
- filter_seen_items,
445
- recs_file_path=recs_file_path,
446
- )
447
-
448
- def _predict_pairs_wrap(
449
- self,
450
- pairs: SparkDataFrame,
451
- log: Optional[SparkDataFrame] = None,
452
- user_features: Optional[SparkDataFrame] = None,
453
- item_features: Optional[SparkDataFrame] = None,
454
- recs_file_path: Optional[str] = None,
455
- k: Optional[int] = None,
456
- ) -> Optional[SparkDataFrame]:
457
- """
458
- This method
459
- 1) converts data to spark
460
- 2) removes cold users and items if model does not predict them
461
- 3) calls inner _predict_pairs method of a model
462
-
463
- :param pairs: user-item pairs to get relevance for,
464
- dataframe containing``[user_idx, item_idx]``.
465
- :param log: train data
466
- ``[user_idx, item_idx, timestamp, relevance]``.
467
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
468
- If None, cached and materialized recommendations dataframe will be returned
469
- :return: cached dataframe with columns ``[user_idx, item_idx, relevance]``
470
- or None if `file_path` is provided
471
- """
472
- log, user_features, item_features, pairs = [
473
- convert2spark(df) for df in [log, user_features, item_features, pairs]
474
- ]
475
- if sorted(pairs.columns) != ["item_idx", "user_idx"]:
476
- msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
477
- raise ValueError(msg)
478
- pairs, log = self._filter_cold_for_predict(pairs, log, "user")
479
- pairs, log = self._filter_cold_for_predict(pairs, log, "item")
480
-
481
- pred = self._predict_pairs(
482
- pairs=pairs,
483
- log=log,
484
- user_features=user_features,
485
- item_features=item_features,
486
- )
487
-
488
- if k:
489
- pred = get_top_k(
490
- dataframe=pred,
491
- partition_by_col=sf.col("user_idx"),
492
- order_by_col=[
493
- sf.col("relevance").desc(),
494
- ],
495
- k=k,
496
- )
497
-
498
- if recs_file_path is not None:
499
- pred.write.parquet(path=recs_file_path, mode="overwrite")
500
- return None
501
-
502
- pred.cache().count()
503
- return pred
504
-
505
- def _predict_pairs(
506
- self,
507
- pairs: SparkDataFrame,
508
- log: Optional[SparkDataFrame] = None,
509
- user_features: Optional[SparkDataFrame] = None,
510
- item_features: Optional[SparkDataFrame] = None,
511
- ) -> SparkDataFrame:
512
- """
513
- Fallback method to use in case ``_predict_pairs`` is not implemented.
514
- Simply joins ``predict`` with given ``pairs``.
515
- :param pairs: user-item pairs to get relevance for,
516
- dataframe containing``[user_idx, item_idx]``.
517
- :param log: train data
518
- ``[user_idx, item_idx, timestamp, relevance]``.
519
- """
520
- message = (
521
- "native predict_pairs is not implemented for this model. "
522
- "Falling back to usual predict method and filtering the results."
523
- )
524
- self.logger.warning(message)
525
-
526
- users = pairs.select("user_idx").distinct()
527
- items = pairs.select("item_idx").distinct()
528
- k = items.count()
529
- pred = self._predict(
530
- log=log,
531
- k=k,
532
- users=users,
533
- items=items,
534
- user_features=user_features,
535
- item_features=item_features,
536
- filter_seen_items=False,
537
- )
538
-
539
- pred = pred.join(
540
- pairs.select("user_idx", "item_idx"),
541
- on=["user_idx", "item_idx"],
542
- how="inner",
543
- )
544
- return pred
545
-
546
- def _get_features_wrap(
547
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
548
- ) -> Optional[Tuple[SparkDataFrame, int]]:
549
- if "user_idx" not in ids.columns and "item_idx" not in ids.columns:
550
- msg = "user_idx or item_idx missing"
551
- raise ValueError(msg)
552
- vectors, rank = self._get_features(ids, features)
553
- return vectors, rank
554
-
555
- def _get_features(
556
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
557
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
558
- """
559
- Get embeddings from model
560
-
561
- :param ids: id ids to get embeddings for Spark DataFrame containing user_idx or item_idx
562
- :param features: user or item features
563
- :return: DataFrame with biases and embeddings, and vector size
564
- """
565
-
566
- self.logger.info(
567
- "get_features method is not defined for the model %s. Features will not be returned.",
568
- str(self),
569
- )
570
- return None, None
571
-
572
- def _get_nearest_items_wrap(
573
- self,
574
- items: Union[SparkDataFrame, Iterable],
575
- k: int,
576
- metric: Optional[str] = "cosine_similarity",
577
- candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
578
- ) -> Optional[SparkDataFrame]:
579
- """
580
- Convert indexes and leave top-k nearest items for each item in `items`.
581
- """
582
- items = get_unique_entities(items, "item_idx")
583
- if candidates is not None:
584
- candidates = get_unique_entities(candidates, "item_idx")
585
-
586
- nearest_items_to_filter = self._get_nearest_items(
587
- items=items,
588
- metric=metric,
589
- candidates=candidates,
590
- )
591
-
592
- rel_col_name = metric if metric is not None else "similarity"
593
- nearest_items = get_top_k(
594
- dataframe=nearest_items_to_filter,
595
- partition_by_col=sf.col("item_idx_one"),
596
- order_by_col=[
597
- sf.col(rel_col_name).desc(),
598
- sf.col("item_idx_two").desc(),
599
- ],
600
- k=k,
601
- )
602
-
603
- nearest_items = nearest_items.withColumnRenamed("item_idx_two", "neighbour_item_idx")
604
- nearest_items = nearest_items.withColumnRenamed("item_idx_one", "item_idx")
605
- return nearest_items
606
-
607
- def _get_nearest_items(
608
- self,
609
- items: SparkDataFrame, # noqa: ARG002
610
- metric: Optional[str] = None, # noqa: ARG002
611
- candidates: Optional[SparkDataFrame] = None, # noqa: ARG002
612
- ) -> Optional[SparkDataFrame]:
613
- msg = f"item-to-item prediction is not implemented for {self}"
614
- raise NotImplementedError(msg)
615
-
616
- def _save_model(self, path: str):
617
- pass
618
-
619
- def _load_model(self, path: str):
620
- pass
621
-
622
-
623
- class ItemVectorModel(BaseRecommender):
624
- """Parent for models generating items' vector representations"""
625
-
626
- can_predict_item_to_item: bool = True
627
- item_to_item_metrics: List[str] = [
628
- "euclidean_distance_sim",
629
- "cosine_similarity",
630
- "dot_product",
631
- ]
632
-
633
- @abstractmethod
634
- def _get_item_vectors(self) -> SparkDataFrame:
635
- """
636
- Return dataframe with items' vectors as a
637
- spark dataframe with columns ``[item_idx, item_vector]``
638
- """
639
-
640
- def get_nearest_items(
641
- self,
642
- items: Union[SparkDataFrame, Iterable],
643
- k: int,
644
- metric: Optional[str] = "cosine_similarity",
645
- candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
646
- ) -> Optional[SparkDataFrame]:
647
- """
648
- Get k most similar items be the `metric` for each of the `items`.
649
-
650
- :param items: spark dataframe or list of item ids to find neighbors
651
- :param k: number of neighbors
652
- :param metric: 'euclidean_distance_sim', 'cosine_similarity', 'dot_product'
653
- :param candidates: spark dataframe or list of items
654
- to consider as similar, e.g. popular/new items. If None,
655
- all items presented during model training are used.
656
- :return: dataframe with the most similar items,
657
- where bigger value means greater similarity.
658
- spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
659
- """
660
- if metric not in self.item_to_item_metrics:
661
- msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
662
- raise ValueError(msg)
663
-
664
- return self._get_nearest_items_wrap(
665
- items=items,
666
- k=k,
667
- metric=metric,
668
- candidates=candidates,
669
- )
670
-
671
- def _get_nearest_items(
672
- self,
673
- items: SparkDataFrame,
674
- metric: str = "cosine_similarity",
675
- candidates: Optional[SparkDataFrame] = None,
676
- ) -> SparkDataFrame:
677
- """
678
- Return distance metric value for all available close items filtered by `candidates`.
679
-
680
- :param items: ids to find neighbours, spark dataframe with column ``item_idx``
681
- :param metric: 'euclidean_distance_sim' calculated as 1/(1 + euclidean_distance),
682
- 'cosine_similarity', 'dot_product'
683
- :param candidates: items among which we are looking for similar,
684
- e.g. popular/new items. If None, all items presented during model training are used.
685
- :return: dataframe with neighbours,
686
- spark-dataframe with columns ``[item_idx_one, item_idx_two, similarity]``
687
- """
688
- dist_function = cosine_similarity
689
- if metric == "euclidean_distance_sim":
690
- dist_function = vector_euclidean_distance_similarity
691
- elif metric == "dot_product":
692
- dist_function = vector_dot
693
-
694
- items_vectors = self._get_item_vectors()
695
- left_part = (
696
- items_vectors.withColumnRenamed("item_idx", "item_idx_one")
697
- .withColumnRenamed("item_vector", "item_vector_one")
698
- .join(
699
- items.select(sf.col("item_idx").alias("item_idx_one")),
700
- on="item_idx_one",
701
- )
702
- )
703
-
704
- right_part = items_vectors.withColumnRenamed("item_idx", "item_idx_two").withColumnRenamed(
705
- "item_vector", "item_vector_two"
706
- )
707
-
708
- if candidates is not None:
709
- right_part = right_part.join(
710
- candidates.withColumnRenamed("item_idx", "item_idx_two"),
711
- on="item_idx_two",
712
- )
713
-
714
- joined_factors = left_part.join(right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two"))
715
-
716
- joined_factors = joined_factors.withColumn(
717
- metric,
718
- dist_function(sf.col("item_vector_one"), sf.col("item_vector_two")),
719
- )
720
-
721
- similarity_matrix = joined_factors.select("item_idx_one", "item_idx_two", metric)
722
-
723
- return similarity_matrix
724
-
725
-
726
- class HybridRecommender(BaseRecommender, ABC):
727
- """Base class for models that can use extra features"""
728
-
729
- def fit(
730
- self,
731
- log: SparkDataFrame,
732
- user_features: Optional[SparkDataFrame] = None,
733
- item_features: Optional[SparkDataFrame] = None,
734
- ) -> None:
735
- """
736
- Fit a recommendation model
737
-
738
- :param log: historical log of interactions
739
- ``[user_idx, item_idx, timestamp, relevance]``
740
- :param user_features: user features
741
- ``[user_idx, timestamp]`` + feature columns
742
- :param item_features: item features
743
- ``[item_idx, timestamp]`` + feature columns
744
- :return:
745
- """
746
- self._fit_wrap(
747
- log=log,
748
- user_features=user_features,
749
- item_features=item_features,
750
- )
751
-
752
- def predict(
753
- self,
754
- log: SparkDataFrame,
755
- k: int,
756
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
757
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
758
- user_features: Optional[SparkDataFrame] = None,
759
- item_features: Optional[SparkDataFrame] = None,
760
- filter_seen_items: bool = True,
761
- recs_file_path: Optional[str] = None,
762
- ) -> Optional[SparkDataFrame]:
763
- """
764
- Get recommendations
765
-
766
- :param log: historical log of interactions
767
- ``[user_idx, item_idx, timestamp, relevance]``
768
- :param k: number of recommendations for each user
769
- :param users: users to create recommendations for
770
- dataframe containing ``[user_idx]`` or ``array-like``;
771
- if ``None``, recommend to all users from ``log``
772
- :param items: candidate items for recommendations
773
- dataframe containing ``[item_idx]`` or ``array-like``;
774
- if ``None``, take all items from ``log``.
775
- If it contains new items, ``relevance`` for them will be ``0``.
776
- :param user_features: user features
777
- ``[user_idx , timestamp]`` + feature columns
778
- :param item_features: item features
779
- ``[item_idx , timestamp]`` + feature columns
780
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
781
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
782
- If None, cached and materialized recommendations dataframe will be returned
783
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
784
- or None if `file_path` is provided
785
-
786
- """
787
- return self._predict_wrap(
788
- log=log,
789
- k=k,
790
- users=users,
791
- items=items,
792
- user_features=user_features,
793
- item_features=item_features,
794
- filter_seen_items=filter_seen_items,
795
- recs_file_path=recs_file_path,
796
- )
797
-
798
- def fit_predict(
799
- self,
800
- log: SparkDataFrame,
801
- k: int,
802
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
803
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
804
- user_features: Optional[SparkDataFrame] = None,
805
- item_features: Optional[SparkDataFrame] = None,
806
- filter_seen_items: bool = True,
807
- recs_file_path: Optional[str] = None,
808
- ) -> Optional[SparkDataFrame]:
809
- """
810
- Fit model and get recommendations
811
-
812
- :param log: historical log of interactions
813
- ``[user_idx, item_idx, timestamp, relevance]``
814
- :param k: number of recommendations for each user
815
- :param users: users to create recommendations for
816
- dataframe containing ``[user_idx]`` or ``array-like``;
817
- if ``None``, recommend to all users from ``log``
818
- :param items: candidate items for recommendations
819
- dataframe containing ``[item_idx]`` or ``array-like``;
820
- if ``None``, take all items from ``log``.
821
- If it contains new items, ``relevance`` for them will be ``0``.
822
- :param user_features: user features
823
- ``[user_idx , timestamp]`` + feature columns
824
- :param item_features: item features
825
- ``[item_idx , timestamp]`` + feature columns
826
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
827
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
828
- If None, cached and materialized recommendations dataframe will be returned
829
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
830
- or None if `file_path` is provided
831
- """
832
- return self._fit_predict(
833
- log=log,
834
- k=k,
835
- users=users,
836
- items=items,
837
- user_features=user_features,
838
- item_features=item_features,
839
- filter_seen_items=filter_seen_items,
840
- recs_file_path=recs_file_path,
841
- )
842
-
843
- def predict_pairs(
844
- self,
845
- pairs: SparkDataFrame,
846
- log: Optional[SparkDataFrame] = None,
847
- user_features: Optional[SparkDataFrame] = None,
848
- item_features: Optional[SparkDataFrame] = None,
849
- recs_file_path: Optional[str] = None,
850
- k: Optional[int] = None,
851
- ) -> Optional[SparkDataFrame]:
852
- """
853
- Get recommendations for specific user-item ``pairs``.
854
- If a model can't produce recommendation
855
- for specific pair it is removed from the resulting dataframe.
856
-
857
- :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
858
- :param log: historical log of interactions
859
- ``[user_idx, item_idx, timestamp, relevance]``
860
- :param user_features: user features
861
- ``[user_idx , timestamp]`` + feature columns
862
- :param item_features: item features
863
- ``[item_idx , timestamp]`` + feature columns
864
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
865
- If None, cached and materialized recommendations dataframe will be returned
866
- :param k: top-k items for each user from pairs.
867
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
868
- or None if `file_path` is provided
869
- """
870
- return self._predict_pairs_wrap(
871
- pairs=pairs,
872
- log=log,
873
- user_features=user_features,
874
- item_features=item_features,
875
- recs_file_path=recs_file_path,
876
- k=k,
877
- )
878
-
879
- def get_features(
880
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
881
- ) -> Optional[Tuple[SparkDataFrame, int]]:
882
- """
883
- Returns user or item feature vectors as a Column with type ArrayType
884
- :param ids: Spark DataFrame with unique ids
885
- :param features: Spark DataFrame with features for provided ids
886
- :return: feature vectors
887
- If a model does not have a vector for some ids they are not present in the final result.
888
- """
889
- return self._get_features_wrap(ids, features)
890
-
891
-
892
- class Recommender(BaseRecommender, ABC):
893
- """Usual recommender class for models without features."""
894
-
895
- def fit(self, log: SparkDataFrame) -> None:
896
- """
897
- Fit a recommendation model
898
-
899
- :param log: historical log of interactions
900
- ``[user_idx, item_idx, timestamp, relevance]``
901
- :return:
902
- """
903
- self._fit_wrap(
904
- log=log,
905
- user_features=None,
906
- item_features=None,
907
- )
908
-
909
- def predict(
910
- self,
911
- log: SparkDataFrame,
912
- k: int,
913
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
914
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
915
- filter_seen_items: bool = True,
916
- recs_file_path: Optional[str] = None,
917
- ) -> Optional[SparkDataFrame]:
918
- """
919
- Get recommendations
920
-
921
- :param log: historical log of interactions
922
- ``[user_idx, item_idx, timestamp, relevance]``
923
- :param k: number of recommendations for each user
924
- :param users: users to create recommendations for
925
- dataframe containing ``[user_idx]`` or ``array-like``;
926
- if ``None``, recommend to all users from ``log``
927
- :param items: candidate items for recommendations
928
- dataframe containing ``[item_idx]`` or ``array-like``;
929
- if ``None``, take all items from ``log``.
930
- If it contains new items, ``relevance`` for them will be ``0``.
931
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
932
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
933
- If None, cached and materialized recommendations dataframe will be returned
934
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
935
- or None if `file_path` is provided
936
- """
937
- return self._predict_wrap(
938
- log=log,
939
- k=k,
940
- users=users,
941
- items=items,
942
- user_features=None,
943
- item_features=None,
944
- filter_seen_items=filter_seen_items,
945
- recs_file_path=recs_file_path,
946
- )
947
-
948
- def predict_pairs(
949
- self,
950
- pairs: SparkDataFrame,
951
- log: Optional[SparkDataFrame] = None,
952
- recs_file_path: Optional[str] = None,
953
- k: Optional[int] = None,
954
- ) -> Optional[SparkDataFrame]:
955
- """
956
- Get recommendations for specific user-item ``pairs``.
957
- If a model can't produce recommendation
958
- for specific pair it is removed from the resulting dataframe.
959
-
960
- :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
961
- :param log: historical log of interactions
962
- ``[user_idx, item_idx, timestamp, relevance]``
963
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
964
- If None, cached and materialized recommendations dataframe will be returned
965
- :param k: top-k items for each user from pairs.
966
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
967
- or None if `file_path` is provided
968
- """
969
- return self._predict_pairs_wrap(
970
- pairs=pairs,
971
- log=log,
972
- recs_file_path=recs_file_path,
973
- k=k,
974
- )
975
-
976
- def fit_predict(
977
- self,
978
- log: SparkDataFrame,
979
- k: int,
980
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
981
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
982
- filter_seen_items: bool = True,
983
- recs_file_path: Optional[str] = None,
984
- ) -> Optional[SparkDataFrame]:
985
- """
986
- Fit model and get recommendations
987
-
988
- :param log: historical log of interactions
989
- ``[user_idx, item_idx, timestamp, relevance]``
990
- :param k: number of recommendations for each user
991
- :param users: users to create recommendations for
992
- dataframe containing ``[user_idx]`` or ``array-like``;
993
- if ``None``, recommend to all users from ``log``
994
- :param items: candidate items for recommendations
995
- dataframe containing ``[item_idx]`` or ``array-like``;
996
- if ``None``, take all items from ``log``.
997
- If it contains new items, ``relevance`` for them will be ``0``.
998
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
999
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
1000
- If None, cached and materialized recommendations dataframe will be returned
1001
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
1002
- or None if `file_path` is provided
1003
- """
1004
- return self._fit_predict(
1005
- log=log,
1006
- k=k,
1007
- users=users,
1008
- items=items,
1009
- user_features=None,
1010
- item_features=None,
1011
- filter_seen_items=filter_seen_items,
1012
- recs_file_path=recs_file_path,
1013
- )
1014
-
1015
- def get_features(self, ids: SparkDataFrame) -> Optional[Tuple[SparkDataFrame, int]]:
1016
- """
1017
- Returns user or item feature vectors as a Column with type ArrayType
1018
-
1019
- :param ids: Spark DataFrame with unique ids
1020
- :return: feature vectors.
1021
- If a model does not have a vector for some ids they are not present in the final result.
1022
- """
1023
- return self._get_features_wrap(ids, None)
1024
-
1025
-
1026
- class UserRecommender(BaseRecommender, ABC):
1027
- """Base class for models that use user features
1028
- but not item features. ``log`` is not required for this class."""
1029
-
1030
- def fit(
1031
- self,
1032
- log: SparkDataFrame,
1033
- user_features: SparkDataFrame,
1034
- ) -> None:
1035
- """
1036
- Finds user clusters and calculates item similarity in that clusters.
1037
-
1038
- :param log: historical log of interactions
1039
- ``[user_idx, item_idx, timestamp, relevance]``
1040
- :param user_features: user features
1041
- ``[user_idx, timestamp]`` + feature columns
1042
- :return:
1043
- """
1044
- self._fit_wrap(log=log, user_features=user_features)
1045
-
1046
- def predict(
1047
- self,
1048
- user_features: SparkDataFrame,
1049
- k: int,
1050
- log: Optional[SparkDataFrame] = None,
1051
- users: Optional[Union[SparkDataFrame, Iterable]] = None,
1052
- items: Optional[Union[SparkDataFrame, Iterable]] = None,
1053
- filter_seen_items: bool = True,
1054
- recs_file_path: Optional[str] = None,
1055
- ) -> Optional[SparkDataFrame]:
1056
- """
1057
- Get recommendations
1058
-
1059
- :param log: historical log of interactions
1060
- ``[user_idx, item_idx, timestamp, relevance]``
1061
- :param k: number of recommendations for each user
1062
- :param users: users to create recommendations for
1063
- dataframe containing ``[user_idx]`` or ``array-like``;
1064
- if ``None``, recommend to all users from ``log``
1065
- :param items: candidate items for recommendations
1066
- dataframe containing ``[item_idx]`` or ``array-like``;
1067
- if ``None``, take all items from ``log``.
1068
- If it contains new items, ``relevance`` for them will be ``0``.
1069
- :param user_features: user features
1070
- ``[user_idx , timestamp]`` + feature columns
1071
- :param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
1072
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
1073
- If None, cached and materialized recommendations dataframe will be returned
1074
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
1075
- or None if `file_path` is provided
1076
- """
1077
- return self._predict_wrap(
1078
- log=log,
1079
- user_features=user_features,
1080
- k=k,
1081
- filter_seen_items=filter_seen_items,
1082
- users=users,
1083
- items=items,
1084
- recs_file_path=recs_file_path,
1085
- )
1086
-
1087
- def predict_pairs(
1088
- self,
1089
- pairs: SparkDataFrame,
1090
- user_features: SparkDataFrame,
1091
- log: Optional[SparkDataFrame] = None,
1092
- recs_file_path: Optional[str] = None,
1093
- k: Optional[int] = None,
1094
- ) -> Optional[SparkDataFrame]:
1095
- """
1096
- Get recommendations for specific user-item ``pairs``.
1097
- If a model can't produce recommendation
1098
- for specific pair it is removed from the resulting dataframe.
1099
-
1100
- :param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
1101
- :param user_features: user features
1102
- ``[user_idx , timestamp]`` + feature columns
1103
- :param log: historical log of interactions
1104
- ``[user_idx, item_idx, timestamp, relevance]``
1105
- :param recs_file_path: save recommendations at the given absolute path as parquet file.
1106
- If None, cached and materialized recommendations dataframe will be returned
1107
- :param k: top-k items for each user from pairs.
1108
- :return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
1109
- or None if `file_path` is provided
1110
- """
1111
- return self._predict_pairs_wrap(
1112
- pairs=pairs,
1113
- log=log,
1114
- user_features=user_features,
1115
- recs_file_path=recs_file_path,
1116
- k=k,
1117
- )
1118
-
1119
-
1120
- class NonPersonalizedRecommender(Recommender, ABC):
1121
- """Base class for non-personalized recommenders with popularity statistics."""
1122
-
1123
- can_predict_cold_users = True
1124
- can_predict_cold_items = True
1125
- item_popularity: SparkDataFrame
1126
- add_cold_items: bool
1127
- cold_weight: float
1128
- sample: bool
1129
- fill: float
1130
- seed: Optional[int] = None
1131
-
1132
- def __init__(self, add_cold_items: bool, cold_weight: float):
1133
- self.add_cold_items = add_cold_items
1134
- if 0 < cold_weight <= 1:
1135
- self.cold_weight = cold_weight
1136
- else:
1137
- msg = "`cold_weight` value should be in interval (0, 1]"
1138
- raise ValueError(msg)
1139
-
1140
- @property
1141
- def _dataframes(self):
1142
- return {"item_popularity": self.item_popularity}
1143
-
1144
- def _save_model(self, path: str):
1145
- save_picklable_to_parquet(self.fill, join(path, "params.dump"))
1146
-
1147
- def _load_model(self, path: str):
1148
- self.fill = load_pickled_from_parquet(join(path, "params.dump"))
1149
-
1150
- def _clear_cache(self):
1151
- if hasattr(self, "item_popularity"):
1152
- self.item_popularity.unpersist()
1153
-
1154
- @staticmethod
1155
- def _calc_fill(item_popularity: SparkDataFrame, weight: float) -> float:
1156
- """
1157
- Calculating a fill value a the minimal relevance
1158
- calculated during model training multiplied by weight.
1159
- """
1160
- return item_popularity.select(sf.min("relevance")).first()[0] * weight
1161
-
1162
- @staticmethod
1163
- def _check_relevance(log: SparkDataFrame):
1164
- vals = log.select("relevance").where((sf.col("relevance") != 1) & (sf.col("relevance") != 0))
1165
- if vals.count() > 0:
1166
- msg = "Relevance values in log must be 0 or 1"
1167
- raise ValueError(msg)
1168
-
1169
- def _get_selected_item_popularity(self, items: SparkDataFrame) -> SparkDataFrame:
1170
- """
1171
- Choose only required item from `item_popularity` dataframe
1172
- for further recommendations generation.
1173
- """
1174
- return self.item_popularity.join(
1175
- items,
1176
- on="item_idx",
1177
- how="right" if self.add_cold_items else "inner",
1178
- ).fillna(value=self.fill, subset=["relevance"])
1179
-
1180
- @staticmethod
1181
- def _calc_max_hist_len(log: SparkDataFrame, users: SparkDataFrame) -> int:
1182
- max_hist_len = (
1183
- (log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
1184
- .select(sf.max("items_count"))
1185
- .first()[0]
1186
- )
1187
- # all users have empty history
1188
- if max_hist_len is None:
1189
- max_hist_len = 0
1190
-
1191
- return max_hist_len
1192
-
1193
- def _predict_without_sampling(
1194
- self,
1195
- log: SparkDataFrame,
1196
- k: int,
1197
- users: SparkDataFrame,
1198
- items: SparkDataFrame,
1199
- filter_seen_items: bool = True,
1200
- ) -> SparkDataFrame:
1201
- """
1202
- Regular prediction for popularity-based models,
1203
- top-k most relevant items from `items` are chosen for each user
1204
- """
1205
- selected_item_popularity = self._get_selected_item_popularity(items)
1206
- selected_item_popularity = selected_item_popularity.withColumn(
1207
- "rank",
1208
- sf.row_number().over(Window.orderBy(sf.col("relevance").desc(), sf.col("item_idx").desc())),
1209
- )
1210
-
1211
- if filter_seen_items and log is not None:
1212
- user_to_num_items = (
1213
- log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("num_items"))
1214
- )
1215
- users = users.join(user_to_num_items, on="user_idx", how="left")
1216
- users = users.fillna(0, "num_items")
1217
- # 'selected_item_popularity' truncation by k + max_seen
1218
- max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1219
- selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1220
- return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1221
-
1222
- return users.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
1223
-
1224
- def _predict_with_sampling(
1225
- self,
1226
- log: SparkDataFrame,
1227
- k: int,
1228
- users: SparkDataFrame,
1229
- items: SparkDataFrame,
1230
- filter_seen_items: bool = True,
1231
- ) -> SparkDataFrame:
1232
- """
1233
- Randomized prediction for popularity-based models,
1234
- top-k items from `items` are sampled for each user based with
1235
- probability proportional to items' popularity
1236
- """
1237
- selected_item_popularity = self._get_selected_item_popularity(items)
1238
- selected_item_popularity = selected_item_popularity.withColumn(
1239
- "relevance",
1240
- sf.when(sf.col("relevance") == sf.lit(0.0), 0.1**6).otherwise(sf.col("relevance")),
1241
- )
1242
-
1243
- items_pd = selected_item_popularity.withColumn(
1244
- "probability",
1245
- sf.col("relevance") / selected_item_popularity.select(sf.sum("relevance")).first()[0],
1246
- ).toPandas()
1247
-
1248
- rec_schema = get_schema(
1249
- query_column="user_idx",
1250
- item_column="item_idx",
1251
- rating_column="relevance",
1252
- has_timestamp=False,
1253
- )
1254
- if items_pd.shape[0] == 0:
1255
- return State().session.createDataFrame([], rec_schema)
1256
-
1257
- seed = self.seed
1258
- class_name = self.__class__.__name__
1259
-
1260
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
1261
- user_idx = pandas_df["user_idx"][0]
1262
- cnt = pandas_df["cnt"][0]
1263
-
1264
- local_rng = default_rng(seed + user_idx) if seed is not None else default_rng()
1265
-
1266
- items_positions = local_rng.choice(
1267
- np.arange(items_pd.shape[0]),
1268
- size=cnt,
1269
- p=items_pd["probability"].values,
1270
- replace=False,
1271
- )
1272
-
1273
- # workaround to unify RandomRec and UCB
1274
- if class_name == "RandomRec":
1275
- relevance = 1 / np.arange(1, cnt + 1)
1276
- else:
1277
- relevance = items_pd["probability"].values[items_positions]
1278
-
1279
- return PandasDataFrame(
1280
- {
1281
- "user_idx": cnt * [user_idx],
1282
- "item_idx": items_pd["item_idx"].values[items_positions],
1283
- "relevance": relevance,
1284
- }
1285
- )
1286
-
1287
- if log is not None and filter_seen_items:
1288
- recs = (
1289
- log.select("user_idx", "item_idx")
1290
- .distinct()
1291
- .join(users, how="right", on="user_idx")
1292
- .groupby("user_idx")
1293
- .agg(sf.countDistinct("item_idx").alias("cnt"))
1294
- .selectExpr(
1295
- "user_idx",
1296
- f"LEAST(cnt + {k}, {items_pd.shape[0]}) AS cnt",
1297
- )
1298
- )
1299
- else:
1300
- recs = users.withColumn("cnt", sf.lit(min(k, items_pd.shape[0])))
1301
-
1302
- return recs.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
1303
-
1304
- def _predict(
1305
- self,
1306
- log: SparkDataFrame,
1307
- k: int,
1308
- users: SparkDataFrame,
1309
- items: SparkDataFrame,
1310
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1311
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1312
- filter_seen_items: bool = True,
1313
- ) -> SparkDataFrame:
1314
- if self.sample:
1315
- return self._predict_with_sampling(
1316
- log=log,
1317
- k=k,
1318
- users=users,
1319
- items=items,
1320
- filter_seen_items=filter_seen_items,
1321
- )
1322
- else:
1323
- return self._predict_without_sampling(log, k, users, items, filter_seen_items)
1324
-
1325
- def _predict_pairs(
1326
- self,
1327
- pairs: SparkDataFrame,
1328
- log: Optional[SparkDataFrame] = None, # noqa: ARG002
1329
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1330
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
1331
- ) -> SparkDataFrame:
1332
- return (
1333
- pairs.join(
1334
- self.item_popularity,
1335
- on="item_idx",
1336
- how="left" if self.add_cold_items else "inner",
1337
- )
1338
- .fillna(value=self.fill, subset=["relevance"])
1339
- .select("user_idx", "item_idx", "relevance")
1340
- )