replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -3,16 +3,16 @@ from typing import Any, Dict, Iterable, List, Optional, Union
3
3
  import numpy as np
4
4
 
5
5
  from replay.data import Dataset
6
+ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
7
+
6
8
  from .base_neighbour_rec import NeighbourRec
7
9
  from .extensions.ann.index_builders.base_index_builder import IndexBuilder
8
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
9
10
 
10
11
  if PYSPARK_AVAILABLE:
11
12
  import pyspark.sql.functions as sf
12
13
  from pyspark.sql.window import Window
13
14
 
14
15
 
15
- # pylint: disable=too-many-ancestors, too-many-instance-attributes
16
16
  class AssociationRulesItemRec(NeighbourRec):
17
17
  """
18
18
  Item-to-item recommender based on association rules.
@@ -117,7 +117,6 @@ class AssociationRulesItemRec(NeighbourRec):
117
117
  },
118
118
  }
119
119
 
120
- # pylint: disable=too-many-arguments,
121
120
  def __init__(
122
121
  self,
123
122
  session_column: str,
@@ -204,14 +203,11 @@ class AssociationRulesItemRec(NeighbourRec):
204
203
  frequent_items_interactions.withColumnRenamed(self.item_column, "antecedent")
205
204
  .withColumnRenamed(self.rating_column, "antecedent_rel")
206
205
  .join(
207
- frequent_items_interactions.withColumnRenamed(
208
- self.session_column, self.session_column + "_cons"
209
- )
206
+ frequent_items_interactions.withColumnRenamed(self.session_column, self.session_column + "_cons")
210
207
  .withColumnRenamed(self.item_column, "consequent")
211
208
  .withColumnRenamed(self.rating_column, "consequent_rel"),
212
209
  on=[
213
- sf.col(self.session_column)
214
- == sf.col(self.session_column + "_cons"),
210
+ sf.col(self.session_column) == sf.col(self.session_column + "_cons"),
215
211
  sf.col("antecedent") < sf.col("consequent"),
216
212
  ],
217
213
  )
@@ -220,9 +216,7 @@ class AssociationRulesItemRec(NeighbourRec):
220
216
  self.rating_column,
221
217
  sf.least(sf.col("consequent_rel"), sf.col("antecedent_rel")),
222
218
  )
223
- .drop(
224
- self.session_column + "_cons", "consequent_rel", "antecedent_rel"
225
- )
219
+ .drop(self.session_column + "_cons", "consequent_rel", "antecedent_rel")
226
220
  )
227
221
 
228
222
  pairs_count = (
@@ -243,16 +237,12 @@ class AssociationRulesItemRec(NeighbourRec):
243
237
  )
244
238
 
245
239
  pairs_metrics = pairs_metrics.join(
246
- frequent_items_cached.withColumnRenamed(
247
- "item_rating", "antecedent_rating"
248
- ),
240
+ frequent_items_cached.withColumnRenamed("item_rating", "antecedent_rating"),
249
241
  on=[sf.col("antecedent") == sf.col(self.item_column)],
250
242
  ).drop(self.item_column)
251
243
 
252
244
  pairs_metrics = pairs_metrics.join(
253
- frequent_items_cached.withColumnRenamed(
254
- "item_rating", "consequent_rating"
255
- ),
245
+ frequent_items_cached.withColumnRenamed("item_rating", "consequent_rating"),
256
246
  on=[sf.col("consequent") == sf.col(self.item_column)],
257
247
  ).drop(self.item_column)
258
248
 
@@ -261,9 +251,7 @@ class AssociationRulesItemRec(NeighbourRec):
261
251
  sf.col("pair_rating") / sf.col("antecedent_rating"),
262
252
  ).withColumn(
263
253
  "lift",
264
- num_sessions
265
- * sf.col("confidence")
266
- / sf.col("consequent_rating"),
254
+ num_sessions * sf.col("confidence") / sf.col("consequent_rating"),
267
255
  )
268
256
 
269
257
  if self.num_neighbours is not None:
@@ -331,10 +319,8 @@ class AssociationRulesItemRec(NeighbourRec):
331
319
  spark-dataframe with columns ``[item_id, neighbour_item_id, similarity]``
332
320
  """
333
321
  if metric not in self.item_to_item_metrics:
334
- raise ValueError(
335
- f"Select one of the valid distance metrics: "
336
- f"{self.item_to_item_metrics}"
337
- )
322
+ msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
323
+ raise ValueError(msg)
338
324
 
