replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
replay/data/nn/utils.py
CHANGED
|
@@ -2,11 +2,11 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import polars as pl
|
|
4
4
|
|
|
5
|
-
from replay.utils.spark_utils import spark_to_pandas
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame
|
|
6
|
+
from replay.utils.spark_utils import spark_to_pandas
|
|
7
7
|
|
|
8
8
|
if PYSPARK_AVAILABLE: # pragma: no cover
|
|
9
|
-
import pyspark.sql.functions as
|
|
9
|
+
import pyspark.sql.functions as sf
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optional[str] = None) -> DataFrameLike:
|
|
@@ -38,9 +38,7 @@ def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optiona
|
|
|
38
38
|
event_cols_without_groupby.insert(0, sort_col)
|
|
39
39
|
events = events.sort(event_cols_without_groupby)
|
|
40
40
|
|
|
41
|
-
grouped_sequences = events.group_by(groupby_col).agg(
|
|
42
|
-
*[pl.col(x) for x in event_cols_without_groupby]
|
|
43
|
-
)
|
|
41
|
+
grouped_sequences = events.group_by(groupby_col).agg(*[pl.col(x) for x in event_cols_without_groupby])
|
|
44
42
|
else:
|
|
45
43
|
event_cols_without_groupby = events.columns.copy()
|
|
46
44
|
event_cols_without_groupby.remove(groupby_col)
|
|
@@ -49,16 +47,16 @@ def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optiona
|
|
|
49
47
|
event_cols_without_groupby.remove(sort_col)
|
|
50
48
|
event_cols_without_groupby.insert(0, sort_col)
|
|
51
49
|
|
|
52
|
-
all_cols_struct =
|
|
50
|
+
all_cols_struct = sf.struct(event_cols_without_groupby)
|
|
53
51
|
|
|
54
|
-
collect_fn =
|
|
52
|
+
collect_fn = sf.collect_list(all_cols_struct)
|
|
55
53
|
if sort_col:
|
|
56
|
-
collect_fn =
|
|
54
|
+
collect_fn = sf.sort_array(collect_fn)
|
|
57
55
|
|
|
58
56
|
grouped_sequences = (
|
|
59
57
|
events.groupby(groupby_col)
|
|
60
58
|
.agg(collect_fn.alias("_"))
|
|
61
|
-
.select([
|
|
59
|
+
.select([sf.col(groupby_col)] + [sf.col(f"_.{col}").alias(col) for col in event_cols_without_groupby])
|
|
62
60
|
.drop("_")
|
|
63
61
|
)
|
|
64
62
|
|
replay/data/schema.py
CHANGED
|
@@ -45,7 +45,6 @@ class FeatureInfo:
|
|
|
45
45
|
Information about a feature.
|
|
46
46
|
"""
|
|
47
47
|
|
|
48
|
-
# pylint: disable=too-many-arguments
|
|
49
48
|
def __init__(
|
|
50
49
|
self,
|
|
51
50
|
column: str,
|
|
@@ -72,7 +71,8 @@ class FeatureInfo:
|
|
|
72
71
|
self._feature_hint = feature_hint
|
|
73
72
|
|
|
74
73
|
if feature_type == FeatureType.NUMERICAL and cardinality:
|
|
75
|
-
|
|
74
|
+
msg = "Cardinality is needed only with categorical feature_type."
|
|
75
|
+
raise ValueError(msg)
|
|
76
76
|
self._cardinality = cardinality
|
|
77
77
|
|
|
78
78
|
@property
|
|
@@ -112,14 +112,12 @@ class FeatureInfo:
|
|
|
112
112
|
:returns: cardinality of the feature.
|
|
113
113
|
"""
|
|
114
114
|
if self.feature_type != FeatureType.CATEGORICAL:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
)
|
|
115
|
+
msg = f"Can not get cardinality because feature_type of {self.column} column is not categorical."
|
|
116
|
+
raise RuntimeError(msg)
|
|
118
117
|
if hasattr(self, "_cardinality_callback") and self._cardinality is None:
|
|
119
118
|
self._cardinality = self._cardinality_callback(self._column)
|
|
120
119
|
return self._cardinality
|
|
121
120
|
|
|
122
|
-
# pylint: disable=attribute-defined-outside-init
|
|
123
121
|
def _set_cardinality_callback(self, callback: Callable) -> None:
|
|
124
122
|
self._cardinality_callback = callback
|
|
125
123
|
|
|
@@ -130,7 +128,6 @@ class FeatureInfo:
|
|
|
130
128
|
self._cardinality = None
|
|
131
129
|
|
|
132
130
|
|
|
133
|
-
# pylint: disable=too-many-public-methods
|
|
134
131
|
class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
135
132
|
"""
|
|
136
133
|
Key-value like collection with information about all dataset features.
|
|
@@ -174,8 +171,9 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
174
171
|
:returns: extract a feature information from a schema.
|
|
175
172
|
"""
|
|
176
173
|
if len(self._features_schema) > 1:
|
|
177
|
-
|
|
178
|
-
|
|
174
|
+
msg = "Only one element feature schema can be converted to single feature"
|
|
175
|
+
raise ValueError(msg)
|
|
176
|
+
return next(iter(self._features_schema.values()))
|
|
179
177
|
|
|
180
178
|
def items(self) -> ItemsView[str, FeatureInfo]:
|
|
181
179
|
return self._features_schema.items()
|
|
@@ -186,7 +184,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
186
184
|
def values(self) -> ValuesView[FeatureInfo]:
|
|
187
185
|
return self._features_schema.values()
|
|
188
186
|
|
|
189
|
-
def get(
|
|
187
|
+
def get(
|
|
190
188
|
self,
|
|
191
189
|
key: str,
|
|
192
190
|
default: Optional[FeatureInfo] = None,
|
|
@@ -358,7 +356,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
358
356
|
for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
|
|
359
357
|
filtered_features = list(
|
|
360
358
|
filter(
|
|
361
|
-
lambda x: filtration_func(x, filtration_param),
|
|
359
|
+
lambda x: filtration_func(x, filtration_param),
|
|
362
360
|
filtered_features,
|
|
363
361
|
)
|
|
364
362
|
)
|
|
@@ -391,7 +389,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
391
389
|
for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
|
|
392
390
|
filtered_features = list(
|
|
393
391
|
filter(
|
|
394
|
-
lambda x: filtration_func(x, filtration_param),
|
|
392
|
+
lambda x: filtration_func(x, filtration_param),
|
|
395
393
|
filtered_features,
|
|
396
394
|
)
|
|
397
395
|
)
|
|
@@ -426,7 +424,6 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
426
424
|
def _type_drop(value: FeatureInfo, feature_type: FeatureType) -> bool:
|
|
427
425
|
return value.feature_type != feature_type if feature_type else True
|
|
428
426
|
|
|
429
|
-
# pylint: disable=no-self-use
|
|
430
427
|
@staticmethod
|
|
431
428
|
def _hint_drop(value: FeatureInfo, feature_hint: FeatureHint) -> bool:
|
|
432
429
|
return value.feature_hint != feature_hint if feature_hint else True
|
|
@@ -451,13 +448,16 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
|
|
|
451
448
|
item_query_names[feature.feature_hint] += [feature.column]
|
|
452
449
|
|
|
453
450
|
if len(duplicates) > 0:
|
|
454
|
-
|
|
451
|
+
msg = (
|
|
455
452
|
"Features column names should be unique, exept ITEM_ID and QUERY_ID columns. "
|
|
456
|
-
|
|
453
|
+
f"{duplicates} columns are not unique."
|
|
457
454
|
)
|
|
455
|
+
raise ValueError(msg)
|
|
458
456
|
|
|
459
457
|
if len(item_query_names[FeatureHint.ITEM_ID]) > 1:
|
|
460
|
-
|
|
458
|
+
msg = f"ITEM_ID must be present only once. Rename {item_query_names[FeatureHint.ITEM_ID]}"
|
|
459
|
+
raise ValueError(msg)
|
|
461
460
|
|
|
462
461
|
if len(item_query_names[FeatureHint.QUERY_ID]) > 1:
|
|
463
|
-
|
|
462
|
+
msg = f"QUERY_ID must be present only once. Rename {item_query_names[FeatureHint.QUERY_ID]}"
|
|
463
|
+
raise ValueError(msg)
|
replay/data/spark_schema.py
CHANGED
replay/metrics/base_metric.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Any, Dict, List, Mapping, Union
|
|
3
|
+
from typing import Any, Dict, List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import polars as pl
|
|
7
7
|
|
|
8
|
-
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame,
|
|
8
|
+
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
9
9
|
|
|
10
10
|
from .descriptors import CalculationDescriptor, Mean
|
|
11
11
|
|
|
@@ -27,7 +27,7 @@ class MetricDuplicatesWarning(Warning):
|
|
|
27
27
|
class Metric(ABC):
|
|
28
28
|
"""Base metric class"""
|
|
29
29
|
|
|
30
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
31
|
self,
|
|
32
32
|
topk: Union[List[int], int],
|
|
33
33
|
query_column: str = "query_id",
|
|
@@ -46,11 +46,13 @@ class Metric(ABC):
|
|
|
46
46
|
if isinstance(topk, list):
|
|
47
47
|
for item in topk:
|
|
48
48
|
if not isinstance(item, int):
|
|
49
|
-
|
|
49
|
+
msg = f"{item} is not int"
|
|
50
|
+
raise ValueError(msg)
|
|
50
51
|
elif isinstance(topk, int):
|
|
51
52
|
topk = [topk]
|
|
52
53
|
else:
|
|
53
|
-
|
|
54
|
+
msg = "topk not list or int"
|
|
55
|
+
raise ValueError(msg)
|
|
54
56
|
self.topk = sorted(topk)
|
|
55
57
|
self.query_column = query_column
|
|
56
58
|
self.item_column = item_column
|
|
@@ -60,11 +62,8 @@ class Metric(ABC):
|
|
|
60
62
|
@property
|
|
61
63
|
def __name__(self) -> str:
|
|
62
64
|
mode_name = self._mode.__name__
|
|
63
|
-
return str(type(self).__name__) + (
|
|
64
|
-
f"-{mode_name}" if mode_name != "Mean" else ""
|
|
65
|
-
)
|
|
65
|
+
return str(type(self).__name__) + (f"-{mode_name}" if mode_name != "Mean" else "")
|
|
66
66
|
|
|
67
|
-
# pylint: disable=no-self-use
|
|
68
67
|
def _check_dataframes_equal_types(
|
|
69
68
|
self,
|
|
70
69
|
recommendations: MetricsDataFrameLike,
|
|
@@ -74,39 +73,31 @@ class Metric(ABC):
|
|
|
74
73
|
Types of all data frames must be the same.
|
|
75
74
|
"""
|
|
76
75
|
if not isinstance(recommendations, type(ground_truth)):
|
|
77
|
-
|
|
76
|
+
msg = "All given data frames must have the same type"
|
|
77
|
+
raise ValueError(msg)
|
|
78
78
|
|
|
79
79
|
def _duplicate_warn(self):
|
|
80
80
|
warnings.warn(
|
|
81
|
-
"The recommendations contain duplicated users and items."
|
|
82
|
-
"The metrics may be higher than the actual ones.",
|
|
81
|
+
"The recommendations contain duplicated users and items.The metrics may be higher than the actual ones.",
|
|
83
82
|
MetricDuplicatesWarning,
|
|
84
83
|
)
|
|
85
84
|
|
|
86
85
|
def _check_duplicates_spark(self, recommendations: SparkDataFrame) -> None:
|
|
87
86
|
duplicates_count = (
|
|
88
|
-
recommendations.groupBy(self.query_column, self.item_column)
|
|
89
|
-
.count()
|
|
90
|
-
.filter("count >= 2")
|
|
91
|
-
.count()
|
|
87
|
+
recommendations.groupBy(self.query_column, self.item_column).count().filter("count >= 2").count()
|
|
92
88
|
)
|
|
93
89
|
if duplicates_count:
|
|
94
90
|
self._duplicate_warn()
|
|
95
91
|
|
|
96
92
|
def _check_duplicates_dict(self, recommendations: Dict) -> None:
|
|
97
|
-
for
|
|
93
|
+
for items in recommendations.values():
|
|
98
94
|
items_set = set(items)
|
|
99
95
|
if len(items) != len(items_set):
|
|
100
96
|
self._duplicate_warn()
|
|
101
97
|
return
|
|
102
98
|
|
|
103
99
|
def _check_duplicates_polars(self, recommendations: PolarsDataFrame) -> None:
|
|
104
|
-
duplicates_count = (
|
|
105
|
-
recommendations
|
|
106
|
-
.group_by(self.query_column, self.item_column)
|
|
107
|
-
.len()
|
|
108
|
-
.filter(pl.col("len") > 1)
|
|
109
|
-
)
|
|
100
|
+
duplicates_count = recommendations.group_by(self.query_column, self.item_column).len().filter(pl.col("len") > 1)
|
|
110
101
|
if not duplicates_count.is_empty():
|
|
111
102
|
self._duplicate_warn()
|
|
112
103
|
|
|
@@ -144,11 +135,7 @@ class Metric(ABC):
|
|
|
144
135
|
else self._convert_dict_to_dict_with_score(recommendations)
|
|
145
136
|
)
|
|
146
137
|
self._check_duplicates_dict(recommendations)
|
|
147
|
-
ground_truth = (
|
|
148
|
-
self._convert_pandas_to_dict_without_score(ground_truth)
|
|
149
|
-
if is_pandas
|
|
150
|
-
else ground_truth
|
|
151
|
-
)
|
|
138
|
+
ground_truth = self._convert_pandas_to_dict_without_score(ground_truth) if is_pandas else ground_truth
|
|
152
139
|
assert isinstance(ground_truth, dict)
|
|
153
140
|
return self._dict_call(
|
|
154
141
|
list(ground_truth),
|
|
@@ -164,7 +151,6 @@ class Metric(ABC):
|
|
|
164
151
|
.to_dict()
|
|
165
152
|
)
|
|
166
153
|
|
|
167
|
-
# pylint: disable=no-self-use
|
|
168
154
|
def _convert_dict_to_dict_with_score(self, data: Dict) -> Dict:
|
|
169
155
|
converted_data = {}
|
|
170
156
|
for user, items in data.items():
|
|
@@ -191,31 +177,21 @@ class Metric(ABC):
|
|
|
191
177
|
distribution_per_user = {}
|
|
192
178
|
for user in users:
|
|
193
179
|
args = [kwargs[key].get(user, None) for key in keys_list]
|
|
194
|
-
distribution_per_user[user] = self._get_metric_value_by_user(
|
|
195
|
-
self.topk, *args
|
|
196
|
-
) # pylint: disable=protected-access
|
|
180
|
+
distribution_per_user[user] = self._get_metric_value_by_user(self.topk, *args)
|
|
197
181
|
if self._mode.__name__ == "PerUser":
|
|
198
182
|
return self._aggregate_results_per_user(distribution_per_user)
|
|
199
183
|
distribution = np.stack(list(distribution_per_user.values()))
|
|
200
184
|
assert distribution.shape[1] == len(self.topk)
|
|
201
|
-
metrics = []
|
|
202
|
-
for k in range(distribution.shape[1]):
|
|
203
|
-
metrics.append(self._mode.cpu(distribution[:, k]))
|
|
185
|
+
metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
|
|
204
186
|
return self._aggregate_results(metrics)
|
|
205
187
|
|
|
206
188
|
def _get_items_list_per_user_spark(
|
|
207
|
-
self, recommendations: SparkDataFrame, extra_column: str = None
|
|
189
|
+
self, recommendations: SparkDataFrame, extra_column: Optional[str] = None
|
|
208
190
|
) -> SparkDataFrame:
|
|
209
191
|
recommendations = recommendations.groupby(self.query_column).agg(
|
|
210
192
|
sf.sort_array(
|
|
211
193
|
sf.collect_list(
|
|
212
|
-
sf.struct(
|
|
213
|
-
*[
|
|
214
|
-
c
|
|
215
|
-
for c in [self.rating_column, self.item_column, extra_column]
|
|
216
|
-
if c is not None
|
|
217
|
-
]
|
|
218
|
-
)
|
|
194
|
+
sf.struct(*[c for c in [self.rating_column, self.item_column, extra_column] if c is not None])
|
|
219
195
|
),
|
|
220
196
|
False,
|
|
221
197
|
).alias("pred")
|
|
@@ -231,7 +207,7 @@ class Metric(ABC):
|
|
|
231
207
|
return recommendations
|
|
232
208
|
|
|
233
209
|
def _get_items_list_per_user_polars(
|
|
234
|
-
self, recommendations: PolarsDataFrame, extra_column: str = None
|
|
210
|
+
self, recommendations: PolarsDataFrame, extra_column: Optional[str] = None
|
|
235
211
|
) -> PolarsDataFrame:
|
|
236
212
|
selection = [self.query_column, "pred_item_id"]
|
|
237
213
|
sorting = [self.rating_column, self.item_column]
|
|
@@ -242,8 +218,7 @@ class Metric(ABC):
|
|
|
242
218
|
selection.append(extra_column)
|
|
243
219
|
|
|
244
220
|
recommendations = (
|
|
245
|
-
recommendations
|
|
246
|
-
.sort(sorting, descending=True)
|
|
221
|
+
recommendations.sort(sorting, descending=True)
|
|
247
222
|
.group_by(self.query_column)
|
|
248
223
|
.agg(*agg)
|
|
249
224
|
.rename({self.item_column: "pred_item_id"})
|
|
@@ -253,7 +228,7 @@ class Metric(ABC):
|
|
|
253
228
|
return recommendations
|
|
254
229
|
|
|
255
230
|
def _get_items_list_per_user(
|
|
256
|
-
self, recommendations: Union[SparkDataFrame, PolarsDataFrame], extra_column: str = None
|
|
231
|
+
self, recommendations: Union[SparkDataFrame, PolarsDataFrame], extra_column: Optional[str] = None
|
|
257
232
|
) -> Union[SparkDataFrame, PolarsDataFrame]:
|
|
258
233
|
if isinstance(recommendations, SparkDataFrame):
|
|
259
234
|
return self._get_items_list_per_user_spark(recommendations, extra_column)
|
|
@@ -265,7 +240,7 @@ class Metric(ABC):
|
|
|
265
240
|
) -> Union[SparkDataFrame, PolarsDataFrame]:
|
|
266
241
|
cols = data.columns
|
|
267
242
|
cols.remove(self.query_column)
|
|
268
|
-
cols = [self.query_column
|
|
243
|
+
cols = [self.query_column, *sorted(cols)]
|
|
269
244
|
return data.select(*cols)
|
|
270
245
|
|
|
271
246
|
def _get_enriched_recommendations(
|
|
@@ -300,8 +275,7 @@ class Metric(ABC):
|
|
|
300
275
|
ground_truth: PolarsDataFrame,
|
|
301
276
|
) -> PolarsDataFrame:
|
|
302
277
|
true_items_by_users = (
|
|
303
|
-
ground_truth
|
|
304
|
-
.group_by(self.query_column)
|
|
278
|
+
ground_truth.group_by(self.query_column)
|
|
305
279
|
.agg(pl.col(self.item_column))
|
|
306
280
|
.rename({self.item_column: "ground_truth"})
|
|
307
281
|
)
|
|
@@ -313,9 +287,7 @@ class Metric(ABC):
|
|
|
313
287
|
)
|
|
314
288
|
return self._rearrange_columns(enriched_recommendations)
|
|
315
289
|
|
|
316
|
-
def _aggregate_results_per_user(
|
|
317
|
-
self, distribution_per_user: Dict[Any, List[float]]
|
|
318
|
-
) -> MetricsPerUserReturnType:
|
|
290
|
+
def _aggregate_results_per_user(self, distribution_per_user: Dict[Any, List[float]]) -> MetricsPerUserReturnType:
|
|
319
291
|
res: MetricsPerUserReturnType = {}
|
|
320
292
|
for index, val in enumerate(self.topk):
|
|
321
293
|
metric_name = f"{self.__name__}@{val}"
|
|
@@ -335,18 +307,12 @@ class Metric(ABC):
|
|
|
335
307
|
"""
|
|
336
308
|
Calculating metrics for PySpark DataFrame.
|
|
337
309
|
"""
|
|
338
|
-
recs_with_topk_list = recs.withColumn(
|
|
339
|
-
"k", sf.array(*[sf.lit(x) for x in self.topk])
|
|
340
|
-
)
|
|
310
|
+
recs_with_topk_list = recs.withColumn("k", sf.array(*[sf.lit(x) for x in self.topk]))
|
|
341
311
|
distribution = self._get_metric_distribution(recs_with_topk_list)
|
|
342
312
|
if self._mode.__name__ == "PerUser":
|
|
343
313
|
return self._aggregate_results_per_user(distribution.rdd.collectAsMap())
|
|
344
314
|
metrics = [
|
|
345
|
-
self._mode.spark(
|
|
346
|
-
distribution.select(sf.col("value").getItem(i)).withColumnRenamed(
|
|
347
|
-
f"value[{i}]", "val"
|
|
348
|
-
)
|
|
349
|
-
)
|
|
315
|
+
self._mode.spark(distribution.select(sf.col("value").getItem(i)).withColumnRenamed(f"value[{i}]", "val"))
|
|
350
316
|
for i in range(len(self.topk))
|
|
351
317
|
]
|
|
352
318
|
return self._aggregate_results(metrics)
|
|
@@ -355,27 +321,23 @@ class Metric(ABC):
|
|
|
355
321
|
distribution = self._get_metric_distribution(recs)
|
|
356
322
|
if self._mode.__name__ == "PerUser":
|
|
357
323
|
return self._aggregate_results_per_user(
|
|
358
|
-
dict(
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
324
|
+
dict(
|
|
325
|
+
distribution.select(
|
|
326
|
+
self.query_column, value=pl.concat_list(pl.exclude(self.query_column))
|
|
327
|
+
).iter_rows()
|
|
328
|
+
)
|
|
362
329
|
)
|
|
363
|
-
metrics = [self._mode.cpu(distribution.select(column))
|
|
364
|
-
for column in distribution.columns[1:]]
|
|
330
|
+
metrics = [self._mode.cpu(distribution.select(column)) for column in distribution.columns[1:]]
|
|
365
331
|
return self._aggregate_results(metrics)
|
|
366
332
|
|
|
367
|
-
def _spark_call(
|
|
368
|
-
self, recommendations: SparkDataFrame, ground_truth: SparkDataFrame
|
|
369
|
-
) -> MetricsReturnType:
|
|
333
|
+
def _spark_call(self, recommendations: SparkDataFrame, ground_truth: SparkDataFrame) -> MetricsReturnType:
|
|
370
334
|
"""
|
|
371
335
|
Implementation for PySpark DataFrame.
|
|
372
336
|
"""
|
|
373
337
|
recs = self._get_enriched_recommendations(recommendations, ground_truth)
|
|
374
338
|
return self._spark_compute(recs)
|
|
375
339
|
|
|
376
|
-
def _polars_call(
|
|
377
|
-
self, recommendations: PolarsDataFrame, ground_truth: PolarsDataFrame
|
|
378
|
-
) -> MetricsReturnType:
|
|
340
|
+
def _polars_call(self, recommendations: PolarsDataFrame, ground_truth: PolarsDataFrame) -> MetricsReturnType:
|
|
379
341
|
"""
|
|
380
342
|
Implementation for Polars DataFrame.
|
|
381
343
|
"""
|
|
@@ -383,7 +345,7 @@ class Metric(ABC):
|
|
|
383
345
|
return self._polars_compute(recs)
|
|
384
346
|
|
|
385
347
|
def _get_metric_distribution(
|
|
386
|
-
|
|
348
|
+
self, recs: Union[PolarsDataFrame, SparkDataFrame]
|
|
387
349
|
) -> Union[PolarsDataFrame, SparkDataFrame]:
|
|
388
350
|
if isinstance(recs, SparkDataFrame):
|
|
389
351
|
return self._get_metric_distribution_spark(recs)
|
|
@@ -406,16 +368,13 @@ class Metric(ABC):
|
|
|
406
368
|
distribution = recs.map_rows(lambda x: (x[0], *cur_class._get_metric_value_by_user(self.topk, *x[1:])))
|
|
407
369
|
distribution = distribution.rename({"column_0": self.query_column})
|
|
408
370
|
distribution = distribution.rename(
|
|
409
|
-
{distribution.columns[x + 1]: f"value_{self.topk[x]}"
|
|
410
|
-
for x in range(len(self.topk))}
|
|
371
|
+
{distribution.columns[x + 1]: f"value_{self.topk[x]}" for x in range(len(self.topk))}
|
|
411
372
|
)
|
|
412
373
|
return distribution
|
|
413
374
|
|
|
414
375
|
@staticmethod
|
|
415
376
|
@abstractmethod
|
|
416
|
-
def _get_metric_value_by_user( #
|
|
417
|
-
ks: List[int], *args: List
|
|
418
|
-
) -> List[float]: # pragma: no cover
|
|
377
|
+
def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
|
|
419
378
|
"""
|
|
420
379
|
Metric calculation for one user.
|
|
421
380
|
|
|
@@ -4,7 +4,7 @@ from typing import Dict, List, Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
7
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame,
|
|
7
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
8
8
|
|
|
9
9
|
from .base_metric import (
|
|
10
10
|
Metric,
|
|
@@ -16,11 +16,12 @@ from .base_metric import (
|
|
|
16
16
|
from .descriptors import CalculationDescriptor, Mean
|
|
17
17
|
|
|
18
18
|
if PYSPARK_AVAILABLE:
|
|
19
|
-
from pyspark.sql import
|
|
20
|
-
|
|
19
|
+
from pyspark.sql import (
|
|
20
|
+
Window,
|
|
21
|
+
functions as sf,
|
|
22
|
+
)
|
|
21
23
|
|
|
22
24
|
|
|
23
|
-
# pylint: disable=too-few-public-methods
|
|
24
25
|
class CategoricalDiversity(Metric):
|
|
25
26
|
"""
|
|
26
27
|
Metric calculation is as follows:
|
|
@@ -59,7 +60,6 @@ class CategoricalDiversity(Metric):
|
|
|
59
60
|
<BLANKLINE>
|
|
60
61
|
"""
|
|
61
62
|
|
|
62
|
-
# pylint: disable=too-many-arguments
|
|
63
63
|
def __init__(
|
|
64
64
|
self,
|
|
65
65
|
topk: Union[List, int],
|
|
@@ -108,31 +108,21 @@ class CategoricalDiversity(Metric):
|
|
|
108
108
|
precalculated_answer = self._precalculate_unique_cats(recommendations)
|
|
109
109
|
return self._dict_call(precalculated_answer)
|
|
110
110
|
|
|
111
|
-
# pylint: disable=arguments-differ
|
|
112
111
|
def _get_enriched_recommendations(
|
|
113
|
-
self,
|
|
112
|
+
self,
|
|
113
|
+
recommendations: Union[PolarsDataFrame, SparkDataFrame],
|
|
114
114
|
) -> Union[PolarsDataFrame, SparkDataFrame]:
|
|
115
115
|
if isinstance(recommendations, SparkDataFrame):
|
|
116
116
|
return self._get_enriched_recommendations_spark(recommendations)
|
|
117
117
|
else:
|
|
118
118
|
return self._get_enriched_recommendations_polars(recommendations)
|
|
119
119
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
) -> SparkDataFrame:
|
|
124
|
-
window = Window.partitionBy(self.query_column).orderBy(
|
|
125
|
-
F.col(self.rating_column).desc()
|
|
126
|
-
)
|
|
127
|
-
sorted_by_score_recommendations = recommendations.withColumn(
|
|
128
|
-
"rank", F.row_number().over(window)
|
|
129
|
-
)
|
|
120
|
+
def _get_enriched_recommendations_spark(self, recommendations: SparkDataFrame) -> SparkDataFrame:
|
|
121
|
+
window = Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())
|
|
122
|
+
sorted_by_score_recommendations = recommendations.withColumn("rank", sf.row_number().over(window))
|
|
130
123
|
return sorted_by_score_recommendations
|
|
131
124
|
|
|
132
|
-
|
|
133
|
-
def _get_enriched_recommendations_polars(
|
|
134
|
-
self, recommendations: PolarsDataFrame
|
|
135
|
-
) -> PolarsDataFrame:
|
|
125
|
+
def _get_enriched_recommendations_polars(self, recommendations: PolarsDataFrame) -> PolarsDataFrame:
|
|
136
126
|
sorted_by_score_recommendations = recommendations.select(
|
|
137
127
|
pl.all().sort_by(self.rating_column, descending=True).over(self.query_column)
|
|
138
128
|
)
|
|
@@ -146,13 +136,9 @@ class CategoricalDiversity(Metric):
|
|
|
146
136
|
def _spark_compute_per_user(self, recs: SparkDataFrame) -> MetricsPerUserReturnType:
|
|
147
137
|
distribution_per_user = defaultdict(list)
|
|
148
138
|
for k in self.topk:
|
|
149
|
-
filtered_recs = recs.filter(
|
|
150
|
-
aggreagated_by_user = filtered_recs.groupBy(self.query_column).agg(
|
|
151
|
-
|
|
152
|
-
)
|
|
153
|
-
aggreagated_by_user_dict = (
|
|
154
|
-
aggreagated_by_user.rdd.collectAsMap()
|
|
155
|
-
) # type:ignore
|
|
139
|
+
filtered_recs = recs.filter(sf.col("rank") <= k)
|
|
140
|
+
aggreagated_by_user = filtered_recs.groupBy(self.query_column).agg(sf.countDistinct(self.category_column))
|
|
141
|
+
aggreagated_by_user_dict = aggreagated_by_user.rdd.collectAsMap()
|
|
156
142
|
for user, metric in aggreagated_by_user_dict.items():
|
|
157
143
|
distribution_per_user[user].append(metric / k)
|
|
158
144
|
return self._aggregate_results_per_user(dict(distribution_per_user))
|
|
@@ -161,12 +147,8 @@ class CategoricalDiversity(Metric):
|
|
|
161
147
|
distribution_per_user = defaultdict(list)
|
|
162
148
|
for k in self.topk:
|
|
163
149
|
filtered_recs = recs.filter(pl.col("rank") <= k)
|
|
164
|
-
aggreagated_by_user = filtered_recs.group_by(self.query_column).agg(
|
|
165
|
-
|
|
166
|
-
)
|
|
167
|
-
aggreagated_by_user_dict = (
|
|
168
|
-
dict(aggreagated_by_user.iter_rows())
|
|
169
|
-
) # type:ignore
|
|
150
|
+
aggreagated_by_user = filtered_recs.group_by(self.query_column).agg(pl.col(self.category_column).n_unique())
|
|
151
|
+
aggreagated_by_user_dict = dict(aggreagated_by_user.iter_rows())
|
|
170
152
|
for user, metric in aggreagated_by_user_dict.items():
|
|
171
153
|
distribution_per_user[user].append(metric / k)
|
|
172
154
|
return self._aggregate_results_per_user(dict(distribution_per_user))
|
|
@@ -174,10 +156,10 @@ class CategoricalDiversity(Metric):
|
|
|
174
156
|
def _spark_compute_agg(self, recs: SparkDataFrame) -> MetricsMeanReturnType:
|
|
175
157
|
metrics = []
|
|
176
158
|
for k in self.topk:
|
|
177
|
-
filtered_recs = recs.filter(
|
|
159
|
+
filtered_recs = recs.filter(sf.col("rank") <= k)
|
|
178
160
|
aggregated_by_user = (
|
|
179
161
|
filtered_recs.groupBy(self.query_column)
|
|
180
|
-
.agg(
|
|
162
|
+
.agg(sf.countDistinct(self.category_column))
|
|
181
163
|
.drop(self.query_column)
|
|
182
164
|
)
|
|
183
165
|
metrics.append(self._mode.spark(aggregated_by_user) / k)
|
|
@@ -195,7 +177,6 @@ class CategoricalDiversity(Metric):
|
|
|
195
177
|
metrics.append(self._mode.cpu(aggregated_by_user) / k)
|
|
196
178
|
return self._aggregate_results(metrics)
|
|
197
179
|
|
|
198
|
-
# pylint: disable=arguments-differ
|
|
199
180
|
def _spark_call(self, recommendations: SparkDataFrame) -> MetricsReturnType:
|
|
200
181
|
"""
|
|
201
182
|
Implementation for Pyspark DataFrame.
|
|
@@ -205,7 +186,6 @@ class CategoricalDiversity(Metric):
|
|
|
205
186
|
return self._spark_compute_per_user(recs)
|
|
206
187
|
return self._spark_compute_agg(recs)
|
|
207
188
|
|
|
208
|
-
# pylint: disable=arguments-differ
|
|
209
189
|
def _polars_call(self, recommendations: PolarsDataFrame) -> MetricsReturnType:
|
|
210
190
|
"""
|
|
211
191
|
Implementation for Polars DataFrame.
|
|
@@ -223,7 +203,6 @@ class CategoricalDiversity(Metric):
|
|
|
223
203
|
.to_dict()
|
|
224
204
|
)
|
|
225
205
|
|
|
226
|
-
# pylint: disable=no-self-use
|
|
227
206
|
def _precalculate_unique_cats(self, recommendations: Dict) -> Dict:
|
|
228
207
|
"""
|
|
229
208
|
Precalculate unique categories for each prefix for each user.
|
|
@@ -238,24 +217,16 @@ class CategoricalDiversity(Metric):
|
|
|
238
217
|
answer[user] = unique_len
|
|
239
218
|
return answer
|
|
240
219
|
|
|
241
|
-
|
|
242
|
-
def _dict_compute_per_user(
|
|
243
|
-
self, precalculated_answer: Dict
|
|
244
|
-
) -> MetricsPerUserReturnType: # type:ignore
|
|
220
|
+
def _dict_compute_per_user(self, precalculated_answer: Dict) -> MetricsPerUserReturnType:
|
|
245
221
|
distribution_per_user = defaultdict(list)
|
|
246
222
|
for k in self.topk:
|
|
247
223
|
for user, unique_cats in precalculated_answer.items():
|
|
248
|
-
distribution_per_user[user].append(
|
|
249
|
-
unique_cats[min(len(unique_cats), k) - 1] / k
|
|
250
|
-
)
|
|
224
|
+
distribution_per_user[user].append(unique_cats[min(len(unique_cats), k) - 1] / k)
|
|
251
225
|
return self._aggregate_results_per_user(distribution_per_user)
|
|
252
226
|
|
|
253
|
-
|
|
254
|
-
def _dict_compute_mean(
|
|
255
|
-
self, precalculated_answer: Dict
|
|
256
|
-
) -> MetricsMeanReturnType: # type:ignore
|
|
227
|
+
def _dict_compute_mean(self, precalculated_answer: Dict) -> MetricsMeanReturnType:
|
|
257
228
|
distribution_list = []
|
|
258
|
-
for
|
|
229
|
+
for unique_cats in precalculated_answer.values():
|
|
259
230
|
metrics_per_user = []
|
|
260
231
|
for k in self.topk:
|
|
261
232
|
metric = unique_cats[min(len(unique_cats), k) - 1] / k
|
|
@@ -264,12 +235,9 @@ class CategoricalDiversity(Metric):
|
|
|
264
235
|
|
|
265
236
|
distribution = np.stack(distribution_list)
|
|
266
237
|
assert distribution.shape[1] == len(self.topk)
|
|
267
|
-
metrics = []
|
|
268
|
-
for k in range(distribution.shape[1]):
|
|
269
|
-
metrics.append(self._mode.cpu(distribution[:, k]))
|
|
238
|
+
metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
|
|
270
239
|
return self._aggregate_results(metrics)
|
|
271
240
|
|
|
272
|
-
# pylint: disable=arguments-differ
|
|
273
241
|
def _dict_call(self, precalculated_answer: Dict) -> MetricsReturnType:
|
|
274
242
|
"""
|
|
275
243
|
Calculating metrics in dict format.
|
|
@@ -279,7 +247,5 @@ class CategoricalDiversity(Metric):
|
|
|
279
247
|
return self._dict_compute_mean(precalculated_answer)
|
|
280
248
|
|
|
281
249
|
@staticmethod
|
|
282
|
-
def _get_metric_value_by_user(
|
|
283
|
-
ks: List[int], *args: List
|
|
284
|
-
) -> List[float]: # pragma: no cover
|
|
250
|
+
def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
|
|
285
251
|
pass
|