replay-rec 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +9 -9
  3. replay/data/nn/__init__.py +6 -6
  4. replay/data/nn/sequence_tokenizer.py +44 -38
  5. replay/data/nn/sequential_dataset.py +13 -8
  6. replay/data/nn/torch_sequential_dataset.py +14 -13
  7. replay/data/nn/utils.py +1 -1
  8. replay/metrics/base_metric.py +1 -1
  9. replay/metrics/coverage.py +7 -11
  10. replay/metrics/experiment.py +3 -3
  11. replay/metrics/offline_metrics.py +2 -2
  12. replay/models/__init__.py +19 -0
  13. replay/models/association_rules.py +1 -4
  14. replay/models/base_neighbour_rec.py +6 -9
  15. replay/models/base_rec.py +44 -293
  16. replay/models/cat_pop_rec.py +2 -1
  17. replay/models/common.py +69 -0
  18. replay/models/extensions/ann/ann_mixin.py +30 -25
  19. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  20. replay/models/extensions/ann/utils.py +4 -3
  21. replay/models/knn.py +18 -17
  22. replay/models/nn/sequential/bert4rec/dataset.py +1 -1
  23. replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
  24. replay/models/nn/sequential/compiled/__init__.py +10 -0
  25. replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
  26. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  27. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  28. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  29. replay/models/nn/sequential/sasrec/model.py +1 -1
  30. replay/models/optimization/__init__.py +14 -0
  31. replay/models/optimization/optuna_mixin.py +279 -0
  32. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  33. replay/models/slim.py +2 -4
  34. replay/models/word2vec.py +7 -12
  35. replay/preprocessing/discretizer.py +1 -2
  36. replay/preprocessing/history_based_fp.py +1 -1
  37. replay/preprocessing/label_encoder.py +1 -1
  38. replay/splitters/cold_user_random_splitter.py +13 -7
  39. replay/splitters/last_n_splitter.py +17 -10
  40. replay/utils/__init__.py +6 -2
  41. replay/utils/common.py +4 -2
  42. replay/utils/model_handler.py +11 -31
  43. replay/utils/session_handler.py +2 -2
  44. replay/utils/spark_utils.py +2 -2
  45. replay/utils/types.py +28 -18
  46. replay/utils/warnings.py +26 -0
  47. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -32
  48. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/RECORD +51 -47
  49. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
  50. replay_rec-0.20.0.dist-info/licenses/NOTICE +41 -0
  51. replay/optimization/__init__.py +0 -5
  52. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
@@ -7,7 +7,7 @@ from abc import ABC
7
7
  from typing import Any, Dict, Iterable, Optional, Union
8
8
 
9
9
  from replay.data.dataset import Dataset
10
- from replay.utils import PYSPARK_AVAILABLE, MissingImportType, SparkDataFrame
10
+ from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
11
11
 
12
12
  from .base_rec import Recommender
13
13
  from .extensions.ann.ann_mixin import ANNMixin
@@ -16,10 +16,10 @@ if PYSPARK_AVAILABLE:
16
16
  from pyspark.sql import functions as sf
17
17
  from pyspark.sql.column import Column
18
18
  else:
19
- Column = MissingImportType
19
+ Column = MissingImport
20
20
 
21
21
 
22
- class NeighbourRec(Recommender, ANNMixin, ABC):
22
+ class NeighbourRec(ANNMixin, Recommender, ABC):
23
23
  """Base class that requires interactions at prediction time"""
24
24
 
25
25
  similarity: Optional[SparkDataFrame]
@@ -187,16 +187,13 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
187
187
  "similarity" if metric is None else metric,
188
188
  )
189
189
 
190
- def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
190
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
191
+ similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
191
192
  self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
