replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
replay/models/cat_pop_rec.py
CHANGED
|
@@ -2,9 +2,10 @@ from os.path import join
|
|
|
2
2
|
from typing import Iterable, Optional, Union
|
|
3
3
|
|
|
4
4
|
from replay.data import Dataset
|
|
5
|
-
from .base_rec import IsSavable, RecommenderCommons
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
7
6
|
|
|
7
|
+
from .base_rec import IsSavable, RecommenderCommons
|
|
8
|
+
|
|
8
9
|
if PYSPARK_AVAILABLE:
|
|
9
10
|
from pyspark.sql import functions as sf
|
|
10
11
|
|
|
@@ -18,7 +19,6 @@ if PYSPARK_AVAILABLE:
|
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
# pylint: disable=too-many-instance-attributes
|
|
22
22
|
class CatPopRec(IsSavable, RecommenderCommons):
|
|
23
23
|
"""
|
|
24
24
|
CatPopRec generate recommendation for item categories.
|
|
@@ -35,9 +35,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
35
35
|
can_predict_cold_items: bool = False
|
|
36
36
|
fit_items: SparkDataFrame
|
|
37
37
|
|
|
38
|
-
def _generate_mapping(
|
|
39
|
-
self, cat_tree: SparkDataFrame, max_iter: int = 20
|
|
40
|
-
) -> SparkDataFrame:
|
|
38
|
+
def _generate_mapping(self, cat_tree: SparkDataFrame, max_iter: int = 20) -> SparkDataFrame:
|
|
41
39
|
"""
|
|
42
40
|
Create SparkDataFrame with mapping [`category`, `leaf_cat`]
|
|
43
41
|
where `leaf_cat` is the lowest level categories of category tree,
|
|
@@ -49,9 +47,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
49
47
|
:param max_iter: maximal number of iteration of descend through the category tree
|
|
50
48
|
:return: SparkDataFrame with mapping [`category`, `leaf_cat`]
|
|
51
49
|
"""
|
|
52
|
-
current_res = cat_tree.select(
|
|
53
|
-
sf.col("category"), sf.col("category").alias("leaf_cat")
|
|
54
|
-
)
|
|
50
|
+
current_res = cat_tree.select(sf.col("category"), sf.col("category").alias("leaf_cat"))
|
|
55
51
|
|
|
56
52
|
i = 0
|
|
57
53
|
res_size_growth = current_res.count()
|
|
@@ -108,9 +104,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
108
104
|
"""
|
|
109
105
|
self.max_iter = max_iter
|
|
110
106
|
if cat_tree is not None:
|
|
111
|
-
self.leaf_cat_mapping = self._generate_mapping(
|
|
112
|
-
cat_tree, max_iter=max_iter
|
|
113
|
-
)
|
|
107
|
+
self.leaf_cat_mapping = self._generate_mapping(cat_tree, max_iter=max_iter)
|
|
114
108
|
|
|
115
109
|
@property
|
|
116
110
|
def _init_args(self):
|
|
@@ -165,7 +159,6 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
165
159
|
if hasattr(self, "leaf_cat_mapping"):
|
|
166
160
|
self.leaf_cat_mapping.unpersist()
|
|
167
161
|
|
|
168
|
-
# pylint: disable=arguments-differ
|
|
169
162
|
def predict(
|
|
170
163
|
self,
|
|
171
164
|
categories: Union[SparkDataFrame, Iterable],
|
|
@@ -219,9 +212,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
219
212
|
item_data = items or self.fit_items
|
|
220
213
|
items = get_unique_entities(item_data, self.item_column)
|
|
221
214
|
|
|
222
|
-
num_new, items = filter_cold(
|
|
223
|
-
items, self.fit_items, col_name=self.item_column
|
|
224
|
-
)
|
|
215
|
+
num_new, items = filter_cold(items, self.fit_items, col_name=self.item_column)
|
|
225
216
|
if num_new > 0:
|
|
226
217
|
self.logger.info(
|
|
227
218
|
"%s model can't predict cold items, they will be ignored",
|
|
@@ -267,9 +258,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
267
258
|
# find number of interactions in all leaf categories after filtering
|
|
268
259
|
num_interactions_in_cat = (
|
|
269
260
|
res.join(
|
|
270
|
-
unique_leaf_cat_items.groupBy("leaf_cat").agg(
|
|
271
|
-
sf.sum(self.rating_column).alias("sum_rating")
|
|
272
|
-
),
|
|
261
|
+
unique_leaf_cat_items.groupBy("leaf_cat").agg(sf.sum(self.rating_column).alias("sum_rating")),
|
|
273
262
|
on="leaf_cat",
|
|
274
263
|
)
|
|
275
264
|
.groupBy("category")
|
|
@@ -284,9 +273,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
284
273
|
.groupBy("category", self.item_column)
|
|
285
274
|
.agg(sf.sum(self.rating_column).alias(self.rating_column))
|
|
286
275
|
.join(num_interactions_in_cat, on="category")
|
|
287
|
-
.withColumn(
|
|
288
|
-
self.rating_column, sf.col(self.rating_column) / sf.col("sum_rating")
|
|
289
|
-
)
|
|
276
|
+
.withColumn(self.rating_column, sf.col(self.rating_column) / sf.col("sum_rating"))
|
|
290
277
|
)
|
|
291
278
|
|
|
292
279
|
def _save_model(self, path: str):
|
|
@@ -296,7 +283,7 @@ class CatPopRec(IsSavable, RecommenderCommons):
|
|
|
296
283
|
"item_column": self.item_column,
|
|
297
284
|
"rating_column": self.rating_column,
|
|
298
285
|
},
|
|
299
|
-
join(path, "params.dump")
|
|
286
|
+
join(path, "params.dump"),
|
|
300
287
|
)
|
|
301
288
|
|
|
302
289
|
def _load_model(self, path: str):
|
replay/models/cluster.py
CHANGED
|
@@ -2,9 +2,10 @@ from os.path import join
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from replay.data.dataset import Dataset
|
|
5
|
-
from .base_rec import QueryRecommender
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
7
6
|
|
|
7
|
+
from .base_rec import QueryRecommender
|
|
8
|
+
|
|
8
9
|
if PYSPARK_AVAILABLE:
|
|
9
10
|
from pyspark.ml.clustering import KMeans, KMeansModel
|
|
10
11
|
from pyspark.ml.feature import VectorAssembler
|
|
@@ -58,12 +59,10 @@ class ClusterRec(QueryRecommender):
|
|
|
58
59
|
sf.count(self.item_column).alias("item_count")
|
|
59
60
|
)
|
|
60
61
|
|
|
61
|
-
max_count_per_cluster = self.item_rel_in_cluster.groupby(
|
|
62
|
-
"
|
|
63
|
-
).agg(sf.max("item_count").alias("max_count_in_cluster"))
|
|
64
|
-
self.item_rel_in_cluster = self.item_rel_in_cluster.join(
|
|
65
|
-
max_count_per_cluster, on="cluster"
|
|
62
|
+
max_count_per_cluster = self.item_rel_in_cluster.groupby("cluster").agg(
|
|
63
|
+
sf.max("item_count").alias("max_count_in_cluster")
|
|
66
64
|
)
|
|
65
|
+
self.item_rel_in_cluster = self.item_rel_in_cluster.join(max_count_per_cluster, on="cluster")
|
|
67
66
|
self.item_rel_in_cluster = self.item_rel_in_cluster.withColumn(
|
|
68
67
|
self.rating_column, sf.col("item_count") / sf.col("max_count_in_cluster")
|
|
69
68
|
).drop("item_count", "max_count_in_cluster")
|
|
@@ -83,47 +82,38 @@ class ClusterRec(QueryRecommender):
|
|
|
83
82
|
return vec.transform(query_features).select(self.query_column, "features")
|
|
84
83
|
|
|
85
84
|
def _make_query_clusters(self, queries, query_features):
|
|
86
|
-
|
|
87
85
|
query_cnt_in_fv = (
|
|
88
|
-
query_features
|
|
89
|
-
.select(self.query_column)
|
|
90
|
-
.distinct()
|
|
91
|
-
.join(queries.distinct(), on=self.query_column)
|
|
92
|
-
.count()
|
|
86
|
+
query_features.select(self.query_column).distinct().join(queries.distinct(), on=self.query_column).count()
|
|
93
87
|
)
|
|
94
88
|
|
|
95
89
|
query_cnt = queries.distinct().count()
|
|
96
90
|
|
|
97
91
|
if query_cnt_in_fv < query_cnt:
|
|
98
|
-
self.logger.info(
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
92
|
+
self.logger.info(
|
|
93
|
+
"%s query(s) don't have a feature vector. The results will not be calculated for them.",
|
|
94
|
+
query_cnt - query_cnt_in_fv,
|
|
95
|
+
)
|
|
102
96
|
|
|
103
|
-
query_features_vector = self._transform_features(
|
|
104
|
-
query_features.join(queries, on=self.query_column)
|
|
105
|
-
)
|
|
97
|
+
query_features_vector = self._transform_features(query_features.join(queries, on=self.query_column))
|
|
106
98
|
return (
|
|
107
99
|
self.model.transform(query_features_vector)
|
|
108
100
|
.select(self.query_column, "prediction")
|
|
109
101
|
.withColumnRenamed("prediction", "cluster")
|
|
110
102
|
)
|
|
111
103
|
|
|
112
|
-
# pylint: disable=too-many-arguments
|
|
113
104
|
def _predict(
|
|
114
105
|
self,
|
|
115
106
|
dataset: Dataset,
|
|
116
|
-
k: int,
|
|
107
|
+
k: int, # noqa: ARG002
|
|
117
108
|
queries: SparkDataFrame,
|
|
118
109
|
items: SparkDataFrame,
|
|
119
|
-
filter_seen_items: bool = True,
|
|
110
|
+
filter_seen_items: bool = True, # noqa: ARG002
|
|
120
111
|
) -> SparkDataFrame:
|
|
121
112
|
query_clusters = self._make_query_clusters(queries, dataset.query_features)
|
|
122
113
|
filtered_items = self.item_rel_in_cluster.join(items, on=self.item_column)
|
|
123
114
|
pred = query_clusters.join(filtered_items, on="cluster").drop("cluster")
|
|
124
115
|
return pred
|
|
125
116
|
|
|
126
|
-
# pylint: disable=signature-differs
|
|
127
117
|
def _predict_pairs(
|
|
128
118
|
self,
|
|
129
119
|
pairs: SparkDataFrame,
|
|
@@ -131,9 +121,8 @@ class ClusterRec(QueryRecommender):
|
|
|
131
121
|
) -> SparkDataFrame:
|
|
132
122
|
query_clusters = self._make_query_clusters(pairs.select(self.query_column).distinct(), dataset.query_features)
|
|
133
123
|
pairs_with_clusters = pairs.join(query_clusters, on=self.query_column)
|
|
134
|
-
filtered_items = (self.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
.select(self.query_column,self.item_column,self.rating_column))
|
|
124
|
+
filtered_items = self.item_rel_in_cluster.join(pairs.select(self.item_column).distinct(), on=self.item_column)
|
|
125
|
+
pred = pairs_with_clusters.join(filtered_items, on=["cluster", self.item_column]).select(
|
|
126
|
+
self.query_column, self.item_column, self.rating_column
|
|
127
|
+
)
|
|
139
128
|
return pred
|
|
@@ -5,15 +5,17 @@ from typing import Any, Dict, Iterable, Optional, Union
|
|
|
5
5
|
|
|
6
6
|
from replay.data import Dataset
|
|
7
7
|
from replay.models.base_rec import BaseRecommender
|
|
8
|
-
from .index_builders.base_index_builder import IndexBuilder
|
|
9
8
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
10
9
|
|
|
10
|
+
from .index_builders.base_index_builder import IndexBuilder
|
|
11
|
+
|
|
11
12
|
if PYSPARK_AVAILABLE:
|
|
12
13
|
from pyspark.sql import functions as sf
|
|
13
14
|
|
|
14
|
-
from .index_stores.spark_files_index_store import SparkFilesIndexStore
|
|
15
15
|
from replay.utils.spark_utils import get_top_k_recs, return_recs
|
|
16
16
|
|
|
17
|
+
from .index_stores.spark_files_index_store import SparkFilesIndexStore
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
logger = logging.getLogger("replay")
|
|
19
21
|
|
|
@@ -82,9 +84,7 @@ class ANNMixin(BaseRecommender):
|
|
|
82
84
|
self.index_builder.build_index(vectors, **ann_params)
|
|
83
85
|
|
|
84
86
|
@abstractmethod
|
|
85
|
-
def _get_vectors_to_infer_ann_inner(
|
|
86
|
-
self, interactions: SparkDataFrame, queries: SparkDataFrame
|
|
87
|
-
) -> SparkDataFrame:
|
|
87
|
+
def _get_vectors_to_infer_ann_inner(self, interactions: SparkDataFrame, queries: SparkDataFrame) -> SparkDataFrame:
|
|
88
88
|
"""Implementations of this method must return a dataframe with user vectors.
|
|
89
89
|
User vectors from this method are used to infer the index.
|
|
90
90
|
|
|
@@ -134,7 +134,6 @@ class ANNMixin(BaseRecommender):
|
|
|
134
134
|
|
|
135
135
|
"""
|
|
136
136
|
|
|
137
|
-
# pylint: disable=too-many-arguments, too-many-locals
|
|
138
137
|
def _predict_wrap(
|
|
139
138
|
self,
|
|
140
139
|
dataset: Optional[Dataset],
|
|
@@ -144,14 +143,10 @@ class ANNMixin(BaseRecommender):
|
|
|
144
143
|
filter_seen_items: bool = True,
|
|
145
144
|
recs_file_path: Optional[str] = None,
|
|
146
145
|
) -> Optional[SparkDataFrame]:
|
|
147
|
-
dataset, queries, items = self._filter_interactions_queries_items_dataframes(
|
|
148
|
-
dataset, k, queries, items
|
|
149
|
-
)
|
|
146
|
+
dataset, queries, items = self._filter_interactions_queries_items_dataframes(dataset, k, queries, items)
|
|
150
147
|
|
|
151
148
|
if self._use_ann:
|
|
152
|
-
vectors = self._get_vectors_to_infer_ann(
|
|
153
|
-
dataset.interactions, queries, filter_seen_items
|
|
154
|
-
)
|
|
149
|
+
vectors = self._get_vectors_to_infer_ann(dataset.interactions, queries, filter_seen_items)
|
|
155
150
|
ann_params = self._get_ann_infer_params()
|
|
156
151
|
inferer = self.index_builder.produce_inferer(filter_seen_items)
|
|
157
152
|
recs = inferer.infer(vectors, ann_params["features_col"], k)
|
|
@@ -59,9 +59,3 @@ class HnswlibParam(BaseHnswParam):
|
|
|
59
59
|
dim: int = field(default=None, init=False)
|
|
60
60
|
# Max number of elements that will be stored in the index
|
|
61
61
|
max_elements: int = field(default=None, init=False)
|
|
62
|
-
|
|
63
|
-
# def init_args_as_dict(self):
|
|
64
|
-
# # union dicts
|
|
65
|
-
# return dict(
|
|
66
|
-
# super().init_args_as_dict()["init_args"], **{"space": self.space}
|
|
67
|
-
# )
|
|
@@ -65,9 +65,3 @@ class NmslibHnswParam(BaseHnswParam):
|
|
|
65
65
|
items_count: Optional[int] = field(default=None, init=False)
|
|
66
66
|
|
|
67
67
|
method: ClassVar[str] = "hnsw"
|
|
68
|
-
|
|
69
|
-
# def init_args_as_dict(self):
|
|
70
|
-
# # union dicts
|
|
71
|
-
# return dict(
|
|
72
|
-
# super().init_args_as_dict()["init_args"], **{"space": self.space}
|
|
73
|
-
# )
|
|
@@ -3,7 +3,6 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
|
-
from .base_index_builder import IndexBuilder
|
|
7
6
|
from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
|
|
8
7
|
from replay.models.extensions.ann.index_inferers.hnswlib_filter_index_inferer import HnswlibFilterIndexInferer
|
|
9
8
|
from replay.models.extensions.ann.index_inferers.hnswlib_index_inferer import HnswlibIndexInferer
|
|
@@ -11,6 +10,8 @@ from replay.models.extensions.ann.utils import create_hnswlib_index_instance
|
|
|
11
10
|
from replay.utils import SparkDataFrame
|
|
12
11
|
from replay.utils.spark_utils import spark_to_pandas
|
|
13
12
|
|
|
13
|
+
from .base_index_builder import IndexBuilder
|
|
14
|
+
|
|
14
15
|
logger = logging.getLogger("replay")
|
|
15
16
|
|
|
16
17
|
|
|
@@ -21,13 +22,10 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
|
|
|
21
22
|
|
|
22
23
|
def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
|
|
23
24
|
if filter_seen_items:
|
|
24
|
-
return HnswlibFilterIndexInferer(
|
|
25
|
-
self.index_params, self.index_store
|
|
26
|
-
)
|
|
25
|
+
return HnswlibFilterIndexInferer(self.index_params, self.index_store)
|
|
27
26
|
else:
|
|
28
27
|
return HnswlibIndexInferer(self.index_params, self.index_store)
|
|
29
28
|
|
|
30
|
-
# pylint: disable=no-member
|
|
31
29
|
def build_index(
|
|
32
30
|
self,
|
|
33
31
|
vectors: SparkDataFrame,
|
|
@@ -43,8 +41,4 @@ class DriverHnswlibIndexBuilder(IndexBuilder):
|
|
|
43
41
|
else:
|
|
44
42
|
index.add_items(np.stack(vectors_np))
|
|
45
43
|
|
|
46
|
-
self.index_store.save_to_store(
|
|
47
|
-
lambda path: index.save_index( # pylint: disable=unnecessary-lambda)
|
|
48
|
-
path
|
|
49
|
-
)
|
|
50
|
-
)
|
|
44
|
+
self.index_store.save_to_store(lambda path: index.save_index(path))
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
|
-
from .base_index_builder import IndexBuilder
|
|
5
|
-
from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
|
|
6
4
|
from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
|
|
7
5
|
from replay.models.extensions.ann.index_inferers.nmslib_filter_index_inferer import NmslibFilterIndexInferer
|
|
8
6
|
from replay.models.extensions.ann.index_inferers.nmslib_index_inferer import NmslibIndexInferer
|
|
9
7
|
from replay.utils import SparkDataFrame
|
|
10
8
|
from replay.utils.spark_utils import spark_to_pandas
|
|
11
9
|
|
|
10
|
+
from .base_index_builder import IndexBuilder
|
|
11
|
+
from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
|
|
12
|
+
|
|
12
13
|
logger = logging.getLogger("replay")
|
|
13
14
|
|
|
14
15
|
|
|
@@ -19,20 +20,15 @@ class DriverNmslibIndexBuilder(IndexBuilder):
|
|
|
19
20
|
|
|
20
21
|
def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
|
|
21
22
|
if filter_seen_items:
|
|
22
|
-
return NmslibFilterIndexInferer(
|
|
23
|
-
self.index_params, self.index_store
|
|
24
|
-
)
|
|
23
|
+
return NmslibFilterIndexInferer(self.index_params, self.index_store)
|
|
25
24
|
else:
|
|
26
25
|
return NmslibIndexInferer(self.index_params, self.index_store)
|
|
27
26
|
|
|
28
|
-
# pylint: disable=no-member
|
|
29
27
|
def build_index(
|
|
30
28
|
self,
|
|
31
29
|
vectors: SparkDataFrame,
|
|
32
|
-
features_col: str,
|
|
33
|
-
ids_col: Optional[str] = None,
|
|
30
|
+
features_col: str, # noqa: ARG002
|
|
31
|
+
ids_col: Optional[str] = None, # noqa: ARG002
|
|
34
32
|
):
|
|
35
33
|
vectors = spark_to_pandas(vectors, self.allow_collect_to_master)
|
|
36
|
-
NmslibIndexBuilderMixin.build_and_save_index(
|
|
37
|
-
vectors, self.index_params, self.index_store
|
|
38
|
-
)
|
|
34
|
+
NmslibIndexBuilderMixin.build_and_save_index(vectors, self.index_params, self.index_store)
|
|
@@ -3,13 +3,14 @@ from typing import Iterator, Optional
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
|
-
from .base_index_builder import IndexBuilder
|
|
7
6
|
from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
|
|
8
7
|
from replay.models.extensions.ann.index_inferers.hnswlib_filter_index_inferer import HnswlibFilterIndexInferer
|
|
9
8
|
from replay.models.extensions.ann.index_inferers.hnswlib_index_inferer import HnswlibIndexInferer
|
|
10
9
|
from replay.models.extensions.ann.utils import create_hnswlib_index_instance
|
|
11
10
|
from replay.utils import PandasDataFrame, SparkDataFrame
|
|
12
11
|
|
|
12
|
+
from .base_index_builder import IndexBuilder
|
|
13
|
+
|
|
13
14
|
logger = logging.getLogger("replay")
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,9 +21,7 @@ class ExecutorHnswlibIndexBuilder(IndexBuilder):
|
|
|
20
21
|
|
|
21
22
|
def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
|
|
22
23
|
if filter_seen_items:
|
|
23
|
-
return HnswlibFilterIndexInferer(
|
|
24
|
-
self.index_params, self.index_store
|
|
25
|
-
)
|
|
24
|
+
return HnswlibFilterIndexInferer(self.index_params, self.index_store)
|
|
26
25
|
else:
|
|
27
26
|
return HnswlibIndexInferer(self.index_params, self.index_store)
|
|
28
27
|
|
|
@@ -56,17 +55,11 @@ class ExecutorHnswlibIndexBuilder(IndexBuilder):
|
|
|
56
55
|
# ids will be from [0, ..., len(vectors_np)]
|
|
57
56
|
index.add_items(np.stack(vectors_np))
|
|
58
57
|
|
|
59
|
-
_index_store.save_to_store(
|
|
60
|
-
lambda path: index.save_index( # pylint: disable=unnecessary-lambda)
|
|
61
|
-
path
|
|
62
|
-
)
|
|
63
|
-
)
|
|
58
|
+
_index_store.save_to_store(lambda path: index.save_index(path))
|
|
64
59
|
|
|
65
60
|
yield PandasDataFrame(data={"_success": 1}, index=[0])
|
|
66
61
|
|
|
67
62
|
# Here we perform materialization (`.collect()`) to build the hnsw index.
|
|
68
63
|
cols = [ids_col, features_col] if ids_col else [features_col]
|
|
69
64
|
|
|
70
|
-
vectors.select(*cols).mapInPandas(
|
|
71
|
-
build_index_udf, "_success int"
|
|
72
|
-
).collect()
|
|
65
|
+
vectors.select(*cols).mapInPandas(build_index_udf, "_success int").collect()
|
|
@@ -3,13 +3,14 @@ from typing import Iterator, Optional
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
-
from .base_index_builder import IndexBuilder
|
|
7
|
-
from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
|
|
8
6
|
from replay.models.extensions.ann.index_inferers.base_inferer import IndexInferer
|
|
9
7
|
from replay.models.extensions.ann.index_inferers.nmslib_filter_index_inferer import NmslibFilterIndexInferer
|
|
10
8
|
from replay.models.extensions.ann.index_inferers.nmslib_index_inferer import NmslibIndexInferer
|
|
11
9
|
from replay.utils import PandasDataFrame, SparkDataFrame
|
|
12
10
|
|
|
11
|
+
from .base_index_builder import IndexBuilder
|
|
12
|
+
from .nmslib_index_builder_mixin import NmslibIndexBuilderMixin
|
|
13
|
+
|
|
13
14
|
logger = logging.getLogger("replay")
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,9 +21,7 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
|
|
|
20
21
|
|
|
21
22
|
def produce_inferer(self, filter_seen_items: bool) -> IndexInferer:
|
|
22
23
|
if filter_seen_items:
|
|
23
|
-
return NmslibFilterIndexInferer(
|
|
24
|
-
self.index_params, self.index_store
|
|
25
|
-
)
|
|
24
|
+
return NmslibFilterIndexInferer(self.index_params, self.index_store)
|
|
26
25
|
else:
|
|
27
26
|
return NmslibIndexInferer(self.index_params, self.index_store)
|
|
28
27
|
|
|
@@ -47,15 +46,9 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
|
|
|
47
46
|
# with the same `item_idx_two`.
|
|
48
47
|
# And therefore we cannot call the `addDataPointBatch` iteratively
|
|
49
48
|
# (in build_and_save_index).
|
|
50
|
-
|
|
51
|
-
for pdf in iterator:
|
|
52
|
-
pdfs.append(pdf)
|
|
53
|
-
|
|
54
|
-
pdf = pd.concat(pdfs)
|
|
49
|
+
pdf = pd.concat(list(iterator))
|
|
55
50
|
|
|
56
|
-
NmslibIndexBuilderMixin.build_and_save_index(
|
|
57
|
-
pdf, index_params, index_store
|
|
58
|
-
)
|
|
51
|
+
NmslibIndexBuilderMixin.build_and_save_index(pdf, index_params, index_store)
|
|
59
52
|
|
|
60
53
|
yield PandasDataFrame(data={"_success": 1}, index=[0])
|
|
61
54
|
|
|
@@ -64,8 +57,8 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
|
|
|
64
57
|
def build_index(
|
|
65
58
|
self,
|
|
66
59
|
vectors: SparkDataFrame,
|
|
67
|
-
features_col: str,
|
|
68
|
-
ids_col: Optional[str] = None,
|
|
60
|
+
features_col: str, # noqa: ARG002
|
|
61
|
+
ids_col: Optional[str] = None, # noqa: ARG002
|
|
69
62
|
):
|
|
70
63
|
# to execution in one executor
|
|
71
64
|
vectors = vectors.repartition(1)
|
|
@@ -74,6 +67,6 @@ class ExecutorNmslibIndexBuilder(IndexBuilder):
|
|
|
74
67
|
build_index_udf = self.make_build_index_udf()
|
|
75
68
|
|
|
76
69
|
# Here we perform materialization (`.collect()`) to build the hnsw index.
|
|
77
|
-
vectors.select(
|
|
78
|
-
|
|
79
|
-
).
|
|
70
|
+
vectors.select("similarity", "item_idx_one", "item_idx_two").mapInPandas(
|
|
71
|
+
build_index_udf, "_success int"
|
|
72
|
+
).collect()
|
|
@@ -6,7 +6,6 @@ from replay.models.extensions.ann.utils import create_nmslib_index_instance
|
|
|
6
6
|
from replay.utils import PandasDataFrame
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
# pylint: disable=too-few-public-methods
|
|
10
9
|
class NmslibIndexBuilderMixin:
|
|
11
10
|
"""Provides nmslib index building method for different nmslib index builders"""
|
|
12
11
|
|
|
@@ -49,6 +48,4 @@ class NmslibIndexBuilderMixin:
|
|
|
49
48
|
index.addDataPointBatch(data=sim_matrix)
|
|
50
49
|
index.createIndex(creation_index_params)
|
|
51
50
|
|
|
52
|
-
index_store.save_to_store(
|
|
53
|
-
lambda path: index.saveIndex(path, save_data=True)
|
|
54
|
-
) # pylint: disable=unnecessary-lambda)
|
|
51
|
+
index_store.save_to_store(lambda path: index.saveIndex(path, save_data=True))
|
|
@@ -8,7 +8,6 @@ if PYSPARK_AVAILABLE:
|
|
|
8
8
|
from pyspark.sql import functions as sf
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
# pylint: disable=too-few-public-methods
|
|
12
11
|
class IndexInferer(ABC):
|
|
13
12
|
"""Abstract base class that describes a common interface for index inferers
|
|
14
13
|
and provides common methods for them."""
|
|
@@ -21,9 +20,7 @@ class IndexInferer(ABC):
|
|
|
21
20
|
self.index_store = index_store
|
|
22
21
|
|
|
23
22
|
@abstractmethod
|
|
24
|
-
def infer(
|
|
25
|
-
self, vectors: SparkDataFrame, features_col: str, k: int
|
|
26
|
-
) -> SparkDataFrame:
|
|
23
|
+
def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
|
|
27
24
|
"""Infers index"""
|
|
28
25
|
|
|
29
26
|
@staticmethod
|
|
@@ -51,9 +48,7 @@ class IndexInferer(ABC):
|
|
|
51
48
|
"""
|
|
52
49
|
res = inference_result.select(
|
|
53
50
|
"user_idx",
|
|
54
|
-
sf.explode(
|
|
55
|
-
sf.arrays_zip("neighbours.item_idx", "neighbours.distance")
|
|
56
|
-
).alias("zip_exp"),
|
|
51
|
+
sf.explode(sf.arrays_zip("neighbours.item_idx", "neighbours.distance")).alias("zip_exp"),
|
|
57
52
|
)
|
|
58
53
|
|
|
59
54
|
# Fix arrays_zip random behavior.
|
|
@@ -65,8 +60,6 @@ class IndexInferer(ABC):
|
|
|
65
60
|
res = res.select(
|
|
66
61
|
"user_idx",
|
|
67
62
|
sf.col(f"zip_exp.{item_idx_field_name}").alias("item_idx"),
|
|
68
|
-
(sf.lit(-1.0) * sf.col(f"zip_exp.{distance_field_name}")).alias(
|
|
69
|
-
"relevance"
|
|
70
|
-
),
|
|
63
|
+
(sf.lit(-1.0) * sf.col(f"zip_exp.{distance_field_name}")).alias("relevance"),
|
|
71
64
|
)
|
|
72
65
|
return res
|
|
@@ -1,28 +1,24 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pandas as pd
|
|
3
3
|
|
|
4
|
-
from .base_inferer import IndexInferer
|
|
5
4
|
from replay.models.extensions.ann.utils import create_hnswlib_index_instance
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
7
6
|
from replay.utils.session_handler import State
|
|
8
7
|
|
|
8
|
+
from .base_inferer import IndexInferer
|
|
9
|
+
|
|
9
10
|
if PYSPARK_AVAILABLE:
|
|
10
11
|
from pyspark.sql.pandas.functions import pandas_udf
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
14
14
|
class HnswlibFilterIndexInferer(IndexInferer):
|
|
15
15
|
"""Hnswlib index inferer with filter seen items. Infers hnswlib index."""
|
|
16
16
|
|
|
17
|
-
def infer(
|
|
18
|
-
self, vectors: SparkDataFrame, features_col: str, k: int
|
|
19
|
-
) -> SparkDataFrame:
|
|
17
|
+
def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
|
|
20
18
|
_index_store = self.index_store
|
|
21
19
|
index_params = self.index_params
|
|
22
20
|
|
|
23
|
-
index_store_broadcast = State().session.sparkContext.broadcast(
|
|
24
|
-
_index_store
|
|
25
|
-
)
|
|
21
|
+
index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
|
|
26
22
|
|
|
27
23
|
@pandas_udf(self.udf_return_type)
|
|
28
24
|
def infer_index_udf(
|
|
@@ -34,9 +30,7 @@ class HnswlibFilterIndexInferer(IndexInferer):
|
|
|
34
30
|
index = index_store.load_index(
|
|
35
31
|
init_index=lambda: create_hnswlib_index_instance(index_params),
|
|
36
32
|
load_index=lambda index, path: index.load_index(path),
|
|
37
|
-
configure_index=lambda index: index.set_ef(index_params.ef_s)
|
|
38
|
-
if index_params.ef_s
|
|
39
|
-
else None,
|
|
33
|
+
configure_index=lambda index: index.set_ef(index_params.ef_s) if index_params.ef_s else None,
|
|
40
34
|
)
|
|
41
35
|
|
|
42
36
|
# max number of items to retrieve per batch
|
|
@@ -51,13 +45,9 @@ class HnswlibFilterIndexInferer(IndexInferer):
|
|
|
51
45
|
filtered_labels = []
|
|
52
46
|
filtered_distances = []
|
|
53
47
|
for i, item_ids in enumerate(labels):
|
|
54
|
-
non_seen_item_indexes = ~np.isin(
|
|
55
|
-
item_ids, seen_item_ids[i], assume_unique=True
|
|
56
|
-
)
|
|
48
|
+
non_seen_item_indexes = ~np.isin(item_ids, seen_item_ids[i], assume_unique=True)
|
|
57
49
|
filtered_labels.append((item_ids[non_seen_item_indexes])[:k])
|
|
58
|
-
filtered_distances.append(
|
|
59
|
-
(distances[i][non_seen_item_indexes])[:k]
|
|
60
|
-
)
|
|
50
|
+
filtered_distances.append((distances[i][non_seen_item_indexes])[:k])
|
|
61
51
|
|
|
62
52
|
pd_res = pd.DataFrame(
|
|
63
53
|
{
|
|
@@ -1,28 +1,24 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pandas as pd
|
|
3
3
|
|
|
4
|
-
from .base_inferer import IndexInferer
|
|
5
4
|
from replay.models.extensions.ann.utils import create_hnswlib_index_instance
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
7
6
|
from replay.utils.session_handler import State
|
|
8
7
|
|
|
8
|
+
from .base_inferer import IndexInferer
|
|
9
|
+
|
|
9
10
|
if PYSPARK_AVAILABLE:
|
|
10
11
|
from pyspark.sql.pandas.functions import pandas_udf
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
14
14
|
class HnswlibIndexInferer(IndexInferer):
|
|
15
15
|
"""Hnswlib index inferer without filter seen items. Infers hnswlib index."""
|
|
16
16
|
|
|
17
|
-
def infer(
|
|
18
|
-
self, vectors: SparkDataFrame, features_col: str, k: int
|
|
19
|
-
) -> SparkDataFrame:
|
|
17
|
+
def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame:
|
|
20
18
|
_index_store = self.index_store
|
|
21
19
|
index_params = self.index_params
|
|
22
20
|
|
|
23
|
-
index_store_broadcast = State().session.sparkContext.broadcast(
|
|
24
|
-
_index_store
|
|
25
|
-
)
|
|
21
|
+
index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
|
|
26
22
|
|
|
27
23
|
@pandas_udf(self.udf_return_type)
|
|
28
24
|
def infer_index_udf(vectors: pd.Series) -> PandasDataFrame: # pragma: no cover
|
|
@@ -30,9 +26,7 @@ class HnswlibIndexInferer(IndexInferer):
|
|
|
30
26
|
index = index_store.load_index(
|
|
31
27
|
init_index=lambda: create_hnswlib_index_instance(index_params),
|
|
32
28
|
load_index=lambda index, path: index.load_index(path),
|
|
33
|
-
configure_index=lambda index: index.set_ef(index_params.ef_s)
|
|
34
|
-
if index_params.ef_s
|
|
35
|
-
else None,
|
|
29
|
+
configure_index=lambda index: index.set_ef(index_params.ef_s) if index_params.ef_s else None,
|
|
36
30
|
)
|
|
37
31
|
|
|
38
32
|
labels, distances = index.knn_query(
|
|
@@ -41,9 +35,7 @@ class HnswlibIndexInferer(IndexInferer):
|
|
|
41
35
|
num_threads=1,
|
|
42
36
|
)
|
|
43
37
|
|
|
44
|
-
pd_res = pd.DataFrame(
|
|
45
|
-
{"item_idx": list(labels), "distance": list(distances)}
|
|
46
|
-
)
|
|
38
|
+
pd_res = pd.DataFrame({"item_idx": list(labels), "distance": list(distances)})
|
|
47
39
|
|
|
48
40
|
return pd_res
|
|
49
41
|
|