replay-rec 0.20.2__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/nn/sequential_dataset.py +8 -2
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +62 -0
- replay/experimental/metrics/base_metric.py +603 -0
- replay/experimental/metrics/coverage.py +97 -0
- replay/experimental/metrics/experiment.py +175 -0
- replay/experimental/metrics/hitrate.py +26 -0
- replay/experimental/metrics/map.py +30 -0
- replay/experimental/metrics/mrr.py +18 -0
- replay/experimental/metrics/ncis_precision.py +31 -0
- replay/experimental/metrics/ndcg.py +49 -0
- replay/experimental/metrics/precision.py +22 -0
- replay/experimental/metrics/recall.py +25 -0
- replay/experimental/metrics/rocauc.py +49 -0
- replay/experimental/metrics/surprisal.py +90 -0
- replay/experimental/metrics/unexpectedness.py +76 -0
- replay/experimental/models/__init__.py +50 -0
- replay/experimental/models/admm_slim.py +257 -0
- replay/experimental/models/base_neighbour_rec.py +200 -0
- replay/experimental/models/base_rec.py +1386 -0
- replay/experimental/models/base_torch_rec.py +234 -0
- replay/experimental/models/cql.py +454 -0
- replay/experimental/models/ddpg.py +932 -0
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +189 -0
- replay/experimental/models/dt4rec/gpt1.py +401 -0
- replay/experimental/models/dt4rec/trainer.py +127 -0
- replay/experimental/models/dt4rec/utils.py +264 -0
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
- replay/experimental/models/hierarchical_recommender.py +331 -0
- replay/experimental/models/implicit_wrap.py +131 -0
- replay/experimental/models/lightfm_wrap.py +303 -0
- replay/experimental/models/mult_vae.py +332 -0
- replay/experimental/models/neural_ts.py +986 -0
- replay/experimental/models/neuromf.py +406 -0
- replay/experimental/models/scala_als.py +293 -0
- replay/experimental/models/u_lin_ucb.py +115 -0
- replay/experimental/nn/data/__init__.py +1 -0
- replay/experimental/nn/data/schema_builder.py +102 -0
- replay/experimental/preprocessing/__init__.py +3 -0
- replay/experimental/preprocessing/data_preparator.py +839 -0
- replay/experimental/preprocessing/padder.py +229 -0
- replay/experimental/preprocessing/sequence_generator.py +208 -0
- replay/experimental/scenarios/__init__.py +1 -0
- replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
- replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +117 -0
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +24 -0
- replay/experimental/utils/model_handler.py +186 -0
- replay/experimental/utils/session_handler.py +44 -0
- {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
- {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +62 -7
- {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,1386 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base abstract classes:
|
|
3
|
+
- BaseRecommender - the simplest base class
|
|
4
|
+
- Recommender - base class for models that fit on interaction log
|
|
5
|
+
- HybridRecommender - base class for models that accept user or item features
|
|
6
|
+
- UserRecommender - base class that accepts only user features, but not item features
|
|
7
|
+
- NeighbourRec - base class that requires log at prediction time
|
|
8
|
+
- ItemVectorModel - class for models which provides items' vectors.
|
|
9
|
+
Implements similar items search.
|
|
10
|
+
- NonPersonalizedRecommender - base class for non-personalized recommenders
|
|
11
|
+
with popularity statistics
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from collections.abc import Iterable
|
|
16
|
+
from os.path import join
|
|
17
|
+
from typing import Any, Optional, Union
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from numpy.random import default_rng
|
|
21
|
+
|
|
22
|
+
from replay.data import get_schema
|
|
23
|
+
from replay.experimental.utils.session_handler import State
|
|
24
|
+
from replay.models.base_rec import IsSavable
|
|
25
|
+
from replay.models.common import RecommenderCommons
|
|
26
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
27
|
+
from replay.utils.spark_utils import (
|
|
28
|
+
convert2spark,
|
|
29
|
+
cosine_similarity,
|
|
30
|
+
filter_cold,
|
|
31
|
+
get_top_k,
|
|
32
|
+
get_top_k_recs,
|
|
33
|
+
get_unique_entities,
|
|
34
|
+
load_pickled_from_parquet,
|
|
35
|
+
return_recs,
|
|
36
|
+
save_picklable_to_parquet,
|
|
37
|
+
vector_dot,
|
|
38
|
+
vector_euclidean_distance_similarity,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if PYSPARK_AVAILABLE:
|
|
42
|
+
from pyspark.sql import (
|
|
43
|
+
Window,
|
|
44
|
+
functions as sf,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
49
|
+
"""Base recommender"""
|
|
50
|
+
|
|
51
|
+
model: Any
|
|
52
|
+
can_predict_cold_users: bool = False
|
|
53
|
+
can_predict_cold_items: bool = False
|
|
54
|
+
fit_users: SparkDataFrame
|
|
55
|
+
fit_items: SparkDataFrame
|
|
56
|
+
_num_users: int
|
|
57
|
+
_num_items: int
|
|
58
|
+
_user_dim_size: int
|
|
59
|
+
_item_dim_size: int
|
|
60
|
+
|
|
61
|
+
def _fit_wrap(
|
|
62
|
+
self,
|
|
63
|
+
log: SparkDataFrame,
|
|
64
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
65
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Wrapper for fit to allow for fewer arguments in a model.
|
|
69
|
+
|
|
70
|
+
:param log: historical log of interactions
|
|
71
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
72
|
+
:param user_features: user features
|
|
73
|
+
``[user_idx, timestamp]`` + feature columns
|
|
74
|
+
:param item_features: item features
|
|
75
|
+
``[item_idx, timestamp]`` + feature columns
|
|
76
|
+
:return:
|
|
77
|
+
"""
|
|
78
|
+
self.logger.debug("Starting fit %s", type(self).__name__)
|
|
79
|
+
if user_features is None:
|
|
80
|
+
users = log.select("user_idx").distinct()
|
|
81
|
+
else:
|
|
82
|
+
users = log.select("user_idx").union(user_features.select("user_idx")).distinct()
|
|
83
|
+
if item_features is None:
|
|
84
|
+
items = log.select("item_idx").distinct()
|
|
85
|
+
else:
|
|
86
|
+
items = log.select("item_idx").union(item_features.select("item_idx")).distinct()
|
|
87
|
+
self.fit_users = sf.broadcast(users)
|
|
88
|
+
self.fit_items = sf.broadcast(items)
|
|
89
|
+
self._num_users = self.fit_users.count()
|
|
90
|
+
self._num_items = self.fit_items.count()
|
|
91
|
+
self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
|
|
92
|
+
self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
|
|
93
|
+
self._fit(log, user_features, item_features)
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def _fit(
|
|
97
|
+
self,
|
|
98
|
+
log: SparkDataFrame,
|
|
99
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
100
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Inner method where model actually fits.
|
|
104
|
+
|
|
105
|
+
:param log: historical log of interactions
|
|
106
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
107
|
+
:param user_features: user features
|
|
108
|
+
``[user_idx, timestamp]`` + feature columns
|
|
109
|
+
:param item_features: item features
|
|
110
|
+
``[item_idx, timestamp]`` + feature columns
|
|
111
|
+
:return:
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def _filter_seen(self, recs: SparkDataFrame, log: SparkDataFrame, k: int, users: SparkDataFrame):
|
|
115
|
+
"""
|
|
116
|
+
Filter seen items (presented in log) out of the users' recommendations.
|
|
117
|
+
For each user return from `k` to `k + number of seen by user` recommendations.
|
|
118
|
+
"""
|
|
119
|
+
users_log = log.join(users, on="user_idx")
|
|
120
|
+
self._cache_model_temp_view(users_log, "filter_seen_users_log")
|
|
121
|
+
num_seen = users_log.groupBy("user_idx").agg(sf.count("item_idx").alias("seen_count"))
|
|
122
|
+
self._cache_model_temp_view(num_seen, "filter_seen_num_seen")
|
|
123
|
+
|
|
124
|
+
# count maximal number of items seen by users
|
|
125
|
+
max_seen = 0
|
|
126
|
+
if num_seen.count() > 0:
|
|
127
|
+
max_seen = num_seen.select(sf.max("seen_count")).first()[0]
|
|
128
|
+
|
|
129
|
+
# crop recommendations to first k + max_seen items for each user
|
|
130
|
+
recs = recs.withColumn(
|
|
131
|
+
"temp_rank",
|
|
132
|
+
sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.col("relevance").desc())),
|
|
133
|
+
).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
|
|
134
|
+
|
|
135
|
+
# leave k + number of items seen by user recommendations in recs
|
|
136
|
+
recs = (
|
|
137
|
+
recs.join(num_seen, on="user_idx", how="left")
|
|
138
|
+
.fillna(0)
|
|
139
|
+
.filter(sf.col("temp_rank") <= sf.col("seen_count") + sf.lit(k))
|
|
140
|
+
.drop("temp_rank", "seen_count")
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# filter recommendations presented in interactions log
|
|
144
|
+
recs = recs.join(
|
|
145
|
+
users_log.withColumnRenamed("item_idx", "item")
|
|
146
|
+
.withColumnRenamed("user_idx", "user")
|
|
147
|
+
.select("user", "item"),
|
|
148
|
+
on=(sf.col("user_idx") == sf.col("user")) & (sf.col("item_idx") == sf.col("item")),
|
|
149
|
+
how="anti",
|
|
150
|
+
).drop("user", "item")
|
|
151
|
+
|
|
152
|
+
return recs
|
|
153
|
+
|
|
154
|
+
def _filter_log_users_items_dataframes(
|
|
155
|
+
self,
|
|
156
|
+
log: Optional[SparkDataFrame],
|
|
157
|
+
k: int,
|
|
158
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
159
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
160
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
161
|
+
):
|
|
162
|
+
"""
|
|
163
|
+
Returns triplet of filtered `log`, `users`, and `items`.
|
|
164
|
+
Filters out cold entities (users/items) from the `users`/`items` and `log` dataframes
|
|
165
|
+
if the model does not predict cold.
|
|
166
|
+
Filters out duplicates from `users` and `items` dataframes,
|
|
167
|
+
and excludes all columns except `user_idx` and `item_idx`.
|
|
168
|
+
|
|
169
|
+
:param log: historical log of interactions
|
|
170
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
171
|
+
:param k: number of recommendations for each user
|
|
172
|
+
:param users: users to create recommendations for
|
|
173
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
174
|
+
if ``None``, recommend to all users from ``log``
|
|
175
|
+
:param items: candidate items for recommendations
|
|
176
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
177
|
+
if ``None``, take all items from ``log``.
|
|
178
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
179
|
+
:param user_features: user features
|
|
180
|
+
``[user_idx , timestamp]`` + feature columns
|
|
181
|
+
:return: triplet of filtered `log`, `users`, and `items` dataframes.
|
|
182
|
+
"""
|
|
183
|
+
self.logger.debug("Starting predict %s", type(self).__name__)
|
|
184
|
+
user_data = users or log or user_features or self.fit_users
|
|
185
|
+
users = get_unique_entities(user_data, "user_idx")
|
|
186
|
+
users, log = self._filter_cold_for_predict(users, log, "user")
|
|
187
|
+
|
|
188
|
+
item_data = items or self.fit_items
|
|
189
|
+
items = get_unique_entities(item_data, "item_idx")
|
|
190
|
+
items, log = self._filter_cold_for_predict(items, log, "item")
|
|
191
|
+
num_items = items.count()
|
|
192
|
+
if num_items < k:
|
|
193
|
+
message = f"k = {k} > number of items = {num_items}"
|
|
194
|
+
self.logger.debug(message)
|
|
195
|
+
return log, users, items
|
|
196
|
+
|
|
197
|
+
def _predict_wrap(
|
|
198
|
+
self,
|
|
199
|
+
log: Optional[SparkDataFrame],
|
|
200
|
+
k: int,
|
|
201
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
202
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
203
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
204
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
205
|
+
filter_seen_items: bool = True,
|
|
206
|
+
recs_file_path: Optional[str] = None,
|
|
207
|
+
) -> Optional[SparkDataFrame]:
|
|
208
|
+
"""
|
|
209
|
+
Predict wrapper to allow for fewer parameters in models
|
|
210
|
+
|
|
211
|
+
:param log: historical log of interactions
|
|
212
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
213
|
+
:param k: number of recommendations for each user
|
|
214
|
+
:param users: users to create recommendations for
|
|
215
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
216
|
+
if ``None``, recommend to all users from ``log``
|
|
217
|
+
:param items: candidate items for recommendations
|
|
218
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
219
|
+
if ``None``, take all items from ``log``.
|
|
220
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
221
|
+
:param user_features: user features
|
|
222
|
+
``[user_idx , timestamp]`` + feature columns
|
|
223
|
+
:param item_features: item features
|
|
224
|
+
``[item_idx , timestamp]`` + feature columns
|
|
225
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
226
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
227
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
228
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
229
|
+
or None if `file_path` is provided
|
|
230
|
+
"""
|
|
231
|
+
log, users, items = self._filter_log_users_items_dataframes(log, k, users, items)
|
|
232
|
+
|
|
233
|
+
recs = self._predict(
|
|
234
|
+
log,
|
|
235
|
+
k,
|
|
236
|
+
users,
|
|
237
|
+
items,
|
|
238
|
+
user_features,
|
|
239
|
+
item_features,
|
|
240
|
+
filter_seen_items,
|
|
241
|
+
)
|
|
242
|
+
if filter_seen_items and log:
|
|
243
|
+
recs = self._filter_seen(recs=recs, log=log, users=users, k=k)
|
|
244
|
+
|
|
245
|
+
recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
|
|
246
|
+
|
|
247
|
+
output = return_recs(recs, recs_file_path)
|
|
248
|
+
self._clear_model_temp_view("filter_seen_users_log")
|
|
249
|
+
self._clear_model_temp_view("filter_seen_num_seen")
|
|
250
|
+
return output
|
|
251
|
+
|
|
252
|
+
def _filter_cold_for_predict(
|
|
253
|
+
self,
|
|
254
|
+
main_df: SparkDataFrame,
|
|
255
|
+
log_df: Optional[SparkDataFrame],
|
|
256
|
+
entity: str,
|
|
257
|
+
suffix: str = "idx",
|
|
258
|
+
):
|
|
259
|
+
"""
|
|
260
|
+
Filter out cold entities (users/items) from the `main_df` and `log_df`
|
|
261
|
+
if the model does not predict cold.
|
|
262
|
+
Warn if cold entities are present in the `main_df`.
|
|
263
|
+
"""
|
|
264
|
+
if getattr(self, f"can_predict_cold_{entity}s"):
|
|
265
|
+
return main_df, log_df
|
|
266
|
+
|
|
267
|
+
fit_entities = getattr(self, f"fit_{entity}s")
|
|
268
|
+
|
|
269
|
+
num_new, main_df = filter_cold(main_df, fit_entities, col_name=f"{entity}_{suffix}")
|
|
270
|
+
if num_new > 0:
|
|
271
|
+
self.logger.info(
|
|
272
|
+
"%s model can't predict cold %ss, they will be ignored",
|
|
273
|
+
self,
|
|
274
|
+
entity,
|
|
275
|
+
)
|
|
276
|
+
_, log_df = filter_cold(log_df, fit_entities, col_name=f"{entity}_{suffix}")
|
|
277
|
+
return main_df, log_df
|
|
278
|
+
|
|
279
|
+
@abstractmethod
|
|
280
|
+
def _predict(
|
|
281
|
+
self,
|
|
282
|
+
log: SparkDataFrame,
|
|
283
|
+
k: int,
|
|
284
|
+
users: SparkDataFrame,
|
|
285
|
+
items: SparkDataFrame,
|
|
286
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
287
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
288
|
+
filter_seen_items: bool = True,
|
|
289
|
+
) -> SparkDataFrame:
|
|
290
|
+
"""
|
|
291
|
+
Inner method where model actually predicts.
|
|
292
|
+
|
|
293
|
+
:param log: historical log of interactions
|
|
294
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
295
|
+
:param k: number of recommendations for each user
|
|
296
|
+
:param users: users to create recommendations for
|
|
297
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
298
|
+
if ``None``, recommend to all users from ``log``
|
|
299
|
+
:param items: candidate items for recommendations
|
|
300
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
301
|
+
if ``None``, take all items from ``log``.
|
|
302
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
303
|
+
:param user_features: user features
|
|
304
|
+
``[user_idx , timestamp]`` + feature columns
|
|
305
|
+
:param item_features: item features
|
|
306
|
+
``[item_idx , timestamp]`` + feature columns
|
|
307
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
308
|
+
:return: recommendation dataframe
|
|
309
|
+
``[user_idx, item_idx, relevance]``
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def _predict_proba(
|
|
313
|
+
self,
|
|
314
|
+
log: SparkDataFrame,
|
|
315
|
+
k: int,
|
|
316
|
+
users: SparkDataFrame,
|
|
317
|
+
items: SparkDataFrame,
|
|
318
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
319
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
320
|
+
filter_seen_items: bool = True,
|
|
321
|
+
) -> np.ndarray:
|
|
322
|
+
"""
|
|
323
|
+
Inner method where model actually predicts probability estimates.
|
|
324
|
+
|
|
325
|
+
Mainly used in ```OBPOfflinePolicyLearner```.
|
|
326
|
+
|
|
327
|
+
:param log: historical log of interactions
|
|
328
|
+
``[user_idx, item_idx, timestamp, rating]``
|
|
329
|
+
:param k: number of recommendations for each user
|
|
330
|
+
:param users: users to create recommendations for
|
|
331
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
332
|
+
if ``None``, recommend to all users from ``log``
|
|
333
|
+
:param items: candidate items for recommendations
|
|
334
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
335
|
+
if ``None``, take all items from ``log``.
|
|
336
|
+
If it contains new items, ``rating`` for them will be ``0``.
|
|
337
|
+
:param user_features: user features
|
|
338
|
+
``[user_idx , timestamp]`` + feature columns
|
|
339
|
+
:param item_features: item features
|
|
340
|
+
``[item_idx , timestamp]`` + feature columns
|
|
341
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
342
|
+
:return: distribution over items for each user with shape
|
|
343
|
+
``(n_users, n_items, k)``
|
|
344
|
+
where we have probability for each user to choose item at fixed position(top-k).
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
n_users = users.select("user_idx").count()
|
|
348
|
+
n_items = items.select("item_idx").count()
|
|
349
|
+
|
|
350
|
+
recs = self._predict(
|
|
351
|
+
log,
|
|
352
|
+
k,
|
|
353
|
+
users,
|
|
354
|
+
items,
|
|
355
|
+
user_features,
|
|
356
|
+
item_features,
|
|
357
|
+
filter_seen_items,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
|
|
361
|
+
|
|
362
|
+
cols = [f"k{i}" for i in range(k)]
|
|
363
|
+
|
|
364
|
+
recs_items = (
|
|
365
|
+
recs.groupBy("user_idx")
|
|
366
|
+
.agg(sf.collect_list("item_idx").alias("item_idx"))
|
|
367
|
+
.select([sf.col("item_idx")[i].alias(cols[i]) for i in range(k)])
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
action_dist = np.zeros(shape=(n_users, n_items, k))
|
|
371
|
+
|
|
372
|
+
for i in range(k):
|
|
373
|
+
action_dist[
|
|
374
|
+
np.arange(n_users),
|
|
375
|
+
recs_items.select(cols[i]).toPandas()[cols[i]].to_numpy(),
|
|
376
|
+
np.ones(n_users, dtype=int) * i,
|
|
377
|
+
] += 1
|
|
378
|
+
|
|
379
|
+
return action_dist
|
|
380
|
+
|
|
381
|
+
def _filter_interactions_queries_items_dataframes(
|
|
382
|
+
self,
|
|
383
|
+
log: Optional[SparkDataFrame],
|
|
384
|
+
k: int,
|
|
385
|
+
queries: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
386
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
387
|
+
):
|
|
388
|
+
"""
|
|
389
|
+
Returns triplet of filtered `dataset`, `queries`, and `items`.
|
|
390
|
+
Filters out cold entities (queries/items) from the `queries`/`items` and `dataset`
|
|
391
|
+
if the model does not predict cold.
|
|
392
|
+
Filters out duplicates from `queries` and `items` dataframes,
|
|
393
|
+
and excludes all columns except `user_idx` and `item_idx`.
|
|
394
|
+
|
|
395
|
+
:param dataset: historical interactions with query/item features
|
|
396
|
+
``[user_idx, item_idx, timestamp, rating]``
|
|
397
|
+
:param k: number of recommendations for each user
|
|
398
|
+
:param queries: queries to create recommendations for
|
|
399
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
400
|
+
if ``None``, recommend to all queries from ``dataset``
|
|
401
|
+
:param items: candidate items for recommendations
|
|
402
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
403
|
+
if ``None``, take all items from ``dataset``.
|
|
404
|
+
If it contains new items, ``rating`` for them will be ``0``.
|
|
405
|
+
:return: triplet of filtered `dataset`, `queries`, and `items`.
|
|
406
|
+
"""
|
|
407
|
+
self.logger.debug("Starting predict %s", type(self).__name__)
|
|
408
|
+
|
|
409
|
+
query_data = queries or log or self.fit_users
|
|
410
|
+
interactions = log
|
|
411
|
+
|
|
412
|
+
queries = get_unique_entities(query_data, "user_idx")
|
|
413
|
+
queries, interactions = self._filter_cold_for_predict(queries, interactions, "user")
|
|
414
|
+
|
|
415
|
+
item_data = items or self.fit_items
|
|
416
|
+
items = get_unique_entities(item_data, "item_idx")
|
|
417
|
+
items, interactions = self._filter_cold_for_predict(items, interactions, "item")
|
|
418
|
+
num_items = items.count()
|
|
419
|
+
if num_items < k:
|
|
420
|
+
message = f"k = {k} > number of items = {num_items}"
|
|
421
|
+
self.logger.debug(message)
|
|
422
|
+
|
|
423
|
+
return log, queries, items
|
|
424
|
+
|
|
425
|
+
def _get_fit_counts(self, entity: str) -> int:
|
|
426
|
+
if not hasattr(self, f"_num_{entity}s"):
|
|
427
|
+
setattr(
|
|
428
|
+
self,
|
|
429
|
+
f"_num_{entity}s",
|
|
430
|
+
getattr(self, f"fit_{entity}s").count(),
|
|
431
|
+
)
|
|
432
|
+
return getattr(self, f"_num_{entity}s")
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def users_count(self) -> int:
|
|
436
|
+
"""
|
|
437
|
+
:returns: number of users the model was trained on
|
|
438
|
+
"""
|
|
439
|
+
return self._get_fit_counts("user")
|
|
440
|
+
|
|
441
|
+
@property
|
|
442
|
+
def items_count(self) -> int:
|
|
443
|
+
"""
|
|
444
|
+
:returns: number of items the model was trained on
|
|
445
|
+
"""
|
|
446
|
+
return self._get_fit_counts("item")
|
|
447
|
+
|
|
448
|
+
def _get_fit_dims(self, entity: str) -> int:
|
|
449
|
+
if not hasattr(self, f"_{entity}_dim_size"):
|
|
450
|
+
setattr(
|
|
451
|
+
self,
|
|
452
|
+
f"_{entity}_dim_size",
|
|
453
|
+
getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
|
|
454
|
+
)
|
|
455
|
+
return getattr(self, f"_{entity}_dim_size")
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def _user_dim(self) -> int:
|
|
459
|
+
"""
|
|
460
|
+
:returns: dimension of users matrix (maximal user idx + 1)
|
|
461
|
+
"""
|
|
462
|
+
return self._get_fit_dims("user")
|
|
463
|
+
|
|
464
|
+
@property
|
|
465
|
+
def _item_dim(self) -> int:
|
|
466
|
+
"""
|
|
467
|
+
:returns: dimension of items matrix (maximal item idx + 1)
|
|
468
|
+
"""
|
|
469
|
+
return self._get_fit_dims("item")
|
|
470
|
+
|
|
471
|
+
def _fit_predict(
|
|
472
|
+
self,
|
|
473
|
+
log: SparkDataFrame,
|
|
474
|
+
k: int,
|
|
475
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
476
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
477
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
478
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
479
|
+
filter_seen_items: bool = True,
|
|
480
|
+
recs_file_path: Optional[str] = None,
|
|
481
|
+
) -> Optional[SparkDataFrame]:
|
|
482
|
+
self._fit_wrap(log, user_features, item_features)
|
|
483
|
+
return self._predict_wrap(
|
|
484
|
+
log,
|
|
485
|
+
k,
|
|
486
|
+
users,
|
|
487
|
+
items,
|
|
488
|
+
user_features,
|
|
489
|
+
item_features,
|
|
490
|
+
filter_seen_items,
|
|
491
|
+
recs_file_path=recs_file_path,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
def _predict_pairs_wrap(
|
|
495
|
+
self,
|
|
496
|
+
pairs: SparkDataFrame,
|
|
497
|
+
log: Optional[SparkDataFrame] = None,
|
|
498
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
499
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
500
|
+
recs_file_path: Optional[str] = None,
|
|
501
|
+
k: Optional[int] = None,
|
|
502
|
+
) -> Optional[SparkDataFrame]:
|
|
503
|
+
"""
|
|
504
|
+
This method
|
|
505
|
+
1) converts data to spark
|
|
506
|
+
2) removes cold users and items if model does not predict them
|
|
507
|
+
3) calls inner _predict_pairs method of a model
|
|
508
|
+
|
|
509
|
+
:param pairs: user-item pairs to get relevance for,
|
|
510
|
+
dataframe containing``[user_idx, item_idx]``.
|
|
511
|
+
:param log: train data
|
|
512
|
+
``[user_idx, item_idx, timestamp, relevance]``.
|
|
513
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
514
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
515
|
+
:return: cached dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
516
|
+
or None if `file_path` is provided
|
|
517
|
+
"""
|
|
518
|
+
log, user_features, item_features, pairs = (
|
|
519
|
+
convert2spark(df) for df in [log, user_features, item_features, pairs]
|
|
520
|
+
)
|
|
521
|
+
if sorted(pairs.columns) != ["item_idx", "user_idx"]:
|
|
522
|
+
msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
|
|
523
|
+
raise ValueError(msg)
|
|
524
|
+
pairs, log = self._filter_cold_for_predict(pairs, log, "user")
|
|
525
|
+
pairs, log = self._filter_cold_for_predict(pairs, log, "item")
|
|
526
|
+
|
|
527
|
+
pred = self._predict_pairs(
|
|
528
|
+
pairs=pairs,
|
|
529
|
+
log=log,
|
|
530
|
+
user_features=user_features,
|
|
531
|
+
item_features=item_features,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if k:
|
|
535
|
+
pred = get_top_k(
|
|
536
|
+
dataframe=pred,
|
|
537
|
+
partition_by_col=sf.col("user_idx"),
|
|
538
|
+
order_by_col=[
|
|
539
|
+
sf.col("relevance").desc(),
|
|
540
|
+
],
|
|
541
|
+
k=k,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
if recs_file_path is not None:
|
|
545
|
+
pred.write.parquet(path=recs_file_path, mode="overwrite")
|
|
546
|
+
return None
|
|
547
|
+
|
|
548
|
+
pred.cache().count()
|
|
549
|
+
return pred
|
|
550
|
+
|
|
551
|
+
def _predict_pairs(
|
|
552
|
+
self,
|
|
553
|
+
pairs: SparkDataFrame,
|
|
554
|
+
log: Optional[SparkDataFrame] = None,
|
|
555
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
556
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
557
|
+
) -> SparkDataFrame:
|
|
558
|
+
"""
|
|
559
|
+
Fallback method to use in case ``_predict_pairs`` is not implemented.
|
|
560
|
+
Simply joins ``predict`` with given ``pairs``.
|
|
561
|
+
:param pairs: user-item pairs to get relevance for,
|
|
562
|
+
dataframe containing``[user_idx, item_idx]``.
|
|
563
|
+
:param log: train data
|
|
564
|
+
``[user_idx, item_idx, timestamp, relevance]``.
|
|
565
|
+
"""
|
|
566
|
+
message = (
|
|
567
|
+
"native predict_pairs is not implemented for this model. "
|
|
568
|
+
"Falling back to usual predict method and filtering the results."
|
|
569
|
+
)
|
|
570
|
+
self.logger.warning(message)
|
|
571
|
+
|
|
572
|
+
users = pairs.select("user_idx").distinct()
|
|
573
|
+
items = pairs.select("item_idx").distinct()
|
|
574
|
+
k = items.count()
|
|
575
|
+
pred = self._predict(
|
|
576
|
+
log=log,
|
|
577
|
+
k=k,
|
|
578
|
+
users=users,
|
|
579
|
+
items=items,
|
|
580
|
+
user_features=user_features,
|
|
581
|
+
item_features=item_features,
|
|
582
|
+
filter_seen_items=False,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
pred = pred.join(
|
|
586
|
+
pairs.select("user_idx", "item_idx"),
|
|
587
|
+
on=["user_idx", "item_idx"],
|
|
588
|
+
how="inner",
|
|
589
|
+
)
|
|
590
|
+
return pred
|
|
591
|
+
|
|
592
|
+
def _get_features_wrap(
|
|
593
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
594
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
595
|
+
if "user_idx" not in ids.columns and "item_idx" not in ids.columns:
|
|
596
|
+
msg = "user_idx or item_idx missing"
|
|
597
|
+
raise ValueError(msg)
|
|
598
|
+
vectors, rank = self._get_features(ids, features)
|
|
599
|
+
return vectors, rank
|
|
600
|
+
|
|
601
|
+
def _get_features(
|
|
602
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
603
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
604
|
+
"""
|
|
605
|
+
Get embeddings from model
|
|
606
|
+
|
|
607
|
+
:param ids: id ids to get embeddings for Spark DataFrame containing user_idx or item_idx
|
|
608
|
+
:param features: user or item features
|
|
609
|
+
:return: DataFrame with biases and embeddings, and vector size
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
self.logger.info(
|
|
613
|
+
"get_features method is not defined for the model %s. Features will not be returned.",
|
|
614
|
+
str(self),
|
|
615
|
+
)
|
|
616
|
+
return None, None
|
|
617
|
+
|
|
618
|
+
def _get_nearest_items_wrap(
|
|
619
|
+
self,
|
|
620
|
+
items: Union[SparkDataFrame, Iterable],
|
|
621
|
+
k: int,
|
|
622
|
+
metric: Optional[str] = "cosine_similarity",
|
|
623
|
+
candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
624
|
+
) -> Optional[SparkDataFrame]:
|
|
625
|
+
"""
|
|
626
|
+
Convert indexes and leave top-k nearest items for each item in `items`.
|
|
627
|
+
"""
|
|
628
|
+
items = get_unique_entities(items, "item_idx")
|
|
629
|
+
if candidates is not None:
|
|
630
|
+
candidates = get_unique_entities(candidates, "item_idx")
|
|
631
|
+
|
|
632
|
+
nearest_items_to_filter = self._get_nearest_items(
|
|
633
|
+
items=items,
|
|
634
|
+
metric=metric,
|
|
635
|
+
candidates=candidates,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
rel_col_name = metric if metric is not None else "similarity"
|
|
639
|
+
nearest_items = get_top_k(
|
|
640
|
+
dataframe=nearest_items_to_filter,
|
|
641
|
+
partition_by_col=sf.col("item_idx_one"),
|
|
642
|
+
order_by_col=[
|
|
643
|
+
sf.col(rel_col_name).desc(),
|
|
644
|
+
sf.col("item_idx_two").desc(),
|
|
645
|
+
],
|
|
646
|
+
k=k,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
nearest_items = nearest_items.withColumnRenamed("item_idx_two", "neighbour_item_idx")
|
|
650
|
+
nearest_items = nearest_items.withColumnRenamed("item_idx_one", "item_idx")
|
|
651
|
+
return nearest_items
|
|
652
|
+
|
|
653
|
+
def _get_nearest_items(
|
|
654
|
+
self,
|
|
655
|
+
items: SparkDataFrame,
|
|
656
|
+
metric: Optional[str] = None,
|
|
657
|
+
candidates: Optional[SparkDataFrame] = None,
|
|
658
|
+
) -> Optional[SparkDataFrame]:
|
|
659
|
+
msg = f"item-to-item prediction is not implemented for {self}"
|
|
660
|
+
raise NotImplementedError(msg)
|
|
661
|
+
|
|
662
|
+
def _save_model(self, path: str):
|
|
663
|
+
pass
|
|
664
|
+
|
|
665
|
+
def _load_model(self, path: str):
|
|
666
|
+
pass
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class ItemVectorModel(BaseRecommender):
|
|
670
|
+
"""Parent for models generating items' vector representations"""
|
|
671
|
+
|
|
672
|
+
can_predict_item_to_item: bool = True
|
|
673
|
+
item_to_item_metrics: list[str] = [
|
|
674
|
+
"euclidean_distance_sim",
|
|
675
|
+
"cosine_similarity",
|
|
676
|
+
"dot_product",
|
|
677
|
+
]
|
|
678
|
+
|
|
679
|
+
@abstractmethod
|
|
680
|
+
def _get_item_vectors(self) -> SparkDataFrame:
|
|
681
|
+
"""
|
|
682
|
+
Return dataframe with items' vectors as a
|
|
683
|
+
spark dataframe with columns ``[item_idx, item_vector]``
|
|
684
|
+
"""
|
|
685
|
+
|
|
686
|
+
def get_nearest_items(
|
|
687
|
+
self,
|
|
688
|
+
items: Union[SparkDataFrame, Iterable],
|
|
689
|
+
k: int,
|
|
690
|
+
metric: Optional[str] = "cosine_similarity",
|
|
691
|
+
candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
692
|
+
) -> Optional[SparkDataFrame]:
|
|
693
|
+
"""
|
|
694
|
+
Get k most similar items be the `metric` for each of the `items`.
|
|
695
|
+
|
|
696
|
+
:param items: spark dataframe or list of item ids to find neighbors
|
|
697
|
+
:param k: number of neighbors
|
|
698
|
+
:param metric: 'euclidean_distance_sim', 'cosine_similarity', 'dot_product'
|
|
699
|
+
:param candidates: spark dataframe or list of items
|
|
700
|
+
to consider as similar, e.g. popular/new items. If None,
|
|
701
|
+
all items presented during model training are used.
|
|
702
|
+
:return: dataframe with the most similar items,
|
|
703
|
+
where bigger value means greater similarity.
|
|
704
|
+
spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
|
|
705
|
+
"""
|
|
706
|
+
if metric not in self.item_to_item_metrics:
|
|
707
|
+
msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
|
|
708
|
+
raise ValueError(msg)
|
|
709
|
+
|
|
710
|
+
return self._get_nearest_items_wrap(
|
|
711
|
+
items=items,
|
|
712
|
+
k=k,
|
|
713
|
+
metric=metric,
|
|
714
|
+
candidates=candidates,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def _get_nearest_items(
|
|
718
|
+
self,
|
|
719
|
+
items: SparkDataFrame,
|
|
720
|
+
metric: str = "cosine_similarity",
|
|
721
|
+
candidates: Optional[SparkDataFrame] = None,
|
|
722
|
+
) -> SparkDataFrame:
|
|
723
|
+
"""
|
|
724
|
+
Return distance metric value for all available close items filtered by `candidates`.
|
|
725
|
+
|
|
726
|
+
:param items: ids to find neighbours, spark dataframe with column ``item_idx``
|
|
727
|
+
:param metric: 'euclidean_distance_sim' calculated as 1/(1 + euclidean_distance),
|
|
728
|
+
'cosine_similarity', 'dot_product'
|
|
729
|
+
:param candidates: items among which we are looking for similar,
|
|
730
|
+
e.g. popular/new items. If None, all items presented during model training are used.
|
|
731
|
+
:return: dataframe with neighbours,
|
|
732
|
+
spark-dataframe with columns ``[item_idx_one, item_idx_two, similarity]``
|
|
733
|
+
"""
|
|
734
|
+
dist_function = cosine_similarity
|
|
735
|
+
if metric == "euclidean_distance_sim":
|
|
736
|
+
dist_function = vector_euclidean_distance_similarity
|
|
737
|
+
elif metric == "dot_product":
|
|
738
|
+
dist_function = vector_dot
|
|
739
|
+
|
|
740
|
+
items_vectors = self._get_item_vectors()
|
|
741
|
+
left_part = (
|
|
742
|
+
items_vectors.withColumnRenamed("item_idx", "item_idx_one")
|
|
743
|
+
.withColumnRenamed("item_vector", "item_vector_one")
|
|
744
|
+
.join(
|
|
745
|
+
items.select(sf.col("item_idx").alias("item_idx_one")),
|
|
746
|
+
on="item_idx_one",
|
|
747
|
+
)
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
right_part = items_vectors.withColumnRenamed("item_idx", "item_idx_two").withColumnRenamed(
|
|
751
|
+
"item_vector", "item_vector_two"
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
if candidates is not None:
|
|
755
|
+
right_part = right_part.join(
|
|
756
|
+
candidates.withColumnRenamed("item_idx", "item_idx_two"),
|
|
757
|
+
on="item_idx_two",
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
joined_factors = left_part.join(right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two"))
|
|
761
|
+
|
|
762
|
+
joined_factors = joined_factors.withColumn(
|
|
763
|
+
metric,
|
|
764
|
+
dist_function(sf.col("item_vector_one"), sf.col("item_vector_two")),
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
similarity_matrix = joined_factors.select("item_idx_one", "item_idx_two", metric)
|
|
768
|
+
|
|
769
|
+
return similarity_matrix
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
class HybridRecommender(BaseRecommender, ABC):
|
|
773
|
+
"""Base class for models that can use extra features"""
|
|
774
|
+
|
|
775
|
+
def fit(
|
|
776
|
+
self,
|
|
777
|
+
log: SparkDataFrame,
|
|
778
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
779
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
780
|
+
) -> None:
|
|
781
|
+
"""
|
|
782
|
+
Fit a recommendation model
|
|
783
|
+
|
|
784
|
+
:param log: historical log of interactions
|
|
785
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
786
|
+
:param user_features: user features
|
|
787
|
+
``[user_idx, timestamp]`` + feature columns
|
|
788
|
+
:param item_features: item features
|
|
789
|
+
``[item_idx, timestamp]`` + feature columns
|
|
790
|
+
:return:
|
|
791
|
+
"""
|
|
792
|
+
self._fit_wrap(
|
|
793
|
+
log=log,
|
|
794
|
+
user_features=user_features,
|
|
795
|
+
item_features=item_features,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
def predict(
|
|
799
|
+
self,
|
|
800
|
+
log: SparkDataFrame,
|
|
801
|
+
k: int,
|
|
802
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
803
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
804
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
805
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
806
|
+
filter_seen_items: bool = True,
|
|
807
|
+
recs_file_path: Optional[str] = None,
|
|
808
|
+
) -> Optional[SparkDataFrame]:
|
|
809
|
+
"""
|
|
810
|
+
Get recommendations
|
|
811
|
+
|
|
812
|
+
:param log: historical log of interactions
|
|
813
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
814
|
+
:param k: number of recommendations for each user
|
|
815
|
+
:param users: users to create recommendations for
|
|
816
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
817
|
+
if ``None``, recommend to all users from ``log``
|
|
818
|
+
:param items: candidate items for recommendations
|
|
819
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
820
|
+
if ``None``, take all items from ``log``.
|
|
821
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
822
|
+
:param user_features: user features
|
|
823
|
+
``[user_idx , timestamp]`` + feature columns
|
|
824
|
+
:param item_features: item features
|
|
825
|
+
``[item_idx , timestamp]`` + feature columns
|
|
826
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
827
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
828
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
829
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
830
|
+
or None if `file_path` is provided
|
|
831
|
+
|
|
832
|
+
"""
|
|
833
|
+
return self._predict_wrap(
|
|
834
|
+
log=log,
|
|
835
|
+
k=k,
|
|
836
|
+
users=users,
|
|
837
|
+
items=items,
|
|
838
|
+
user_features=user_features,
|
|
839
|
+
item_features=item_features,
|
|
840
|
+
filter_seen_items=filter_seen_items,
|
|
841
|
+
recs_file_path=recs_file_path,
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
def fit_predict(
|
|
845
|
+
self,
|
|
846
|
+
log: SparkDataFrame,
|
|
847
|
+
k: int,
|
|
848
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
849
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
850
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
851
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
852
|
+
filter_seen_items: bool = True,
|
|
853
|
+
recs_file_path: Optional[str] = None,
|
|
854
|
+
) -> Optional[SparkDataFrame]:
|
|
855
|
+
"""
|
|
856
|
+
Fit model and get recommendations
|
|
857
|
+
|
|
858
|
+
:param log: historical log of interactions
|
|
859
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
860
|
+
:param k: number of recommendations for each user
|
|
861
|
+
:param users: users to create recommendations for
|
|
862
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
863
|
+
if ``None``, recommend to all users from ``log``
|
|
864
|
+
:param items: candidate items for recommendations
|
|
865
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
866
|
+
if ``None``, take all items from ``log``.
|
|
867
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
868
|
+
:param user_features: user features
|
|
869
|
+
``[user_idx , timestamp]`` + feature columns
|
|
870
|
+
:param item_features: item features
|
|
871
|
+
``[item_idx , timestamp]`` + feature columns
|
|
872
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
873
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
874
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
875
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
876
|
+
or None if `file_path` is provided
|
|
877
|
+
"""
|
|
878
|
+
return self._fit_predict(
|
|
879
|
+
log=log,
|
|
880
|
+
k=k,
|
|
881
|
+
users=users,
|
|
882
|
+
items=items,
|
|
883
|
+
user_features=user_features,
|
|
884
|
+
item_features=item_features,
|
|
885
|
+
filter_seen_items=filter_seen_items,
|
|
886
|
+
recs_file_path=recs_file_path,
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
def predict_pairs(
|
|
890
|
+
self,
|
|
891
|
+
pairs: SparkDataFrame,
|
|
892
|
+
log: Optional[SparkDataFrame] = None,
|
|
893
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
894
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
895
|
+
recs_file_path: Optional[str] = None,
|
|
896
|
+
k: Optional[int] = None,
|
|
897
|
+
) -> Optional[SparkDataFrame]:
|
|
898
|
+
"""
|
|
899
|
+
Get recommendations for specific user-item ``pairs``.
|
|
900
|
+
If a model can't produce recommendation
|
|
901
|
+
for specific pair it is removed from the resulting dataframe.
|
|
902
|
+
|
|
903
|
+
:param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
|
|
904
|
+
:param log: historical log of interactions
|
|
905
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
906
|
+
:param user_features: user features
|
|
907
|
+
``[user_idx , timestamp]`` + feature columns
|
|
908
|
+
:param item_features: item features
|
|
909
|
+
``[item_idx , timestamp]`` + feature columns
|
|
910
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
911
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
912
|
+
:param k: top-k items for each user from pairs.
|
|
913
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
914
|
+
or None if `file_path` is provided
|
|
915
|
+
"""
|
|
916
|
+
return self._predict_pairs_wrap(
|
|
917
|
+
pairs=pairs,
|
|
918
|
+
log=log,
|
|
919
|
+
user_features=user_features,
|
|
920
|
+
item_features=item_features,
|
|
921
|
+
recs_file_path=recs_file_path,
|
|
922
|
+
k=k,
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
def get_features(
|
|
926
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
927
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
928
|
+
"""
|
|
929
|
+
Returns user or item feature vectors as a Column with type ArrayType
|
|
930
|
+
:param ids: Spark DataFrame with unique ids
|
|
931
|
+
:param features: Spark DataFrame with features for provided ids
|
|
932
|
+
:return: feature vectors
|
|
933
|
+
If a model does not have a vector for some ids they are not present in the final result.
|
|
934
|
+
"""
|
|
935
|
+
return self._get_features_wrap(ids, features)
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
class Recommender(BaseRecommender, ABC):
|
|
939
|
+
"""Usual recommender class for models without features."""
|
|
940
|
+
|
|
941
|
+
def fit(self, log: SparkDataFrame) -> None:
|
|
942
|
+
"""
|
|
943
|
+
Fit a recommendation model
|
|
944
|
+
|
|
945
|
+
:param log: historical log of interactions
|
|
946
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
947
|
+
:return:
|
|
948
|
+
"""
|
|
949
|
+
self._fit_wrap(
|
|
950
|
+
log=log,
|
|
951
|
+
user_features=None,
|
|
952
|
+
item_features=None,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
def predict(
|
|
956
|
+
self,
|
|
957
|
+
log: SparkDataFrame,
|
|
958
|
+
k: int,
|
|
959
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
960
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
961
|
+
filter_seen_items: bool = True,
|
|
962
|
+
recs_file_path: Optional[str] = None,
|
|
963
|
+
) -> Optional[SparkDataFrame]:
|
|
964
|
+
"""
|
|
965
|
+
Get recommendations
|
|
966
|
+
|
|
967
|
+
:param log: historical log of interactions
|
|
968
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
969
|
+
:param k: number of recommendations for each user
|
|
970
|
+
:param users: users to create recommendations for
|
|
971
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
972
|
+
if ``None``, recommend to all users from ``log``
|
|
973
|
+
:param items: candidate items for recommendations
|
|
974
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
975
|
+
if ``None``, take all items from ``log``.
|
|
976
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
977
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
978
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
979
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
980
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
981
|
+
or None if `file_path` is provided
|
|
982
|
+
"""
|
|
983
|
+
return self._predict_wrap(
|
|
984
|
+
log=log,
|
|
985
|
+
k=k,
|
|
986
|
+
users=users,
|
|
987
|
+
items=items,
|
|
988
|
+
user_features=None,
|
|
989
|
+
item_features=None,
|
|
990
|
+
filter_seen_items=filter_seen_items,
|
|
991
|
+
recs_file_path=recs_file_path,
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
def predict_pairs(
|
|
995
|
+
self,
|
|
996
|
+
pairs: SparkDataFrame,
|
|
997
|
+
log: Optional[SparkDataFrame] = None,
|
|
998
|
+
recs_file_path: Optional[str] = None,
|
|
999
|
+
k: Optional[int] = None,
|
|
1000
|
+
) -> Optional[SparkDataFrame]:
|
|
1001
|
+
"""
|
|
1002
|
+
Get recommendations for specific user-item ``pairs``.
|
|
1003
|
+
If a model can't produce recommendation
|
|
1004
|
+
for specific pair it is removed from the resulting dataframe.
|
|
1005
|
+
|
|
1006
|
+
:param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
|
|
1007
|
+
:param log: historical log of interactions
|
|
1008
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
1009
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
1010
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
1011
|
+
:param k: top-k items for each user from pairs.
|
|
1012
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
1013
|
+
or None if `file_path` is provided
|
|
1014
|
+
"""
|
|
1015
|
+
return self._predict_pairs_wrap(
|
|
1016
|
+
pairs=pairs,
|
|
1017
|
+
log=log,
|
|
1018
|
+
recs_file_path=recs_file_path,
|
|
1019
|
+
k=k,
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
def fit_predict(
|
|
1023
|
+
self,
|
|
1024
|
+
log: SparkDataFrame,
|
|
1025
|
+
k: int,
|
|
1026
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
1027
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
1028
|
+
filter_seen_items: bool = True,
|
|
1029
|
+
recs_file_path: Optional[str] = None,
|
|
1030
|
+
) -> Optional[SparkDataFrame]:
|
|
1031
|
+
"""
|
|
1032
|
+
Fit model and get recommendations
|
|
1033
|
+
|
|
1034
|
+
:param log: historical log of interactions
|
|
1035
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
1036
|
+
:param k: number of recommendations for each user
|
|
1037
|
+
:param users: users to create recommendations for
|
|
1038
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
1039
|
+
if ``None``, recommend to all users from ``log``
|
|
1040
|
+
:param items: candidate items for recommendations
|
|
1041
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
1042
|
+
if ``None``, take all items from ``log``.
|
|
1043
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
1044
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
1045
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
1046
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
1047
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
1048
|
+
or None if `file_path` is provided
|
|
1049
|
+
"""
|
|
1050
|
+
return self._fit_predict(
|
|
1051
|
+
log=log,
|
|
1052
|
+
k=k,
|
|
1053
|
+
users=users,
|
|
1054
|
+
items=items,
|
|
1055
|
+
user_features=None,
|
|
1056
|
+
item_features=None,
|
|
1057
|
+
filter_seen_items=filter_seen_items,
|
|
1058
|
+
recs_file_path=recs_file_path,
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
def get_features(self, ids: SparkDataFrame) -> Optional[tuple[SparkDataFrame, int]]:
|
|
1062
|
+
"""
|
|
1063
|
+
Returns user or item feature vectors as a Column with type ArrayType
|
|
1064
|
+
|
|
1065
|
+
:param ids: Spark DataFrame with unique ids
|
|
1066
|
+
:return: feature vectors.
|
|
1067
|
+
If a model does not have a vector for some ids they are not present in the final result.
|
|
1068
|
+
"""
|
|
1069
|
+
return self._get_features_wrap(ids, None)
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
class UserRecommender(BaseRecommender, ABC):
|
|
1073
|
+
"""Base class for models that use user features
|
|
1074
|
+
but not item features. ``log`` is not required for this class."""
|
|
1075
|
+
|
|
1076
|
+
def fit(
|
|
1077
|
+
self,
|
|
1078
|
+
log: SparkDataFrame,
|
|
1079
|
+
user_features: SparkDataFrame,
|
|
1080
|
+
) -> None:
|
|
1081
|
+
"""
|
|
1082
|
+
Finds user clusters and calculates item similarity in that clusters.
|
|
1083
|
+
|
|
1084
|
+
:param log: historical log of interactions
|
|
1085
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
1086
|
+
:param user_features: user features
|
|
1087
|
+
``[user_idx, timestamp]`` + feature columns
|
|
1088
|
+
:return:
|
|
1089
|
+
"""
|
|
1090
|
+
self._fit_wrap(log=log, user_features=user_features)
|
|
1091
|
+
|
|
1092
|
+
def predict(
|
|
1093
|
+
self,
|
|
1094
|
+
user_features: SparkDataFrame,
|
|
1095
|
+
k: int,
|
|
1096
|
+
log: Optional[SparkDataFrame] = None,
|
|
1097
|
+
users: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
1098
|
+
items: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
1099
|
+
filter_seen_items: bool = True,
|
|
1100
|
+
recs_file_path: Optional[str] = None,
|
|
1101
|
+
) -> Optional[SparkDataFrame]:
|
|
1102
|
+
"""
|
|
1103
|
+
Get recommendations
|
|
1104
|
+
|
|
1105
|
+
:param log: historical log of interactions
|
|
1106
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
1107
|
+
:param k: number of recommendations for each user
|
|
1108
|
+
:param users: users to create recommendations for
|
|
1109
|
+
dataframe containing ``[user_idx]`` or ``array-like``;
|
|
1110
|
+
if ``None``, recommend to all users from ``log``
|
|
1111
|
+
:param items: candidate items for recommendations
|
|
1112
|
+
dataframe containing ``[item_idx]`` or ``array-like``;
|
|
1113
|
+
if ``None``, take all items from ``log``.
|
|
1114
|
+
If it contains new items, ``relevance`` for them will be ``0``.
|
|
1115
|
+
:param user_features: user features
|
|
1116
|
+
``[user_idx , timestamp]`` + feature columns
|
|
1117
|
+
:param filter_seen_items: flag to remove seen items from recommendations based on ``log``.
|
|
1118
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
1119
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
1120
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
1121
|
+
or None if `file_path` is provided
|
|
1122
|
+
"""
|
|
1123
|
+
return self._predict_wrap(
|
|
1124
|
+
log=log,
|
|
1125
|
+
user_features=user_features,
|
|
1126
|
+
k=k,
|
|
1127
|
+
filter_seen_items=filter_seen_items,
|
|
1128
|
+
users=users,
|
|
1129
|
+
items=items,
|
|
1130
|
+
recs_file_path=recs_file_path,
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
def predict_pairs(
|
|
1134
|
+
self,
|
|
1135
|
+
pairs: SparkDataFrame,
|
|
1136
|
+
user_features: SparkDataFrame,
|
|
1137
|
+
log: Optional[SparkDataFrame] = None,
|
|
1138
|
+
recs_file_path: Optional[str] = None,
|
|
1139
|
+
k: Optional[int] = None,
|
|
1140
|
+
) -> Optional[SparkDataFrame]:
|
|
1141
|
+
"""
|
|
1142
|
+
Get recommendations for specific user-item ``pairs``.
|
|
1143
|
+
If a model can't produce recommendation
|
|
1144
|
+
for specific pair it is removed from the resulting dataframe.
|
|
1145
|
+
|
|
1146
|
+
:param pairs: dataframe with pairs to calculate relevance for, ``[user_idx, item_idx]``.
|
|
1147
|
+
:param user_features: user features
|
|
1148
|
+
``[user_idx , timestamp]`` + feature columns
|
|
1149
|
+
:param log: historical log of interactions
|
|
1150
|
+
``[user_idx, item_idx, timestamp, relevance]``
|
|
1151
|
+
:param recs_file_path: save recommendations at the given absolute path as parquet file.
|
|
1152
|
+
If None, cached and materialized recommendations dataframe will be returned
|
|
1153
|
+
:param k: top-k items for each user from pairs.
|
|
1154
|
+
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, relevance]``
|
|
1155
|
+
or None if `file_path` is provided
|
|
1156
|
+
"""
|
|
1157
|
+
return self._predict_pairs_wrap(
|
|
1158
|
+
pairs=pairs,
|
|
1159
|
+
log=log,
|
|
1160
|
+
user_features=user_features,
|
|
1161
|
+
recs_file_path=recs_file_path,
|
|
1162
|
+
k=k,
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
class NonPersonalizedRecommender(Recommender, ABC):
|
|
1167
|
+
"""Base class for non-personalized recommenders with popularity statistics."""
|
|
1168
|
+
|
|
1169
|
+
can_predict_cold_users = True
|
|
1170
|
+
can_predict_cold_items = True
|
|
1171
|
+
item_popularity: SparkDataFrame
|
|
1172
|
+
add_cold_items: bool
|
|
1173
|
+
cold_weight: float
|
|
1174
|
+
sample: bool
|
|
1175
|
+
fill: float
|
|
1176
|
+
seed: Optional[int] = None
|
|
1177
|
+
|
|
1178
|
+
def __init__(self, add_cold_items: bool, cold_weight: float):
|
|
1179
|
+
self.add_cold_items = add_cold_items
|
|
1180
|
+
if 0 < cold_weight <= 1:
|
|
1181
|
+
self.cold_weight = cold_weight
|
|
1182
|
+
else:
|
|
1183
|
+
msg = "`cold_weight` value should be in interval (0, 1]"
|
|
1184
|
+
raise ValueError(msg)
|
|
1185
|
+
|
|
1186
|
+
@property
|
|
1187
|
+
def _dataframes(self):
|
|
1188
|
+
return {"item_popularity": self.item_popularity}
|
|
1189
|
+
|
|
1190
|
+
def _save_model(self, path: str):
|
|
1191
|
+
save_picklable_to_parquet(self.fill, join(path, "params.dump"))
|
|
1192
|
+
|
|
1193
|
+
def _load_model(self, path: str):
|
|
1194
|
+
self.fill = load_pickled_from_parquet(join(path, "params.dump"))
|
|
1195
|
+
|
|
1196
|
+
def _clear_cache(self):
|
|
1197
|
+
if hasattr(self, "item_popularity"):
|
|
1198
|
+
self.item_popularity.unpersist()
|
|
1199
|
+
|
|
1200
|
+
@staticmethod
|
|
1201
|
+
def _calc_fill(item_popularity: SparkDataFrame, weight: float) -> float:
|
|
1202
|
+
"""
|
|
1203
|
+
Calculating a fill value a the minimal relevance
|
|
1204
|
+
calculated during model training multiplied by weight.
|
|
1205
|
+
"""
|
|
1206
|
+
return item_popularity.select(sf.min("relevance")).first()[0] * weight
|
|
1207
|
+
|
|
1208
|
+
@staticmethod
|
|
1209
|
+
def _check_relevance(log: SparkDataFrame):
|
|
1210
|
+
vals = log.select("relevance").where((sf.col("relevance") != 1) & (sf.col("relevance") != 0))
|
|
1211
|
+
if vals.count() > 0:
|
|
1212
|
+
msg = "Relevance values in log must be 0 or 1"
|
|
1213
|
+
raise ValueError(msg)
|
|
1214
|
+
|
|
1215
|
+
def _get_selected_item_popularity(self, items: SparkDataFrame) -> SparkDataFrame:
|
|
1216
|
+
"""
|
|
1217
|
+
Choose only required item from `item_popularity` dataframe
|
|
1218
|
+
for further recommendations generation.
|
|
1219
|
+
"""
|
|
1220
|
+
return self.item_popularity.join(
|
|
1221
|
+
items,
|
|
1222
|
+
on="item_idx",
|
|
1223
|
+
how="right" if self.add_cold_items else "inner",
|
|
1224
|
+
).fillna(value=self.fill, subset=["relevance"])
|
|
1225
|
+
|
|
1226
|
+
@staticmethod
|
|
1227
|
+
def _calc_max_hist_len(log: SparkDataFrame, users: SparkDataFrame) -> int:
|
|
1228
|
+
max_hist_len = (
|
|
1229
|
+
(log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
|
|
1230
|
+
.select(sf.max("items_count"))
|
|
1231
|
+
.first()[0]
|
|
1232
|
+
)
|
|
1233
|
+
# all users have empty history
|
|
1234
|
+
if max_hist_len is None:
|
|
1235
|
+
max_hist_len = 0
|
|
1236
|
+
|
|
1237
|
+
return max_hist_len
|
|
1238
|
+
|
|
1239
|
+
def _predict_without_sampling(
|
|
1240
|
+
self,
|
|
1241
|
+
log: SparkDataFrame,
|
|
1242
|
+
k: int,
|
|
1243
|
+
users: SparkDataFrame,
|
|
1244
|
+
items: SparkDataFrame,
|
|
1245
|
+
filter_seen_items: bool = True,
|
|
1246
|
+
) -> SparkDataFrame:
|
|
1247
|
+
"""
|
|
1248
|
+
Regular prediction for popularity-based models,
|
|
1249
|
+
top-k most relevant items from `items` are chosen for each user
|
|
1250
|
+
"""
|
|
1251
|
+
selected_item_popularity = self._get_selected_item_popularity(items)
|
|
1252
|
+
selected_item_popularity = selected_item_popularity.withColumn(
|
|
1253
|
+
"rank",
|
|
1254
|
+
sf.row_number().over(Window.orderBy(sf.col("relevance").desc(), sf.col("item_idx").desc())),
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
if filter_seen_items and log is not None:
|
|
1258
|
+
user_to_num_items = (
|
|
1259
|
+
log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("num_items"))
|
|
1260
|
+
)
|
|
1261
|
+
users = users.join(user_to_num_items, on="user_idx", how="left")
|
|
1262
|
+
users = users.fillna(0, "num_items")
|
|
1263
|
+
# 'selected_item_popularity' truncation by k + max_seen
|
|
1264
|
+
max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
|
|
1265
|
+
selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
|
|
1266
|
+
return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
|
|
1267
|
+
|
|
1268
|
+
return users.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
|
|
1269
|
+
|
|
1270
|
+
def _predict_with_sampling(
|
|
1271
|
+
self,
|
|
1272
|
+
log: SparkDataFrame,
|
|
1273
|
+
k: int,
|
|
1274
|
+
users: SparkDataFrame,
|
|
1275
|
+
items: SparkDataFrame,
|
|
1276
|
+
filter_seen_items: bool = True,
|
|
1277
|
+
) -> SparkDataFrame:
|
|
1278
|
+
"""
|
|
1279
|
+
Randomized prediction for popularity-based models,
|
|
1280
|
+
top-k items from `items` are sampled for each user based with
|
|
1281
|
+
probability proportional to items' popularity
|
|
1282
|
+
"""
|
|
1283
|
+
selected_item_popularity = self._get_selected_item_popularity(items)
|
|
1284
|
+
selected_item_popularity = selected_item_popularity.withColumn(
|
|
1285
|
+
"relevance",
|
|
1286
|
+
sf.when(sf.col("relevance") == sf.lit(0.0), 0.1**6).otherwise(sf.col("relevance")),
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
items_pd = selected_item_popularity.withColumn(
|
|
1290
|
+
"probability",
|
|
1291
|
+
sf.col("relevance") / selected_item_popularity.select(sf.sum("relevance")).first()[0],
|
|
1292
|
+
).toPandas()
|
|
1293
|
+
|
|
1294
|
+
rec_schema = get_schema(
|
|
1295
|
+
query_column="user_idx",
|
|
1296
|
+
item_column="item_idx",
|
|
1297
|
+
rating_column="relevance",
|
|
1298
|
+
has_timestamp=False,
|
|
1299
|
+
)
|
|
1300
|
+
if items_pd.shape[0] == 0:
|
|
1301
|
+
return State().session.createDataFrame([], rec_schema)
|
|
1302
|
+
|
|
1303
|
+
seed = self.seed
|
|
1304
|
+
class_name = self.__class__.__name__
|
|
1305
|
+
|
|
1306
|
+
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
1307
|
+
user_idx = pandas_df["user_idx"][0]
|
|
1308
|
+
cnt = pandas_df["cnt"][0]
|
|
1309
|
+
|
|
1310
|
+
local_rng = default_rng(seed + user_idx) if seed is not None else default_rng()
|
|
1311
|
+
|
|
1312
|
+
items_positions = local_rng.choice(
|
|
1313
|
+
np.arange(items_pd.shape[0]),
|
|
1314
|
+
size=cnt,
|
|
1315
|
+
p=items_pd["probability"].values,
|
|
1316
|
+
replace=False,
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
# workaround to unify RandomRec and UCB
|
|
1320
|
+
if class_name == "RandomRec":
|
|
1321
|
+
relevance = 1 / np.arange(1, cnt + 1)
|
|
1322
|
+
else:
|
|
1323
|
+
relevance = items_pd["probability"].values[items_positions]
|
|
1324
|
+
|
|
1325
|
+
return PandasDataFrame(
|
|
1326
|
+
{
|
|
1327
|
+
"user_idx": cnt * [user_idx],
|
|
1328
|
+
"item_idx": items_pd["item_idx"].values[items_positions],
|
|
1329
|
+
"relevance": relevance,
|
|
1330
|
+
}
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
if log is not None and filter_seen_items:
|
|
1334
|
+
recs = (
|
|
1335
|
+
log.select("user_idx", "item_idx")
|
|
1336
|
+
.distinct()
|
|
1337
|
+
.join(users, how="right", on="user_idx")
|
|
1338
|
+
.groupby("user_idx")
|
|
1339
|
+
.agg(sf.countDistinct("item_idx").alias("cnt"))
|
|
1340
|
+
.selectExpr(
|
|
1341
|
+
"user_idx",
|
|
1342
|
+
f"LEAST(cnt + {k}, {items_pd.shape[0]}) AS cnt",
|
|
1343
|
+
)
|
|
1344
|
+
)
|
|
1345
|
+
else:
|
|
1346
|
+
recs = users.withColumn("cnt", sf.lit(min(k, items_pd.shape[0])))
|
|
1347
|
+
|
|
1348
|
+
return recs.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
|
|
1349
|
+
|
|
1350
|
+
def _predict(
|
|
1351
|
+
self,
|
|
1352
|
+
log: SparkDataFrame,
|
|
1353
|
+
k: int,
|
|
1354
|
+
users: SparkDataFrame,
|
|
1355
|
+
items: SparkDataFrame,
|
|
1356
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
1357
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
1358
|
+
filter_seen_items: bool = True,
|
|
1359
|
+
) -> SparkDataFrame:
|
|
1360
|
+
if self.sample:
|
|
1361
|
+
return self._predict_with_sampling(
|
|
1362
|
+
log=log,
|
|
1363
|
+
k=k,
|
|
1364
|
+
users=users,
|
|
1365
|
+
items=items,
|
|
1366
|
+
filter_seen_items=filter_seen_items,
|
|
1367
|
+
)
|
|
1368
|
+
else:
|
|
1369
|
+
return self._predict_without_sampling(log, k, users, items, filter_seen_items)
|
|
1370
|
+
|
|
1371
|
+
def _predict_pairs(
|
|
1372
|
+
self,
|
|
1373
|
+
pairs: SparkDataFrame,
|
|
1374
|
+
log: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
1375
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
1376
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
1377
|
+
) -> SparkDataFrame:
|
|
1378
|
+
return (
|
|
1379
|
+
pairs.join(
|
|
1380
|
+
self.item_popularity,
|
|
1381
|
+
on="item_idx",
|
|
1382
|
+
how="left" if self.add_cold_items else "inner",
|
|
1383
|
+
)
|
|
1384
|
+
.fillna(value=self.fill, subset=["relevance"])
|
|
1385
|
+
.select("user_idx", "item_idx", "relevance")
|
|
1386
|
+
)
|