replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -661
- replay/experimental/metrics/coverage.py +0 -117
- replay/experimental/metrics/experiment.py +0 -200
- replay/experimental/metrics/hitrate.py +0 -27
- replay/experimental/metrics/map.py +0 -31
- replay/experimental/metrics/mrr.py +0 -19
- replay/experimental/metrics/ncis_precision.py +0 -32
- replay/experimental/metrics/ndcg.py +0 -50
- replay/experimental/metrics/precision.py +0 -23
- replay/experimental/metrics/recall.py +0 -26
- replay/experimental/metrics/rocauc.py +0 -50
- replay/experimental/metrics/surprisal.py +0 -102
- replay/experimental/metrics/unexpectedness.py +0 -74
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -216
- replay/experimental/models/base_neighbour_rec.py +0 -222
- replay/experimental/models/base_rec.py +0 -1361
- replay/experimental/models/base_torch_rec.py +0 -247
- replay/experimental/models/cql.py +0 -468
- replay/experimental/models/ddpg.py +0 -1007
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -193
- replay/experimental/models/dt4rec/gpt1.py +0 -411
- replay/experimental/models/dt4rec/trainer.py +0 -128
- replay/experimental/models/dt4rec/utils.py +0 -274
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
- replay/experimental/models/implicit_wrap.py +0 -138
- replay/experimental/models/lightfm_wrap.py +0 -327
- replay/experimental/models/mult_vae.py +0 -374
- replay/experimental/models/neuromf.py +0 -462
- replay/experimental/models/scala_als.py +0 -311
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -58
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -929
- replay/experimental/preprocessing/padder.py +0 -231
- replay/experimental/preprocessing/sequence_generator.py +0 -218
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
- replay/experimental/scenarios/two_stages/reranker.py +0 -116
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -213
- replay/experimental/utils/session_handler.py +0 -47
- replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
- replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
replay/metrics/coverage.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import operator
|
|
1
3
|
from typing import Dict, List, Union
|
|
4
|
+
|
|
2
5
|
import polars as pl
|
|
3
6
|
|
|
4
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame,
|
|
7
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
5
8
|
|
|
6
9
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsMeanReturnType, MetricsReturnType
|
|
7
10
|
|
|
8
11
|
if PYSPARK_AVAILABLE:
|
|
9
|
-
from pyspark.sql import
|
|
10
|
-
|
|
12
|
+
from pyspark.sql import (
|
|
13
|
+
Window,
|
|
14
|
+
functions as sf,
|
|
15
|
+
)
|
|
11
16
|
|
|
12
17
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
14
18
|
class Coverage(Metric):
|
|
15
19
|
"""
|
|
16
20
|
Metric calculation is as follows:
|
|
@@ -54,7 +58,6 @@ class Coverage(Metric):
|
|
|
54
58
|
<BLANKLINE>
|
|
55
59
|
"""
|
|
56
60
|
|
|
57
|
-
# pylint: disable=too-many-arguments
|
|
58
61
|
def __init__(
|
|
59
62
|
self,
|
|
60
63
|
topk: Union[List, int],
|
|
@@ -79,7 +82,6 @@ class Coverage(Metric):
|
|
|
79
82
|
)
|
|
80
83
|
self._allow_caching = allow_caching
|
|
81
84
|
|
|
82
|
-
# pylint: disable=arguments-differ
|
|
83
85
|
def _get_enriched_recommendations(
|
|
84
86
|
self,
|
|
85
87
|
recommendations: Union[PolarsDataFrame, SparkDataFrame],
|
|
@@ -89,16 +91,9 @@ class Coverage(Metric):
|
|
|
89
91
|
else:
|
|
90
92
|
return self._get_enriched_recommendations_polars(recommendations)
|
|
91
93
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
) -> SparkDataFrame:
|
|
96
|
-
window = Window.partitionBy(self.query_column).orderBy(
|
|
97
|
-
sf.col(self.rating_column).desc()
|
|
98
|
-
)
|
|
99
|
-
sorted_by_score_recommendations = recommendations.withColumn(
|
|
100
|
-
"rank", sf.row_number().over(window)
|
|
101
|
-
)
|
|
94
|
+
def _get_enriched_recommendations_spark(self, recommendations: SparkDataFrame) -> SparkDataFrame:
|
|
95
|
+
window = Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())
|
|
96
|
+
sorted_by_score_recommendations = recommendations.withColumn("rank", sf.row_number().over(window))
|
|
102
97
|
grouped_recs = (
|
|
103
98
|
sorted_by_score_recommendations.select(self.item_column, "rank")
|
|
104
99
|
.groupBy(self.item_column)
|
|
@@ -106,10 +101,7 @@ class Coverage(Metric):
|
|
|
106
101
|
)
|
|
107
102
|
return grouped_recs
|
|
108
103
|
|
|
109
|
-
|
|
110
|
-
def _get_enriched_recommendations_polars(
|
|
111
|
-
self, recommendations: PolarsDataFrame
|
|
112
|
-
) -> PolarsDataFrame:
|
|
104
|
+
def _get_enriched_recommendations_polars(self, recommendations: PolarsDataFrame) -> PolarsDataFrame:
|
|
113
105
|
sorted_by_score_recommendations = recommendations.select(
|
|
114
106
|
pl.all().sort_by(self.rating_column, descending=True).over(self.query_column)
|
|
115
107
|
)
|
|
@@ -119,17 +111,13 @@ class Coverage(Metric):
|
|
|
119
111
|
)
|
|
120
112
|
)
|
|
121
113
|
grouped_recs = (
|
|
122
|
-
sorted_by_score_recommendations
|
|
123
|
-
.select(self.item_column, "rank")
|
|
114
|
+
sorted_by_score_recommendations.select(self.item_column, "rank")
|
|
124
115
|
.group_by(self.item_column)
|
|
125
116
|
.agg(pl.col("rank").min().alias("best_position"))
|
|
126
117
|
)
|
|
127
118
|
return grouped_recs
|
|
128
119
|
|
|
129
|
-
|
|
130
|
-
def _spark_compute(
|
|
131
|
-
self, recs: SparkDataFrame, train: SparkDataFrame
|
|
132
|
-
) -> MetricsMeanReturnType:
|
|
120
|
+
def _spark_compute(self, recs: SparkDataFrame, train: SparkDataFrame) -> MetricsMeanReturnType:
|
|
133
121
|
"""
|
|
134
122
|
Calculating metrics for PySpark DataFrame.
|
|
135
123
|
"""
|
|
@@ -144,10 +132,9 @@ class Coverage(Metric):
|
|
|
144
132
|
recs.filter(sf.col("best_position") <= k)
|
|
145
133
|
.select(self.item_column)
|
|
146
134
|
.distinct()
|
|
147
|
-
.join(
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
.count() / item_count
|
|
135
|
+
.join(train.select(self.item_column).distinct(), on=self.item_column)
|
|
136
|
+
.count()
|
|
137
|
+
/ item_count
|
|
151
138
|
)
|
|
152
139
|
metrics.append(res)
|
|
153
140
|
|
|
@@ -156,10 +143,7 @@ class Coverage(Metric):
|
|
|
156
143
|
|
|
157
144
|
return self._aggregate_results(metrics)
|
|
158
145
|
|
|
159
|
-
|
|
160
|
-
def _polars_compute(
|
|
161
|
-
self, recs: PolarsDataFrame, train: PolarsDataFrame
|
|
162
|
-
) -> MetricsMeanReturnType:
|
|
146
|
+
def _polars_compute(self, recs: PolarsDataFrame, train: PolarsDataFrame) -> MetricsMeanReturnType:
|
|
163
147
|
"""
|
|
164
148
|
Calculating metrics for Polars DataFrame.
|
|
165
149
|
"""
|
|
@@ -172,44 +156,38 @@ class Coverage(Metric):
|
|
|
172
156
|
.select(self.item_column)
|
|
173
157
|
.unique()
|
|
174
158
|
.join(train.select(self.item_column).unique(), on=self.item_column)
|
|
175
|
-
.count()
|
|
159
|
+
.count()
|
|
160
|
+
/ item_count
|
|
176
161
|
).rows()[0][0]
|
|
177
162
|
metrics.append(res)
|
|
178
163
|
|
|
179
164
|
return self._aggregate_results(metrics)
|
|
180
165
|
|
|
181
|
-
|
|
182
|
-
def _spark_call(
|
|
183
|
-
self, recommendations: SparkDataFrame, train: SparkDataFrame
|
|
184
|
-
) -> MetricsReturnType:
|
|
166
|
+
def _spark_call(self, recommendations: SparkDataFrame, train: SparkDataFrame) -> MetricsReturnType:
|
|
185
167
|
"""
|
|
186
168
|
Implementation for Pyspark DataFrame.
|
|
187
169
|
"""
|
|
188
170
|
recs = self._get_enriched_recommendations(recommendations)
|
|
189
171
|
return self._spark_compute(recs, train)
|
|
190
172
|
|
|
191
|
-
|
|
192
|
-
def _polars_call(
|
|
193
|
-
self, recommendations: PolarsDataFrame, train: PolarsDataFrame
|
|
194
|
-
) -> MetricsReturnType:
|
|
173
|
+
def _polars_call(self, recommendations: PolarsDataFrame, train: PolarsDataFrame) -> MetricsReturnType:
|
|
195
174
|
"""
|
|
196
175
|
Implementation for Polars DataFrame.
|
|
197
176
|
"""
|
|
198
177
|
recs = self._get_enriched_recommendations(recommendations)
|
|
199
178
|
return self._polars_compute(recs, train)
|
|
200
179
|
|
|
201
|
-
# pylint: disable=arguments-differ
|
|
202
180
|
def _dict_call(self, recommendations: Dict, train: Dict) -> MetricsReturnType:
|
|
203
181
|
"""
|
|
204
182
|
Calculating metrics in dict format.
|
|
205
183
|
"""
|
|
206
|
-
train_items = set(
|
|
184
|
+
train_items = set(functools.reduce(operator.iconcat, train.values(), []))
|
|
207
185
|
|
|
208
186
|
len_train_items = len(train_items)
|
|
209
187
|
metrics = []
|
|
210
188
|
for k in self.topk:
|
|
211
189
|
pred_items = set()
|
|
212
|
-
for
|
|
190
|
+
for items in recommendations.values():
|
|
213
191
|
for item in items[:k]:
|
|
214
192
|
pred_items.add(item)
|
|
215
193
|
metrics.append(len(pred_items & train_items) / len_train_items)
|
|
@@ -250,9 +228,7 @@ class Coverage(Metric):
|
|
|
250
228
|
else self._convert_dict_to_dict_with_score(recommendations)
|
|
251
229
|
)
|
|
252
230
|
self._check_duplicates_dict(recommendations)
|
|
253
|
-
train = (
|
|
254
|
-
self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
255
|
-
)
|
|
231
|
+
train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
256
232
|
assert isinstance(train, dict)
|
|
257
233
|
return self._dict_call(recommendations, train)
|
|
258
234
|
|
replay/metrics/descriptors.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from scipy.stats import norm, sem
|
|
6
6
|
|
|
7
|
-
from replay.utils import PYSPARK_AVAILABLE,
|
|
7
|
+
from replay.utils import PYSPARK_AVAILABLE, PolarsDataFrame, SparkDataFrame
|
|
8
8
|
|
|
9
9
|
if PYSPARK_AVAILABLE:
|
|
10
10
|
from pyspark.sql import functions as sf
|
|
@@ -66,9 +66,7 @@ class Median(CalculationDescriptor):
|
|
|
66
66
|
|
|
67
67
|
def spark(self, distribution: SparkDataFrame):
|
|
68
68
|
column_name = distribution.columns[0]
|
|
69
|
-
return distribution.select(
|
|
70
|
-
sf.expr(f"percentile_approx({column_name}, 0.5)")
|
|
71
|
-
).first()[0]
|
|
69
|
+
return distribution.select(sf.expr(f"percentile_approx({column_name}, 0.5)")).first()[0]
|
|
72
70
|
|
|
73
71
|
def cpu(self, distribution: Union[np.array, PolarsDataFrame]):
|
|
74
72
|
if isinstance(distribution, PolarsDataFrame):
|
|
@@ -119,12 +117,5 @@ class ConfidenceInterval(CalculationDescriptor):
|
|
|
119
117
|
column_name = distribution.columns[0]
|
|
120
118
|
quantile = norm.ppf((1 + self.alpha) / 2)
|
|
121
119
|
count = distribution.select(column_name).count().rows()[0][0]
|
|
122
|
-
std = (
|
|
123
|
-
|
|
124
|
-
.select(column_name)
|
|
125
|
-
.std()
|
|
126
|
-
.fill_null(0.0)
|
|
127
|
-
.fill_nan(0.0)
|
|
128
|
-
.rows()[0][0]
|
|
129
|
-
)
|
|
130
|
-
return quantile * std / (count ** 0.5)
|
|
120
|
+
std = distribution.select(column_name).std().fill_null(0.0).fill_nan(0.0).rows()[0][0]
|
|
121
|
+
return quantile * std / (count**0.5)
|
replay/metrics/experiment.py
CHANGED
|
@@ -6,8 +6,6 @@ from .base_metric import Metric, MetricsDataFrameLike
|
|
|
6
6
|
from .offline_metrics import OfflineMetrics
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
# pylint: disable=too-many-instance-attributes
|
|
10
|
-
# pylint: disable=too-few-public-methods
|
|
11
9
|
class Experiment:
|
|
12
10
|
"""
|
|
13
11
|
The class is designed for calculating, storing and comparing metrics
|
|
@@ -102,15 +100,12 @@ class Experiment:
|
|
|
102
100
|
<BLANKLINE>
|
|
103
101
|
"""
|
|
104
102
|
|
|
105
|
-
# pylint: disable=too-many-arguments
|
|
106
103
|
def __init__(
|
|
107
104
|
self,
|
|
108
105
|
metrics: List[Metric],
|
|
109
106
|
ground_truth: MetricsDataFrameLike,
|
|
110
107
|
train: Optional[MetricsDataFrameLike] = None,
|
|
111
|
-
base_recommendations: Optional[
|
|
112
|
-
Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
|
|
113
|
-
] = None,
|
|
108
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
|
|
114
109
|
query_column: str = "query_id",
|
|
115
110
|
item_column: str = "item_id",
|
|
116
111
|
rating_column: str = "rating",
|
|
@@ -182,7 +177,6 @@ class Experiment:
|
|
|
182
177
|
for metric, value in cur_metrics.items():
|
|
183
178
|
self.results.at[name, metric] = value
|
|
184
179
|
|
|
185
|
-
# pylint: disable=not-an-iterable
|
|
186
180
|
def compare(self, name: str) -> pd.DataFrame:
|
|
187
181
|
"""
|
|
188
182
|
Show results as a percentage difference to record ``name``.
|
|
@@ -191,7 +185,8 @@ class Experiment:
|
|
|
191
185
|
:return: results table in a percentage format
|
|
192
186
|
"""
|
|
193
187
|
if name not in self.results.index:
|
|
194
|
-
|
|
188
|
+
msg = f"No results for model {name}"
|
|
189
|
+
raise ValueError(msg)
|
|
195
190
|
columns = [column for column in self.results.columns if column[-1].isdigit()]
|
|
196
191
|
data_frame = self.results[columns].copy()
|
|
197
192
|
baseline = data_frame.loc[name]
|
replay/metrics/hitrate.py
CHANGED
|
@@ -3,17 +3,16 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class HitRate(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Percentage of users that have at least one correctly recommended item\
|
|
10
9
|
among top-k.
|
|
11
10
|
|
|
12
11
|
.. math::
|
|
13
|
-
HitRate@K(i) =
|
|
12
|
+
HitRate@K(i) = \\max_{j \\in [1..K]}\\mathbb{1}_{r_{ij}}
|
|
14
13
|
|
|
15
14
|
.. math::
|
|
16
|
-
HitRate@K = \\frac {
|
|
15
|
+
HitRate@K = \\frac {\\sum_{i=1}^{N}HitRate@K(i)}{N}
|
|
17
16
|
|
|
18
17
|
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function stating that user :math:`i` interacted with item :math:`j`
|
|
19
18
|
|
|
@@ -63,9 +62,7 @@ class HitRate(Metric):
|
|
|
63
62
|
"""
|
|
64
63
|
|
|
65
64
|
@staticmethod
|
|
66
|
-
def _get_metric_value_by_user(
|
|
67
|
-
ks: List[int], ground_truth: List, pred: List
|
|
68
|
-
) -> List[float]:
|
|
65
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
69
66
|
if not ground_truth or not pred:
|
|
70
67
|
return [0.0 for _ in ks]
|
|
71
68
|
set_gt = set(ground_truth)
|
replay/metrics/map.py
CHANGED
|
@@ -3,16 +3,15 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class MAP(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Mean Average Precision -- average the ``Precision`` at relevant positions \
|
|
10
9
|
for each user, and then calculate the mean across all users.
|
|
11
10
|
|
|
12
11
|
.. math::
|
|
13
|
-
&AP@K(i) = \\frac {1}{
|
|
12
|
+
&AP@K(i) = \\frac {1}{\\min(K, |Rel_i|)} \\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}Precision@j(i)
|
|
14
13
|
|
|
15
|
-
&MAP@K = \\frac {
|
|
14
|
+
&MAP@K = \\frac {\\sum_{i=1}^{N}AP@K(i)}{N}
|
|
16
15
|
|
|
17
16
|
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing if user :math:`i` interacted with item :math:`j`
|
|
18
17
|
|
|
@@ -64,9 +63,7 @@ class MAP(Metric):
|
|
|
64
63
|
"""
|
|
65
64
|
|
|
66
65
|
@staticmethod
|
|
67
|
-
def _get_metric_value_by_user(
|
|
68
|
-
ks: List[int], ground_truth: List, pred: List
|
|
69
|
-
) -> List[float]:
|
|
66
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
70
67
|
if not ground_truth or not pred:
|
|
71
68
|
return [0.0 for _ in ks]
|
|
72
69
|
res = []
|
replay/metrics/mrr.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class MRR(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Mean Reciprocal Rank -- Reciprocal Rank is the inverse position of the
|
|
@@ -56,9 +55,7 @@ class MRR(Metric):
|
|
|
56
55
|
"""
|
|
57
56
|
|
|
58
57
|
@staticmethod
|
|
59
|
-
def _get_metric_value_by_user(
|
|
60
|
-
ks: List[int], ground_truth: List, pred: List
|
|
61
|
-
) -> List[float]:
|
|
58
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
62
59
|
if not ground_truth or not pred:
|
|
63
60
|
return [0.0 for _ in ks]
|
|
64
61
|
set_gt = set(ground_truth)
|
replay/metrics/ndcg.py
CHANGED
|
@@ -4,7 +4,6 @@ from typing import List
|
|
|
4
4
|
from .base_metric import Metric
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
# pylint: disable=too-few-public-methods
|
|
8
7
|
class NDCG(Metric):
|
|
9
8
|
"""
|
|
10
9
|
Normalized Discounted Cumulative Gain is a metric
|
|
@@ -14,7 +13,7 @@ class NDCG(Metric):
|
|
|
14
13
|
whether the item was consumed or not, relevance value is ignored.
|
|
15
14
|
|
|
16
15
|
.. math::
|
|
17
|
-
DCG@K(i) =
|
|
16
|
+
DCG@K(i) = \\sum_{j=1}^{K}\\frac{\\mathbb{1}_{r_{ij}}}{\\log_2 (j+1)}
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
|
|
@@ -23,7 +22,7 @@ class NDCG(Metric):
|
|
|
23
22
|
for user :math:`i` and recommendation length :math:`K`.
|
|
24
23
|
|
|
25
24
|
.. math::
|
|
26
|
-
IDCG@K(i) = max(DCG@K(i)) =
|
|
25
|
+
IDCG@K(i) = max(DCG@K(i)) = \\sum_{j=1}^{K}\\frac{\\mathbb{1}_{j\\le|Rel_i|}}{\\log_2 (j+1)}
|
|
27
26
|
|
|
28
27
|
.. math::
|
|
29
28
|
nDCG@K(i) = \\frac {DCG@K(i)}{IDCG@K(i)}
|
|
@@ -33,7 +32,7 @@ class NDCG(Metric):
|
|
|
33
32
|
Metric is averaged by users.
|
|
34
33
|
|
|
35
34
|
.. math::
|
|
36
|
-
nDCG@K = \\frac {
|
|
35
|
+
nDCG@K = \\frac {\\sum_{i=1}^{N}nDCG@K(i)}{N}
|
|
37
36
|
|
|
38
37
|
>>> recommendations
|
|
39
38
|
query_id item_id rating
|
|
@@ -81,9 +80,7 @@ class NDCG(Metric):
|
|
|
81
80
|
"""
|
|
82
81
|
|
|
83
82
|
@staticmethod
|
|
84
|
-
def _get_metric_value_by_user(
|
|
85
|
-
ks: List[int], ground_truth: List, pred: List
|
|
86
|
-
) -> List[float]:
|
|
83
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
87
84
|
if not pred or not ground_truth:
|
|
88
85
|
return [0.0 for _ in ks]
|
|
89
86
|
set_gt = set(ground_truth)
|
replay/metrics/novelty.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, List, Type
|
|
2
2
|
|
|
3
|
-
from replay.utils import PandasDataFrame,
|
|
3
|
+
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
4
4
|
|
|
5
5
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
|
|
6
6
|
|
|
@@ -8,7 +8,6 @@ if TYPE_CHECKING: # pragma: no cover
|
|
|
8
8
|
__class__: Type
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
# pylint: disable=too-few-public-methods
|
|
12
11
|
class Novelty(Metric):
|
|
13
12
|
"""
|
|
14
13
|
Measure the fraction of shown items in recommendation list, that users\
|
|
@@ -16,11 +15,11 @@ class Novelty(Metric):
|
|
|
16
15
|
|
|
17
16
|
.. math::
|
|
18
17
|
Novelty@K(i) = \\frac
|
|
19
|
-
{
|
|
18
|
+
{\\parallel {R^{i}_{1..\\min(K, \\parallel R^{i} \\parallel)} \\setminus train^{i}} \\parallel}
|
|
20
19
|
{K}
|
|
21
20
|
|
|
22
21
|
.. math::
|
|
23
|
-
Novelty@K = \\frac {1}{N}
|
|
22
|
+
Novelty@K = \\frac {1}{N}\\sum_{i=1}^{N}Novelty@K(i)
|
|
24
23
|
|
|
25
24
|
:math:`R^{i}` -- the recommendations for the :math:`i`-th user.
|
|
26
25
|
|
|
@@ -114,9 +113,7 @@ class Novelty(Metric):
|
|
|
114
113
|
else self._convert_dict_to_dict_with_score(recommendations)
|
|
115
114
|
)
|
|
116
115
|
self._check_duplicates_dict(recommendations)
|
|
117
|
-
train = (
|
|
118
|
-
self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
119
|
-
)
|
|
116
|
+
train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
120
117
|
assert isinstance(train, dict)
|
|
121
118
|
|
|
122
119
|
return self._dict_call(
|
|
@@ -125,41 +122,25 @@ class Novelty(Metric):
|
|
|
125
122
|
train=train,
|
|
126
123
|
)
|
|
127
124
|
|
|
128
|
-
|
|
129
|
-
def _spark_call(
|
|
130
|
-
self, recommendations: SparkDataFrame, train: SparkDataFrame
|
|
131
|
-
) -> MetricsReturnType:
|
|
125
|
+
def _spark_call(self, recommendations: SparkDataFrame, train: SparkDataFrame) -> MetricsReturnType:
|
|
132
126
|
"""
|
|
133
127
|
Implementation for Pyspark DataFrame.
|
|
134
128
|
"""
|
|
135
|
-
recs = self._get_enriched_recommendations(
|
|
136
|
-
recommendations, train
|
|
137
|
-
).withColumnRenamed("ground_truth", "train")
|
|
129
|
+
recs = self._get_enriched_recommendations(recommendations, train).withColumnRenamed("ground_truth", "train")
|
|
138
130
|
recs = self._rearrange_columns(recs)
|
|
139
131
|
return self._spark_compute(recs)
|
|
140
132
|
|
|
141
|
-
|
|
142
|
-
def _polars_call(
|
|
143
|
-
self, recommendations: PolarsDataFrame, train: PolarsDataFrame
|
|
144
|
-
) -> MetricsReturnType:
|
|
133
|
+
def _polars_call(self, recommendations: PolarsDataFrame, train: PolarsDataFrame) -> MetricsReturnType:
|
|
145
134
|
"""
|
|
146
135
|
Implementation for Polars DataFrame.
|
|
147
136
|
"""
|
|
148
|
-
recs = self._get_enriched_recommendations(
|
|
149
|
-
recommendations, train
|
|
150
|
-
).rename({"ground_truth": "train"})
|
|
137
|
+
recs = self._get_enriched_recommendations(recommendations, train).rename({"ground_truth": "train"})
|
|
151
138
|
recs = self._rearrange_columns(recs)
|
|
152
139
|
return self._polars_compute(recs)
|
|
153
140
|
|
|
154
|
-
# pylint: disable=arguments-differ
|
|
155
141
|
@staticmethod
|
|
156
|
-
def _get_metric_value_by_user(
|
|
157
|
-
ks: List[int], pred: List, train: List
|
|
158
|
-
) -> List[float]:
|
|
142
|
+
def _get_metric_value_by_user(ks: List[int], pred: List, train: List) -> List[float]:
|
|
159
143
|
if not train or not pred:
|
|
160
144
|
return [1.0 for _ in ks]
|
|
161
145
|
set_train = set(train)
|
|
162
|
-
|
|
163
|
-
for k in ks:
|
|
164
|
-
res.append(1.0 - len(set(pred[:k]) & set_train) / len(pred[:k]))
|
|
165
|
-
return res
|
|
146
|
+
return [1.0 - len(set(pred[:k]) & set_train) / len(pred[:k]) for k in ks]
|