replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +10 -9
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +26 -18
- replay/data/nn/sequential_dataset.py +22 -18
- replay/data/nn/torch_sequential_dataset.py +17 -16
- replay/data/nn/utils.py +2 -1
- replay/data/schema.py +3 -12
- replay/metrics/base_metric.py +11 -10
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +4 -4
- replay/metrics/experiment.py +3 -3
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +16 -16
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/als.py +2 -2
- replay/models/association_rules.py +4 -3
- replay/models/base_neighbour_rec.py +3 -2
- replay/models/base_rec.py +11 -10
- replay/models/cat_pop_rec.py +2 -1
- replay/models/extensions/ann/ann_mixin.py +2 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/lin_ucb.py +57 -11
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +5 -18
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
- replay/models/nn/sequential/sasrec/dataset.py +3 -16
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +8 -8
- replay/models/slim.py +2 -2
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +3 -3
- replay/preprocessing/discretizer.py +8 -7
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +6 -6
- replay/preprocessing/label_encoder.py +8 -7
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +4 -4
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +10 -10
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +1 -1
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +6 -5
- replay/utils/types.py +3 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
- replay/utils/warnings.py +0 -26
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
replay/metrics/base_metric.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import Any, Optional, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import polars as pl
|
|
@@ -14,7 +15,7 @@ if PYSPARK_AVAILABLE:
|
|
|
14
15
|
from pyspark.sql.types import ArrayType, DoubleType, StructType
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
MetricsDataFrameLike = Union[DataFrameLike,
|
|
18
|
+
MetricsDataFrameLike = Union[DataFrameLike, dict]
|
|
18
19
|
MetricsMeanReturnType = Mapping[str, float]
|
|
19
20
|
MetricsPerUserReturnType = Mapping[str, Mapping[Any, float]]
|
|
20
21
|
MetricsReturnType = Union[MetricsMeanReturnType, MetricsPerUserReturnType]
|
|
@@ -29,7 +30,7 @@ class Metric(ABC):
|
|
|
29
30
|
|
|
30
31
|
def __init__(
|
|
31
32
|
self,
|
|
32
|
-
topk: Union[
|
|
33
|
+
topk: Union[list[int], int],
|
|
33
34
|
query_column: str = "query_id",
|
|
34
35
|
item_column: str = "item_id",
|
|
35
36
|
rating_column: str = "rating",
|
|
@@ -89,7 +90,7 @@ class Metric(ABC):
|
|
|
89
90
|
if duplicates_count:
|
|
90
91
|
self._duplicate_warn()
|
|
91
92
|
|
|
92
|
-
def _check_duplicates_dict(self, recommendations:
|
|
93
|
+
def _check_duplicates_dict(self, recommendations: dict) -> None:
|
|
93
94
|
for items in recommendations.values():
|
|
94
95
|
items_set = set(items)
|
|
95
96
|
if len(items) != len(items_set):
|
|
@@ -143,7 +144,7 @@ class Metric(ABC):
|
|
|
143
144
|
ground_truth=ground_truth,
|
|
144
145
|
)
|
|
145
146
|
|
|
146
|
-
def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) ->
|
|
147
|
+
def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> dict:
|
|
147
148
|
return (
|
|
148
149
|
data.sort_values(by=[self.rating_column, self.item_column], ascending=False, kind="stable")
|
|
149
150
|
.groupby(self.query_column)[self.item_column]
|
|
@@ -151,7 +152,7 @@ class Metric(ABC):
|
|
|
151
152
|
.to_dict()
|
|
152
153
|
)
|
|
153
154
|
|
|
154
|
-
def _convert_dict_to_dict_with_score(self, data:
|
|
155
|
+
def _convert_dict_to_dict_with_score(self, data: dict) -> dict:
|
|
155
156
|
converted_data = {}
|
|
156
157
|
for user, items in data.items():
|
|
157
158
|
is_sorted = True
|
|
@@ -164,10 +165,10 @@ class Metric(ABC):
|
|
|
164
165
|
converted_data[user] = [item for item, _ in items]
|
|
165
166
|
return converted_data
|
|
166
167
|
|
|
167
|
-
def _convert_pandas_to_dict_without_score(self, data: PandasDataFrame) ->
|
|
168
|
+
def _convert_pandas_to_dict_without_score(self, data: PandasDataFrame) -> dict:
|
|
168
169
|
return data.groupby(self.query_column)[self.item_column].apply(list).to_dict()
|
|
169
170
|
|
|
170
|
-
def _dict_call(self, users:
|
|
171
|
+
def _dict_call(self, users: list, **kwargs: dict) -> MetricsReturnType:
|
|
171
172
|
"""
|
|
172
173
|
Calculating metrics in dict format.
|
|
173
174
|
kwargs can contain different dicts (for example, ground_truth or train), it depends on the metric.
|
|
@@ -287,7 +288,7 @@ class Metric(ABC):
|
|
|
287
288
|
)
|
|
288
289
|
return self._rearrange_columns(enriched_recommendations)
|
|
289
290
|
|
|
290
|
-
def _aggregate_results_per_user(self, distribution_per_user:
|
|
291
|
+
def _aggregate_results_per_user(self, distribution_per_user: dict[Any, list[float]]) -> MetricsPerUserReturnType:
|
|
291
292
|
res: MetricsPerUserReturnType = {}
|
|
292
293
|
for index, val in enumerate(self.topk):
|
|
293
294
|
metric_name = f"{self.__name__}@{val}"
|
|
@@ -374,7 +375,7 @@ class Metric(ABC):
|
|
|
374
375
|
|
|
375
376
|
@staticmethod
|
|
376
377
|
@abstractmethod
|
|
377
|
-
def _get_metric_value_by_user(ks:
|
|
378
|
+
def _get_metric_value_by_user(ks: list[int], *args: list) -> list[float]: # pragma: no cover
|
|
378
379
|
"""
|
|
379
380
|
Metric calculation for one user.
|
|
380
381
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import polars as pl
|
|
@@ -62,7 +62,7 @@ class CategoricalDiversity(Metric):
|
|
|
62
62
|
|
|
63
63
|
def __init__(
|
|
64
64
|
self,
|
|
65
|
-
topk: Union[
|
|
65
|
+
topk: Union[list, int],
|
|
66
66
|
query_column: str = "query_id",
|
|
67
67
|
category_column: str = "category_id",
|
|
68
68
|
rating_column: str = "rating",
|
|
@@ -195,7 +195,7 @@ class CategoricalDiversity(Metric):
|
|
|
195
195
|
return self._polars_compute_per_user(recs)
|
|
196
196
|
return self._polars_compute_agg(recs)
|
|
197
197
|
|
|
198
|
-
def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) ->
|
|
198
|
+
def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> dict:
|
|
199
199
|
return (
|
|
200
200
|
data.sort_values(by=self.rating_column, ascending=False)
|
|
201
201
|
.groupby(self.query_column)[self.category_column]
|
|
@@ -203,7 +203,7 @@ class CategoricalDiversity(Metric):
|
|
|
203
203
|
.to_dict()
|
|
204
204
|
)
|
|
205
205
|
|
|
206
|
-
def _precalculate_unique_cats(self, recommendations:
|
|
206
|
+
def _precalculate_unique_cats(self, recommendations: dict) -> dict:
|
|
207
207
|
"""
|
|
208
208
|
Precalculate unique categories for each prefix for each user.
|
|
209
209
|
"""
|
|
@@ -217,14 +217,14 @@ class CategoricalDiversity(Metric):
|
|
|
217
217
|
answer[user] = unique_len
|
|
218
218
|
return answer
|
|
219
219
|
|
|
220
|
-
def _dict_compute_per_user(self, precalculated_answer:
|
|
220
|
+
def _dict_compute_per_user(self, precalculated_answer: dict) -> MetricsPerUserReturnType:
|
|
221
221
|
distribution_per_user = defaultdict(list)
|
|
222
222
|
for k in self.topk:
|
|
223
223
|
for user, unique_cats in precalculated_answer.items():
|
|
224
224
|
distribution_per_user[user].append(unique_cats[min(len(unique_cats), k) - 1] / k)
|
|
225
225
|
return self._aggregate_results_per_user(distribution_per_user)
|
|
226
226
|
|
|
227
|
-
def _dict_compute_mean(self, precalculated_answer:
|
|
227
|
+
def _dict_compute_mean(self, precalculated_answer: dict) -> MetricsMeanReturnType:
|
|
228
228
|
distribution_list = []
|
|
229
229
|
for unique_cats in precalculated_answer.values():
|
|
230
230
|
metrics_per_user = []
|
|
@@ -238,7 +238,7 @@ class CategoricalDiversity(Metric):
|
|
|
238
238
|
metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
|
|
239
239
|
return self._aggregate_results(metrics)
|
|
240
240
|
|
|
241
|
-
def _dict_call(self, precalculated_answer:
|
|
241
|
+
def _dict_call(self, precalculated_answer: dict) -> MetricsReturnType:
|
|
242
242
|
"""
|
|
243
243
|
Calculating metrics in dict format.
|
|
244
244
|
"""
|
|
@@ -247,5 +247,5 @@ class CategoricalDiversity(Metric):
|
|
|
247
247
|
return self._dict_compute_mean(precalculated_answer)
|
|
248
248
|
|
|
249
249
|
@staticmethod
|
|
250
|
-
def _get_metric_value_by_user(ks:
|
|
250
|
+
def _get_metric_value_by_user(ks: list[int], *args: list) -> list[float]: # pragma: no cover
|
|
251
251
|
pass
|
replay/metrics/coverage.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import operator
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Union
|
|
4
4
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
@@ -60,7 +60,7 @@ class Coverage(Metric):
|
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
62
62
|
self,
|
|
63
|
-
topk: Union[
|
|
63
|
+
topk: Union[list, int],
|
|
64
64
|
query_column: str = "query_id",
|
|
65
65
|
item_column: str = "item_id",
|
|
66
66
|
rating_column: str = "rating",
|
|
@@ -173,7 +173,7 @@ class Coverage(Metric):
|
|
|
173
173
|
recs = self._get_enriched_recommendations(recommendations)
|
|
174
174
|
return self._polars_compute(recs, train)
|
|
175
175
|
|
|
176
|
-
def _dict_call(self, recommendations:
|
|
176
|
+
def _dict_call(self, recommendations: dict, train: dict) -> MetricsReturnType:
|
|
177
177
|
"""
|
|
178
178
|
Calculating metrics in dict format.
|
|
179
179
|
"""
|
|
@@ -229,5 +229,5 @@ class Coverage(Metric):
|
|
|
229
229
|
return self._dict_call(recommendations, train)
|
|
230
230
|
|
|
231
231
|
@staticmethod
|
|
232
|
-
def _get_metric_value_by_user(ks, *args) ->
|
|
232
|
+
def _get_metric_value_by_user(ks, *args) -> list[float]: # pragma: no cover
|
|
233
233
|
pass
|
replay/metrics/experiment.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
@@ -102,10 +102,10 @@ class Experiment:
|
|
|
102
102
|
|
|
103
103
|
def __init__(
|
|
104
104
|
self,
|
|
105
|
-
metrics:
|
|
105
|
+
metrics: list[Metric],
|
|
106
106
|
ground_truth: MetricsDataFrameLike,
|
|
107
107
|
train: Optional[MetricsDataFrameLike] = None,
|
|
108
|
-
base_recommendations: Optional[Union[MetricsDataFrameLike,
|
|
108
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]] = None,
|
|
109
109
|
query_column: str = "query_id",
|
|
110
110
|
item_column: str = "item_id",
|
|
111
111
|
rating_column: str = "rating",
|
replay/metrics/hitrate.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -62,7 +60,7 @@ class HitRate(Metric):
|
|
|
62
60
|
"""
|
|
63
61
|
|
|
64
62
|
@staticmethod
|
|
65
|
-
def _get_metric_value_by_user(ks:
|
|
63
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
66
64
|
if not ground_truth or not pred:
|
|
67
65
|
return [0.0 for _ in ks]
|
|
68
66
|
set_gt = set(ground_truth)
|
replay/metrics/map.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -63,7 +61,7 @@ class MAP(Metric):
|
|
|
63
61
|
"""
|
|
64
62
|
|
|
65
63
|
@staticmethod
|
|
66
|
-
def _get_metric_value_by_user(ks:
|
|
64
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
67
65
|
if not ground_truth or not pred:
|
|
68
66
|
return [0.0 for _ in ks]
|
|
69
67
|
res = []
|
replay/metrics/mrr.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -55,7 +53,7 @@ class MRR(Metric):
|
|
|
55
53
|
"""
|
|
56
54
|
|
|
57
55
|
@staticmethod
|
|
58
|
-
def _get_metric_value_by_user(ks:
|
|
56
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
59
57
|
if not ground_truth or not pred:
|
|
60
58
|
return [0.0 for _ in ks]
|
|
61
59
|
set_gt = set(ground_truth)
|
replay/metrics/ndcg.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import List
|
|
3
2
|
|
|
4
3
|
from .base_metric import Metric
|
|
5
4
|
|
|
@@ -80,7 +79,7 @@ class NDCG(Metric):
|
|
|
80
79
|
"""
|
|
81
80
|
|
|
82
81
|
@staticmethod
|
|
83
|
-
def _get_metric_value_by_user(ks:
|
|
82
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
84
83
|
if not pred or not ground_truth:
|
|
85
84
|
return [0.0 for _ in ks]
|
|
86
85
|
set_gt = set(ground_truth)
|
replay/metrics/novelty.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
2
|
|
|
3
3
|
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
4
4
|
|
|
5
5
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
|
|
6
6
|
|
|
7
7
|
if TYPE_CHECKING: # pragma: no cover
|
|
8
|
-
__class__:
|
|
8
|
+
__class__: type
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class Novelty(Metric):
|
|
@@ -139,7 +139,7 @@ class Novelty(Metric):
|
|
|
139
139
|
return self._polars_compute(recs)
|
|
140
140
|
|
|
141
141
|
@staticmethod
|
|
142
|
-
def _get_metric_value_by_user(ks:
|
|
142
|
+
def _get_metric_value_by_user(ks: list[int], pred: list, train: list) -> list[float]:
|
|
143
143
|
if not train or not pred:
|
|
144
144
|
return [1.0 for _ in ks]
|
|
145
145
|
set_train = set(train)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional, Union
|
|
3
3
|
|
|
4
4
|
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
5
5
|
|
|
@@ -132,7 +132,7 @@ class OfflineMetrics:
|
|
|
132
132
|
<BLANKLINE>
|
|
133
133
|
"""
|
|
134
134
|
|
|
135
|
-
_metrics_call_requirement_map:
|
|
135
|
+
_metrics_call_requirement_map: dict[str, list[str]] = {
|
|
136
136
|
"HitRate": ["ground_truth"],
|
|
137
137
|
"MAP": ["ground_truth"],
|
|
138
138
|
"NDCG": ["ground_truth"],
|
|
@@ -147,7 +147,7 @@ class OfflineMetrics:
|
|
|
147
147
|
|
|
148
148
|
def __init__(
|
|
149
149
|
self,
|
|
150
|
-
metrics:
|
|
150
|
+
metrics: list[Metric],
|
|
151
151
|
query_column: str = "query_id",
|
|
152
152
|
item_column: str = "item_id",
|
|
153
153
|
rating_column: str = "rating",
|
|
@@ -174,9 +174,9 @@ class OfflineMetrics:
|
|
|
174
174
|
:param allow_caching: (bool): The flag for using caching to optimize calculations.
|
|
175
175
|
Default: ``True``.
|
|
176
176
|
"""
|
|
177
|
-
self.unexpectedness_metric:
|
|
178
|
-
self.diversity_metric:
|
|
179
|
-
self.main_metrics:
|
|
177
|
+
self.unexpectedness_metric: list[Metric] = []
|
|
178
|
+
self.diversity_metric: list[Metric] = []
|
|
179
|
+
self.main_metrics: list[Metric] = []
|
|
180
180
|
self._allow_caching = allow_caching
|
|
181
181
|
|
|
182
182
|
for metric in metrics:
|
|
@@ -198,7 +198,7 @@ class OfflineMetrics:
|
|
|
198
198
|
recommendations: Union[SparkDataFrame, PolarsDataFrame],
|
|
199
199
|
ground_truth: Union[SparkDataFrame, PolarsDataFrame],
|
|
200
200
|
train: Optional[Union[SparkDataFrame, PolarsDataFrame]],
|
|
201
|
-
) ->
|
|
201
|
+
) -> tuple[dict[str, Union[SparkDataFrame, PolarsDataFrame]], Optional[Union[SparkDataFrame, PolarsDataFrame]]]:
|
|
202
202
|
if len(self.main_metrics) == 0:
|
|
203
203
|
return {}, train
|
|
204
204
|
result_dict = {}
|
|
@@ -257,21 +257,21 @@ class OfflineMetrics:
|
|
|
257
257
|
|
|
258
258
|
return result_dict, train
|
|
259
259
|
|
|
260
|
-
def _cache_dataframes(self, dataframes:
|
|
260
|
+
def _cache_dataframes(self, dataframes: dict[str, SparkDataFrame]) -> None:
|
|
261
261
|
for data in dataframes.values():
|
|
262
262
|
data.cache()
|
|
263
263
|
|
|
264
|
-
def _unpersist_dataframes(self, dataframes:
|
|
264
|
+
def _unpersist_dataframes(self, dataframes: dict[str, SparkDataFrame]) -> None:
|
|
265
265
|
for data in dataframes.values():
|
|
266
266
|
data.unpersist()
|
|
267
267
|
|
|
268
268
|
def _calculate_metrics(
|
|
269
269
|
self,
|
|
270
|
-
enriched_recs_dict:
|
|
270
|
+
enriched_recs_dict: dict[str, Union[SparkDataFrame, PolarsDataFrame]],
|
|
271
271
|
train: Optional[Union[SparkDataFrame, PolarsDataFrame]] = None,
|
|
272
272
|
is_spark: bool = True,
|
|
273
273
|
) -> MetricsReturnType:
|
|
274
|
-
result:
|
|
274
|
+
result: dict = {}
|
|
275
275
|
for metric in self.metrics:
|
|
276
276
|
metric_args = {}
|
|
277
277
|
if metric.__class__.__name__ == "Coverage" and train is not None:
|
|
@@ -295,7 +295,7 @@ class OfflineMetrics:
|
|
|
295
295
|
recommendations: MetricsDataFrameLike,
|
|
296
296
|
ground_truth: MetricsDataFrameLike,
|
|
297
297
|
train: Optional[MetricsDataFrameLike],
|
|
298
|
-
base_recommendations: Optional[Union[MetricsDataFrameLike,
|
|
298
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]],
|
|
299
299
|
) -> None:
|
|
300
300
|
types = set()
|
|
301
301
|
types.add(type(recommendations))
|
|
@@ -379,8 +379,8 @@ class OfflineMetrics:
|
|
|
379
379
|
recommendations: MetricsDataFrameLike,
|
|
380
380
|
ground_truth: MetricsDataFrameLike,
|
|
381
381
|
train: Optional[MetricsDataFrameLike] = None,
|
|
382
|
-
base_recommendations: Optional[Union[MetricsDataFrameLike,
|
|
383
|
-
) ->
|
|
382
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]] = None,
|
|
383
|
+
) -> dict[str, float]:
|
|
384
384
|
"""
|
|
385
385
|
Compute metrics.
|
|
386
386
|
|
|
@@ -450,12 +450,12 @@ class OfflineMetrics:
|
|
|
450
450
|
if is_spark and self._allow_caching:
|
|
451
451
|
self._unpersist_dataframes(enriched_recs_dict)
|
|
452
452
|
else: # Calculating metrics in dict format
|
|
453
|
-
current_map:
|
|
453
|
+
current_map: dict[str, Union[PandasDataFrame, dict]] = {
|
|
454
454
|
"ground_truth": ground_truth,
|
|
455
455
|
"train": train,
|
|
456
456
|
}
|
|
457
457
|
for metric in self.metrics:
|
|
458
|
-
args_to_call:
|
|
458
|
+
args_to_call: dict[str, Union[PandasDataFrame, dict]] = {"recommendations": recommendations}
|
|
459
459
|
for data_name in self._metrics_call_requirement_map[str(metric.__class__.__name__)]:
|
|
460
460
|
args_to_call[data_name] = current_map[data_name]
|
|
461
461
|
result.update(metric(**args_to_call))
|
replay/metrics/precision.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -61,7 +59,7 @@ class Precision(Metric):
|
|
|
61
59
|
"""
|
|
62
60
|
|
|
63
61
|
@staticmethod
|
|
64
|
-
def _get_metric_value_by_user(ks:
|
|
62
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
65
63
|
if not ground_truth or not pred:
|
|
66
64
|
return [0.0 for _ in ks]
|
|
67
65
|
set_gt = set(ground_truth)
|
replay/metrics/recall.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -65,7 +63,7 @@ class Recall(Metric):
|
|
|
65
63
|
"""
|
|
66
64
|
|
|
67
65
|
@staticmethod
|
|
68
|
-
def _get_metric_value_by_user(ks:
|
|
66
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
69
67
|
if not ground_truth or not pred:
|
|
70
68
|
return [0.0 for _ in ks]
|
|
71
69
|
set_gt = set(ground_truth)
|
replay/metrics/rocauc.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -74,7 +72,7 @@ class RocAuc(Metric):
|
|
|
74
72
|
"""
|
|
75
73
|
|
|
76
74
|
@staticmethod
|
|
77
|
-
def _get_metric_value_by_user(ks:
|
|
75
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
78
76
|
if not ground_truth or not pred:
|
|
79
77
|
return [0.0 for _ in ks]
|
|
80
78
|
set_gt = set(ground_truth)
|
replay/metrics/surprisal.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import polars as pl
|
|
@@ -82,7 +82,7 @@ class Surprisal(Metric):
|
|
|
82
82
|
<BLANKLINE>
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
|
-
def _get_weights(self, train:
|
|
85
|
+
def _get_weights(self, train: dict) -> dict:
|
|
86
86
|
n_users = len(train.keys())
|
|
87
87
|
items_counter = defaultdict(set)
|
|
88
88
|
for user, items in train.items():
|
|
@@ -93,7 +93,7 @@ class Surprisal(Metric):
|
|
|
93
93
|
weights[item] = np.log2(n_users / len(users)) / np.log2(n_users)
|
|
94
94
|
return weights
|
|
95
95
|
|
|
96
|
-
def _get_recommendation_weights(self, recommendations:
|
|
96
|
+
def _get_recommendation_weights(self, recommendations: dict, train: dict) -> dict:
|
|
97
97
|
weights = self._get_weights(train)
|
|
98
98
|
recs_with_weights = {}
|
|
99
99
|
for user, items in recommendations.items():
|
|
@@ -183,7 +183,7 @@ class Surprisal(Metric):
|
|
|
183
183
|
)
|
|
184
184
|
|
|
185
185
|
@staticmethod
|
|
186
|
-
def _get_metric_value_by_user(ks:
|
|
186
|
+
def _get_metric_value_by_user(ks: list[int], pred_item_ids: list, pred_weights: list) -> list[float]:
|
|
187
187
|
if not pred_item_ids:
|
|
188
188
|
return [0.0 for _ in ks]
|
|
189
189
|
res = []
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
from collections.abc import Mapping
|
|
2
3
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Literal, Optional
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
@@ -19,13 +20,13 @@ MetricName = Literal[
|
|
|
19
20
|
"coverage",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
|
-
DEFAULT_METRICS:
|
|
23
|
+
DEFAULT_METRICS: list[MetricName] = [
|
|
23
24
|
"map",
|
|
24
25
|
"ndcg",
|
|
25
26
|
"recall",
|
|
26
27
|
]
|
|
27
28
|
|
|
28
|
-
DEFAULT_KS:
|
|
29
|
+
DEFAULT_KS: list[int] = [1, 5, 10, 20]
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
@dataclass
|
|
@@ -34,7 +35,7 @@ class _MetricRequirements:
|
|
|
34
35
|
Stores description of metrics which need to be computed
|
|
35
36
|
"""
|
|
36
37
|
|
|
37
|
-
top_k:
|
|
38
|
+
top_k: list[int]
|
|
38
39
|
need_recall: bool
|
|
39
40
|
need_precision: bool
|
|
40
41
|
need_ndcg: bool
|
|
@@ -68,14 +69,14 @@ class _MetricRequirements:
|
|
|
68
69
|
self._metric_names = metrics
|
|
69
70
|
|
|
70
71
|
@property
|
|
71
|
-
def metric_names(self) ->
|
|
72
|
+
def metric_names(self) -> list[str]:
|
|
72
73
|
"""
|
|
73
74
|
Getting metric names
|
|
74
75
|
"""
|
|
75
76
|
return self._metric_names
|
|
76
77
|
|
|
77
78
|
@classmethod
|
|
78
|
-
def from_metrics(cls, metrics:
|
|
79
|
+
def from_metrics(cls, metrics: set[str], top_k: list[int]) -> "_MetricRequirements":
|
|
79
80
|
"""
|
|
80
81
|
Creating a class based on a given list of metrics and K values
|
|
81
82
|
"""
|
|
@@ -96,7 +97,7 @@ class _CoverageHelper:
|
|
|
96
97
|
Computes coverage metric over multiple batches
|
|
97
98
|
"""
|
|
98
99
|
|
|
99
|
-
def __init__(self, top_k:
|
|
100
|
+
def __init__(self, top_k: list[int], item_count: Optional[int]) -> None:
|
|
100
101
|
"""
|
|
101
102
|
:param top_k: (list): Consider the highest k scores in the ranking.
|
|
102
103
|
:param item_count: (optional, int): the total number of items in the dataset.
|
|
@@ -110,7 +111,7 @@ class _CoverageHelper:
|
|
|
110
111
|
Reload the metric counter
|
|
111
112
|
"""
|
|
112
113
|
self._train_hist = torch.zeros(self.item_count)
|
|
113
|
-
self._pred_hist:
|
|
114
|
+
self._pred_hist: dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
|
|
114
115
|
|
|
115
116
|
def _ensure_hists_on_device(self, device: torch.device) -> None:
|
|
116
117
|
self._train_hist = self._train_hist.to(device)
|
|
@@ -197,8 +198,8 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
197
198
|
|
|
198
199
|
def __init__(
|
|
199
200
|
self,
|
|
200
|
-
metrics:
|
|
201
|
-
top_k: Optional[
|
|
201
|
+
metrics: list[MetricName] = DEFAULT_METRICS,
|
|
202
|
+
top_k: Optional[list[int]] = DEFAULT_KS,
|
|
202
203
|
item_count: Optional[int] = None,
|
|
203
204
|
) -> None:
|
|
204
205
|
"""
|
|
@@ -331,8 +332,8 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
331
332
|
|
|
332
333
|
def _compute_metrics_sum(
|
|
333
334
|
self, predictions: torch.LongTensor, ground_truth: torch.LongTensor, train: Optional[torch.LongTensor]
|
|
334
|
-
) ->
|
|
335
|
-
result:
|
|
335
|
+
) -> list[float]:
|
|
336
|
+
result: list[float] = []
|
|
336
337
|
|
|
337
338
|
# Getting a tensor of the same size as predictions
|
|
338
339
|
# The tensor contains information about whether the item from the prediction is present in the test set
|
replay/metrics/unexpectedness.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
4
4
|
|
|
@@ -152,7 +152,7 @@ class Unexpectedness(Metric):
|
|
|
152
152
|
)
|
|
153
153
|
|
|
154
154
|
@staticmethod
|
|
155
|
-
def _get_metric_value_by_user(ks:
|
|
155
|
+
def _get_metric_value_by_user(ks: list[int], base_recs: Optional[list], recs: Optional[list]) -> list[float]:
|
|
156
156
|
if not base_recs or not recs:
|
|
157
157
|
return [0.0 for _ in ks]
|
|
158
158
|
return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
|
replay/models/als.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from os.path import join
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from replay.data import Dataset
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
@@ -148,7 +148,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
148
148
|
|
|
149
149
|
def _get_features(
|
|
150
150
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
151
|
-
) ->
|
|
151
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
152
152
|
entity = "user" if self.query_column in ids.columns else "item"
|
|
153
153
|
entity_col = self.query_column if self.query_column in ids.columns else self.item_column
|
|
154
154
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Any, Optional, Union
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
@@ -97,13 +98,13 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
97
98
|
In this case all items in sessions should have the same rating.
|
|
98
99
|
"""
|
|
99
100
|
|
|
100
|
-
def _get_ann_infer_params(self) ->
|
|
101
|
+
def _get_ann_infer_params(self) -> dict[str, Any]:
|
|
101
102
|
return {
|
|
102
103
|
"features_col": None,
|
|
103
104
|
}
|
|
104
105
|
|
|
105
106
|
can_predict_item_to_item = True
|
|
106
|
-
item_to_item_metrics:
|
|
107
|
+
item_to_item_metrics: list[str] = ["lift", "confidence", "confidence_gain"]
|
|
107
108
|
similarity: SparkDataFrame
|
|
108
109
|
can_change_metric = True
|
|
109
110
|
_search_space = {
|
|
@@ -4,7 +4,8 @@ Part of set of abstract classes (from base_rec.py)
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from abc import ABC
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from typing import Any, Optional, Union
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
|
|
@@ -187,7 +188,7 @@ class NeighbourRec(ANNMixin, Recommender, ABC):
|
|
|
187
188
|
"similarity" if metric is None else metric,
|
|
188
189
|
)
|
|
189
190
|
|
|
190
|
-
def _configure_index_builder(self, interactions: SparkDataFrame) ->
|
|
191
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
|
|
191
192
|
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
192
193
|
self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
|
|
193
194
|
return similarity_df, {
|