339
325
  return self._get_nearest_items_wrap(
340
326
  items=items,
@@ -346,7 +332,7 @@ class AssociationRulesItemRec(NeighbourRec):
346
332
  def _get_nearest_items(
347
333
  self,
348
334
  items: SparkDataFrame,
349
- metric: Optional[str] = None,
335
+ metric: Optional[str] = None, # noqa: ARG002
350
336
  candidates: Optional[SparkDataFrame] = None,
351
337
  ) -> SparkDataFrame:
352
338
  """
@@ -361,9 +347,7 @@ class AssociationRulesItemRec(NeighbourRec):
361
347
  pairs_to_consider = self.similarity
362
348
  if candidates is not None:
363
349
  pairs_to_consider = self.similarity.join(
364
- sf.broadcast(
365
- candidates.withColumnRenamed(self.item_column, "item_idx_two")
366
- ),
350
+ sf.broadcast(candidates.withColumnRenamed(self.item_column, "item_idx_two")),
367
351
  on="item_idx_two",
368
352
  )
369
353
 
@@ -1,4 +1,3 @@
1
- # pylint: disable=too-many-lines
2
1
  """
3
2
  NeighbourRec - base class that requires interactions at prediction time.
4
3
  Part of set of abstract classes (from base_rec.py)
@@ -8,9 +7,10 @@ from abc import ABC
8
7
  from typing import Any, Dict, Iterable, Optional, Union
9
8
 
10
9
  from replay.data.dataset import Dataset
10
+ from replay.utils import PYSPARK_AVAILABLE, MissingImportType, SparkDataFrame
11
+
11
12
  from .base_rec import Recommender
12
13
  from .extensions.ann.ann_mixin import ANNMixin
13
- from replay.utils import PYSPARK_AVAILABLE, MissingImportType, SparkDataFrame
14
14
 
15
15
  if PYSPARK_AVAILABLE:
16
16
  from pyspark.sql import functions as sf
@@ -37,7 +37,6 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
37
37
  if hasattr(self, "similarity"):
38
38
  self.similarity.unpersist()
39
39
 
40
- # pylint: disable=missing-function-docstring
41
40
  @property
42
41
  def similarity_metric(self):
43
42
  return self._similarity_metric
@@ -45,14 +44,11 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
45
44
  @similarity_metric.setter
46
45
  def similarity_metric(self, value):
47
46
  if not self.can_change_metric:
48
- raise ValueError(
49
- "This class does not support changing similarity metrics"
50
- )
47
+ msg = "This class does not support changing similarity metrics"
48
+ raise ValueError(msg)
51
49
  if value not in self.item_to_item_metrics:
52
- raise ValueError(
53
- f"Select one of the valid metrics for predict: "
54
- f"{self.item_to_item_metrics}"
55
- )
50
+ msg = f"Select one of the valid metrics for predict: {self.item_to_item_metrics}"
51
+ raise ValueError(msg)
56
52
  self._similarity_metric = value
57
53
 
58
54
  def _predict_pairs_inner(
@@ -76,9 +72,8 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
76
72
  :return: SparkDataFrame ``[user_id, item_id, rating]``
77
73
  """
78
74
  if dataset is None:
79
- raise ValueError(
80
- "interactions is not provided, but it is required for prediction"
81
- )
75
+ msg = "interactions is not provided, but it is required for prediction"
76
+ raise ValueError(msg)
82
77
 
83
78
  recs = (
84
79
  dataset.interactions.join(queries, how="inner", on=self.query_column)
@@ -98,16 +93,14 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
98
93
  )
99
94
  return recs
100
95
 
101
- # pylint: disable=too-many-arguments
102
96
  def _predict(
103
97
  self,
104
98
  dataset: Dataset,
105
- k: int,
99
+ k: int, # noqa: ARG002
106
100
  queries: SparkDataFrame,
107
101
  items: SparkDataFrame,
108
- filter_seen_items: bool = True,
102
+ filter_seen_items: bool = True, # noqa: ARG002
109
103
  ) -> SparkDataFrame:
110
-
111
104
  return self._predict_pairs_inner(
112
105
  dataset=dataset,
113
106
  filter_df=items.withColumnRenamed(self.item_column, "item_idx_filter"),
@@ -120,13 +113,12 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
120
113
  pairs: SparkDataFrame,
121
114
  dataset: Optional[Dataset] = None,
122
115
  ) -> SparkDataFrame:
123
-
124
116
  return self._predict_pairs_inner(
125
117
  dataset=dataset,
126
118
  filter_df=(
127
- pairs.withColumnRenamed(
128
- self.query_column, "user_idx_filter"
129
- ).withColumnRenamed(self.item_column, "item_idx_filter")
119
+ pairs.withColumnRenamed(self.query_column, "user_idx_filter").withColumnRenamed(
120
+ self.item_column, "item_idx_filter"
121
+ )
130
122
  ),
131
123
  condition=(sf.col(self.query_column) == sf.col("user_idx_filter"))
132
124
  & (sf.col("item_idx_two") == sf.col("item_idx_filter")),
@@ -157,10 +149,8 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
157
149
 
158
150
  if metric is not None:
159
151
  if metric not in self.item_to_item_metrics:
160
- raise ValueError(
161
- f"Select one of the valid distance metrics: "
162
- f"{self.item_to_item_metrics}"
163
- )
152
+ msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
153
+ raise ValueError(msg)
164
154
 
165
155
  self.logger.debug(
166
156
  "Metric is not used to determine nearest items in %s model",
@@ -180,7 +170,6 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
180
170
  metric: Optional[str] = None,
181
171
  candidates: Optional[SparkDataFrame] = None,
182
172
  ) -> SparkDataFrame:
183
-
184
173
  similarity_filtered = self.similarity.join(
185
174
  items.withColumnRenamed(self.item_column, "item_idx_one"),
186
175
  on="item_idx_one",
@@ -204,20 +193,16 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
204
193
  "features_col": None,
205
194
  }
206
195
 
207
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame:
208
- similarity_df = self.similarity.select(
209
- "similarity", "item_idx_one", "item_idx_two"
210
- )
196
+ def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
197
+ similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
211
198
  return similarity_df
212
199
 
213
200
  def _get_vectors_to_infer_ann_inner(
214
- self, interactions: SparkDataFrame, queries: SparkDataFrame
201
+ self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
215
202
  ) -> SparkDataFrame:
216
-
217
- user_vectors = (
218
- interactions.groupBy(self.query_column).agg(
219
- sf.collect_list(self.item_column).alias("vector_items"),
220
- sf.collect_list(self.rating_column).alias("vector_ratings"))
203
+ user_vectors = interactions.groupBy(self.query_column).agg(
204
+ sf.collect_list(self.item_column).alias("vector_items"),
205
+ sf.collect_list(self.rating_column).alias("vector_ratings"),
221
206
  )
222
207
  return user_vectors
223
208