replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -2,9 +2,10 @@ from os.path import join
2
2
  from typing import Iterable, Optional, Union
3
3
 
4
4
  from replay.data import Dataset
5
- from .base_rec import IsSavable, RecommenderCommons
6
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
7
6
 
7
+ from .base_rec import IsSavable, RecommenderCommons
8
+
8
9
  if PYSPARK_AVAILABLE:
9
10
  from pyspark.sql import functions as sf
10
11
 
@@ -18,7 +19,6 @@ if PYSPARK_AVAILABLE:
18
19
  )
19
20
 
20
21
 
21
- # pylint: disable=too-many-instance-attributes
22
22
  class CatPopRec(IsSavable, RecommenderCommons):
23
23
  """
24
24
  CatPopRec generate recommendation for item categories.
@@ -35,9 +35,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
35
35
  can_predict_cold_items: bool = False
36
36
  fit_items: SparkDataFrame
37
37
 
38
- def _generate_mapping(
39
- self, cat_tree: SparkDataFrame, max_iter: int = 20
40
- ) -> SparkDataFrame:
38
+ def _generate_mapping(self, cat_tree: SparkDataFrame, max_iter: int = 20) -> SparkDataFrame:
41
39
  """
42
40
  Create SparkDataFrame with mapping [`category`, `leaf_cat`]
43
41
  where `leaf_cat` is the lowest level categories of category tree,
@@ -49,9 +47,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
49
47
  :param max_iter: maximal number of iteration of descend through the category tree
50
48
  :return: SparkDataFrame with mapping [`category`, `leaf_cat`]
51
49
  """
52
- current_res = cat_tree.select(
53
- sf.col("category"), sf.col("category").alias("leaf_cat")
54
- )
50
+ current_res = cat_tree.select(sf.col("category"), sf.col("category").alias("leaf_cat"))
55
51
 
56
52
  i = 0
57
53
  res_size_growth = current_res.count()
@@ -108,9 +104,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
108
104
  """
109
105
  self.max_iter = max_iter
110
106
  if cat_tree is not None:
111
- self.leaf_cat_mapping = self._generate_mapping(
112
- cat_tree, max_iter=max_iter
113
- )
107
+ self.leaf_cat_mapping = self._generate_mapping(cat_tree, max_iter=max_iter)
114
108
 
115
109
  @property
116
110
  def _init_args(self):
@@ -165,7 +159,6 @@ class CatPopRec(IsSavable, RecommenderCommons):
165
159
  if hasattr(self, "leaf_cat_mapping"):
166
160
  self.leaf_cat_mapping.unpersist()
167
161
 
168
- # pylint: disable=arguments-differ
169
162
  def predict(
170
163
  self,
171
164
  categories: Union[SparkDataFrame, Iterable],
@@ -219,9 +212,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
219
212
  item_data = items or self.fit_items
220
213
  items = get_unique_entities(item_data, self.item_column)
221
214
 
222
- num_new, items = filter_cold(
223
- items, self.fit_items, col_name=self.item_column
224
- )
215
+ num_new, items = filter_cold(items, self.fit_items, col_name=self.item_column)
225
216
  if num_new > 0:
226
217
  self.logger.info(
227
218
  "%s model can't predict cold items, they will be ignored",
@@ -267,9 +258,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
267
258
  # find number of interactions in all leaf categories after filtering
268
259
  num_interactions_in_cat = (
269
260
  res.join(
270
- unique_leaf_cat_items.groupBy("leaf_cat").agg(
271
- sf.sum(self.rating_column).alias("sum_rating")
272
- ),
261
+ unique_leaf_cat_items.groupBy("leaf_cat").agg(sf.sum(self.rating_column).alias("sum_rating")),
273
262
  on="leaf_cat",
274
263
  )
275
264
  .groupBy("category")
@@ -284,9 +273,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
284
273
  .groupBy("category", self.item_column)
285
274
  .agg(sf.sum(self.rating_column).alias(self.rating_column))
286
275
  .join(num_interactions_in_cat, on="category")
287
- .withColumn(
288
- self.rating_column, sf.col(self.rating_column) / sf.col("sum_rating")
289
- )
276
+ .withColumn(self.rating_column, sf.col(self.rating_column) / sf.col("sum_rating"))
290
277
  )
291
278
 
292
279
  def _save_model(self, path: str):
@@ -296,7 +283,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
296
283
  "item_column": self.item_column,
297
284
  "rating_column": self.rating_column,
298
285
  },
299
- join(path, "params.dump")
286
+ join(path, "params.dump"),
300
287
  )
