replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,16 +1,20 @@
1
+ import functools
2
+ import operator
1
3
  from typing import Dict, List, Union
4
+
2
5
  import polars as pl
3
6
 
4
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame, PolarsDataFrame
7
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
5
8
 
6
9
  from .base_metric import Metric, MetricsDataFrameLike, MetricsMeanReturnType, MetricsReturnType
7
10
 
8
11
  if PYSPARK_AVAILABLE:
9
- from pyspark.sql import Window
10
- from pyspark.sql import functions as sf
12
+ from pyspark.sql import (
13
+ Window,
14
+ functions as sf,
15
+ )
11
16
 
12
17
 
13
- # pylint: disable=too-few-public-methods
14
18
  class Coverage(Metric):
15
19
  """
16
20
  Metric calculation is as follows:
@@ -54,7 +58,6 @@ class Coverage(Metric):
54
58
  <BLANKLINE>
55
59
  """
56
60
 
57
- # pylint: disable=too-many-arguments
58
61
  def __init__(
59
62
  self,
60
63
  topk: Union[List, int],
@@ -79,7 +82,6 @@ class Coverage(Metric):
79
82
  )
80
83
  self._allow_caching = allow_caching
81
84
 
82
- # pylint: disable=arguments-differ
83
85
  def _get_enriched_recommendations(
84
86
  self,
85
87
  recommendations: Union[PolarsDataFrame, SparkDataFrame],
@@ -89,16 +91,9 @@ class Coverage(Metric):
89
91
  else:
90
92
  return self._get_enriched_recommendations_polars(recommendations)
91
93
 
92
- # pylint: disable=arguments-differ
93
- def _get_enriched_recommendations_spark(
94
- self, recommendations: SparkDataFrame
95
- ) -> SparkDataFrame:
96
- window = Window.partitionBy(self.query_column).orderBy(
97
- sf.col(self.rating_column).desc()
98
- )
99
- sorted_by_score_recommendations = recommendations.withColumn(
100
- "rank", sf.row_number().over(window)
101
- )
94
+ def _get_enriched_recommendations_spark(self, recommendations: SparkDataFrame) -> SparkDataFrame:
95
+ window = Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())
96
+ sorted_by_score_recommendations = recommendations.withColumn("rank", sf.row_number().over(window))
102
97
  grouped_recs = (
103
98
  sorted_by_score_recommendations.select(self.item_column, "rank")
104
99
  .groupBy(self.item_column)
@@ -106,10 +101,7 @@ class Coverage(Metric):
106
101
  )
107
102
  return grouped_recs
108
103
 
109
- # pylint: disable=arguments-differ
110
- def _get_enriched_recommendations_polars(
111
- self, recommendations: PolarsDataFrame
112
- ) -> PolarsDataFrame:
104
+ def _get_enriched_recommendations_polars(self, recommendations: PolarsDataFrame) -> PolarsDataFrame:
113
105
  sorted_by_score_recommendations = recommendations.select(
114
106
  pl.all().sort_by(self.rating_column, descending=True).over(self.query_column)
115
107
  )
@@ -119,17 +111,13 @@ class Coverage(Metric):
119
111
  )
120
112
  )
121
113
  grouped_recs = (
122
- sorted_by_score_recommendations
123
- .select(self.item_column, "rank")
114
+ sorted_by_score_recommendations.select(self.item_column, "rank")
124
115
  .group_by(self.item_column)
125
116
  .agg(pl.col("rank").min().alias("best_position"))
126
117
  )
127
118
  return grouped_recs
128
119
 
129
- # pylint: disable=arguments-differ
130
- def _spark_compute(
131
- self, recs: SparkDataFrame, train: SparkDataFrame
132
- ) -> MetricsMeanReturnType:
120
+ def _spark_compute(self, recs: SparkDataFrame, train: SparkDataFrame) -> MetricsMeanReturnType:
133
121
  """
134
122
  Calculating metrics for PySpark DataFrame.
135
123
  """
@@ -144,10 +132,9 @@ class Coverage(Metric):
144
132
  recs.filter(sf.col("best_position") <= k)
145
133
  .select(self.item_column)
146
134
  .distinct()
