replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -3,16 +3,16 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from replay.data import Dataset
|
|
6
|
+
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
7
|
+
|
|
6
8
|
from .base_neighbour_rec import NeighbourRec
|
|
7
9
|
from .extensions.ann.index_builders.base_index_builder import IndexBuilder
|
|
8
|
-
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
9
10
|
|
|
10
11
|
if PYSPARK_AVAILABLE:
|
|
11
12
|
import pyspark.sql.functions as sf
|
|
12
13
|
from pyspark.sql.window import Window
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
# pylint: disable=too-many-ancestors, too-many-instance-attributes
|
|
16
16
|
class AssociationRulesItemRec(NeighbourRec):
|
|
17
17
|
"""
|
|
18
18
|
Item-to-item recommender based on association rules.
|
|
@@ -117,7 +117,6 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
117
117
|
},
|
|
118
118
|
}
|
|
119
119
|
|
|
120
|
-
# pylint: disable=too-many-arguments,
|
|
121
120
|
def __init__(
|
|
122
121
|
self,
|
|
123
122
|
session_column: str,
|
|
@@ -204,14 +203,11 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
204
203
|
frequent_items_interactions.withColumnRenamed(self.item_column, "antecedent")
|
|
205
204
|
.withColumnRenamed(self.rating_column, "antecedent_rel")
|
|
206
205
|
.join(
|
|
207
|
-
frequent_items_interactions.withColumnRenamed(
|
|
208
|
-
self.session_column, self.session_column + "_cons"
|
|
209
|
-
)
|
|
206
|
+
frequent_items_interactions.withColumnRenamed(self.session_column, self.session_column + "_cons")
|
|
210
207
|
.withColumnRenamed(self.item_column, "consequent")
|
|
211
208
|
.withColumnRenamed(self.rating_column, "consequent_rel"),
|
|
212
209
|
on=[
|
|
213
|
-
sf.col(self.session_column)
|
|
214
|
-
== sf.col(self.session_column + "_cons"),
|
|
210
|
+
sf.col(self.session_column) == sf.col(self.session_column + "_cons"),
|
|
215
211
|
sf.col("antecedent") < sf.col("consequent"),
|
|
216
212
|
],
|
|
217
213
|
)
|
|
@@ -220,9 +216,7 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
220
216
|
self.rating_column,
|
|
221
217
|
sf.least(sf.col("consequent_rel"), sf.col("antecedent_rel")),
|
|
222
218
|
)
|
|
223
|
-
.drop(
|
|
224
|
-
self.session_column + "_cons", "consequent_rel", "antecedent_rel"
|
|
225
|
-
)
|
|
219
|
+
.drop(self.session_column + "_cons", "consequent_rel", "antecedent_rel")
|
|
226
220
|
)
|
|
227
221
|
|
|
228
222
|
pairs_count = (
|
|
@@ -243,16 +237,12 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
243
237
|
)
|
|
244
238
|
|
|
245
239
|
pairs_metrics = pairs_metrics.join(
|
|
246
|
-
frequent_items_cached.withColumnRenamed(
|
|
247
|
-
"item_rating", "antecedent_rating"
|
|
248
|
-
),
|
|
240
|
+
frequent_items_cached.withColumnRenamed("item_rating", "antecedent_rating"),
|
|
249
241
|
on=[sf.col("antecedent") == sf.col(self.item_column)],
|
|
250
242
|
).drop(self.item_column)
|
|
251
243
|
|
|
252
244
|
pairs_metrics = pairs_metrics.join(
|
|
253
|
-
frequent_items_cached.withColumnRenamed(
|
|
254
|
-
"item_rating", "consequent_rating"
|
|
255
|
-
),
|
|
245
|
+
frequent_items_cached.withColumnRenamed("item_rating", "consequent_rating"),
|
|
256
246
|
on=[sf.col("consequent") == sf.col(self.item_column)],
|
|
257
247
|
).drop(self.item_column)
|
|
258
248
|
|
|
@@ -261,9 +251,7 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
261
251
|
sf.col("pair_rating") / sf.col("antecedent_rating"),
|
|
262
252
|
).withColumn(
|
|
263
253
|
"lift",
|
|
264
|
-
num_sessions
|
|
265
|
-
* sf.col("confidence")
|
|
266
|
-
/ sf.col("consequent_rating"),
|
|
254
|
+
num_sessions * sf.col("confidence") / sf.col("consequent_rating"),
|
|
267
255
|
)
|
|
268
256
|
|
|
269
257
|
if self.num_neighbours is not None:
|
|
@@ -331,10 +319,8 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
331
319
|
spark-dataframe with columns ``[item_id, neighbour_item_id, similarity]``
|
|
332
320
|
"""
|
|
333
321
|
if metric not in self.item_to_item_metrics:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
f"{self.item_to_item_metrics}"
|
|
337
|
-
)
|
|
322
|
+
msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
|
|
323
|
+
raise ValueError(msg)
|
|
338
324
|
|
|
339
325
|
return self._get_nearest_items_wrap(
|
|
340
326
|
items=items,
|
|
@@ -346,7 +332,7 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
346
332
|
def _get_nearest_items(
|
|
347
333
|
self,
|
|
348
334
|
items: SparkDataFrame,
|
|
349
|
-
metric: Optional[str] = None,
|
|
335
|
+
metric: Optional[str] = None, # noqa: ARG002
|
|
350
336
|
candidates: Optional[SparkDataFrame] = None,
|
|
351
337
|
) -> SparkDataFrame:
|
|
352
338
|
"""
|
|
@@ -361,9 +347,7 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
361
347
|
pairs_to_consider = self.similarity
|
|
362
348
|
if candidates is not None:
|
|
363
349
|
pairs_to_consider = self.similarity.join(
|
|
364
|
-
sf.broadcast(
|
|
365
|
-
candidates.withColumnRenamed(self.item_column, "item_idx_two")
|
|
366
|
-
),
|
|
350
|
+
sf.broadcast(candidates.withColumnRenamed(self.item_column, "item_idx_two")),
|
|
367
351
|
on="item_idx_two",
|
|
368
352
|
)
|
|
369
353
|
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines
|
|
2
1
|
"""
|
|
3
2
|
NeighbourRec - base class that requires interactions at prediction time.
|
|
4
3
|
Part of set of abstract classes (from base_rec.py)
|
|
@@ -8,9 +7,10 @@ from abc import ABC
|
|
|
8
7
|
from typing import Any, Dict, Iterable, Optional, Union
|
|
9
8
|
|
|
10
9
|
from replay.data.dataset import Dataset
|
|
10
|
+
from replay.utils import PYSPARK_AVAILABLE, MissingImportType, SparkDataFrame
|
|
11
|
+
|
|
11
12
|
from .base_rec import Recommender
|
|
12
13
|
from .extensions.ann.ann_mixin import ANNMixin
|
|
13
|
-
from replay.utils import PYSPARK_AVAILABLE, MissingImportType, SparkDataFrame
|
|
14
14
|
|
|
15
15
|
if PYSPARK_AVAILABLE:
|
|
16
16
|
from pyspark.sql import functions as sf
|
|
@@ -37,7 +37,6 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
37
37
|
if hasattr(self, "similarity"):
|
|
38
38
|
self.similarity.unpersist()
|
|
39
39
|
|
|
40
|
-
# pylint: disable=missing-function-docstring
|
|
41
40
|
@property
|
|
42
41
|
def similarity_metric(self):
|
|
43
42
|
return self._similarity_metric
|
|
@@ -45,14 +44,11 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
45
44
|
@similarity_metric.setter
|
|
46
45
|
def similarity_metric(self, value):
|
|
47
46
|
if not self.can_change_metric:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
)
|
|
47
|
+
msg = "This class does not support changing similarity metrics"
|
|
48
|
+
raise ValueError(msg)
|
|
51
49
|
if value not in self.item_to_item_metrics:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
f"{self.item_to_item_metrics}"
|
|
55
|
-
)
|
|
50
|
+
msg = f"Select one of the valid metrics for predict: {self.item_to_item_metrics}"
|
|
51
|
+
raise ValueError(msg)
|
|
56
52
|
self._similarity_metric = value
|
|
57
53
|
|
|
58
54
|
def _predict_pairs_inner(
|
|
@@ -76,9 +72,8 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
76
72
|
:return: SparkDataFrame ``[user_id, item_id, rating]``
|
|
77
73
|
"""
|
|
78
74
|
if dataset is None:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
)
|
|
75
|
+
msg = "interactions is not provided, but it is required for prediction"
|
|
76
|
+
raise ValueError(msg)
|
|
82
77
|
|
|
83
78
|
recs = (
|
|
84
79
|
dataset.interactions.join(queries, how="inner", on=self.query_column)
|
|
@@ -98,16 +93,14 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
98
93
|
)
|
|
99
94
|
return recs
|
|
100
95
|
|
|
101
|
-
# pylint: disable=too-many-arguments
|
|
102
96
|
def _predict(
|
|
103
97
|
self,
|
|
104
98
|
dataset: Dataset,
|
|
105
|
-
k: int,
|
|
99
|
+
k: int, # noqa: ARG002
|
|
106
100
|
queries: SparkDataFrame,
|
|
107
101
|
items: SparkDataFrame,
|
|
108
|
-
filter_seen_items: bool = True,
|
|
102
|
+
filter_seen_items: bool = True, # noqa: ARG002
|
|
109
103
|
) -> SparkDataFrame:
|
|
110
|
-
|
|
111
104
|
return self._predict_pairs_inner(
|
|
112
105
|
dataset=dataset,
|
|
113
106
|
filter_df=items.withColumnRenamed(self.item_column, "item_idx_filter"),
|
|
@@ -120,13 +113,12 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
120
113
|
pairs: SparkDataFrame,
|
|
121
114
|
dataset: Optional[Dataset] = None,
|
|
122
115
|
) -> SparkDataFrame:
|
|
123
|
-
|
|
124
116
|
return self._predict_pairs_inner(
|
|
125
117
|
dataset=dataset,
|
|
126
118
|
filter_df=(
|
|
127
|
-
pairs.withColumnRenamed(
|
|
128
|
-
self.
|
|
129
|
-
)
|
|
119
|
+
pairs.withColumnRenamed(self.query_column, "user_idx_filter").withColumnRenamed(
|
|
120
|
+
self.item_column, "item_idx_filter"
|
|
121
|
+
)
|
|
130
122
|
),
|
|
131
123
|
condition=(sf.col(self.query_column) == sf.col("user_idx_filter"))
|
|
132
124
|
& (sf.col("item_idx_two") == sf.col("item_idx_filter")),
|
|
@@ -157,10 +149,8 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
157
149
|
|
|
158
150
|
if metric is not None:
|
|
159
151
|
if metric not in self.item_to_item_metrics:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
f"{self.item_to_item_metrics}"
|
|
163
|
-
)
|
|
152
|
+
msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
|
|
153
|
+
raise ValueError(msg)
|
|
164
154
|
|
|
165
155
|
self.logger.debug(
|
|
166
156
|
"Metric is not used to determine nearest items in %s model",
|
|
@@ -180,7 +170,6 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
180
170
|
metric: Optional[str] = None,
|
|
181
171
|
candidates: Optional[SparkDataFrame] = None,
|
|
182
172
|
) -> SparkDataFrame:
|
|
183
|
-
|
|
184
173
|
similarity_filtered = self.similarity.join(
|
|
185
174
|
items.withColumnRenamed(self.item_column, "item_idx_one"),
|
|
186
175
|
on="item_idx_one",
|
|
@@ -204,20 +193,16 @@ class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
|
204
193
|
"features_col": None,
|
|
205
194
|
}
|
|
206
195
|
|
|
207
|
-
def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame:
|
|
208
|
-
similarity_df = self.similarity.select(
|
|
209
|
-
"similarity", "item_idx_one", "item_idx_two"
|
|
210
|
-
)
|
|
196
|
+
def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
|
|
197
|
+
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
211
198
|
return similarity_df
|
|
212
199
|
|
|
213
200
|
def _get_vectors_to_infer_ann_inner(
|
|
214
|
-
|
|
201
|
+
self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
|
|
215
202
|
) -> SparkDataFrame:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
sf.collect_list(self.item_column).alias("vector_items"),
|
|
220
|
-
sf.collect_list(self.rating_column).alias("vector_ratings"))
|
|
203
|
+
user_vectors = interactions.groupBy(self.query_column).agg(
|
|
204
|
+
sf.collect_list(self.item_column).alias("vector_items"),
|
|
205
|
+
sf.collect_list(self.rating_column).alias("vector_ratings"),
|
|
221
206
|
)
|
|
222
207
|
return user_vectors
|
|
223
208
|
|