301
288
 
302
289
  def _load_model(self, path: str):
replay/models/cluster.py CHANGED
@@ -2,9 +2,10 @@ from os.path import join
2
2
  from typing import Optional
3
3
 
4
4
  from replay.data.dataset import Dataset
5
- from .base_rec import QueryRecommender
6
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
7
6
 
7
+ from .base_rec import QueryRecommender
8
+
8
9
  if PYSPARK_AVAILABLE:
9
10
  from pyspark.ml.clustering import KMeans, KMeansModel
10
11
  from pyspark.ml.feature import VectorAssembler
@@ -58,12 +59,10 @@ class ClusterRec(QueryRecommender):
58
59
  sf.count(self.item_column).alias("item_count")
59
60
  )
60
61
 
61
- max_count_per_cluster = self.item_rel_in_cluster.groupby(
62
- "cluster"
63
- ).agg(sf.max("item_count").alias("max_count_in_cluster"))
64
- self.item_rel_in_cluster = self.item_rel_in_cluster.join(
65
- max_count_per_cluster, on="cluster"
62
+ max_count_per_cluster = self.item_rel_in_cluster.groupby("cluster").agg(
63
+ sf.max("item_count").alias("max_count_in_cluster")
66
64
  )
65
+ self.item_rel_in_cluster = self.item_rel_in_cluster.join(max_count_per_cluster, on="cluster")
67
66
  self.item_rel_in_cluster = self.item_rel_in_cluster.withColumn(
68
67
  self.rating_column, sf.col("item_count") / sf.col("max_count_in_cluster")
69
68
  ).drop("item_count", "max_count_in_cluster")
@@ -83,47 +82,38 @@ class ClusterRec(QueryRecommender):
83
82
  return vec.transform(query_features).select(self.query_column, "features")
84
83
 
85
84
  def _make_query_clusters(self, queries, query_features):
86
-
87
85
  query_cnt_in_fv = (
88
- query_features
89
- .select(self.query_column)
90
- .distinct()
91
- .join(queries.distinct(), on=self.query_column)
92
- .count()
86
+ query_features.select(self.query_column).distinct().join(queries.distinct(), on=self.query_column).count()
93
87
  )
94
88
 
95
89
  query_cnt = queries.distinct().count()
96
90
 
97
91
  if query_cnt_in_fv < query_cnt:
98
- self.logger.info("%s query(s) don't "
99
- "have a feature vector. "
100
- "The results will not be calculated for them.",
101
- query_cnt - query_cnt_in_fv)
92
+ self.logger.info(
93
+ "%s query(s) don't have a feature vector. The results will not be calculated for them.",
94
+ query_cnt - query_cnt_in_fv,
95
+ )
102
96
 
103
- query_features_vector = self._transform_features(
104
- query_features.join(queries, on=self.query_column)
105
- )
97
+ query_features_vector = self._transform_features(query_features.join(queries, on=self.query_column))
106
98
  return (
107
99
  self.model.transform(query_features_vector)
108
100
  .select(self.query_column, "prediction")
109
101
  .withColumnRenamed("prediction", "cluster")
110
102
  )
111
103
 
112
- # pylint: disable=too-many-arguments
113
104
  def _predict(
114
105
  self,
115
106
  dataset: Dataset,
116
- k: int,
107
+ k: int, # noqa: ARG002
117
108
  queries: SparkDataFrame,
118
109
  items: SparkDataFrame,
119
- filter_seen_items: bool = True,
110
+ filter_seen_items: bool = True, # noqa: ARG002
120
111
  ) -> SparkDataFrame:
121
112
  query_clusters = self._make_query_clusters(queries, dataset.query_features)
122
113
  filtered_items = self.item_rel_in_cluster.join(items, on=self.item_column)
123
114
  pred = query_clusters.join(filtered_items, on="cluster").drop("cluster")
124
115
  return pred
125
116
 
126
- # pylint: disable=signature-differs
127
117
  def _predict_pairs(
128
118
  self,
129
119
  pairs: SparkDataFrame,
@@ -131,9 +121,8 @@ class ClusterRec(QueryRecommender):
131
121
  ) -> SparkDataFrame:
132
122
  query_clusters = self._make_query_clusters(pairs.select(self.query_column).distinct(), dataset.query_features)
133
123
  pairs_with_clusters = pairs.join(query_clusters, on=self.query_column)
134
- filtered_items = (self.item_rel_in_cluster
135
- .join(pairs.select(self.item_column).distinct(), on=self.item_column))
136
- pred = (pairs_with_clusters
137
- .join(filtered_items, on=["cluster", self.item_column])
138
- .select(self.query_column,self.item_column,self.rating_column))
124
+ filtered_items = self.item_rel_in_cluster.join(pairs.select(self.item_column).distinct(), on=self.item_column)
125
+ pred = pairs_with_clusters.join(filtered_items, on=["cluster", self.item_column]).select(
126
+ self.query_column, self.item_column, self.rating_column
127
+ )
139
128
  return pred
@@ -5,15 +5,17 @@ from typing import Any, Dict, Iterable, Optional, Union
5
5
 
6
6
  from replay.data import Dataset
7
7
  from replay.models.base_rec import BaseRecommender
8
- from .index_builders.base_index_builder import IndexBuilder
9
8
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
10
9
 
10
+ from .index_builders.base_index_builder import IndexBuilder
11
+
11
12
  if PYSPARK_AVAILABLE:
12
13
  from pyspark.sql import functions as sf
13
14
 
14
- from .index_stores.spark_files_index_store import SparkFilesIndexStore
15
15
  from replay.utils.spark_utils import get_top_k_recs, return_recs
16
16
 
17
+ from .index_stores.spark_files_index_store import SparkFilesIndexStore
18
+
17
19
 
18
20
  logger = logging.getLogger("replay")
19
21
 
@@ -82,9 +84,7 @@ class ANNMixin(BaseRecommender):
82
84
  self.index_builder.build_index(vectors, **ann_params)
83
85
 
84
86
  @abstractmethod
85
- def _get_vectors_to_infer_ann_inner(
86
- self, interactions: SparkDataFrame, queries: SparkDataFrame
87
- ) -> SparkDataFrame:
87
+ def _get_vectors_to_infer_ann_inner(self, interactions: SparkDataFrame, queries: SparkDataFrame) -> SparkDataFrame:
88
88
  """Implementations of this method must return a dataframe with user vectors.
89
89
  User vectors from this method are used to infer the index.
90
90
 
@@ -134,7 +134,6 @@ class ANNMixin(BaseRecommender):
134
134
 
135
135
  """
136
136
 
137
- # pylint: disable=too-many-arguments, too-many-locals
138
137
  def _predict_wrap(
139
138
  self,
140
139
  dataset: Optional[Dataset],
@@ -144,14 +143,10 @@ class ANNMixin(BaseRecommender):
144
143
  filter_seen_items: bool = True,
145
144
  recs_file_path: Optional[str] = None,
146
145
  ) -> Optional[SparkDataFrame]:
147
- dataset, queries, items = self._filter_interactions_queries_items_dataframes(
148
- dataset, k, queries, items
149
- )
146
+ dataset, queries, items = self._filter_interactions_queries_items_dataframes(dataset, k, queries, items)
150
147
 
151
148
  if self._use_ann:
152
- vectors = self._get_vectors_to_infer_ann(
153
- dataset.interactions, queries, filter_seen_items
154
- )
149
+ vectors = self._get_vectors_to_infer_ann(dataset.interactions, queries, filter_seen_items)
155
150
  ann_params = self._get_ann_infer_params()
156
151
  inferer = self.index_builder.produce_inferer(filter_seen_items)
157
152
  recs = inferer.infer(vectors, ann_params["features_col"], k)
@@ -9,7 +9,7 @@ class BaseHnswParam:
9
9
  """
10
10
 
11
11
  space: str
12
- m: int = 200 # pylint: disable=invalid-name
12
+ m: int = 200
13
13
  ef_c: int = 20000
14
14
  post: int = 0
15
15
  ef_s: Optional[int] = None
@@ -59,9 +59,3 @@ class HnswlibParam(BaseHnswParam):
59
59
  dim: int = field(default=None, init=False)
60
60
  # Max number of elements that will be stored in the index
61
61
  max_elements: int = field(default=None, init=False)
62
-
63
- # def init_args_as_dict(self):
64
- # # union dicts
65
- # return dict(
66
- # super().init_args_as_dict()["init_args"], **{"space": self.space}
67
- # )
@@ -65,9 +65,3 @@ class NmslibHnswParam(BaseHnswParam):
65
65
  items_count: Optional[int] = field(default=None, init=False)
66
66
 
67
67
  method: ClassVar[str] = "hnsw"
68
-
69
- # def init_args_as_dict(self):
70
- # # union dicts
71
- # return dict(
72
- # super().init_args_as_dict()["init_args"], **{"space": self.space}
73
- # )
@@ -3,7 +3,6 @@ from typing import Optional
3
3
 
4
4
  import numpy as np
5
5
 
6
- from .base_index_builder import IndexBuilder
7
6
  from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
8
7
  from replay.models.extensions.ann.index_inferers.hnswlib_filter_index_inferer import HnswlibFilterIndexInferer
9
8
  from replay.models.extensions.ann.index_inferers.hnswlib_index_inferer import HnswlibIndexInferer
@@ -11,6 +10,8 @@ from replay.models.extensions.ann.utils import create_hnswlib_index_instance
11
10
  from replay.utils import SparkDataFrame
12
11
  from replay.utils.spark_utils import spark_to_pandas
13
12
 
13
+ from .base_index_builder import IndexBuilder
14
+
14
15
  logger = logging.getLogger("replay")
15
16
 
16
17
 
@@ -21,13 +22,10 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
21
22
 
22
23
  def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
23
24
  if filter_seen_items:
24
- return HnswlibFilterIndexInferer(
25
- self.index_params, self.index_store
26
- )
25
+ return HnswlibFilterIndexInferer(self.index_params, self.index_store)
27
26
  else:
28
27
  return HnswlibIndexInferer(self.index_params, self.index_store)
29
28
 
30
- # pylint: disable=no-member
31
29
  def build_index(
32
30
  self,
33
31
  vectors: SparkDataFrame,
@@ -43,8 +41,4 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
43
41
  else:
44
42
  index.add_items(np.stack(vectors_np))
45
43
 
46
- self.index_store.save_to_store(
47
- lambda path: index.save_index( # pylint: disable=unnecessary-lambda)
48
- path
49
- )
50
- )
44
+ self.index_store.save_to_store(lambda path: index.save_index(path))
@@ -1,14 +1,15 @@
1
1
  import logging
2
2
  from typing import Optional
3
3
 
4
- from .base_index_builder import IndexBuilder
5
- from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
6
4
  from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
7
5
  from replay.models.extensions.ann.index_inferers.nmslib_filter_index_inferer import NmslibFilterIndexInferer
8
6
  from replay.models.extensions.ann.index_inferers.nmslib_index_inferer import NmslibIndexInferer
9
7
  from replay.utils import SparkDataFrame
10
8
  from replay.utils.spark_utils import spark_to_pandas
11
9
 
10
+ from .base_index_builder import IndexBuilder
11
+ from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
12
+
12
13
  logger = logging.getLogger("replay")
13
14
 
14
15
 
@@ -19,20 +20,15 @@ class DriverNmslibIndexBuilder(IndexBuilder):
19
20
 
20
21
  def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
21
22
  if filter_seen_items:
22
- return NmslibFilterIndexInferer(
23
- self.index_params, self.index_store
24
- )
23
+ return NmslibFilterIndexInferer(self.index_params, self.index_store)
25
24
  else:
26
25
  return NmslibIndexInferer(self.index_params, self.index_store)
27
26
 
28
- # pylint: disable=no-member
29
27
  def build_index(
30
28
  self,
31
29
  vectors: SparkDataFrame,
32
- features_col: str,
33
- ids_col: Optional[str] = None,
30
+ features_col: str, # noqa: ARG002
31
+ ids_col: Optional[str] = None, # noqa: ARG002
34
32
  ):
35
33
  vectors = spark_to_pandas(vectors, self.allow_collect_to_master)
36
- NmslibIndexBuilderMixin.build_and_save_index(
37
- vectors, self.index_params, self.index_store
38
- )
34
+ NmslibIndexBuilderMixin.build_and_save_index(vectors, self.index_params, self.index_store)
@@ -3,13 +3,14 @@ from typing import Iterator, Optional
3
3
 
4
4
  import numpy as np
5
5
 
6
- from .base_index_builder import IndexBuilder
7
6
  from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
8
7
  from replay.models.extensions.ann.index_inferers.hnswlib_filter_index_inferer import HnswlibFilterIndexInferer
9
8
  from replay.models.extensions.ann.index_inferers.hnswlib_index_inferer import HnswlibIndexInferer
10
9
  from replay.models.extensions.ann.utils import create_hnswlib_index_instance
11
10
  from replay.utils import PandasDataFrame, SparkDataFrame
12
11
 
12
+ from .base_index_builder import IndexBuilder
13
+
13
14
  logger = logging.getLogger("replay")
14
15
 
15
16
 
@@ -20,9 +21,7 @@ class ExecutorHnswlibIndexBuilder(IndexBuilder):
20
21
 
21
22
  def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
22
23
  if filter_seen_items:
23
- return HnswlibFilterIndexInferer(
24
- self.index_params, self.index_store
25
- )
24
+ return HnswlibFilterIndexInferer(self.index_params, self.index_store)
26
25
  else:
27
26
  return HnswlibIndexInferer(self.index_params, self.index_store)
28
27
 
@@ -56,17 +55,11 @@ class ExecutorHnswlibIndexBuilder(IndexBuilder):
56
55
  # ids will be from [0, ..., len(vectors_np)]
57
56
  index.add_items(np.stack(vectors_np))
58
57
 
59
- _index_store.save_to_store(
60
- lambda path: index.save_index( # pylint: disable=unnecessary-lambda)
61
- path
62
- )
63
- )
58
+ _index_store.save_to_store(lambda path: index.save_index(path))
64
59
 
65
60
  yield PandasDataFrame(data={"_success": 1}, index=[0])
66
61
 
67
62
  # Here we perform materialization (`.collect()`) to build the hnsw index.
68
63
  cols = [ids_col, features_col] if ids_col else [features_col]
69
64
 
70
- vectors.select(*cols).mapInPandas(
71
- build_index_udf, "_success int"
72
- ).collect()
65
+ vectors.select(*cols).mapInPandas(build_index_udf, "_success int").collect()
@@ -3,13 +3,14 @@ from typing import Iterator, Optional
3
3
 
4
4
  import pandas as pd
5
5
 
6
- from .base_index_builder import IndexBuilder
7
- from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
8
6
  from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
9
7
  from replay.models.extensions.ann.index_inferers.nmslib_filter_index_inferer import NmslibFilterIndexInferer
10
8
  from replay.models.extensions.ann.index_inferers.nmslib_index_inferer import NmslibIndexInferer
11
9
  from replay.utils import PandasDataFrame, SparkDataFrame
12
10
 
11
+ from .base_index_builder import IndexBuilder
12
+ from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
13
+
13
14
  logger = logging.getLogger("replay")
14
15
 
15
16
 
@@ -20,9 +21,7 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
20
21
 
21
22
  def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
22
23
  if filter_seen_items:
23
- return NmslibFilterIndexInferer(
24
- self.index_params, self.index_store
25
- )
24
+ return NmslibFilterIndexInferer(self.index_params, self.index_store)
26
25
  else:
27
26
  return NmslibIndexInferer(self.index_params, self.index_store)
28
27
 
@@ -47,15 +46,9 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
47
46
  # with the same `item_idx_two`.
48
47
  # And therefore we cannot call the `addDataPointBatch` iteratively
49
48
  # (in build_and_save_index).
50
- pdfs = []
51
- for pdf in iterator:
52
- pdfs.append(pdf)
53
-
54
- pdf = pd.concat(pdfs)
49
+ pdf = pd.concat(list(iterator))
55
50
 
56
- NmslibIndexBuilderMixin.build_and_save_index(
57
- pdf, index_params, index_store
58
- )
51
+ NmslibIndexBuilderMixin.build_and_save_index(pdf, index_params, index_store)
59
52
 
60
53
  yield PandasDataFrame(data={"_success": 1}, index=[0])
61
54
 
@@ -64,8 +57,8 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
64
57
  def build_index(
65
58
  self,
66
59
  vectors: SparkDataFrame,
67
- features_col: str,
68
- ids_col: Optional[str] = None,
60
+ features_col: str, # noqa: ARG002
61
+ ids_col: Optional[str] = None, # noqa: ARG002
69
62
  ):
70
63
  # to execution in one executor
71
64
  vectors = vectors.repartition(1)
@@ -74,6 +67,6 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
74
67
  build_index_udf = self.make_build_index_udf()
75
68
 
76
69
  # Here we perform materialization (`.collect()`) to build the hnsw index.
77
- vectors.select(
78
- "similarity", "item_idx_one", "item_idx_two"
79
- ).mapInPandas(build_index_udf, "_success int").collect()
70
+ vectors.select("similarity", "item_idx_one", "item_idx_two").mapInPandas(
71
+ build_index_udf, "_success int"
72
+ ).collect()
@@ -6,7 +6,6 @@ from replay.models.extensions.ann.utils import create_nmslib_index_instance
6
6
  from replay.utils import PandasDataFrame
7
7
 
8
8
 
9
- # pylint: disable=too-few-public-methods
10
9
  class NmslibIndexBuilderMixin:
11
10
  """Provides nmslib index building method for different nmslib index builders"""
12
11
 
@@ -49,6 +48,4 @@ class NmslibIndexBuilderMixin:
49
48
  index.addDataPointBatch(data=sim_matrix)
50
49
  index.createIndex(creation_index_params)
51
50
 
52
- index_store.save_to_store(
53
- lambda path: index.saveIndex(path, save_data=True)
54
- ) # pylint: disable=unnecessary-lambda)
51
+ index_store.save_to_store(lambda path: index.saveIndex(path, save_data=True))
@@ -8,7 +8,6 @@ if PYSPARK_AVAILABLE:
8
8
  from pyspark.sql import functions as sf
9
9
 
10
10
 
11
- # pylint: disable=too-few-public-methods
12
11
  class IndexInferer(ABC):
13
12
  """Abstract base class that describes a common interface for index inferers
14
13
  and provides common methods for them."""
@@ -21,9 +20,7 @@ class IndexInferer(ABC):
21
20
  self.index_store = index_store
22
21
 
23
22
  @abstractmethod
24
- def infer(
25
- self, vectors: SparkDataFrame, features_col: str, k: int
26
- ) -> SparkDataFrame:
23
+ def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
27
24
  """Infers index"""
28
25
 
29
26
  @staticmethod
@@ -51,9 +48,7 @@ class IndexInferer(ABC):
51
48
  """
52
49
  res = inference_result.select(
53
50
  "user_idx",
54
- sf.explode(
55
- sf.arrays_zip("neighbours.item_idx", "neighbours.distance")
56
- ).alias("zip_exp"),
51
+ sf.explode(sf.arrays_zip("neighbours.item_idx", "neighbours.distance")).alias("zip_exp"),
57
52
  )
58
53
 
59
54
  # Fix arrays_zip random behavior.
@@ -65,8 +60,6 @@ class IndexInferer(ABC):
65
60
  res = res.select(
66
61
  "user_idx",
67
62
  sf.col(f"zip_exp.{item_idx_field_name}").alias("item_idx"),
68
- (sf.lit(-1.0) * sf.col(f"zip_exp.{distance_field_name}")).alias(
69
- "relevance"
70
- ),
63
+ (sf.lit(-1.0) * sf.col(f"zip_exp.{distance_field_name}")).alias("relevance"),
71
64
  )
72
65
  return res
@@ -1,28 +1,24 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
 
4
- from .base_inferer import IndexInferer
5
4
  from replay.models.extensions.ann.utils import create_hnswlib_index_instance
6
5
  from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
7
6
  from replay.utils.session_handler import State
8
7
 
8
+ from .base_inferer import IndexInferer
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql.pandas.functions import pandas_udf
11
12
 
12
13
 
13
- # pylint: disable=too-few-public-methods
14
14
  class HnswlibFilterIndexInferer(IndexInferer):
15
15
  """Hnswlib index inferer with filter seen items. Infers hnswlib index."""
16
16
 
17
- def infer(
18
- self, vectors: SparkDataFrame, features_col: str, k: int
19
- ) -> SparkDataFrame:
17
+ def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
20
18
  _index_store = self.index_store
21
19
  index_params = self.index_params
22
20
 
23
- index_store_broadcast = State().session.sparkContext.broadcast(
24
- _index_store
25
- )
21
+ index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
26
22
 
27
23
  @pandas_udf(self.udf_return_type)
28
24
  def infer_index_udf(
@@ -34,9 +30,7 @@ class HnswlibFilterIndexInferer(IndexInferer):
34
30
  index = index_store.load_index(
35
31
  init_index=lambda: create_hnswlib_index_instance(index_params),
36
32
  load_index=lambda index, path: index.load_index(path),
37
- configure_index=lambda index: index.set_ef(index_params.ef_s)
38
- if index_params.ef_s
39
- else None,
33
+ configure_index=lambda index: index.set_ef(index_params.ef_s) if index_params.ef_s else None,
40
34
  )
41
35
 
42
36
  # max number of items to retrieve per batch
@@ -51,13 +45,9 @@ class HnswlibFilterIndexInferer(IndexInferer):
51
45
  filtered_labels = []
52
46
  filtered_distances = []
53
47
  for i, item_ids in enumerate(labels):
54
- non_seen_item_indexes = ~np.isin(
55
- item_ids, seen_item_ids[i], assume_unique=True
56
- )
48
+ non_seen_item_indexes = ~np.isin(item_ids, seen_item_ids[i], assume_unique=True)
57
49
  filtered_labels.append((item_ids[non_seen_item_indexes])[:k])
58
- filtered_distances.append(
59
- (distances[i][non_seen_item_indexes])[:k]
60
- )
50
+ filtered_distances.append((distances[i][non_seen_item_indexes])[:k])
61
51
 
62
52
  pd_res = pd.DataFrame(
63
53
  {
@@ -1,28 +1,24 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
 
4
- from .base_inferer import IndexInferer
5
4
  from replay.models.extensions.ann.utils import create_hnswlib_index_instance
6
5
  from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
7
6
  from replay.utils.session_handler import State
8
7
 
8
+ from .base_inferer import IndexInferer
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql.pandas.functions import pandas_udf
11
12
 
12
13
 
13
- # pylint: disable=too-few-public-methods
14
14
  class HnswlibIndexInferer(IndexInferer):
15
15
  """Hnswlib index inferer without filter seen items. Infers hnswlib index."""
16
16
 
17
- def infer(
18
- self, vectors: SparkDataFrame, features_col: str, k: int
19
- ) -> SparkDataFrame:
17
+ def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
20
18
  _index_store = self.index_store
21
19
  index_params = self.index_params
22
20
 
23
- index_store_broadcast = State().session.sparkContext.broadcast(
24
- _index_store
25
- )
21
+ index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
26
22
 
27
23
  @pandas_udf(self.udf_return_type)
28
24
  def infer_index_udf(vectors: pd.Series) -> PandasDataFrame: # pragma: no cover
@@ -30,9 +26,7 @@ class HnswlibIndexInferer(IndexInferer):
30
26
  index = index_store.load_index(
31
27
  init_index=lambda: create_hnswlib_index_instance(index_params),
32
28
  load_index=lambda index, path: index.load_index(path),
33
- configure_index=lambda index: index.set_ef(index_params.ef_s)
34
- if index_params.ef_s
35
- else None,
29
+ configure_index=lambda index: index.set_ef(index_params.ef_s) if index_params.ef_s else None,
36
30
  )
37
31
 
38
32
  labels, distances = index.knn_query(
@@ -41,9 +35,7 @@ class HnswlibIndexInferer(IndexInferer):
41
35
  num_threads=1,
42
36
  )
43
37
 
44
- pd_res = pd.DataFrame(
45
- {"item_idx": list(labels), "distance": list(distances)}
46
- )
38
+ pd_res = pd.DataFrame({"item_idx": list(labels), "distance": list(distances)})
47
39
 
48
40
  return pd_res
49
41