147
- .join(
148
- train.select(self.item_column).distinct(), on=self.item_column
149
- )
150
- .count() / item_count
135
+ .join(train.select(self.item_column).distinct(), on=self.item_column)
136
+ .count()
137
+ / item_count
151
138
  )
152
139
  metrics.append(res)
153
140
 
@@ -156,10 +143,7 @@ class Coverage(Metric):
156
143
 
157
144
  return self._aggregate_results(metrics)
158
145
 
159
- # pylint: disable=arguments-differ
160
- def _polars_compute(
161
- self, recs: PolarsDataFrame, train: PolarsDataFrame
162
- ) -> MetricsMeanReturnType:
146
+ def _polars_compute(self, recs: PolarsDataFrame, train: PolarsDataFrame) -> MetricsMeanReturnType:
163
147
  """
164
148
  Calculating metrics for Polars DataFrame.
165
149
  """
@@ -172,44 +156,38 @@ class Coverage(Metric):
172
156
  .select(self.item_column)
173
157
  .unique()
174
158
  .join(train.select(self.item_column).unique(), on=self.item_column)
175
- .count() / item_count
159
+ .count()
160
+ / item_count
176
161
  ).rows()[0][0]
177
162
  metrics.append(res)
178
163
 
179
164
  return self._aggregate_results(metrics)
180
165
 
181
- # pylint: disable=arguments-renamed
182
- def _spark_call(
183
- self, recommendations: SparkDataFrame, train: SparkDataFrame
184
- ) -> MetricsReturnType:
166
+ def _spark_call(self, recommendations: SparkDataFrame, train: SparkDataFrame) -> MetricsReturnType:
185
167
  """
186
168
  Implementation for Pyspark DataFrame.
187
169
  """
188
170
  recs = self._get_enriched_recommendations(recommendations)
189
171
  return self._spark_compute(recs, train)
190
172
 
191
- # pylint: disable=arguments-renamed
192
- def _polars_call(
193
- self, recommendations: PolarsDataFrame, train: PolarsDataFrame
194
- ) -> MetricsReturnType:
173
+ def _polars_call(self, recommendations: PolarsDataFrame, train: PolarsDataFrame) -> MetricsReturnType:
195
174
  """
196
175
  Implementation for Polars DataFrame.
197
176
  """
198
177
  recs = self._get_enriched_recommendations(recommendations)
199
178
  return self._polars_compute(recs, train)
200
179
 
201
- # pylint: disable=arguments-differ
202
180
  def _dict_call(self, recommendations: Dict, train: Dict) -> MetricsReturnType:
203
181
  """
204
182
  Calculating metrics in dict format.
205
183
  """
206
- train_items = set(sum(train.values(), []))
184
+ train_items = set(functools.reduce(operator.iconcat, train.values(), []))
207
185
 
208
186
  len_train_items = len(train_items)
209
187
  metrics = []
210
188
  for k in self.topk:
211
189
  pred_items = set()
212
- for _, items in recommendations.items():
190
+ for items in recommendations.values():
213
191
  for item in items[:k]:
214
192
  pred_items.add(item)
215
193
  metrics.append(len(pred_items & train_items) / len_train_items)
@@ -250,9 +228,7 @@ class Coverage(Metric):
250
228
  else self._convert_dict_to_dict_with_score(recommendations)
251
229
  )
252
230
  self._check_duplicates_dict(recommendations)
253
- train = (
254
- self._convert_pandas_to_dict_without_score(train) if is_pandas else train
255
- )
231
+ train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
256
232
  assert isinstance(train, dict)
257
233
  return self._dict_call(recommendations, train)
258
234
 
@@ -4,7 +4,7 @@ from typing import Union
4
4
  import numpy as np
5
5
  from scipy.stats import norm, sem
6
6
 
7
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame, PolarsDataFrame
7
+ from replay.utils import PYSPARK_AVAILABLE, PolarsDataFrame, SparkDataFrame
8
8
 
9
9
  if PYSPARK_AVAILABLE:
10
10
  from pyspark.sql import functions as sf
@@ -66,9 +66,7 @@ class Median(CalculationDescriptor):
66
66
 
67
67
  def spark(self, distribution: SparkDataFrame):
68
68
  column_name = distribution.columns[0]