192
- return {
193
+ return similarity_df, {
193
194
  "features_col": None,
194
195
  }
195
196
 
196
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
197
- similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
198
- return similarity_df
199
-
200
197
  def _get_vectors_to_infer_ann_inner(
201
198
  self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
202
199
  ) -> SparkDataFrame:
replay/models/base_rec.py CHANGED
@@ -11,22 +11,18 @@ Base abstract classes:
11
11
  with popularity statistics
12
12
  """
13
13
 
14
- import logging
15
14
  import warnings
16
15
  from abc import ABC, abstractmethod
17
- from copy import deepcopy
18
16
  from os.path import join
19
- from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
18
 
21
19
  import numpy as np
22
20
  import pandas as pd
23
21
  from numpy.random import default_rng
24
- from optuna import create_study
25
- from optuna.samplers import TPESampler
26
22
 
27
23
  from replay.data import Dataset, get_schema
28
- from replay.metrics import NDCG, Metric
29
- from replay.optimization.optuna_objective import MainObjective, SplitData
24
+ from replay.models.common import RecommenderCommons
25
+ from replay.models.optimization import IsOptimizible
30
26
  from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
31
27
  from replay.utils.session_handler import State
32
28
  from replay.utils.spark_utils import SparkCollectToMasterWarning
@@ -38,10 +34,8 @@ if PYSPARK_AVAILABLE:
38
34
  )
39
35
 
40
36
  from replay.utils.spark_utils import (
41
- cache_temp_view,
42
37
  convert2spark,
43
38
  cosine_similarity,
44
- drop_temp_view,
45
39
  filter_cold,
46
40
  get_top_k,
47
41
  get_top_k_recs,
@@ -88,80 +82,12 @@ class IsSavable(ABC):
88
82
  """
89
83
 
90
84
 
91
- class RecommenderCommons:
92
- """
93
- Common methods and attributes of RePlay models for caching, setting parameters and logging
94
- """
95
-
96
- _logger: Optional[logging.Logger] = None
97
- cached_dfs: Optional[Set] = None
98
- query_column: str
99
- item_column: str
100
- rating_column: str
101
- timestamp_column: str
102
-
103
- def set_params(self, **params: Dict[str, Any]) -> None:
104
- """
105
- Set model parameters
106
-
107
- :param params: dictionary param name - param value
108
- :return:
109
- """
110
- for param, value in params.items():
111
- setattr(self, param, value)
112
- self._clear_cache()
113
-
114
- def _clear_cache(self):
115
- """
116
- Clear spark cache
117
- """
118
-
119
- def __str__(self):
120
- return type(self).__name__
121
-
122
- @property
123
- def logger(self) -> logging.Logger:
124
- """
125
- :returns: get library logger
126
- """
127
- if self._logger is None:
128
- self._logger = logging.getLogger("replay")
129
- return self._logger
130
-
131
- def _cache_model_temp_view(self, df: SparkDataFrame, df_name: str) -> None:
132
- """
133
- Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
134
- Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
135
- """
136
- full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
137
- cache_temp_view(df, full_name)
138
-
139
- if self.cached_dfs is None:
140
- self.cached_dfs = set()
141
- self.cached_dfs.add(full_name)
142
-
143
- def _clear_model_temp_view(self, df_name: str) -> None:
144
- """
145
- Uncache and drop Spark SQL temporary view and remove from self.cached_dfs
146
- Temp view to replace will be constructed as
147
- "id_<python object id>_model_<RePlay model name>_<df_name>"
148
- """
149
- full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
150
- drop_temp_view(full_name)
151
- if self.cached_dfs is not None:
152
- self.cached_dfs.discard(full_name)
153
-
154
-
155
- class BaseRecommender(RecommenderCommons, IsSavable, ABC):
85
+ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
156
86
  """Base recommender"""
157
87
 
158
88
  model: Any
159
89
  can_predict_cold_queries: bool = False
160
90
  can_predict_cold_items: bool = False
161
- _search_space: Optional[Dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
162
- _objective = MainObjective
163
- study = None
164
- criterion = None
165
91
  fit_queries: SparkDataFrame
166
92
  fit_items: SparkDataFrame
167
93
  _num_queries: int
@@ -169,202 +95,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
169
95
  _query_dim_size: int
170
96
  _item_dim_size: int
171
97
 
172
- def optimize(
173
- self,
174
- train_dataset: Dataset,
175
- test_dataset: Dataset,
176
- param_borders: Optional[Dict[str, List[Any]]] = None,
177
- criterion: Metric = NDCG,
178
- k: int = 10,
179
- budget: int = 10,
180
- new_study: bool = True,
181
- ) -> Optional[Dict[str, Any]]:
182
- """
183
- Searches the best parameters with optuna.
184
-
185
- :param train_dataset: train data
186
- :param test_dataset: test data
187
- :param param_borders: a dictionary with search borders, where
188
- key is the parameter name and value is the range of possible values
189
- ``{param: [low, high]}``. In case of categorical parameters it is
190
- all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
191
- :param criterion: metric to use for optimization
192
- :param k: recommendation list length
193
- :param budget: number of points to try
194
- :param new_study: keep searching with previous study or start a new study
195
- :return: dictionary with best parameters
196
- """
197
- self.query_column = train_dataset.feature_schema.query_id_column
198
- self.item_column = train_dataset.feature_schema.item_id_column
199
- self.rating_column = train_dataset.feature_schema.interactions_rating_column
200
- self.timestamp_column = train_dataset.feature_schema.interactions_timestamp_column
201
-
202
- self.criterion = criterion(
203
- topk=k,
204
- query_column=self.query_column,
205
- item_column=self.item_column,
206
- rating_column=self.rating_column,
207
- )
208
-
209
- if self._search_space is None:
210
- self.logger.warning("%s has no hyper parameters to optimize", str(self))
211
- return None
212
-
213
- if self.study is None or new_study:
214
- self.study = create_study(direction="maximize", sampler=TPESampler())
215
-
216
- search_space = self._prepare_param_borders(param_borders)
217
- if self._init_params_in_search_space(search_space) and not self._params_tried():
218
- self.study.enqueue_trial(self._init_args)
219
-
220
- split_data = self._prepare_split_data(train_dataset, test_dataset)
221
- objective = self._objective(
222
- search_space=search_space,
223
- split_data=split_data,
224
- recommender=self,
225
- criterion=self.criterion,
226
- k=k,
227
- )
228
-
229
- self.study.optimize(objective, budget)
230
- best_params = self.study.best_params
231
- self.set_params(**best_params)
232
- return best_params
233
-
234
- def _init_params_in_search_space(self, search_space):
235
- """Check if model params are inside search space"""
236
- params = self._init_args
237
- outside_search_space = {}
238
- for param, value in params.items():
239
- if param not in search_space:
240
- continue
241
- borders = search_space[param]["args"]
242
- param_type = search_space[param]["type"]
243
-
244
- extra_category = param_type == "categorical" and value not in borders
245
- param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
246
- if extra_category or param_out_of_bounds:
247
- outside_search_space[param] = {
248
- "borders": borders,
249
- "value": value,
250
- }
251
-
252
- if outside_search_space:
253
- self.logger.debug(
254
- "Model is initialized with parameters outside the search space: %s."
255
- "Initial parameters will not be evaluated during optimization."
256
- "Change search spare with 'param_borders' argument if necessary",
257
- outside_search_space,
258
- )
259
- return False
260
- else:
261
- return True
262
-
263
- def _prepare_param_borders(
264
- self, param_borders: Optional[Dict[str, List[Any]]] = None
265
- ) -> Dict[str, Dict[str, List[Any]]]:
266
- """
267
- Checks if param borders are valid and convert them to a search_space format
268
-
269
- :param param_borders: a dictionary with search grid, where
270
- key is the parameter name and value is the range of possible values
271
- ``{param: [low, high]}``.
272
- :return:
273
- """
274
- search_space = deepcopy(self._search_space)
275
- if param_borders is None:
276
- return search_space
277
-
278
- for param, borders in param_borders.items():
279
- self._check_borders(param, borders)
280
- search_space[param]["args"] = borders
281
-
282
- # Optuna trials should contain all searchable parameters
283
- # to be able to correctly return best params
284
- # If used didn't specify some params to be tested optuna still needs to suggest them
285
- # This part makes sure this suggestion will be constant
286
- args = self._init_args
287
- missing_borders = {param: args[param] for param in search_space if param not in param_borders}
288
- for param, value in missing_borders.items():
289
- if search_space[param]["type"] == "categorical":
290
- search_space[param]["args"] = [value]
291
- else:
292
- search_space[param]["args"] = [value, value]
293
-
294
- return search_space
295
-
296
- def _check_borders(self, param, borders):
297
- """Raise value error if param borders are not valid"""
298
- if param not in self._search_space:
299
- msg = f"Hyper parameter {param} is not defined for {self!s}"
300
- raise ValueError(msg)
301
- if not isinstance(borders, list):
302
- msg = f"Parameter {param} borders are not a list"
303
- raise ValueError()
304
- if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
305
- msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
306
- raise ValueError(msg)
307
-
308
- def _prepare_split_data(
309
- self,
310
- train_dataset: Dataset,
311
- test_dataset: Dataset,
312
- ) -> SplitData:
313
- """
314
- This method converts data to spark and packs it into a named tuple to pass into optuna.
315
-
316
- :param train_dataset: train data
317
- :param test_dataset: test data
318
- :return: packed PySpark DataFrames
319
- """
320
- train = self._filter_dataset_features(train_dataset)
321
- test = self._filter_dataset_features(test_dataset)
322
- queries = test_dataset.interactions.select(self.query_column).distinct()
323
- items = test_dataset.interactions.select(self.item_column).distinct()
324
-
325
- split_data = SplitData(
326
- train,
327
- test,
328
- queries,
329
- items,
330
- )
331
- return split_data
332
-
333
- @staticmethod
334
- def _filter_dataset_features(
335
- dataset: Dataset,
336
- ) -> Dataset:
337
- """
338
- Filter features of dataset to match with items and queries of interactions
339
-
340
- :param dataset: dataset with interactions and features
341
- :return: filtered dataset
342
- """
343
- if dataset.query_features is None and dataset.item_features is None:
344
- return dataset
345
-
346
- query_features = None
347
- item_features = None
348
- if dataset.query_features is not None:
349
- query_features = dataset.query_features.join(
350
- dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
351
- on=dataset.feature_schema.query_id_column,
352
- )
353
- if dataset.item_features is not None:
354
- item_features = dataset.item_features.join(
355
- dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
356
- on=dataset.feature_schema.item_id_column,
357
- )
358
-
359
- return Dataset(
360
- feature_schema=dataset.feature_schema,
361
- interactions=dataset.interactions,
362
- query_features=query_features,
363
- item_features=item_features,
364
- check_consistency=False,
365
- categorical_encoded=False,
366
- )
367
-
368
98
  def _fit_wrap(
369
99
  self,
370
100
  dataset: Dataset,
@@ -418,7 +148,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
418
148
  :return:
419
149
  """
420
150
 
421
- def _filter_seen(self, recs: SparkDataFrame, interactions: SparkDataFrame, k: int, queries: SparkDataFrame):
151
+ def _filter_seen(
152
+ self,
153
+ recs: SparkDataFrame,
154
+ interactions: SparkDataFrame,
155
+ k: int,
156
+ queries: SparkDataFrame,
157
+ ):
422
158
  """
423
159
  Filter seen items (presented in interactions) out of the queries' recommendations.
424
160
  For each query return from `k` to `k + number of seen by query` recommendations.
@@ -579,11 +315,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
579
315
  Warn if cold entities are present in the `main_df`.
580
316
  """
581
317
  can_predict_cold = self.can_predict_cold_queries if entity == "query" else self.can_predict_cold_items
582
- fit_entities = self.fit_queries if entity == "query" else self.fit_items
583
- column = self.query_column if entity == "query" else self.item_column
584
318
  if can_predict_cold:
585
319
  return main_df, interactions_df
586
320
 
321
+ fit_entities = self.fit_queries if entity == "query" else self.fit_items
322
+ column = self.query_column if entity == "query" else self.item_column
323
+
587
324
  num_new, main_df = filter_cold(main_df, fit_entities, col_name=column)
588
325
  if num_new > 0:
589
326
  self.logger.info(
@@ -622,7 +359,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
622
359
  """
623
360
 
624
361
  def _predict_proba(
625
- self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
362
+ self,
363
+ dataset: Dataset,
364
+ k: int,
365
+ queries: SparkDataFrame,
366
+ items: SparkDataFrame,
367
+ filter_seen_items: bool = True,
626
368
  ) -> np.ndarray:
627
369
  """
628
370
  Inner method where model actually predicts probability estimates.
@@ -767,7 +509,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
767
509
  """
768
510
  if dataset is not None:
769
511
  interactions, query_features, item_features, pairs = [
770
- convert2spark(df) for df in [dataset.interactions, dataset.query_features, dataset.item_features, pairs]
512
+ convert2spark(df)
513
+ for df in [
514
+ dataset.interactions,
515
+ dataset.query_features,
516
+ dataset.item_features,
517
+ pairs,
518
+ ]
771
519
  ]
772
520
  if set(pairs.columns) != {self.item_column, self.query_column}:
773
521
  msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
@@ -903,21 +651,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
903
651
 
904
652
  def _get_nearest_items(
905
653
  self,
906
- items: SparkDataFrame, # noqa: ARG002
907
- metric: Optional[str] = None, # noqa: ARG002
908
- candidates: Optional[SparkDataFrame] = None, # noqa: ARG002
654
+ items: SparkDataFrame,
655
+ metric: Optional[str] = None,
656
+ candidates: Optional[SparkDataFrame] = None,
909
657
  ) -> Optional[SparkDataFrame]:
910
658
  msg = f"item-to-item prediction is not implemented for {self}"
911
659
  raise NotImplementedError(msg)
912
660
 
913
- def _params_tried(self):
914
- """check if current parameters were already evaluated"""
915
- if self.study is None:
916
- return False
917
-
918
- params = {name: value for name, value in self._init_args.items() if name in self._search_space}
919
- return any(params == trial.params for trial in self.study.trials)
920
-
921
661
  def _save_model(self, path: str, additional_params: Optional[dict] = None):
922
662
  saved_params = {
923
663
  "query_column": self.query_column,
@@ -1496,7 +1236,11 @@ class NonPersonalizedRecommender(Recommender, ABC):
1496
1236
  # 'selected_item_popularity' truncation by k + max_seen
1497
1237
  max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1498
1238
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1499
- return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1239
+ return queries.join(
1240
+ selected_item_popularity,
1241
+ on=(sf.col("rank") <= k + sf.col("num_items")),
1242
+ how="left",
1243
+ )
1500
1244
 
1501
1245
  return queries.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
1502
1246
 
@@ -1555,7 +1299,9 @@ class NonPersonalizedRecommender(Recommender, ABC):
1555
1299
  rating_column = self.rating_column
1556
1300
  class_name = self.__class__.__name__
1557
1301
 
1558
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame: # pragma: no cover
1302
+ def grouped_map(
1303
+ pandas_df: PandasDataFrame,
1304
+ ) -> PandasDataFrame: # pragma: no cover
1559
1305
  query_idx = pandas_df[query_column][0]
1560
1306
  cnt = pandas_df["cnt"][0]
1561
1307
 
@@ -1640,7 +1386,12 @@ class NonPersonalizedRecommender(Recommender, ABC):
1640
1386
  )
1641
1387
 
1642
1388
  def _predict_proba(
1643
- self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
1389
+ self,
1390
+ dataset: Dataset,
1391
+ k: int,
1392
+ queries: SparkDataFrame,
1393
+ items: SparkDataFrame,
1394
+ filter_seen_items: bool = True,
1644
1395
  ) -> np.ndarray:
1645
1396
  """
1646
1397
  Inner method where model actually predicts probability estimates.
@@ -4,7 +4,8 @@ from typing import Iterable, Optional, Union
4
4
  from replay.data import Dataset
5
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
6
6
 
7
- from .base_rec import IsSavable, RecommenderCommons
7
+ from .base_rec import IsSavable
8
+ from .common import RecommenderCommons
8
9
 
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql import functions as sf
@@ -0,0 +1,69 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+
4
+ from replay.utils import SparkDataFrame
5
+ from replay.utils.spark_utils import cache_temp_view, drop_temp_view
6
+
7
+
8
+ class RecommenderCommons:
9
+ """
10
+ Common methods and attributes of RePlay models for caching, setting parameters and logging
11
+ """
12
+
13
+ _logger: Optional[logging.Logger] = None
14
+ cached_dfs: Optional[set] = None
15
+ query_column: str
16
+ item_column: str
17
+ rating_column: str
18
+ timestamp_column: str
19
+
20
+ def set_params(self, **params: dict[str, Any]) -> None:
21
+ """
22
+ Set model parameters
23
+
24
+ :param params: dictionary param name - param value
25
+ :return:
26
+ """
27
+ for param, value in params.items():
28
+ setattr(self, param, value)
29
+ self._clear_cache()
30
+
31
+ def _clear_cache(self):
32
+ """
33
+ Clear spark cache
34
+ """
35
+
36
+ def __str__(self):
37
+ return type(self).__name__
38
+
39
+ @property
40
+ def logger(self) -> logging.Logger:
41
+ """
42
+ :returns: get library logger
43
+ """
44
+ if self._logger is None:
45
+ self._logger = logging.getLogger("replay")
46
+ return self._logger
47
+
48
+ def _cache_model_temp_view(self, df: SparkDataFrame, df_name: str) -> None:
49
+ """
50
+ Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
51
+ Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
52
+ """
53
+ full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
54
+ cache_temp_view(df, full_name)
55
+
56
+ if self.cached_dfs is None:
57
+ self.cached_dfs = set()
58
+ self.cached_dfs.add(full_name)
59
+
60
+ def _clear_model_temp_view(self, df_name: str) -> None:
61
+ """
62
+ Uncache and drop Spark SQL temporary view and remove from self.cached_dfs
63
+ Temp view to replace will be constructed as
64
+ "id_<python object id>_model_<RePlay model name>_<df_name>"
65
+ """
66
+ full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
67
+ drop_temp_view(full_name)
68
+ if self.cached_dfs is not None:
69
+ self.cached_dfs.discard(full_name)
@@ -1,11 +1,12 @@
1
1
  import importlib
2
2
  import logging
3
+ import sys
3
4
  from abc import abstractmethod
4
- from typing import Any, Dict, Iterable, Optional, Union
5
+ from typing import Any, Iterable, Optional, Union
5
6
 
6
7
  from replay.data import Dataset
7
- from replay.models.base_rec import BaseRecommender
8
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
8
+ from replay.models.common import RecommenderCommons
9
+ from replay.utils import ANN_AVAILABLE, PYSPARK_AVAILABLE, FeatureUnavailableError, SparkDataFrame
9
10
 
10
11
  from .index_builders.base_index_builder import IndexBuilder
11
12
 
@@ -16,18 +17,32 @@ if PYSPARK_AVAILABLE:
16
17
 
17
18
  from .index_stores.spark_files_index_store import SparkFilesIndexStore
18
19
 
19
-
20
20
  logger = logging.getLogger("replay")
21
21
 
22
22
 
23
- class ANNMixin(BaseRecommender):
23
+ class ANNMixin(RecommenderCommons):
24
24
  """
25
25
  This class overrides the `_fit_wrap` and `_predict_wrap` methods of the base class,
26
26
  adding an index construction in the `_fit_wrap` step
27
27
  and an index inference in the `_predict_wrap` step.
28
28
  """
29
29
 
30
- index_builder: Optional[IndexBuilder] = None
30
+ index_builder: Optional["IndexBuilder"] = None
31
+
32
+ def init_index_builder(self, index_builder: Optional[IndexBuilder] = None) -> None:
33
+ if index_builder is not None and not ANN_AVAILABLE:
34
+ err = FeatureUnavailableError(
35
+ "`index_builder` can only be provided when all ANN dependencies are installed."
36
+ )
37
+ if sys.version_info >= (3, 11): # pragma: py-lt-311
38
+ err.add_note(
39
+ "To enable ANN, ensure you have both 'hnswlib' and 'fixed-install-nmslib' packages installed."
40
+ )
41
+ raise err
42
+ elif isinstance(index_builder, IndexBuilder):
43
+ self.index_builder = index_builder
44
+ elif isinstance(index_builder, dict):
45
+ self.init_builder_from_dict(index_builder)
31
46
 
32
47
  @property
33
48
  def _use_ann(self) -> bool:
@@ -39,26 +54,17 @@ class ANNMixin(BaseRecommender):
39
54
  return self.index_builder is not None
40
55
 
41
56
  @abstractmethod
42
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame:
43
- """Implementations of this method must return a dataframe with item vectors.
44
- Item vectors from this method are used to build the index.
45
-
46
- Args:
47
- log: DataFrame with interactions
48
-
49
- Returns: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
50
- Column names in dataframe can be anything.
51
- """
52
-
53
- @abstractmethod
54
- def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
57
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> tuple[SparkDataFrame, dict]:
55
58
  """Implementation of this method must return dictionary
56
- with arguments for `_build_ann_index` method.
59
+ with arguments for for an index builder's`build_index` method.
57
60
 
58
61
  Args:
59
62
  interactions: DataFrame with interactions
60
63
 
61
- Returns: Dictionary with arguments to build index. For example: {
64
+ Returns:
65
+ vectors: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
66
+ Column names in dataframe can be anything.
67
+ ann_params: Dictionary with arguments to build index. For example: {
62
68
  "id_col": "item_idx",
63
69
  "features_col": "item_factors",
64
70
  ...
@@ -79,8 +85,7 @@ class ANNMixin(BaseRecommender):
79
85
  super()._fit_wrap(dataset)
80
86
 
81
87
  if self._use_ann:
82
- vectors = self._get_vectors_to_build_ann(dataset.interactions)
83
- ann_params = self._get_ann_build_params(dataset.interactions)
88
+ vectors, ann_params = self._configure_index_builder(dataset.interactions)
84
89
  self.index_builder.build_index(vectors, **ann_params)
85
90
 
86
91
  @abstractmethod
@@ -123,11 +128,11 @@ class ANNMixin(BaseRecommender):
123
128
  return queries
124
129
 
125
130
  @abstractmethod
126
- def _get_ann_infer_params(self) -> Dict[str, Any]:
131
+ def _get_ann_infer_params(self) -> dict[str, Any]:
127
132
  """Implementation of this method must return dictionary
128
133
  with arguments for `_infer_ann_index` method.
129
134
 
130
- Returns: Dictionary with arguments to infer index. For example: {
135
+ Returns: dictionary with arguments to infer index. For example: {
131
136
  "features_col": "user_vector",
132
137
  ...
133
138
  }
@@ -36,7 +36,7 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
36
36
  vectors_np = np.squeeze(vectors[features_col].values)
37
37
  index = create_hnswlib_index_instance(self.index_params, init=True)
38
38
 
39
- if ids_col:
39
+ if ids_col is not None:
40
40
  index.add_items(np.stack(vectors_np), vectors[ids_col].values)
41
41
  else:
42
42
  index.add_items(np.stack(vectors_np))