replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -11,22 +11,17 @@ if PYSPARK_AVAILABLE:
|
|
|
11
11
|
from pyspark.sql.pandas.functions import pandas_udf
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
# pylint: disable=too-few-public-methods
|
|
15
14
|
class NmslibFilterIndexInferer(IndexInferer):
|
|
16
15
|
"""Nmslib index inferer with filter seen items. Infers nmslib hnsw index."""
|
|
17
16
|
|
|
18
|
-
def infer(
|
|
19
|
-
self, vectors: SparkDataFrame, features_col: str, k: int
|
|
20
|
-
) -> SparkDataFrame:
|
|
17
|
+
def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame: # noqa: ARG002
|
|
21
18
|
_index_store = self.index_store
|
|
22
19
|
index_params = self.index_params
|
|
23
20
|
|
|
24
|
-
index_store_broadcast = State().session.sparkContext.broadcast(
|
|
25
|
-
_index_store
|
|
26
|
-
)
|
|
21
|
+
index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
|
|
27
22
|
|
|
28
23
|
@pandas_udf(self.udf_return_type)
|
|
29
|
-
def infer_index_udf(
|
|
24
|
+
def infer_index_udf(
|
|
30
25
|
user_idx: pd.Series,
|
|
31
26
|
vector_items: pd.Series,
|
|
32
27
|
vector_ratings: pd.Series,
|
|
@@ -36,12 +31,8 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
36
31
|
index_store = index_store_broadcast.value
|
|
37
32
|
index = index_store.load_index(
|
|
38
33
|
init_index=lambda: create_nmslib_index_instance(index_params),
|
|
39
|
-
load_index=lambda index, path: index.loadIndex(
|
|
40
|
-
|
|
41
|
-
),
|
|
42
|
-
configure_index=lambda index: index.setQueryTimeParams(
|
|
43
|
-
{"efSearch": index_params.ef_s}
|
|
44
|
-
)
|
|
34
|
+
load_index=lambda index, path: index.loadIndex(path, load_data=True),
|
|
35
|
+
configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
|
|
45
36
|
if index_params.ef_s
|
|
46
37
|
else None,
|
|
47
38
|
)
|
|
@@ -49,9 +40,7 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
49
40
|
# max number of items to retrieve per batch
|
|
50
41
|
max_items_to_retrieve = num_items.max()
|
|
51
42
|
|
|
52
|
-
user_vectors = get_csr_matrix(
|
|
53
|
-
user_idx, vector_items, vector_ratings
|
|
54
|
-
)
|
|
43
|
+
user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
|
|
55
44
|
|
|
56
45
|
neighbours = index.knnQueryBatch(
|
|
57
46
|
user_vectors[user_idx.values, :],
|
|
@@ -61,9 +50,7 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
61
50
|
|
|
62
51
|
neighbours_filtered = []
|
|
63
52
|
for i, (item_idxs, distances) in enumerate(neighbours):
|
|
64
|
-
non_seen_item_indexes = ~np.isin(
|
|
65
|
-
item_idxs, seen_item_ids[i], assume_unique=True
|
|
66
|
-
)
|
|
53
|
+
non_seen_item_indexes = ~np.isin(item_idxs, seen_item_ids[i], assume_unique=True)
|
|
67
54
|
neighbours_filtered.append(
|
|
68
55
|
(
|
|
69
56
|
(item_idxs[non_seen_item_indexes])[:k],
|
|
@@ -71,14 +58,14 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
71
58
|
)
|
|
72
59
|
)
|
|
73
60
|
|
|
74
|
-
pd_res = PandasDataFrame(
|
|
75
|
-
neighbours_filtered, columns=["item_idx", "distance"]
|
|
76
|
-
)
|
|
61
|
+
pd_res = PandasDataFrame(neighbours_filtered, columns=["item_idx", "distance"])
|
|
77
62
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
63
|
+
"""
|
|
64
|
+
pd_res looks like
|
|
65
|
+
item_idx distances
|
|
66
|
+
[1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
|
|
67
|
+
[1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
|
|
68
|
+
"""
|
|
82
69
|
|
|
83
70
|
return pd_res
|
|
84
71
|
|
|
@@ -89,7 +76,6 @@ class NmslibFilterIndexInferer(IndexInferer):
|
|
|
89
76
|
"num_items",
|
|
90
77
|
"seen_item_idxs",
|
|
91
78
|
]
|
|
92
|
-
# cols = cols + ["num_items", "seen_item_idxs"]
|
|
93
79
|
|
|
94
80
|
res = vectors.select(
|
|
95
81
|
"user_idx",
|
|
@@ -1,28 +1,24 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
|
|
3
|
-
from .base_inferer import IndexInferer
|
|
4
|
-
from .utils import get_csr_matrix
|
|
5
3
|
from replay.models.extensions.ann.utils import create_nmslib_index_instance
|
|
6
4
|
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
7
5
|
from replay.utils.session_handler import State
|
|
8
6
|
|
|
7
|
+
from .base_inferer import IndexInferer
|
|
8
|
+
from .utils import get_csr_matrix
|
|
9
|
+
|
|
9
10
|
if PYSPARK_AVAILABLE:
|
|
10
11
|
from pyspark.sql.pandas.functions import pandas_udf
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
14
14
|
class NmslibIndexInferer(IndexInferer):
|
|
15
15
|
"""Nmslib index inferer without filter seen items. Infers nmslib hnsw index."""
|
|
16
16
|
|
|
17
|
-
def infer(
|
|
18
|
-
self, vectors: SparkDataFrame, features_col: str, k: int
|
|
19
|
-
) -> SparkDataFrame:
|
|
17
|
+
def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame: # noqa: ARG002
|
|
20
18
|
_index_store = self.index_store
|
|
21
19
|
index_params = self.index_params
|
|
22
20
|
|
|
23
|
-
index_store_broadcast = State().session.sparkContext.broadcast(
|
|
24
|
-
_index_store
|
|
25
|
-
)
|
|
21
|
+
index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
|
|
26
22
|
|
|
27
23
|
@pandas_udf(self.udf_return_type)
|
|
28
24
|
def infer_index_udf(
|
|
@@ -33,29 +29,23 @@ class NmslibIndexInferer(IndexInferer):
|
|
|
33
29
|
index_store = index_store_broadcast.value
|
|
34
30
|
index = index_store.load_index(
|
|
35
31
|
init_index=lambda: create_nmslib_index_instance(index_params),
|
|
36
|
-
load_index=lambda index, path: index.loadIndex(
|
|
37
|
-
|
|
38
|
-
),
|
|
39
|
-
configure_index=lambda index: index.setQueryTimeParams(
|
|
40
|
-
{"efSearch": index_params.ef_s}
|
|
41
|
-
)
|
|
32
|
+
load_index=lambda index, path: index.loadIndex(path, load_data=True),
|
|
33
|
+
configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
|
|
42
34
|
if index_params.ef_s
|
|
43
35
|
else None,
|
|
44
36
|
)
|
|
45
37
|
|
|
46
|
-
user_vectors = get_csr_matrix(
|
|
47
|
-
|
|
48
|
-
)
|
|
49
|
-
neighbours = index.knnQueryBatch(
|
|
50
|
-
user_vectors[user_idx.values, :], k=k, num_threads=1
|
|
51
|
-
)
|
|
38
|
+
user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
|
|
39
|
+
neighbours = index.knnQueryBatch(user_vectors[user_idx.values, :], k=k, num_threads=1)
|
|
52
40
|
|
|
53
41
|
pd_res = PandasDataFrame(neighbours, columns=["item_idx", "distance"])
|
|
54
42
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
43
|
+
"""
|
|
44
|
+
pd_res looks like
|
|
45
|
+
item_idx distances
|
|
46
|
+
[1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
|
|
47
|
+
[1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
|
|
48
|
+
"""
|
|
59
49
|
|
|
60
50
|
return pd_res
|
|
61
51
|
|
|
@@ -12,19 +12,12 @@ def get_csr_matrix(
|
|
|
12
12
|
(
|
|
13
13
|
vector_ratings.explode().values.astype(float),
|
|
14
14
|
(
|
|
15
|
-
user_idx.repeat(
|
|
16
|
-
vector_items.apply(
|
|
17
|
-
lambda x: len(x) # pylint: disable=unnecessary-lambda
|
|
18
|
-
)
|
|
19
|
-
).values,
|
|
15
|
+
user_idx.repeat(vector_items.apply(lambda x: len(x))).values,
|
|
20
16
|
vector_items.explode().values.astype(int),
|
|
21
17
|
),
|
|
22
18
|
),
|
|
23
19
|
shape=(
|
|
24
20
|
user_idx.max() + 1,
|
|
25
|
-
vector_items.apply(
|
|
26
|
-
lambda x: max(x) # pylint: disable=unnecessary-lambda
|
|
27
|
-
).max()
|
|
28
|
-
+ 1,
|
|
21
|
+
vector_items.apply(lambda x: max(x)).max() + 1,
|
|
29
22
|
),
|
|
30
23
|
)
|
|
@@ -19,12 +19,9 @@ class HdfsIndexStore(IndexStore):
|
|
|
19
19
|
index_dir_path = os.path.join(warehouse_dir, index_dir)
|
|
20
20
|
self._index_dir_info = get_filesystem(index_dir_path)
|
|
21
21
|
if self._index_dir_info.filesystem != FileSystem.HDFS:
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self._hadoop_fs = fs.HadoopFileSystem.from_uri(
|
|
26
|
-
self._index_dir_info.hdfs_uri
|
|
27
|
-
)
|
|
22
|
+
msg = f"Can't recognize path {index_dir_path} as HDFS path!"
|
|
23
|
+
raise ValueError(msg)
|
|
24
|
+
self._hadoop_fs = fs.HadoopFileSystem.from_uri(self._index_dir_info.hdfs_uri)
|
|
28
25
|
super().__init__()
|
|
29
26
|
|
|
30
27
|
if self.cleanup:
|
|
@@ -32,9 +29,7 @@ class HdfsIndexStore(IndexStore):
|
|
|
32
29
|
"Index directory %s is marked for deletion via weakref.finalize()",
|
|
33
30
|
self._index_dir_info.path,
|
|
34
31
|
)
|
|
35
|
-
weakref.finalize(
|
|
36
|
-
self, self._hadoop_fs.delete_dir, self._index_dir_info.path
|
|
37
|
-
)
|
|
32
|
+
weakref.finalize(self, self._hadoop_fs.delete_dir, self._index_dir_info.path)
|
|
38
33
|
|
|
39
34
|
def load_index(
|
|
40
35
|
self,
|
|
@@ -17,9 +17,7 @@ class SharedDiskIndexStore(IndexStore):
|
|
|
17
17
|
It can also be used with a local disk when the driver and executors
|
|
18
18
|
are running on the same machine."""
|
|
19
19
|
|
|
20
|
-
def __init__(
|
|
21
|
-
self, warehouse_dir: str, index_dir: str, cleanup: bool = True
|
|
22
|
-
):
|
|
20
|
+
def __init__(self, warehouse_dir: str, index_dir: str, cleanup: bool = True):
|
|
23
21
|
self.index_dir_path = os.path.join(warehouse_dir, index_dir)
|
|
24
22
|
super().__init__(cleanup)
|
|
25
23
|
if self.cleanup:
|
|
@@ -52,9 +50,7 @@ class SharedDiskIndexStore(IndexStore):
|
|
|
52
50
|
save_index(temp_file_path)
|
|
53
51
|
|
|
54
52
|
def dump_index(self, target_path: str):
|
|
55
|
-
destination_filesystem, target_path = fs.FileSystem.from_uri(
|
|
56
|
-
target_path
|
|
57
|
-
)
|
|
53
|
+
destination_filesystem, target_path = fs.FileSystem.from_uri(target_path)
|
|
58
54
|
target_path = os.path.join(target_path, "index_files")
|
|
59
55
|
destination_filesystem.create_dir(target_path)
|
|
60
56
|
fs.copy_files(
|
|
@@ -7,11 +7,12 @@ from typing import Any, Callable
|
|
|
7
7
|
|
|
8
8
|
from pyarrow import fs
|
|
9
9
|
|
|
10
|
-
from .base_index_store import IndexStore
|
|
11
|
-
from .utils import FileSystem, get_filesystem
|
|
12
10
|
from replay.utils import PYSPARK_AVAILABLE
|
|
13
11
|
from replay.utils.session_handler import State
|
|
14
12
|
|
|
13
|
+
from .base_index_store import IndexStore
|
|
14
|
+
from .utils import FileSystem, get_filesystem
|
|
15
|
+
|
|
15
16
|
if PYSPARK_AVAILABLE:
|
|
16
17
|
from pyspark import SparkFiles
|
|
17
18
|
|
|
@@ -20,6 +21,7 @@ logger = logging.getLogger("replay")
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
if PYSPARK_AVAILABLE:
|
|
24
|
+
|
|
23
25
|
class SparkFilesIndexStore(IndexStore):
|
|
24
26
|
"""Class that responsible for index store in spark files.
|
|
25
27
|
Works through SparkContext.addFile()."""
|
|
@@ -62,14 +64,10 @@ if PYSPARK_AVAILABLE:
|
|
|
62
64
|
for filename in os.listdir(self.index_dir_path):
|
|
63
65
|
index_file_path = os.path.join(self.index_dir_path, filename)
|
|
64
66
|
spark.sparkContext.addFile("file://" + index_file_path)
|
|
65
|
-
logger.info(
|
|
66
|
-
"Index file %s transferred to executors", index_file_path
|
|
67
|
-
)
|
|
67
|
+
logger.info("Index file %s transferred to executors", index_file_path)
|
|
68
68
|
|
|
69
69
|
def dump_index(self, target_path: str):
|
|
70
|
-
destination_filesystem, target_path = fs.FileSystem.from_uri(
|
|
71
|
-
target_path
|
|
72
|
-
)
|
|
70
|
+
destination_filesystem, target_path = fs.FileSystem.from_uri(target_path)
|
|
73
71
|
target_path = os.path.join(target_path, "index_files")
|
|
74
72
|
destination_filesystem.create_dir(target_path)
|
|
75
73
|
fs.copy_files(
|
|
@@ -83,9 +81,7 @@ if PYSPARK_AVAILABLE:
|
|
|
83
81
|
"""Loads index from `path` directory to spark files."""
|
|
84
82
|
path_info = get_filesystem(path)
|
|
85
83
|
source_filesystem, path = fs.FileSystem.from_uri(
|
|
86
|
-
path_info.hdfs_uri + path_info.path
|
|
87
|
-
if path_info.filesystem == FileSystem.HDFS
|
|
88
|
-
else path_info.path
|
|
84
|
+
path_info.hdfs_uri + path_info.path if path_info.filesystem == FileSystem.HDFS else path_info.path
|
|
89
85
|
)
|
|
90
86
|
path = os.path.join(path, "index_files")
|
|
91
87
|
self.index_dir_path: str = tempfile.mkdtemp()
|
|
@@ -100,6 +96,4 @@ if PYSPARK_AVAILABLE:
|
|
|
100
96
|
for filename in os.listdir(self.index_dir_path):
|
|
101
97
|
index_file_path = os.path.join(self.index_dir_path, filename)
|
|
102
98
|
spark.sparkContext.addFile("file://" + index_file_path)
|
|
103
|
-
logger.info(
|
|
104
|
-
"Index file %s transferred to executors", index_file_path
|
|
105
|
-
)
|
|
99
|
+
logger.info("Index file %s transferred to executors", index_file_path)
|
|
@@ -9,6 +9,7 @@ if PYSPARK_AVAILABLE:
|
|
|
9
9
|
|
|
10
10
|
class FileSystem(Enum):
|
|
11
11
|
"""File system types"""
|
|
12
|
+
|
|
12
13
|
HDFS = 1
|
|
13
14
|
LOCAL = 2
|
|
14
15
|
|
|
@@ -24,6 +25,7 @@ def get_default_fs() -> str:
|
|
|
24
25
|
@dataclass(frozen=True)
|
|
25
26
|
class FileInfo:
|
|
26
27
|
"""File meta-information: filesystem, path and hdfs_uri (optional)"""
|
|
28
|
+
|
|
27
29
|
path: str
|
|
28
30
|
filesystem: FileSystem
|
|
29
31
|
hdfs_uri: str = None
|
|
@@ -83,15 +85,16 @@ or set 'fs.defaultFS' in hadoop configuration.
|
|
|
83
85
|
if default_fs.startswith("hdfs://"):
|
|
84
86
|
return FileInfo(path[prefix_len:], FileSystem.HDFS, default_fs)
|
|
85
87
|
else:
|
|
86
|
-
|
|
88
|
+
msg = (
|
|
87
89
|
f"Can't get default hdfs uri for path = '{path}'. "
|
|
88
90
|
"Specify an explicit path, such as 'hdfs://host:port/dir/file', "
|
|
89
91
|
"or set 'fs.defaultFS' in hadoop configuration."
|
|
90
92
|
)
|
|
93
|
+
raise ValueError(msg)
|
|
91
94
|
else:
|
|
92
95
|
hostname = path[prefix_len:].split("/", 1)[0]
|
|
93
96
|
hdfs_uri = "hdfs://" + hostname
|
|
94
|
-
return FileInfo(path[len(hdfs_uri):], FileSystem.HDFS, hdfs_uri)
|
|
97
|
+
return FileInfo(path[len(hdfs_uri) :], FileSystem.HDFS, hdfs_uri)
|
|
95
98
|
elif path.startswith("file://"):
|
|
96
99
|
return FileInfo(path[prefix_len:], FileSystem.LOCAL)
|
|
97
100
|
else:
|
|
@@ -15,9 +15,7 @@ def create_hnswlib_index_instance(params: HnswlibParam, init: bool = False):
|
|
|
15
15
|
If `False` then the index will be used to load index data from a file.
|
|
16
16
|
:return: `hnswlib` index instance
|
|
17
17
|
"""
|
|
18
|
-
index = hnswlib.Index(
|
|
19
|
-
space=params.space, dim=params.dim
|
|
20
|
-
)
|
|
18
|
+
index = hnswlib.Index(space=params.space, dim=params.dim)
|
|
21
19
|
|
|
22
20
|
if init:
|
|
23
21
|
# Initializing index - the maximum number of elements should be known beforehand
|
|
@@ -37,10 +35,10 @@ def create_nmslib_index_instance(params: NmslibHnswParam):
|
|
|
37
35
|
:param params: `NmslibHnswParam`
|
|
38
36
|
:return: `nmslib` index
|
|
39
37
|
"""
|
|
40
|
-
index = nmslib.init(
|
|
38
|
+
index = nmslib.init(
|
|
41
39
|
method=params.method,
|
|
42
40
|
space=params.space,
|
|
43
|
-
data_type=nmslib.DataType.SPARSE_VECTOR,
|
|
41
|
+
data_type=nmslib.DataType.SPARSE_VECTOR,
|
|
44
42
|
)
|
|
45
43
|
|
|
46
44
|
return index
|
replay/models/kl_ucb.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
3
2
|
from typing import Optional
|
|
4
|
-
|
|
5
|
-
from replay.utils import PYSPARK_AVAILABLE
|
|
3
|
+
|
|
6
4
|
from scipy.optimize import root_scalar
|
|
7
5
|
|
|
6
|
+
from replay.utils import PYSPARK_AVAILABLE
|
|
7
|
+
|
|
8
|
+
from .ucb import UCB
|
|
9
|
+
|
|
8
10
|
if PYSPARK_AVAILABLE:
|
|
9
|
-
from pyspark.sql.types import DoubleType
|
|
10
11
|
from pyspark.sql.functions import udf
|
|
12
|
+
from pyspark.sql.types import DoubleType
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class KLUCB(UCB):
|
|
@@ -17,7 +19,7 @@ class KLUCB(UCB):
|
|
|
17
19
|
computes item relevance as an upper confidence bound of true fraction of
|
|
18
20
|
positive interactions.
|
|
19
21
|
|
|
20
|
-
In a nutshell, KL-UCB
|
|
22
|
+
In a nutshell, KL-UCB considers the data as the history of interactions
|
|
21
23
|
with items. The interaction may be either positive or negative. For each
|
|
22
24
|
item the model computes empirical frequency of positive interactions
|
|
23
25
|
and estimates the true frequency with an upper confidence bound. The higher
|
|
@@ -137,14 +139,11 @@ class KLUCB(UCB):
|
|
|
137
139
|
super().__init__(exploration_coef, sample, seed)
|
|
138
140
|
|
|
139
141
|
def _calc_item_popularity(self):
|
|
140
|
-
|
|
141
|
-
right_hand_side = math.log(self.full_count) \
|
|
142
|
-
+ self.coef * math.log(math.log(self.full_count))
|
|
142
|
+
right_hand_side = math.log(self.full_count) + self.coef * math.log(math.log(self.full_count))
|
|
143
143
|
eps = 1e-12
|
|
144
144
|
|
|
145
145
|
def bernoulli_kl(proba_p, proba_q): # pragma: no cover
|
|
146
|
-
return proba_p * math.log(proba_p / proba_q)
|
|
147
|
-
(1 - proba_p) * math.log((1 - proba_p) / (1 - proba_q))
|
|
146
|
+
return proba_p * math.log(proba_p / proba_q) + (1 - proba_p) * math.log((1 - proba_p) / (1 - proba_q))
|
|
148
147
|
|
|
149
148
|
@udf(returnType=DoubleType())
|
|
150
149
|
def get_ucb(pos, total): # pragma: no cover
|
|
@@ -152,27 +151,22 @@ class KLUCB(UCB):
|
|
|
152
151
|
|
|
153
152
|
if proba == 0:
|
|
154
153
|
ucb = root_scalar(
|
|
155
|
-
f=lambda qq: math.log(1 / (1 - qq)) - right_hand_side,
|
|
156
|
-
|
|
157
|
-
method='brentq').root
|
|
154
|
+
f=lambda qq: math.log(1 / (1 - qq)) - right_hand_side, bracket=[0, 1 - eps], method="brentq"
|
|
155
|
+
).root
|
|
158
156
|
return ucb
|
|
159
157
|
|
|
160
158
|
if proba == 1:
|
|
161
159
|
ucb = root_scalar(
|
|
162
|
-
f=lambda qq: math.log(1 / qq) - right_hand_side,
|
|
163
|
-
|
|
164
|
-
method='brentq').root
|
|
160
|
+
f=lambda qq: math.log(1 / qq) - right_hand_side, bracket=[0 + eps, 1], method="brentq"
|
|
161
|
+
).root
|
|
165
162
|
return ucb
|
|
166
163
|
|
|
167
164
|
ucb = root_scalar(
|
|
168
|
-
f=lambda q: total * bernoulli_kl(proba, q) - right_hand_side,
|
|
169
|
-
|
|
170
|
-
method='brentq').root
|
|
165
|
+
f=lambda q: total * bernoulli_kl(proba, q) - right_hand_side, bracket=[proba, 1 - eps], method="brentq"
|
|
166
|
+
).root
|
|
171
167
|
return ucb
|
|
172
168
|
|
|
173
|
-
items_counts = self.items_counts_aggr.withColumn(
|
|
174
|
-
self.rating_column, get_ucb("pos", "total")
|
|
175
|
-
)
|
|
169
|
+
items_counts = self.items_counts_aggr.withColumn(self.rating_column, get_ucb("pos", "total"))
|
|
176
170
|
|
|
177
171
|
self.item_popularity = items_counts.drop("pos", "total")
|
|
178
172
|
|
replay/models/knn.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
from replay.data import Dataset
|
|
4
|
-
from .base_neighbour_rec import NeighbourRec
|
|
5
|
-
from .extensions.ann.index_builders.base_index_builder import IndexBuilder
|
|
6
4
|
from replay.optimization.optuna_objective import ItemKNNObjective
|
|
7
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
8
6
|
|
|
7
|
+
from .base_neighbour_rec import NeighbourRec
|
|
8
|
+
from .extensions.ann.index_builders.base_index_builder import IndexBuilder
|
|
9
|
+
|
|
9
10
|
if PYSPARK_AVAILABLE:
|
|
10
11
|
from pyspark.sql import functions as sf
|
|
11
12
|
from pyspark.sql.window import Window
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
# pylint: disable=too-many-ancestors, too-many-instance-attributes
|
|
15
15
|
class ItemKNN(NeighbourRec):
|
|
16
16
|
"""Item-based ItemKNN with modified cosine similarity measure."""
|
|
17
17
|
|
|
@@ -29,15 +29,15 @@ class ItemKNN(NeighbourRec):
|
|
|
29
29
|
_search_space = {
|
|
30
30
|
"num_neighbours": {"type": "int", "args": [1, 100]},
|
|
31
31
|
"shrink": {"type": "int", "args": [0, 100]},
|
|
32
|
-
"weighting": {"type": "categorical", "args": [None, "tf_idf", "bm25"]}
|
|
32
|
+
"weighting": {"type": "categorical", "args": [None, "tf_idf", "bm25"]},
|
|
33
33
|
}
|
|
34
34
|
|
|
35
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
num_neighbours: int = 10,
|
|
38
38
|
use_rating: bool = False,
|
|
39
39
|
shrink: float = 0.0,
|
|
40
|
-
weighting: str = None,
|
|
40
|
+
weighting: Optional[str] = None,
|
|
41
41
|
index_builder: Optional[IndexBuilder] = None,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
@@ -54,7 +54,8 @@ class ItemKNN(NeighbourRec):
|
|
|
54
54
|
|
|
55
55
|
valid_weightings = self._search_space["weighting"]["args"]
|
|
56
56
|
if weighting not in valid_weightings:
|
|
57
|
-
|
|
57
|
+
msg = f"weighting must be one of {valid_weightings}"
|
|
58
|
+
raise ValueError(msg)
|
|
58
59
|
self.weighting = weighting
|
|
59
60
|
if isinstance(index_builder, (IndexBuilder, type(None))):
|
|
60
61
|
self.index_builder = index_builder
|
|
@@ -75,8 +76,7 @@ class ItemKNN(NeighbourRec):
|
|
|
75
76
|
def _shrink(dot_products: SparkDataFrame, shrink: float) -> SparkDataFrame:
|
|
76
77
|
return dot_products.withColumn(
|
|
77
78
|
"similarity",
|
|
78
|
-
sf.col("dot_product")
|
|
79
|
-
/ (sf.col("norm1") * sf.col("norm2") + shrink),
|
|
79
|
+
sf.col("dot_product") / (sf.col("norm1") * sf.col("norm2") + shrink),
|
|
80
80
|
).select("item_idx_one", "item_idx_two", "similarity")
|
|
81
81
|
|
|
82
82
|
def _get_similarity(self, interactions: SparkDataFrame) -> SparkDataFrame:
|
|
@@ -116,25 +116,19 @@ class ItemKNN(NeighbourRec):
|
|
|
116
116
|
:param interactions: SparkDataFrame with interactions, `[user_id, item_id, rating]`
|
|
117
117
|
:return: interactions `[user_id, item_id, rating]`
|
|
118
118
|
"""
|
|
119
|
-
item_stats = interactions.groupBy(self.item_column).agg(
|
|
120
|
-
sf.count(self.query_column).alias("n_queries_per_item")
|
|
121
|
-
)
|
|
119
|
+
item_stats = interactions.groupBy(self.item_column).agg(sf.count(self.query_column).alias("n_queries_per_item"))
|
|
122
120
|
avgdl = item_stats.select(sf.mean("n_queries_per_item")).take(1)[0][0]
|
|
123
121
|
interactions = interactions.join(item_stats, how="inner", on=self.item_column)
|
|
124
122
|
|
|
125
|
-
interactions = (
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
)
|
|
135
|
-
)
|
|
136
|
-
.drop("n_queries_per_item")
|
|
137
|
-
)
|
|
123
|
+
interactions = interactions.withColumn(
|
|
124
|
+
self.rating_column,
|
|
125
|
+
sf.col(self.rating_column)
|
|
126
|
+
* (self.bm25_k1 + 1)
|
|
127
|
+
/ (
|
|
128
|
+
sf.col(self.rating_column)
|
|
129
|
+
+ self.bm25_k1 * (1 - self.bm25_b + self.bm25_b * (sf.col("n_queries_per_item") / avgdl))
|
|
130
|
+
),
|
|
131
|
+
).drop("n_queries_per_item")
|
|
138
132
|
|
|
139
133
|
return interactions
|
|
140
134
|
|
|
@@ -150,23 +144,15 @@ class ItemKNN(NeighbourRec):
|
|
|
150
144
|
n_items = interactions.select(self.item_column).distinct().count()
|
|
151
145
|
|
|
152
146
|
if self.weighting == "tf_idf":
|
|
153
|
-
idf = (
|
|
154
|
-
df.withColumn("idf", sf.log1p(sf.lit(n_items) / sf.col("DF")))
|
|
155
|
-
.drop("DF")
|
|
156
|
-
)
|
|
147
|
+
idf = df.withColumn("idf", sf.log1p(sf.lit(n_items) / sf.col("DF"))).drop("DF")
|
|
157
148
|
elif self.weighting == "bm25":
|
|
158
|
-
idf = (
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
(sf.lit(n_items) - sf.col("DF") + 0.5)
|
|
163
|
-
/ (sf.col("DF") + 0.5)
|
|
164
|
-
),
|
|
165
|
-
)
|
|
166
|
-
.drop("DF")
|
|
167
|
-
)
|
|
149
|
+
idf = df.withColumn(
|
|
150
|
+
"idf",
|
|
151
|
+
sf.log1p((sf.lit(n_items) - sf.col("DF") + 0.5) / (sf.col("DF") + 0.5)),
|
|
152
|
+
).drop("DF")
|
|
168
153
|
else:
|
|
169
|
-
|
|
154
|
+
msg = "weighting must be one of ['tf_idf', 'bm25']"
|
|
155
|
+
raise ValueError(msg)
|
|
170
156
|
|
|
171
157
|
return idf
|
|
172
158
|
|
|
@@ -180,12 +166,12 @@ class ItemKNN(NeighbourRec):
|
|
|
180
166
|
if self.weighting:
|
|
181
167
|
interactions = self._reweight_interactions(interactions)
|
|
182
168
|
|
|
183
|
-
left = interactions.withColumnRenamed(
|
|
184
|
-
self.
|
|
185
|
-
)
|
|
186
|
-
right = interactions.withColumnRenamed(
|
|
187
|
-
self.
|
|
188
|
-
)
|
|
169
|
+
left = interactions.withColumnRenamed(self.item_column, "item_idx_one").withColumnRenamed(
|
|
170
|
+
self.rating_column, "rel_one"
|
|
171
|
+
)
|
|
172
|
+
right = interactions.withColumnRenamed(self.item_column, "item_idx_two").withColumnRenamed(
|
|
173
|
+
self.rating_column, "rel_two"
|
|
174
|
+
)
|
|
189
175
|
|
|
190
176
|
dot_products = (
|
|
191
177
|
left.join(right, how="inner", on=self.query_column)
|
|
@@ -201,19 +187,11 @@ class ItemKNN(NeighbourRec):
|
|
|
201
187
|
.agg(sf.sum(self.rating_column).alias("square_norm"))
|
|
202
188
|
.select(sf.col(self.item_column), sf.sqrt("square_norm").alias("norm"))
|
|
203
189
|
)
|
|
204
|
-
norm1 = item_norms.withColumnRenamed(
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
).withColumnRenamed("norm", "norm2")
|
|
210
|
-
|
|
211
|
-
dot_products = dot_products.join(
|
|
212
|
-
norm1, how="inner", on=sf.col("item_id1") == sf.col("item_idx_one")
|
|
213
|
-
)
|
|
214
|
-
dot_products = dot_products.join(
|
|
215
|
-
norm2, how="inner", on=sf.col("item_id2") == sf.col("item_idx_two")
|
|
216
|
-
)
|
|
190
|
+
norm1 = item_norms.withColumnRenamed(self.item_column, "item_id1").withColumnRenamed("norm", "norm1")
|
|
191
|
+
norm2 = item_norms.withColumnRenamed(self.item_column, "item_id2").withColumnRenamed("norm", "norm2")
|
|
192
|
+
|
|
193
|
+
dot_products = dot_products.join(norm1, how="inner", on=sf.col("item_id1") == sf.col("item_idx_one"))
|
|
194
|
+
dot_products = dot_products.join(norm2, how="inner", on=sf.col("item_id2") == sf.col("item_idx_two"))
|
|
217
195
|
|
|
218
196
|
return dot_products
|
|
219
197
|
|
|
@@ -1,9 +1,4 @@
|
|
|
1
1
|
from replay.utils import TORCH_AVAILABLE
|
|
2
2
|
|
|
3
3
|
if TORCH_AVAILABLE:
|
|
4
|
-
from .optimizer_factory import
|
|
5
|
-
FatLRSchedulerFactory,
|
|
6
|
-
FatOptimizerFactory,
|
|
7
|
-
LRSchedulerFactory,
|
|
8
|
-
OptimizerFactory
|
|
9
|
-
)
|
|
4
|
+
from .optimizer_factory import FatLRSchedulerFactory, FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
|
|
@@ -4,7 +4,6 @@ from typing import Iterator, Tuple
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
# pylint: disable=too-few-public-methods
|
|
8
7
|
class OptimizerFactory(abc.ABC):
|
|
9
8
|
"""
|
|
10
9
|
Interface for optimizer factory
|
|
@@ -21,7 +20,6 @@ class OptimizerFactory(abc.ABC):
|
|
|
21
20
|
"""
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
# pylint: disable=too-few-public-methods
|
|
25
23
|
class LRSchedulerFactory(abc.ABC):
|
|
26
24
|
"""
|
|
27
25
|
Interface for learning rate scheduler factory
|
|
@@ -38,13 +36,11 @@ class LRSchedulerFactory(abc.ABC):
|
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
|
|
41
|
-
# pylint: disable=too-few-public-methods
|
|
42
39
|
class FatOptimizerFactory(OptimizerFactory):
|
|
43
40
|
"""
|
|
44
41
|
Factory that creates optimizer depending on passed parameters
|
|
45
42
|
"""
|
|
46
43
|
|
|
47
|
-
# pylint: disable=too-many-arguments
|
|
48
44
|
def __init__(
|
|
49
45
|
self,
|
|
50
46
|
optimizer: str = "adam",
|
|
@@ -74,10 +70,11 @@ class FatOptimizerFactory(OptimizerFactory):
|
|
|
74
70
|
return torch.optim.SGD(
|
|
75
71
|
parameters, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.sgd_momentum
|
|
76
72
|
)
|
|
77
|
-
|
|
73
|
+
|
|
74
|
+
msg = "Unexpected optimizer"
|
|
75
|
+
raise ValueError(msg)
|
|
78
76
|
|
|
79
77
|
|
|
80
|
-
# pylint: disable=too-few-public-methods
|
|
81
78
|
class FatLRSchedulerFactory(LRSchedulerFactory):
|
|
82
79
|
"""
|
|
83
80
|
Factory that creates learning rate schedule depending on passed parameters
|