replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +6 -2
- replay/data/dataset.py +9 -9
- replay/data/nn/__init__.py +6 -6
- replay/data/nn/sequence_tokenizer.py +44 -38
- replay/data/nn/sequential_dataset.py +13 -8
- replay/data/nn/torch_sequential_dataset.py +14 -13
- replay/data/nn/utils.py +1 -1
- replay/metrics/base_metric.py +1 -1
- replay/metrics/coverage.py +7 -11
- replay/metrics/experiment.py +3 -3
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +19 -0
- replay/models/association_rules.py +1 -4
- replay/models/base_neighbour_rec.py +6 -9
- replay/models/base_rec.py +44 -293
- replay/models/cat_pop_rec.py +2 -1
- replay/models/common.py +69 -0
- replay/models/extensions/ann/ann_mixin.py +30 -25
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- replay/models/extensions/ann/utils.py +4 -3
- replay/models/knn.py +18 -17
- replay/models/nn/sequential/bert4rec/dataset.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
- replay/models/nn/sequential/compiled/__init__.py +10 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- replay/models/nn/sequential/sasrec/dataset.py +1 -1
- replay/models/nn/sequential/sasrec/model.py +1 -1
- replay/models/optimization/__init__.py +14 -0
- replay/models/optimization/optuna_mixin.py +279 -0
- replay/{optimization → models/optimization}/optuna_objective.py +13 -15
- replay/models/slim.py +2 -4
- replay/models/word2vec.py +7 -12
- replay/preprocessing/discretizer.py +1 -2
- replay/preprocessing/history_based_fp.py +1 -1
- replay/preprocessing/label_encoder.py +1 -1
- replay/splitters/cold_user_random_splitter.py +13 -7
- replay/splitters/last_n_splitter.py +17 -10
- replay/utils/__init__.py +6 -2
- replay/utils/common.py +4 -2
- replay/utils/model_handler.py +11 -31
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +2 -2
- replay/utils/types.py +28 -18
- replay/utils/warnings.py +26 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -40
- replay_rec-0.20.0.dist-info/RECORD +139 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -602
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -13
- replay/experimental/models/admm_slim.py +0 -205
- replay/experimental/models/base_neighbour_rec.py +0 -204
- replay/experimental/models/base_rec.py +0 -1340
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -923
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -265
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -302
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -296
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay/optimization/__init__.py +0 -5
- replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/NOTICE +0 -0
|
@@ -7,7 +7,7 @@ from abc import ABC
|
|
|
7
7
|
from typing import Any, Dict, Iterable, Optional, Union
|
|
8
8
|
|
|
9
9
|
from replay.data.dataset import Dataset
|
|
10
|
-
from replay.utils import PYSPARK_AVAILABLE,
|
|
10
|
+
from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
|
|
11
11
|
|
|
12
12
|
from .base_rec import Recommender
|
|
13
13
|
from .extensions.ann.ann_mixin import ANNMixin
|
|
@@ -16,10 +16,10 @@ if PYSPARK_AVAILABLE:
|
|
|
16
16
|
from pyspark.sql import functions as sf
|
|
17
17
|
from pyspark.sql.column import Column
|
|
18
18
|
else:
|
|
19
|
-
Column =
|
|
19
|
+
Column = MissingImport
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class NeighbourRec(
|
|
22
|
+
class NeighbourRec(ANNMixin, Recommender, ABC):
|
|
23
23
|
"""Base class that requires interactions at prediction time"""
|
|
24
24
|
|
|
25
25
|
similarity: Optional[SparkDataFrame]
|
|
@@ -187,16 +187,13 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
187
187
|
"similarity" if metric is None else metric,
|
|
188
188
|
)
|
|
189
189
|
|
|
190
|
-
def
|
|
190
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
|
|
191
|
+
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
191
192
|
self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
|
|
192
|
-
return {
|
|
193
|
+
return similarity_df, {
|
|
193
194
|
"features_col": None,
|
|
194
195
|
}
|
|
195
196
|
|
|
196
|
-
def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
|
|
197
|
-
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
198
|
-
return similarity_df
|
|
199
|
-
|
|
200
197
|
def _get_vectors_to_infer_ann_inner(
|
|
201
198
|
self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
|
|
202
199
|
) -> SparkDataFrame:
|
replay/models/base_rec.py
CHANGED
|
@@ -11,22 +11,18 @@ Base abstract classes:
|
|
|
11
11
|
with popularity statistics
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
-
import logging
|
|
15
14
|
import warnings
|
|
16
15
|
from abc import ABC, abstractmethod
|
|
17
|
-
from copy import deepcopy
|
|
18
16
|
from os.path import join
|
|
19
|
-
from typing import Any, Dict, Iterable, List, Optional,
|
|
17
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
20
18
|
|
|
21
19
|
import numpy as np
|
|
22
20
|
import pandas as pd
|
|
23
21
|
from numpy.random import default_rng
|
|
24
|
-
from optuna import create_study
|
|
25
|
-
from optuna.samplers import TPESampler
|
|
26
22
|
|
|
27
23
|
from replay.data import Dataset, get_schema
|
|
28
|
-
from replay.
|
|
29
|
-
from replay.optimization
|
|
24
|
+
from replay.models.common import RecommenderCommons
|
|
25
|
+
from replay.models.optimization import IsOptimizible
|
|
30
26
|
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
31
27
|
from replay.utils.session_handler import State
|
|
32
28
|
from replay.utils.spark_utils import SparkCollectToMasterWarning
|
|
@@ -38,10 +34,8 @@ if PYSPARK_AVAILABLE:
|
|
|
38
34
|
)
|
|
39
35
|
|
|
40
36
|
from replay.utils.spark_utils import (
|
|
41
|
-
cache_temp_view,
|
|
42
37
|
convert2spark,
|
|
43
38
|
cosine_similarity,
|
|
44
|
-
drop_temp_view,
|
|
45
39
|
filter_cold,
|
|
46
40
|
get_top_k,
|
|
47
41
|
get_top_k_recs,
|
|
@@ -88,80 +82,12 @@ class IsSavable(ABC):
|
|
|
88
82
|
"""
|
|
89
83
|
|
|
90
84
|
|
|
91
|
-
class RecommenderCommons:
|
|
92
|
-
"""
|
|
93
|
-
Common methods and attributes of RePlay models for caching, setting parameters and logging
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
_logger: Optional[logging.Logger] = None
|
|
97
|
-
cached_dfs: Optional[Set] = None
|
|
98
|
-
query_column: str
|
|
99
|
-
item_column: str
|
|
100
|
-
rating_column: str
|
|
101
|
-
timestamp_column: str
|
|
102
|
-
|
|
103
|
-
def set_params(self, **params: Dict[str, Any]) -> None:
|
|
104
|
-
"""
|
|
105
|
-
Set model parameters
|
|
106
|
-
|
|
107
|
-
:param params: dictionary param name - param value
|
|
108
|
-
:return:
|
|
109
|
-
"""
|
|
110
|
-
for param, value in params.items():
|
|
111
|
-
setattr(self, param, value)
|
|
112
|
-
self._clear_cache()
|
|
113
|
-
|
|
114
|
-
def _clear_cache(self):
|
|
115
|
-
"""
|
|
116
|
-
Clear spark cache
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
def __str__(self):
|
|
120
|
-
return type(self).__name__
|
|
121
|
-
|
|
122
|
-
@property
|
|
123
|
-
def logger(self) -> logging.Logger:
|
|
124
|
-
"""
|
|
125
|
-
:returns: get library logger
|
|
126
|
-
"""
|
|
127
|
-
if self._logger is None:
|
|
128
|
-
self._logger = logging.getLogger("replay")
|
|
129
|
-
return self._logger
|
|
130
|
-
|
|
131
|
-
def _cache_model_temp_view(self, df: SparkDataFrame, df_name: str) -> None:
|
|
132
|
-
"""
|
|
133
|
-
Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
|
|
134
|
-
Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
135
|
-
"""
|
|
136
|
-
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
137
|
-
cache_temp_view(df, full_name)
|
|
138
|
-
|
|
139
|
-
if self.cached_dfs is None:
|
|
140
|
-
self.cached_dfs = set()
|
|
141
|
-
self.cached_dfs.add(full_name)
|
|
142
|
-
|
|
143
|
-
def _clear_model_temp_view(self, df_name: str) -> None:
|
|
144
|
-
"""
|
|
145
|
-
Uncache and drop Spark SQL temporary view and remove from self.cached_dfs
|
|
146
|
-
Temp view to replace will be constructed as
|
|
147
|
-
"id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
148
|
-
"""
|
|
149
|
-
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
150
|
-
drop_temp_view(full_name)
|
|
151
|
-
if self.cached_dfs is not None:
|
|
152
|
-
self.cached_dfs.discard(full_name)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
85
|
+
class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
156
86
|
"""Base recommender"""
|
|
157
87
|
|
|
158
88
|
model: Any
|
|
159
89
|
can_predict_cold_queries: bool = False
|
|
160
90
|
can_predict_cold_items: bool = False
|
|
161
|
-
_search_space: Optional[Dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
|
|
162
|
-
_objective = MainObjective
|
|
163
|
-
study = None
|
|
164
|
-
criterion = None
|
|
165
91
|
fit_queries: SparkDataFrame
|
|
166
92
|
fit_items: SparkDataFrame
|
|
167
93
|
_num_queries: int
|
|
@@ -169,202 +95,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
169
95
|
_query_dim_size: int
|
|
170
96
|
_item_dim_size: int
|
|
171
97
|
|
|
172
|
-
def optimize(
|
|
173
|
-
self,
|
|
174
|
-
train_dataset: Dataset,
|
|
175
|
-
test_dataset: Dataset,
|
|
176
|
-
param_borders: Optional[Dict[str, List[Any]]] = None,
|
|
177
|
-
criterion: Metric = NDCG,
|
|
178
|
-
k: int = 10,
|
|
179
|
-
budget: int = 10,
|
|
180
|
-
new_study: bool = True,
|
|
181
|
-
) -> Optional[Dict[str, Any]]:
|
|
182
|
-
"""
|
|
183
|
-
Searches the best parameters with optuna.
|
|
184
|
-
|
|
185
|
-
:param train_dataset: train data
|
|
186
|
-
:param test_dataset: test data
|
|
187
|
-
:param param_borders: a dictionary with search borders, where
|
|
188
|
-
key is the parameter name and value is the range of possible values
|
|
189
|
-
``{param: [low, high]}``. In case of categorical parameters it is
|
|
190
|
-
all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
|
|
191
|
-
:param criterion: metric to use for optimization
|
|
192
|
-
:param k: recommendation list length
|
|
193
|
-
:param budget: number of points to try
|
|
194
|
-
:param new_study: keep searching with previous study or start a new study
|
|
195
|
-
:return: dictionary with best parameters
|
|
196
|
-
"""
|
|
197
|
-
self.query_column = train_dataset.feature_schema.query_id_column
|
|
198
|
-
self.item_column = train_dataset.feature_schema.item_id_column
|
|
199
|
-
self.rating_column = train_dataset.feature_schema.interactions_rating_column
|
|
200
|
-
self.timestamp_column = train_dataset.feature_schema.interactions_timestamp_column
|
|
201
|
-
|
|
202
|
-
self.criterion = criterion(
|
|
203
|
-
topk=k,
|
|
204
|
-
query_column=self.query_column,
|
|
205
|
-
item_column=self.item_column,
|
|
206
|
-
rating_column=self.rating_column,
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
if self._search_space is None:
|
|
210
|
-
self.logger.warning("%s has no hyper parameters to optimize", str(self))
|
|
211
|
-
return None
|
|
212
|
-
|
|
213
|
-
if self.study is None or new_study:
|
|
214
|
-
self.study = create_study(direction="maximize", sampler=TPESampler())
|
|
215
|
-
|
|
216
|
-
search_space = self._prepare_param_borders(param_borders)
|
|
217
|
-
if self._init_params_in_search_space(search_space) and not self._params_tried():
|
|
218
|
-
self.study.enqueue_trial(self._init_args)
|
|
219
|
-
|
|
220
|
-
split_data = self._prepare_split_data(train_dataset, test_dataset)
|
|
221
|
-
objective = self._objective(
|
|
222
|
-
search_space=search_space,
|
|
223
|
-
split_data=split_data,
|
|
224
|
-
recommender=self,
|
|
225
|
-
criterion=self.criterion,
|
|
226
|
-
k=k,
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
self.study.optimize(objective, budget)
|
|
230
|
-
best_params = self.study.best_params
|
|
231
|
-
self.set_params(**best_params)
|
|
232
|
-
return best_params
|
|
233
|
-
|
|
234
|
-
def _init_params_in_search_space(self, search_space):
|
|
235
|
-
"""Check if model params are inside search space"""
|
|
236
|
-
params = self._init_args
|
|
237
|
-
outside_search_space = {}
|
|
238
|
-
for param, value in params.items():
|
|
239
|
-
if param not in search_space:
|
|
240
|
-
continue
|
|
241
|
-
borders = search_space[param]["args"]
|
|
242
|
-
param_type = search_space[param]["type"]
|
|
243
|
-
|
|
244
|
-
extra_category = param_type == "categorical" and value not in borders
|
|
245
|
-
param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
|
|
246
|
-
if extra_category or param_out_of_bounds:
|
|
247
|
-
outside_search_space[param] = {
|
|
248
|
-
"borders": borders,
|
|
249
|
-
"value": value,
|
|
250
|
-
}
|
|
251
|
-
|
|
252
|
-
if outside_search_space:
|
|
253
|
-
self.logger.debug(
|
|
254
|
-
"Model is initialized with parameters outside the search space: %s."
|
|
255
|
-
"Initial parameters will not be evaluated during optimization."
|
|
256
|
-
"Change search spare with 'param_borders' argument if necessary",
|
|
257
|
-
outside_search_space,
|
|
258
|
-
)
|
|
259
|
-
return False
|
|
260
|
-
else:
|
|
261
|
-
return True
|
|
262
|
-
|
|
263
|
-
def _prepare_param_borders(
|
|
264
|
-
self, param_borders: Optional[Dict[str, List[Any]]] = None
|
|
265
|
-
) -> Dict[str, Dict[str, List[Any]]]:
|
|
266
|
-
"""
|
|
267
|
-
Checks if param borders are valid and convert them to a search_space format
|
|
268
|
-
|
|
269
|
-
:param param_borders: a dictionary with search grid, where
|
|
270
|
-
key is the parameter name and value is the range of possible values
|
|
271
|
-
``{param: [low, high]}``.
|
|
272
|
-
:return:
|
|
273
|
-
"""
|
|
274
|
-
search_space = deepcopy(self._search_space)
|
|
275
|
-
if param_borders is None:
|
|
276
|
-
return search_space
|
|
277
|
-
|
|
278
|
-
for param, borders in param_borders.items():
|
|
279
|
-
self._check_borders(param, borders)
|
|
280
|
-
search_space[param]["args"] = borders
|
|
281
|
-
|
|
282
|
-
# Optuna trials should contain all searchable parameters
|
|
283
|
-
# to be able to correctly return best params
|
|
284
|
-
# If used didn't specify some params to be tested optuna still needs to suggest them
|
|
285
|
-
# This part makes sure this suggestion will be constant
|
|
286
|
-
args = self._init_args
|
|
287
|
-
missing_borders = {param: args[param] for param in search_space if param not in param_borders}
|
|
288
|
-
for param, value in missing_borders.items():
|
|
289
|
-
if search_space[param]["type"] == "categorical":
|
|
290
|
-
search_space[param]["args"] = [value]
|
|
291
|
-
else:
|
|
292
|
-
search_space[param]["args"] = [value, value]
|
|
293
|
-
|
|
294
|
-
return search_space
|
|
295
|
-
|
|
296
|
-
def _check_borders(self, param, borders):
|
|
297
|
-
"""Raise value error if param borders are not valid"""
|
|
298
|
-
if param not in self._search_space:
|
|
299
|
-
msg = f"Hyper parameter {param} is not defined for {self!s}"
|
|
300
|
-
raise ValueError(msg)
|
|
301
|
-
if not isinstance(borders, list):
|
|
302
|
-
msg = f"Parameter {param} borders are not a list"
|
|
303
|
-
raise ValueError()
|
|
304
|
-
if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
|
|
305
|
-
msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
|
|
306
|
-
raise ValueError(msg)
|
|
307
|
-
|
|
308
|
-
def _prepare_split_data(
|
|
309
|
-
self,
|
|
310
|
-
train_dataset: Dataset,
|
|
311
|
-
test_dataset: Dataset,
|
|
312
|
-
) -> SplitData:
|
|
313
|
-
"""
|
|
314
|
-
This method converts data to spark and packs it into a named tuple to pass into optuna.
|
|
315
|
-
|
|
316
|
-
:param train_dataset: train data
|
|
317
|
-
:param test_dataset: test data
|
|
318
|
-
:return: packed PySpark DataFrames
|
|
319
|
-
"""
|
|
320
|
-
train = self._filter_dataset_features(train_dataset)
|
|
321
|
-
test = self._filter_dataset_features(test_dataset)
|
|
322
|
-
queries = test_dataset.interactions.select(self.query_column).distinct()
|
|
323
|
-
items = test_dataset.interactions.select(self.item_column).distinct()
|
|
324
|
-
|
|
325
|
-
split_data = SplitData(
|
|
326
|
-
train,
|
|
327
|
-
test,
|
|
328
|
-
queries,
|
|
329
|
-
items,
|
|
330
|
-
)
|
|
331
|
-
return split_data
|
|
332
|
-
|
|
333
|
-
@staticmethod
|
|
334
|
-
def _filter_dataset_features(
|
|
335
|
-
dataset: Dataset,
|
|
336
|
-
) -> Dataset:
|
|
337
|
-
"""
|
|
338
|
-
Filter features of dataset to match with items and queries of interactions
|
|
339
|
-
|
|
340
|
-
:param dataset: dataset with interactions and features
|
|
341
|
-
:return: filtered dataset
|
|
342
|
-
"""
|
|
343
|
-
if dataset.query_features is None and dataset.item_features is None:
|
|
344
|
-
return dataset
|
|
345
|
-
|
|
346
|
-
query_features = None
|
|
347
|
-
item_features = None
|
|
348
|
-
if dataset.query_features is not None:
|
|
349
|
-
query_features = dataset.query_features.join(
|
|
350
|
-
dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
|
|
351
|
-
on=dataset.feature_schema.query_id_column,
|
|
352
|
-
)
|
|
353
|
-
if dataset.item_features is not None:
|
|
354
|
-
item_features = dataset.item_features.join(
|
|
355
|
-
dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
|
|
356
|
-
on=dataset.feature_schema.item_id_column,
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
return Dataset(
|
|
360
|
-
feature_schema=dataset.feature_schema,
|
|
361
|
-
interactions=dataset.interactions,
|
|
362
|
-
query_features=query_features,
|
|
363
|
-
item_features=item_features,
|
|
364
|
-
check_consistency=False,
|
|
365
|
-
categorical_encoded=False,
|
|
366
|
-
)
|
|
367
|
-
|
|
368
98
|
def _fit_wrap(
|
|
369
99
|
self,
|
|
370
100
|
dataset: Dataset,
|
|
@@ -418,7 +148,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
418
148
|
:return:
|
|
419
149
|
"""
|
|
420
150
|
|
|
421
|
-
def _filter_seen(
|
|
151
|
+
def _filter_seen(
|
|
152
|
+
self,
|
|
153
|
+
recs: SparkDataFrame,
|
|
154
|
+
interactions: SparkDataFrame,
|
|
155
|
+
k: int,
|
|
156
|
+
queries: SparkDataFrame,
|
|
157
|
+
):
|
|
422
158
|
"""
|
|
423
159
|
Filter seen items (presented in interactions) out of the queries' recommendations.
|
|
424
160
|
For each query return from `k` to `k + number of seen by query` recommendations.
|
|
@@ -579,11 +315,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
579
315
|
Warn if cold entities are present in the `main_df`.
|
|
580
316
|
"""
|
|
581
317
|
can_predict_cold = self.can_predict_cold_queries if entity == "query" else self.can_predict_cold_items
|
|
582
|
-
fit_entities = self.fit_queries if entity == "query" else self.fit_items
|
|
583
|
-
column = self.query_column if entity == "query" else self.item_column
|
|
584
318
|
if can_predict_cold:
|
|
585
319
|
return main_df, interactions_df
|
|
586
320
|
|
|
321
|
+
fit_entities = self.fit_queries if entity == "query" else self.fit_items
|
|
322
|
+
column = self.query_column if entity == "query" else self.item_column
|
|
323
|
+
|
|
587
324
|
num_new, main_df = filter_cold(main_df, fit_entities, col_name=column)
|
|
588
325
|
if num_new > 0:
|
|
589
326
|
self.logger.info(
|
|
@@ -622,7 +359,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
622
359
|
"""
|
|
623
360
|
|
|
624
361
|
def _predict_proba(
|
|
625
|
-
self,
|
|
362
|
+
self,
|
|
363
|
+
dataset: Dataset,
|
|
364
|
+
k: int,
|
|
365
|
+
queries: SparkDataFrame,
|
|
366
|
+
items: SparkDataFrame,
|
|
367
|
+
filter_seen_items: bool = True,
|
|
626
368
|
) -> np.ndarray:
|
|
627
369
|
"""
|
|
628
370
|
Inner method where model actually predicts probability estimates.
|
|
@@ -767,7 +509,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
767
509
|
"""
|
|
768
510
|
if dataset is not None:
|
|
769
511
|
interactions, query_features, item_features, pairs = [
|
|
770
|
-
convert2spark(df)
|
|
512
|
+
convert2spark(df)
|
|
513
|
+
for df in [
|
|
514
|
+
dataset.interactions,
|
|
515
|
+
dataset.query_features,
|
|
516
|
+
dataset.item_features,
|
|
517
|
+
pairs,
|
|
518
|
+
]
|
|
771
519
|
]
|
|
772
520
|
if set(pairs.columns) != {self.item_column, self.query_column}:
|
|
773
521
|
msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
|
|
@@ -903,21 +651,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
903
651
|
|
|
904
652
|
def _get_nearest_items(
|
|
905
653
|
self,
|
|
906
|
-
items: SparkDataFrame,
|
|
907
|
-
metric: Optional[str] = None,
|
|
908
|
-
candidates: Optional[SparkDataFrame] = None,
|
|
654
|
+
items: SparkDataFrame,
|
|
655
|
+
metric: Optional[str] = None,
|
|
656
|
+
candidates: Optional[SparkDataFrame] = None,
|
|
909
657
|
) -> Optional[SparkDataFrame]:
|
|
910
658
|
msg = f"item-to-item prediction is not implemented for {self}"
|
|
911
659
|
raise NotImplementedError(msg)
|
|
912
660
|
|
|
913
|
-
def _params_tried(self):
|
|
914
|
-
"""check if current parameters were already evaluated"""
|
|
915
|
-
if self.study is None:
|
|
916
|
-
return False
|
|
917
|
-
|
|
918
|
-
params = {name: value for name, value in self._init_args.items() if name in self._search_space}
|
|
919
|
-
return any(params == trial.params for trial in self.study.trials)
|
|
920
|
-
|
|
921
661
|
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
922
662
|
saved_params = {
|
|
923
663
|
"query_column": self.query_column,
|
|
@@ -1496,7 +1236,11 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1496
1236
|
# 'selected_item_popularity' truncation by k + max_seen
|
|
1497
1237
|
max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
|
|
1498
1238
|
selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
|
|
1499
|
-
return queries.join(
|
|
1239
|
+
return queries.join(
|
|
1240
|
+
selected_item_popularity,
|
|
1241
|
+
on=(sf.col("rank") <= k + sf.col("num_items")),
|
|
1242
|
+
how="left",
|
|
1243
|
+
)
|
|
1500
1244
|
|
|
1501
1245
|
return queries.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
|
|
1502
1246
|
|
|
@@ -1555,7 +1299,9 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1555
1299
|
rating_column = self.rating_column
|
|
1556
1300
|
class_name = self.__class__.__name__
|
|
1557
1301
|
|
|
1558
|
-
def grouped_map(
|
|
1302
|
+
def grouped_map(
|
|
1303
|
+
pandas_df: PandasDataFrame,
|
|
1304
|
+
) -> PandasDataFrame: # pragma: no cover
|
|
1559
1305
|
query_idx = pandas_df[query_column][0]
|
|
1560
1306
|
cnt = pandas_df["cnt"][0]
|
|
1561
1307
|
|
|
@@ -1640,7 +1386,12 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1640
1386
|
)
|
|
1641
1387
|
|
|
1642
1388
|
def _predict_proba(
|
|
1643
|
-
self,
|
|
1389
|
+
self,
|
|
1390
|
+
dataset: Dataset,
|
|
1391
|
+
k: int,
|
|
1392
|
+
queries: SparkDataFrame,
|
|
1393
|
+
items: SparkDataFrame,
|
|
1394
|
+
filter_seen_items: bool = True,
|
|
1644
1395
|
) -> np.ndarray:
|
|
1645
1396
|
"""
|
|
1646
1397
|
Inner method where model actually predicts probability estimates.
|
replay/models/cat_pop_rec.py
CHANGED
|
@@ -4,7 +4,8 @@ from typing import Iterable, Optional, Union
|
|
|
4
4
|
from replay.data import Dataset
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
6
6
|
|
|
7
|
-
from .base_rec import IsSavable
|
|
7
|
+
from .base_rec import IsSavable
|
|
8
|
+
from .common import RecommenderCommons
|
|
8
9
|
|
|
9
10
|
if PYSPARK_AVAILABLE:
|
|
10
11
|
from pyspark.sql import functions as sf
|
replay/models/common.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
from replay.utils import SparkDataFrame
|
|
5
|
+
from replay.utils.spark_utils import cache_temp_view, drop_temp_view
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RecommenderCommons:
|
|
9
|
+
"""
|
|
10
|
+
Common methods and attributes of RePlay models for caching, setting parameters and logging
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
_logger: Optional[logging.Logger] = None
|
|
14
|
+
cached_dfs: Optional[set] = None
|
|
15
|
+
query_column: str
|
|
16
|
+
item_column: str
|
|
17
|
+
rating_column: str
|
|
18
|
+
timestamp_column: str
|
|
19
|
+
|
|
20
|
+
def set_params(self, **params: dict[str, Any]) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Set model parameters
|
|
23
|
+
|
|
24
|
+
:param params: dictionary param name - param value
|
|
25
|
+
:return:
|
|
26
|
+
"""
|
|
27
|
+
for param, value in params.items():
|
|
28
|
+
setattr(self, param, value)
|
|
29
|
+
self._clear_cache()
|
|
30
|
+
|
|
31
|
+
def _clear_cache(self):
|
|
32
|
+
"""
|
|
33
|
+
Clear spark cache
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __str__(self):
|
|
37
|
+
return type(self).__name__
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def logger(self) -> logging.Logger:
|
|
41
|
+
"""
|
|
42
|
+
:returns: get library logger
|
|
43
|
+
"""
|
|
44
|
+
if self._logger is None:
|
|
45
|
+
self._logger = logging.getLogger("replay")
|
|
46
|
+
return self._logger
|
|
47
|
+
|
|
48
|
+
def _cache_model_temp_view(self, df: SparkDataFrame, df_name: str) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
|
|
51
|
+
Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
52
|
+
"""
|
|
53
|
+
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
54
|
+
cache_temp_view(df, full_name)
|
|
55
|
+
|
|
56
|
+
if self.cached_dfs is None:
|
|
57
|
+
self.cached_dfs = set()
|
|
58
|
+
self.cached_dfs.add(full_name)
|
|
59
|
+
|
|
60
|
+
def _clear_model_temp_view(self, df_name: str) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Uncache and drop Spark SQL temporary view and remove from self.cached_dfs
|
|
63
|
+
Temp view to replace will be constructed as
|
|
64
|
+
"id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
65
|
+
"""
|
|
66
|
+
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
67
|
+
drop_temp_view(full_name)
|
|
68
|
+
if self.cached_dfs is not None:
|
|
69
|
+
self.cached_dfs.discard(full_name)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import logging
|
|
3
|
+
import sys
|
|
3
4
|
from abc import abstractmethod
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Iterable, Optional, Union
|
|
5
6
|
|
|
6
7
|
from replay.data import Dataset
|
|
7
|
-
from replay.models.
|
|
8
|
-
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
8
|
+
from replay.models.common import RecommenderCommons
|
|
9
|
+
from replay.utils import ANN_AVAILABLE, PYSPARK_AVAILABLE, FeatureUnavailableError, SparkDataFrame
|
|
9
10
|
|
|
10
11
|
from .index_builders.base_index_builder import IndexBuilder
|
|
11
12
|
|
|
@@ -16,18 +17,32 @@ if PYSPARK_AVAILABLE:
|
|
|
16
17
|
|
|
17
18
|
from .index_stores.spark_files_index_store import SparkFilesIndexStore
|
|
18
19
|
|
|
19
|
-
|
|
20
20
|
logger = logging.getLogger("replay")
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class ANNMixin(
|
|
23
|
+
class ANNMixin(RecommenderCommons):
|
|
24
24
|
"""
|
|
25
25
|
This class overrides the `_fit_wrap` and `_predict_wrap` methods of the base class,
|
|
26
26
|
adding an index construction in the `_fit_wrap` step
|
|
27
27
|
and an index inference in the `_predict_wrap` step.
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
|
-
index_builder: Optional[IndexBuilder] = None
|
|
30
|
+
index_builder: Optional["IndexBuilder"] = None
|
|
31
|
+
|
|
32
|
+
def init_index_builder(self, index_builder: Optional[IndexBuilder] = None) -> None:
|
|
33
|
+
if index_builder is not None and not ANN_AVAILABLE:
|
|
34
|
+
err = FeatureUnavailableError(
|
|
35
|
+
"`index_builder` can only be provided when all ANN dependencies are installed."
|
|
36
|
+
)
|
|
37
|
+
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
|
38
|
+
err.add_note(
|
|
39
|
+
"To enable ANN, ensure you have both 'hnswlib' and 'fixed-install-nmslib' packages installed."
|
|
40
|
+
)
|
|
41
|
+
raise err
|
|
42
|
+
elif isinstance(index_builder, IndexBuilder):
|
|
43
|
+
self.index_builder = index_builder
|
|
44
|
+
elif isinstance(index_builder, dict):
|
|
45
|
+
self.init_builder_from_dict(index_builder)
|
|
31
46
|
|
|
32
47
|
@property
|
|
33
48
|
def _use_ann(self) -> bool:
|
|
@@ -39,26 +54,17 @@ class ANNMixin(BaseRecommender):
|
|
|
39
54
|
return self.index_builder is not None
|
|
40
55
|
|
|
41
56
|
@abstractmethod
|
|
42
|
-
def
|
|
43
|
-
"""Implementations of this method must return a dataframe with item vectors.
|
|
44
|
-
Item vectors from this method are used to build the index.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
log: DataFrame with interactions
|
|
48
|
-
|
|
49
|
-
Returns: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
|
|
50
|
-
Column names in dataframe can be anything.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
@abstractmethod
|
|
54
|
-
def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
|
|
57
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> tuple[SparkDataFrame, dict]:
|
|
55
58
|
"""Implementation of this method must return dictionary
|
|
56
|
-
with arguments for `
|
|
59
|
+
with arguments for for an index builder's`build_index` method.
|
|
57
60
|
|
|
58
61
|
Args:
|
|
59
62
|
interactions: DataFrame with interactions
|
|
60
63
|
|
|
61
|
-
Returns:
|
|
64
|
+
Returns:
|
|
65
|
+
vectors: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
|
|
66
|
+
Column names in dataframe can be anything.
|
|
67
|
+
ann_params: Dictionary with arguments to build index. For example: {
|
|
62
68
|
"id_col": "item_idx",
|
|
63
69
|
"features_col": "item_factors",
|
|
64
70
|
...
|
|
@@ -79,8 +85,7 @@ class ANNMixin(BaseRecommender):
|
|
|
79
85
|
super()._fit_wrap(dataset)
|
|
80
86
|
|
|
81
87
|
if self._use_ann:
|
|
82
|
-
vectors = self.
|
|
83
|
-
ann_params = self._get_ann_build_params(dataset.interactions)
|
|
88
|
+
vectors, ann_params = self._configure_index_builder(dataset.interactions)
|
|
84
89
|
self.index_builder.build_index(vectors, **ann_params)
|
|
85
90
|
|
|
86
91
|
@abstractmethod
|
|
@@ -123,11 +128,11 @@ class ANNMixin(BaseRecommender):
|
|
|
123
128
|
return queries
|
|
124
129
|
|
|
125
130
|
@abstractmethod
|
|
126
|
-
def _get_ann_infer_params(self) ->
|
|
131
|
+
def _get_ann_infer_params(self) -> dict[str, Any]:
|
|
127
132
|
"""Implementation of this method must return dictionary
|
|
128
133
|
with arguments for `_infer_ann_index` method.
|
|
129
134
|
|
|
130
|
-
Returns:
|
|
135
|
+
Returns: dictionary with arguments to infer index. For example: {
|
|
131
136
|
"features_col": "user_vector",
|
|
132
137
|
...
|
|
133
138
|
}
|
|
@@ -36,7 +36,7 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
|
|
|
36
36
|
vectors_np = np.squeeze(vectors[features_col].values)
|
|
37
37
|
index = create_hnswlib_index_instance(self.index_params, init=True)
|
|
38
38
|
|
|
39
|
-
if ids_col:
|
|
39
|
+
if ids_col is not None:
|
|
40
40
|
index.add_items(np.stack(vectors_np), vectors[ids_col].values)
|
|
41
41
|
else:
|
|
42
42
|
index.add_items(np.stack(vectors_np))
|