69
- return distribution.select(
70
- sf.expr(f"percentile_approx({column_name}, 0.5)")
71
- ).first()[0]
69
+ return distribution.select(sf.expr(f"percentile_approx({column_name}, 0.5)")).first()[0]
72
70
 
73
71
  def cpu(self, distribution: Union[np.array, PolarsDataFrame]):
74
72
  if isinstance(distribution, PolarsDataFrame):
@@ -119,12 +117,5 @@ class ConfidenceInterval(CalculationDescriptor):
119
117
  column_name = distribution.columns[0]
120
118
  quantile = norm.ppf((1 + self.alpha) / 2)
121
119
  count = distribution.select(column_name).count().rows()[0][0]
122
- std = (
123
- distribution
124
- .select(column_name)
125
- .std()
126
- .fill_null(0.0)
127
- .fill_nan(0.0)
128
- .rows()[0][0]
129
- )
130
- return quantile * std / (count ** 0.5)
120
+ std = distribution.select(column_name).std().fill_null(0.0).fill_nan(0.0).rows()[0][0]
121
+ return quantile * std / (count**0.5)
@@ -6,8 +6,6 @@ from .base_metric import Metric, MetricsDataFrameLike
6
6
  from .offline_metrics import OfflineMetrics
7
7
 
8
8
 
9
- # pylint: disable=too-many-instance-attributes
10
- # pylint: disable=too-few-public-methods
11
9
  class Experiment:
12
10
  """
13
11
  The class is designed for calculating, storing and comparing metrics
@@ -102,15 +100,12 @@ class Experiment:
102
100
  <BLANKLINE>
103
101
  """
104
102
 
105
- # pylint: disable=too-many-arguments
106
103
  def __init__(
107
104
  self,
108
105
  metrics: List[Metric],
109
106
  ground_truth: MetricsDataFrameLike,
110
107
  train: Optional[MetricsDataFrameLike] = None,
111
- base_recommendations: Optional[
112
- Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
113
- ] = None,
108
+ base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
114
109
  query_column: str = "query_id",
115
110
  item_column: str = "item_id",
116
111
  rating_column: str = "rating",
@@ -182,7 +177,6 @@ class Experiment:
182
177
  for metric, value in cur_metrics.items():
183
178
  self.results.at[name, metric] = value
184
179
 
185
- # pylint: disable=not-an-iterable
186
180
  def compare(self, name: str) -> pd.DataFrame:
187
181
  """
188
182
  Show results as a percentage difference to record ``name``.
@@ -191,7 +185,8 @@ class Experiment:
191
185
  :return: results table in a percentage format
192
186
  """
193
187
  if name not in self.results.index:
194
- raise ValueError(f"No results for model {name}")
188
+ msg = f"No results for model {name}"
189
+ raise ValueError(msg)
195
190
  columns = [column for column in self.results.columns if column[-1].isdigit()]
196
191
  data_frame = self.results[columns].copy()
197
192
  baseline = data_frame.loc[name]
replay/metrics/hitrate.py CHANGED
@@ -3,17 +3,16 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class HitRate(Metric):
8
7
  """
9
8
  Percentage of users that have at least one correctly recommended item\
10
9
  among top-k.
11
10
 
12
11
  .. math::
13
- HitRate@K(i) = \max_{j \in [1..K]}\mathbb{1}_{r_{ij}}
12
+ HitRate@K(i) = \\max_{j \\in [1..K]}\\mathbb{1}_{r_{ij}}
14
13
 
15
14
  .. math::
16
- HitRate@K = \\frac {\sum_{i=1}^{N}HitRate@K(i)}{N}
15
+ HitRate@K = \\frac {\\sum_{i=1}^{N}HitRate@K(i)}{N}
17
16
 
18
17
  :math:`\\mathbb{1}_{r_{ij}}` -- indicator function stating that user :math:`i` interacted with item :math:`j`
19
18
 
@@ -63,9 +62,7 @@ class HitRate(Metric):
63
62
  """
64
63
 
65
64
  @staticmethod
66
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
67
- ks: List[int], ground_truth: List, pred: List
68
- ) -> List[float]:
65
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
69
66
  if not ground_truth or not pred:
70
67
  return [0.0 for _ in ks]
71
68
  set_gt = set(ground_truth)
replay/metrics/map.py CHANGED
@@ -3,16 +3,15 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class MAP(Metric):
8
7
  """
9
8
  Mean Average Precision -- average the ``Precision`` at relevant positions \
10
9
  for each user, and then calculate the mean across all users.
11
10
 
12
11
  .. math::
13
- &AP@K(i) = \\frac {1}{\min(K, |Rel_i|)} \sum_{j=1}^{K}\mathbb{1}_{r_{ij}}Precision@j(i)
12
+ &AP@K(i) = \\frac {1}{\\min(K, |Rel_i|)} \\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}Precision@j(i)
14
13
 
15
- &MAP@K = \\frac {\sum_{i=1}^{N}AP@K(i)}{N}
14
+ &MAP@K = \\frac {\\sum_{i=1}^{N}AP@K(i)}{N}
16
15
 
17
16
  :math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing if user :math:`i` interacted with item :math:`j`
18
17
 
@@ -64,9 +63,7 @@ class MAP(Metric):
64
63
  """
65
64
 
66
65
  @staticmethod
67
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
68
- ks: List[int], ground_truth: List, pred: List
69
- ) -> List[float]:
66
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
70
67
  if not ground_truth or not pred:
71
68
  return [0.0 for _ in ks]
72
69
  res = []
replay/metrics/mrr.py CHANGED
@@ -3,7 +3,6 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class MRR(Metric):
8
7
  """
9
8
  Mean Reciprocal Rank -- Reciprocal Rank is the inverse position of the
@@ -56,9 +55,7 @@ class MRR(Metric):
56
55
  """
57
56
 
58
57
  @staticmethod
59
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
60
- ks: List[int], ground_truth: List, pred: List
61
- ) -> List[float]:
58
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
62
59
  if not ground_truth or not pred:
63
60
  return [0.0 for _ in ks]
64
61
  set_gt = set(ground_truth)
replay/metrics/ndcg.py CHANGED
@@ -4,7 +4,6 @@ from typing import List
4
4
  from .base_metric import Metric
5
5
 
6
6
 
7
- # pylint: disable=too-few-public-methods
8
7
  class NDCG(Metric):
9
8
  """
10
9
  Normalized Discounted Cumulative Gain is a metric
@@ -14,7 +13,7 @@ class NDCG(Metric):
14
13
  whether the item was consumed or not, relevance value is ignored.
15
14
 
16
15
  .. math::
17
- DCG@K(i) = \sum_{j=1}^{K}\\frac{\mathbb{1}_{r_{ij}}}{\log_2 (j+1)}
16
+ DCG@K(i) = \\sum_{j=1}^{K}\\frac{\\mathbb{1}_{r_{ij}}}{\\log_2 (j+1)}
18
17
 
19
18
 
20
19
  :math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
@@ -23,7 +22,7 @@ class NDCG(Metric):
23
22
  for user :math:`i` and recommendation length :math:`K`.
24
23
 
25
24
  .. math::
26
- IDCG@K(i) = max(DCG@K(i)) = \sum_{j=1}^{K}\\frac{\mathbb{1}_{j\le|Rel_i|}}{\log_2 (j+1)}
25
+ IDCG@K(i) = max(DCG@K(i)) = \\sum_{j=1}^{K}\\frac{\\mathbb{1}_{j\\le|Rel_i|}}{\\log_2 (j+1)}
27
26
 
28
27
  .. math::
29
28
  nDCG@K(i) = \\frac {DCG@K(i)}{IDCG@K(i)}
@@ -33,7 +32,7 @@ class NDCG(Metric):
33
32
  Metric is averaged by users.
34
33
 
35
34
  .. math::
36
- nDCG@K = \\frac {\sum_{i=1}^{N}nDCG@K(i)}{N}
35
+ nDCG@K = \\frac {\\sum_{i=1}^{N}nDCG@K(i)}{N}
37
36
 
38
37
  >>> recommendations
39
38
  query_id item_id rating
@@ -81,9 +80,7 @@ class NDCG(Metric):
81
80
  """
82
81
 
83
82
  @staticmethod
84
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
85
- ks: List[int], ground_truth: List, pred: List
86
- ) -> List[float]:
83
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
87
84
  if not pred or not ground_truth:
88
85
  return [0.0 for _ in ks]
89
86
  set_gt = set(ground_truth)
replay/metrics/novelty.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import TYPE_CHECKING, List, Type
2
2
 
3
- from replay.utils import PandasDataFrame, SparkDataFrame, PolarsDataFrame
3
+ from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
4
4
 
5
5
  from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
6
6
 
@@ -8,7 +8,6 @@ if TYPE_CHECKING: # pragma: no cover
8
8
  __class__: Type
9
9
 
10
10
 
11
- # pylint: disable=too-few-public-methods
12
11
  class Novelty(Metric):
13
12
  """
14
13
  Measure the fraction of shown items in recommendation list, that users\
@@ -16,11 +15,11 @@ class Novelty(Metric):
16
15
 
17
16
  .. math::
18
17
  Novelty@K(i) = \\frac
19
- {\parallel {R^{i}_{1..\min(K, \parallel R^{i} \parallel)} \setminus train^{i}} \parallel}
18
+ {\\parallel {R^{i}_{1..\\min(K, \\parallel R^{i} \\parallel)} \\setminus train^{i}} \\parallel}
20
19
  {K}
21
20
 
22
21
  .. math::
23
- Novelty@K = \\frac {1}{N}\sum_{i=1}^{N}Novelty@K(i)
22
+ Novelty@K = \\frac {1}{N}\\sum_{i=1}^{N}Novelty@K(i)
24
23
 
25
24
  :math:`R^{i}` -- the recommendations for the :math:`i`-th user.
26
25
 
@@ -114,9 +113,7 @@ class Novelty(Metric):
114
113
  else self._convert_dict_to_dict_with_score(recommendations)
115
114
  )
116
115
  self._check_duplicates_dict(recommendations)
117
- train = (
118
- self._convert_pandas_to_dict_without_score(train) if is_pandas else train
119
- )
116
+ train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
120
117
  assert isinstance(train, dict)
121
118
 
122
119
  return self._dict_call(
@@ -125,41 +122,25 @@ class Novelty(Metric):
125
122
  train=train,
126
123
  )
127
124
 
128
- # pylint: disable=arguments-renamed
129
- def _spark_call(
130
- self, recommendations: SparkDataFrame, train: SparkDataFrame
131
- ) -> MetricsReturnType:
125
+ def _spark_call(self, recommendations: SparkDataFrame, train: SparkDataFrame) -> MetricsReturnType:
132
126
  """
133
127
  Implementation for Pyspark DataFrame.
134
128
  """
135
- recs = self._get_enriched_recommendations(
136
- recommendations, train
137
- ).withColumnRenamed("ground_truth", "train")
129
+ recs = self._get_enriched_recommendations(recommendations, train).withColumnRenamed("ground_truth", "train")
138
130
  recs = self._rearrange_columns(recs)
139
131
  return self._spark_compute(recs)
140
132
 
141
- # pylint: disable=arguments-renamed
142
- def _polars_call(
143
- self, recommendations: PolarsDataFrame, train: PolarsDataFrame
144
- ) -> MetricsReturnType:
133
+ def _polars_call(self, recommendations: PolarsDataFrame, train: PolarsDataFrame) -> MetricsReturnType:
145
134
  """
146
135
  Implementation for Polars DataFrame.
147
136
  """
148
- recs = self._get_enriched_recommendations(
149
- recommendations, train
150
- ).rename({"ground_truth": "train"})
137
+ recs = self._get_enriched_recommendations(recommendations, train).rename({"ground_truth": "train"})
151
138
  recs = self._rearrange_columns(recs)
152
139
  return self._polars_compute(recs)
153
140
 
154
- # pylint: disable=arguments-differ
155
141
  @staticmethod
156
- def _get_metric_value_by_user(
157
- ks: List[int], pred: List, train: List
158
- ) -> List[float]:
142
+ def _get_metric_value_by_user(ks: List[int], pred: List, train: List) -> List[float]:
159
143
  if not train or not pred:
160
144
  return [1.0 for _ in ks]
161
145
  set_train = set(train)
162
- res = []
163
- for k in ks:
164
- res.append(1.0 - len(set(pred[:k]) & set_train) / len(pred[:k]))
165
- return res
146
+ return [1.0 - len(set(pred[:k]) & set_train) / len(pred[:k]) for k in ks]